fennol.models.embeddings.minimace

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from ...utils.spherical_harmonics import generate_spherical_harmonics
  5from ..misc.encodings import SpeciesEncoding, RadialBasis
  6import dataclasses
  7import numpy as np
  8from typing import Any, Dict, List, Union, Callable, Tuple, Sequence, Optional, ClassVar
  9from ..misc.nets import FullyConnectedNet
 10from ..misc.e3 import FilteredTensorProduct, ChannelMixingE3, ChannelMixing
 11
 12
 13class MiniMaceEmbedding(nn.Module):
 14    """Minimal MACE Embedding
 15
 16    FID : MINIMACE
 17
 18    This is a simplified version of the MACE embedding from the paper:
 19    Batatia et al., MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields
 20    https://doi.org/10.48550/arXiv.2206.07697
 21
 22    It is designed to neglect the most costly operations (such as edge-wise tensor products)
 23    and filter the results at each atomic tensor products to control the number of tensors.
 24    It may not have the same performance as the full MACE embedding but should be faster.
 25        
 26    """
 27    _graphs_properties: Dict
 28    dim: int = 128
 29    """The dimension of the embedding."""
 30    nchannels: int = 16
 31    """The number of tensor channels."""
 32    message_dim: int = 16
 33    """The dimension of the message formed from the current embedding."""
 34    nlayers: int = 2
 35    """The number of interaction layers."""
 36    ntp: int = 2
 37    """The number of tensor products per layer."""
 38    lmax: int = 2
 39    """The maximum angular momentum of spherical tensors."""
 40    embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 41    """The hidden layers for the species embedding network."""
 42    latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 43    """The hidden layers for the latent update network."""
 44    activation: Union[Callable, str] = "silu"
 45    """The activation function."""
 46    graph_key: str = "graph"
 47    """The key for the graph input."""
 48    embedding_key: str = "embedding"
 49    """The key for the embedding output."""
 50    tensor_embedding_key: str = "tensor_embedding"
 51    """The key for the tensor embedding output."""
 52    species_encoding: dict = dataclasses.field(default_factory=dict)
 53    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`. """
 54    radial_basis: dict = dataclasses.field(default_factory=dict)
 55    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. """
 56    ignore_parity: bool = True
 57    """Whether to ignore parity of irreps in the tensor products"""
 58
 59    FID: ClassVar[str] = "MINIMACE"
 60
 61
 62    @nn.compact
 63    def __call__(self, inputs):
 64        species = inputs["species"]
 65        assert (
 66            len(species.shape) == 1
 67        ), "Species must be a 1D array (batches must be flattened)"
 68        # nchannels_density = (
 69        #     self.nchannels_density
 70        #     if self.nchannels_density is not None
 71        #     else self.nchannels
 72        # )
 73        # nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
 74
 75        graph = inputs[self.graph_key]
 76        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 77        switch = graph["switch"][:, None]
 78        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 79        radial_basis = (
 80            RadialBasis(**{**self.radial_basis, "end": cutoff, "name": "RadialBasis"})(
 81                graph["distances"]
 82            )
 83            * switch
 84        )
 85
 86        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
 87            graph["vec"] / graph["distances"][:, None]
 88        )[:, None, :]
 89
 90        species_encoding = SpeciesEncoding(
 91            **self.species_encoding, name="SpeciesEncoding"
 92        )(species)
 93
 94        xi = FullyConnectedNet(
 95            neurons=[*self.embedding_hidden, self.dim],
 96            activation=self.activation,
 97            use_bias=True,
 98            name="species_embedding",
 99        )(species_encoding)
100
101        nchannels_density = self.message_dim * radial_basis.shape[1]
102
103        for layer in range(self.nlayers):
104            mi = nn.Dense(
105                self.message_dim,
106                use_bias=True,
107                name=f"species_linear_{layer}",
108            )(xi)
109            xij = (mi[edge_dst, :, None] * radial_basis[:, None, :]).reshape(
110                -1, nchannels_density
111            )
112            if layer == 0:
113                rhoij = xij[:, :, None] * Yij
114                density = jax.ops.segment_sum(rhoij, edge_src, species.shape[0])
115                Vi = ChannelMixingE3(
116                    self.lmax,
117                    nchannels_density,
118                    self.nchannels,
119                    name=f"Vi_initial",
120                )(density)
121            else:
122                rhoi = ChannelMixingE3(
123                    self.lmax,
124                    self.nchannels,
125                    nchannels_density,
126                    name=f"rho_mixing_{layer}",
127                )(Vi)
128                rhoij = xij[:, :, None] * rhoi[edge_dst]
129                density = density + jax.ops.segment_sum(
130                    rhoij, edge_src, species.shape[0]
131                )
132
133            scals = [jax.lax.index_in_dim(density, 0, axis=-1, keepdims=False)]
134            for i in range(self.ntp):
135                Hi = ChannelMixing(
136                    self.lmax,
137                    nchannels_density,
138                    self.nchannels,
139                    name=f"density_mixing_{layer}_{i}",
140                )(density)
141                Li = FilteredTensorProduct(
142                    self.lmax,
143                    self.lmax,
144                    self.lmax,
145                    name=f"TP_{layer}_{i}",
146                    ignore_parity=self.ignore_parity,
147                )(Vi, Hi)
148                scals.append(jax.lax.index_in_dim(Li, 0, axis=-1, keepdims=False))
149                Vi = Vi + Li
150
151            dxi = FullyConnectedNet(
152                [*self.latent_hidden, self.dim],
153                activation=self.activation,
154                use_bias=True,
155                name=f"latent_net_{layer}",
156            )(jnp.concatenate([xi, *scals], axis=-1))
157            xi = xi + dxi
158
159        if self.embedding_key is None:
160            return xi, Vi
161        return {**inputs, self.embedding_key: xi, self.tensor_embedding_key: Vi}
class MiniMaceEmbedding(flax.linen.module.Module):
 14class MiniMaceEmbedding(nn.Module):
 15    """Minimal MACE Embedding
 16
 17    FID : MINIMACE
 18
 19    This is a simplified version of the MACE embedding from the paper:
 20    Batatia et al., MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields
 21    https://doi.org/10.48550/arXiv.2206.07697
 22
 23    It is designed to neglect the most costly operations (such as edge-wise tensor products)
 24    and filter the results at each atomic tensor products to control the number of tensors.
 25    It may not have the same performance as the full MACE embedding but should be faster.
 26        
 27    """
 28    _graphs_properties: Dict
 29    dim: int = 128
 30    """The dimension of the embedding."""
 31    nchannels: int = 16
 32    """The number of tensor channels."""
 33    message_dim: int = 16
 34    """The dimension of the message formed from the current embedding."""
 35    nlayers: int = 2
 36    """The number of interaction layers."""
 37    ntp: int = 2
 38    """The number of tensor products per layer."""
 39    lmax: int = 2
 40    """The maximum angular momentum of spherical tensors."""
 41    embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 42    """The hidden layers for the species embedding network."""
 43    latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 44    """The hidden layers for the latent update network."""
 45    activation: Union[Callable, str] = "silu"
 46    """The activation function."""
 47    graph_key: str = "graph"
 48    """The key for the graph input."""
 49    embedding_key: str = "embedding"
 50    """The key for the embedding output."""
 51    tensor_embedding_key: str = "tensor_embedding"
 52    """The key for the tensor embedding output."""
 53    species_encoding: dict = dataclasses.field(default_factory=dict)
 54    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`. """
 55    radial_basis: dict = dataclasses.field(default_factory=dict)
 56    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. """
 57    ignore_parity: bool = True
 58    """Whether to ignore parity of irreps in the tensor products"""
 59
 60    FID: ClassVar[str] = "MINIMACE"
 61
 62
 63    @nn.compact
 64    def __call__(self, inputs):
 65        species = inputs["species"]
 66        assert (
 67            len(species.shape) == 1
 68        ), "Species must be a 1D array (batches must be flattened)"
 69        # nchannels_density = (
 70        #     self.nchannels_density
 71        #     if self.nchannels_density is not None
 72        #     else self.nchannels
 73        # )
 74        # nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
 75
 76        graph = inputs[self.graph_key]
 77        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 78        switch = graph["switch"][:, None]
 79        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 80        radial_basis = (
 81            RadialBasis(**{**self.radial_basis, "end": cutoff, "name": "RadialBasis"})(
 82                graph["distances"]
 83            )
 84            * switch
 85        )
 86
 87        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
 88            graph["vec"] / graph["distances"][:, None]
 89        )[:, None, :]
 90
 91        species_encoding = SpeciesEncoding(
 92            **self.species_encoding, name="SpeciesEncoding"
 93        )(species)
 94
 95        xi = FullyConnectedNet(
 96            neurons=[*self.embedding_hidden, self.dim],
 97            activation=self.activation,
 98            use_bias=True,
 99            name="species_embedding",
100        )(species_encoding)
101
102        nchannels_density = self.message_dim * radial_basis.shape[1]
103
104        for layer in range(self.nlayers):
105            mi = nn.Dense(
106                self.message_dim,
107                use_bias=True,
108                name=f"species_linear_{layer}",
109            )(xi)
110            xij = (mi[edge_dst, :, None] * radial_basis[:, None, :]).reshape(
111                -1, nchannels_density
112            )
113            if layer == 0:
114                rhoij = xij[:, :, None] * Yij
115                density = jax.ops.segment_sum(rhoij, edge_src, species.shape[0])
116                Vi = ChannelMixingE3(
117                    self.lmax,
118                    nchannels_density,
119                    self.nchannels,
120                    name=f"Vi_initial",
121                )(density)
122            else:
123                rhoi = ChannelMixingE3(
124                    self.lmax,
125                    self.nchannels,
126                    nchannels_density,
127                    name=f"rho_mixing_{layer}",
128                )(Vi)
129                rhoij = xij[:, :, None] * rhoi[edge_dst]
130                density = density + jax.ops.segment_sum(
131                    rhoij, edge_src, species.shape[0]
132                )
133
134            scals = [jax.lax.index_in_dim(density, 0, axis=-1, keepdims=False)]
135            for i in range(self.ntp):
136                Hi = ChannelMixing(
137                    self.lmax,
138                    nchannels_density,
139                    self.nchannels,
140                    name=f"density_mixing_{layer}_{i}",
141                )(density)
142                Li = FilteredTensorProduct(
143                    self.lmax,
144                    self.lmax,
145                    self.lmax,
146                    name=f"TP_{layer}_{i}",
147                    ignore_parity=self.ignore_parity,
148                )(Vi, Hi)
149                scals.append(jax.lax.index_in_dim(Li, 0, axis=-1, keepdims=False))
150                Vi = Vi + Li
151
152            dxi = FullyConnectedNet(
153                [*self.latent_hidden, self.dim],
154                activation=self.activation,
155                use_bias=True,
156                name=f"latent_net_{layer}",
157            )(jnp.concatenate([xi, *scals], axis=-1))
158            xi = xi + dxi
159
160        if self.embedding_key is None:
161            return xi, Vi
162        return {**inputs, self.embedding_key: xi, self.tensor_embedding_key: Vi}

