fennol.models.embeddings.newtonnet

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3
  5from ..misc.encodings import SpeciesEncoding, RadialBasis
  6import dataclasses
  7import numpy as np
  8from typing import Dict, Union, Callable, Sequence, Optional, ClassVar
  9from ...utils.activations import activation_from_str
 10from ..misc.nets import FullyConnectedNet
 11
 12
 13class NewtonNetEmbedding(nn.Module):
 14    """ Newtonian message passing network
 15
 16    ### Reference
 17    Haghighatlari et al., NewtonNet: a Newtonian message passing network for deep learning of interatomic potentials and forces
 18    https://doi.org/10.1039/D2DD00008C
 19
 20    """
 21    
 22    _graphs_properties: Dict
 23    dim: int = 128
 24    """The dimension of the embedding."""
 25    nlayers: int = 3
 26    """The number of interaction layers."""
 27    nchannels: Optional[int] = None
 28    """The number of vector channels. If None, it is set to dim."""
 29    embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 30    """The hidden layers for the embedding networks."""
 31    latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 32    """The hidden layers for the latent update network."""
 33    activation: Union[Callable, str] = "silu"
 34    """The activation function."""
 35    graph_key: str = "graph"
 36    """The key for the graph input."""
 37    embedding_key: str = "embedding"
 38    """The key for the output embedding."""
 39    species_encoding: dict = dataclasses.field(default_factory=dict)
 40    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 41    radial_basis: dict = dataclasses.field(default_factory=dict)
 42    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`."""
 43    keep_all_layers: bool = False
 44    """Whether to keep embeddings from each layer in the output."""
 45
 46    FID: ClassVar[str] = "NEWTONNET"
 47
 48
 49    @nn.compact
 50    def __call__(self, inputs):
 51        species = inputs["species"]
 52        assert (
 53            len(species.shape) == 1
 54        ), "Species must be a 1D array (batches must be flattened)"
 55
 56        graph = inputs[self.graph_key]
 57        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 58
 59        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 60
 61        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 62            species
 63        )
 64        xi = nn.Dense(self.dim, name="species_linear", use_bias=True)(onehot)
 65
 66        nchannels = self.nchannels if self.nchannels is not None else self.dim
 67
 68        distances = graph["distances"]
 69        switch = graph["switch"][:, None]
 70        dirij = graph["vec"] / distances[:, None] * switch
 71
 72        radial_basis = RadialBasis(
 73            **{
 74                **self.radial_basis,
 75                "end": cutoff,
 76                "name": f"RadialBasis",
 77            }
 78        )(distances)
 79
 80        if self.keep_all_layers:
 81            xis = []
 82        for layer in range(self.nlayers):
 83            ai = FullyConnectedNet(
 84                [*self.embedding_hidden, self.dim],
 85                activation=self.activation,
 86                name=f"phi_a_{layer}",
 87                use_bias=True,
 88            )(xi)
 89            Dij = nn.Dense(self.dim, name=f"radial_linear_{layer}", use_bias=True)(
 90                radial_basis
 91            )
 92            mij = ai[edge_src] * ai[edge_dst] * Dij * switch
 93
 94            mi = jax.ops.segment_sum(mij, edge_src, xi.shape[0])
 95            xi = xi + mi
 96
 97            Fij = (
 98                FullyConnectedNet(
 99                    [*self.embedding_hidden, 1],
100                    activation=self.activation,
101                    name=f"phi_F_{layer}",
102                    use_bias=True,
103                )(mij)
104                * dirij
105            )
106
107            fij = (
108                FullyConnectedNet(
109                    [*self.embedding_hidden, nchannels],
110                    activation=self.activation,
111                    name=f"phi_f_{layer}",
112                    use_bias=True,
113                )(mij)[:, :, None]
114                * Fij[:, None, :]
115            )
116
117            if layer == 0:
118                fi = jax.ops.segment_sum(fij, edge_src, xi.shape[0])
119            else:
120                fi = fi + jax.ops.segment_sum(fij, edge_src, xi.shape[0])
121
122            deltai = (
123                FullyConnectedNet(
124                    [*self.embedding_hidden, nchannels],
125                    activation=self.activation,
126                    name=f"phi_R_{layer}",
127                    use_bias=True,
128                )(xi)[:, :, None]
129                * fi
130            )
131            if layer == 0:
132                di = deltai
133            else:
134                phi_rij = FullyConnectedNet(
135                    [*self.embedding_hidden, nchannels],
136                    activation=self.activation,
137                    name=f"phi_r_{layer}",
138                    use_bias=True,
139                )(mij)
140                
141                phi_r = jax.ops.segment_sum(phi_rij * switch, edge_src, xi.shape[0])
142                di = phi_r[:, :, None] * di + deltai
143
144            scal = jnp.sum(fi * di, axis=-1)
145            dui = (
146                -FullyConnectedNet(
147                    [*self.latent_hidden, nchannels],
148                    activation=self.activation,
149                    name=f"phi_u_{layer}",
150                    use_bias=True,
151                )(xi)
152                * scal
153            )
154
155            if nchannels != self.dim:
156                dui = nn.Dense(self.dim, name=f"reshape_{layer}", use_bias=False)(dui)
157
158            xi = xi + dui
159            if self.keep_all_layers:
160                xis.append(xi)
161
162        output = {
163            **inputs,
164            self.embedding_key: xi,
165        }
166        if self.keep_all_layers:
167            output[self.embedding_key + "_layers"] = jnp.stack(xis, axis=1)
168        return output
class NewtonNetEmbedding(flax.linen.module.Module):
 14class NewtonNetEmbedding(nn.Module):
 15    """ Newtonian message passing network
 16
 17    ### Reference
 18    Haghighatlari et al., NewtonNet: a Newtonian message passing network for deep learning of interatomic potentials and forces
 19    https://doi.org/10.1039/D2DD00008C
 20
 21    """
 22    
 23    _graphs_properties: Dict
 24    dim: int = 128
 25    """The dimension of the embedding."""
 26    nlayers: int = 3
 27    """The number of interaction layers."""
 28    nchannels: Optional[int] = None
 29    """The number of vector channels. If None, it is set to dim."""
 30    embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 31    """The hidden layers for the embedding networks."""
 32    latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 33    """The hidden layers for the latent update network."""
 34    activation: Union[Callable, str] = "silu"
 35    """The activation function."""
 36    graph_key: str = "graph"
 37    """The key for the graph input."""
 38    embedding_key: str = "embedding"
 39    """The key for the output embedding."""
 40    species_encoding: dict = dataclasses.field(default_factory=dict)
 41    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 42    radial_basis: dict = dataclasses.field(default_factory=dict)
 43    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`."""
 44    keep_all_layers: bool = False
 45    """Whether to keep embeddings from each layer in the output."""
 46
 47    FID: ClassVar[str] = "NEWTONNET"
 48
 49
 50    @nn.compact
 51    def __call__(self, inputs):
 52        species = inputs["species"]
 53        assert (
 54            len(species.shape) == 1
 55        ), "Species must be a 1D array (batches must be flattened)"
 56
 57        graph = inputs[self.graph_key]
 58        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 59
 60        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 61
 62        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 63            species
 64        )
 65        xi = nn.Dense(self.dim, name="species_linear", use_bias=True)(onehot)
 66
 67        nchannels = self.nchannels if self.nchannels is not None else self.dim
 68
 69        distances = graph["distances"]
 70        switch = graph["switch"][:, None]
 71        dirij = graph["vec"] / distances[:, None] * switch
 72
 73        radial_basis = RadialBasis(
 74            **{
 75                **self.radial_basis,
 76                "end": cutoff,
 77                "name": f"RadialBasis",
 78            }
 79        )(distances)
 80
 81        if self.keep_all_layers:
 82            xis = []
 83        for layer in range(self.nlayers):
 84            ai = FullyConnectedNet(
 85                [*self.embedding_hidden, self.dim],
 86                activation=self.activation,
 87                name=f"phi_a_{layer}",
 88                use_bias=True,
 89            )(xi)
 90            Dij = nn.Dense(self.dim, name=f"radial_linear_{layer}", use_bias=True)(
 91                radial_basis
 92            )
 93            mij = ai[edge_src] * ai[edge_dst] * Dij * switch
 94
 95            mi = jax.ops.segment_sum(mij, edge_src, xi.shape[0])
 96            xi = xi + mi
 97
 98            Fij = (
 99                FullyConnectedNet(
100                    [*self.embedding_hidden, 1],
101                    activation=self.activation,
102                    name=f"phi_F_{layer}",
103                    use_bias=True,
104                )(mij)
105                * dirij
106            )
107
108            fij = (
109                FullyConnectedNet(
110                    [*self.embedding_hidden, nchannels],
111                    activation=self.activation,
112                    name=f"phi_f_{layer}",
113                    use_bias=True,
114                )(mij)[:, :, None]
115                * Fij[:, None, :]
116            )
117
118            if layer == 0:
119                fi = jax.ops.segment_sum(fij, edge_src, xi.shape[0])
120            else:
121                fi = fi + jax.ops.segment_sum(fij, edge_src, xi.shape[0])
122
123            deltai = (
124                FullyConnectedNet(
125                    [*self.embedding_hidden, nchannels],
126                    activation=self.activation,
127                    name=f"phi_R_{layer}",
128                    use_bias=True,
129                )(xi)[:, :, None]
130                * fi
131            )
132            if layer == 0:
133                di = deltai
134            else:
135                phi_rij = FullyConnectedNet(
136                    [*self.embedding_hidden, nchannels],
137                    activation=self.activation,
138                    name=f"phi_r_{layer}",
139                    use_bias=True,
140                )(mij)
141                
142                phi_r = jax.ops.segment_sum(phi_rij * switch, edge_src, xi.shape[0])
143                di = phi_r[:, :, None] * di + deltai
144
145            scal = jnp.sum(fi * di, axis=-1)
146            dui = (
147                -FullyConnectedNet(
148                    [*self.latent_hidden, nchannels],
149                    activation=self.activation,
150                    name=f"phi_u_{layer}",
151                    use_bias=True,
152                )(xi)
153                * scal
154            )
155
156            if nchannels != self.dim:
157                dui = nn.Dense(self.dim, name=f"reshape_{layer}", use_bias=False)(dui)
158
159            xi = xi + dui
160            if self.keep_all_layers:
161                xis.append(xi)
162
163        output = {
164            **inputs,
165            self.embedding_key: xi,
166        }
167        if self.keep_all_layers:
168            output[self.embedding_key + "_layers"] = jnp.stack(xis, axis=1)
169        return output

Newtonian message passing network

Reference

Haghighatlari et al., NewtonNet: a Newtonian message passing network for deep learning of interatomic potentials and forces https://doi.org/10.1039/D2DD00008C

NewtonNetEmbedding( _graphs_properties: Dict, dim: int = 128, nlayers: int = 3, nchannels: Optional[int] = None, embedding_hidden: Sequence[int] = <factory>, latent_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', graph_key: str = 'graph', embedding_key: str = 'embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, keep_all_layers: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 128

The dimension of the embedding.

nlayers: int = 3

The number of interaction layers.

nchannels: Optional[int] = None

The number of vector channels. If None, it is set to dim.

embedding_hidden: Sequence[int]

The hidden layers for the embedding networks.

latent_hidden: Sequence[int]

The hidden layers for the latent update network.

activation: Union[Callable, str] = 'silu'

The activation function.

graph_key: str = 'graph'

The key for the graph input.

embedding_key: str = 'embedding'

The key for the output embedding.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding.

radial_basis: dict

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis.

keep_all_layers: bool = False

Whether to keep embeddings from each layer in the output.

FID: ClassVar[str] = 'NEWTONNET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None