fennol.models.embeddings.allegro
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Any, Dict, List, Union, Callable, Tuple, Sequence, Optional, ClassVar 9from ..misc.nets import FullyConnectedNet 10from ..misc.e3 import ( 11 FilteredTensorProduct, 12 ChannelMixingE3, 13 ChannelMixing, 14 E3NN_AVAILABLE, 15 E3NN_EXCEPTION, 16 Irreps, 17) 18 19 20class AllegroEmbedding(nn.Module): 21 """Allegro equivariant pair embedding 22 23 FID : ALLEGRO 24 25 ### Reference 26 Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics. 27 Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y 28 29 """ 30 31 _graphs_properties: Dict 32 dim: int = 128 33 """ The dimension of the embedding.""" 34 nchannels: int = 16 35 """ The number of equivariant channels.""" 36 nlayers: int = 3 37 """ The number of interaction layers.""" 38 lmax: int = 2 39 """ The maximum degree of tensorial embedding.""" 40 lmax_density: Optional[int] = None 41 """ The maximum degree of spherical harmonics for density. 42 If None, it will be set to lmax. Must be greater or equal to lmax.""" 43 twobody_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 44 """ The number of hidden neurons in the two-body network.""" 45 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 46 """ The number of hidden neurons in the embedding network.""" 47 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 48 """ The number of hidden neurons in the latent network.""" 49 activation: Union[Callable, str] = "silu" 50 """ The activation function to use.""" 51 graph_key: str = "graph" 52 """ The key in the input dictionary that corresponds to the graph.""" 53 embedding_key: str = "embedding" 54 """ The key to use for the output embedding in the returned dictionary.""" 55 tensor_embedding_key: str = "tensor_embedding" 56 """ The key to use for the output tensor embedding in the returned dictionary.""" 57 species_encoding: dict = dataclasses.field(default_factory=dict) 58 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 59 radial_basis: dict = dataclasses.field(default_factory=dict) 60 """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`""" 61 62 FID: ClassVar[str] = "ALLEGRO" 63 64 @nn.compact 65 def __call__(self, inputs): 66 """ Forward pass of the Allegro model. """ 67 species = inputs["species"] 68 69 graph = inputs[self.graph_key] 70 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 71 switch = graph["switch"][:, None] 72 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 73 radial_basis = RadialBasis( 74 **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"} 75 )(graph["distances"]) 76 77 species_encoding = SpeciesEncoding( 78 **self.species_encoding, name="SpeciesEncoding" 79 )(species) 80 81 xij = ( 82 FullyConnectedNet( 83 neurons=[*self.twobody_hidden, self.dim], activation=self.activation 84 )( 85 jnp.concatenate( 86 [ 87 species_encoding[edge_src], 88 species_encoding[edge_dst], 89 radial_basis, 90 ], 91 axis=-1, 92 ) 93 ) 94 * switch 95 ) 96 97 lmax_density = self.lmax_density if self.lmax_density is not None else self.lmax 98 assert lmax_density >= self.lmax 99 100 Yij = generate_spherical_harmonics(lmax=lmax_density, normalize=False)( 101 graph["vec"] / graph["distances"][:, None] 102 )[:, None, :] 103 104 nel = (self.lmax + 1) ** 2 105 Vij = ( 106 ChannelMixingE3(self.lmax, 1, self.nchannels)(Yij[..., :nel]) 107 * nn.Dense(self.nchannels, use_bias=False)(xij)[:, :, None] 108 ) 109 110 for _ in range(self.nlayers): 111 rhoij = ( 112 FullyConnectedNet( 113 neurons=[*self.embedding_hidden, self.nchannels], 114 activation=self.activation, 115 )(xij) 116 * switch 117 )[:, :, None] * Yij 118 density = ( 119 jnp.zeros((species.shape[0], *rhoij.shape[1:])).at[edge_src].add(rhoij) 120 ) 121 122 Lij = FilteredTensorProduct(self.lmax, lmax_density)(Vij, density[edge_src]) 123 scals = jax.lax.index_in_dim(Lij, 0, axis=-1, keepdims=False) 124 lij = FullyConnectedNet(neurons=[*self.latent_hidden, self.dim])( 125 jnp.concatenate((xij, scals), axis=-1) 126 ) 127 128 xij = xij + lij * switch 129 Vij = ChannelMixing(self.lmax, self.nchannels, self.nchannels)(Lij) 130 131 if self.embedding_key is None: 132 return xij, Vij 133 return {**inputs, self.embedding_key: xij, self.tensor_embedding_key: Vij} 134 135 136if E3NN_AVAILABLE: 137 import e3nn_jax as e3nn 138 139 class AllegroE3NNEmbedding(nn.Module): 140 """Allegro equivariant pair embedding 141 142 FID : ALLEGRO_E3NN 143 144 in this version, equivariant operations use the e3nn library. 145 146 Reference 147 --------- 148 Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics. 149 Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y 150 151 """ 152 153 _graphs_properties: Dict 154 dim: int = 128 155 """ The dimension of the embedding.""" 156 nchannels: int = 16 157 """ The number of equivariant channels.""" 158 nlayers: int = 3 159 """ The number of interaction layers.""" 160 irreps_Vij: Union[str, int, 'Irreps'] = 2 161 """ Irreps used for the tensor embedding.""" 162 lmax_density: int = None 163 """ The maximum degree of spherical harmonics for density. 164 If None, it will be set to lmax. Must be greater or equal to lmax.""" 165 twobody_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 166 """ The number of hidden neurons in the two-body network.""" 167 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 168 """ The number of hidden neurons in the embedding network.""" 169 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 170 """ The number of hidden neurons in the latent network.""" 171 activation: Union[Callable, str] = "silu" 172 """ The activation function to use.""" 173 graph_key: str = "graph" 174 """ The key in the input dictionary that corresponds to the graph.""" 175 embedding_key: str = "embedding" 176 """ The key to use for the output embedding in the returned dictionary.""" 177 tensor_embedding_key: str = "tensor_embedding" 178 """ The key to use for the output tensor embedding in the returned dictionary.""" 179 species_encoding: dict = dataclasses.field(default_factory=dict) 180 """ The species encoding parameters.""" 181 radial_basis: dict = dataclasses.field(default_factory=dict) 182 """ The radial basis parameters.""" 183 184 FID: ClassVar[str] = "ALLEGRO_E3NN" 185 """ Identification of the module when building a model.""" 186 187 @nn.compact 188 def __call__(self, inputs): 189 species = inputs["species"] 190 assert ( 191 len(species.shape) == 1 192 ), "Species must be a 1D array (batches must be flattened)" 193 194 graph = inputs[self.graph_key] 195 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 196 switch = graph["switch"][:, None] 197 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 198 radial_basis = RadialBasis( 199 **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"} 200 )(graph["distances"]) 201 radial_size = radial_basis.shape[-1] 202 203 species_encoding = SpeciesEncoding( 204 **self.species_encoding, name="SpeciesEncoding" 205 )(species) 206 afvs_size = species_encoding.shape[-1] 207 208 xij = ( 209 FullyConnectedNet( 210 neurons=[*self.twobody_hidden, self.dim], activation=self.activation 211 )( 212 jnp.concatenate( 213 [ 214 species_encoding[edge_src], 215 species_encoding[edge_dst], 216 radial_basis, 217 ], 218 axis=-1, 219 ) 220 ) 221 * switch 222 ) 223 if isinstance(self.irreps_Vij, int): 224 irreps_Vij = e3nn.Irreps.spherical_harmonics(self.irreps_Vij) 225 elif isinstance(self.irreps_Vij, str): 226 irreps_Vij = e3nn.Irreps(self.irreps_Vij) 227 else: 228 irreps_Vij = self.irreps_Vij 229 lmax = max(irreps_Vij.ls) 230 lmax_density = self.lmax_density or lmax 231 irreps_density = e3nn.Irreps.spherical_harmonics(lmax_density) 232 233 # Yij = e3nn.IrrepsArray( 234 # irreps_density, 235 # generate_spherical_harmonics(lmax=lmax_density, normalize=False)( 236 # graph["vec"] / graph["distances"][:, None] 237 # ), 238 # )[:, None, :] 239 Yij = e3nn.spherical_harmonics( 240 irreps_density, graph["vec"], normalize=True 241 )[:, None, :] 242 243 Vij = ( 244 e3nn.flax.Linear(irreps_Vij, channel_out=self.nchannels)(Yij) 245 * nn.Dense(self.nchannels, use_bias=False)(xij)[:, :, None] 246 ) 247 248 for _ in range(self.nlayers): 249 rhoij = ( 250 FullyConnectedNet( 251 neurons=[*self.embedding_hidden, self.nchannels], 252 activation=self.activation, 253 )(xij) 254 * switch 255 )[:, :, None] * Yij 256 density = e3nn.scatter_sum( 257 rhoij, dst=edge_src, output_size=species_encoding.shape[0] 258 ) 259 260 Lij = e3nn.tensor_product( 261 Vij, density[edge_src], filter_ir_out=irreps_Vij 262 ) 263 scals = Lij.filter(["0e"]).array.reshape(Lij.shape[0], -1) 264 lij = FullyConnectedNet(neurons=[*self.latent_hidden, self.dim])( 265 jnp.concatenate((xij, scals), axis=-1) 266 ) 267 268 xij = xij + lij * switch 269 # filtering 270 Lij = e3nn.flax.Linear(irreps_Vij)(Lij) 271 # channel mixing 272 Vij = e3nn.flax.Linear(irreps_Vij, channel_out=self.nchannels)(Lij) 273 274 if self.embedding_key is None: 275 return xij, Vij 276 return {**inputs, self.embedding_key: xij, self.tensor_embedding_key: Vij} 277 278else: 279 280 class AllegroE3NNEmbedding(nn.Module): 281 FID: ClassVar[str] = "ALLEGRO_E3NN" 282 283 def __call__(self, *args, **kwargs) -> Any: 284 raise E3NN_EXCEPTION
21class AllegroEmbedding(nn.Module): 22 """Allegro equivariant pair embedding 23 24 FID : ALLEGRO 25 26 ### Reference 27 Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics. 28 Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y 29 30 """ 31 32 _graphs_properties: Dict 33 dim: int = 128 34 """ The dimension of the embedding.""" 35 nchannels: int = 16 36 """ The number of equivariant channels.""" 37 nlayers: int = 3 38 """ The number of interaction layers.""" 39 lmax: int = 2 40 """ The maximum degree of tensorial embedding.""" 41 lmax_density: Optional[int] = None 42 """ The maximum degree of spherical harmonics for density. 43 If None, it will be set to lmax. Must be greater or equal to lmax.""" 44 twobody_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 45 """ The number of hidden neurons in the two-body network.""" 46 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 47 """ The number of hidden neurons in the embedding network.""" 48 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 49 """ The number of hidden neurons in the latent network.""" 50 activation: Union[Callable, str] = "silu" 51 """ The activation function to use.""" 52 graph_key: str = "graph" 53 """ The key in the input dictionary that corresponds to the graph.""" 54 embedding_key: str = "embedding" 55 """ The key to use for the output embedding in the returned dictionary.""" 56 tensor_embedding_key: str = "tensor_embedding" 57 """ The key to use for the output tensor embedding in the returned dictionary.""" 58 species_encoding: dict = dataclasses.field(default_factory=dict) 59 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 60 radial_basis: dict = dataclasses.field(default_factory=dict) 61 """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`""" 62 63 FID: ClassVar[str] = "ALLEGRO" 64 65 @nn.compact 66 def __call__(self, inputs): 67 """ Forward pass of the Allegro model. """ 68 species = inputs["species"] 69 70 graph = inputs[self.graph_key] 71 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 72 switch = graph["switch"][:, None] 73 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 74 radial_basis = RadialBasis( 75 **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"} 76 )(graph["distances"]) 77 78 species_encoding = SpeciesEncoding( 79 **self.species_encoding, name="SpeciesEncoding" 80 )(species) 81 82 xij = ( 83 FullyConnectedNet( 84 neurons=[*self.twobody_hidden, self.dim], activation=self.activation 85 )( 86 jnp.concatenate( 87 [ 88 species_encoding[edge_src], 89 species_encoding[edge_dst], 90 radial_basis, 91 ], 92 axis=-1, 93 ) 94 ) 95 * switch 96 ) 97 98 lmax_density = self.lmax_density if self.lmax_density is not None else self.lmax 99 assert lmax_density >= self.lmax 100 101 Yij = generate_spherical_harmonics(lmax=lmax_density, normalize=False)( 102 graph["vec"] / graph["distances"][:, None] 103 )[:, None, :] 104 105 nel = (self.lmax + 1) ** 2 106 Vij = ( 107 ChannelMixingE3(self.lmax, 1, self.nchannels)(Yij[..., :nel]) 108 * nn.Dense(self.nchannels, use_bias=False)(xij)[:, :, None] 109 ) 110 111 for _ in range(self.nlayers): 112 rhoij = ( 113 FullyConnectedNet( 114 neurons=[*self.embedding_hidden, self.nchannels], 115 activation=self.activation, 116 )(xij) 117 * switch 118 )[:, :, None] * Yij 119 density = ( 120 jnp.zeros((species.shape[0], *rhoij.shape[1:])).at[edge_src].add(rhoij) 121 ) 122 123 Lij = FilteredTensorProduct(self.lmax, lmax_density)(Vij, density[edge_src]) 124 scals = jax.lax.index_in_dim(Lij, 0, axis=-1, keepdims=False) 125 lij = FullyConnectedNet(neurons=[*self.latent_hidden, self.dim])( 126 jnp.concatenate((xij, scals), axis=-1) 127 ) 128 129 xij = xij + lij * switch 130 Vij = ChannelMixing(self.lmax, self.nchannels, self.nchannels)(Lij) 131 132 if self.embedding_key is None: 133 return xij, Vij 134 return {**inputs, self.embedding_key: xij, self.tensor_embedding_key: Vij}
Allegro equivariant pair embedding
FID : ALLEGRO
Reference
Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics. Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y
The maximum degree of spherical harmonics for density. If None, it will be set to lmax. Must be greater or equal to lmax.
The key to use for the output embedding in the returned dictionary.
The key to use for the output tensor embedding in the returned dictionary.
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
140 class AllegroE3NNEmbedding(nn.Module): 141 """Allegro equivariant pair embedding 142 143 FID : ALLEGRO_E3NN 144 145 in this version, equivariant operations use the e3nn library. 146 147 Reference 148 --------- 149 Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics. 150 Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y 151 152 """ 153 154 _graphs_properties: Dict 155 dim: int = 128 156 """ The dimension of the embedding.""" 157 nchannels: int = 16 158 """ The number of equivariant channels.""" 159 nlayers: int = 3 160 """ The number of interaction layers.""" 161 irreps_Vij: Union[str, int, 'Irreps'] = 2 162 """ Irreps used for the tensor embedding.""" 163 lmax_density: int = None 164 """ The maximum degree of spherical harmonics for density. 165 If None, it will be set to lmax. Must be greater or equal to lmax.""" 166 twobody_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 167 """ The number of hidden neurons in the two-body network.""" 168 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 169 """ The number of hidden neurons in the embedding network.""" 170 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 171 """ The number of hidden neurons in the latent network.""" 172 activation: Union[Callable, str] = "silu" 173 """ The activation function to use.""" 174 graph_key: str = "graph" 175 """ The key in the input dictionary that corresponds to the graph.""" 176 embedding_key: str = "embedding" 177 """ The key to use for the output embedding in the returned dictionary.""" 178 tensor_embedding_key: str = "tensor_embedding" 179 """ The key to use for the output tensor embedding in the returned dictionary.""" 180 species_encoding: dict = dataclasses.field(default_factory=dict) 181 """ The species encoding parameters.""" 182 radial_basis: dict = dataclasses.field(default_factory=dict) 183 """ The radial basis parameters.""" 184 185 FID: ClassVar[str] = "ALLEGRO_E3NN" 186 """ Identification of the module when building a model.""" 187 188 @nn.compact 189 def __call__(self, inputs): 190 species = inputs["species"] 191 assert ( 192 len(species.shape) == 1 193 ), "Species must be a 1D array (batches must be flattened)" 194 195 graph = inputs[self.graph_key] 196 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 197 switch = graph["switch"][:, None] 198 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 199 radial_basis = RadialBasis( 200 **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"} 201 )(graph["distances"]) 202 radial_size = radial_basis.shape[-1] 203 204 species_encoding = SpeciesEncoding( 205 **self.species_encoding, name="SpeciesEncoding" 206 )(species) 207 afvs_size = species_encoding.shape[-1] 208 209 xij = ( 210 FullyConnectedNet( 211 neurons=[*self.twobody_hidden, self.dim], activation=self.activation 212 )( 213 jnp.concatenate( 214 [ 215 species_encoding[edge_src], 216 species_encoding[edge_dst], 217 radial_basis, 218 ], 219 axis=-1, 220 ) 221 ) 222 * switch 223 ) 224 if isinstance(self.irreps_Vij, int): 225 irreps_Vij = e3nn.Irreps.spherical_harmonics(self.irreps_Vij) 226 elif isinstance(self.irreps_Vij, str): 227 irreps_Vij = e3nn.Irreps(self.irreps_Vij) 228 else: 229 irreps_Vij = self.irreps_Vij 230 lmax = max(irreps_Vij.ls) 231 lmax_density = self.lmax_density or lmax 232 irreps_density = e3nn.Irreps.spherical_harmonics(lmax_density) 233 234 # Yij = e3nn.IrrepsArray( 235 # irreps_density, 236 # generate_spherical_harmonics(lmax=lmax_density, normalize=False)( 237 # graph["vec"] / graph["distances"][:, None] 238 # ), 239 # )[:, None, :] 240 Yij = e3nn.spherical_harmonics( 241 irreps_density, graph["vec"], normalize=True 242 )[:, None, :] 243 244 Vij = ( 245 e3nn.flax.Linear(irreps_Vij, channel_out=self.nchannels)(Yij) 246 * nn.Dense(self.nchannels, use_bias=False)(xij)[:, :, None] 247 ) 248 249 for _ in range(self.nlayers): 250 rhoij = ( 251 FullyConnectedNet( 252 neurons=[*self.embedding_hidden, self.nchannels], 253 activation=self.activation, 254 )(xij) 255 * switch 256 )[:, :, None] * Yij 257 density = e3nn.scatter_sum( 258 rhoij, dst=edge_src, output_size=species_encoding.shape[0] 259 ) 260 261 Lij = e3nn.tensor_product( 262 Vij, density[edge_src], filter_ir_out=irreps_Vij 263 ) 264 scals = Lij.filter(["0e"]).array.reshape(Lij.shape[0], -1) 265 lij = FullyConnectedNet(neurons=[*self.latent_hidden, self.dim])( 266 jnp.concatenate((xij, scals), axis=-1) 267 ) 268 269 xij = xij + lij * switch 270 # filtering 271 Lij = e3nn.flax.Linear(irreps_Vij)(Lij) 272 # channel mixing 273 Vij = e3nn.flax.Linear(irreps_Vij, channel_out=self.nchannels)(Lij) 274 275 if self.embedding_key is None: 276 return xij, Vij 277 return {**inputs, self.embedding_key: xij, self.tensor_embedding_key: Vij}
Allegro equivariant pair embedding
FID : ALLEGRO_E3NN
in this version, equivariant operations use the e3nn library.
Reference
Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics. Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y
The maximum degree of spherical harmonics for density. If None, it will be set to lmax. Must be greater or equal to lmax.
The key to use for the output embedding in the returned dictionary.
The key to use for the output tensor embedding in the returned dictionary.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.