Minimal MACE Embedding

FID : MINIMACE

This is a simplified version of the MACE embedding from the paper: Batatia et al., MACE: Higher Order Equivariant Message Passing Neural Networks for Fast and Accurate Force Fields https://doi.org/10.48550/arXiv.2206.07697

It is designed to neglect the most costly operations (such as edge-wise tensor products) and filter the results at each atomic tensor products to control the number of tensors. It may not have the same performance as the full MACE embedding but should be faster.

MiniMaceEmbedding( _graphs_properties: Dict, dim: int = 128, nchannels: int = 16, message_dim: int = 16, nlayers: int = 2, ntp: int = 2, lmax: int = 2, embedding_hidden: Sequence[int] = <factory>, latent_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', graph_key: str = 'graph', embedding_key: str = 'embedding', tensor_embedding_key: str = 'tensor_embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, ignore_parity: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 128

The dimension of the embedding.

nchannels: int = 16

The number of tensor channels.

message_dim: int = 16

The dimension of the message formed from the current embedding.

nlayers: int = 2

The number of interaction layers.

ntp: int = 2

The number of tensor products per layer.

lmax: int = 2

The maximum angular momentum of spherical tensors.

embedding_hidden: Sequence[int]

The hidden layers for the species embedding network.

latent_hidden: Sequence[int]

The hidden layers for the latent update network.

activation: Union[Callable, str] = 'silu'

The activation function.

graph_key: str = 'graph'

The key for the graph input.

embedding_key: str = 'embedding'

The key for the embedding output.

tensor_embedding_key: str = 'tensor_embedding'

The key for the tensor embedding output.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding.

radial_basis: dict

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis.

ignore_parity: bool = True

Whether to ignore parity of irreps in the tensor products

FID: ClassVar[str] = 'MINIMACE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None