fennol.models.embeddings.deeppot

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Sequence, Dict, Optional, Union, Callable, ClassVar
  5import numpy as np
  6import dataclasses
  7
  8from ..misc.nets import FullyConnectedNet, ChemicalNet
  9
 10from ...utils.periodic_table import PERIODIC_TABLE, VALENCE_STRUCTURE
 11from ..misc.encodings import SpeciesEncoding, RadialBasis
 12
 13
 14class DeepPotEmbedding(nn.Module):
 15    """Deep Potential embedding
 16
 17    FID : DEEPPOT
 18
 19    ### Reference
 20    Zhang, L., Han, J., Wang, H., Car, R., & E, W. (2018). Deep Potential Molecular dynamics: A scalable model with the accuracy of quantum mechanics. Phys. Rev. Lett., 120(14), 143001. https://doi.org/10.1103/PhysRevLett.120.143001
 21
 22    """
 23    _graphs_properties: Dict
 24    dim: int = 64
 25    """The dimension of the embedding."""
 26    subdim: int = 8
 27    """The first dimensions to select for the embedding tensor product."""
 28    radial_dim: Optional[int] = None
 29    """The dimension of the radial embedding for tensor combination. 
 30        If None, we use a neural net to combine chemical and radial information, like in the original DeepPot."""
 31    embedding_key: str = "embedding"
 32    """The key to use for the output embedding in the returned dictionary."""
 33    graph_key: str = "graph"
 34    """The key in the input dictionary that corresponds to the radial graph."""
 35    species_encoding: dict = dataclasses.field(default_factory=dict)
 36    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 37    radial_basis: Optional[dict] = None
 38    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. 
 39        If None, the radial basis is the s_ij like in the original DeepPot."""
 40    embedding_hidden: Sequence[int] = dataclasses.field(
 41        default_factory=lambda: [64, 64, 64]
 42    )
 43    """The hidden layers of the embedding network."""
 44    activation: Union[Callable, str] = "silu"
 45    """The activation function."""
 46    concatenate_species: bool = False
 47    """Whether to concatenate the species encoding with the embedding."""
 48    divide_distances: bool = True
 49    """Whether to divide the switch by the distance in s_ij."""
 50    species_order: Optional[Union[str,Sequence[str]]] = None
 51    """Species considered by the network when using species-specialized embedding network."""
 52
 53    FID: ClassVar[str] = "DEEPPOT"
 54
 55    @nn.compact
 56    def __call__(self, inputs):
 57        species = inputs["species"]
 58
 59        # species encoding
 60        if self.species_order is None or self.concatenate_species:
 61            onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 62                species
 63            )
 64
 65        # Radial graph
 66        graph = inputs[self.graph_key]
 67        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 68        distances = graph["distances"][:, None]
 69        switch = graph["switch"][:, None]
 70        vec = graph["vec"] / distances
 71        sij = switch / distances if self.divide_distances else switch
 72        Rij = jnp.concatenate((sij, sij * vec), axis=-1)
 73
 74        # Radial BASIS
 75        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 76        if self.radial_basis is not None:
 77            radial_terms = RadialBasis(
 78                **{
 79                    **self.radial_basis,
 80                    "end": cutoff,
 81                    "name": f"RadialBasis",
 82                }
 83            )(graph["distances"])
 84        else:
 85            radial_terms = sij
 86
 87        if self.species_order is not None:
 88            Gij = ChemicalNet(
 89                self.species_order,
 90                [*self.embedding_hidden, self.dim],
 91                activation=self.activation,
 92            )((species[edge_dst], radial_terms))
 93        elif self.radial_dim is not None:
 94            Gij = FullyConnectedNet(
 95                [*self.embedding_hidden, self.radial_dim], activation=self.activation
 96            )(radial_terms)
 97            Wa = self.param(
 98                f"Wa",
 99                nn.initializers.normal(
100                    stddev=1.0 / (Gij.shape[1] * onehot.shape[1]) ** 0.5
101                ),
102                (onehot.shape[1], Gij.shape[1], self.dim),
103            )
104            Gij = jnp.einsum(
105                "...i,...j,ijk->...k",
106                onehot[edge_dst],
107                Gij,
108                Wa,
109            )
110        else:
111            Gij = FullyConnectedNet(
112                [*self.embedding_hidden, self.dim], activation=self.activation
113            )(jnp.concatenate((radial_terms, onehot[edge_dst]), axis=-1))
114
115        GRi = jax.ops.segment_sum(
116            Gij[:, None, :] * Rij[:, :, None], edge_src, species.shape[0]
117        )
118        if self.subdim > 0:
119            GRisub = GRi[:, :, : self.subdim]
120
121            embedding = (
122                (GRi[:, :, :, None] * GRisub[:, :, None, :])
123                .sum(axis=1)
124                .reshape((species.shape[0], -1))
125            )
126        else:
127            GRisub = nn.Dense(self.dim, use_bias=False, name="Gri_linear")(GRi)
128            embedding = (GRi * GRisub).sum(axis=1)
129
130        if self.concatenate_species:
131            embedding = jnp.concatenate((onehot, embedding), axis=-1)
132
133        if self.embedding_key is None:
134            return embedding
135        return {**inputs, self.embedding_key: embedding}
136
137
138class DeepPotE3Embedding(nn.Module):
139    """Deep Potential embedding with angle information
140
141    FID : DEEPPOT_E3
142
143    ### Reference
144    L. Zhang, J. Han, H. Wang, W. A. Saidi, R. Car, Weinan E, End-to-end Symmetry Preserving Inter-atomic Potential Energy Model for Finite and Extended Systems,
145    Conference on Neural Information Processing Systems (NeurIPS), 2018,
146    https://doi.org/10.48550/arXiv.1805.09003
147
148    """
149    _graphs_properties: Dict
150    dim: int = 64
151    """The dimension of the embedding."""
152    embedding_key: str = "embedding"
153    """The key to use for the output embedding in the returned dictionary."""
154    graph_key: str = "graph"
155    """The key in the input dictionary that corresponds to the graph."""
156    species_encoding: dict = dataclasses.field(default_factory=dict)
157    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
158    embedding_hidden: Sequence[int] = dataclasses.field(
159        default_factory=lambda: [64, 64, 64]
160    )
161    """The hidden layers of the embedding network."""
162    activation: Union[Callable, str] = "silu"
163    """The activation function."""
164    concatenate_species: bool = False
165    """Whether to concatenate the species encoding with the embedding."""
166    divide_distances: bool = True
167    """Whether to divide the switch by the distance in s_ij."""
168
169    FID: ClassVar[str] = "DEEPPOT_E3"
170
171    @nn.compact
172    def __call__(self, inputs):
173        species = inputs["species"]
174
175        # species encoding
176        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
177            species
178        )
179
180        # Radial graph
181        graph = inputs[self.graph_key]
182        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
183        distances = graph["distances"][:, None]
184        switch = graph["switch"][:, None]
185        vec = graph["vec"] / distances
186        sij = switch / distances if self.divide_distances else switch
187        Rij = jnp.concatenate((sij, sij * vec), axis=-1)
188
189        zdest = onehot[edge_dst]
190
191        angle_src, angle_dst = graph["angle_src"], graph["angle_dst"]
192        z_angsrc = zdest[angle_src]
193        z_angdst = zdest[angle_dst]
194
195        # Radial BASIS
196        assert (
197            "angles" in graph
198        ), "Error: DeepPotE3 requires angles (GRAPH_ANGLE_EXTENSION)"
199        theta = (Rij[angle_src] * Rij[angle_dst]).sum(axis=-1,keepdims=True)
200
201        Ne3 = FullyConnectedNet(
202            [*self.embedding_hidden, self.dim],
203            activation=self.activation,
204            name="Ne3",
205        )
206        Gijk = Ne3(jnp.concatenate((theta, z_angsrc, z_angdst), axis=-1)) + Ne3(
207            jnp.concatenate((theta, z_angdst, z_angsrc), axis=-1)
208        )
209
210        embedding = jax.ops.segment_sum(
211            Gijk * theta, graph["central_atom"], species.shape[0]
212        )
213
214        if self.concatenate_species:
215            embedding = jnp.concatenate((onehot, embedding), axis=-1)
216
217        if self.embedding_key is None:
218            return embedding
219        return {**inputs, self.embedding_key: embedding}
class DeepPotEmbedding(flax.linen.module.Module):
 15class DeepPotEmbedding(nn.Module):
 16    """Deep Potential embedding
 17
 18    FID : DEEPPOT
 19
 20    ### Reference
 21    Zhang, L., Han, J., Wang, H., Car, R., & E, W. (2018). Deep Potential Molecular dynamics: A scalable model with the accuracy of quantum mechanics. Phys. Rev. Lett., 120(14), 143001. https://doi.org/10.1103/PhysRevLett.120.143001
 22
 23    """
 24    _graphs_properties: Dict
 25    dim: int = 64
 26    """The dimension of the embedding."""
 27    subdim: int = 8
 28    """The first dimensions to select for the embedding tensor product."""
 29    radial_dim: Optional[int] = None
 30    """The dimension of the radial embedding for tensor combination. 
 31        If None, we use a neural net to combine chemical and radial information, like in the original DeepPot."""
 32    embedding_key: str = "embedding"
 33    """The key to use for the output embedding in the returned dictionary."""
 34    graph_key: str = "graph"
 35    """The key in the input dictionary that corresponds to the radial graph."""
 36    species_encoding: dict = dataclasses.field(default_factory=dict)
 37    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 38    radial_basis: Optional[dict] = None
 39    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. 
 40        If None, the radial basis is the s_ij like in the original DeepPot."""
 41    embedding_hidden: Sequence[int] = dataclasses.field(
 42        default_factory=lambda: [64, 64, 64]
 43    )
 44    """The hidden layers of the embedding network."""
 45    activation: Union[Callable, str] = "silu"
 46    """The activation function."""
 47    concatenate_species: bool = False
 48    """Whether to concatenate the species encoding with the embedding."""
 49    divide_distances: bool = True
 50    """Whether to divide the switch by the distance in s_ij."""
 51    species_order: Optional[Union[str,Sequence[str]]] = None
 52    """Species considered by the network when using species-specialized embedding network."""
 53
 54    FID: ClassVar[str] = "DEEPPOT"
 55
 56    @nn.compact
 57    def __call__(self, inputs):
 58        species = inputs["species"]
 59
 60        # species encoding
 61        if self.species_order is None or self.concatenate_species:
 62            onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 63                species
 64            )
 65
 66        # Radial graph
 67        graph = inputs[self.graph_key]
 68        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 69        distances = graph["distances"][:, None]
 70        switch = graph["switch"][:, None]
 71        vec = graph["vec"] / distances
 72        sij = switch / distances if self.divide_distances else switch
 73        Rij = jnp.concatenate((sij, sij * vec), axis=-1)
 74
 75        # Radial BASIS
 76        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 77        if self.radial_basis is not None:
 78            radial_terms = RadialBasis(
 79                **{
 80                    **self.radial_basis,
 81                    "end": cutoff,
 82                    "name": f"RadialBasis",
 83                }
 84            )(graph["distances"])
 85        else:
 86            radial_terms = sij
 87
 88        if self.species_order is not None:
 89            Gij = ChemicalNet(
 90                self.species_order,
 91                [*self.embedding_hidden, self.dim],
 92                activation=self.activation,
 93            )((species[edge_dst], radial_terms))
 94        elif self.radial_dim is not None:
 95            Gij = FullyConnectedNet(
 96                [*self.embedding_hidden, self.radial_dim], activation=self.activation
 97            )(radial_terms)
 98            Wa = self.param(
 99                f"Wa",
100                nn.initializers.normal(
101                    stddev=1.0 / (Gij.shape[1] * onehot.shape[1]) ** 0.5
102                ),
103                (onehot.shape[1], Gij.shape[1], self.dim),
104            )
105            Gij = jnp.einsum(
106                "...i,...j,ijk->...k",
107                onehot[edge_dst],
108                Gij,
109                Wa,
110            )
111        else:
112            Gij = FullyConnectedNet(
113                [*self.embedding_hidden, self.dim], activation=self.activation
114            )(jnp.concatenate((radial_terms, onehot[edge_dst]), axis=-1))
115
116        GRi = jax.ops.segment_sum(
117            Gij[:, None, :] * Rij[:, :, None], edge_src, species.shape[0]
118        )
119        if self.subdim > 0:
120            GRisub = GRi[:, :, : self.subdim]
121
122            embedding = (
123                (GRi[:, :, :, None] * GRisub[:, :, None, :])
124                .sum(axis=1)
125                .reshape((species.shape[0], -1))
126            )
127        else:
128            GRisub = nn.Dense(self.dim, use_bias=False, name="Gri_linear")(GRi)
129            embedding = (GRi * GRisub).sum(axis=1)
130
131        if self.concatenate_species:
132            embedding = jnp.concatenate((onehot, embedding), axis=-1)
133
134        if self.embedding_key is None:
135            return embedding
136        return {**inputs, self.embedding_key: embedding}

