fennol.models.embeddings.allegro

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from ...utils.spherical_harmonics import generate_spherical_harmonics
  5from ..misc.encodings import SpeciesEncoding, RadialBasis
  6import dataclasses
  7import numpy as np
  8from typing import Any, Dict, List, Union, Callable, Tuple, Sequence, Optional, ClassVar
  9from ..misc.nets import FullyConnectedNet
 10from ..misc.e3 import (
 11    FilteredTensorProduct,
 12    ChannelMixingE3,
 13    ChannelMixing,
 14    E3NN_AVAILABLE,
 15    E3NN_EXCEPTION,
 16    Irreps,
 17)
 18
 19
 20class AllegroEmbedding(nn.Module):
 21    """Allegro equivariant pair embedding
 22
 23    FID : ALLEGRO
 24
 25    ### Reference
 26    Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics.
 27      Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y
 28
 29    """
 30
 31    _graphs_properties: Dict
 32    dim: int = 128
 33    """ The dimension of the embedding."""
 34    nchannels: int = 16
 35    """ The number of equivariant channels."""
 36    nlayers: int = 3
 37    """ The number of interaction layers."""
 38    lmax: int = 2
 39    """ The maximum degree of tensorial embedding."""
 40    lmax_density: Optional[int] = None
 41    """ The maximum degree of spherical harmonics for density.
 42        If None, it will be set to lmax. Must be greater or equal to lmax."""
 43    twobody_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 44    """ The number of hidden neurons in the two-body network."""
 45    embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 46    """ The number of hidden neurons in the embedding network."""
 47    latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 48    """ The number of hidden neurons in the latent network."""
 49    activation: Union[Callable, str] = "silu"
 50    """ The activation function to use."""
 51    graph_key: str = "graph"
 52    """ The key in the input dictionary that corresponds to the graph."""
 53    embedding_key: str = "embedding"
 54    """ The key to use for the output embedding in the returned dictionary."""
 55    tensor_embedding_key: str = "tensor_embedding"
 56    """ The key to use for the output tensor embedding in the returned dictionary."""
 57    species_encoding: dict = dataclasses.field(default_factory=dict)
 58    """ The species encoding parameters.  See `fennol.models.misc.encodings.SpeciesEncoding`"""
 59    radial_basis: dict = dataclasses.field(default_factory=dict)
 60    """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`"""
 61
 62    FID: ClassVar[str] = "ALLEGRO"
 63
 64    @nn.compact
 65    def __call__(self, inputs):
 66        """ Forward pass of the Allegro model. """
 67        species = inputs["species"]
 68
 69        graph = inputs[self.graph_key]
 70        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 71        switch = graph["switch"][:, None]
 72        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 73        radial_basis = RadialBasis(
 74            **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"}
 75        )(graph["distances"])
 76
 77        species_encoding = SpeciesEncoding(
 78            **self.species_encoding, name="SpeciesEncoding"
 79        )(species)
 80
 81        xij = (
 82            FullyConnectedNet(
 83                neurons=[*self.twobody_hidden, self.dim], activation=self.activation
 84            )(
 85                jnp.concatenate(
 86                    [
 87                        species_encoding[edge_src],
 88                        species_encoding[edge_dst],
 89                        radial_basis,
 90                    ],
 91                    axis=-1,
 92                )
 93            )
 94            * switch
 95        )
 96
 97        lmax_density = self.lmax_density if self.lmax_density is not None else self.lmax
 98        assert lmax_density >= self.lmax
 99
