fennol.utils.spherical_harmonics

  1import numpy as np
  2import jax
  3import math
  4import jax.numpy as jnp
  5import sympy
  6from sympy.printing.pycode import pycode
  7from sympy.physics.wigner import clebsch_gordan
  8from functools import partial
  9
 10
 11def CG_SU2(j1: int, j2: int, j3: int) -> np.array:
 12    r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SU(2)`
 13    Returns
 14    -------
 15    `np.Array`
 16        tensor :math:`C` of shape :math:`(2j_1+1, 2j_2+1, 2j_3+1)`
 17    """
 18    C = np.zeros((2 * j1 + 1, 2 * j2 + 1, 2 * j3 + 1))
 19    for m1 in range(-j1, j1 + 1):
 20        for m2 in range(-j2, j2 + 1):
 21            for m3 in range(-j3, j3 + 1):
 22                C[m1 + j1, m2 + j2, m3 + j3] = float(
 23                    clebsch_gordan(j1, j2, j3, m1, m2, m3)
 24                )
 25    return C
 26
 27
 28def change_basis_real_to_complex(l: int) -> np.array:
 29    r"""Change of basis matrix from real to complex spherical harmonics
 30    https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
 31    adapted from e3nn.o3._wigner
 32    """
 33    q = np.zeros((2 * l + 1, 2 * l + 1), dtype=complex)
 34    for m in range(-l, 0):
 35        q[l + m, l + abs(m)] = 1 / 2**0.5
 36        q[l + m, l - abs(m)] = -1j / 2**0.5
 37    q[l, l] = 1
 38    for m in range(1, l + 1):
 39        q[l + m, l + abs(m)] = (-1) ** m / 2**0.5
 40        q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5
 41
 42    # factor of (-i)**l to make the Clebsch-Gordan coefficients real
 43    return q * (-1j) ** l
 44
 45
 46def CG_SO3(j1: int, j2: int, j3: int) -> np.array:
 47    r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SO(3)`
 48    Returns
 49    -------
 50    `torch.Tensor`
 51        tensor :math:`C` of shape :math:`(2l_1+1, 2l_2+1, 2l_3+1)`
 52    """
 53    C = CG_SU2(j1, j2, j3)
 54    Q1 = change_basis_real_to_complex(j1)
 55    Q2 = change_basis_real_to_complex(j2)
 56    Q3 = change_basis_real_to_complex(j3)
 57    C = np.real(np.einsum("ij,kl,mn,ikn->jlm", Q1, Q2, np.conj(Q3.T), C))
 58    return C / np.linalg.norm(C)
 59
 60
 61def generate_spherical_harmonics(
 62    lmax, normalize=False, print_code=False, jit=False, vmapped=False
 63):  # pragma: no cover
 64    r"""returns a function that computes spherical harmonic up to lmax
 65    (adapted from e3nn)
 66    """
 67
 68    def to_frac(x: float):
 69        from fractions import Fraction
 70
 71        s = 1 if x >= 0 else -1
 72        x = x**2
 73        x = Fraction(x).limit_denominator()
 74        x = s * sympy.sqrt(x)
 75        x = sympy.simplify(x)
 76        return x
 77
 78    if vmapped:
 79        fn_str = "def spherical_harmonics_(x,y,z):\n"
 80        fn_str += "  sh_0_0 = 1.\n"
 81    else:
 82        fn_str = "def spherical_harmonics_(vec):\n"
 83        if normalize:
 84            fn_str += "  vec = vec/jnp.linalg.norm(vec,axis=-1,keepdims=True)\n"
 85        fn_str += "  x,y,z = [jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)]\n"
 86        fn_str += "  sh_0_0 = jnp.ones_like(x)\n"
 87
 88    x_var, y_var, z_var = sympy.symbols("x y z")
 89    polynomials = [sympy.sqrt(3) * x_var, sympy.sqrt(3) * y_var, sympy.sqrt(3) * z_var]
 90
 91    def sub_z1(p, names, polynormz):
 92        p = p.subs(x_var, 0).subs(y_var, 1).subs(z_var, 0)
 93        for n, c in zip(names, polynormz):
 94            p = p.subs(n, c)
 95        return p
 96
 97    poly_evalz = [sub_z1(p, [], []) for p in polynomials]
 98
 99    for l in range(1, lmax + 1):
100        sh_variables = sympy.symbols(" ".join(f"sh_{l}_{m}" for m in range(2 * l + 1)))
101
102        for n, p in zip(sh_variables, polynomials):
103            fn_str += f"  {n} = {pycode(p.evalf())}\n"
104
105        if l == lmax:
106            break
107
108        polynomials = [
109            sum(
110                to_frac(c.item()) * v * sh
111                for cj, v in zip(cij, [x_var, y_var, z_var])
112                for c, sh in zip(cj, sh_variables)
113            )
114            for cij in CG_SO3(l + 1, 1, l)
115        ]
116
117        poly_evalz = [sub_z1(p, sh_variables, poly_evalz) for p in polynomials]
118        norm = sympy.sqrt(sum(p**2 for p in poly_evalz))
119        polynomials = [sympy.sqrt(2 * l + 3) * p / norm for p in polynomials]
120        poly_evalz = [sympy.sqrt(2 * l + 3) * p / norm for p in poly_evalz]
121
122        polynomials = [sympy.simplify(p, full=True) for p in polynomials]
123
124    u = ",\n        ".join(
125        ", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1)
126    )
127    if vmapped:
128        fn_str += f"  return jnp.array([\n        {u}\n    ])\n"
129    else:
130        fn_str += f"  return jnp.stack([\n        {u}\n    ], axis=-1)\n"
131
132    if print_code:
133        print(fn_str)
134    exec(fn_str)
135    sh = locals()["spherical_harmonics_"]
136    if jit:
137        sh = jax.jit(sh)
138    if not vmapped:
139        return sh
140
141    if normalize:
142
143        def spherical_harmonics(vec):
144            vec = vec / jnp.linalg.norm(vec, axis=-1, keepdims=True)
145            x, y, z = [
146                jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)
147            ]
148            return jax.vmap(sh)(x, y, z)
149
150    else:
151
152        def spherical_harmonics(vec):
153            x, y, z = [
154                jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)
155            ]
156            return jax.vmap(sh)(x, y, z)
157
158    if jit:
159        spherical_harmonics = jax.jit(spherical_harmonics)
160    return spherical_harmonics
161
162
163@partial(jax.jit, static_argnums=1)
164def spherical_to_cartesian_tensor(Q, lmax):
165    q = Q[..., 0]
166    if lmax == 0:
167        return q[..., None]
168
169    mu = Q[..., 1:4]
170    if lmax == 1:
171        return jnp.concatenate([q[..., None], mu], axis=-1)
172
173    Q22s = Q[..., 4]
174    Q21s = Q[..., 5]
175    Q20 = Q[..., 6]
176    Q21c = Q[..., 7]
177    Q22c = Q[..., 8]
178    Tzz = -0.5 * Q20 + (0.5 * 3**0.5) * Q22c
179    Txx = -0.5 * Q20 - (0.5 * 3**0.5) * Q22c
180    Tyy = Q20
181    Txz = 0.5 * (3**0.5) * Q22s
182    Tyz = 0.5 * (3**0.5) * Q21c
183    Txy = 0.5 * (3**0.5) * Q21s
184
185    if lmax == 2:
186        return jnp.concatenate(
187            [
188                q[..., None],
189                mu,
190                Txx[..., None],
191                Tyy[..., None],
192                Tzz[..., None],
193                Txy[..., None],
194                Txz[..., None],
195                Tyz[..., None],
196            ],
197            axis=-1,
198        )
def CG_SU2(j1: int, j2: int, j3: int) -> <built-in function array>:
12def CG_SU2(j1: int, j2: int, j3: int) -> np.array:
13    r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SU(2)`
14    Returns
15    -------
16    `np.Array`
17        tensor :math:`C` of shape :math:`(2j_1+1, 2j_2+1, 2j_3+1)`
18    """
19    C = np.zeros((2 * j1 + 1, 2 * j2 + 1, 2 * j3 + 1))
20    for m1 in range(-j1, j1 + 1):
21        for m2 in range(-j2, j2 + 1):
22            for m3 in range(-j3, j3 + 1):
23                C[m1 + j1, m2 + j2, m3 + j3] = float(
24                    clebsch_gordan(j1, j2, j3, m1, m2, m3)
25                )
26    return C

