fennol.models.embeddings.caiman

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from ...utils.spherical_harmonics import generate_spherical_harmonics
  5from ..misc.encodings import SpeciesEncoding, RadialBasis
  6import dataclasses
  7import numpy as np
  8from typing import Any, Dict, List, Union, Callable, Tuple, Sequence, Optional, ClassVar
  9from ..misc.nets import FullyConnectedNet
 10from ..misc.e3 import FilteredTensorProduct, ChannelMixingE3
 11
 12
 13class CaimanEmbedding(nn.Module):
 14    """Covariant Atom-In-Molecule Network
 15
 16    FID : CAIMAN
 17
 18    This is an E(3) equivariant embedding that forms an equivariant neighbor density
 19    and then uses multiple self-interaction tensor products to generate a tensorial embedding
 20    along with a scalar embedding (similar to the tensor/scalar tracks in allegro).
 21
 22    """
 23
 24    _graphs_properties: Dict
 25    dim: int = 128
 26    """ The dimension of the embedding. """
 27    nchannels: int = 16
 28    """ The number of channels. """
 29    nchannels_density: Optional[int] = None
 30    """ The number of channels for the neighborhood density. If None, it is equal to nchannels."""
 31    nlayers: int = 3
 32    """ The number of layers. """
 33    lmax: int = 2
 34    """ The maximum order of spherical tensors. """
 35    embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 36    """ The hidden layers for the embedding."""
 37    latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 38    """ The hidden layers for the latent network."""
 39    activation: Union[Callable, str] = "silu"
 40    """ The activation function."""
 41    graph_key: str = "graph"
 42    """ The key for the graph input."""
 43    embedding_key: str = "embedding"
 44    """ The key for the embedding output."""
 45    tensor_embedding_key: str = "tensor_embedding"
 46    """ The key for the tensor embedding output."""
 47    species_encoding: dict = dataclasses.field(default_factory=dict)
 48    """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 49    radial_basis: dict = dataclasses.field(default_factory=dict)
 50    """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`"""
 51    message_passing: bool = False
 52    """ Whether to use message passing."""
 53
 54    FID: ClassVar[str] = "CAIMAN"
 55
 56    @nn.compact
 57    def __call__(self, inputs):
 58        species = inputs["species"]
 59        assert (
 60            len(species.shape) == 1
 61        ), "Species must be a 1D array (batches must be flattened)"
 62        nchannels_density = (
 63            self.nchannels_density
 64            if self.nchannels_density is not None
 65            else self.nchannels
 66        )
 67
 68        graph = inputs[self.graph_key]
 69        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 70        switch = graph["switch"][:, None]
 71        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 72        radial_basis = RadialBasis(
 73            **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"}
 74        )(graph["distances"])
 75
 76        Dij = (
 77            nn.Dense(nchannels_density, use_bias=True, name="Dij")(radial_basis)
 78            * switch
 79        )
 80
 81        species_encoding = SpeciesEncoding(
 82            **self.species_encoding, name="SpeciesEncoding"
 83        )(species)
 84
 85        xi = FullyConnectedNet(
 86            neurons=[*self.embedding_hidden, self.dim],
 87            activation=self.activation,
 88            use_bias=True,
 89            name="species_embedding",
 90        )(species_encoding)
 91        Zs, Zd = jnp.split(
 92            nn.Dense(2 * nchannels_density, use_bias=True, name="species_linear")(xi),
 93            2,
 94            axis=-1,
 95        )
 96        xij = Zs[edge_src] * Zd[edge_dst] * Dij
 97
 98        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
 99            graph["vec"] / graph["distances"][:, None]
100        )[:, None, :]
101
102        rhoij = xij[:, :, None] * Yij
103
104        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
105        wsh = self.param(
106            "wsh",
107            lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32),
108            (nchannels_density, self.lmax + 1),
109        ).repeat(nrep, axis=-1)
110        density = (
111            jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) * wsh[None, :, :]
112        )
113
114        nel = (self.lmax + 1) ** 2
115        Vi = ChannelMixingE3(self.lmax, nchannels_density, self.nchannels)(
116            density[..., :nel]
117        )
118        lambda_message = self.param(
119            "lambda_message",
120            lambda key: jnp.asarray(0.1, dtype=density.dtype),
121        )
122
123        for layer in range(self.nlayers):
124            if self.message_passing:
125                Zs, Zs = jnp.split(
126                    nn.Dense(
127                        2 * nchannels_density,
128                        use_bias=True,
129                        name=f"message_linear_{layer}",
130                    )(xi),
131                    2,
132                    axis=-1,
133                )
134                mij = (
135                    nn.Dense(
136                        nchannels_density, use_bias=False, name=f"radial_linear_{layer}"
137                    )(Dij)
138                    * Zs[edge_src]
139                    * Zd[edge_dst]
140                )
141                rhoij = (
142                    mij[:, :, None]
143                    * ChannelMixingE3(
144                        self.lmax,
145                        self.nchannels,
146                        nchannels_density,
147                        name=f"message_mixing_{layer}",
148                    )(Vi)[edge_dst]
149                )
150                rhoi = jax.ops.segment_sum(rhoij, edge_src, species.shape[0])
151                density = density + lambda_message * ChannelMixingE3(
152                    self.lmax,
153                    nchannels_density,
154                    nchannels_density,
155                    name=f"density_update_{layer}",
156                )(rhoi)
157
158            Hi = ChannelMixingE3(
159                self.lmax,
160                nchannels_density,
161                self.nchannels,
162                name=f"density_mixing_{layer}",
163            )(density)
164
165            Li = FilteredTensorProduct(
166                self.lmax, self.lmax, self.lmax, name=f"TP_{layer}"
167            )(Vi, Hi)
168            scals = jax.lax.index_in_dim(Li, 0, axis=-1, keepdims=False)
169            li = FullyConnectedNet(
170                [*self.latent_hidden, self.dim],
171                activation=self.activation,
172                use_bias=True,
173                name=f"latent_net_{layer}",
174            )(jnp.concatenate((xi, scals), axis=-1))
175
176            xi = xi + li
177            Vi = Vi + ChannelMixingE3(
178                self.lmax, self.nchannels, self.nchannels, name=f"mixing_{layer}"
179            )(Li)
180
181        if self.embedding_key is None:
182            return xi, Vi
183        return {**inputs, self.embedding_key: xi, self.tensor_embedding_key: Vi}
class CaimanEmbedding(flax.linen.module.Module):
 14class CaimanEmbedding(nn.Module):
 15    """Covariant Atom-In-Molecule Network
 16
 17    FID : CAIMAN
 18
 19    This is an E(3) equivariant embedding that forms an equivariant neighbor density
 20    and then uses multiple self-interaction tensor products to generate a tensorial embedding
 21    along with a scalar embedding (similar to the tensor/scalar tracks in allegro).
 22
 23    """
 24
 25    _graphs_properties: Dict
 26    dim: int = 128
 27    """ The dimension of the embedding. """
 28    nchannels: int = 16
 29    """ The number of channels. """
 30    nchannels_density: Optional[int] = None
 31    """ The number of channels for the neighborhood density. If None, it is equal to nchannels."""
 32    nlayers: int = 3
 33    """ The number of layers. """
 34    lmax: int = 2
 35    """ The maximum order of spherical tensors. """
 36    embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 37    """ The hidden layers for the embedding."""
 38    latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 39    """ The hidden layers for the latent network."""
 40    activation: Union[Callable, str] = "silu"
 41    """ The activation function."""
 42    graph_key: str = "graph"
 43    """ The key for the graph input."""
 44    embedding_key: str = "embedding"
 45    """ The key for the embedding output."""
 46    tensor_embedding_key: str = "tensor_embedding"
 47    """ The key for the tensor embedding output."""
 48    species_encoding: dict = dataclasses.field(default_factory=dict)
 49    """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 50    radial_basis: dict = dataclasses.field(default_factory=dict)
 51    """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`"""
 52    message_passing: bool = False
 53    """ Whether to use message passing."""
 54
 55    FID: ClassVar[str] = "CAIMAN"
 56
 57    @nn.compact
 58    def __call__(self, inputs):
 59        species = inputs["species"]
 60        assert (
 61            len(species.shape) == 1
 62        ), "Species must be a 1D array (batches must be flattened)"
 63        nchannels_density = (
 64            self.nchannels_density
 65            if self.nchannels_density is not None
 66            else self.nchannels
 67        )
 68
 69        graph = inputs[self.graph_key]
 70        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 71        switch = graph["switch"][:, None]
 72        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 73        radial_basis = RadialBasis(
 74            **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"}
 75        )(graph["distances"])
 76
 77        Dij = (
 78            nn.Dense(nchannels_density, use_bias=True, name="Dij")(radial_basis)
 79            * switch
 80        )
 81
 82        species_encoding = SpeciesEncoding(
 83            **self.species_encoding, name="SpeciesEncoding"
 84        )(species)
 85
 86        xi = FullyConnectedNet(
 87            neurons=[*self.embedding_hidden, self.dim],
 88            activation=self.activation,
 89            use_bias=True,
 90            name="species_embedding",
 91        )(species_encoding)
 92        Zs, Zd = jnp.split(
 93            nn.Dense(2 * nchannels_density, use_bias=True, name="species_linear")(xi),
 94            2,
 95            axis=-1,
 96        )
 97        xij = Zs[edge_src] * Zd[edge_dst] * Dij
 98
 99        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
100            graph["vec"] / graph["distances"][:, None]
101        )[:, None, :]
102
103        rhoij = xij[:, :, None] * Yij
104
105        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
106        wsh = self.param(
107            "wsh",
108            lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32),
109            (nchannels_density, self.lmax + 1),
110        ).repeat(nrep, axis=-1)
111        density = (
112            jax.ops.segment_sum(rhoij, edge_src, species.shape[0]) * wsh[None, :, :]
113        )
114
115        nel = (self.lmax + 1) ** 2
116        Vi = ChannelMixingE3(self.lmax, nchannels_density, self.nchannels)(
117            density[..., :nel]
118        )
119        lambda_message = self.param(
120            "lambda_message",
121            lambda key: jnp.asarray(0.1, dtype=density.dtype),
122        )
123
124        for layer in range(self.nlayers):
125            if self.message_passing:
126                Zs, Zs = jnp.split(
127                    nn.Dense(
128                        2 * nchannels_density,
129                        use_bias=True,
130                        name=f"message_linear_{layer}",
131                    )(xi),
132                    2,
133                    axis=-1,
134                )
135                mij = (
136                    nn.Dense(
137                        nchannels_density, use_bias=False, name=f"radial_linear_{layer}"
138                    )(Dij)
139                    * Zs[edge_src]
140                    * Zd[edge_dst]
141                )
142                rhoij = (
143                    mij[:, :, None]
144                    * ChannelMixingE3(
145                        self.lmax,
146                        self.nchannels,
147                        nchannels_density,
148                        name=f"message_mixing_{layer}",
149                    )(Vi)[edge_dst]
150                )
151                rhoi = jax.ops.segment_sum(rhoij, edge_src, species.shape[0])
152                density = density + lambda_message * ChannelMixingE3(
153                    self.lmax,
154                    nchannels_density,
155                    nchannels_density,
156                    name=f"density_update_{layer}",
157                )(rhoi)
158
159            Hi = ChannelMixingE3(
160                self.lmax,
161                nchannels_density,
162                self.nchannels,
163                name=f"density_mixing_{layer}",
164            )(density)
165
166            Li = FilteredTensorProduct(
167                self.lmax, self.lmax, self.lmax, name=f"TP_{layer}"
168            )(Vi, Hi)
169            scals = jax.lax.index_in_dim(Li, 0, axis=-1, keepdims=False)
170            li = FullyConnectedNet(
171                [*self.latent_hidden, self.dim],
172                activation=self.activation,
173                use_bias=True,
174                name=f"latent_net_{layer}",
175            )(jnp.concatenate((xi, scals), axis=-1))
176
177            xi = xi + li
178            Vi = Vi + ChannelMixingE3(
179                self.lmax, self.nchannels, self.nchannels, name=f"mixing_{layer}"
180            )(Li)
181
182        if self.embedding_key is None:
183            return xi, Vi
184        return {**inputs, self.embedding_key: xi, self.tensor_embedding_key: Vi}

Covariant Atom-In-Molecule Network

FID : CAIMAN

This is an E(3) equivariant embedding that forms an equivariant neighbor density and then uses multiple self-interaction tensor products to generate a tensorial embedding along with a scalar embedding (similar to the tensor/scalar tracks in allegro).

CaimanEmbedding( _graphs_properties: Dict, dim: int = 128, nchannels: int = 16, nchannels_density: Optional[int] = None, nlayers: int = 3, lmax: int = 2, embedding_hidden: Sequence[int] = <factory>, latent_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', graph_key: str = 'graph', embedding_key: str = 'embedding', tensor_embedding_key: str = 'tensor_embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, message_passing: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 128

The dimension of the embedding.

nchannels: int = 16

The number of channels.

nchannels_density: Optional[int] = None

The number of channels for the neighborhood density. If None, it is equal to nchannels.

nlayers: int = 3

The number of layers.

lmax: int = 2

The maximum order of spherical tensors.

embedding_hidden: Sequence[int]

The hidden layers for the embedding.

latent_hidden: Sequence[int]

The hidden layers for the latent network.

activation: Union[Callable, str] = 'silu'

The activation function.

graph_key: str = 'graph'

The key for the graph input.

embedding_key: str = 'embedding'

The key for the embedding output.

tensor_embedding_key: str = 'tensor_embedding'

The key for the tensor embedding output.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding

radial_basis: dict

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis

message_passing: bool = False

Whether to use message passing.

FID: ClassVar[str] = 'CAIMAN'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None