fennol.models.embeddings.spookynet

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3
  5from ..misc.encodings import SpeciesEncoding, RadialBasis
  6import dataclasses
  7import numpy as np
  8from typing import Dict, Union, Callable, Sequence, Optional, ClassVar
  9from ...utils.initializers import initializer_from_str, scaled_orthogonal
 10from ..misc.nets import  ResMLP
 11
 12
 13class SpookyNetEmbedding(nn.Module):
 14    """SpookyNet equivariant message-passing embedding with electronic encodings (charge + spin).
 15
 16    FID : SPOOKYNET
 17
 18    ### Reference
 19    Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). https://doi.org/10.1038/s41467-021-27504-0
 20
 21    
 22    ### Warning
 23    non-local attention interaction is not yet implemented !
 24
 25    """
 26
 27    _graphs_properties: Dict
 28    dim: int = 128
 29    """ The dimension of the embedding."""
 30    nlayers: int = 3
 31    """ The number of interaction layers."""
 32    graph_key: str = "graph"
 33    """ The key for the graph input."""
 34    embedding_key: str = "embedding"
 35    """ The key for the embedding output."""
 36    species_encoding: dict = dataclasses.field(
 37        default_factory=lambda: {"encoding": "electronic_structure"}
 38    )
 39    """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`. """
 40    radial_basis: dict = dataclasses.field(default_factory=lambda: {"basis": "spooky"})
 41    """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. """
 42    kernel_init: Union[Callable, str] = "scaled_orthogonal(scale=1.0, mode='fan_avg')"
 43    """ The kernel initializer for Dense operations."""
 44    use_spin_encoding: bool = True
 45    """ Whether to use spin encoding."""
 46    use_charge_encoding: bool = True
 47    """ Whether to use charge encoding."""
 48    total_charge_key: str = "total_charge"
 49    """ The key for the total charge input."""
 50
 51    FID: ClassVar[str] = "SPOOKYNET"
 52
 53    @nn.compact
 54    def __call__(self, inputs):
 55        species = inputs["species"]
 56        assert (
 57            len(species.shape) == 1
 58        ), "Species must be a 1D array (batches must be flattened)"
 59
 60        graph = inputs[self.graph_key]
 61        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 62
 63        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 64
 65        kernel_init = (
 66            initializer_from_str(self.kernel_init)
 67            if isinstance(self.kernel_init, str)
 68            else self.kernel_init
 69        )
 70
 71        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 72            species
 73        )
 74        zrand = SpeciesEncoding(
 75            encoding="random", dim=self.dim, name="RandSpeciesEncoding"
 76        )(species)
 77
 78        eZ = (
 79            nn.Dense(
 80                self.dim,
 81                name="species_linear",
 82                use_bias=False,
 83                kernel_init=nn.initializers.zeros,
 84            )(onehot)
 85            + zrand
 86        )
 87        xi = eZ
 88
 89        # encode charge information
 90        batch_index = inputs["batch_index"]
 91        natoms = inputs["natoms"]
 92        if self.use_charge_encoding and (
 93            self.total_charge_key in inputs or self.is_initializing()
 94        ):
 95            Q = inputs.get(self.total_charge_key, jnp.zeros(natoms.shape[0], dtype=xi.dtype))
 96            kq_pos, kq_neg, vq_pos, vq_neg = self.param(
 97                "kv_charge",
 98                lambda key, shape: jax.random.normal(key, shape, dtype=xi.dtype),
 99                (4, self.dim),
100            )
101            qi = nn.Dense(self.dim, kernel_init=kernel_init, name="q_linear")(eZ)
102            pos_mask = Q >= 0
103            kq = jnp.where(pos_mask[:,None], kq_pos[None, :], kq_neg[None, :])
104            vq = jnp.where(pos_mask[:,None], vq_pos[None, :], vq_neg[None, :])
105            qik = (qi * kq[batch_index]).sum(axis=-1) / self.dim**0.5
106            wi =  jax.nn.softplus(qik)
107            wnorm = jax.ops.segment_sum(wi, batch_index, Q.shape[0])
108            avi = wi[:,None] * ((Q / wnorm)[:, None] * vq)[batch_index]
109            eQ = ResMLP(use_bias=False, name="eQ", kernel_init=kernel_init)(avi)
110            xi = xi + eQ
111
112        # encode spin information
113        if self.use_spin_encoding and (
114            "total_spin" in inputs or self.is_initializing()
115        ):
116            S = inputs.get("total_spin", jnp.zeros(natoms.shape[0], dtype=xi.dtype))
117            ks, vs = self.param(
118                "kv_spin",
119                lambda key, shape: jax.random.normal(key, shape, dtype=xi.dtype),
120                (2, self.dim),
121            )
122            si = nn.Dense(self.dim, kernel_init=kernel_init, name="s_linear")(eZ)
123            sik = (si * ks[None, :]).sum(axis=-1) / self.dim**0.5
124            wi = jax.nn.softplus(sik)
125            wnorm = jax.ops.segment_sum(wi, batch_index, S.shape[0])
126            avi = wi[:,None] * ((S / wnorm)[:, None] * vs[None, :])[batch_index]
127            eS = ResMLP(use_bias=False, name="eS", kernel_init=kernel_init)(avi)
128            xi = xi + eS
129
130        distances = graph["distances"]
131        switch = graph["switch"][:, None]
132        dirij = graph["vec"] / distances[:, None]
133        Yij = generate_spherical_harmonics(lmax=2, normalize=False)(dirij)[:, None, :]
134
135        radial_basis = (
136            RadialBasis(
137                **{
138                    **self.radial_basis,
139                    "end": cutoff,
140                    "name": f"RadialBasis",
141                }
142            )(distances)
143            * switch
144        )
145
146        gij = radial_basis[:, :, None] * Yij
147
148        gsij = gij[:, :, 0]
149        gpij = jnp.transpose(gij[:, :, 1:4], (0, 2, 1))
150        gdij = jnp.transpose(gij[:, :, 4:], (0, 2, 1))
151
152        y = 0.0
153
154        for layer in range(self.nlayers):
155            xtilde = ResMLP(res_only=True, name=f"xtilde_{layer}")(xi)
156
157            ### compute local update
158            c = ResMLP(name=f"c_{layer}", kernel_init=kernel_init)(xtilde)
159            sj = ResMLP(name=f"s_{layer}", kernel_init=kernel_init)(xtilde)
160            pj = ResMLP(name=f"p_{layer}", kernel_init=kernel_init)(xtilde)
161            dj = ResMLP(name=f"d_{layer}", kernel_init=kernel_init)(xtilde)
162
163            Gs = nn.Dense(
164                self.dim, use_bias=False, name=f"Gs_{layer}", kernel_init=kernel_init
165            )(gsij)
166            Gp = nn.Dense(
167                self.dim, use_bias=False, name=f"Gp_{layer}", kernel_init=kernel_init
168            )(gpij)
169            Gd = nn.Dense(
170                self.dim, use_bias=False, name=f"Gd_{layer}", kernel_init=kernel_init
171            )(gdij)
172
173            si = jax.ops.segment_sum(sj[edge_dst] * Gs, edge_src, xi.shape[0])
174            pi = jax.ops.segment_sum(pj[edge_dst, None, :] * Gp, edge_src, xi.shape[0])
175            di = jax.ops.segment_sum(dj[edge_dst, None, :] * Gd, edge_src, xi.shape[0])
176
177            P1, P2 = jnp.split(
178                nn.Dense(
179                    2 * self.dim,
180                    use_bias=False,
181                    name=f"P12_{layer}",
182                    kernel_init=kernel_init,
183                )(pi),
184                2,
185                axis=-1,
186            )
187            D1, D2 = jnp.split(
188                nn.Dense(
189                    2 * self.dim,
190                    use_bias=False,
191                    name=f"D12_{layer}",
192                    kernel_init=kernel_init,
193                )(di),
194                2,
195                axis=-1,
196            )
197
198            P12 = (P1 * P2).sum(axis=1)
199            D12 = (D1 * D2).sum(axis=1)
200
201            l = ResMLP(name=f"l_{layer}", kernel_init=kernel_init)(c + si + P12 + D12)
202
203            ### aggregate and update
204            xi = ResMLP(name=f"xi_{layer}", kernel_init=kernel_init)(xtilde + l)
205            y = y + ResMLP(name=f"y_{layer}", kernel_init=kernel_init)(xi)
206
207        output = {
208            **inputs,
209            self.embedding_key: y,
210        }
211        if self.use_charge_encoding and "total_charge" in inputs:
212            output[self.embedding_key + "_eQ"] = eQ
213        if self.use_spin_encoding and "total_spin" in inputs:
214            output[self.embedding_key + "_eS"] = eS
215        return output
class SpookyNetEmbedding(flax.linen.module.Module):
 14class SpookyNetEmbedding(nn.Module):
 15    """SpookyNet equivariant message-passing embedding with electronic encodings (charge + spin).
 16
 17    FID : SPOOKYNET
 18
 19    ### Reference
 20    Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). https://doi.org/10.1038/s41467-021-27504-0
 21
 22    
 23    ### Warning
 24    non-local attention interaction is not yet implemented !
 25
 26    """
 27
 28    _graphs_properties: Dict
 29    dim: int = 128
 30    """ The dimension of the embedding."""
 31    nlayers: int = 3
 32    """ The number of interaction layers."""
 33    graph_key: str = "graph"
 34    """ The key for the graph input."""
 35    embedding_key: str = "embedding"
 36    """ The key for the embedding output."""
 37    species_encoding: dict = dataclasses.field(
 38        default_factory=lambda: {"encoding": "electronic_structure"}
 39    )
 40    """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`. """
 41    radial_basis: dict = dataclasses.field(default_factory=lambda: {"basis": "spooky"})
 42    """ The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. """
 43    kernel_init: Union[Callable, str] = "scaled_orthogonal(scale=1.0, mode='fan_avg')"
 44    """ The kernel initializer for Dense operations."""
 45    use_spin_encoding: bool = True
 46    """ Whether to use spin encoding."""
 47    use_charge_encoding: bool = True
 48    """ Whether to use charge encoding."""
 49    total_charge_key: str = "total_charge"
 50    """ The key for the total charge input."""
 51
 52    FID: ClassVar[str] = "SPOOKYNET"
 53
 54    @nn.compact
 55    def __call__(self, inputs):
 56        species = inputs["species"]
 57        assert (
 58            len(species.shape) == 1
 59        ), "Species must be a 1D array (batches must be flattened)"
 60
 61        graph = inputs[self.graph_key]
 62        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 63
 64        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 65
 66        kernel_init = (
 67            initializer_from_str(self.kernel_init)
 68            if isinstance(self.kernel_init, str)
 69            else self.kernel_init
 70        )
 71
 72        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 73            species
 74        )
 75        zrand = SpeciesEncoding(
 76            encoding="random", dim=self.dim, name="RandSpeciesEncoding"
 77        )(species)
 78
 79        eZ = (
 80            nn.Dense(
 81                self.dim,
 82                name="species_linear",
 83                use_bias=False,
 84                kernel_init=nn.initializers.zeros,
 85            )(onehot)
 86            + zrand
 87        )
 88        xi = eZ
 89
 90        # encode charge information
 91        batch_index = inputs["batch_index"]
 92        natoms = inputs["natoms"]
 93        if self.use_charge_encoding and (
 94            self.total_charge_key in inputs or self.is_initializing()
 95        ):
 96            Q = inputs.get(self.total_charge_key, jnp.zeros(natoms.shape[0], dtype=xi.dtype))
 97            kq_pos, kq_neg, vq_pos, vq_neg = self.param(
 98                "kv_charge",
 99                lambda key, shape: jax.random.normal(key, shape, dtype=xi.dtype),
100                (4, self.dim),
101            )
102            qi = nn.Dense(self.dim, kernel_init=kernel_init, name="q_linear")(eZ)
103            pos_mask = Q >= 0
104            kq = jnp.where(pos_mask[:,None], kq_pos[None, :], kq_neg[None, :])
105            vq = jnp.where(pos_mask[:,None], vq_pos[None, :], vq_neg[None, :])
106            qik = (qi * kq[batch_index]).sum(axis=-1) / self.dim**0.5
107            wi =  jax.nn.softplus(qik)
108            wnorm = jax.ops.segment_sum(wi, batch_index, Q.shape[0])
109            avi = wi[:,None] * ((Q / wnorm)[:, None] * vq)[batch_index]
110            eQ = ResMLP(use_bias=False, name="eQ", kernel_init=kernel_init)(avi)
111            xi = xi + eQ
112
113        # encode spin information
114        if self.use_spin_encoding and (
115            "total_spin" in inputs or self.is_initializing()
116        ):
117            S = inputs.get("total_spin", jnp.zeros(natoms.shape[0], dtype=xi.dtype))
118            ks, vs = self.param(
119                "kv_spin",
120                lambda key, shape: jax.random.normal(key, shape, dtype=xi.dtype),
121                (2, self.dim),
122            )
123            si = nn.Dense(self.dim, kernel_init=kernel_init, name="s_linear")(eZ)
124            sik = (si * ks[None, :]).sum(axis=-1) / self.dim**0.5
125            wi = jax.nn.softplus(sik)
126            wnorm = jax.ops.segment_sum(wi, batch_index, S.shape[0])
127            avi = wi[:,None] * ((S / wnorm)[:, None] * vs[None, :])[batch_index]
128            eS = ResMLP(use_bias=False, name="eS", kernel_init=kernel_init)(avi)
129            xi = xi + eS
130
131        distances = graph["distances"]
132        switch = graph["switch"][:, None]
133        dirij = graph["vec"] / distances[:, None]
134        Yij = generate_spherical_harmonics(lmax=2, normalize=False)(dirij)[:, None, :]
135
136        radial_basis = (
137            RadialBasis(
138                **{
139                    **self.radial_basis,
140                    "end": cutoff,
141                    "name": f"RadialBasis",
142                }
143            )(distances)
144            * switch
145        )
146
147        gij = radial_basis[:, :, None] * Yij
148
149        gsij = gij[:, :, 0]
150        gpij = jnp.transpose(gij[:, :, 1:4], (0, 2, 1))
151        gdij = jnp.transpose(gij[:, :, 4:], (0, 2, 1))
152
153        y = 0.0
154
155        for layer in range(self.nlayers):
156            xtilde = ResMLP(res_only=True, name=f"xtilde_{layer}")(xi)
157
158            ### compute local update
159            c = ResMLP(name=f"c_{layer}", kernel_init=kernel_init)(xtilde)
160            sj = ResMLP(name=f"s_{layer}", kernel_init=kernel_init)(xtilde)
161            pj = ResMLP(name=f"p_{layer}", kernel_init=kernel_init)(xtilde)
162            dj = ResMLP(name=f"d_{layer}", kernel_init=kernel_init)(xtilde)
163
164            Gs = nn.Dense(
165                self.dim, use_bias=False, name=f"Gs_{layer}", kernel_init=kernel_init
166            )(gsij)
167            Gp = nn.Dense(
168                self.dim, use_bias=False, name=f"Gp_{layer}", kernel_init=kernel_init
169            )(gpij)
170            Gd = nn.Dense(
171                self.dim, use_bias=False, name=f"Gd_{layer}", kernel_init=kernel_init
172            )(gdij)
173
174            si = jax.ops.segment_sum(sj[edge_dst] * Gs, edge_src, xi.shape[0])
175            pi = jax.ops.segment_sum(pj[edge_dst, None, :] * Gp, edge_src, xi.shape[0])
176            di = jax.ops.segment_sum(dj[edge_dst, None, :] * Gd, edge_src, xi.shape[0])
177
178            P1, P2 = jnp.split(
179                nn.Dense(
180                    2 * self.dim,
181                    use_bias=False,
182                    name=f"P12_{layer}",
183                    kernel_init=kernel_init,
184                )(pi),
185                2,
186                axis=-1,
187            )
188            D1, D2 = jnp.split(
189                nn.Dense(
190                    2 * self.dim,
191                    use_bias=False,
192                    name=f"D12_{layer}",
193                    kernel_init=kernel_init,
194                )(di),
195                2,
196                axis=-1,
197            )
198
199            P12 = (P1 * P2).sum(axis=1)
200            D12 = (D1 * D2).sum(axis=1)
201
202            l = ResMLP(name=f"l_{layer}", kernel_init=kernel_init)(c + si + P12 + D12)
203
204            ### aggregate and update
205            xi = ResMLP(name=f"xi_{layer}", kernel_init=kernel_init)(xtilde + l)
206            y = y + ResMLP(name=f"y_{layer}", kernel_init=kernel_init)(xi)
207
208        output = {
209            **inputs,
210            self.embedding_key: y,
211        }
212        if self.use_charge_encoding and "total_charge" in inputs:
213            output[self.embedding_key + "_eQ"] = eQ
214        if self.use_spin_encoding and "total_spin" in inputs:
215            output[self.embedding_key + "_eS"] = eS
216        return output

