fennol.models.embeddings.painn

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3
  5from ..misc.encodings import SpeciesEncoding, RadialBasis
  6import dataclasses
  7import numpy as np
  8from typing import Dict, Union, Callable, Sequence, Optional, ClassVar
  9from ...utils.activations import activation_from_str, tssr3
 10from ..misc.nets import FullyConnectedNet
 11
 12
 13class PAINNEmbedding(nn.Module):
 14    """polarizable atom interaction neural network
 15    
 16    ### Reference
 17    K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller. SchNet - a deep learning architecture for molecules and materials. The Journal of Chemical Physics 148(24), 241722 (2018) 
 18    https://doi.org/10.1063/1.5019779
 19
 20    """
 21
 22    _graphs_properties: Dict
 23    dim: int = 128
 24    """ The dimension of the embedding. """
 25    nlayers: int = 3
 26    """ The number of interaction layers. """
 27    nchannels: Optional[int] = None
 28    """ The number of equivariant channels. If None, it is set to dim. """
 29    message_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 30    """ The hidden layers for the message network."""
 31    update_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 32    """ The hidden layers for the update network."""
 33    activation: Union[Callable, str] = "silu"
 34    """ The activation function."""
 35    graph_key: str = "graph"
 36    """ The key for the graph input."""
 37    embedding_key: str = "embedding"
 38    """ The key for the embedding output."""
 39    tensor_embedding_key: str = "embedding_vectors"
 40    """ The key for the tensor embedding output."""
 41    species_encoding: dict = dataclasses.field(default_factory=dict)
 42    """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 43    radial_basis: dict = dataclasses.field(default_factory=dict)
 44    """ The radial basis function parameters. See `fennol.models.misc.encodings.RadialBasis`."""
 45    keep_all_layers: bool = False
 46    """ Whether to keep the embedding from each layer in the output."""
 47
 48    FID: ClassVar[str] = "PAINN"
 49
 50    @nn.compact
 51    def __call__(self, inputs):
 52        species = inputs["species"]
 53        assert (
 54            len(species.shape) == 1
 55        ), "Species must be a 1D array (batches must be flattened)"
 56
 57        graph = inputs[self.graph_key]
 58        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 59
 60        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 61
 62        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 63            species
 64        )
 65        xi = nn.Dense(self.dim, name="species_linear", use_bias=True)(onehot)
 66
 67        nchannels = self.nchannels if self.nchannels is not None else self.dim
 68
 69        distances = graph["distances"]
 70        switch = graph["switch"][:, None]
 71        dirij = (graph["vec"] / distances[:, None])[:, :,None]
 72        Vi = jnp.zeros((xi.shape[0], 3, nchannels), dtype=xi.dtype)
 73
 74        radial_basis = RadialBasis(
 75            **{
 76                **self.radial_basis,
 77                "end": cutoff,
 78                "name": f"RadialBasis",
 79            }
 80        )(distances)
 81
 82        if self.keep_all_layers:
 83            xis = []
 84        for layer in range(self.nlayers):
 85            # compute messages
 86            phi = FullyConnectedNet(
 87                [*self.message_hidden, self.dim + 2 * nchannels],
 88                activation=self.activation,
 89                name=f"message_{layer}",
 90                use_bias=True,
 91            )(xi)
 92            w = (
 93                nn.Dense(
 94                    self.dim + 2 * nchannels,
 95                    name=f"radial_linear_{layer}",
 96                    use_bias=True,
 97                )(radial_basis)
 98                * switch
 99            )
100            dxij, hvv, hvs = jnp.split(
101                phi[edge_dst] * w, [self.dim, self.dim + nchannels], axis=-1
102            )
103
104            dvij = dirij * hvs[:, None,:]
105            if layer > 0:
106                dvij = dvij + Vi[edge_dst] * hvv[:, None,:]
107
108            # aggregate messages
109            v_message = Vi + jax.ops.segment_sum(dvij, edge_src, Vi.shape[0])
110            x_message = xi + jax.ops.segment_sum(dxij, edge_src, xi.shape[0])
111
112            # update
113            u,v = jnp.split(
114                nn.Dense(
115                    2 * self.nchannels,
116                    use_bias=False,
117                    name=f"UV_{layer}",
118                )(v_message),
119                2,
120                axis=-1,
121            )
122
123            scals = (u * v).sum(axis=1)
124            norms = tssr3((v**2).sum(axis=1))
125
126            A = FullyConnectedNet(
127                [*self.update_hidden, self.dim + 2 * nchannels],
128                activation=self.activation,
129                name=f"update_{layer}",
130                use_bias=True,
131            )(jnp.concatenate((x_message, norms), axis=-1))
132
133            ass, asv, avv = jnp.split(
134                A,
135                [self.dim, self.dim + nchannels],
136                axis=-1,
137            )
138
139            Vi = Vi + u * avv[:, None,:]
140            if self.dim != nchannels:
141                dxi = nn.Dense(self.dim, name=f"resize_{layer}", use_bias=False)(
142                    scals * asv
143                )
144            else:
145                dxi = scals * asv
146
147            xi = xi + ass + dxi
148
149            if self.keep_all_layers:
150                xis.append(xi)
151
152        output = {
153            **inputs,
154            self.embedding_key: xi,
155            self.tensor_embedding_key: Vi,
156        }
157        if self.keep_all_layers:
158            output[self.embedding_key + "_layers"] = jnp.stack(xis, axis=1)
159        return output
class PAINNEmbedding(flax.linen.module.Module):
 14class PAINNEmbedding(nn.Module):
 15    """polarizable atom interaction neural network
 16    
 17    ### Reference
 18    K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller. SchNet - a deep learning architecture for molecules and materials. The Journal of Chemical Physics 148(24), 241722 (2018) 
 19    https://doi.org/10.1063/1.5019779
 20
 21    """
 22
 23    _graphs_properties: Dict
 24    dim: int = 128
 25    """ The dimension of the embedding. """
 26    nlayers: int = 3
 27    """ The number of interaction layers. """
 28    nchannels: Optional[int] = None
 29    """ The number of equivariant channels. If None, it is set to dim. """
 30    message_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 31    """ The hidden layers for the message network."""
 32    update_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128])
 33    """ The hidden layers for the update network."""
 34    activation: Union[Callable, str] = "silu"
 35    """ The activation function."""
 36    graph_key: str = "graph"
 37    """ The key for the graph input."""
 38    embedding_key: str = "embedding"
 39    """ The key for the embedding output."""
 40    tensor_embedding_key: str = "embedding_vectors"
 41    """ The key for the tensor embedding output."""
 42    species_encoding: dict = dataclasses.field(default_factory=dict)
 43    """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 44    radial_basis: dict = dataclasses.field(default_factory=dict)
 45    """ The radial basis function parameters. See `fennol.models.misc.encodings.RadialBasis`."""
 46    keep_all_layers: bool = False
 47    """ Whether to keep the embedding from each layer in the output."""
 48
 49    FID: ClassVar[str] = "PAINN"
 50
 51    @nn.compact
 52    def __call__(self, inputs):
 53        species = inputs["species"]
 54        assert (
 55            len(species.shape) == 1
 56        ), "Species must be a 1D array (batches must be flattened)"
 57
 58        graph = inputs[self.graph_key]
 59        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 60
 61        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 62
 63        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 64            species
 65        )
 66        xi = nn.Dense(self.dim, name="species_linear", use_bias=True)(onehot)
 67
 68        nchannels = self.nchannels if self.nchannels is not None else self.dim
 69
 70        distances = graph["distances"]
 71        switch = graph["switch"][:, None]
 72        dirij = (graph["vec"] / distances[:, None])[:, :,None]
 73        Vi = jnp.zeros((xi.shape[0], 3, nchannels), dtype=xi.dtype)
 74
 75        radial_basis = RadialBasis(
 76            **{
 77                **self.radial_basis,
 78                "end": cutoff,
 79                "name": f"RadialBasis",
 80            }
 81        )(distances)
 82
 83        if self.keep_all_layers:
 84            xis = []
 85        for layer in range(self.nlayers):
 86            # compute messages
 87            phi = FullyConnectedNet(
 88                [*self.message_hidden, self.dim + 2 * nchannels],
 89                activation=self.activation,
 90                name=f"message_{layer}",
 91                use_bias=True,
 92            )(xi)
 93            w = (
 94                nn.Dense(
 95                    self.dim + 2 * nchannels,
 96                    name=f"radial_linear_{layer}",
 97                    use_bias=True,
 98                )(radial_basis)
 99                * switch
100            )
101            dxij, hvv, hvs = jnp.split(
102                phi[edge_dst] * w, [self.dim, self.dim + nchannels], axis=-1
103            )
104
105            dvij = dirij * hvs[:, None,:]
106            if layer > 0:
107                dvij = dvij + Vi[edge_dst] * hvv[:, None,:]
108
109            # aggregate messages
110            v_message = Vi + jax.ops.segment_sum(dvij, edge_src, Vi.shape[0])
111            x_message = xi + jax.ops.segment_sum(dxij, edge_src, xi.shape[0])
112
113            # update
114            u,v = jnp.split(
115                nn.Dense(
116                    2 * self.nchannels,
117                    use_bias=False,
118                    name=f"UV_{layer}",
119                )(v_message),
120                2,
121                axis=-1,
122            )
123
124            scals = (u * v).sum(axis=1)
125            norms = tssr3((v**2).sum(axis=1))
126
127            A = FullyConnectedNet(
128                [*self.update_hidden, self.dim + 2 * nchannels],
129                activation=self.activation,
130                name=f"update_{layer}",
131                use_bias=True,
132            )(jnp.concatenate((x_message, norms), axis=-1))
133
134            ass, asv, avv = jnp.split(
135                A,
136                [self.dim, self.dim + nchannels],
137                axis=-1,
138            )
139
140            Vi = Vi + u * avv[:, None,:]
141            if self.dim != nchannels:
142                dxi = nn.Dense(self.dim, name=f"resize_{layer}", use_bias=False)(
143                    scals * asv
144                )
145            else:
146                dxi = scals * asv
147
148            xi = xi + ass + dxi
149
150            if self.keep_all_layers:
151                xis.append(xi)
152
153        output = {
154            **inputs,
155            self.embedding_key: xi,
156            self.tensor_embedding_key: Vi,
157        }
158        if self.keep_all_layers:
159            output[self.embedding_key + "_layers"] = jnp.stack(xis, axis=1)
160        return output