100        Yij = generate_spherical_harmonics(lmax=lmax_density, normalize=False)(
101            graph["vec"] / graph["distances"][:, None]  
102        )[:, None, :]
103
104        nel = (self.lmax + 1) ** 2
105        Vij = (
106            ChannelMixingE3(self.lmax, 1, self.nchannels)(Yij[..., :nel])
107            * nn.Dense(self.nchannels, use_bias=False)(xij)[:, :, None]
108        )
109
110        for _ in range(self.nlayers):
111            rhoij = (
112                FullyConnectedNet(
113                    neurons=[*self.embedding_hidden, self.nchannels],
114                    activation=self.activation,
115                )(xij)
116                * switch
117            )[:, :, None] * Yij
118            density = (
119                jnp.zeros((species.shape[0], *rhoij.shape[1:])).at[edge_src].add(rhoij)
120            )
121
122            Lij = FilteredTensorProduct(self.lmax, lmax_density)(Vij, density[edge_src])
123            scals = jax.lax.index_in_dim(Lij, 0, axis=-1, keepdims=False)
124            lij = FullyConnectedNet(neurons=[*self.latent_hidden, self.dim])(
125                jnp.concatenate((xij, scals), axis=-1)
126            )
127
128            xij = xij + lij * switch
129            Vij = ChannelMixing(self.lmax, self.nchannels, self.nchannels)(Lij)
130
131        if self.embedding_key is None:
132            return xij, Vij
133        return {**inputs, self.embedding_key: xij, self.tensor_embedding_key: Vij}
134
135
136if E3NN_AVAILABLE:
137    import e3nn_jax as e3nn
138
139    class AllegroE3NNEmbedding(nn.Module):
140        """Allegro equivariant pair embedding
141
142        FID : ALLEGRO_E3NN
143
144        in this version, equivariant operations use the e3nn library.
145
146        Reference
147        ---------
148        Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics.
149        Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y
150
151        """
152
153        _graphs_properties: Dict
154        dim: int = 128
155        """ The dimension of the embedding."""
156        nchannels: int = 16
157        """ The number of equivariant channels."""
158        nlayers: int = 3
159        """ The number of interaction layers."""
160        irreps_Vij: Union[str, int, 'Irreps'] = 2
161        """ Irreps used for the tensor embedding."""
162        lmax_density: int = None
163        """ The maximum degree of spherical harmonics for density.
164            If None, it will be set to lmax. Must be greater or equal to lmax."""
165        twobody_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
166        """ The number of hidden neurons in the two-body network."""
167        embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
168        """ The number of hidden neurons in the embedding network."""
169        latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
170        """ The number of hidden neurons in the latent network."""
171        activation: Union[Callable, str] = "silu"
172        """ The activation function to use."""
173        graph_key: str = "graph"
174        """ The key in the input dictionary that corresponds to the graph."""
175        embedding_key: str = "embedding"
176        """ The key to use for the output embedding in the returned dictionary."""
177        tensor_embedding_key: str = "tensor_embedding"
178        """ The key to use for the output tensor embedding in the returned dictionary."""
179        species_encoding: dict = dataclasses.field(default_factory=dict)
180        """ The species encoding parameters."""
181        radial_basis: dict = dataclasses.field(default_factory=dict)
182        """ The radial basis parameters."""
183
184        FID: ClassVar[str] = "ALLEGRO_E3NN"
185        """ Identification of the module when building a model."""
186
187        @nn.compact
188        def __call__(self, inputs):
189            species = inputs["species"]
190            assert (
191                len(species.shape) == 1
192            ), "Species must be a 1D array (batches must be flattened)"
193
194            graph = inputs[self.graph_key]
195            edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
196            switch = graph["switch"][:, None]
197            cutoff = self._graphs_properties[self.graph_key]["cutoff"]
198            radial_basis = RadialBasis(
199                **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"}
200            )(graph["distances"])
201            radial_size = radial_basis.shape[-1]
202
203            species_encoding = SpeciesEncoding(
204                **self.species_encoding, name="SpeciesEncoding"
205            )(species)
206            afvs_size = species_encoding.shape[-1]
207
208            xij = (
209                FullyConnectedNet(
210                    neurons=[*self.twobody_hidden, self.dim], activation=self.activation
211                )(
212                    jnp.concatenate(
213                        [
214                            species_encoding[edge_src],
215                            species_encoding[edge_dst],
216                            radial_basis,
217                        ],
218                        axis=-1,
219                    )
220                )
221                * switch
222            )
223            if isinstance(self.irreps_Vij, int):
224                irreps_Vij = e3nn.Irreps.spherical_harmonics(self.irreps_Vij)
225            elif isinstance(self.irreps_Vij, str):
226                irreps_Vij = e3nn.Irreps(self.irreps_Vij)
227            else:
228                irreps_Vij = self.irreps_Vij
229            lmax = max(irreps_Vij.ls)
230            lmax_density = self.lmax_density or lmax
231            irreps_density = e3nn.Irreps.spherical_harmonics(lmax_density)
232
233            # Yij = e3nn.IrrepsArray(
234            #     irreps_density,
235            #     generate_spherical_harmonics(lmax=lmax_density, normalize=False)(
236            #         graph["vec"] / graph["distances"][:, None]
237            #     ),
238            # )[:, None, :]
239            Yij = e3nn.spherical_harmonics(
240                irreps_density, graph["vec"], normalize=True
241            )[:, None, :]
242
243            Vij = (
244                e3nn.flax.Linear(irreps_Vij, channel_out=self.nchannels)(Yij)
245                * nn.Dense(self.nchannels, use_bias=False)(xij)[:, :, None]
246            )
247
248            for _ in range(self.nlayers):
249                rhoij = (
250                    FullyConnectedNet(
251                        neurons=[*self.embedding_hidden, self.nchannels],
252                        activation=self.activation,
253                    )(xij)
254                    * switch
255                )[:, :, None] * Yij
256                density = e3nn.scatter_sum(
257                    rhoij, dst=edge_src, output_size=species_encoding.shape[0]
258                )
259
260                Lij = e3nn.tensor_product(
261                    Vij, density[edge_src], filter_ir_out=irreps_Vij
262                )
263                scals = Lij.filter(["0e"]).array.reshape(Lij.shape[0], -1)
264                lij = FullyConnectedNet(neurons=[*self.latent_hidden, self.dim])(
265                    jnp.concatenate((xij, scals), axis=-1)
266                )
267
268                xij = xij + lij * switch
269                # filtering
270                Lij = e3nn.flax.Linear(irreps_Vij)(Lij)
271                # channel mixing
272                Vij = e3nn.flax.Linear(irreps_Vij, channel_out=self.nchannels)(Lij)
273
274            if self.embedding_key is None:
275                return xij, Vij
276            return {**inputs, self.embedding_key: xij, self.tensor_embedding_key: Vij}
277
278else:
279
280    class AllegroE3NNEmbedding(nn.Module):
281        FID: ClassVar[str] = "ALLEGRO_E3NN"
282
283        def __call__(self, *args, **kwargs) -> Any:
284            raise E3NN_EXCEPTION
class AllegroEmbedding(flax.linen.module.Module):
 21class AllegroEmbedding(nn.Module):
 22    """Allegro equivariant pair embedding
 23
 24    FID : ALLEGRO
 25
 26    ### Reference
 27    Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics.
 28      Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y
 29
 30    """
 31
 32    _graphs_properties: Dict
 33    dim: int = 128
 34    """ The dimension of the embedding."""
 35    nchannels: int = 16
 36    """ The number of equivariant channels."""
 37    nlayers: int = 3
 38    """ The number of interaction layers."""
 39    lmax: int = 2
 40    """ The maximum degree of tensorial embedding."""
 41    lmax_density: Optional[int] = None
 42    """ The maximum degree of spherical harmonics for density.
 43        If None, it will be set to lmax. Must be greater or equal to lmax."""
 44    twobody_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 45    """ The number of hidden neurons in the two-body network."""
 46    embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 47    """ The number of hidden neurons in the embedding network."""
 48    latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 49    """ The number of hidden neurons in the latent network."""
 50    activation: Union[Callable, str] = "silu"
 51    """ The activation function to use."""
 52    graph_key: str = "graph"
 53    """ The key in the input dictionary that corresponds to the graph."""
 54    embedding_key: str = "embedding"
 55    """ The key to use for the output embedding in the returned dictionary."""
 56    tensor_embedding_key: str = "tensor_embedding"
 57    """ The key to use for the output tensor embedding in the returned dictionary."""
 58    species_encoding: dict = dataclasses.field(default_factory=dict)
 59    """ The species encoding parameters.  See `fennol.models.misc.encodings.SpeciesEncoding`"""
 60    radial_basis: dict = dataclasses.field(default_factory=dict)
 61    """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`"""
 62
 63    FID: ClassVar[str] = "ALLEGRO"
 64
 65    @nn.compact
 66    def __call__(self, inputs):
 67        """ Forward pass of the Allegro model. """
 68        species = inputs["species"]
 69
 70        graph = inputs[self.graph_key]
 71        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 72        switch = graph["switch"][:, None]
 73        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 74        radial_basis = RadialBasis(
 75            **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"}
 76        )(graph["distances"])
 77
 78        species_encoding = SpeciesEncoding(
 79            **self.species_encoding, name="SpeciesEncoding"
 80        )(species)
 81
 82        xij = (
 83            FullyConnectedNet(
 84                neurons=[*self.twobody_hidden, self.dim], activation=self.activation
 85            )(
 86                jnp.concatenate(
 87                    [
 88                        species_encoding[edge_src],
 89                        species_encoding[edge_dst],
 90                        radial_basis,
 91                    ],
 92                    axis=-1,
 93                )
 94            )
 95            * switch
 96        )
 97
 98        lmax_density = self.lmax_density if self.lmax_density is not None else self.lmax
 99        assert lmax_density >= self.lmax
