fennol.utils.spherical_harmonics
1import numpy as np 2import jax 3import math 4import jax.numpy as jnp 5import sympy 6from sympy.printing.pycode import pycode 7from sympy.physics.wigner import clebsch_gordan 8from functools import partial 9 10 11def CG_SU2(j1: int, j2: int, j3: int) -> np.array: 12 r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SU(2)` 13 Returns 14 ------- 15 `np.Array` 16 tensor :math:`C` of shape :math:`(2j_1+1, 2j_2+1, 2j_3+1)` 17 """ 18 C = np.zeros((2 * j1 + 1, 2 * j2 + 1, 2 * j3 + 1)) 19 for m1 in range(-j1, j1 + 1): 20 for m2 in range(-j2, j2 + 1): 21 for m3 in range(-j3, j3 + 1): 22 C[m1 + j1, m2 + j2, m3 + j3] = float( 23 clebsch_gordan(j1, j2, j3, m1, m2, m3) 24 ) 25 return C 26 27 28def change_basis_real_to_complex(l: int) -> np.array: 29 r"""Change of basis matrix from real to complex spherical harmonics 30 https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form 31 adapted from e3nn.o3._wigner 32 """ 33 q = np.zeros((2 * l + 1, 2 * l + 1), dtype=complex) 34 for m in range(-l, 0): 35 q[l + m, l + abs(m)] = 1 / 2**0.5 36 q[l + m, l - abs(m)] = -1j / 2**0.5 37 q[l, l] = 1 38 for m in range(1, l + 1): 39 q[l + m, l + abs(m)] = (-1) ** m / 2**0.5 40 q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5 41 42 # factor of (-i)**l to make the Clebsch-Gordan coefficients real 43 return q * (-1j) ** l 44 45 46def CG_SO3(j1: int, j2: int, j3: int) -> np.array: 47 r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SO(3)` 48 Returns 49 ------- 50 `torch.Tensor` 51 tensor :math:`C` of shape :math:`(2l_1+1, 2l_2+1, 2l_3+1)` 52 """ 53 C = CG_SU2(j1, j2, j3) 54 Q1 = change_basis_real_to_complex(j1) 55 Q2 = change_basis_real_to_complex(j2) 56 Q3 = change_basis_real_to_complex(j3) 57 C = np.real(np.einsum("ij,kl,mn,ikn->jlm", Q1, Q2, np.conj(Q3.T), C)) 58 return C / np.linalg.norm(C) 59 60 61def generate_spherical_harmonics( 62 lmax, normalize=False, print_code=False, jit=False, vmapped=False 63): # pragma: no cover 64 r"""returns a function that computes spherical harmonic up to lmax 65 (adapted from e3nn) 66 """ 67 68 def to_frac(x: float): 69 from fractions import Fraction 70 71 s = 1 if x >= 0 else -1 72 x = x**2 73 x = Fraction(x).limit_denominator() 74 x = s * sympy.sqrt(x) 75 x = sympy.simplify(x) 76 return x 77 78 if vmapped: 79 fn_str = "def spherical_harmonics_(x,y,z):\n" 80 fn_str += " sh_0_0 = 1.\n" 81 else: 82 fn_str = "def spherical_harmonics_(vec):\n" 83 if normalize: 84 fn_str += " vec = vec/jnp.linalg.norm(vec,axis=-1,keepdims=True)\n" 85 fn_str += " x,y,z = [jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)]\n" 86 fn_str += " sh_0_0 = jnp.ones_like(x)\n" 87 88 x_var, y_var, z_var = sympy.symbols("x y z") 89 polynomials = [sympy.sqrt(3) * x_var, sympy.sqrt(3) * y_var, sympy.sqrt(3) * z_var] 90 91 def sub_z1(p, names, polynormz): 92 p = p.subs(x_var, 0).subs(y_var, 1).subs(z_var, 0) 93 for n, c in zip(names, polynormz): 94 p = p.subs(n, c) 95 return p 96 97 poly_evalz = [sub_z1(p, [], []) for p in polynomials] 98 99 for l in range(1, lmax + 1): 100 sh_variables = sympy.symbols(" ".join(f"sh_{l}_{m}" for m in range(2 * l + 1))) 101 102 for n, p in zip(sh_variables, polynomials): 103 fn_str += f" {n} = {pycode(p.evalf())}\n" 104 105 if l == lmax: 106 break 107 108 polynomials = [ 109 sum( 110 to_frac(c.item()) * v * sh 111 for cj, v in zip(cij, [x_var, y_var, z_var]) 112 for c, sh in zip(cj, sh_variables) 113 ) 114 for cij in CG_SO3(l + 1, 1, l) 115 ] 116 117 poly_evalz = [sub_z1(p, sh_variables, poly_evalz) for p in polynomials] 118 norm = sympy.sqrt(sum(p**2 for p in poly_evalz)) 119 polynomials = [sympy.sqrt(2 * l + 3) * p / norm for p in polynomials] 120 poly_evalz = [sympy.sqrt(2 * l + 3) * p / norm for p in poly_evalz] 121 122 polynomials = [sympy.simplify(p, full=True) for p in polynomials] 123 124 u = ",\n ".join( 125 ", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1) 126 ) 127 if vmapped: 128 fn_str += f" return jnp.array([\n {u}\n ])\n" 129 else: 130 fn_str += f" return jnp.stack([\n {u}\n ], axis=-1)\n" 131 132 if print_code: 133 print(fn_str) 134 exec(fn_str) 135 sh = locals()["spherical_harmonics_"] 136 if jit: 137 sh = jax.jit(sh) 138 if not vmapped: 139 return sh 140 141 if normalize: 142 143 def spherical_harmonics(vec): 144 vec = vec / jnp.linalg.norm(vec, axis=-1, keepdims=True) 145 x, y, z = [ 146 jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3) 147 ] 148 return jax.vmap(sh)(x, y, z) 149 150 else: 151 152 def spherical_harmonics(vec): 153 x, y, z = [ 154 jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3) 155 ] 156 return jax.vmap(sh)(x, y, z) 157 158 if jit: 159 spherical_harmonics = jax.jit(spherical_harmonics) 160 return spherical_harmonics 161 162 163@partial(jax.jit, static_argnums=1) 164def spherical_to_cartesian_tensor(Q, lmax): 165 q = Q[..., 0] 166 if lmax == 0: 167 return q[..., None] 168 169 mu = Q[..., 1:4] 170 if lmax == 1: 171 return jnp.concatenate([q[..., None], mu], axis=-1) 172 173 Q22s = Q[..., 4] 174 Q21s = Q[..., 5] 175 Q20 = Q[..., 6] 176 Q21c = Q[..., 7] 177 Q22c = Q[..., 8] 178 Tzz = -0.5 * Q20 + (0.5 * 3**0.5) * Q22c 179 Txx = -0.5 * Q20 - (0.5 * 3**0.5) * Q22c 180 Tyy = Q20 181 Txz = 0.5 * (3**0.5) * Q22s 182 Tyz = 0.5 * (3**0.5) * Q21c 183 Txy = 0.5 * (3**0.5) * Q21s 184 185 if lmax == 2: 186 return jnp.concatenate( 187 [ 188 q[..., None], 189 mu, 190 Txx[..., None], 191 Tyy[..., None], 192 Tzz[..., None], 193 Txy[..., None], 194 Txz[..., None], 195 Tyz[..., None], 196 ], 197 axis=-1, 198 )
def
CG_SU2(j1: int, j2: int, j3: int) -> <built-in function array>:
12def CG_SU2(j1: int, j2: int, j3: int) -> np.array: 13 r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SU(2)` 14 Returns 15 ------- 16 `np.Array` 17 tensor :math:`C` of shape :math:`(2j_1+1, 2j_2+1, 2j_3+1)` 18 """ 19 C = np.zeros((2 * j1 + 1, 2 * j2 + 1, 2 * j3 + 1)) 20 for m1 in range(-j1, j1 + 1): 21 for m2 in range(-j2, j2 + 1): 22 for m3 in range(-j3, j3 + 1): 23 C[m1 + j1, m2 + j2, m3 + j3] = float( 24 clebsch_gordan(j1, j2, j3, m1, m2, m3) 25 ) 26 return C
Clebsch-Gordan coefficients for the direct product of two irreducible representations of \( SU(2) \)
Returns
np.Array
: tensor \( C \) of shape \( (2j_1+1, 2j_2+1, 2j_3+1) \)
def
change_basis_real_to_complex(l: int) -> <built-in function array>:
29def change_basis_real_to_complex(l: int) -> np.array: 30 r"""Change of basis matrix from real to complex spherical harmonics 31 https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form 32 adapted from e3nn.o3._wigner 33 """ 34 q = np.zeros((2 * l + 1, 2 * l + 1), dtype=complex) 35 for m in range(-l, 0): 36 q[l + m, l + abs(m)] = 1 / 2**0.5 37 q[l + m, l - abs(m)] = -1j / 2**0.5 38 q[l, l] = 1 39 for m in range(1, l + 1): 40 q[l + m, l + abs(m)] = (-1) ** m / 2**0.5 41 q[l + m, l - abs(m)] = 1j * (-1) ** m / 2**0.5 42 43 # factor of (-i)**l to make the Clebsch-Gordan coefficients real 44 return q * (-1j) ** l
Change of basis matrix from real to complex spherical harmonics https://en.wikipedia.org/wiki/Spherical_harmonics#Real_form adapted from e3nn.o3._wigner
def
CG_SO3(j1: int, j2: int, j3: int) -> <built-in function array>:
47def CG_SO3(j1: int, j2: int, j3: int) -> np.array: 48 r"""Clebsch-Gordan coefficients for the direct product of two irreducible representations of :math:`SO(3)` 49 Returns 50 ------- 51 `torch.Tensor` 52 tensor :math:`C` of shape :math:`(2l_1+1, 2l_2+1, 2l_3+1)` 53 """ 54 C = CG_SU2(j1, j2, j3) 55 Q1 = change_basis_real_to_complex(j1) 56 Q2 = change_basis_real_to_complex(j2) 57 Q3 = change_basis_real_to_complex(j3) 58 C = np.real(np.einsum("ij,kl,mn,ikn->jlm", Q1, Q2, np.conj(Q3.T), C)) 59 return C / np.linalg.norm(C)
Clebsch-Gordan coefficients for the direct product of two irreducible representations of \( SO(3) \)
Returns
torch.Tensor
: tensor \( C \) of shape \( (2l_1+1, 2l_2+1, 2l_3+1) \)
def
generate_spherical_harmonics(lmax, normalize=False, print_code=False, jit=False, vmapped=False):
62def generate_spherical_harmonics( 63 lmax, normalize=False, print_code=False, jit=False, vmapped=False 64): # pragma: no cover 65 r"""returns a function that computes spherical harmonic up to lmax 66 (adapted from e3nn) 67 """ 68 69 def to_frac(x: float): 70 from fractions import Fraction 71 72 s = 1 if x >= 0 else -1 73 x = x**2 74 x = Fraction(x).limit_denominator() 75 x = s * sympy.sqrt(x) 76 x = sympy.simplify(x) 77 return x 78 79 if vmapped: 80 fn_str = "def spherical_harmonics_(x,y,z):\n" 81 fn_str += " sh_0_0 = 1.\n" 82 else: 83 fn_str = "def spherical_harmonics_(vec):\n" 84 if normalize: 85 fn_str += " vec = vec/jnp.linalg.norm(vec,axis=-1,keepdims=True)\n" 86 fn_str += " x,y,z = [jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3)]\n" 87 fn_str += " sh_0_0 = jnp.ones_like(x)\n" 88 89 x_var, y_var, z_var = sympy.symbols("x y z") 90 polynomials = [sympy.sqrt(3) * x_var, sympy.sqrt(3) * y_var, sympy.sqrt(3) * z_var] 91 92 def sub_z1(p, names, polynormz): 93 p = p.subs(x_var, 0).subs(y_var, 1).subs(z_var, 0) 94 for n, c in zip(names, polynormz): 95 p = p.subs(n, c) 96 return p 97 98 poly_evalz = [sub_z1(p, [], []) for p in polynomials] 99 100 for l in range(1, lmax + 1): 101 sh_variables = sympy.symbols(" ".join(f"sh_{l}_{m}" for m in range(2 * l + 1))) 102 103 for n, p in zip(sh_variables, polynomials): 104 fn_str += f" {n} = {pycode(p.evalf())}\n" 105 106 if l == lmax: 107 break 108 109 polynomials = [ 110 sum( 111 to_frac(c.item()) * v * sh 112 for cj, v in zip(cij, [x_var, y_var, z_var]) 113 for c, sh in zip(cj, sh_variables) 114 ) 115 for cij in CG_SO3(l + 1, 1, l) 116 ] 117 118 poly_evalz = [sub_z1(p, sh_variables, poly_evalz) for p in polynomials] 119 norm = sympy.sqrt(sum(p**2 for p in poly_evalz)) 120 polynomials = [sympy.sqrt(2 * l + 3) * p / norm for p in polynomials] 121 poly_evalz = [sympy.sqrt(2 * l + 3) * p / norm for p in poly_evalz] 122 123 polynomials = [sympy.simplify(p, full=True) for p in polynomials] 124 125 u = ",\n ".join( 126 ", ".join(f"sh_{j}_{m}" for m in range(2 * j + 1)) for j in range(l + 1) 127 ) 128 if vmapped: 129 fn_str += f" return jnp.array([\n {u}\n ])\n" 130 else: 131 fn_str += f" return jnp.stack([\n {u}\n ], axis=-1)\n" 132 133 if print_code: 134 print(fn_str) 135 exec(fn_str) 136 sh = locals()["spherical_harmonics_"] 137 if jit: 138 sh = jax.jit(sh) 139 if not vmapped: 140 return sh 141 142 if normalize: 143 144 def spherical_harmonics(vec): 145 vec = vec / jnp.linalg.norm(vec, axis=-1, keepdims=True) 146 x, y, z = [ 147 jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3) 148 ] 149 return jax.vmap(sh)(x, y, z) 150 151 else: 152 153 def spherical_harmonics(vec): 154 x, y, z = [ 155 jax.lax.index_in_dim(vec, i, axis=-1, keepdims=False) for i in range(3) 156 ] 157 return jax.vmap(sh)(x, y, z) 158 159 if jit: 160 spherical_harmonics = jax.jit(spherical_harmonics) 161 return spherical_harmonics
returns a function that computes spherical harmonic up to lmax (adapted from e3nn)
@partial(jax.jit, static_argnums=1)
def
spherical_to_cartesian_tensor(Q, lmax):
164@partial(jax.jit, static_argnums=1) 165def spherical_to_cartesian_tensor(Q, lmax): 166 q = Q[..., 0] 167 if lmax == 0: 168 return q[..., None] 169 170 mu = Q[..., 1:4] 171 if lmax == 1: 172 return jnp.concatenate([q[..., None], mu], axis=-1) 173 174 Q22s = Q[..., 4] 175 Q21s = Q[..., 5] 176 Q20 = Q[..., 6] 177 Q21c = Q[..., 7] 178 Q22c = Q[..., 8] 179 Tzz = -0.5 * Q20 + (0.5 * 3**0.5) * Q22c 180 Txx = -0.5 * Q20 - (0.5 * 3**0.5) * Q22c 181 Tyy = Q20 182 Txz = 0.5 * (3**0.5) * Q22s 183 Tyz = 0.5 * (3**0.5) * Q21c 184 Txy = 0.5 * (3**0.5) * Q21s 185 186 if lmax == 2: 187 return jnp.concatenate( 188 [ 189 q[..., None], 190 mu, 191 Txx[..., None], 192 Tyy[..., None], 193 Tzz[..., None], 194 Txy[..., None], 195 Txz[..., None], 196 Tyz[..., None], 197 ], 198 axis=-1, 199 )