fennol.models.embeddings.hipnn
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3 5from ..misc.encodings import SpeciesEncoding, RadialBasis 6import dataclasses 7import numpy as np 8from typing import Dict, Union, Callable, ClassVar 9from ...utils.activations import activation_from_str, tssr3 10from ..misc.nets import FullyConnectedNet 11 12 13class HIPNNEmbedding(nn.Module): 14 """Hierarchically Interacting Particle Neural Network 15 16 ### Reference 17 Adapted from N. Lubbers, J. S. Smith and K. Barros, Hierarchical modeling of molecular energies using a deep neural network 18 J. Chem. Phys. 148, 241715 (2018) (J. Chem. Phys. 148, 241715 (2018)) (https://doi.org/10.1063/1.5011181) 19 20 """ 21 22 _graphs_properties: Dict 23 dim: int = 80 24 """The dimension of the embedding.""" 25 n_onsite: int = 3 26 """The number of onsite layers per interaction layer.""" 27 nlayers: int = 2 28 """The number of interaction layers.""" 29 lmax: int = 0 30 """The maximum value degree of spherical harmonics.""" 31 n_message: int = 0 32 """The number of layers for the message NN.""" 33 activation: Union[Callable, str] = "silu" 34 """The activation function.""" 35 graph_key: str = "graph" 36 """The key for the graph input.""" 37 embedding_key: str = "embedding" 38 """The key for the embedding output.""" 39 species_encoding: dict = dataclasses.field(default_factory=dict) 40 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 41 radial_basis: dict = dataclasses.field(default_factory=lambda: {"dim": 20}) 42 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`.""" 43 keep_all_layers: bool = True 44 """Whether to keep embeddings from each layer in the output.""" 45 graph_l_key: str = "graph" 46 """The key for the graph input for the spherical harmonics.""" 47 48 FID: ClassVar[str] = "HIPNN" 49 50 51 @nn.compact 52 def __call__(self, inputs): 53 species = inputs["species"] 54 assert ( 55 len(species.shape) == 1 56 ), "Species must be a 1D array (batches must be flattened)" 57 58 graph = inputs[self.graph_key] 59 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 60 61 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 62 63 zi = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(species) 64 65 act = ( 66 activation_from_str(self.activation) 67 if isinstance(self.activation, str) 68 else self.activation 69 ) 70 71 distances = graph["distances"] 72 if self.lmax > 0: 73 filtered_l = "parent_graph" in self._graphs_properties[self.graph_l_key] 74 75 correct_graph = ( 76 self.graph_l_key == self.graph_key 77 or self._graphs_properties[self.graph_l_key]["parent_graph"] 78 == self.graph_key 79 ) 80 assert ( 81 correct_graph 82 ), f"graph_l_key={self.graph_l_key} must be a subgraph of graph_key={self.graph_key}" 83 84 graph_l = inputs[self.graph_l_key] 85 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 86 graph_l["vec"] / graph_l["distances"][:, None] 87 )[:, None, :] 88 if filtered_l: 89 Yij = Yij[:, :, 1:] 90 # reps_l = np.array([2 * l + 1 for l in range(1, self.lmax + 1)]) 91 92 switch = graph["switch"][:, None] 93 if self.keep_all_layers: 94 zis = [] 95 for layer in range(self.nlayers): 96 ### interaction layer 97 s = RadialBasis( 98 **{ 99 "basis": "gaussian_rinv", 100 "end": cutoff, 101 **self.radial_basis, 102 "name": f"RadialBasis_{layer}", 103 } 104 )(distances) 105 106 zself = nn.Dense(self.dim, name=f"self_int_{layer}", use_bias=True)(zi) 107 if self.n_message == 0: 108 V = self.param( 109 f"V_{layer}", 110 jax.nn.initializers.glorot_normal(batch_axis=0), 111 (s.shape[-1], zi.shape[-1], self.dim), 112 ) 113 mij = jnp.einsum("...j,...k,jkl->...l", s, zi[edge_dst], V) 114 else: 115 mij = FullyConnectedNet( 116 [self.dim] * (self.n_message + 1), 117 activation=act, 118 name=f"mij_{layer}", 119 )(jnp.concatenate([zi[edge_dst], s], axis=-1)) 120 121 if self.lmax == 0: 122 zi = act(zself.at[edge_src].add(mij * switch,mode="drop")) 123 else: 124 if filtered_l: 125 zself = zself.at[edge_src].add(mij * switch,mode="drop") 126 filter_indices = graph_l["filter_indices"] 127 mij = mij[filter_indices] 128 129 mij = mij * graph_l["switch"][:, None] 130 Mij = mij[:, :, None] * Yij 131 Mi = jax.ops.segment_sum(Mij, graph_l["edge_src"], zi.shape[0]) 132 133 if filtered_l: 134 zint = jnp.zeros_like(zself) 135 else: 136 zint = jax.lax.index_in_dim(Mi, 0, axis=-1, keepdims=False) 137 Mi = Mi[:, :, 1:] 138 139 ts = self.param( 140 f"ts_{layer}", 141 lambda key, shape: jax.random.normal(key, shape, dtype=zint.dtype), 142 (self.lmax, 3), 143 ) # .repeat(reps_l) 144 145 # zint = zint + jnp.sum(ts[None,None,:]*Mi**2,axis=-1) 146 for l in range(1, self.lmax + 1): 147 Ml = jax.lax.dynamic_slice_in_dim( 148 Mi, start_index=l**2 - 1, slice_size=2 * l + 1, axis=-1 149 ) 150 zint = zint + ts[l, 2] * tssr3( 151 nn.Dense(self.dim, name=f"linear_{layer}_l{l}")( 152 jnp.sum(Ml**2, axis=-1) 153 ) 154 ) 155 zi = act(zself + zint) 156 157 ### onsite layers 158 # zi = zi + FullyConnectedNet( 159 # [self.dim] * (self.n_onsite + 1), 160 # activation=act, 161 # use_bias=True, 162 # name=f"onsite_{layer}", 163 # )(zi) 164 for j in range(self.n_onsite): 165 zi = zi + FullyConnectedNet( 166 (self.dim, self.dim), 167 activation=act, 168 use_bias=True, 169 name=f"onsite_{layer}_{j}", 170 )(zi) 171 if self.keep_all_layers: 172 zis.append(zi) 173 174 output = { 175 **inputs, 176 self.embedding_key: zi, 177 } 178 if self.keep_all_layers: 179 output[self.embedding_key + "_layers"] = jnp.stack(zis, axis=1) 180 return output
14class HIPNNEmbedding(nn.Module): 15 """Hierarchically Interacting Particle Neural Network 16 17 ### Reference 18 Adapted from N. Lubbers, J. S. Smith and K. Barros, Hierarchical modeling of molecular energies using a deep neural network 19 J. Chem. Phys. 148, 241715 (2018) (J. Chem. Phys. 148, 241715 (2018)) (https://doi.org/10.1063/1.5011181) 20 21 """ 22 23 _graphs_properties: Dict 24 dim: int = 80 25 """The dimension of the embedding.""" 26 n_onsite: int = 3 27 """The number of onsite layers per interaction layer.""" 28 nlayers: int = 2 29 """The number of interaction layers.""" 30 lmax: int = 0 31 """The maximum value degree of spherical harmonics.""" 32 n_message: int = 0 33 """The number of layers for the message NN.""" 34 activation: Union[Callable, str] = "silu" 35 """The activation function.""" 36 graph_key: str = "graph" 37 """The key for the graph input.""" 38 embedding_key: str = "embedding" 39 """The key for the embedding output.""" 40 species_encoding: dict = dataclasses.field(default_factory=dict) 41 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 42 radial_basis: dict = dataclasses.field(default_factory=lambda: {"dim": 20}) 43 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`.""" 44 keep_all_layers: bool = True 45 """Whether to keep embeddings from each layer in the output.""" 46 graph_l_key: str = "graph" 47 """The key for the graph input for the spherical harmonics.""" 48 49 FID: ClassVar[str] = "HIPNN" 50 51 52 @nn.compact 53 def __call__(self, inputs): 54 species = inputs["species"] 55 assert ( 56 len(species.shape) == 1 57 ), "Species must be a 1D array (batches must be flattened)" 58 59 graph = inputs[self.graph_key] 60 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 61 62 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 63 64 zi = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(species) 65 66 act = ( 67 activation_from_str(self.activation) 68 if isinstance(self.activation, str) 69 else self.activation 70 ) 71 72 distances = graph["distances"] 73 if self.lmax > 0: 74 filtered_l = "parent_graph" in self._graphs_properties[self.graph_l_key] 75 76 correct_graph = ( 77 self.graph_l_key == self.graph_key 78 or self._graphs_properties[self.graph_l_key]["parent_graph"] 79 == self.graph_key 80 ) 81 assert ( 82 correct_graph 83 ), f"graph_l_key={self.graph_l_key} must be a subgraph of graph_key={self.graph_key}" 84 85 graph_l = inputs[self.graph_l_key] 86 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 87 graph_l["vec"] / graph_l["distances"][:, None] 88 )[:, None, :] 89 if filtered_l: 90 Yij = Yij[:, :, 1:] 91 # reps_l = np.array([2 * l + 1 for l in range(1, self.lmax + 1)]) 92 93 switch = graph["switch"][:, None] 94 if self.keep_all_layers: 95 zis = [] 96 for layer in range(self.nlayers): 97 ### interaction layer 98 s = RadialBasis( 99 **{ 100 "basis": "gaussian_rinv", 101 "end": cutoff, 102 **self.radial_basis, 103 "name": f"RadialBasis_{layer}", 104 } 105 )(distances) 106 107 zself = nn.Dense(self.dim, name=f"self_int_{layer}", use_bias=True)(zi) 108 if self.n_message == 0: 109 V = self.param( 110 f"V_{layer}", 111 jax.nn.initializers.glorot_normal(batch_axis=0), 112 (s.shape[-1], zi.shape[-1], self.dim), 113 ) 114 mij = jnp.einsum("...j,...k,jkl->...l", s, zi[edge_dst], V) 115 else: 116 mij = FullyConnectedNet( 117 [self.dim] * (self.n_message + 1), 118 activation=act, 119 name=f"mij_{layer}", 120 )(jnp.concatenate([zi[edge_dst], s], axis=-1)) 121 122 if self.lmax == 0: 123 zi = act(zself.at[edge_src].add(mij * switch,mode="drop")) 124 else: 125 if filtered_l: 126 zself = zself.at[edge_src].add(mij * switch,mode="drop") 127 filter_indices = graph_l["filter_indices"] 128 mij = mij[filter_indices] 129 130 mij = mij * graph_l["switch"][:, None] 131 Mij = mij[:, :, None] * Yij 132 Mi = jax.ops.segment_sum(Mij, graph_l["edge_src"], zi.shape[0]) 133 134 if filtered_l: 135 zint = jnp.zeros_like(zself) 136 else: 137 zint = jax.lax.index_in_dim(Mi, 0, axis=-1, keepdims=False) 138 Mi = Mi[:, :, 1:] 139 140 ts = self.param( 141 f"ts_{layer}", 142 lambda key, shape: jax.random.normal(key, shape, dtype=zint.dtype), 143 (self.lmax, 3), 144 ) # .repeat(reps_l) 145 146 # zint = zint + jnp.sum(ts[None,None,:]*Mi**2,axis=-1) 147 for l in range(1, self.lmax + 1): 148 Ml = jax.lax.dynamic_slice_in_dim( 149 Mi, start_index=l**2 - 1, slice_size=2 * l + 1, axis=-1 150 ) 151 zint = zint + ts[l, 2] * tssr3( 152 nn.Dense(self.dim, name=f"linear_{layer}_l{l}")( 153 jnp.sum(Ml**2, axis=-1) 154 ) 155 ) 156 zi = act(zself + zint) 157 158 ### onsite layers 159 # zi = zi + FullyConnectedNet( 160 # [self.dim] * (self.n_onsite + 1), 161 # activation=act, 162 # use_bias=True, 163 # name=f"onsite_{layer}", 164 # )(zi) 165 for j in range(self.n_onsite): 166 zi = zi + FullyConnectedNet( 167 (self.dim, self.dim), 168 activation=act, 169 use_bias=True, 170 name=f"onsite_{layer}_{j}", 171 )(zi) 172 if self.keep_all_layers: 173 zis.append(zi) 174 175 output = { 176 **inputs, 177 self.embedding_key: zi, 178 } 179 if self.keep_all_layers: 180 output[self.embedding_key + "_layers"] = jnp.stack(zis, axis=1) 181 return output
Hierarchically Interacting Particle Neural Network
Reference
Adapted from N. Lubbers, J. S. Smith and K. Barros, Hierarchical modeling of molecular energies using a deep neural network J. Chem. Phys. 148, 241715 (2018) (J. Chem. Phys. 148, 241715 (2018)) (https://doi.org/10.1063/1.5011181)
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.