fennol.models.physics.repulsion
1import pathlib 2import jax 3import jax.numpy as jnp 4import flax.linen as nn 5import numpy as np 6from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 7from ...utils import AtomicUnits as au 8from ...utils.periodic_table import D3_COV_RADII, UFF_VDW_RADII 9 10 11class RepulsionZBL(nn.Module): 12 """Repulsion energy based on the Ziegler-Biersack-Littmark potential 13 14 FID: REPULSION_ZBL 15 16 ### Reference 17 J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter 18 19 """ 20 21 _graphs_properties: Dict 22 graph_key: str = "graph" 23 """The key for the graph input.""" 24 energy_key: Optional[str] = None 25 """The key for the output energy.""" 26 trainable: bool = True 27 """Whether the parameters are trainable.""" 28 _energy_unit: str = "Ha" 29 """The energy unit of the model. **Automatically set by FENNIX**""" 30 proportional_regularization: bool = True 31 d: float = 0.46850 / au.BOHR 32 p: float = 0.23 33 alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162) 34 cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697) 35 cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514) 36 37 FID: ClassVar[str] = "REPULSION_ZBL" 38 39 @nn.compact 40 def __call__(self, inputs): 41 species = inputs["species"] 42 graph = inputs[self.graph_key] 43 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 44 45 training = "training" in inputs.get("flags", {}) 46 47 rijs = graph["distances"] / au.BOHR 48 49 d_ = self.d 50 p_ = self.p 51 assert len(self.alphas) == 4, "alphas must be a sequence of length 4" 52 alphas_ = np.array(self.alphas, dtype=rijs.dtype) 53 assert len(self.cs) == 4, "cs must be a sequence of length 4" 54 cs_ = np.array(self.cs, dtype=rijs.dtype) 55 cs_ = 0.5 * cs_ / np.sum(cs_) 56 if self.trainable: 57 d = jnp.abs( 58 self.param( 59 "d", 60 lambda key, d: jnp.asarray(d, dtype=rijs.dtype), 61 d_, 62 ) 63 ) 64 p = jnp.abs( 65 self.param( 66 "p", 67 lambda key, p: jnp.asarray(p, dtype=rijs.dtype), 68 p_, 69 ) 70 ) 71 cs = 0.5 * jax.nn.softmax( 72 self.param( 73 "cs", 74 lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype), 75 np.array(self.cs_logits, dtype=rijs.dtype), 76 ) 77 ) 78 alphas = jnp.abs( 79 self.param( 80 "alphas", 81 lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype), 82 alphas_, 83 ) 84 ) 85 86 if training: 87 if self.proportional_regularization: 88 reg = jnp.asarray( 89 ((1 - alphas / alphas_) ** 2).sum() 90 + ((1 - cs / cs_) ** 2).sum() 91 + (1 - p / p_) ** 2 92 + (1 - d / d_) ** 2 93 ).reshape(1) 94 else: 95 reg = jnp.asarray( 96 ((alphas_ - alphas) ** 2).sum() 97 + ((cs_ - cs) ** 2).sum() 98 + (p_ - p) ** 2 99 + (d_ - d) ** 2 100 ).reshape(1) 101 else: 102 cs = jnp.asarray(cs_) 103 alphas = jnp.asarray(alphas_) 104 d = d_ 105 p = p_ 106 107 if "alch_group" in inputs: 108 switch = graph["switch_raw"] 109 lambda_v = inputs["alch_vlambda"] 110 alch_group = inputs["alch_group"] 111 alch_m = inputs.get("alch_m", 2) 112 113 mask = alch_group[edge_src] == alch_group[edge_dst] 114 115 if "alch_softcore_rep" in inputs: 116 alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v) 117 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5) 118 lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v)) 119 switch = jnp.where( 120 mask, 121 switch, 122 (lambda_v**alch_m) * switch, 123 ) 124 else: 125 switch = graph["switch"] 126 127 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 128 Zij = Z[edge_src] * Z[edge_dst] 129 Zp = Z**p / d 130 x = rijs * (Zp[edge_src] + Zp[edge_dst]) 131 phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1) 132 133 ereppair = Zij * phi / rijs * switch 134 135 erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 136 137 energy_unit = au.get_multiplier(self._energy_unit) 138 energy_key = self.energy_key if self.energy_key is not None else self.name 139 output = {**inputs, energy_key: erep_atomic * energy_unit} 140 if self.trainable and training: 141 output[energy_key + "_regularization"] = reg 142 143 return output 144 145 146class RepulsionNLH(nn.Module): 147 """NLH pairwise repulsive potential with pair-specific coefficients up to Z=92 148 149 FID: REPULSION_NLH 150 151 ### Reference 152 K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 153 https://doi.org/10.1103/PhysRevA.111.032818 154 """ 155 156 _graphs_properties: Dict 157 graph_key: str = "graph" 158 """The key for the graph input.""" 159 energy_key: Optional[str] = None 160 """The key for the output energy.""" 161 _energy_unit: str = "Ha" 162 """The energy unit of the model. **Automatically set by FENNIX**""" 163 trainable: bool = False 164 direct_forces_key: Optional[str] = None 165 166 FID: ClassVar[str] = "REPULSION_NLH" 167 168 @nn.compact 169 def __call__(self, inputs): 170 171 path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat" 172 DATA_NLH = np.loadtxt(path, usecols=np.arange(0, 8)) 173 zmax = int(np.max(DATA_NLH[:, 0])) 174 AB = np.zeros(((zmax + 1) ** 2, 6), dtype=np.float32) 175 for i in range(DATA_NLH.shape[0]): 176 z1 = int(DATA_NLH[i, 0]) 177 z2 = int(DATA_NLH[i, 1]) 178 AB[z1 + zmax * z2] = DATA_NLH[i, 2:8] 179 AB[z2 + zmax * z1] = DATA_NLH[i, 2:8] 180 AB = AB.reshape((zmax + 1) ** 2, 3, 2) 181 182 species = inputs["species"] 183 graph = inputs[self.graph_key] 184 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 185 rijs = graph["distances"] 186 187 # coefficients (a1,a2,a3) 188 CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype) 189 # exponents (b1,b2,b3) 190 ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype) 191 192 if self.trainable: 193 cfact = jnp.abs( 194 self.param( 195 "c_fact", 196 lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype), 197 ) 198 ) 199 CS = CS * cfact[None, :] 200 CS = CS / jnp.sum(CS, axis=1, keepdims=True) 201 alphas_fact = jnp.abs( 202 self.param( 203 "alpha_fact", 204 lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype), 205 ) 206 ) 207 ALPHAS = ALPHAS * alphas_fact[None, :] 208 209 s12 = species[edge_src] + zmax * species[edge_dst] 210 cs = CS[s12] 211 alphas = ALPHAS[s12] 212 213 if "alch_group" in inputs: 214 switch = graph["switch_raw"] 215 lambda_v = inputs["alch_vlambda"] 216 alch_group = inputs["alch_group"] 217 alch_m = inputs.get("alch_m", 2) 218 219 mask = alch_group[edge_src] == alch_group[edge_dst] 220 if "alch_softcore_rep" in inputs: 221 alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v) 222 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5) 223 lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v)) 224 switch = jnp.where( 225 mask, 226 switch, 227 (lambda_v**alch_m) * switch, 228 ) 229 # alphas = jnp.where( 230 # mask[:,None], 231 # alphas, 232 # lambda_v * alphas , 233 # ) 234 else: 235 switch = graph["switch"] 236 237 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 238 phi = (cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1) 239 Zij = Z[edge_src] * Z[edge_dst] * switch 240 241 ereppair = Zij * phi / rijs 242 243 energy_unit = au.get_multiplier(self._energy_unit) 244 erep_atomic = (energy_unit * 0.5 * au.BOHR) * jax.ops.segment_sum( 245 ereppair, edge_src, species.shape[0] 246 ) 247 248 energy_key = self.energy_key if self.energy_key is not None else self.name 249 output = {**inputs, energy_key: erep_atomic} 250 251 if self.direct_forces_key is not None: 252 dphidr = -(alphas * cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1) 253 dedr = Zij * (dphidr / rijs - phi / (rijs**2)) 254 dedij = (dedr / rijs)[:, None] * graph["vec"] 255 fi = (energy_unit * au.BOHR) * jax.ops.segment_sum( 256 dedij, edge_src, species.shape[0] 257 ) 258 output[self.direct_forces_key] = fi 259 260 return output
12class RepulsionZBL(nn.Module): 13 """Repulsion energy based on the Ziegler-Biersack-Littmark potential 14 15 FID: REPULSION_ZBL 16 17 ### Reference 18 J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter 19 20 """ 21 22 _graphs_properties: Dict 23 graph_key: str = "graph" 24 """The key for the graph input.""" 25 energy_key: Optional[str] = None 26 """The key for the output energy.""" 27 trainable: bool = True 28 """Whether the parameters are trainable.""" 29 _energy_unit: str = "Ha" 30 """The energy unit of the model. **Automatically set by FENNIX**""" 31 proportional_regularization: bool = True 32 d: float = 0.46850 / au.BOHR 33 p: float = 0.23 34 alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162) 35 cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697) 36 cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514) 37 38 FID: ClassVar[str] = "REPULSION_ZBL" 39 40 @nn.compact 41 def __call__(self, inputs): 42 species = inputs["species"] 43 graph = inputs[self.graph_key] 44 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 45 46 training = "training" in inputs.get("flags", {}) 47 48 rijs = graph["distances"] / au.BOHR 49 50 d_ = self.d 51 p_ = self.p 52 assert len(self.alphas) == 4, "alphas must be a sequence of length 4" 53 alphas_ = np.array(self.alphas, dtype=rijs.dtype) 54 assert len(self.cs) == 4, "cs must be a sequence of length 4" 55 cs_ = np.array(self.cs, dtype=rijs.dtype) 56 cs_ = 0.5 * cs_ / np.sum(cs_) 57 if self.trainable: 58 d = jnp.abs( 59 self.param( 60 "d", 61 lambda key, d: jnp.asarray(d, dtype=rijs.dtype), 62 d_, 63 ) 64 ) 65 p = jnp.abs( 66 self.param( 67 "p", 68 lambda key, p: jnp.asarray(p, dtype=rijs.dtype), 69 p_, 70 ) 71 ) 72 cs = 0.5 * jax.nn.softmax( 73 self.param( 74 "cs", 75 lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype), 76 np.array(self.cs_logits, dtype=rijs.dtype), 77 ) 78 ) 79 alphas = jnp.abs( 80 self.param( 81 "alphas", 82 lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype), 83 alphas_, 84 ) 85 ) 86 87 if training: 88 if self.proportional_regularization: 89 reg = jnp.asarray( 90 ((1 - alphas / alphas_) ** 2).sum() 91 + ((1 - cs / cs_) ** 2).sum() 92 + (1 - p / p_) ** 2 93 + (1 - d / d_) ** 2 94 ).reshape(1) 95 else: 96 reg = jnp.asarray( 97 ((alphas_ - alphas) ** 2).sum() 98 + ((cs_ - cs) ** 2).sum() 99 + (p_ - p) ** 2 100 + (d_ - d) ** 2 101 ).reshape(1) 102 else: 103 cs = jnp.asarray(cs_) 104 alphas = jnp.asarray(alphas_) 105 d = d_ 106 p = p_ 107 108 if "alch_group" in inputs: 109 switch = graph["switch_raw"] 110 lambda_v = inputs["alch_vlambda"] 111 alch_group = inputs["alch_group"] 112 alch_m = inputs.get("alch_m", 2) 113 114 mask = alch_group[edge_src] == alch_group[edge_dst] 115 116 if "alch_softcore_rep" in inputs: 117 alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v) 118 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5) 119 lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v)) 120 switch = jnp.where( 121 mask, 122 switch, 123 (lambda_v**alch_m) * switch, 124 ) 125 else: 126 switch = graph["switch"] 127 128 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 129 Zij = Z[edge_src] * Z[edge_dst] 130 Zp = Z**p / d 131 x = rijs * (Zp[edge_src] + Zp[edge_dst]) 132 phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1) 133 134 ereppair = Zij * phi / rijs * switch 135 136 erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 137 138 energy_unit = au.get_multiplier(self._energy_unit) 139 energy_key = self.energy_key if self.energy_key is not None else self.name 140 output = {**inputs, energy_key: erep_atomic * energy_unit} 141 if self.trainable and training: 142 output[energy_key + "_regularization"] = reg 143 144 return output
Repulsion energy based on the Ziegler-Biersack-Littmark potential
FID: REPULSION_ZBL
Reference
J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
147class RepulsionNLH(nn.Module): 148 """NLH pairwise repulsive potential with pair-specific coefficients up to Z=92 149 150 FID: REPULSION_NLH 151 152 ### Reference 153 K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 154 https://doi.org/10.1103/PhysRevA.111.032818 155 """ 156 157 _graphs_properties: Dict 158 graph_key: str = "graph" 159 """The key for the graph input.""" 160 energy_key: Optional[str] = None 161 """The key for the output energy.""" 162 _energy_unit: str = "Ha" 163 """The energy unit of the model. **Automatically set by FENNIX**""" 164 trainable: bool = False 165 direct_forces_key: Optional[str] = None 166 167 FID: ClassVar[str] = "REPULSION_NLH" 168 169 @nn.compact 170 def __call__(self, inputs): 171 172 path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat" 173 DATA_NLH = np.loadtxt(path, usecols=np.arange(0, 8)) 174 zmax = int(np.max(DATA_NLH[:, 0])) 175 AB = np.zeros(((zmax + 1) ** 2, 6), dtype=np.float32) 176 for i in range(DATA_NLH.shape[0]): 177 z1 = int(DATA_NLH[i, 0]) 178 z2 = int(DATA_NLH[i, 1]) 179 AB[z1 + zmax * z2] = DATA_NLH[i, 2:8] 180 AB[z2 + zmax * z1] = DATA_NLH[i, 2:8] 181 AB = AB.reshape((zmax + 1) ** 2, 3, 2) 182 183 species = inputs["species"] 184 graph = inputs[self.graph_key] 185 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 186 rijs = graph["distances"] 187 188 # coefficients (a1,a2,a3) 189 CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype) 190 # exponents (b1,b2,b3) 191 ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype) 192 193 if self.trainable: 194 cfact = jnp.abs( 195 self.param( 196 "c_fact", 197 lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype), 198 ) 199 ) 200 CS = CS * cfact[None, :] 201 CS = CS / jnp.sum(CS, axis=1, keepdims=True) 202 alphas_fact = jnp.abs( 203 self.param( 204 "alpha_fact", 205 lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype), 206 ) 207 ) 208 ALPHAS = ALPHAS * alphas_fact[None, :] 209 210 s12 = species[edge_src] + zmax * species[edge_dst] 211 cs = CS[s12] 212 alphas = ALPHAS[s12] 213 214 if "alch_group" in inputs: 215 switch = graph["switch_raw"] 216 lambda_v = inputs["alch_vlambda"] 217 alch_group = inputs["alch_group"] 218 alch_m = inputs.get("alch_m", 2) 219 220 mask = alch_group[edge_src] == alch_group[edge_dst] 221 if "alch_softcore_rep" in inputs: 222 alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v) 223 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5) 224 lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v)) 225 switch = jnp.where( 226 mask, 227 switch, 228 (lambda_v**alch_m) * switch, 229 ) 230 # alphas = jnp.where( 231 # mask[:,None], 232 # alphas, 233 # lambda_v * alphas , 234 # ) 235 else: 236 switch = graph["switch"] 237 238 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 239 phi = (cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1) 240 Zij = Z[edge_src] * Z[edge_dst] * switch 241 242 ereppair = Zij * phi / rijs 243 244 energy_unit = au.get_multiplier(self._energy_unit) 245 erep_atomic = (energy_unit * 0.5 * au.BOHR) * jax.ops.segment_sum( 246 ereppair, edge_src, species.shape[0] 247 ) 248 249 energy_key = self.energy_key if self.energy_key is not None else self.name 250 output = {**inputs, energy_key: erep_atomic} 251 252 if self.direct_forces_key is not None: 253 dphidr = -(alphas * cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1) 254 dedr = Zij * (dphidr / rijs - phi / (rijs**2)) 255 dedij = (dedr / rijs)[:, None] * graph["vec"] 256 fi = (energy_unit * au.BOHR) * jax.ops.segment_sum( 257 dedij, edge_src, species.shape[0] 258 ) 259 output[self.direct_forces_key] = fi 260 261 return output
NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
FID: REPULSION_NLH
Reference
K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 https://doi.org/10.1103/PhysRevA.111.032818
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.