100
101        Yij = generate_spherical_harmonics(lmax=lmax_density, normalize=False)(
102            graph["vec"] / graph["distances"][:, None]  
103        )[:, None, :]
104
105        nel = (self.lmax + 1) ** 2
106        Vij = (
107            ChannelMixingE3(self.lmax, 1, self.nchannels)(Yij[..., :nel])
108            * nn.Dense(self.nchannels, use_bias=False)(xij)[:, :, None]
109        )
110
111        for _ in range(self.nlayers):
112            rhoij = (
113                FullyConnectedNet(
114                    neurons=[*self.embedding_hidden, self.nchannels],
115                    activation=self.activation,
116                )(xij)
117                * switch
118            )[:, :, None] * Yij
119            density = (
120                jnp.zeros((species.shape[0], *rhoij.shape[1:])).at[edge_src].add(rhoij)
121            )
122
123            Lij = FilteredTensorProduct(self.lmax, lmax_density)(Vij, density[edge_src])
124            scals = jax.lax.index_in_dim(Lij, 0, axis=-1, keepdims=False)
125            lij = FullyConnectedNet(neurons=[*self.latent_hidden, self.dim])(
126                jnp.concatenate((xij, scals), axis=-1)
127            )
128
129            xij = xij + lij * switch
130            Vij = ChannelMixing(self.lmax, self.nchannels, self.nchannels)(Lij)
131
132        if self.embedding_key is None:
133            return xij, Vij
134        return {**inputs, self.embedding_key: xij, self.tensor_embedding_key: Vij}

