fennol.models.embeddings.foam
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Dict, Optional, ClassVar 9 10 11class FOAMEmbedding(nn.Module): 12 """Filtered Overlap of Atomic Moments 13 14 FID : FOAM 15 16 Similar to SOAP embedding but for each rank l, we do not take all combinations 17 of each channels but linearly project on 2 nchannels elements and then take the 18 scalar product. This is then kind of a linearly filtered SOAP embedding. 19 20 """ 21 _graphs_properties: Dict 22 lmax: int = 2 23 """The maximum order of spherical tensors.""" 24 nchannels: Optional[int] = None 25 """The number of channels.""" 26 graph_key: str = "graph" 27 """The key for the graph input.""" 28 embedding_key: str = "embedding" 29 """The key for the embedding output.""" 30 species_encoding: dict = dataclasses.field(default_factory=dict) 31 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 32 radial_basis: dict = dataclasses.field(default_factory=dict) 33 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`""" 34 include_species: bool = True 35 """Whether to include the species encoding in the embedding.""" 36 37 FID: ClassVar[str] = "FOAM" 38 39 @nn.compact 40 def __call__(self, inputs): 41 species = inputs["species"] 42 assert ( 43 len(species.shape) == 1 44 ), "Species must be a 1D array (batches must be flattened)" 45 46 graph = inputs[self.graph_key] 47 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 48 switch = graph["switch"][:, None] 49 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 50 radial_basis = ( 51 RadialBasis(**{**self.radial_basis, "end": cutoff, "name": "RadialBasis"})( 52 graph["distances"] 53 ) 54 * switch 55 ) 56 57 species_encoding = SpeciesEncoding( 58 **self.species_encoding, name="SpeciesEncoding" 59 )(species) 60 61 Dij = (radial_basis[:, :, None] * species_encoding[edge_dst, None, :]).reshape( 62 -1, species_encoding.shape[-1] * radial_basis.shape[-1] 63 ) 64 65 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 66 graph["vec"] / graph["distances"][:, None] 67 )[:, :, None] 68 69 rhoi = jax.ops.segment_sum(Dij[:, None, :] * Yij, edge_src, species.shape[0]) 70 71 nbasis = rhoi.shape[-1] 72 nchannels = self.nchannels if self.nchannels is not None else nbasis 73 74 if self.include_species: 75 xis = [species_encoding, rhoi[:, 0, :]] 76 else: 77 xis = [rhoi[:, 0, :]] 78 79 for l in range(self.lmax + 1): 80 rhoil = rhoi[:, l**2 : (l + 1) ** 2, :] 81 xl, yl = jnp.split( 82 nn.Dense(2 * nchannels, use_bias=False, name=f"xy_l{l}")(rhoil), 83 2, 84 axis=-1, 85 ) 86 xil = (xl*yl).sum(axis=1) / (2 * l + 1) ** 0.5 87 xis.append(xil) 88 xi = jnp.concatenate(xis, axis=-1) 89 90 if self.embedding_key is None: 91 return xi 92 return {**inputs, self.embedding_key: xi}
12class FOAMEmbedding(nn.Module): 13 """Filtered Overlap of Atomic Moments 14 15 FID : FOAM 16 17 Similar to SOAP embedding but for each rank l, we do not take all combinations 18 of each channels but linearly project on 2 nchannels elements and then take the 19 scalar product. This is then kind of a linearly filtered SOAP embedding. 20 21 """ 22 _graphs_properties: Dict 23 lmax: int = 2 24 """The maximum order of spherical tensors.""" 25 nchannels: Optional[int] = None 26 """The number of channels.""" 27 graph_key: str = "graph" 28 """The key for the graph input.""" 29 embedding_key: str = "embedding" 30 """The key for the embedding output.""" 31 species_encoding: dict = dataclasses.field(default_factory=dict) 32 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 33 radial_basis: dict = dataclasses.field(default_factory=dict) 34 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`""" 35 include_species: bool = True 36 """Whether to include the species encoding in the embedding.""" 37 38 FID: ClassVar[str] = "FOAM" 39 40 @nn.compact 41 def __call__(self, inputs): 42 species = inputs["species"] 43 assert ( 44 len(species.shape) == 1 45 ), "Species must be a 1D array (batches must be flattened)" 46 47 graph = inputs[self.graph_key] 48 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 49 switch = graph["switch"][:, None] 50 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 51 radial_basis = ( 52 RadialBasis(**{**self.radial_basis, "end": cutoff, "name": "RadialBasis"})( 53 graph["distances"] 54 ) 55 * switch 56 ) 57 58 species_encoding = SpeciesEncoding( 59 **self.species_encoding, name="SpeciesEncoding" 60 )(species) 61 62 Dij = (radial_basis[:, :, None] * species_encoding[edge_dst, None, :]).reshape( 63 -1, species_encoding.shape[-1] * radial_basis.shape[-1] 64 ) 65 66 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 67 graph["vec"] / graph["distances"][:, None] 68 )[:, :, None] 69 70 rhoi = jax.ops.segment_sum(Dij[:, None, :] * Yij, edge_src, species.shape[0]) 71 72 nbasis = rhoi.shape[-1] 73 nchannels = self.nchannels if self.nchannels is not None else nbasis 74 75 if self.include_species: 76 xis = [species_encoding, rhoi[:, 0, :]] 77 else: 78 xis = [rhoi[:, 0, :]] 79 80 for l in range(self.lmax + 1): 81 rhoil = rhoi[:, l**2 : (l + 1) ** 2, :] 82 xl, yl = jnp.split( 83 nn.Dense(2 * nchannels, use_bias=False, name=f"xy_l{l}")(rhoil), 84 2, 85 axis=-1, 86 ) 87 xil = (xl*yl).sum(axis=1) / (2 * l + 1) ** 0.5 88 xis.append(xil) 89 xi = jnp.concatenate(xis, axis=-1) 90 91 if self.embedding_key is None: 92 return xi 93 return {**inputs, self.embedding_key: xi}
Filtered Overlap of Atomic Moments
FID : FOAM
Similar to SOAP embedding but for each rank l, we do not take all combinations of each channels but linearly project on 2 nchannels elements and then take the scalar product. This is then kind of a linearly filtered SOAP embedding.
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.