fennol.models.embeddings.painn
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Dict, Union, Callable, Sequence, Optional, ClassVar 9from ...utils.activations import activation_from_str, tssr3 10from ..misc.nets import FullyConnectedNet 11 12 13class PAINNEmbedding(nn.Module): 14 """polarizable atom interaction neural network 15 16 ### Reference 17 K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller. SchNet - a deep learning architecture for molecules and materials. The Journal of Chemical Physics 148(24), 241722 (2018) 18 https://doi.org/10.1063/1.5019779 19 20 """ 21 22 _graphs_properties: Dict 23 dim: int = 128 24 """ The dimension of the embedding. """ 25 nlayers: int = 3 26 """ The number of interaction layers. """ 27 nchannels: Optional[int] = None 28 """ The number of equivariant channels. If None, it is set to dim. """ 29 message_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 30 """ The hidden layers for the message network.""" 31 update_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 32 """ The hidden layers for the update network.""" 33 activation: Union[Callable, str] = "silu" 34 """ The activation function.""" 35 graph_key: str = "graph" 36 """ The key for the graph input.""" 37 embedding_key: str = "embedding" 38 """ The key for the embedding output.""" 39 tensor_embedding_key: str = "embedding_vectors" 40 """ The key for the tensor embedding output.""" 41 species_encoding: dict = dataclasses.field(default_factory=dict) 42 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 43 radial_basis: dict = dataclasses.field(default_factory=dict) 44 """ The radial basis function parameters. See `fennol.models.misc.encodings.RadialBasis`.""" 45 keep_all_layers: bool = False 46 """ Whether to keep the embedding from each layer in the output.""" 47 48 FID: ClassVar[str] = "PAINN" 49 50 @nn.compact 51 def __call__(self, inputs): 52 species = inputs["species"] 53 assert ( 54 len(species.shape) == 1 55 ), "Species must be a 1D array (batches must be flattened)" 56 57 graph = inputs[self.graph_key] 58 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 59 60 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 61 62 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 63 species 64 ) 65 xi = nn.Dense(self.dim, name="species_linear", use_bias=True)(onehot) 66 67 nchannels = self.nchannels if self.nchannels is not None else self.dim 68 69 distances = graph["distances"] 70 switch = graph["switch"][:, None] 71 dirij = (graph["vec"] / distances[:, None])[:, :,None] 72 Vi = jnp.zeros((xi.shape[0], 3, nchannels), dtype=xi.dtype) 73 74 radial_basis = RadialBasis( 75 **{ 76 **self.radial_basis, 77 "end": cutoff, 78 "name": f"RadialBasis", 79 } 80 )(distances) 81 82 if self.keep_all_layers: 83 xis = [] 84 for layer in range(self.nlayers): 85 # compute messages 86 phi = FullyConnectedNet( 87 [*self.message_hidden, self.dim + 2 * nchannels], 88 activation=self.activation, 89 name=f"message_{layer}", 90 use_bias=True, 91 )(xi) 92 w = ( 93 nn.Dense( 94 self.dim + 2 * nchannels, 95 name=f"radial_linear_{layer}", 96 use_bias=True, 97 )(radial_basis) 98 * switch 99 ) 100 dxij, hvv, hvs = jnp.split( 101 phi[edge_dst] * w, [self.dim, self.dim + nchannels], axis=-1 102 ) 103 104 dvij = dirij * hvs[:, None,:] 105 if layer > 0: 106 dvij = dvij + Vi[edge_dst] * hvv[:, None,:] 107 108 # aggregate messages 109 v_message = Vi + jax.ops.segment_sum(dvij, edge_src, Vi.shape[0]) 110 x_message = xi + jax.ops.segment_sum(dxij, edge_src, xi.shape[0]) 111 112 # update 113 u,v = jnp.split( 114 nn.Dense( 115 2 * self.nchannels, 116 use_bias=False, 117 name=f"UV_{layer}", 118 )(v_message), 119 2, 120 axis=-1, 121 ) 122 123 scals = (u * v).sum(axis=1) 124 norms = tssr3((v**2).sum(axis=1)) 125 126 A = FullyConnectedNet( 127 [*self.update_hidden, self.dim + 2 * nchannels], 128 activation=self.activation, 129 name=f"update_{layer}", 130 use_bias=True, 131 )(jnp.concatenate((x_message, norms), axis=-1)) 132 133 ass, asv, avv = jnp.split( 134 A, 135 [self.dim, self.dim + nchannels], 136 axis=-1, 137 ) 138 139 Vi = Vi + u * avv[:, None,:] 140 if self.dim != nchannels: 141 dxi = nn.Dense(self.dim, name=f"resize_{layer}", use_bias=False)( 142 scals * asv 143 ) 144 else: 145 dxi = scals * asv 146 147 xi = xi + ass + dxi 148 149 if self.keep_all_layers: 150 xis.append(xi) 151 152 output = { 153 **inputs, 154 self.embedding_key: xi, 155 self.tensor_embedding_key: Vi, 156 } 157 if self.keep_all_layers: 158 output[self.embedding_key + "_layers"] = jnp.stack(xis, axis=1) 159 return output
14class PAINNEmbedding(nn.Module): 15 """polarizable atom interaction neural network 16 17 ### Reference 18 K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller. SchNet - a deep learning architecture for molecules and materials. The Journal of Chemical Physics 148(24), 241722 (2018) 19 https://doi.org/10.1063/1.5019779 20 21 """ 22 23 _graphs_properties: Dict 24 dim: int = 128 25 """ The dimension of the embedding. """ 26 nlayers: int = 3 27 """ The number of interaction layers. """ 28 nchannels: Optional[int] = None 29 """ The number of equivariant channels. If None, it is set to dim. """ 30 message_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 31 """ The hidden layers for the message network.""" 32 update_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [128]) 33 """ The hidden layers for the update network.""" 34 activation: Union[Callable, str] = "silu" 35 """ The activation function.""" 36 graph_key: str = "graph" 37 """ The key for the graph input.""" 38 embedding_key: str = "embedding" 39 """ The key for the embedding output.""" 40 tensor_embedding_key: str = "embedding_vectors" 41 """ The key for the tensor embedding output.""" 42 species_encoding: dict = dataclasses.field(default_factory=dict) 43 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 44 radial_basis: dict = dataclasses.field(default_factory=dict) 45 """ The radial basis function parameters. See `fennol.models.misc.encodings.RadialBasis`.""" 46 keep_all_layers: bool = False 47 """ Whether to keep the embedding from each layer in the output.""" 48 49 FID: ClassVar[str] = "PAINN" 50 51 @nn.compact 52 def __call__(self, inputs): 53 species = inputs["species"] 54 assert ( 55 len(species.shape) == 1 56 ), "Species must be a 1D array (batches must be flattened)" 57 58 graph = inputs[self.graph_key] 59 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 60 61 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 62 63 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 64 species 65 ) 66 xi = nn.Dense(self.dim, name="species_linear", use_bias=True)(onehot) 67 68 nchannels = self.nchannels if self.nchannels is not None else self.dim 69 70 distances = graph["distances"] 71 switch = graph["switch"][:, None] 72 dirij = (graph["vec"] / distances[:, None])[:, :,None] 73 Vi = jnp.zeros((xi.shape[0], 3, nchannels), dtype=xi.dtype) 74 75 radial_basis = RadialBasis( 76 **{ 77 **self.radial_basis, 78 "end": cutoff, 79 "name": f"RadialBasis", 80 } 81 )(distances) 82 83 if self.keep_all_layers: 84 xis = [] 85 for layer in range(self.nlayers): 86 # compute messages 87 phi = FullyConnectedNet( 88 [*self.message_hidden, self.dim + 2 * nchannels], 89 activation=self.activation, 90 name=f"message_{layer}", 91 use_bias=True, 92 )(xi) 93 w = ( 94 nn.Dense( 95 self.dim + 2 * nchannels, 96 name=f"radial_linear_{layer}", 97 use_bias=True, 98 )(radial_basis) 99 * switch 100 ) 101 dxij, hvv, hvs = jnp.split( 102 phi[edge_dst] * w, [self.dim, self.dim + nchannels], axis=-1 103 ) 104 105 dvij = dirij * hvs[:, None,:] 106 if layer > 0: 107 dvij = dvij + Vi[edge_dst] * hvv[:, None,:] 108 109 # aggregate messages 110 v_message = Vi + jax.ops.segment_sum(dvij, edge_src, Vi.shape[0]) 111 x_message = xi + jax.ops.segment_sum(dxij, edge_src, xi.shape[0]) 112 113 # update 114 u,v = jnp.split( 115 nn.Dense( 116 2 * self.nchannels, 117 use_bias=False, 118 name=f"UV_{layer}", 119 )(v_message), 120 2, 121 axis=-1, 122 ) 123 124 scals = (u * v).sum(axis=1) 125 norms = tssr3((v**2).sum(axis=1)) 126 127 A = FullyConnectedNet( 128 [*self.update_hidden, self.dim + 2 * nchannels], 129 activation=self.activation, 130 name=f"update_{layer}", 131 use_bias=True, 132 )(jnp.concatenate((x_message, norms), axis=-1)) 133 134 ass, asv, avv = jnp.split( 135 A, 136 [self.dim, self.dim + nchannels], 137 axis=-1, 138 ) 139 140 Vi = Vi + u * avv[:, None,:] 141 if self.dim != nchannels: 142 dxi = nn.Dense(self.dim, name=f"resize_{layer}", use_bias=False)( 143 scals * asv 144 ) 145 else: 146 dxi = scals * asv 147 148 xi = xi + ass + dxi 149 150 if self.keep_all_layers: 151 xis.append(xi) 152 153 output = { 154 **inputs, 155 self.embedding_key: xi, 156 self.tensor_embedding_key: Vi, 157 } 158 if self.keep_all_layers: 159 output[self.embedding_key + "_layers"] = jnp.stack(xis, axis=1) 160 return output
polarizable atom interaction neural network
Reference
K.T. Schütt. P.-J. Kindermans, H. E. Sauceda, S. Chmiela, A. Tkatchenko, K.-R. Müller. SchNet - a deep learning architecture for molecules and materials. The Journal of Chemical Physics 148(24), 241722 (2018) https://doi.org/10.1063/1.5019779
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
.
The radial basis function parameters. See fennol.models.misc.encodings.RadialBasis
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.