SpookyNet equivariant message-passing embedding with electronic encodings (charge + spin).

FID : SPOOKYNET

Reference

Unke, O.T., Chmiela, S., Gastegger, M. et al. SpookyNet: Learning force fields with electronic degrees of freedom and nonlocal effects. Nat Commun 12, 7273 (2021). https://doi.org/10.1038/s41467-021-27504-0

Warning

non-local attention interaction is not yet implemented !

SpookyNetEmbedding( _graphs_properties: Dict, dim: int = 128, nlayers: int = 3, graph_key: str = 'graph', embedding_key: str = 'embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, kernel_init: Union[Callable, str] = "scaled_orthogonal(scale=1.0, mode='fan_avg')", use_spin_encoding: bool = True, use_charge_encoding: bool = True, total_charge_key: str = 'total_charge', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 128

The dimension of the embedding.

nlayers: int = 3

The number of interaction layers.

graph_key: str = 'graph'

The key for the graph input.

embedding_key: str = 'embedding'

The key for the embedding output.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding.

radial_basis: dict

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis.

kernel_init: Union[Callable, str] = "scaled_orthogonal(scale=1.0, mode='fan_avg')"

The kernel initializer for Dense operations.

use_spin_encoding: bool = True

Whether to use spin encoding.

use_charge_encoding: bool = True

Whether to use charge encoding.

total_charge_key: str = 'total_charge'

The key for the total charge input.

FID: ClassVar[str] = 'SPOOKYNET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None