Clebsch-Gordan coefficients for the direct product of two irreducible representations of \( SU(2) \)

Returns
  • np.Array: tensor \( C \) of shape \( (2j_1+1, 2j_2+1, 2j_3+1) \)
def change_basis_real_to_complex(l: int) -> <built-in function array>:
29def change_basis_real_to_complex(l: int) -> np.array:
30    r"""Change of basis matrix from real to complex spherical harmonics
31    https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form
32    adapted from e3nn.o3._wigner
33    """
34    q = np.zeros((2 * l + 1, 2 * l + 1), dtype=complex)
35    for m in range(-l, 0):
36        q[l + m, l + abs(m)] = 1 / 2**0.5
37        q[l + m, l - abs(m)] = -1j / 2**0.5
38    q[l, l] = 1
39    for m in range(1, l + 1):
40        q[l + m, l + abs(m)] = (-1) ** m / 2**0.5
41        q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5
42
43    # factor of (-i)**l to make the Clebsch-Gordan coefficients real
44    return q * (-1j) ** l

Change of basis matrix from real to complex spherical harmonics https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form adapted from e3nn.o3._wigner

def CG_SO3(j1: int, j2: int, j3: int) -> <built-in function array>:
47def CG_SO3(j1: int, j2: int, j3: int) -> np.array:
48    r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SO(3)`
49    Returns
50    -------
51    `torch.Tensor`
52        tensor :math:`C` of shape :math:`(2l_1+1, 2l_2+1, 2l_3+1)`
53    """
54    C = CG_SU2(j1, j2, j3)
55    Q1 = change_basis_real_to_complex(j1)
56    Q2 = change_basis_real_to_complex(j2)
57    Q3 = change_basis_real_to_complex(j3)
58    C = np.real(np.einsum("ij,kl,mn,ikn->jlm", Q1, Q2, np.conj(Q3.T), C))
59    return C / np.linalg.norm(C)

Clebsch-Gordan coefficients for the direct product of two irreducible representations of \( SO(3) \)

Returns
  • torch.Tensor: tensor \( C \) of shape \( (2l_1+1, 2l_2+1, 2l_3+1) \)
def generate_spherical_harmonics(lmax, normalize=False, print_code=False, jit=False, vmapped=False):
 62def generate_spherical_harmonics(
 63    lmax, normalize=False, print_code=False, jit=False, vmapped=False
 64):  # pragma: no cover
 65    r"""returns a function that computes spherical harmonic up to lmax
 66    (adapted from e3nn)
 67    """
 68
 69    def to_frac(x: float):
 70        from fractions import Fraction
 71
 72        s = 1 if x >= 0 else -1
 73        x = x**2
 74        x = Fraction(x).limit_denominator()
 75        x = s * sympy.sqrt(x)
 76        x = sympy.simplify(x)
 77        return x
 78
 79    if vmapped:
 80        fn_str = "def spherical_harmonics_(x,y,z):\n"
 81        fn_str += "  sh_0_0 = 1.\n"
 82    else:
 83        fn_str = "def spherical_harmonics_(vec):\n"
 84        if normalize:
 85            fn_str += "  vec = vec/jnp.linalg.norm(vec,axis=-1,keepdims=True)\n"
 86        fn_str += "  x,y,z = [jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)]\n"
 87        fn_str += "  sh_0_0 = jnp.ones_like(x)\n"
 88
 89    x_var, y_var, z_var = sympy.symbols("x y z")
 90    polynomials = [sympy.sqrt(3) * x_var, sympy.sqrt(3) * y_var, sympy.sqrt(3) * z_var]
 91
 92    def sub_z1(p, names, polynormz):
 93        p = p.subs(x_var, 0).subs(y_var, 1).subs(z_var, 0)
 94        for n, c in zip(names, polynormz):
 95            p = p.subs(n, c)
 96        return p
 97
 98    poly_evalz = [sub_z1(p, [], []) for p in polynomials]
 99
100    for l in range(1, lmax + 1):
101        sh_variables = sympy.symbols(" ".join(f"sh_{l}_{m}" for m in range(2 * l + 1)))
102
103        for n, p in zip(sh_variables, polynomials):
104            fn_str += f"  {n} = {pycode(p.evalf())}\n"
105
106        if l == lmax:
107            break
108
109        polynomials = [
110            sum(
111                to_frac(c.item()) * v * sh
112                for cj, v in zip(cij, [x_var, y_var, z_var])
113                for c, sh in zip(cj, sh_variables)
114            )
115            for cij in CG_SO3(l + 1, 1, l)
116        ]
117
118        poly_evalz = [sub_z1(p, sh_variables, poly_evalz) for p in polynomials]
119        norm = sympy.sqrt(sum(p**2 for p in poly_evalz))
120        polynomials = [sympy.sqrt(2 * l + 3) * p / norm for p in polynomials]
121        poly_evalz = [sympy.sqrt(2 * l + 3) * p / norm for p in poly_evalz]
122
123        polynomials = [sympy.simplify(p, full=True) for p in polynomials]
124
125    u = ",\n        ".join(
126        ", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1)
127    )
128    if vmapped:
129        fn_str += f"  return jnp.array([\n        {u}\n    ])\n"
130    else:
131        fn_str += f"  return jnp.stack([\n        {u}\n    ], axis=-1)\n"
132
133    if print_code:
134        print(fn_str)
135    exec(fn_str)
136    sh = locals()["spherical_harmonics_"]
137    if jit:
138        sh = jax.jit(sh)
139    if not vmapped:
140        return sh
141
142    if normalize:
143
144        def spherical_harmonics(vec):
145            vec = vec / jnp.linalg.norm(vec, axis=-1, keepdims=True)
146            x, y, z = [
147                jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)
148            ]
149            return jax.vmap(sh)(x, y, z)
150
151    else:
152
153        def spherical_harmonics(vec):
154            x, y, z = [
155                jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)
156            ]
157            return jax.vmap(sh)(x, y, z)
158
159    if jit:
160        spherical_harmonics = jax.jit(spherical_harmonics)
161    return spherical_harmonics

returns a function that computes spherical harmonic up to lmax (adapted from e3nn)

@partial(jax.jit, static_argnums=1)
def spherical_to_cartesian_tensor(Q, lmax):
164@partial(jax.jit, static_argnums=1)
165def spherical_to_cartesian_tensor(Q, lmax):
166    q = Q[..., 0]
167    if lmax == 0:
168        return q[..., None]
169
170    mu = Q[..., 1:4]
171    if lmax == 1:
172        return jnp.concatenate([q[..., None], mu], axis=-1)
173
174    Q22s = Q[..., 4]
175    Q21s = Q[..., 5]
176    Q20 = Q[..., 6]
177    Q21c = Q[..., 7]
178    Q22c = Q[..., 8]
179    Tzz = -0.5 * Q20 + (0.5 * 3**0.5) * Q22c
180    Txx = -0.5 * Q20 - (0.5 * 3**0.5) * Q22c
181    Tyy = Q20
182    Txz = 0.5 * (3**0.5) * Q22s
183    Tyz = 0.5 * (3**0.5) * Q21c
184    Txy = 0.5 * (3**0.5) * Q21s
185
186    if lmax == 2:
187        return jnp.concatenate(
188            [
189                q[..., None],
190                mu,
191                Txx[..., None],
192                Tyy[..., None],
193                Tzz[..., None],
194                Txy[..., None],
195                Txz[..., None],
196                Tyz[..., None],
197            ],
198            axis=-1,
199        )