fennol.models.embeddings.minimace
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Any, Dict, List, Union, Callable, Tuple, Sequence, Optional, ClassVar 9from ..misc.nets import FullyConnectedNet 10from ..misc.e3 import FilteredTensorProduct, ChannelMixingE3, ChannelMixing 11 12 13class MiniMaceEmbedding(nn.Module): 14 """Minimal MACE Embedding 15 16 FID : MINIMACE 17 18 This is a simplified version of the MACE embedding from the paper: 19 Batatia et al., MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields 20 https://doi.org/10.48550/arXiv.2206.07697 21 22 It is designed to neglect the most costly operations (such as edge-wise tensor products) 23 and filter the results at each atomic tensor products to control the number of tensors. 24 It may not have the same performance as the full MACE embedding but should be faster. 25 26 """ 27 _graphs_properties: Dict 28 dim: int = 128 29 """The dimension of the embedding.""" 30 nchannels: int = 16 31 """The number of tensor channels.""" 32 message_dim: int = 16 33 """The dimension of the message formed from the current embedding.""" 34 nlayers: int = 2 35 """The number of interaction layers.""" 36 ntp: int = 2 37 """The number of tensor products per layer.""" 38 lmax: int = 2 39 """The maximum angular momentum of spherical tensors.""" 40 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 41 """The hidden layers for the species embedding network.""" 42 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 43 """The hidden layers for the latent update network.""" 44 activation: Union[Callable, str] = "silu" 45 """The activation function.""" 46 graph_key: str = "graph" 47 """The key for the graph input.""" 48 embedding_key: str = "embedding" 49 """The key for the embedding output.""" 50 tensor_embedding_key: str = "tensor_embedding" 51 """The key for the tensor embedding output.""" 52 species_encoding: dict = dataclasses.field(default_factory=dict) 53 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`. """ 54 radial_basis: dict = dataclasses.field(default_factory=dict) 55 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. """ 56 ignore_parity: bool = True 57 """Whether to ignore parity of irreps in the tensor products""" 58 59 FID: ClassVar[str] = "MINIMACE" 60 61 62 @nn.compact 63 def __call__(self, inputs): 64 species = inputs["species"] 65 assert ( 66 len(species.shape) == 1 67 ), "Species must be a 1D array (batches must be flattened)" 68 # nchannels_density = ( 69 # self.nchannels_density 70 # if self.nchannels_density is not None 71 # else self.nchannels 72 # ) 73 # nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 74 75 graph = inputs[self.graph_key] 76 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 77 switch = graph["switch"][:, None] 78 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 79 radial_basis = ( 80 RadialBasis(**{**self.radial_basis, "end": cutoff, "name": "RadialBasis"})( 81 graph["distances"] 82 ) 83 * switch 84 ) 85 86 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 87 graph["vec"] / graph["distances"][:, None] 88 )[:, None, :] 89 90 species_encoding = SpeciesEncoding( 91 **self.species_encoding, name="SpeciesEncoding" 92 )(species) 93 94 xi = FullyConnectedNet( 95 neurons=[*self.embedding_hidden, self.dim], 96 activation=self.activation, 97 use_bias=True, 98 name="species_embedding", 99 )(species_encoding) 100 101 nchannels_density = self.message_dim * radial_basis.shape[1] 102 103 for layer in range(self.nlayers): 104 mi = nn.Dense( 105 self.message_dim, 106 use_bias=True, 107 name=f"species_linear_{layer}", 108 )(xi) 109 xij = (mi[edge_dst, :, None] * radial_basis[:, None, :]).reshape( 110 -1, nchannels_density 111 ) 112 if layer == 0: 113 rhoij = xij[:, :, None] * Yij 114 density = jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) 115 Vi = ChannelMixingE3( 116 self.lmax, 117 nchannels_density, 118 self.nchannels, 119 name=f"Vi_initial", 120 )(density) 121 else: 122 rhoi = ChannelMixingE3( 123 self.lmax, 124 self.nchannels, 125 nchannels_density, 126 name=f"rho_mixing_{layer}", 127 )(Vi) 128 rhoij = xij[:, :, None] * rhoi[edge_dst] 129 density = density + jax.ops.segment_sum( 130 rhoij, edge_src, species.shape[0] 131 ) 132 133 scals = [jax.lax.index_in_dim(density, 0, axis=-1, keepdims=False)] 134 for i in range(self.ntp): 135 Hi = ChannelMixing( 136 self.lmax, 137 nchannels_density, 138 self.nchannels, 139 name=f"density_mixing_{layer}_{i}", 140 )(density) 141 Li = FilteredTensorProduct( 142 self.lmax, 143 self.lmax, 144 self.lmax, 145 name=f"TP_{layer}_{i}", 146 ignore_parity=self.ignore_parity, 147 )(Vi, Hi) 148 scals.append(jax.lax.index_in_dim(Li, 0, axis=-1, keepdims=False)) 149 Vi = Vi + Li 150 151 dxi = FullyConnectedNet( 152 [*self.latent_hidden, self.dim], 153 activation=self.activation, 154 use_bias=True, 155 name=f"latent_net_{layer}", 156 )(jnp.concatenate([xi, *scals], axis=-1)) 157 xi = xi + dxi 158 159 if self.embedding_key is None: 160 return xi, Vi 161 return {**inputs, self.embedding_key: xi, self.tensor_embedding_key: Vi}
14class MiniMaceEmbedding(nn.Module): 15 """Minimal MACE Embedding 16 17 FID : MINIMACE 18 19 This is a simplified version of the MACE embedding from the paper: 20 Batatia et al., MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields 21 https://doi.org/10.48550/arXiv.2206.07697 22 23 It is designed to neglect the most costly operations (such as edge-wise tensor products) 24 and filter the results at each atomic tensor products to control the number of tensors. 25 It may not have the same performance as the full MACE embedding but should be faster. 26 27 """ 28 _graphs_properties: Dict 29 dim: int = 128 30 """The dimension of the embedding.""" 31 nchannels: int = 16 32 """The number of tensor channels.""" 33 message_dim: int = 16 34 """The dimension of the message formed from the current embedding.""" 35 nlayers: int = 2 36 """The number of interaction layers.""" 37 ntp: int = 2 38 """The number of tensor products per layer.""" 39 lmax: int = 2 40 """The maximum angular momentum of spherical tensors.""" 41 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 42 """The hidden layers for the species embedding network.""" 43 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 44 """The hidden layers for the latent update network.""" 45 activation: Union[Callable, str] = "silu" 46 """The activation function.""" 47 graph_key: str = "graph" 48 """The key for the graph input.""" 49 embedding_key: str = "embedding" 50 """The key for the embedding output.""" 51 tensor_embedding_key: str = "tensor_embedding" 52 """The key for the tensor embedding output.""" 53 species_encoding: dict = dataclasses.field(default_factory=dict) 54 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`. """ 55 radial_basis: dict = dataclasses.field(default_factory=dict) 56 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. """ 57 ignore_parity: bool = True 58 """Whether to ignore parity of irreps in the tensor products""" 59 60 FID: ClassVar[str] = "MINIMACE" 61 62 63 @nn.compact 64 def __call__(self, inputs): 65 species = inputs["species"] 66 assert ( 67 len(species.shape) == 1 68 ), "Species must be a 1D array (batches must be flattened)" 69 # nchannels_density = ( 70 # self.nchannels_density 71 # if self.nchannels_density is not None 72 # else self.nchannels 73 # ) 74 # nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 75 76 graph = inputs[self.graph_key] 77 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 78 switch = graph["switch"][:, None] 79 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 80 radial_basis = ( 81 RadialBasis(**{**self.radial_basis, "end": cutoff, "name": "RadialBasis"})( 82 graph["distances"] 83 ) 84 * switch 85 ) 86 87 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 88 graph["vec"] / graph["distances"][:, None] 89 )[:, None, :] 90 91 species_encoding = SpeciesEncoding( 92 **self.species_encoding, name="SpeciesEncoding" 93 )(species) 94 95 xi = FullyConnectedNet( 96 neurons=[*self.embedding_hidden, self.dim], 97 activation=self.activation, 98 use_bias=True, 99 name="species_embedding", 100 )(species_encoding) 101 102 nchannels_density = self.message_dim * radial_basis.shape[1] 103 104 for layer in range(self.nlayers): 105 mi = nn.Dense( 106 self.message_dim, 107 use_bias=True, 108 name=f"species_linear_{layer}", 109 )(xi) 110 xij = (mi[edge_dst, :, None] * radial_basis[:, None, :]).reshape( 111 -1, nchannels_density 112 ) 113 if layer == 0: 114 rhoij = xij[:, :, None] * Yij 115 density = jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) 116 Vi = ChannelMixingE3( 117 self.lmax, 118 nchannels_density, 119 self.nchannels, 120 name=f"Vi_initial", 121 )(density) 122 else: 123 rhoi = ChannelMixingE3( 124 self.lmax, 125 self.nchannels, 126 nchannels_density, 127 name=f"rho_mixing_{layer}", 128 )(Vi) 129 rhoij = xij[:, :, None] * rhoi[edge_dst] 130 density = density + jax.ops.segment_sum( 131 rhoij, edge_src, species.shape[0] 132 ) 133 134 scals = [jax.lax.index_in_dim(density, 0, axis=-1, keepdims=False)] 135 for i in range(self.ntp): 136 Hi = ChannelMixing( 137 self.lmax, 138 nchannels_density, 139 self.nchannels, 140 name=f"density_mixing_{layer}_{i}", 141 )(density) 142 Li = FilteredTensorProduct( 143 self.lmax, 144 self.lmax, 145 self.lmax, 146 name=f"TP_{layer}_{i}", 147 ignore_parity=self.ignore_parity, 148 )(Vi, Hi) 149 scals.append(jax.lax.index_in_dim(Li, 0, axis=-1, keepdims=False)) 150 Vi = Vi + Li 151 152 dxi = FullyConnectedNet( 153 [*self.latent_hidden, self.dim], 154 activation=self.activation, 155 use_bias=True, 156 name=f"latent_net_{layer}", 157 )(jnp.concatenate([xi, *scals], axis=-1)) 158 xi = xi + dxi 159 160 if self.embedding_key is None: 161 return xi, Vi 162 return {**inputs, self.embedding_key: xi, self.tensor_embedding_key: Vi}
Minimal MACE Embedding
FID : MINIMACE
This is a simplified version of the MACE embedding from the paper: Batatia et al., MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields https://doi.org/10.48550/arXiv.2206.07697
It is designed to neglect the most costly operations (such as edge-wise tensor products) and filter the results at each atomic tensor products to control the number of tensors. It may not have the same performance as the full MACE embedding but should be faster.
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.