fennol.models.embeddings.spookynet
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Dict, Union, Callable, Sequence, Optional, ClassVar 9from ...utils.initializers import initializer_from_str, scaled_orthogonal 10from ..misc.nets import ResMLP 11 12 13class SpookyNetEmbedding(nn.Module): 14 """SpookyNet equivariant message-passing embedding with electronic encodings (charge + spin). 15 16 FID : SPOOKYNET 17 18 ### Reference 19 Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). https://doi.org/10.1038/s41467-021-27504-0 20 21 22 ### Warning 23 non-local attention interaction is not yet implemented ! 24 25 """ 26 27 _graphs_properties: Dict 28 dim: int = 128 29 """ The dimension of the embedding.""" 30 nlayers: int = 3 31 """ The number of interaction layers.""" 32 graph_key: str = "graph" 33 """ The key for the graph input.""" 34 embedding_key: str = "embedding" 35 """ The key for the embedding output.""" 36 species_encoding: dict = dataclasses.field( 37 default_factory=lambda: {"encoding": "electronic_structure"} 38 ) 39 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`. """ 40 radial_basis: dict = dataclasses.field(default_factory=lambda: {"basis": "spooky"}) 41 """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. """ 42 kernel_init: Union[Callable, str] = "scaled_orthogonal(scale=1.0, mode='fan_avg')" 43 """ The kernel initializer for Dense operations.""" 44 use_spin_encoding: bool = True 45 """ Whether to use spin encoding.""" 46 use_charge_encoding: bool = True 47 """ Whether to use charge encoding.""" 48 total_charge_key: str = "total_charge" 49 """ The key for the total charge input.""" 50 51 FID: ClassVar[str] = "SPOOKYNET" 52 53 @nn.compact 54 def __call__(self, inputs): 55 species = inputs["species"] 56 assert ( 57 len(species.shape) == 1 58 ), "Species must be a 1D array (batches must be flattened)" 59 60 graph = inputs[self.graph_key] 61 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 62 63 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 64 65 kernel_init = ( 66 initializer_from_str(self.kernel_init) 67 if isinstance(self.kernel_init, str) 68 else self.kernel_init 69 ) 70 71 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 72 species 73 ) 74 zrand = SpeciesEncoding( 75 encoding="random", dim=self.dim, name="RandSpeciesEncoding" 76 )(species) 77 78 eZ = ( 79 nn.Dense( 80 self.dim, 81 name="species_linear", 82 use_bias=False, 83 kernel_init=nn.initializers.zeros, 84 )(onehot) 85 + zrand 86 ) 87 xi = eZ 88 89 # encode charge information 90 batch_index = inputs["batch_index"] 91 natoms = inputs["natoms"] 92 if self.use_charge_encoding and ( 93 self.total_charge_key in inputs or self.is_initializing() 94 ): 95 Q = inputs.get(self.total_charge_key, jnp.zeros(natoms.shape[0], dtype=xi.dtype)) 96 kq_pos, kq_neg, vq_pos, vq_neg = self.param( 97 "kv_charge", 98 lambda key, shape: jax.random.normal(key, shape, dtype=xi.dtype), 99 (4, self.dim), 100 ) 101 qi = nn.Dense(self.dim, kernel_init=kernel_init, name="q_linear")(eZ) 102 pos_mask = Q >= 0 103 kq = jnp.where(pos_mask[:,None], kq_pos[None, :], kq_neg[None, :]) 104 vq = jnp.where(pos_mask[:,None], vq_pos[None, :], vq_neg[None, :]) 105 qik = (qi * kq[batch_index]).sum(axis=-1) / self.dim**0.5 106 wi = jax.nn.softplus(qik) 107 wnorm = jax.ops.segment_sum(wi, batch_index, Q.shape[0]) 108 avi = wi[:,None] * ((Q / wnorm)[:, None] * vq)[batch_index] 109 eQ = ResMLP(use_bias=False, name="eQ", kernel_init=kernel_init)(avi) 110 xi = xi + eQ 111 112 # encode spin information 113 if self.use_spin_encoding and ( 114 "total_spin" in inputs or self.is_initializing() 115 ): 116 S = inputs.get("total_spin", jnp.zeros(natoms.shape[0], dtype=xi.dtype)) 117 ks, vs = self.param( 118 "kv_spin", 119 lambda key, shape: jax.random.normal(key, shape, dtype=xi.dtype), 120 (2, self.dim), 121 ) 122 si = nn.Dense(self.dim, kernel_init=kernel_init, name="s_linear")(eZ) 123 sik = (si * ks[None, :]).sum(axis=-1) / self.dim**0.5 124 wi = jax.nn.softplus(sik) 125 wnorm = jax.ops.segment_sum(wi, batch_index, S.shape[0]) 126 avi = wi[:,None] * ((S / wnorm)[:, None] * vs[None, :])[batch_index] 127 eS = ResMLP(use_bias=False, name="eS", kernel_init=kernel_init)(avi) 128 xi = xi + eS 129 130 distances = graph["distances"] 131 switch = graph["switch"][:, None] 132 dirij = graph["vec"] / distances[:, None] 133 Yij = generate_spherical_harmonics(lmax=2, normalize=False)(dirij)[:, None, :] 134 135 radial_basis = ( 136 RadialBasis( 137 **{ 138 **self.radial_basis, 139 "end": cutoff, 140 "name": f"RadialBasis", 141 } 142 )(distances) 143 * switch 144 ) 145 146 gij = radial_basis[:, :, None] * Yij 147 148 gsij = gij[:, :, 0] 149 gpij = jnp.transpose(gij[:, :, 1:4], (0, 2, 1)) 150 gdij = jnp.transpose(gij[:, :, 4:], (0, 2, 1)) 151 152 y = 0.0 153 154 for layer in range(self.nlayers): 155 xtilde = ResMLP(res_only=True, name=f"xtilde_{layer}")(xi) 156 157 ### compute local update 158 c = ResMLP(name=f"c_{layer}", kernel_init=kernel_init)(xtilde) 159 sj = ResMLP(name=f"s_{layer}", kernel_init=kernel_init)(xtilde) 160 pj = ResMLP(name=f"p_{layer}", kernel_init=kernel_init)(xtilde) 161 dj = ResMLP(name=f"d_{layer}", kernel_init=kernel_init)(xtilde) 162 163 Gs = nn.Dense( 164 self.dim, use_bias=False, name=f"Gs_{layer}", kernel_init=kernel_init 165 )(gsij) 166 Gp = nn.Dense( 167 self.dim, use_bias=False, name=f"Gp_{layer}", kernel_init=kernel_init 168 )(gpij) 169 Gd = nn.Dense( 170 self.dim, use_bias=False, name=f"Gd_{layer}", kernel_init=kernel_init 171 )(gdij) 172 173 si = jax.ops.segment_sum(sj[edge_dst] * Gs, edge_src, xi.shape[0]) 174 pi = jax.ops.segment_sum(pj[edge_dst, None, :] * Gp, edge_src, xi.shape[0]) 175 di = jax.ops.segment_sum(dj[edge_dst, None, :] * Gd, edge_src, xi.shape[0]) 176 177 P1, P2 = jnp.split( 178 nn.Dense( 179 2 * self.dim, 180 use_bias=False, 181 name=f"P12_{layer}", 182 kernel_init=kernel_init, 183 )(pi), 184 2, 185 axis=-1, 186 ) 187 D1, D2 = jnp.split( 188 nn.Dense( 189 2 * self.dim, 190 use_bias=False, 191 name=f"D12_{layer}", 192 kernel_init=kernel_init, 193 )(di), 194 2, 195 axis=-1, 196 ) 197 198 P12 = (P1 * P2).sum(axis=1) 199 D12 = (D1 * D2).sum(axis=1) 200 201 l = ResMLP(name=f"l_{layer}", kernel_init=kernel_init)(c + si + P12 + D12) 202 203 ### aggregate and update 204 xi = ResMLP(name=f"xi_{layer}", kernel_init=kernel_init)(xtilde + l) 205 y = y + ResMLP(name=f"y_{layer}", kernel_init=kernel_init)(xi) 206 207 output = { 208 **inputs, 209 self.embedding_key: y, 210 } 211 if self.use_charge_encoding and "total_charge" in inputs: 212 output[self.embedding_key + "_eQ"] = eQ 213 if self.use_spin_encoding and "total_spin" in inputs: 214 output[self.embedding_key + "_eS"] = eS 215 return output
14class SpookyNetEmbedding(nn.Module): 15 """SpookyNet equivariant message-passing embedding with electronic encodings (charge + spin). 16 17 FID : SPOOKYNET 18 19 ### Reference 20 Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). https://doi.org/10.1038/s41467-021-27504-0 21 22 23 ### Warning 24 non-local attention interaction is not yet implemented ! 25 26 """ 27 28 _graphs_properties: Dict 29 dim: int = 128 30 """ The dimension of the embedding.""" 31 nlayers: int = 3 32 """ The number of interaction layers.""" 33 graph_key: str = "graph" 34 """ The key for the graph input.""" 35 embedding_key: str = "embedding" 36 """ The key for the embedding output.""" 37 species_encoding: dict = dataclasses.field( 38 default_factory=lambda: {"encoding": "electronic_structure"} 39 ) 40 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`. """ 41 radial_basis: dict = dataclasses.field(default_factory=lambda: {"basis": "spooky"}) 42 """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. """ 43 kernel_init: Union[Callable, str] = "scaled_orthogonal(scale=1.0, mode='fan_avg')" 44 """ The kernel initializer for Dense operations.""" 45 use_spin_encoding: bool = True 46 """ Whether to use spin encoding.""" 47 use_charge_encoding: bool = True 48 """ Whether to use charge encoding.""" 49 total_charge_key: str = "total_charge" 50 """ The key for the total charge input.""" 51 52 FID: ClassVar[str] = "SPOOKYNET" 53 54 @nn.compact 55 def __call__(self, inputs): 56 species = inputs["species"] 57 assert ( 58 len(species.shape) == 1 59 ), "Species must be a 1D array (batches must be flattened)" 60 61 graph = inputs[self.graph_key] 62 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 63 64 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 65 66 kernel_init = ( 67 initializer_from_str(self.kernel_init) 68 if isinstance(self.kernel_init, str) 69 else self.kernel_init 70 ) 71 72 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 73 species 74 ) 75 zrand = SpeciesEncoding( 76 encoding="random", dim=self.dim, name="RandSpeciesEncoding" 77 )(species) 78 79 eZ = ( 80 nn.Dense( 81 self.dim, 82 name="species_linear", 83 use_bias=False, 84 kernel_init=nn.initializers.zeros, 85 )(onehot) 86 + zrand 87 ) 88 xi = eZ 89 90 # encode charge information 91 batch_index = inputs["batch_index"] 92 natoms = inputs["natoms"] 93 if self.use_charge_encoding and ( 94 self.total_charge_key in inputs or self.is_initializing() 95 ): 96 Q = inputs.get(self.total_charge_key, jnp.zeros(natoms.shape[0], dtype=xi.dtype)) 97 kq_pos, kq_neg, vq_pos, vq_neg = self.param( 98 "kv_charge", 99 lambda key, shape: jax.random.normal(key, shape, dtype=xi.dtype), 100 (4, self.dim), 101 ) 102 qi = nn.Dense(self.dim, kernel_init=kernel_init, name="q_linear")(eZ) 103 pos_mask = Q >= 0 104 kq = jnp.where(pos_mask[:,None], kq_pos[None, :], kq_neg[None, :]) 105 vq = jnp.where(pos_mask[:,None], vq_pos[None, :], vq_neg[None, :]) 106 qik = (qi * kq[batch_index]).sum(axis=-1) / self.dim**0.5 107 wi = jax.nn.softplus(qik) 108 wnorm = jax.ops.segment_sum(wi, batch_index, Q.shape[0]) 109 avi = wi[:,None] * ((Q / wnorm)[:, None] * vq)[batch_index] 110 eQ = ResMLP(use_bias=False, name="eQ", kernel_init=kernel_init)(avi) 111 xi = xi + eQ 112 113 # encode spin information 114 if self.use_spin_encoding and ( 115 "total_spin" in inputs or self.is_initializing() 116 ): 117 S = inputs.get("total_spin", jnp.zeros(natoms.shape[0], dtype=xi.dtype)) 118 ks, vs = self.param( 119 "kv_spin", 120 lambda key, shape: jax.random.normal(key, shape, dtype=xi.dtype), 121 (2, self.dim), 122 ) 123 si = nn.Dense(self.dim, kernel_init=kernel_init, name="s_linear")(eZ) 124 sik = (si * ks[None, :]).sum(axis=-1) / self.dim**0.5 125 wi = jax.nn.softplus(sik) 126 wnorm = jax.ops.segment_sum(wi, batch_index, S.shape[0]) 127 avi = wi[:,None] * ((S / wnorm)[:, None] * vs[None, :])[batch_index] 128 eS = ResMLP(use_bias=False, name="eS", kernel_init=kernel_init)(avi) 129 xi = xi + eS 130 131 distances = graph["distances"] 132 switch = graph["switch"][:, None] 133 dirij = graph["vec"] / distances[:, None] 134 Yij = generate_spherical_harmonics(lmax=2, normalize=False)(dirij)[:, None, :] 135 136 radial_basis = ( 137 RadialBasis( 138 **{ 139 **self.radial_basis, 140 "end": cutoff, 141 "name": f"RadialBasis", 142 } 143 )(distances) 144 * switch 145 ) 146 147 gij = radial_basis[:, :, None] * Yij 148 149 gsij = gij[:, :, 0] 150 gpij = jnp.transpose(gij[:, :, 1:4], (0, 2, 1)) 151 gdij = jnp.transpose(gij[:, :, 4:], (0, 2, 1)) 152 153 y = 0.0 154 155 for layer in range(self.nlayers): 156 xtilde = ResMLP(res_only=True, name=f"xtilde_{layer}")(xi) 157 158 ### compute local update 159 c = ResMLP(name=f"c_{layer}", kernel_init=kernel_init)(xtilde) 160 sj = ResMLP(name=f"s_{layer}", kernel_init=kernel_init)(xtilde) 161 pj = ResMLP(name=f"p_{layer}", kernel_init=kernel_init)(xtilde) 162 dj = ResMLP(name=f"d_{layer}", kernel_init=kernel_init)(xtilde) 163 164 Gs = nn.Dense( 165 self.dim, use_bias=False, name=f"Gs_{layer}", kernel_init=kernel_init 166 )(gsij) 167 Gp = nn.Dense( 168 self.dim, use_bias=False, name=f"Gp_{layer}", kernel_init=kernel_init 169 )(gpij) 170 Gd = nn.Dense( 171 self.dim, use_bias=False, name=f"Gd_{layer}", kernel_init=kernel_init 172 )(gdij) 173 174 si = jax.ops.segment_sum(sj[edge_dst] * Gs, edge_src, xi.shape[0]) 175 pi = jax.ops.segment_sum(pj[edge_dst, None, :] * Gp, edge_src, xi.shape[0]) 176 di = jax.ops.segment_sum(dj[edge_dst, None, :] * Gd, edge_src, xi.shape[0]) 177 178 P1, P2 = jnp.split( 179 nn.Dense( 180 2 * self.dim, 181 use_bias=False, 182 name=f"P12_{layer}", 183 kernel_init=kernel_init, 184 )(pi), 185 2, 186 axis=-1, 187 ) 188 D1, D2 = jnp.split( 189 nn.Dense( 190 2 * self.dim, 191 use_bias=False, 192 name=f"D12_{layer}", 193 kernel_init=kernel_init, 194 )(di), 195 2, 196 axis=-1, 197 ) 198 199 P12 = (P1 * P2).sum(axis=1) 200 D12 = (D1 * D2).sum(axis=1) 201 202 l = ResMLP(name=f"l_{layer}", kernel_init=kernel_init)(c + si + P12 + D12) 203 204 ### aggregate and update 205 xi = ResMLP(name=f"xi_{layer}", kernel_init=kernel_init)(xtilde + l) 206 y = y + ResMLP(name=f"y_{layer}", kernel_init=kernel_init)(xi) 207 208 output = { 209 **inputs, 210 self.embedding_key: y, 211 } 212 if self.use_charge_encoding and "total_charge" in inputs: 213 output[self.embedding_key + "_eQ"] = eQ 214 if self.use_spin_encoding and "total_spin" in inputs: 215 output[self.embedding_key + "_eS"] = eS 216 return output
SpookyNet equivariant message-passing embedding with electronic encodings (charge + spin).
FID : SPOOKYNET
Reference
Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). https://doi.org/10.1038/s41467-021-27504-0
Warning
non-local attention interaction is not yet implemented !
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
.
The kernel initializer for Dense operations.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.