fennol.models.physics.bond
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import numpy as np 5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 6from ...utils import AtomicUnits as au 7import dataclasses 8from ...utils.periodic_table import ( 9 D3_ELECTRONEGATIVITIES, 10 D3_HARDNESSES, 11 D3_VDW_RADII, 12 D3_COV_RADII, 13 D3_KAPPA, 14 VDW_RADII, 15 VALENCE_ELECTRONS, 16 PAULING_ELECTRONEGATIVITY, 17) 18 19 20class CND4(nn.Module): 21 """ Coordination number as defined in D4 dispersion correction 22 23 FID : CN_D4 24 """ 25 graph_key: str = "graph" 26 """ The key for the graph input.""" 27 output_key: Optional[str] = None 28 """ The key for the output.""" 29 k0: float = 7.5 30 k1: float = 4.1 31 k2: float = 19.09 32 k3: float = 254.56 33 electronegativity_factor: bool = False 34 """ Whether to include electronegativity factor.""" 35 trainable: bool = False 36 """ Whether the parameters are trainable.""" 37 38 FID: ClassVar[str] = "CN_D4" 39 40 @nn.compact 41 def __call__(self, inputs): 42 graph = inputs[self.graph_key] 43 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 44 species = inputs["species"] 45 46 if self.trainable: 47 rc = self.param("rc", lambda key: jnp.asarray(D3_COV_RADII))[species] 48 else: 49 rc = jnp.asarray(D3_COV_RADII)[species] 50 rcij = rc[edge_src] + rc[edge_dst] 51 rij = graph["distances"] / au.BOHR 52 53 if self.trainable: 54 k0 = self.k0 * jnp.abs(self.param("k0", lambda key: jnp.asarray(1.0))) 55 else: 56 k0 = self.k0 57 58 CNij = ( 59 0.5 * (1 + jax.scipy.special.erf(-k0 * (rij / rcij - 1.))) * graph["switch"] 60 ) 61 62 if self.electronegativity_factor: 63 k1 = self.k1 64 k2 = self.k2 65 k3 = self.k3 66 if self.trainable: 67 k1 = k1 * jnp.abs(self.param("k1", lambda key: jnp.asarray(1.0))) 68 k2 = self.param("k2", lambda key: jnp.asarray(1.0)) 69 k3 = jnp.abs(self.param("k3", lambda key: jnp.asarray(1.0))) 70 en = self.param("en", lambda key: jnp.asarray(D3_ELECTRONEGATIVITIES))[ 71 species 72 ] 73 else: 74 en = jnp.asarray(PAULING_ELECTRONEGATIVITY)[species] 75 en_ij = jnp.abs(en[edge_src] - en[edge_dst]) 76 dij = k1 * jnp.exp(-((en_ij + k2) ** 2) / k3) 77 CNij = CNij * dij 78 CNi = jax.ops.segment_sum(CNij, edge_src, species.shape[0]) 79 80 output_key = self.name if self.output_key is None else self.output_key 81 return {**inputs, output_key: CNi, output_key + "_pair": CNij} 82 83 84class SumSwitch(nn.Module): 85 """Sum (a power of) the switch values for each neighbor. 86 87 FID : SUM_SWITCH 88 89 """ 90 graph_key: str = "graph" 91 """ The key for the graph input.""" 92 output_key: Optional[str] = None 93 """ The key for the output.""" 94 pow: float = 1.0 95 """ The power to raise the switch values to.""" 96 trainable: bool = False 97 """ Whether the pow parameter is trainable.""" 98 99 FID: ClassVar[str] = "SUM_SWITCH" 100 101 @nn.compact 102 def __call__(self, inputs): 103 graph = inputs[self.graph_key] 104 edge_src = graph["edge_src"] 105 switch = graph["switch"] 106 107 if self.trainable: 108 p = jnp.abs( 109 self.param("pow", lambda key: jnp.asarray(self.pow)) 110 ) 111 else: 112 p = self.pow 113 shift=(1.e-3)**p 114 115 cn = jax.ops.segment_sum((1.e-3+switch)**p-shift, edge_src, inputs["species"].shape[0]) 116 117 output_key = self.name if self.output_key is None else self.output_key 118 return {**inputs, output_key: cn} 119 120 121class CNShift(nn.Module): 122 123 cn_key: str 124 output_key: Optional[str] = None 125 kappa_key: Optional[str] = None 126 sqrt_shift: float = 1.0e-6 127 ref_value: Union[str, float] = 1.0 128 enforce_positive: bool = False 129 cn_pow: float = 0.5 130 131 FID: ClassVar[str] = "CN_SHIFT" 132 133 134 @nn.compact 135 def __call__(self, inputs): 136 CNi = inputs[self.cn_key] 137 if self.kappa_key is not None: 138 kappai = inputs[self.kappa_key] 139 assert kappai.shape == CNi.shape 140 else: 141 species = inputs["species"] 142 kappai = self.param("kappa", nn.initializers.zeros, (len(D3_COV_RADII),))[ 143 species 144 ] 145 shift = kappai * (CNi + self.sqrt_shift) ** self.cn_pow 146 147 if isinstance(self.ref_value, str): 148 ref_value = inputs[self.ref_value] 149 assert ref_value.shape == shift.shape 150 else: 151 ref_value = self.ref_value 152 153 if self.enforce_positive: 154 shift = jax.nn.celu(shift, alpha=ref_value) 155 156 out = ref_value + shift 157 158 output_key = self.name if self.output_key is None else self.output_key 159 return {**inputs, output_key: out} 160 161 162class CNStore(nn.Module): 163 cn_key: str 164 output_key: Optional[str] = None 165 store_size: int = 10 166 n_gaussians: int = 4 167 isolated_value: float = 0.0 168 init_scale_cn: float = 5.0 169 init_scale_values: float = 1.0 170 beta: float = 6.0 171 trainable: bool = True 172 output_dim: int = 1 173 squeeze: bool = True 174 175 FID: ClassVar[str] = "CN_STORE" 176 177 @nn.compact 178 def __call__(self, inputs): 179 cn = inputs[self.cn_key] 180 species = inputs["species"] 181 182 cn_refs = self.param( 183 "cn_refs", 184 nn.initializers.uniform(self.init_scale_cn), 185 (len(D3_COV_RADII), self.store_size), 186 )[species] 187 values_refs = self.param( 188 "values_refs", 189 nn.initializers.uniform(self.init_scale_values), 190 (len(D3_COV_RADII), self.store_size, self.output_dim), 191 )[species] 192 193 beta = self.beta 194 if self.trainable: 195 beta = self.param("beta", lambda key: jnp.asarray(self.beta)) 196 j = jnp.asarray(np.arange(self.n_gaussians)[None, None, :], dtype=jnp.float32) 197 delta_cns = jnp.log( 198 jnp.sum( 199 jnp.exp(-beta * j * ((cn[:, None] - cn_refs) ** 2)[:, :, None]), axis=-1 200 ) 201 ) 202 w = jax.nn.softmax(delta_cns, axis=-1) 203 204 values = jnp.sum(w[:, :, None] * values_refs, axis=1) 205 if self.output_dim == 1 and self.squeeze: 206 values = jnp.squeeze(values, axis=-1) 207 208 output_key = self.name if self.output_key is None else self.output_key 209 return {**inputs, output_key: values} 210 211 212class FlatBottom(nn.Module): 213 """Flat bottom potential energy surface. 214 215 Realized by Côme Cattin, 2024. 216 217 Flat bottom potential energy: 218 E = alpha * (r - req) ** 2 if r >= req 219 E = 0 if r < req 220 221 FID: FLAT_BOTTOM 222 """ 223 224 energy_key: Optional[str] = None 225 """Key of the energy in the outputs.""" 226 graph_key: str = "graph" 227 """Key of the graph in the inputs.""" 228 alpha: float = 400.0 229 """Force constant of the flat bottom potential (in kcal/mol/A^2).""" 230 r_eq_factor: float = 1.3 231 """Factor to multiply the sum of the VDW radii of the two atoms.""" 232 _energy_unit: str = "Ha" 233 """The energy unit of the model. **Automatically set by FENNIX**""" 234 235 FID: ClassVar[str] = "FLAT_BOTTOM" 236 237 @nn.compact 238 def __call__(self, inputs): 239 240 species = inputs["species"] 241 graph = inputs[self.graph_key] 242 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 243 distances = graph["distances"] 244 rij = distances / au.BOHR 245 training = "training" in inputs.get("flags", {}) 246 247 output = {} 248 energy_key = self.energy_key if self.energy_key is not None else self.name 249 250 if training: 251 output[energy_key] = jnp.zeros(species.shape[0],dtype=distances.dtype) 252 return {**inputs, **output} 253 254 # req is the sum of the covalent radii of the two atoms 255 rcov = jnp.asarray(D3_COV_RADII)[species] 256 req = self.r_eq_factor * (rcov[edge_src] + rcov[edge_dst]) 257 258 alpha = inputs.get("alpha", self.alpha)/ au.KCALPERMOL*au.BOHR**2 259 260 flat_bottom_energy = jnp.where( 261 rij > req, alpha * (rij - req) ** 2, 0. 262 ) 263 264 flat_bottom_energy = jax.ops.segment_sum(flat_bottom_energy, edge_src, num_segments=species.shape[0]) 265 266 energy_unit = au.get_multiplier(self._energy_unit) 267 output[energy_key] = flat_bottom_energy * energy_unit 268 269 return {**inputs, **output}
21class CND4(nn.Module): 22 """ Coordination number as defined in D4 dispersion correction 23 24 FID : CN_D4 25 """ 26 graph_key: str = "graph" 27 """ The key for the graph input.""" 28 output_key: Optional[str] = None 29 """ The key for the output.""" 30 k0: float = 7.5 31 k1: float = 4.1 32 k2: float = 19.09 33 k3: float = 254.56 34 electronegativity_factor: bool = False 35 """ Whether to include electronegativity factor.""" 36 trainable: bool = False 37 """ Whether the parameters are trainable.""" 38 39 FID: ClassVar[str] = "CN_D4" 40 41 @nn.compact 42 def __call__(self, inputs): 43 graph = inputs[self.graph_key] 44 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 45 species = inputs["species"] 46 47 if self.trainable: 48 rc = self.param("rc", lambda key: jnp.asarray(D3_COV_RADII))[species] 49 else: 50 rc = jnp.asarray(D3_COV_RADII)[species] 51 rcij = rc[edge_src] + rc[edge_dst] 52 rij = graph["distances"] / au.BOHR 53 54 if self.trainable: 55 k0 = self.k0 * jnp.abs(self.param("k0", lambda key: jnp.asarray(1.0))) 56 else: 57 k0 = self.k0 58 59 CNij = ( 60 0.5 * (1 + jax.scipy.special.erf(-k0 * (rij / rcij - 1.))) * graph["switch"] 61 ) 62 63 if self.electronegativity_factor: 64 k1 = self.k1 65 k2 = self.k2 66 k3 = self.k3 67 if self.trainable: 68 k1 = k1 * jnp.abs(self.param("k1", lambda key: jnp.asarray(1.0))) 69 k2 = self.param("k2", lambda key: jnp.asarray(1.0)) 70 k3 = jnp.abs(self.param("k3", lambda key: jnp.asarray(1.0))) 71 en = self.param("en", lambda key: jnp.asarray(D3_ELECTRONEGATIVITIES))[ 72 species 73 ] 74 else: 75 en = jnp.asarray(PAULING_ELECTRONEGATIVITY)[species] 76 en_ij = jnp.abs(en[edge_src] - en[edge_dst]) 77 dij = k1 * jnp.exp(-((en_ij + k2) ** 2) / k3) 78 CNij = CNij * dij 79 CNi = jax.ops.segment_sum(CNij, edge_src, species.shape[0]) 80 81 output_key = self.name if self.output_key is None else self.output_key 82 return {**inputs, output_key: CNi, output_key + "_pair": CNij}
Coordination number as defined in D4 dispersion correction
FID : CN_D4
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
85class SumSwitch(nn.Module): 86 """Sum (a power of) the switch values for each neighbor. 87 88 FID : SUM_SWITCH 89 90 """ 91 graph_key: str = "graph" 92 """ The key for the graph input.""" 93 output_key: Optional[str] = None 94 """ The key for the output.""" 95 pow: float = 1.0 96 """ The power to raise the switch values to.""" 97 trainable: bool = False 98 """ Whether the pow parameter is trainable.""" 99 100 FID: ClassVar[str] = "SUM_SWITCH" 101 102 @nn.compact 103 def __call__(self, inputs): 104 graph = inputs[self.graph_key] 105 edge_src = graph["edge_src"] 106 switch = graph["switch"] 107 108 if self.trainable: 109 p = jnp.abs( 110 self.param("pow", lambda key: jnp.asarray(self.pow)) 111 ) 112 else: 113 p = self.pow 114 shift=(1.e-3)**p 115 116 cn = jax.ops.segment_sum((1.e-3+switch)**p-shift, edge_src, inputs["species"].shape[0]) 117 118 output_key = self.name if self.output_key is None else self.output_key 119 return {**inputs, output_key: cn}
Sum (a power of) the switch values for each neighbor.
FID : SUM_SWITCH
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
122class CNShift(nn.Module): 123 124 cn_key: str 125 output_key: Optional[str] = None 126 kappa_key: Optional[str] = None 127 sqrt_shift: float = 1.0e-6 128 ref_value: Union[str, float] = 1.0 129 enforce_positive: bool = False 130 cn_pow: float = 0.5 131 132 FID: ClassVar[str] = "CN_SHIFT" 133 134 135 @nn.compact 136 def __call__(self, inputs): 137 CNi = inputs[self.cn_key] 138 if self.kappa_key is not None: 139 kappai = inputs[self.kappa_key] 140 assert kappai.shape == CNi.shape 141 else: 142 species = inputs["species"] 143 kappai = self.param("kappa", nn.initializers.zeros, (len(D3_COV_RADII),))[ 144 species 145 ] 146 shift = kappai * (CNi + self.sqrt_shift) ** self.cn_pow 147 148 if isinstance(self.ref_value, str): 149 ref_value = inputs[self.ref_value] 150 assert ref_value.shape == shift.shape 151 else: 152 ref_value = self.ref_value 153 154 if self.enforce_positive: 155 shift = jax.nn.celu(shift, alpha=ref_value) 156 157 out = ref_value + shift 158 159 output_key = self.name if self.output_key is None else self.output_key 160 return {**inputs, output_key: out}
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
163class CNStore(nn.Module): 164 cn_key: str 165 output_key: Optional[str] = None 166 store_size: int = 10 167 n_gaussians: int = 4 168 isolated_value: float = 0.0 169 init_scale_cn: float = 5.0 170 init_scale_values: float = 1.0 171 beta: float = 6.0 172 trainable: bool = True 173 output_dim: int = 1 174 squeeze: bool = True 175 176 FID: ClassVar[str] = "CN_STORE" 177 178 @nn.compact 179 def __call__(self, inputs): 180 cn = inputs[self.cn_key] 181 species = inputs["species"] 182 183 cn_refs = self.param( 184 "cn_refs", 185 nn.initializers.uniform(self.init_scale_cn), 186 (len(D3_COV_RADII), self.store_size), 187 )[species] 188 values_refs = self.param( 189 "values_refs", 190 nn.initializers.uniform(self.init_scale_values), 191 (len(D3_COV_RADII), self.store_size, self.output_dim), 192 )[species] 193 194 beta = self.beta 195 if self.trainable: 196 beta = self.param("beta", lambda key: jnp.asarray(self.beta)) 197 j = jnp.asarray(np.arange(self.n_gaussians)[None, None, :], dtype=jnp.float32) 198 delta_cns = jnp.log( 199 jnp.sum( 200 jnp.exp(-beta * j * ((cn[:, None] - cn_refs) ** 2)[:, :, None]), axis=-1 201 ) 202 ) 203 w = jax.nn.softmax(delta_cns, axis=-1) 204 205 values = jnp.sum(w[:, :, None] * values_refs, axis=1) 206 if self.output_dim == 1 and self.squeeze: 207 values = jnp.squeeze(values, axis=-1) 208 209 output_key = self.name if self.output_key is None else self.output_key 210 return {**inputs, output_key: values}
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
213class FlatBottom(nn.Module): 214 """Flat bottom potential energy surface. 215 216 Realized by Côme Cattin, 2024. 217 218 Flat bottom potential energy: 219 E = alpha * (r - req) ** 2 if r >= req 220 E = 0 if r < req 221 222 FID: FLAT_BOTTOM 223 """ 224 225 energy_key: Optional[str] = None 226 """Key of the energy in the outputs.""" 227 graph_key: str = "graph" 228 """Key of the graph in the inputs.""" 229 alpha: float = 400.0 230 """Force constant of the flat bottom potential (in kcal/mol/A^2).""" 231 r_eq_factor: float = 1.3 232 """Factor to multiply the sum of the VDW radii of the two atoms.""" 233 _energy_unit: str = "Ha" 234 """The energy unit of the model. **Automatically set by FENNIX**""" 235 236 FID: ClassVar[str] = "FLAT_BOTTOM" 237 238 @nn.compact 239 def __call__(self, inputs): 240 241 species = inputs["species"] 242 graph = inputs[self.graph_key] 243 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 244 distances = graph["distances"] 245 rij = distances / au.BOHR 246 training = "training" in inputs.get("flags", {}) 247 248 output = {} 249 energy_key = self.energy_key if self.energy_key is not None else self.name 250 251 if training: 252 output[energy_key] = jnp.zeros(species.shape[0],dtype=distances.dtype) 253 return {**inputs, **output} 254 255 # req is the sum of the covalent radii of the two atoms 256 rcov = jnp.asarray(D3_COV_RADII)[species] 257 req = self.r_eq_factor * (rcov[edge_src] + rcov[edge_dst]) 258 259 alpha = inputs.get("alpha", self.alpha)/ au.KCALPERMOL*au.BOHR**2 260 261 flat_bottom_energy = jnp.where( 262 rij > req, alpha * (rij - req) ** 2, 0. 263 ) 264 265 flat_bottom_energy = jax.ops.segment_sum(flat_bottom_energy, edge_src, num_segments=species.shape[0]) 266 267 energy_unit = au.get_multiplier(self._energy_unit) 268 output[energy_key] = flat_bottom_energy * energy_unit 269 270 return {**inputs, **output}
Flat bottom potential energy surface.
Realized by Côme Cattin, 2024.
Flat bottom potential energy: E = alpha * (r - req) ** 2 if r >= req E = 0 if r < req
FID: FLAT_BOTTOM
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.