fennol.models.embeddings.gaussian_moments

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3
  5from ..misc.encodings import SpeciesEncoding, RadialBasis
  6import dataclasses
  7import numpy as np
  8from typing import Dict, ClassVar
  9
 10
 11class GaussianMomentsEmbedding(nn.Module):
 12    """Gaussian moments embedding
 13
 14    The construction of this embedding is similar to ACE but with a fixed lmax=3 and
 15    a subset of tensor product paths chosen by hand.
 16
 17    ### Reference 
 18    adapted from J. Chem. Theory Comput. 2020, 16, 8, 5410–5421
 19    (https://pubs.acs.org/doi/full/10.1021/acs.jctc.0c00347)
 20
 21    """
 22
 23    _graphs_properties: Dict
 24    nchannels: int = 7
 25    """The number of chemical-radial (chemrad) channels for the density representation."""
 26    graph_key: str = "graph"
 27    """The key in the input dictionary that corresponds to the molecular graph."""
 28    embedding_key: str = "embedding"
 29    """The key in the output dictionary where the computed embedding will be stored."""
 30    species_encoding: dict = dataclasses.field(default_factory=dict)
 31    """A dictionary of parameters for the species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 32    radial_basis: dict = dataclasses.field(default_factory=dict)
 33    """A dictionary of parameters for the radial basis. See `fennol.models.misc.encodings.RadialBasis`"""
 34
 35    FID: ClassVar[str] = "GAUSSIAN_MOMENTS"
 36
 37
 38    @nn.compact
 39    def __call__(self, inputs):
 40        species = inputs["species"]
 41        assert len(species.shape) == 1, "Species must be a 1D array (batches must be flattened)"
 42
 43        graph = inputs[self.graph_key]
 44        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 45
 46        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 47        radial_basis = RadialBasis(
 48            **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"}
 49        )(graph["distances"])
 50        radial_size = radial_basis.shape[-1]
 51
 52        species_encoding = SpeciesEncoding(**self.species_encoding,name="SpeciesEncoding")(species)
 53        afvs_size = species_encoding.shape[-1]
 54
 55        chemrad_coupling = self.param(
 56            "chemrad_coupling",
 57            nn.initializers.normal(stddev=1.0 / (afvs_size * radial_size) ** 0.5),
 58            (afvs_size, radial_size, self.nchannels),
 59        )
 60        xij = (
 61            jnp.einsum(
 62                "ai,aj,ijk->ak",
 63                species_encoding[edge_dst],
 64                radial_basis,
 65                chemrad_coupling,
 66            )
 67            * graph["switch"][:, None]
 68        )
 69
 70        Yij = generate_spherical_harmonics(lmax=3, normalize=False)(
 71            graph["vec"] / graph["distances"][:, None]
 72        )
 73        rhoij = xij[:, :, None] * Yij[:, None, :]
 74
 75        rhoi = jax.ops.segment_sum(rhoij, edge_src, species.shape[0])
 76
 77        xi0 = jax.lax.index_in_dim(rhoi, 0, axis=-1, keepdims=False)
 78
 79        rhoi1 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=1, slice_size=3, axis=-1)
 80        rhoi2 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=4, slice_size=5, axis=-1)
 81        rhoi3 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=9, slice_size=7, axis=-1)
 82
 83        pairs = []
 84        triplets = []
 85        for i in range(self.nchannels):
 86            for j in range(i, self.nchannels):
 87                pairs.append([i, j])
 88                for k in range(j, self.nchannels):
 89                    triplets.append([i, j, k])
 90
 91        p1, p2 = np.array(pairs).T
 92        p1, p2 = jnp.array(p1), jnp.array(p2)
 93        xi11 = jnp.sum(rhoi1[:, p1, :] * rhoi1[:, p2, :], axis=-1)
 94        xi22 = jnp.sum(rhoi2[:, p1, :] * rhoi2[:, p2, :], axis=-1)
 95        xi33 = jnp.sum(rhoi3[:, p1, :] * rhoi3[:, p2, :], axis=-1)
 96
 97        t1, t2, t3 = np.array(triplets).T
 98        t1, t2, t3 = jnp.array(t1), jnp.array(t2), jnp.array(t3)
 99        rhoi2t1 = rhoi2[:, t1, :]
100        rhoi1t2 = rhoi1[:, t2, :]
101        rhoi1t3 = rhoi1[:, t3, :]
102        w112 = jnp.array(CG_SO3(1, 1, 2))
103        xi211 = jnp.einsum("...m,...n,...o,nom->...", rhoi2t1, rhoi1t2, rhoi1t3, w112)
104
105        rhoi2t2 = rhoi2[:, t2, :]
106        rhoi2t3 = rhoi2[:, t3, :]
107        w222 = jnp.array(CG_SO3(2, 2, 2))
108        xi222 = jnp.einsum("...m,...n,...o,nom->...", rhoi2t1, rhoi2t2, rhoi2t3, w222)
109
110        rhoi3t1 = rhoi3[:, t1, :]
111        w123 = jnp.array(CG_SO3(1, 2, 3))
112        xi312 = jnp.einsum("...m,...n,...o,nom->...", rhoi3t1, rhoi1t2, rhoi2t3, w123)
113
114        rhoi3t3 = rhoi3[:, t3, :]
115        w233 = jnp.array(CG_SO3(2, 3, 3))
116        xi323 = jnp.einsum("...m,...n,...o,nom->...", rhoi3t1, rhoi2t2, rhoi3t3, w233)
117
118        embedding = jnp.concatenate(
119            [species_encoding, xi0, xi11, xi22, xi33, xi211, xi222, xi312, xi323],
120            axis=-1,
121        )
122
123        if self.embedding_key is None:
124            return embedding
125        return {**inputs, self.embedding_key: embedding}
class GaussianMomentsEmbedding(flax.linen.module.Module):
 12class GaussianMomentsEmbedding(nn.Module):
 13    """Gaussian moments embedding
 14
 15    The construction of this embedding is similar to ACE but with a fixed lmax=3 and
 16    a subset of tensor product paths chosen by hand.
 17
 18    ### Reference 
 19    adapted from J. Chem. Theory Comput. 2020, 16, 8, 5410–5421
 20    (https://pubs.acs.org/doi/full/10.1021/acs.jctc.0c00347)
 21
 22    """
 23
 24    _graphs_properties: Dict
 25    nchannels: int = 7
 26    """The number of chemical-radial (chemrad) channels for the density representation."""
 27    graph_key: str = "graph"
 28    """The key in the input dictionary that corresponds to the molecular graph."""
 29    embedding_key: str = "embedding"
 30    """The key in the output dictionary where the computed embedding will be stored."""
 31    species_encoding: dict = dataclasses.field(default_factory=dict)
 32    """A dictionary of parameters for the species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 33    radial_basis: dict = dataclasses.field(default_factory=dict)
 34    """A dictionary of parameters for the radial basis. See `fennol.models.misc.encodings.RadialBasis`"""
 35
 36    FID: ClassVar[str] = "GAUSSIAN_MOMENTS"
 37
 38
 39    @nn.compact
 40    def __call__(self, inputs):
 41        species = inputs["species"]
 42        assert len(species.shape) == 1, "Species must be a 1D array (batches must be flattened)"
 43
 44        graph = inputs[self.graph_key]
 45        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 46
 47        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 48        radial_basis = RadialBasis(
 49            **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"}
 50        )(graph["distances"])
 51        radial_size = radial_basis.shape[-1]
 52
 53        species_encoding = SpeciesEncoding(**self.species_encoding,name="SpeciesEncoding")(species)
 54        afvs_size = species_encoding.shape[-1]
 55
 56        chemrad_coupling = self.param(
 57            "chemrad_coupling",
 58            nn.initializers.normal(stddev=1.0 / (afvs_size * radial_size) ** 0.5),
 59            (afvs_size, radial_size, self.nchannels),
 60        )
 61        xij = (
 62            jnp.einsum(
 63                "ai,aj,ijk->ak",
 64                species_encoding[edge_dst],
 65                radial_basis,
 66                chemrad_coupling,
 67            )
 68            * graph["switch"][:, None]
 69        )
 70
 71        Yij = generate_spherical_harmonics(lmax=3, normalize=False)(
 72            graph["vec"] / graph["distances"][:, None]
 73        )
 74        rhoij = xij[:, :, None] * Yij[:, None, :]
 75
 76        rhoi = jax.ops.segment_sum(rhoij, edge_src, species.shape[0])
 77
 78        xi0 = jax.lax.index_in_dim(rhoi, 0, axis=-1, keepdims=False)
 79
 80        rhoi1 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=1, slice_size=3, axis=-1)
 81        rhoi2 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=4, slice_size=5, axis=-1)
 82        rhoi3 = jax.lax.dynamic_slice_in_dim(rhoi, start_index=9, slice_size=7, axis=-1)
 83
 84        pairs = []
 85        triplets = []
 86        for i in range(self.nchannels):
 87            for j in range(i, self.nchannels):
 88                pairs.append([i, j])
 89                for k in range(j, self.nchannels):
 90                    triplets.append([i, j, k])
 91
 92        p1, p2 = np.array(pairs).T
 93        p1, p2 = jnp.array(p1), jnp.array(p2)
 94        xi11 = jnp.sum(rhoi1[:, p1, :] * rhoi1[:, p2, :], axis=-1)
 95        xi22 = jnp.sum(rhoi2[:, p1, :] * rhoi2[:, p2, :], axis=-1)
 96        xi33 = jnp.sum(rhoi3[:, p1, :] * rhoi3[:, p2, :], axis=-1)
 97
 98        t1, t2, t3 = np.array(triplets).T
 99        t1, t2, t3 = jnp.array(t1), jnp.array(t2), jnp.array(t3)
