fennol.models.embeddings.deeppot
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, Optional, Union, Callable, ClassVar 5import numpy as np 6import dataclasses 7 8from ..misc.nets import FullyConnectedNet, ChemicalNet 9 10from ...utils.periodic_table import PERIODIC_TABLE, VALENCE_STRUCTURE 11from ..misc.encodings import SpeciesEncoding, RadialBasis 12 13 14class DeepPotEmbedding(nn.Module): 15 """Deep Potential embedding 16 17 FID : DEEPPOT 18 19 ### Reference 20 Zhang, L., Han, J., Wang, H., Car, R., & E, W. (2018). Deep Potential Molecular dynamics: A scalable model with the accuracy of quantum mechanics. Phys. Rev. Lett., 120(14), 143001. https://doi.org/10.1103/PhysRevLett.120.143001 21 22 """ 23 _graphs_properties: Dict 24 dim: int = 64 25 """The dimension of the embedding.""" 26 subdim: int = 8 27 """The first dimensions to select for the embedding tensor product.""" 28 radial_dim: Optional[int] = None 29 """The dimension of the radial embedding for tensor combination. 30 If None, we use a neural net to combine chemical and radial information, like in the original DeepPot.""" 31 embedding_key: str = "embedding" 32 """The key to use for the output embedding in the returned dictionary.""" 33 graph_key: str = "graph" 34 """The key in the input dictionary that corresponds to the radial graph.""" 35 species_encoding: dict = dataclasses.field(default_factory=dict) 36 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 37 radial_basis: Optional[dict] = None 38 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. 39 If None, the radial basis is the s_ij like in the original DeepPot.""" 40 embedding_hidden: Sequence[int] = dataclasses.field( 41 default_factory=lambda: [64, 64, 64] 42 ) 43 """The hidden layers of the embedding network.""" 44 activation: Union[Callable, str] = "silu" 45 """The activation function.""" 46 concatenate_species: bool = False 47 """Whether to concatenate the species encoding with the embedding.""" 48 divide_distances: bool = True 49 """Whether to divide the switch by the distance in s_ij.""" 50 species_order: Optional[Union[str,Sequence[str]]] = None 51 """Species considered by the network when using species-specialized embedding network.""" 52 53 FID: ClassVar[str] = "DEEPPOT" 54 55 @nn.compact 56 def __call__(self, inputs): 57 species = inputs["species"] 58 59 # species encoding 60 if self.species_order is None or self.concatenate_species: 61 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 62 species 63 ) 64 65 # Radial graph 66 graph = inputs[self.graph_key] 67 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 68 distances = graph["distances"][:, None] 69 switch = graph["switch"][:, None] 70 vec = graph["vec"] / distances 71 sij = switch / distances if self.divide_distances else switch 72 Rij = jnp.concatenate((sij, sij * vec), axis=-1) 73 74 # Radial BASIS 75 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 76 if self.radial_basis is not None: 77 radial_terms = RadialBasis( 78 **{ 79 **self.radial_basis, 80 "end": cutoff, 81 "name": f"RadialBasis", 82 } 83 )(graph["distances"]) 84 else: 85 radial_terms = sij 86 87 if self.species_order is not None: 88 Gij = ChemicalNet( 89 self.species_order, 90 [*self.embedding_hidden, self.dim], 91 activation=self.activation, 92 )((species[edge_dst], radial_terms)) 93 elif self.radial_dim is not None: 94 Gij = FullyConnectedNet( 95 [*self.embedding_hidden, self.radial_dim], activation=self.activation 96 )(radial_terms) 97 Wa = self.param( 98 f"Wa", 99 nn.initializers.normal( 100 stddev=1.0 / (Gij.shape[1] * onehot.shape[1]) ** 0.5 101 ), 102 (onehot.shape[1], Gij.shape[1], self.dim), 103 ) 104 Gij = jnp.einsum( 105 "...i,...j,ijk->...k", 106 onehot[edge_dst], 107 Gij, 108 Wa, 109 ) 110 else: 111 Gij = FullyConnectedNet( 112 [*self.embedding_hidden, self.dim], activation=self.activation 113 )(jnp.concatenate((radial_terms, onehot[edge_dst]), axis=-1)) 114 115 GRi = jax.ops.segment_sum( 116 Gij[:, None, :] * Rij[:, :, None], edge_src, species.shape[0] 117 ) 118 if self.subdim > 0: 119 GRisub = GRi[:, :, : self.subdim] 120 121 embedding = ( 122 (GRi[:, :, :, None] * GRisub[:, :, None, :]) 123 .sum(axis=1) 124 .reshape((species.shape[0], -1)) 125 ) 126 else: 127 GRisub = nn.Dense(self.dim, use_bias=False, name="Gri_linear")(GRi) 128 embedding = (GRi * GRisub).sum(axis=1) 129 130 if self.concatenate_species: 131 embedding = jnp.concatenate((onehot, embedding), axis=-1) 132 133 if self.embedding_key is None: 134 return embedding 135 return {**inputs, self.embedding_key: embedding} 136 137 138class DeepPotE3Embedding(nn.Module): 139 """Deep Potential embedding with angle information 140 141 FID : DEEPPOT_E3 142 143 ### Reference 144 L. Zhang, J. Han, H. Wang, W. A. Saidi, R. Car, Weinan E, End-to-end Symmetry Preserving Inter-atomic Potential Energy Model for Finite and Extended Systems, 145 Conference on Neural Information Processing Systems (NeurIPS), 2018, 146 https://doi.org/10.48550/arXiv.1805.09003 147 148 """ 149 _graphs_properties: Dict 150 dim: int = 64 151 """The dimension of the embedding.""" 152 embedding_key: str = "embedding" 153 """The key to use for the output embedding in the returned dictionary.""" 154 graph_key: str = "graph" 155 """The key in the input dictionary that corresponds to the graph.""" 156 species_encoding: dict = dataclasses.field(default_factory=dict) 157 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 158 embedding_hidden: Sequence[int] = dataclasses.field( 159 default_factory=lambda: [64, 64, 64] 160 ) 161 """The hidden layers of the embedding network.""" 162 activation: Union[Callable, str] = "silu" 163 """The activation function.""" 164 concatenate_species: bool = False 165 """Whether to concatenate the species encoding with the embedding.""" 166 divide_distances: bool = True 167 """Whether to divide the switch by the distance in s_ij.""" 168 169 FID: ClassVar[str] = "DEEPPOT_E3" 170 171 @nn.compact 172 def __call__(self, inputs): 173 species = inputs["species"] 174 175 # species encoding 176 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 177 species 178 ) 179 180 # Radial graph 181 graph = inputs[self.graph_key] 182 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 183 distances = graph["distances"][:, None] 184 switch = graph["switch"][:, None] 185 vec = graph["vec"] / distances 186 sij = switch / distances if self.divide_distances else switch 187 Rij = jnp.concatenate((sij, sij * vec), axis=-1) 188 189 zdest = onehot[edge_dst] 190 191 angle_src, angle_dst = graph["angle_src"], graph["angle_dst"] 192 z_angsrc = zdest[angle_src] 193 z_angdst = zdest[angle_dst] 194 195 # Radial BASIS 196 assert ( 197 "angles" in graph 198 ), "Error: DeepPotE3 requires angles (GRAPH_ANGLE_EXTENSION)" 199 theta = (Rij[angle_src] * Rij[angle_dst]).sum(axis=-1,keepdims=True) 200 201 Ne3 = FullyConnectedNet( 202 [*self.embedding_hidden, self.dim], 203 activation=self.activation, 204 name="Ne3", 205 ) 206 Gijk = Ne3(jnp.concatenate((theta, z_angsrc, z_angdst), axis=-1)) + Ne3( 207 jnp.concatenate((theta, z_angdst, z_angsrc), axis=-1) 208 ) 209 210 embedding = jax.ops.segment_sum( 211 Gijk * theta, graph["central_atom"], species.shape[0] 212 ) 213 214 if self.concatenate_species: 215 embedding = jnp.concatenate((onehot, embedding), axis=-1) 216 217 if self.embedding_key is None: 218 return embedding 219 return {**inputs, self.embedding_key: embedding}
15class DeepPotEmbedding(nn.Module): 16 """Deep Potential embedding 17 18 FID : DEEPPOT 19 20 ### Reference 21 Zhang, L., Han, J., Wang, H., Car, R., & E, W. (2018). Deep Potential Molecular dynamics: A scalable model with the accuracy of quantum mechanics. Phys. Rev. Lett., 120(14), 143001. https://doi.org/10.1103/PhysRevLett.120.143001 22 23 """ 24 _graphs_properties: Dict 25 dim: int = 64 26 """The dimension of the embedding.""" 27 subdim: int = 8 28 """The first dimensions to select for the embedding tensor product.""" 29 radial_dim: Optional[int] = None 30 """The dimension of the radial embedding for tensor combination. 31 If None, we use a neural net to combine chemical and radial information, like in the original DeepPot.""" 32 embedding_key: str = "embedding" 33 """The key to use for the output embedding in the returned dictionary.""" 34 graph_key: str = "graph" 35 """The key in the input dictionary that corresponds to the radial graph.""" 36 species_encoding: dict = dataclasses.field(default_factory=dict) 37 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 38 radial_basis: Optional[dict] = None 39 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`. 40 If None, the radial basis is the s_ij like in the original DeepPot.""" 41 embedding_hidden: Sequence[int] = dataclasses.field( 42 default_factory=lambda: [64, 64, 64] 43 ) 44 """The hidden layers of the embedding network.""" 45 activation: Union[Callable, str] = "silu" 46 """The activation function.""" 47 concatenate_species: bool = False 48 """Whether to concatenate the species encoding with the embedding.""" 49 divide_distances: bool = True 50 """Whether to divide the switch by the distance in s_ij.""" 51 species_order: Optional[Union[str,Sequence[str]]] = None 52 """Species considered by the network when using species-specialized embedding network.""" 53 54 FID: ClassVar[str] = "DEEPPOT" 55 56 @nn.compact 57 def __call__(self, inputs): 58 species = inputs["species"] 59 60 # species encoding 61 if self.species_order is None or self.concatenate_species: 62 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 63 species 64 ) 65 66 # Radial graph 67 graph = inputs[self.graph_key] 68 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 69 distances = graph["distances"][:, None] 70 switch = graph["switch"][:, None] 71 vec = graph["vec"] / distances 72 sij = switch / distances if self.divide_distances else switch 73 Rij = jnp.concatenate((sij, sij * vec), axis=-1) 74 75 # Radial BASIS 76 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 77 if self.radial_basis is not None: 78 radial_terms = RadialBasis( 79 **{ 80 **self.radial_basis, 81 "end": cutoff, 82 "name": f"RadialBasis", 83 } 84 )(graph["distances"]) 85 else: 86 radial_terms = sij 87 88 if self.species_order is not None: 89 Gij = ChemicalNet( 90 self.species_order, 91 [*self.embedding_hidden, self.dim], 92 activation=self.activation, 93 )((species[edge_dst], radial_terms)) 94 elif self.radial_dim is not None: 95 Gij = FullyConnectedNet( 96 [*self.embedding_hidden, self.radial_dim], activation=self.activation 97 )(radial_terms) 98 Wa = self.param( 99 f"Wa", 100 nn.initializers.normal( 101 stddev=1.0 / (Gij.shape[1] * onehot.shape[1]) ** 0.5 102 ), 103 (onehot.shape[1], Gij.shape[1], self.dim), 104 ) 105 Gij = jnp.einsum( 106 "...i,...j,ijk->...k", 107 onehot[edge_dst], 108 Gij, 109 Wa, 110 ) 111 else: 112 Gij = FullyConnectedNet( 113 [*self.embedding_hidden, self.dim], activation=self.activation 114 )(jnp.concatenate((radial_terms, onehot[edge_dst]), axis=-1)) 115 116 GRi = jax.ops.segment_sum( 117 Gij[:, None, :] * Rij[:, :, None], edge_src, species.shape[0] 118 ) 119 if self.subdim > 0: 120 GRisub = GRi[:, :, : self.subdim] 121 122 embedding = ( 123 (GRi[:, :, :, None] * GRisub[:, :, None, :]) 124 .sum(axis=1) 125 .reshape((species.shape[0], -1)) 126 ) 127 else: 128 GRisub = nn.Dense(self.dim, use_bias=False, name="Gri_linear")(GRi) 129 embedding = (GRi * GRisub).sum(axis=1) 130 131 if self.concatenate_species: 132 embedding = jnp.concatenate((onehot, embedding), axis=-1) 133 134 if self.embedding_key is None: 135 return embedding 136 return {**inputs, self.embedding_key: embedding}
Deep Potential embedding
FID : DEEPPOT
Reference
Zhang, L., Han, J., Wang, H., Car, R., & E, W. (2018). Deep Potential Molecular dynamics: A scalable model with the accuracy of quantum mechanics. Phys. Rev. Lett., 120(14), 143001. https://doi.org/10.1103/PhysRevLett.120.143001
The dimension of the radial embedding for tensor combination. If None, we use a neural net to combine chemical and radial information, like in the original DeepPot.
The key to use for the output embedding in the returned dictionary.
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
The radial basis parameters. See fennol.models.misc.encodings.RadialBasis
.
If None, the radial basis is the s_ij like in the original DeepPot.
Species considered by the network when using species-specialized embedding network.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
139class DeepPotE3Embedding(nn.Module): 140 """Deep Potential embedding with angle information 141 142 FID : DEEPPOT_E3 143 144 ### Reference 145 L. Zhang, J. Han, H. Wang, W. A. Saidi, R. Car, Weinan E, End-to-end Symmetry Preserving Inter-atomic Potential Energy Model for Finite and Extended Systems, 146 Conference on Neural Information Processing Systems (NeurIPS), 2018, 147 https://doi.org/10.48550/arXiv.1805.09003 148 149 """ 150 _graphs_properties: Dict 151 dim: int = 64 152 """The dimension of the embedding.""" 153 embedding_key: str = "embedding" 154 """The key to use for the output embedding in the returned dictionary.""" 155 graph_key: str = "graph" 156 """The key in the input dictionary that corresponds to the graph.""" 157 species_encoding: dict = dataclasses.field(default_factory=dict) 158 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 159 embedding_hidden: Sequence[int] = dataclasses.field( 160 default_factory=lambda: [64, 64, 64] 161 ) 162 """The hidden layers of the embedding network.""" 163 activation: Union[Callable, str] = "silu" 164 """The activation function.""" 165 concatenate_species: bool = False 166 """Whether to concatenate the species encoding with the embedding.""" 167 divide_distances: bool = True 168 """Whether to divide the switch by the distance in s_ij.""" 169 170 FID: ClassVar[str] = "DEEPPOT_E3" 171 172 @nn.compact 173 def __call__(self, inputs): 174 species = inputs["species"] 175 176 # species encoding 177 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 178 species 179 ) 180 181 # Radial graph 182 graph = inputs[self.graph_key] 183 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 184 distances = graph["distances"][:, None] 185 switch = graph["switch"][:, None] 186 vec = graph["vec"] / distances 187 sij = switch / distances if self.divide_distances else switch 188 Rij = jnp.concatenate((sij, sij * vec), axis=-1) 189 190 zdest = onehot[edge_dst] 191 192 angle_src, angle_dst = graph["angle_src"], graph["angle_dst"] 193 z_angsrc = zdest[angle_src] 194 z_angdst = zdest[angle_dst] 195 196 # Radial BASIS 197 assert ( 198 "angles" in graph 199 ), "Error: DeepPotE3 requires angles (GRAPH_ANGLE_EXTENSION)" 200 theta = (Rij[angle_src] * Rij[angle_dst]).sum(axis=-1,keepdims=True) 201 202 Ne3 = FullyConnectedNet( 203 [*self.embedding_hidden, self.dim], 204 activation=self.activation, 205 name="Ne3", 206 ) 207 Gijk = Ne3(jnp.concatenate((theta, z_angsrc, z_angdst), axis=-1)) + Ne3( 208 jnp.concatenate((theta, z_angdst, z_angsrc), axis=-1) 209 ) 210 211 embedding = jax.ops.segment_sum( 212 Gijk * theta, graph["central_atom"], species.shape[0] 213 ) 214 215 if self.concatenate_species: 216 embedding = jnp.concatenate((onehot, embedding), axis=-1) 217 218 if self.embedding_key is None: 219 return embedding 220 return {**inputs, self.embedding_key: embedding}
Deep Potential embedding with angle information
FID : DEEPPOT_E3
Reference
L. Zhang, J. Han, H. Wang, W. A. Saidi, R. Car, Weinan E, End-to-end Symmetry Preserving Inter-atomic Potential Energy Model for Finite and Extended Systems, Conference on Neural Information Processing Systems (NeurIPS), 2018, https://doi.org/10.48550/arXiv.1805.09003
The key to use for the output embedding in the returned dictionary.
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.