polarizable atom interaction neural network

Reference

K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller. SchNet - a deep learning architecture for molecules and materials. The Journal of Chemical Physics 148(24), 241722 (2018) https://doi.org/10.1063/1.5019779

PAINNEmbedding( _graphs_properties: Dict, dim: int = 128, nlayers: int = 3, nchannels: Optional[int] = None, message_hidden: Sequence[int] = <factory>, update_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', graph_key: str = 'graph', embedding_key: str = 'embedding', tensor_embedding_key: str = 'embedding_vectors', species_encoding: dict = <factory>, radial_basis: dict = <factory>, keep_all_layers: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 128

The dimension of the embedding.

nlayers: int = 3

The number of interaction layers.

nchannels: Optional[int] = None

The number of equivariant channels. If None, it is set to dim.

message_hidden: Sequence[int]

The hidden layers for the message network.

update_hidden: Sequence[int]

The hidden layers for the update network.

activation: Union[Callable, str] = 'silu'

The activation function.

graph_key: str = 'graph'

The key for the graph input.

embedding_key: str = 'embedding'

The key for the embedding output.

tensor_embedding_key: str = 'embedding_vectors'

The key for the tensor embedding output.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding.

radial_basis: dict

The radial basis function parameters. See fennol.models.misc.encodings.RadialBasis.

keep_all_layers: bool = False

Whether to keep the embedding from each layer in the output.

FID: ClassVar[str] = 'PAINN'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None