100        rhoi2t1 = rhoi2[:, t1, :]
101        rhoi1t2 = rhoi1[:, t2, :]
102        rhoi1t3 = rhoi1[:, t3, :]
103        w112 = jnp.array(CG_SO3(1, 1, 2))
104        xi211 = jnp.einsum("...m,...n,...o,nom->...", rhoi2t1, rhoi1t2, rhoi1t3, w112)
105
106        rhoi2t2 = rhoi2[:, t2, :]
107        rhoi2t3 = rhoi2[:, t3, :]
108        w222 = jnp.array(CG_SO3(2, 2, 2))
109        xi222 = jnp.einsum("...m,...n,...o,nom->...", rhoi2t1, rhoi2t2, rhoi2t3, w222)
110
111        rhoi3t1 = rhoi3[:, t1, :]
112        w123 = jnp.array(CG_SO3(1, 2, 3))
113        xi312 = jnp.einsum("...m,...n,...o,nom->...", rhoi3t1, rhoi1t2, rhoi2t3, w123)
114
115        rhoi3t3 = rhoi3[:, t3, :]
116        w233 = jnp.array(CG_SO3(2, 3, 3))
117        xi323 = jnp.einsum("...m,...n,...o,nom->...", rhoi3t1, rhoi2t2, rhoi3t3, w233)
118
119        embedding = jnp.concatenate(
120            [species_encoding, xi0, xi11, xi22, xi33, xi211, xi222, xi312, xi323],
121            axis=-1,
122        )
123
124        if self.embedding_key is None:
125            return embedding
126        return {**inputs, self.embedding_key: embedding}

Gaussian moments embedding

The construction of this embedding is similar to ACE but with a fixed lmax=3 and a subset of tensor product paths chosen by hand.

Reference

adapted from J. Chem. Theory Comput. 2020, 16, 8, 5410–5421 (https://pubs.acs.org/doi/full/10.1021/acs.jctc.0c00347)

GaussianMomentsEmbedding( _graphs_properties: Dict, nchannels: int = 7, graph_key: str = 'graph', embedding_key: str = 'embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
nchannels: int = 7

The number of chemical-radial (chemrad) channels for the density representation.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the molecular graph.

embedding_key: str = 'embedding'

The key in the output dictionary where the computed embedding will be stored.

species_encoding: dict

A dictionary of parameters for the species encoding. See fennol.models.misc.encodings.SpeciesEncoding

radial_basis: dict

A dictionary of parameters for the radial basis. See fennol.models.misc.encodings.RadialBasis

FID: ClassVar[str] = 'GAUSSIAN_MOMENTS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None