Deep Potential embedding

FID : DEEPPOT

Reference

Zhang, L., Han, J., Wang, H., Car, R., & E, W. (2018). Deep Potential Molecular dynamics: A scalable model with the accuracy of quantum mechanics. Phys. Rev. Lett., 120(14), 143001. https://doi.org/10.1103/PhysRevLett.120.143001

DeepPotEmbedding( _graphs_properties: Dict, dim: int = 64, subdim: int = 8, radial_dim: Optional[int] = None, embedding_key: str = 'embedding', graph_key: str = 'graph', species_encoding: dict = <factory>, radial_basis: Optional[dict] = None, embedding_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', concatenate_species: bool = False, divide_distances: bool = True, species_order: Union[str, Sequence[str], NoneType] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 64

The dimension of the embedding.

subdim: int = 8

The first dimensions to select for the embedding tensor product.

radial_dim: Optional[int] = None

The dimension of the radial embedding for tensor combination. If None, we use a neural net to combine chemical and radial information, like in the original DeepPot.

embedding_key: str = 'embedding'

The key to use for the output embedding in the returned dictionary.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the radial graph.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding

radial_basis: Optional[dict] = None

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis. If None, the radial basis is the s_ij like in the original DeepPot.

embedding_hidden: Sequence[int]

The hidden layers of the embedding network.

activation: Union[Callable, str] = 'silu'

The activation function.

concatenate_species: bool = False

Whether to concatenate the species encoding with the embedding.

divide_distances: bool = True

Whether to divide the switch by the distance in s_ij.

species_order: Union[str, Sequence[str], NoneType] = None

Species considered by the network when using species-specialized embedding network.

FID: ClassVar[str] = 'DEEPPOT'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class DeepPotE3Embedding(flax.linen.module.Module):
139class DeepPotE3Embedding(nn.Module):
140    """Deep Potential embedding with angle information
141
142    FID : DEEPPOT_E3
143
144    ### Reference
145    L. Zhang, J. Han, H. Wang, W. A. Saidi, R. Car, Weinan E, End-to-end Symmetry Preserving Inter-atomic Potential Energy Model for Finite and Extended Systems,
146    Conference on Neural Information Processing Systems (NeurIPS), 2018,
147    https://doi.org/10.48550/arXiv.1805.09003
148
149    """
150    _graphs_properties: Dict
151    dim: int = 64
152    """The dimension of the embedding."""
153    embedding_key: str = "embedding"
154    """The key to use for the output embedding in the returned dictionary."""
155    graph_key: str = "graph"
156    """The key in the input dictionary that corresponds to the graph."""
157    species_encoding: dict = dataclasses.field(default_factory=dict)
158    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
159    embedding_hidden: Sequence[int] = dataclasses.field(
160        default_factory=lambda: [64, 64, 64]
161    )
162    """The hidden layers of the embedding network."""
163    activation: Union[Callable, str] = "silu"
164    """The activation function."""
165    concatenate_species: bool = False
166    """Whether to concatenate the species encoding with the embedding."""
167    divide_distances: bool = True
168    """Whether to divide the switch by the distance in s_ij."""
169
170    FID: ClassVar[str] = "DEEPPOT_E3"
171
172    @nn.compact
173    def __call__(self, inputs):
174        species = inputs["species"]
175
176        # species encoding
177        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
178            species
179        )
180
181        # Radial graph
182        graph = inputs[self.graph_key]
183        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
184        distances = graph["distances"][:, None]
185        switch = graph["switch"][:, None]
186        vec = graph["vec"] / distances
187        sij = switch / distances if self.divide_distances else switch
188        Rij = jnp.concatenate((sij, sij * vec), axis=-1)
189
190        zdest = onehot[edge_dst]
191
192        angle_src, angle_dst = graph["angle_src"], graph["angle_dst"]
193        z_angsrc = zdest[angle_src]
194        z_angdst = zdest[angle_dst]
195
196        # Radial BASIS
197        assert (
198            "angles" in graph
199        ), "Error: DeepPotE3 requires angles (GRAPH_ANGLE_EXTENSION)"
200        theta = (Rij[angle_src] * Rij[angle_dst]).sum(axis=-1,keepdims=True)
201
202        Ne3 = FullyConnectedNet(
203            [*self.embedding_hidden, self.dim],
204            activation=self.activation,
205            name="Ne3",
206        )
207        Gijk = Ne3(jnp.concatenate((theta, z_angsrc, z_angdst), axis=-1)) + Ne3(
208            jnp.concatenate((theta, z_angdst, z_angsrc), axis=-1)
209        )
210
211        embedding = jax.ops.segment_sum(
212            Gijk * theta, graph["central_atom"], species.shape[0]
213        )
214
215        if self.concatenate_species:
216            embedding = jnp.concatenate((onehot, embedding), axis=-1)
217
218        if self.embedding_key is None:
219            return embedding
220        return {**inputs, self.embedding_key: embedding}

Deep Potential embedding with angle information

FID : DEEPPOT_E3

Reference

L. Zhang, J. Han, H. Wang, W. A. Saidi, R. Car, Weinan E, End-to-end Symmetry Preserving Inter-atomic Potential Energy Model for Finite and Extended Systems, Conference on Neural Information Processing Systems (NeurIPS), 2018, https://doi.org/10.48550/arXiv.1805.09003

DeepPotE3Embedding( _graphs_properties: Dict, dim: int = 64, embedding_key: str = 'embedding', graph_key: str = 'graph', species_encoding: dict = <factory>, embedding_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', concatenate_species: bool = False, divide_distances: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 64

The dimension of the embedding.

embedding_key: str = 'embedding'

The key to use for the output embedding in the returned dictionary.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the graph.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding

embedding_hidden: Sequence[int]

The hidden layers of the embedding network.

activation: Union[Callable, str] = 'silu'

The activation function.

concatenate_species: bool = False

Whether to concatenate the species encoding with the embedding.

divide_distances: bool = True

Whether to divide the switch by the distance in s_ij.

FID: ClassVar[str] = 'DEEPPOT_E3'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None