fennol.models.embeddings.chgnet
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import dataclasses 5import numpy as np 6from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 7 8from ...utils.initializers import initializer_from_str 9from ..misc.nets import GatedPerceptron 10 11 12class CHGNetEmbedding(nn.Module): 13 """ Crystal Hamiltonian Graph Neural Network 14 15 FID: CHGNET 16 17 ### Reference 18 Deng, B., Zhong, P., Jun, K. et al. CHGNet as a pretrained universal neural network potential for charge-informed atomistic modelling. 19 Nat Mach Intell 5, 1031–1041 (2023). https://doi.org/10.1038/s42256-023-00716-3 20 21 """ 22 _graphs_properties: Dict 23 dim: int = 64 24 """The dimension of the embedding.""" 25 nmax_angle: int = 4 26 """ The maximum fourier order for the angle representation.""" 27 nlayers: int = 3 28 """The number of layers.""" 29 graph_key: str = "graph" 30 """The key for the graph input.""" 31 graph_angle_key: Optional[str] = None 32 """The key for the angular graph input.""" 33 embedding_key: str = "embedding" 34 """The key for the embedding output.""" 35 species_encoding: dict = dataclasses.field(default_factory=dict) 36 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 37 radial_basis: dict = dataclasses.field(default_factory=dict) 38 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`""" 39 radial_basis_angle: Optional[dict] = None 40 """The radial basis parameters for angle embedding. See `fennol.models.misc.encodings.RadialBasis`""" 41 keep_all_layers: bool = False 42 """Whether to keep all layers in the output.""" 43 kernel_init: Union[str, Callable] = "lecun_normal()" 44 """The kernel initializer for Dense operations.""" 45 46 FID: ClassVar[str] = "CHGNET" 47 48 @nn.compact 49 def __call__(self, inputs): 50 species = inputs["species"] 51 assert ( 52 len(species.shape) == 1 53 ), "Species must be a 1D array (batches must be flattened)" 54 55 kernel_init = ( 56 initializer_from_str(self.kernel_init) 57 if isinstance(self.kernel_init, str) 58 else self.kernel_init 59 ) 60 61 ################################################## 62 # Check that the graph_angle is a subgraph of graph 63 graph = inputs[self.graph_key] 64 graph_angle_key = ( 65 self.graph_angle_key if self.graph_angle_key is not None else self.graph_key 66 ) 67 graph_angle = inputs[graph_angle_key] 68 69 correct_graph = ( 70 graph_angle_key == self.graph_key 71 or self._graphs_properties[graph_angle_key]["parent_graph"] 72 == self.graph_key 73 ) 74 assert ( 75 correct_graph 76 ), f"graph_angle_key={graph_angle_key} must be a subgraph of graph_key={self.graph_key}" 77 assert "angles" in graph_angle, f"Graph {graph_angle_key} must contain angles" 78 # check if graph_angle is a filtered graph 79 filtered = "parent_graph" in self._graphs_properties[graph_angle_key] 80 if filtered: 81 filter_indices = graph_angle["filter_indices"] 82 83 ################################################## 84 ### SPECIES ENCODING ### 85 zi = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(species) 86 87 vi = nn.Dense(self.dim, name="vi0", use_bias=True, kernel_init=kernel_init)(zi) 88 89 ################################################## 90 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 91 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 92 distances = graph["distances"] 93 switch = graph["switch"][:, None] 94 95 ### COMPUTE RADIAL BASIS ### 96 radial_basis = RadialBasis( 97 **{ 98 **self.radial_basis, 99 "end": cutoff, 100 "name": f"RadialBasis", 101 } 102 )(distances) 103 104 ################################################## 105 ### GET ANGLES ### 106 angles = graph_angle["angles"][:, None] 107 angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"] 108 switch_angles = graph_angle["switch"][:, None] 109 central_atom = graph_angle["central_atom"] 110 111 ### COMPUTE RADIAL BASIS FOR ANGLES ### 112 if self.radial_basis_angle is not None: 113 dangles = graph_angle["distances"] 114 radial_basis_angle = ( 115 RadialBasis( 116 **{ 117 **self.radial_basis_angle, 118 "end": self._graphs_properties[graph_angle_key]["cutoff"], 119 "name": f"RadialBasisAngle", 120 } 121 )(dangles) 122 * switch_angles 123 ) 124 125 else: 126 if filtered: 127 radial_basis_angle = radial_basis[filter_indices] * switch_angles 128 else: 129 radial_basis_angle = radial_basis * switch 130 131 radial_basis = radial_basis * switch 132 133 eij, eija = jnp.split( 134 nn.Dense( 135 2 * self.dim, use_bias=False, kernel_init=kernel_init, name="eij0" 136 )(radial_basis), 137 2, 138 axis=-1, 139 ) 140 141 eijb = nn.Dense(self.dim, use_bias=False, kernel_init=kernel_init, name="eijb")( 142 radial_basis_angle 143 ) 144 eijkb = eijb[angle_src] * eijb[angle_dst] 145 146 ################################################## 147 ### ANGULAR BASIS ### 148 # build fourier series for angles 149 nangles = jnp.asarray( 150 np.arange(self.nmax_angle + 1, dtype=distances.dtype)[None, :] 151 ) 152 153 Ac = jnp.cos(nangles * angles) 154 As = jnp.sin(nangles[:, 1:] * angles) 155 aijk = nn.Dense( 156 self.dim, use_bias=False, kernel_init=kernel_init, name="aijk0" 157 )(jnp.concatenate([Ac, As], axis=-1)) 158 aikj = aijk 159 160 ################################################## 161 if self.keep_all_layers: 162 vis = [] 163 164 ### LOOP OVER LAYERS ### 165 for layer in range(self.nlayers): 166 phiv = GatedPerceptron( 167 self.dim, 168 use_bias=True, 169 kernel_init=kernel_init, 170 activation=nn.silu, 171 name=f"phiv{layer+1}", 172 )(jnp.concatenate([vi[edge_src], vi[edge_dst], eij], axis=-1)) 173 174 vi = vi + nn.Dense( 175 self.dim, use_bias=True, kernel_init=kernel_init, name=f"vi{layer+1}" 176 )(jax.ops.segment_sum(phiv * eija, edge_src, vi.shape[0])) 177 178 if self.keep_all_layers: 179 vis.append(vi) 180 181 if layer == self.nlayers - 1: 182 break 183 184 eij_ang = eij[filter_indices] if filtered else eij 185 eij_angsrc = eij_ang[angle_src] 186 eij_angdst = eij_ang[angle_dst] 187 188 phie = GatedPerceptron( 189 self.dim, 190 use_bias=True, 191 kernel_init=kernel_init, 192 activation=nn.silu, 193 name=f"phie{layer+1}", 194 ) 195 196 vi_ang = vi[central_atom] 197 phie_ijk = phie( 198 jnp.concatenate( 199 [eij_angsrc, eij_angdst, aijk, vi_ang], 200 axis=-1, 201 ) 202 ) 203 phie_ikj = phie( 204 jnp.concatenate( 205 [eij_angdst, eij_angsrc, aikj, vi_ang], 206 axis=-1, 207 ) 208 ) 209 210 eij_ang = eij_ang + nn.Dense( 211 self.dim, use_bias=False, kernel_init=kernel_init, name=f"Le{layer+1}" 212 )( 213 jax.ops.segment_sum(phie_ikj * eijkb, angle_dst, eij_ang.shape[0]) 214 + jax.ops.segment_sum(phie_ijk * eijkb, angle_src, eij_ang.shape[0]) 215 ) 216 eij_angsrc = eij_ang[angle_src] 217 eij_angdst = eij_ang[angle_dst] 218 eij = eij.at[filter_indices].set(eij_ang) 219 220 phia = GatedPerceptron( 221 self.dim, 222 use_bias=True, 223 kernel_init=kernel_init, 224 activation=nn.silu, 225 name=f"phia{layer+1}", 226 ) 227 aijk = aijk + phia( 228 jnp.concatenate([eij_angsrc, eij_angdst, aijk, vi_ang], axis=-1) 229 ) 230 aikj = aikj + phia( 231 jnp.concatenate([eij_angdst, eij_angsrc, aikj, vi_ang], axis=-1) 232 ) 233 234 235 236 output = { 237 **inputs, 238 self.embedding_key: vi, 239 } 240 if self.keep_all_layers: 241 output[self.embedding_key + "_layers"] = jnp.stack(vis, axis=1) 242 return output
13class CHGNetEmbedding(nn.Module): 14 """ Crystal Hamiltonian Graph Neural Network 15 16 FID: CHGNET 17 18 ### Reference 19 Deng, B., Zhong, P., Jun, K. et al. CHGNet as a pretrained universal neural network potential for charge-informed atomistic modelling. 20 Nat Mach Intell 5, 1031–1041 (2023). https://doi.org/10.1038/s42256-023-00716-3 21 22 """ 23 _graphs_properties: Dict 24 dim: int = 64 25 """The dimension of the embedding.""" 26 nmax_angle: int = 4 27 """ The maximum fourier order for the angle representation.""" 28 nlayers: int = 3 29 """The number of layers.""" 30 graph_key: str = "graph" 31 """The key for the graph input.""" 32 graph_angle_key: Optional[str] = None 33 """The key for the angular graph input.""" 34 embedding_key: str = "embedding" 35 """The key for the embedding output.""" 36 species_encoding: dict = dataclasses.field(default_factory=dict) 37 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 38 radial_basis: dict = dataclasses.field(default_factory=dict) 39 """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`""" 40 radial_basis_angle: Optional[dict] = None 41 """The radial basis parameters for angle embedding. See `fennol.models.misc.encodings.RadialBasis`""" 42 keep_all_layers: bool = False 43 """Whether to keep all layers in the output.""" 44 kernel_init: Union[str, Callable] = "lecun_normal()" 45 """The kernel initializer for Dense operations.""" 46 47 FID: ClassVar[str] = "CHGNET" 48 49 @nn.compact 50 def __call__(self, inputs): 51 species = inputs["species"] 52 assert ( 53 len(species.shape) == 1 54 ), "Species must be a 1D array (batches must be flattened)" 55 56 kernel_init = ( 57 initializer_from_str(self.kernel_init) 58 if isinstance(self.kernel_init, str) 59 else self.kernel_init 60 ) 61 62 ################################################## 63 # Check that the graph_angle is a subgraph of graph 64 graph = inputs[self.graph_key] 65 graph_angle_key = ( 66 self.graph_angle_key if self.graph_angle_key is not None else self.graph_key 67 ) 68 graph_angle = inputs[graph_angle_key] 69 70 correct_graph = ( 71 graph_angle_key == self.graph_key 72 or self._graphs_properties[graph_angle_key]["parent_graph"] 73 == self.graph_key 74 ) 75 assert ( 76 correct_graph 77 ), f"graph_angle_key={graph_angle_key} must be a subgraph of graph_key={self.graph_key}" 78 assert "angles" in graph_angle, f"Graph {graph_angle_key} must contain angles" 79 # check if graph_angle is a filtered graph 80 filtered = "parent_graph" in self._graphs_properties[graph_angle_key] 81 if filtered: 82 filter_indices = graph_angle["filter_indices"] 83 84 ################################################## 85 ### SPECIES ENCODING ### 86 zi = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(species) 87 88 vi = nn.Dense(self.dim, name="vi0", use_bias=True, kernel_init=kernel_init)(zi) 89 90 ################################################## 91 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 92 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 93 distances = graph["distances"] 94 switch = graph["switch"][:, None] 95 96 ### COMPUTE RADIAL BASIS ### 97 radial_basis = RadialBasis( 98 **{ 99 **self.radial_basis, 100 "end": cutoff, 101 "name": f"RadialBasis", 102 } 103 )(distances) 104 105 ################################################## 106 ### GET ANGLES ### 107 angles = graph_angle["angles"][:, None] 108 angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"] 109 switch_angles = graph_angle["switch"][:, None] 110 central_atom = graph_angle["central_atom"] 111 112 ### COMPUTE RADIAL BASIS FOR ANGLES ### 113 if self.radial_basis_angle is not None: 114 dangles = graph_angle["distances"] 115 radial_basis_angle = ( 116 RadialBasis( 117 **{ 118 **self.radial_basis_angle, 119 "end": self._graphs_properties[graph_angle_key]["cutoff"], 120 "name": f"RadialBasisAngle", 121 } 122 )(dangles) 123 * switch_angles 124 ) 125 126 else: 127 if filtered: 128 radial_basis_angle = radial_basis[filter_indices] * switch_angles 129 else: 130 radial_basis_angle = radial_basis * switch 131 132 radial_basis = radial_basis * switch 133 134 eij, eija = jnp.split( 135 nn.Dense( 136 2 * self.dim, use_bias=False, kernel_init=kernel_init, name="eij0" 137 )(radial_basis), 138 2, 139 axis=-1, 140 ) 141 142 eijb = nn.Dense(self.dim, use_bias=False, kernel_init=kernel_init, name="eijb")( 143 radial_basis_angle 144 ) 145 eijkb = eijb[angle_src] * eijb[angle_dst] 146 147 ################################################## 148 ### ANGULAR BASIS ### 149 # build fourier series for angles 150 nangles = jnp.asarray( 151 np.arange(self.nmax_angle + 1, dtype=distances.dtype)[None, :] 152 ) 153 154 Ac = jnp.cos(nangles * angles) 155 As = jnp.sin(nangles[:, 1:] * angles) 156 aijk = nn.Dense( 157 self.dim, use_bias=False, kernel_init=kernel_init, name="aijk0" 158 )(jnp.concatenate([Ac, As], axis=-1)) 159 aikj = aijk 160 161 ################################################## 162 if self.keep_all_layers: 163 vis = [] 164 165 ### LOOP OVER LAYERS ### 166 for layer in range(self.nlayers): 167 phiv = GatedPerceptron( 168 self.dim, 169 use_bias=True, 170 kernel_init=kernel_init, 171 activation=nn.silu, 172 name=f"phiv{layer+1}", 173 )(jnp.concatenate([vi[edge_src], vi[edge_dst], eij], axis=-1)) 174 175 vi = vi + nn.Dense( 176 self.dim, use_bias=True, kernel_init=kernel_init, name=f"vi{layer+1}" 177 )(jax.ops.segment_sum(phiv * eija, edge_src, vi.shape[0])) 178 179 if self.keep_all_layers: 180 vis.append(vi) 181 182 if layer == self.nlayers - 1: 183 break 184 185 eij_ang = eij[filter_indices] if filtered else eij 186 eij_angsrc = eij_ang[angle_src] 187 eij_angdst = eij_ang[angle_dst] 188 189 phie = GatedPerceptron( 190 self.dim, 191 use_bias=True, 192 kernel_init=kernel_init, 193 activation=nn.silu, 194 name=f"phie{layer+1}", 195 ) 196 197 vi_ang = vi[central_atom] 198 phie_ijk = phie( 199 jnp.concatenate( 200 [eij_angsrc, eij_angdst, aijk, vi_ang], 201 axis=-1, 202 ) 203 ) 204 phie_ikj = phie( 205 jnp.concatenate( 206 [eij_angdst, eij_angsrc, aikj, vi_ang], 207 axis=-1, 208 ) 209 ) 210 211 eij_ang = eij_ang + nn.Dense( 212 self.dim, use_bias=False, kernel_init=kernel_init, name=f"Le{layer+1}" 213 )( 214 jax.ops.segment_sum(phie_ikj * eijkb, angle_dst, eij_ang.shape[0]) 215 + jax.ops.segment_sum(phie_ijk * eijkb, angle_src, eij_ang.shape[0]) 216 ) 217 eij_angsrc = eij_ang[angle_src] 218 eij_angdst = eij_ang[angle_dst] 219 eij = eij.at[filter_indices].set(eij_ang) 220 221 phia = GatedPerceptron( 222 self.dim, 223 use_bias=True, 224 kernel_init=kernel_init, 225 activation=nn.silu, 226 name=f"phia{layer+1}", 227 ) 228 aijk = aijk + phia( 229 jnp.concatenate([eij_angsrc, eij_angdst, aijk, vi_ang], axis=-1) 230 ) 231 aikj = aikj + phia( 232 jnp.concatenate([eij_angdst, eij_angsrc, aikj, vi_ang], axis=-1) 233 ) 234 235 236 237 output = { 238 **inputs, 239 self.embedding_key: vi, 240 } 241 if self.keep_all_layers: 242 output[self.embedding_key + "_layers"] = jnp.stack(vis, axis=1) 243 return output
Crystal Hamiltonian Graph Neural Network
FID: CHGNET
Reference
Deng, B., Zhong, P., Jun, K. et al. CHGNet as a pretrained universal neural network potential for charge-informed atomistic modelling. Nat Mach Intell 5, 1031–1041 (2023). https://doi.org/10.1038/s42256-023-00716-3
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
The radial basis parameters for angle embedding. See fennol.models.misc.encodings.RadialBasis
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.