fennol.models.embeddings.newtonnet
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Dict, Union, Callable, Sequence, Optional, ClassVar 9from ...utils.activations import activation_from_str 10from ..misc.nets import FullyConnectedNet 11 12 13class NewtonNetEmbedding(nn.Module): 14 """ Newtonian message passing network 15 16 ### Reference 17 Haghighatlari et al., NewtonNet: a Newtonian message passing network for deep learning of interatomic potentials and forces 18 https://doi.org/10.1039/D2DD00008C 19 20 """ 21 22 _graphs_properties: Dict 23 dim: int = 128 24 """The dimension of the embedding.""" 25 nlayers: int = 3 26 """The number of interaction layers.""" 27 nchannels: Optional[int] = None 28 """The number of vector channels. If None, it is set to dim.""" 29 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 30 """The hidden layers for the embedding networks.""" 31 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 32 """The hidden layers for the latent update network.""" 33 activation: Union[Callable, str] = "silu" 34 """The activation function.""" 35 graph_key: str = "graph" 36 """The key for the graph input.""" 37 embedding_key: str = "embedding" 38 """The key for the output embedding.""" 39 species_encoding: dict = dataclasses.field(default_factory=dict) 40 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 41 radial_basis: dict = dataclasses.field(default_factory=dict) 42 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`.""" 43 keep_all_layers: bool = False 44 """Whether to keep embeddings from each layer in the output.""" 45 46 FID: ClassVar[str] = "NEWTONNET" 47 48 49 @nn.compact 50 def __call__(self, inputs): 51 species = inputs["species"] 52 assert ( 53 len(species.shape) == 1 54 ), "Species must be a 1D array (batches must be flattened)" 55 56 graph = inputs[self.graph_key] 57 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 58 59 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 60 61 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 62 species 63 ) 64 xi = nn.Dense(self.dim, name="species_linear", use_bias=True)(onehot) 65 66 nchannels = self.nchannels if self.nchannels is not None else self.dim 67 68 distances = graph["distances"] 69 switch = graph["switch"][:, None] 70 dirij = graph["vec"] / distances[:, None] * switch 71 72 radial_basis = RadialBasis( 73 **{ 74 **self.radial_basis, 75 "end": cutoff, 76 "name": f"RadialBasis", 77 } 78 )(distances) 79 80 if self.keep_all_layers: 81 xis = [] 82 for layer in range(self.nlayers): 83 ai = FullyConnectedNet( 84 [*self.embedding_hidden, self.dim], 85 activation=self.activation, 86 name=f"phi_a_{layer}", 87 use_bias=True, 88 )(xi) 89 Dij = nn.Dense(self.dim, name=f"radial_linear_{layer}", use_bias=True)( 90 radial_basis 91 ) 92 mij = ai[edge_src] * ai[edge_dst] * Dij * switch 93 94 mi = jax.ops.segment_sum(mij, edge_src, xi.shape[0]) 95 xi = xi + mi 96 97 Fij = ( 98 FullyConnectedNet( 99 [*self.embedding_hidden, 1], 100 activation=self.activation, 101 name=f"phi_F_{layer}", 102 use_bias=True, 103 )(mij) 104 * dirij 105 ) 106 107 fij = ( 108 FullyConnectedNet( 109 [*self.embedding_hidden, nchannels], 110 activation=self.activation, 111 name=f"phi_f_{layer}", 112 use_bias=True, 113 )(mij)[:, :, None] 114 * Fij[:, None, :] 115 ) 116 117 if layer == 0: 118 fi = jax.ops.segment_sum(fij, edge_src, xi.shape[0]) 119 else: 120 fi = fi + jax.ops.segment_sum(fij, edge_src, xi.shape[0]) 121 122 deltai = ( 123 FullyConnectedNet( 124 [*self.embedding_hidden, nchannels], 125 activation=self.activation, 126 name=f"phi_R_{layer}", 127 use_bias=True, 128 )(xi)[:, :, None] 129 * fi 130 ) 131 if layer == 0: 132 di = deltai 133 else: 134 phi_rij = FullyConnectedNet( 135 [*self.embedding_hidden, nchannels], 136 activation=self.activation, 137 name=f"phi_r_{layer}", 138 use_bias=True, 139 )(mij) 140 141 phi_r = jax.ops.segment_sum(phi_rij * switch, edge_src, xi.shape[0]) 142 di = phi_r[:, :, None] * di + deltai 143 144 scal = jnp.sum(fi * di, axis=-1) 145 dui = ( 146 -FullyConnectedNet( 147 [*self.latent_hidden, nchannels], 148 activation=self.activation, 149 name=f"phi_u_{layer}", 150 use_bias=True, 151 )(xi) 152 * scal 153 ) 154 155 if nchannels != self.dim: 156 dui = nn.Dense(self.dim, name=f"reshape_{layer}", use_bias=False)(dui) 157 158 xi = xi + dui 159 if self.keep_all_layers: 160 xis.append(xi) 161 162 output = { 163 **inputs, 164 self.embedding_key: xi, 165 } 166 if self.keep_all_layers: 167 output[self.embedding_key + "_layers"] = jnp.stack(xis, axis=1) 168 return output
14class NewtonNetEmbedding(nn.Module): 15 """ Newtonian message passing network 16 17 ### Reference 18 Haghighatlari et al., NewtonNet: a Newtonian message passing network for deep learning of interatomic potentials and forces 19 https://doi.org/10.1039/D2DD00008C 20 21 """ 22 23 _graphs_properties: Dict 24 dim: int = 128 25 """The dimension of the embedding.""" 26 nlayers: int = 3 27 """The number of interaction layers.""" 28 nchannels: Optional[int] = None 29 """The number of vector channels. If None, it is set to dim.""" 30 embedding_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 31 """The hidden layers for the embedding networks.""" 32 latent_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 33 """The hidden layers for the latent update network.""" 34 activation: Union[Callable, str] = "silu" 35 """The activation function.""" 36 graph_key: str = "graph" 37 """The key for the graph input.""" 38 embedding_key: str = "embedding" 39 """The key for the output embedding.""" 40 species_encoding: dict = dataclasses.field(default_factory=dict) 41 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 42 radial_basis: dict = dataclasses.field(default_factory=dict) 43 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`.""" 44 keep_all_layers: bool = False 45 """Whether to keep embeddings from each layer in the output.""" 46 47 FID: ClassVar[str] = "NEWTONNET" 48 49 50 @nn.compact 51 def __call__(self, inputs): 52 species = inputs["species"] 53 assert ( 54 len(species.shape) == 1 55 ), "Species must be a 1D array (batches must be flattened)" 56 57 graph = inputs[self.graph_key] 58 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 59 60 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 61 62 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 63 species 64 ) 65 xi = nn.Dense(self.dim, name="species_linear", use_bias=True)(onehot) 66 67 nchannels = self.nchannels if self.nchannels is not None else self.dim 68 69 distances = graph["distances"] 70 switch = graph["switch"][:, None] 71 dirij = graph["vec"] / distances[:, None] * switch 72 73 radial_basis = RadialBasis( 74 **{ 75 **self.radial_basis, 76 "end": cutoff, 77 "name": f"RadialBasis", 78 } 79 )(distances) 80 81 if self.keep_all_layers: 82 xis = [] 83 for layer in range(self.nlayers): 84 ai = FullyConnectedNet( 85 [*self.embedding_hidden, self.dim], 86 activation=self.activation, 87 name=f"phi_a_{layer}", 88 use_bias=True, 89 )(xi) 90 Dij = nn.Dense(self.dim, name=f"radial_linear_{layer}", use_bias=True)( 91 radial_basis 92 ) 93 mij = ai[edge_src] * ai[edge_dst] * Dij * switch 94 95 mi = jax.ops.segment_sum(mij, edge_src, xi.shape[0]) 96 xi = xi + mi 97 98 Fij = ( 99 FullyConnectedNet( 100 [*self.embedding_hidden, 1], 101 activation=self.activation, 102 name=f"phi_F_{layer}", 103 use_bias=True, 104 )(mij) 105 * dirij 106 ) 107 108 fij = ( 109 FullyConnectedNet( 110 [*self.embedding_hidden, nchannels], 111 activation=self.activation, 112 name=f"phi_f_{layer}", 113 use_bias=True, 114 )(mij)[:, :, None] 115 * Fij[:, None, :] 116 ) 117 118 if layer == 0: 119 fi = jax.ops.segment_sum(fij, edge_src, xi.shape[0]) 120 else: 121 fi = fi + jax.ops.segment_sum(fij, edge_src, xi.shape[0]) 122 123 deltai = ( 124 FullyConnectedNet( 125 [*self.embedding_hidden, nchannels], 126 activation=self.activation, 127 name=f"phi_R_{layer}", 128 use_bias=True, 129 )(xi)[:, :, None] 130 * fi 131 ) 132 if layer == 0: 133 di = deltai 134 else: 135 phi_rij = FullyConnectedNet( 136 [*self.embedding_hidden, nchannels], 137 activation=self.activation, 138 name=f"phi_r_{layer}", 139 use_bias=True, 140 )(mij) 141 142 phi_r = jax.ops.segment_sum(phi_rij * switch, edge_src, xi.shape[0]) 143 di = phi_r[:, :, None] * di + deltai 144 145 scal = jnp.sum(fi * di, axis=-1) 146 dui = ( 147 -FullyConnectedNet( 148 [*self.latent_hidden, nchannels], 149 activation=self.activation, 150 name=f"phi_u_{layer}", 151 use_bias=True, 152 )(xi) 153 * scal 154 ) 155 156 if nchannels != self.dim: 157 dui = nn.Dense(self.dim, name=f"reshape_{layer}", use_bias=False)(dui) 158 159 xi = xi + dui 160 if self.keep_all_layers: 161 xis.append(xi) 162 163 output = { 164 **inputs, 165 self.embedding_key: xi, 166 } 167 if self.keep_all_layers: 168 output[self.embedding_key + "_layers"] = jnp.stack(xis, axis=1) 169 return output
Newtonian message passing network
Reference
Haghighatlari et al., NewtonNet: a Newtonian message passing network for deep learning of interatomic potentials and forces https://doi.org/10.1039/D2DD00008C
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.