Allegro equivariant pair embedding

FID : ALLEGRO

Reference

Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics. Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y

AllegroEmbedding( _graphs_properties: Dict, dim: int = 128, nchannels: int = 16, nlayers: int = 3, lmax: int = 2, lmax_density: Optional[int] = None, twobody_hidden: Sequence[int] = <factory>, embedding_hidden: Sequence[int] = <factory>, latent_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', graph_key: str = 'graph', embedding_key: str = 'embedding', tensor_embedding_key: str = 'tensor_embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 128

The dimension of the embedding.

nchannels: int = 16

The number of equivariant channels.

nlayers: int = 3

The number of interaction layers.

lmax: int = 2

The maximum degree of tensorial embedding.

lmax_density: Optional[int] = None

The maximum degree of spherical harmonics for density. If None, it will be set to lmax. Must be greater or equal to lmax.

twobody_hidden: Sequence[int]

The number of hidden neurons in the two-body network.

embedding_hidden: Sequence[int]

The number of hidden neurons in the embedding network.

latent_hidden: Sequence[int]

The number of hidden neurons in the latent network.

activation: Union[Callable, str] = 'silu'

The activation function to use.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the graph.

embedding_key: str = 'embedding'

The key to use for the output embedding in the returned dictionary.

tensor_embedding_key: str = 'tensor_embedding'

The key to use for the output tensor embedding in the returned dictionary.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding

radial_basis: dict

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis

FID: ClassVar[str] = 'ALLEGRO'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class AllegroE3NNEmbedding(flax.linen.module.Module):
140    class AllegroE3NNEmbedding(nn.Module):
141        """Allegro equivariant pair embedding
142
143        FID : ALLEGRO_E3NN
144
145        in this version, equivariant operations use the e3nn library.
146
147        Reference
148        ---------
149        Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics.
150        Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y
151
152        """
153
154        _graphs_properties: Dict
155        dim: int = 128
156        """ The dimension of the embedding."""
157        nchannels: int = 16
158        """ The number of equivariant channels."""
159        nlayers: int = 3
160        """ The number of interaction layers."""
161        irreps_Vij: Union[str, int, 'Irreps'] = 2
162        """ Irreps used for the tensor embedding."""
163        lmax_density: int = None
164        """ The maximum degree of spherical harmonics for density.
165            If None, it will be set to lmax. Must be greater or equal to lmax."""
166        twobody_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
167        """ The number of hidden neurons in the two-body network."""
168        embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
169        """ The number of hidden neurons in the embedding network."""
170        latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
171        """ The number of hidden neurons in the latent network."""
172        activation: Union[Callable, str] = "silu"
173        """ The activation function to use."""
174        graph_key: str = "graph"
175        """ The key in the input dictionary that corresponds to the graph."""
176        embedding_key: str = "embedding"
177        """ The key to use for the output embedding in the returned dictionary."""
178        tensor_embedding_key: str = "tensor_embedding"
179        """ The key to use for the output tensor embedding in the returned dictionary."""
180        species_encoding: dict = dataclasses.field(default_factory=dict)
181        """ The species encoding parameters."""
182        radial_basis: dict = dataclasses.field(default_factory=dict)
183        """ The radial basis parameters."""
184
185        FID: ClassVar[str] = "ALLEGRO_E3NN"
186        """ Identification of the module when building a model."""
187
188        @nn.compact
189        def __call__(self, inputs):
190            species = inputs["species"]
191            assert (
192                len(species.shape) == 1
193            ), "Species must be a 1D array (batches must be flattened)"
194
195            graph = inputs[self.graph_key]
196            edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
197            switch = graph["switch"][:, None]
198            cutoff = self._graphs_properties[self.graph_key]["cutoff"]
199            radial_basis = RadialBasis(
200                **{**self.radial_basis, "end": cutoff, "name": "RadialBasis"}
201            )(graph["distances"])
202            radial_size = radial_basis.shape[-1]
203
204            species_encoding = SpeciesEncoding(
205                **self.species_encoding, name="SpeciesEncoding"
206            )(species)
207            afvs_size = species_encoding.shape[-1]
208
209            xij = (
210                FullyConnectedNet(
211                    neurons=[*self.twobody_hidden, self.dim], activation=self.activation
212                )(
213                    jnp.concatenate(
214                        [
215                            species_encoding[edge_src],
216                            species_encoding[edge_dst],
217                            radial_basis,
218                        ],
219                        axis=-1,
220                    )
221                )
222                * switch
223            )
224            if isinstance(self.irreps_Vij, int):
225                irreps_Vij = e3nn.Irreps.spherical_harmonics(self.irreps_Vij)
226            elif isinstance(self.irreps_Vij, str):
227                irreps_Vij = e3nn.Irreps(self.irreps_Vij)
228            else:
229                irreps_Vij = self.irreps_Vij
230            lmax = max(irreps_Vij.ls)
231            lmax_density = self.lmax_density or lmax
232            irreps_density = e3nn.Irreps.spherical_harmonics(lmax_density)
233
234            # Yij = e3nn.IrrepsArray(
235            #     irreps_density,
236            #     generate_spherical_harmonics(lmax=lmax_density, normalize=False)(
237            #         graph["vec"] / graph["distances"][:, None]
238            #     ),
239            # )[:, None, :]
240            Yij = e3nn.spherical_harmonics(
241                irreps_density, graph["vec"], normalize=True
242            )[:, None, :]
243
244            Vij = (
245                e3nn.flax.Linear(irreps_Vij, channel_out=self.nchannels)(Yij)
246                * nn.Dense(self.nchannels, use_bias=False)(xij)[:, :, None]
247            )
248
249            for _ in range(self.nlayers):
250                rhoij = (
251                    FullyConnectedNet(
252                        neurons=[*self.embedding_hidden, self.nchannels],
253                        activation=self.activation,
254                    )(xij)
255                    * switch
256                )[:, :, None] * Yij
257                density = e3nn.scatter_sum(
258                    rhoij, dst=edge_src, output_size=species_encoding.shape[0]
259                )
260
261                Lij = e3nn.tensor_product(
262                    Vij, density[edge_src], filter_ir_out=irreps_Vij
263                )
264                scals = Lij.filter(["0e"]).array.reshape(Lij.shape[0], -1)
265                lij = FullyConnectedNet(neurons=[*self.latent_hidden, self.dim])(
266                    jnp.concatenate((xij, scals), axis=-1)
267                )
268
269                xij = xij + lij * switch
270                # filtering
271                Lij = e3nn.flax.Linear(irreps_Vij)(Lij)
272                # channel mixing
273                Vij = e3nn.flax.Linear(irreps_Vij, channel_out=self.nchannels)(Lij)
274
275            if self.embedding_key is None:
276                return xij, Vij
277            return {**inputs, self.embedding_key: xij, self.tensor_embedding_key: Vij}

Allegro equivariant pair embedding

FID : ALLEGRO_E3NN

in this version, equivariant operations use the e3nn library.

Reference

Musaelian, A., Batzner, S., Johansson, A. et al. Learning local equivariant representations for large-scale atomistic dynamics. Nat Commun 14, 579 (2023). https://doi.org/10.1038/s41467-023-36329-y

AllegroE3NNEmbedding( _graphs_properties: Dict, dim: int = 128, nchannels: int = 16, nlayers: int = 3, irreps_Vij: Union[str, int, e3nn_jax._src.irreps.Irreps] = 2, lmax_density: int = None, twobody_hidden: Sequence[int] = <factory>, embedding_hidden: Sequence[int] = <factory>, latent_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', graph_key: str = 'graph', embedding_key: str = 'embedding', tensor_embedding_key: str = 'tensor_embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 128

The dimension of the embedding.

nchannels: int = 16

The number of equivariant channels.

nlayers: int = 3

The number of interaction layers.

irreps_Vij: Union[str, int, e3nn_jax._src.irreps.Irreps] = 2

Irreps used for the tensor embedding.

lmax_density: int = None

The maximum degree of spherical harmonics for density. If None, it will be set to lmax. Must be greater or equal to lmax.

twobody_hidden: Sequence[int]

The number of hidden neurons in the two-body network.

embedding_hidden: Sequence[int]

The number of hidden neurons in the embedding network.

latent_hidden: Sequence[int]

The number of hidden neurons in the latent network.

activation: Union[Callable, str] = 'silu'

The activation function to use.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the graph.

embedding_key: str = 'embedding'

The key to use for the output embedding in the returned dictionary.

tensor_embedding_key: str = 'tensor_embedding'

The key to use for the output tensor embedding in the returned dictionary.

species_encoding: dict

The species encoding parameters.

radial_basis: dict

The radial basis parameters.

FID: ClassVar[str] = 'ALLEGRO_E3NN'

Identification of the module when building a model.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None