fennol.models.embeddings.gaussian_moments
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Dict, ClassVar 9 10 11class GaussianMomentsEmbedding(nn.Module): 12 """Gaussian moments embedding 13 14 The construction of this embedding is similar to ACE but with a fixed lmax=3 and 15 a subset of tensor product paths chosen by hand. 16 17 ### Reference 18 adapted from J. Chem. Theory Comput. 2020, 16, 8, 5410–5421 19 (https://pubs.acs.org/doi/full/10.1021/acs.jctc.0c00347) 20 21 """ 22 23 _graphs_properties: Dict 24 nchannels: int = 7 25 """The number of chemical-radial (chemrad) channels for the density representation.""" 26 graph_key: str = "graph" 27 """The key in the input dictionary that corresponds to the molecular graph.""" 28 embedding_key: str = "embedding" 29 """The key in the output dictionary where the computed embedding will be stored.""" 30 species_encoding: dict = dataclasses.field(default_factory=dict) 31 """A dictionary of parameters for the species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`""" 32 radial_basis: dict = dataclasses.field(default_factory=dict) 33 """A dictionary of parameters for the radial basis. See `fennol.models.misc.encodings.RadialBasis`""" 34 35 FID: ClassVar[str] = "GAUSSIAN_MOMENTS" 36 37 38 @nn.compact 39 def __call__(self, inputs): 40 species = inputs["species"] 41 assert len(species.shape) == 1, "Species must be a 1D array (batches must be flattened)" 42 43 graph = inputs[self.graph_key] 44 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 45 46 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 47 radial_basis = RadialBasis( 48 **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"} 49 )(graph["distances"]) 50 radial_size = radial_basis.shape[-1] 51 52 species_encoding = SpeciesEncoding(**self.species_encoding,name="SpeciesEncoding")(species) 53 afvs_size = species_encoding.shape[-1] 54 55 chemrad_coupling = self.param( 56 "chemrad_coupling", 57 nn.initializers.normal(stddev=1.0 / (afvs_size * radial_size) ** 0.5), 58 (afvs_size, radial_size, self.nchannels), 59 ) 60 xij = ( 61 jnp.einsum( 62 "ai,aj,ijk->ak", 63 species_encoding[edge_dst], 64 radial_basis, 65 chemrad_coupling, 66 ) 67 * graph["switch"][:, None] 68 ) 69 70 Yij = generate_spherical_harmonics(lmax=3, normalize=False)( 71 graph["vec"] / graph["distances"][:, None] 72 ) 73 rhoij = xij[:, :, None] * Yij[:, None, :] 74 75 rhoi = jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) 76 77 xi0 = jax.lax.index_in_dim(rhoi, 0, axis=-1, keepdims=False) 78 79 rhoi1 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=1, slice_size=3, axis=-1) 80 rhoi2 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=4, slice_size=5, axis=-1) 81 rhoi3 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=9, slice_size=7, axis=-1) 82 83 pairs = [] 84 triplets = [] 85 for i in range(self.nchannels): 86 for j in range(i, self.nchannels): 87 pairs.append([i, j]) 88 for k in range(j, self.nchannels): 89 triplets.append([i, j, k]) 90 91 p1, p2 = np.array(pairs).T 92 p1, p2 = jnp.array(p1), jnp.array(p2) 93 xi11 = jnp.sum(rhoi1[:, p1, :] * rhoi1[:, p2, :], axis=-1) 94 xi22 = jnp.sum(rhoi2[:, p1, :] * rhoi2[:, p2, :], axis=-1) 95 xi33 = jnp.sum(rhoi3[:, p1, :] * rhoi3[:, p2, :], axis=-1) 96 97 t1, t2, t3 = np.array(triplets).T 98 t1, t2, t3 = jnp.array(t1), jnp.array(t2), jnp.array(t3) 99 rhoi2t1 = rhoi2[:, t1, :] 100 rhoi1t2 = rhoi1[:, t2, :] 101 rhoi1t3 = rhoi1[:, t3, :] 102 w112 = jnp.array(CG_SO3(1, 1, 2)) 103 xi211 = jnp.einsum("...m,...n,...o,nom->...", rhoi2t1, rhoi1t2, rhoi1t3, w112) 104 105 rhoi2t2 = rhoi2[:, t2, :] 106 rhoi2t3 = rhoi2[:, t3, :] 107 w222 = jnp.array(CG_SO3(2, 2, 2)) 108 xi222 = jnp.einsum("...m,...n,...o,nom->...", rhoi2t1, rhoi2t2, rhoi2t3, w222) 109 110 rhoi3t1 = rhoi3[:, t1, :] 111 w123 = jnp.array(CG_SO3(1, 2, 3)) 112 xi312 = jnp.einsum("...m,...n,...o,nom->...", rhoi3t1, rhoi1t2, rhoi2t3, w123) 113 114 rhoi3t3 = rhoi3[:, t3, :] 115 w233 = jnp.array(CG_SO3(2, 3, 3)) 116 xi323 = jnp.einsum("...m,...n,...o,nom->...", rhoi3t1, rhoi2t2, rhoi3t3, w233) 117 118 embedding = jnp.concatenate( 119 [species_encoding, xi0, xi11, xi22, xi33, xi211, xi222, xi312, xi323], 120 axis=-1, 121 ) 122 123 if self.embedding_key is None: 124 return embedding 125 return {**inputs, self.embedding_key: embedding}
12class GaussianMomentsEmbedding(nn.Module): 13 """Gaussian moments embedding 14 15 The construction of this embedding is similar to ACE but with a fixed lmax=3 and 16 a subset of tensor product paths chosen by hand. 17 18 ### Reference 19 adapted from J. Chem. Theory Comput. 2020, 16, 8, 5410–5421 20 (https://pubs.acs.org/doi/full/10.1021/acs.jctc.0c00347) 21 22 """ 23 24 _graphs_properties: Dict 25 nchannels: int = 7 26 """The number of chemical-radial (chemrad) channels for the density representation.""" 27 graph_key: str = "graph" 28 """The key in the input dictionary that corresponds to the molecular graph.""" 29 embedding_key: str = "embedding" 30 """The key in the output dictionary where the computed embedding will be stored.""" 31 species_encoding: dict = dataclasses.field(default_factory=dict) 32 """A dictionary of parameters for the species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`""" 33 radial_basis: dict = dataclasses.field(default_factory=dict) 34 """A dictionary of parameters for the radial basis. See `fennol.models.misc.encodings.RadialBasis`""" 35 36 FID: ClassVar[str] = "GAUSSIAN_MOMENTS" 37 38 39 @nn.compact 40 def __call__(self, inputs): 41 species = inputs["species"] 42 assert len(species.shape) == 1, "Species must be a 1D array (batches must be flattened)" 43 44 graph = inputs[self.graph_key] 45 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 46 47 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 48 radial_basis = RadialBasis( 49 **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"} 50 )(graph["distances"]) 51 radial_size = radial_basis.shape[-1] 52 53 species_encoding = SpeciesEncoding(**self.species_encoding,name="SpeciesEncoding")(species) 54 afvs_size = species_encoding.shape[-1] 55 56 chemrad_coupling = self.param( 57 "chemrad_coupling", 58 nn.initializers.normal(stddev=1.0 / (afvs_size * radial_size) ** 0.5), 59 (afvs_size, radial_size, self.nchannels), 60 ) 61 xij = ( 62 jnp.einsum( 63 "ai,aj,ijk->ak", 64 species_encoding[edge_dst], 65 radial_basis, 66 chemrad_coupling, 67 ) 68 * graph["switch"][:, None] 69 ) 70 71 Yij = generate_spherical_harmonics(lmax=3, normalize=False)( 72 graph["vec"] / graph["distances"][:, None] 73 ) 74 rhoij = xij[:, :, None] * Yij[:, None, :] 75 76 rhoi = jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) 77 78 xi0 = jax.lax.index_in_dim(rhoi, 0, axis=-1, keepdims=False) 79 80 rhoi1 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=1, slice_size=3, axis=-1) 81 rhoi2 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=4, slice_size=5, axis=-1) 82 rhoi3 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=9, slice_size=7, axis=-1) 83 84 pairs = [] 85 triplets = [] 86 for i in range(self.nchannels): 87 for j in range(i, self.nchannels): 88 pairs.append([i, j]) 89 for k in range(j, self.nchannels): 90 triplets.append([i, j, k]) 91 92 p1, p2 = np.array(pairs).T 93 p1, p2 = jnp.array(p1), jnp.array(p2) 94 xi11 = jnp.sum(rhoi1[:, p1, :] * rhoi1[:, p2, :], axis=-1) 95 xi22 = jnp.sum(rhoi2[:, p1, :] * rhoi2[:, p2, :], axis=-1) 96 xi33 = jnp.sum(rhoi3[:, p1, :] * rhoi3[:, p2, :], axis=-1) 97 98 t1, t2, t3 = np.array(triplets).T 99 t1, t2, t3 = jnp.array(t1), jnp.array(t2), jnp.array(t3) 100 rhoi2t1 = rhoi2[:, t1, :] 101 rhoi1t2 = rhoi1[:, t2, :] 102 rhoi1t3 = rhoi1[:, t3, :] 103 w112 = jnp.array(CG_SO3(1, 1, 2)) 104 xi211 = jnp.einsum("...m,...n,...o,nom->...", rhoi2t1, rhoi1t2, rhoi1t3, w112) 105 106 rhoi2t2 = rhoi2[:, t2, :] 107 rhoi2t3 = rhoi2[:, t3, :] 108 w222 = jnp.array(CG_SO3(2, 2, 2)) 109 xi222 = jnp.einsum("...m,...n,...o,nom->...", rhoi2t1, rhoi2t2, rhoi2t3, w222) 110 111 rhoi3t1 = rhoi3[:, t1, :] 112 w123 = jnp.array(CG_SO3(1, 2, 3)) 113 xi312 = jnp.einsum("...m,...n,...o,nom->...", rhoi3t1, rhoi1t2, rhoi2t3, w123) 114 115 rhoi3t3 = rhoi3[:, t3, :] 116 w233 = jnp.array(CG_SO3(2, 3, 3)) 117 xi323 = jnp.einsum("...m,...n,...o,nom->...", rhoi3t1, rhoi2t2, rhoi3t3, w233) 118 119 embedding = jnp.concatenate( 120 [species_encoding, xi0, xi11, xi22, xi33, xi211, xi222, xi312, xi323], 121 axis=-1, 122 ) 123 124 if self.embedding_key is None: 125 return embedding 126 return {**inputs, self.embedding_key: embedding}
Gaussian moments embedding
The construction of this embedding is similar to ACE but with a fixed lmax=3 and a subset of tensor product paths chosen by hand.
Reference
adapted from J. Chem. Theory Comput. 2020, 16, 8, 5410–5421 (https://pubs.acs.org/doi/full/10.1021/acs.jctc.0c00347)
The key in the output dictionary where the computed embedding will be stored.
A dictionary of parameters for the species encoding. See fennol.models.misc.encodings.SpeciesEncoding
A dictionary of parameters for the radial basis. See fennol.models.misc.encodings.RadialBasis
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.