fennol.models.embeddings.hipnn

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3
  5from ..misc.encodings import SpeciesEncoding, RadialBasis
  6import dataclasses
  7import numpy as np
  8from typing import Dict, Union, Callable, ClassVar
  9from ...utils.activations import activation_from_str, tssr3
 10from ..misc.nets import FullyConnectedNet
 11
 12
 13class HIPNNEmbedding(nn.Module):
 14    """Hierarchically Interacting Particle Neural Network
 15
 16    ### Reference
 17    Adapted from N. Lubbers, J. S. Smith and K. Barros, Hierarchical modeling of molecular energies using a deep neural network
 18    J. Chem. Phys. 148, 241715 (2018) (J. Chem. Phys. 148, 241715 (2018)) (https://doi.org/10.1063/1.5011181)
 19
 20    """
 21    
 22    _graphs_properties: Dict
 23    dim: int = 80
 24    """The dimension of the embedding."""
 25    n_onsite: int = 3
 26    """The number of onsite layers per interaction layer."""
 27    nlayers: int = 2
 28    """The number of interaction layers."""
 29    lmax: int = 0
 30    """The maximum value degree of spherical harmonics."""
 31    n_message: int = 0
 32    """The number of layers for the message NN."""
 33    activation: Union[Callable, str] = "silu"
 34    """The activation function."""
 35    graph_key: str = "graph"
 36    """The key for the graph input."""
 37    embedding_key: str = "embedding"
 38    """The key for the embedding output."""
 39    species_encoding: dict = dataclasses.field(default_factory=dict)
 40    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 41    radial_basis: dict = dataclasses.field(default_factory=lambda: {"dim": 20})
 42    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`."""
 43    keep_all_layers: bool = True
 44    """Whether to keep embeddings from each layer in the output."""
 45    graph_l_key: str = "graph"
 46    """The key for the graph input for the spherical harmonics."""
 47
 48    FID: ClassVar[str] = "HIPNN"
 49
 50
 51    @nn.compact
 52    def __call__(self, inputs):
 53        species = inputs["species"]
 54        assert (
 55            len(species.shape) == 1
 56        ), "Species must be a 1D array (batches must be flattened)"
 57
 58        graph = inputs[self.graph_key]
 59        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 60
 61        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 62
 63        zi = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(species)
 64
 65        act = (
 66            activation_from_str(self.activation)
 67            if isinstance(self.activation, str)
 68            else self.activation
 69        )
 70
 71        distances = graph["distances"]
 72        if self.lmax > 0:
 73            filtered_l = "parent_graph" in self._graphs_properties[self.graph_l_key]
 74
 75            correct_graph = (
 76                self.graph_l_key == self.graph_key
 77                or self._graphs_properties[self.graph_l_key]["parent_graph"]
 78                == self.graph_key
 79            )
 80            assert (
 81                correct_graph
 82            ), f"graph_l_key={self.graph_l_key} must be a subgraph of graph_key={self.graph_key}"
 83
 84            graph_l = inputs[self.graph_l_key]
 85            Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
 86                graph_l["vec"] / graph_l["distances"][:, None]
 87            )[:, None, :]
 88            if filtered_l:
 89                Yij = Yij[:, :, 1:]
 90            # reps_l = np.array([2 * l + 1 for l in range(1, self.lmax + 1)])
 91
 92        switch = graph["switch"][:, None]
 93        if self.keep_all_layers:
 94            zis = []
 95        for layer in range(self.nlayers):
 96            ### interaction layer
 97            s = RadialBasis(
 98                **{
 99                    "basis": "gaussian_rinv",
100                    "end": cutoff,
101                    **self.radial_basis,
102                    "name": f"RadialBasis_{layer}",
103                }
104            )(distances)
105
106            zself = nn.Dense(self.dim, name=f"self_int_{layer}", use_bias=True)(zi)
107            if self.n_message == 0:
108                V = self.param(
109                    f"V_{layer}",
110                    jax.nn.initializers.glorot_normal(batch_axis=0),
111                    (s.shape[-1], zi.shape[-1], self.dim),
112                )
113                mij = jnp.einsum("...j,...k,jkl->...l", s, zi[edge_dst], V)
114            else:
115                mij = FullyConnectedNet(
116                    [self.dim] * (self.n_message + 1),
117                    activation=act,
118                    name=f"mij_{layer}",
119                )(jnp.concatenate([zi[edge_dst], s], axis=-1))
120
121            if self.lmax == 0:
122                zi = act(zself.at[edge_src].add(mij * switch,mode="drop"))
123            else:
124                if filtered_l:
125                    zself = zself.at[edge_src].add(mij * switch,mode="drop")
126                    filter_indices = graph_l["filter_indices"]
127                    mij = mij[filter_indices]
128
129                mij = mij * graph_l["switch"][:, None]
130                Mij = mij[:, :, None] * Yij
131                Mi = jax.ops.segment_sum(Mij, graph_l["edge_src"], zi.shape[0])
132
133                if filtered_l:
134                    zint = jnp.zeros_like(zself)
135                else:
136                    zint = jax.lax.index_in_dim(Mi, 0, axis=-1, keepdims=False)
137                    Mi = Mi[:, :, 1:]
138
139                ts = self.param(
140                    f"ts_{layer}",
141                    lambda key, shape: jax.random.normal(key, shape, dtype=zint.dtype),
142                    (self.lmax, 3),
143                )  # .repeat(reps_l)
144
145                # zint = zint + jnp.sum(ts[None,None,:]*Mi**2,axis=-1)
146                for l in range(1, self.lmax + 1):
147                    Ml = jax.lax.dynamic_slice_in_dim(
148                        Mi, start_index=l**2 - 1, slice_size=2 * l + 1, axis=-1
149                    )
150                    zint = zint + ts[l, 2] * tssr3(
151                        nn.Dense(self.dim, name=f"linear_{layer}_l{l}")(
152                            jnp.sum(Ml**2, axis=-1)
153                        )
154                    )
155                zi = act(zself + zint)
156
157            ### onsite layers
158            # zi = zi + FullyConnectedNet(
159            #     [self.dim] * (self.n_onsite + 1),
160            #     activation=act,
161            #     use_bias=True,
162            #     name=f"onsite_{layer}",
163            # )(zi)
164            for j in range(self.n_onsite):
165                zi = zi + FullyConnectedNet(
166                    (self.dim, self.dim),
167                    activation=act,
168                    use_bias=True,
169                    name=f"onsite_{layer}_{j}",
170                )(zi)
171            if self.keep_all_layers:
172                zis.append(zi)
173
174        output = {
175            **inputs,
176            self.embedding_key: zi,
177        }
178        if self.keep_all_layers:
179            output[self.embedding_key + "_layers"] = jnp.stack(zis, axis=1)
180        return output
class HIPNNEmbedding(flax.linen.module.Module):
 14class HIPNNEmbedding(nn.Module):
 15    """Hierarchically Interacting Particle Neural Network
 16
 17    ### Reference
 18    Adapted from N. Lubbers, J. S. Smith and K. Barros, Hierarchical modeling of molecular energies using a deep neural network
 19    J. Chem. Phys. 148, 241715 (2018) (J. Chem. Phys. 148, 241715 (2018)) (https://doi.org/10.1063/1.5011181)
 20
 21    """
 22    
 23    _graphs_properties: Dict
 24    dim: int = 80
 25    """The dimension of the embedding."""
 26    n_onsite: int = 3
 27    """The number of onsite layers per interaction layer."""
 28    nlayers: int = 2
 29    """The number of interaction layers."""
 30    lmax: int = 0
 31    """The maximum value degree of spherical harmonics."""
 32    n_message: int = 0
 33    """The number of layers for the message NN."""
 34    activation: Union[Callable, str] = "silu"
 35    """The activation function."""
 36    graph_key: str = "graph"
 37    """The key for the graph input."""
 38    embedding_key: str = "embedding"
 39    """The key for the embedding output."""
 40    species_encoding: dict = dataclasses.field(default_factory=dict)
 41    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 42    radial_basis: dict = dataclasses.field(default_factory=lambda: {"dim": 20})
 43    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`."""
 44    keep_all_layers: bool = True
 45    """Whether to keep embeddings from each layer in the output."""
 46    graph_l_key: str = "graph"
 47    """The key for the graph input for the spherical harmonics."""
 48
 49    FID: ClassVar[str] = "HIPNN"
 50
 51
 52    @nn.compact
 53    def __call__(self, inputs):
 54        species = inputs["species"]
 55        assert (
 56            len(species.shape) == 1
 57        ), "Species must be a 1D array (batches must be flattened)"
 58
 59        graph = inputs[self.graph_key]
 60        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 61
 62        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 63
 64        zi = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(species)
 65
 66        act = (
 67            activation_from_str(self.activation)
 68            if isinstance(self.activation, str)
 69            else self.activation
 70        )
 71
 72        distances = graph["distances"]
 73        if self.lmax > 0:
 74            filtered_l = "parent_graph" in self._graphs_properties[self.graph_l_key]
 75
 76            correct_graph = (
 77                self.graph_l_key == self.graph_key
 78                or self._graphs_properties[self.graph_l_key]["parent_graph"]
 79                == self.graph_key
 80            )
 81            assert (
 82                correct_graph
 83            ), f"graph_l_key={self.graph_l_key} must be a subgraph of graph_key={self.graph_key}"
 84
 85            graph_l = inputs[self.graph_l_key]
 86            Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
 87                graph_l["vec"] / graph_l["distances"][:, None]
 88            )[:, None, :]
 89            if filtered_l:
 90                Yij = Yij[:, :, 1:]
 91            # reps_l = np.array([2 * l + 1 for l in range(1, self.lmax + 1)])
 92
 93        switch = graph["switch"][:, None]
 94        if self.keep_all_layers:
 95            zis = []
 96        for layer in range(self.nlayers):
 97            ### interaction layer
 98            s = RadialBasis(
 99                **{
100                    "basis": "gaussian_rinv",
101                    "end": cutoff,
102                    **self.radial_basis,
103                    "name": f"RadialBasis_{layer}",
104                }
105            )(distances)
106
107            zself = nn.Dense(self.dim, name=f"self_int_{layer}", use_bias=True)(zi)
108            if self.n_message == 0:
109                V = self.param(
110                    f"V_{layer}",
111                    jax.nn.initializers.glorot_normal(batch_axis=0),
112                    (s.shape[-1], zi.shape[-1], self.dim),
113                )
114                mij = jnp.einsum("...j,...k,jkl->...l", s, zi[edge_dst], V)
115            else:
116                mij = FullyConnectedNet(
117                    [self.dim] * (self.n_message + 1),
118                    activation=act,
119                    name=f"mij_{layer}",
120                )(jnp.concatenate([zi[edge_dst], s], axis=-1))
121
122            if self.lmax == 0:
123                zi = act(zself.at[edge_src].add(mij * switch,mode="drop"))
124            else:
125                if filtered_l:
126                    zself = zself.at[edge_src].add(mij * switch,mode="drop")
127                    filter_indices = graph_l["filter_indices"]
128                    mij = mij[filter_indices]
129
130                mij = mij * graph_l["switch"][:, None]
131                Mij = mij[:, :, None] * Yij
132                Mi = jax.ops.segment_sum(Mij, graph_l["edge_src"], zi.shape[0])
133
134                if filtered_l:
135                    zint = jnp.zeros_like(zself)
136                else:
137                    zint = jax.lax.index_in_dim(Mi, 0, axis=-1, keepdims=False)
138                    Mi = Mi[:, :, 1:]
139
140                ts = self.param(
141                    f"ts_{layer}",
142                    lambda key, shape: jax.random.normal(key, shape, dtype=zint.dtype),
143                    (self.lmax, 3),
144                )  # .repeat(reps_l)
145
146                # zint = zint + jnp.sum(ts[None,None,:]*Mi**2,axis=-1)
147                for l in range(1, self.lmax + 1):
148                    Ml = jax.lax.dynamic_slice_in_dim(
149                        Mi, start_index=l**2 - 1, slice_size=2 * l + 1, axis=-1
150                    )
151                    zint = zint + ts[l, 2] * tssr3(
152                        nn.Dense(self.dim, name=f"linear_{layer}_l{l}")(
153                            jnp.sum(Ml**2, axis=-1)
154                        )
155                    )
156                zi = act(zself + zint)
157
158            ### onsite layers
159            # zi = zi + FullyConnectedNet(
160            #     [self.dim] * (self.n_onsite + 1),
161            #     activation=act,
162            #     use_bias=True,
163            #     name=f"onsite_{layer}",
164            # )(zi)
165            for j in range(self.n_onsite):
166                zi = zi + FullyConnectedNet(
167                    (self.dim, self.dim),
168                    activation=act,
169                    use_bias=True,
170                    name=f"onsite_{layer}_{j}",
171                )(zi)
172            if self.keep_all_layers:
173                zis.append(zi)
174
175        output = {
176            **inputs,
177            self.embedding_key: zi,
178        }
179        if self.keep_all_layers:
180            output[self.embedding_key + "_layers"] = jnp.stack(zis, axis=1)
181        return output

Hierarchically Interacting Particle Neural Network

Reference

Adapted from N. Lubbers, J. S. Smith and K. Barros, Hierarchical modeling of molecular energies using a deep neural network J. Chem. Phys. 148, 241715 (2018) (J. Chem. Phys. 148, 241715 (2018)) (https://doi.org/10.1063/1.5011181)

HIPNNEmbedding( _graphs_properties: Dict, dim: int = 80, n_onsite: int = 3, nlayers: int = 2, lmax: int = 0, n_message: int = 0, activation: Union[Callable, str] = 'silu', graph_key: str = 'graph', embedding_key: str = 'embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, keep_all_layers: bool = True, graph_l_key: str = 'graph', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 80

The dimension of the embedding.

n_onsite: int = 3

The number of onsite layers per interaction layer.

nlayers: int = 2

The number of interaction layers.

lmax: int = 0

The maximum value degree of spherical harmonics.

n_message: int = 0

The number of layers for the message NN.

activation: Union[Callable, str] = 'silu'

The activation function.

graph_key: str = 'graph'

The key for the graph input.

embedding_key: str = 'embedding'

The key for the embedding output.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding.

radial_basis: dict

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis.

keep_all_layers: bool = True

Whether to keep embeddings from each layer in the output.

graph_l_key: str = 'graph'

The key for the graph input for the spherical harmonics.

FID: ClassVar[str] = 'HIPNN'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None