fennol.models.embeddings.caiman
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Any, Dict, List, Union, Callable, Tuple, Sequence, Optional, ClassVar 9from ..misc.nets import FullyConnectedNet 10from ..misc.e3 import FilteredTensorProduct, ChannelMixingE3 11 12 13class CaimanEmbedding(nn.Module): 14 """Covariant Atom-In-Molecule Network 15 16 FID : CAIMAN 17 18 This is an E(3) equivariant embedding that forms an equivariant neighbor density 19 and then uses multiple self-interaction tensor products to generate a tensorial embedding 20 along with a scalar embedding (similar to the tensor/scalar tracks in allegro). 21 22 """ 23 24 _graphs_properties: Dict 25 dim: int = 128 26 """ The dimension of the embedding. """ 27 nchannels: int = 16 28 """ The number of channels. """ 29 nchannels_density: Optional[int] = None 30 """ The number of channels for the neighborhood density. If None, it is equal to nchannels.""" 31 nlayers: int = 3 32 """ The number of layers. """ 33 lmax: int = 2 34 """ The maximum order of spherical tensors. """ 35 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 36 """ The hidden layers for the embedding.""" 37 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 38 """ The hidden layers for the latent network.""" 39 activation: Union[Callable, str] = "silu" 40 """ The activation function.""" 41 graph_key: str = "graph" 42 """ The key for the graph input.""" 43 embedding_key: str = "embedding" 44 """ The key for the embedding output.""" 45 tensor_embedding_key: str = "tensor_embedding" 46 """ The key for the tensor embedding output.""" 47 species_encoding: dict = dataclasses.field(default_factory=dict) 48 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 49 radial_basis: dict = dataclasses.field(default_factory=dict) 50 """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`""" 51 message_passing: bool = False 52 """ Whether to use message passing.""" 53 54 FID: ClassVar[str] = "CAIMAN" 55 56 @nn.compact 57 def __call__(self, inputs): 58 species = inputs["species"] 59 assert ( 60 len(species.shape) == 1 61 ), "Species must be a 1D array (batches must be flattened)" 62 nchannels_density = ( 63 self.nchannels_density 64 if self.nchannels_density is not None 65 else self.nchannels 66 ) 67 68 graph = inputs[self.graph_key] 69 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 70 switch = graph["switch"][:, None] 71 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 72 radial_basis = RadialBasis( 73 **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"} 74 )(graph["distances"]) 75 76 Dij = ( 77 nn.Dense(nchannels_density, use_bias=True, name="Dij")(radial_basis) 78 * switch 79 ) 80 81 species_encoding = SpeciesEncoding( 82 **self.species_encoding, name="SpeciesEncoding" 83 )(species) 84 85 xi = FullyConnectedNet( 86 neurons=[*self.embedding_hidden, self.dim], 87 activation=self.activation, 88 use_bias=True, 89 name="species_embedding", 90 )(species_encoding) 91 Zs, Zd = jnp.split( 92 nn.Dense(2 * nchannels_density, use_bias=True, name="species_linear")(xi), 93 2, 94 axis=-1, 95 ) 96 xij = Zs[edge_src] * Zd[edge_dst] * Dij 97 98 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 99 graph["vec"] / graph["distances"][:, None] 100 )[:, None, :] 101 102 rhoij = xij[:, :, None] * Yij 103 104 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 105 wsh = self.param( 106 "wsh", 107 lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32), 108 (nchannels_density, self.lmax + 1), 109 ).repeat(nrep, axis=-1) 110 density = ( 111 jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) * wsh[None, :, :] 112 ) 113 114 nel = (self.lmax + 1) ** 2 115 Vi = ChannelMixingE3(self.lmax, nchannels_density, self.nchannels)( 116 density[..., :nel] 117 ) 118 lambda_message = self.param( 119 "lambda_message", 120 lambda key: jnp.asarray(0.1, dtype=density.dtype), 121 ) 122 123 for layer in range(self.nlayers): 124 if self.message_passing: 125 Zs, Zs = jnp.split( 126 nn.Dense( 127 2 * nchannels_density, 128 use_bias=True, 129 name=f"message_linear_{layer}", 130 )(xi), 131 2, 132 axis=-1, 133 ) 134 mij = ( 135 nn.Dense( 136 nchannels_density, use_bias=False, name=f"radial_linear_{layer}" 137 )(Dij) 138 * Zs[edge_src] 139 * Zd[edge_dst] 140 ) 141 rhoij = ( 142 mij[:, :, None] 143 * ChannelMixingE3( 144 self.lmax, 145 self.nchannels, 146 nchannels_density, 147 name=f"message_mixing_{layer}", 148 )(Vi)[edge_dst] 149 ) 150 rhoi = jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) 151 density = density + lambda_message * ChannelMixingE3( 152 self.lmax, 153 nchannels_density, 154 nchannels_density, 155 name=f"density_update_{layer}", 156 )(rhoi) 157 158 Hi = ChannelMixingE3( 159 self.lmax, 160 nchannels_density, 161 self.nchannels, 162 name=f"density_mixing_{layer}", 163 )(density) 164 165 Li = FilteredTensorProduct( 166 self.lmax, self.lmax, self.lmax, name=f"TP_{layer}" 167 )(Vi, Hi) 168 scals = jax.lax.index_in_dim(Li, 0, axis=-1, keepdims=False) 169 li = FullyConnectedNet( 170 [*self.latent_hidden, self.dim], 171 activation=self.activation, 172 use_bias=True, 173 name=f"latent_net_{layer}", 174 )(jnp.concatenate((xi, scals), axis=-1)) 175 176 xi = xi + li 177 Vi = Vi + ChannelMixingE3( 178 self.lmax, self.nchannels, self.nchannels, name=f"mixing_{layer}" 179 )(Li) 180 181 if self.embedding_key is None: 182 return xi, Vi 183 return {**inputs, self.embedding_key: xi, self.tensor_embedding_key: Vi}
14class CaimanEmbedding(nn.Module): 15 """Covariant Atom-In-Molecule Network 16 17 FID : CAIMAN 18 19 This is an E(3) equivariant embedding that forms an equivariant neighbor density 20 and then uses multiple self-interaction tensor products to generate a tensorial embedding 21 along with a scalar embedding (similar to the tensor/scalar tracks in allegro). 22 23 """ 24 25 _graphs_properties: Dict 26 dim: int = 128 27 """ The dimension of the embedding. """ 28 nchannels: int = 16 29 """ The number of channels. """ 30 nchannels_density: Optional[int] = None 31 """ The number of channels for the neighborhood density. If None, it is equal to nchannels.""" 32 nlayers: int = 3 33 """ The number of layers. """ 34 lmax: int = 2 35 """ The maximum order of spherical tensors. """ 36 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 37 """ The hidden layers for the embedding.""" 38 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 39 """ The hidden layers for the latent network.""" 40 activation: Union[Callable, str] = "silu" 41 """ The activation function.""" 42 graph_key: str = "graph" 43 """ The key for the graph input.""" 44 embedding_key: str = "embedding" 45 """ The key for the embedding output.""" 46 tensor_embedding_key: str = "tensor_embedding" 47 """ The key for the tensor embedding output.""" 48 species_encoding: dict = dataclasses.field(default_factory=dict) 49 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 50 radial_basis: dict = dataclasses.field(default_factory=dict) 51 """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`""" 52 message_passing: bool = False 53 """ Whether to use message passing.""" 54 55 FID: ClassVar[str] = "CAIMAN" 56 57 @nn.compact 58 def __call__(self, inputs): 59 species = inputs["species"] 60 assert ( 61 len(species.shape) == 1 62 ), "Species must be a 1D array (batches must be flattened)" 63 nchannels_density = ( 64 self.nchannels_density 65 if self.nchannels_density is not None 66 else self.nchannels 67 ) 68 69 graph = inputs[self.graph_key] 70 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 71 switch = graph["switch"][:, None] 72 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 73 radial_basis = RadialBasis( 74 **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"} 75 )(graph["distances"]) 76 77 Dij = ( 78 nn.Dense(nchannels_density, use_bias=True, name="Dij")(radial_basis) 79 * switch 80 ) 81 82 species_encoding = SpeciesEncoding( 83 **self.species_encoding, name="SpeciesEncoding" 84 )(species) 85 86 xi = FullyConnectedNet( 87 neurons=[*self.embedding_hidden, self.dim], 88 activation=self.activation, 89 use_bias=True, 90 name="species_embedding", 91 )(species_encoding) 92 Zs, Zd = jnp.split( 93 nn.Dense(2 * nchannels_density, use_bias=True, name="species_linear")(xi), 94 2, 95 axis=-1, 96 ) 97 xij = Zs[edge_src] * Zd[edge_dst] * Dij 98 99 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 100 graph["vec"] / graph["distances"][:, None] 101 )[:, None, :] 102 103 rhoij = xij[:, :, None] * Yij 104 105 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 106 wsh = self.param( 107 "wsh", 108 lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32), 109 (nchannels_density, self.lmax + 1), 110 ).repeat(nrep, axis=-1) 111 density = ( 112 jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) * wsh[None, :, :] 113 ) 114 115 nel = (self.lmax + 1) ** 2 116 Vi = ChannelMixingE3(self.lmax, nchannels_density, self.nchannels)( 117 density[..., :nel] 118 ) 119 lambda_message = self.param( 120 "lambda_message", 121 lambda key: jnp.asarray(0.1, dtype=density.dtype), 122 ) 123 124 for layer in range(self.nlayers): 125 if self.message_passing: 126 Zs, Zs = jnp.split( 127 nn.Dense( 128 2 * nchannels_density, 129 use_bias=True, 130 name=f"message_linear_{layer}", 131 )(xi), 132 2, 133 axis=-1, 134 ) 135 mij = ( 136 nn.Dense( 137 nchannels_density, use_bias=False, name=f"radial_linear_{layer}" 138 )(Dij) 139 * Zs[edge_src] 140 * Zd[edge_dst] 141 ) 142 rhoij = ( 143 mij[:, :, None] 144 * ChannelMixingE3( 145 self.lmax, 146 self.nchannels, 147 nchannels_density, 148 name=f"message_mixing_{layer}", 149 )(Vi)[edge_dst] 150 ) 151 rhoi = jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) 152 density = density + lambda_message * ChannelMixingE3( 153 self.lmax, 154 nchannels_density, 155 nchannels_density, 156 name=f"density_update_{layer}", 157 )(rhoi) 158 159 Hi = ChannelMixingE3( 160 self.lmax, 161 nchannels_density, 162 self.nchannels, 163 name=f"density_mixing_{layer}", 164 )(density) 165 166 Li = FilteredTensorProduct( 167 self.lmax, self.lmax, self.lmax, name=f"TP_{layer}" 168 )(Vi, Hi) 169 scals = jax.lax.index_in_dim(Li, 0, axis=-1, keepdims=False) 170 li = FullyConnectedNet( 171 [*self.latent_hidden, self.dim], 172 activation=self.activation, 173 use_bias=True, 174 name=f"latent_net_{layer}", 175 )(jnp.concatenate((xi, scals), axis=-1)) 176 177 xi = xi + li 178 Vi = Vi + ChannelMixingE3( 179 self.lmax, self.nchannels, self.nchannels, name=f"mixing_{layer}" 180 )(Li) 181 182 if self.embedding_key is None: 183 return xi, Vi 184 return {**inputs, self.embedding_key: xi, self.tensor_embedding_key: Vi}
Covariant Atom-In-Molecule Network
FID : CAIMAN
This is an E(3) equivariant embedding that forms an equivariant neighbor density and then uses multiple self-interaction tensor products to generate a tensorial embedding along with a scalar embedding (similar to the tensor/scalar tracks in allegro).
The number of channels for the neighborhood density. If None, it is equal to nchannels.
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.