fennol.models.physics.repulsion
1import pathlib 2import jax 3import jax.numpy as jnp 4import flax.linen as nn 5import numpy as np 6from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 7from ...utils import AtomicUnits as au 8from ...utils.periodic_table import D3_COV_RADII,UFF_VDW_RADII 9 10 11class RepulsionZBL(nn.Module): 12 """Repulsion energy based on the Ziegler-Biersack-Littmark potential 13 14 FID: REPULSION_ZBL 15 16 ### Reference 17 J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter 18 19 """ 20 21 _graphs_properties: Dict 22 graph_key: str = "graph" 23 """The key for the graph input.""" 24 energy_key: Optional[str] = None 25 """The key for the output energy.""" 26 trainable: bool = True 27 """Whether the parameters are trainable.""" 28 _energy_unit: str = "Ha" 29 """The energy unit of the model. **Automatically set by FENNIX**""" 30 proportional_regularization: bool = True 31 d: float = 0.46850/au.BOHR 32 p: float = 0.23 33 alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162) 34 cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697) 35 cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514) 36 37 FID: ClassVar[str] = "REPULSION_ZBL" 38 39 @nn.compact 40 def __call__(self, inputs): 41 species = inputs["species"] 42 graph = inputs[self.graph_key] 43 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 44 45 training = "training" in inputs.get("flags", {}) 46 47 rijs = graph["distances"] / au.BOHR 48 49 d_ = self.d 50 p_ = self.p 51 assert len(self.alphas) == 4, "alphas must be a sequence of length 4" 52 alphas_ = np.array(self.alphas, dtype=rijs.dtype) 53 assert len(self.cs) == 4, "cs must be a sequence of length 4" 54 cs_ = np.array(self.cs, dtype=rijs.dtype) 55 cs_ = 0.5 * cs_ / np.sum(cs_) 56 if self.trainable: 57 d = jnp.abs( 58 self.param( 59 "d", 60 lambda key, d: jnp.asarray(d, dtype=rijs.dtype), 61 d_, 62 ) 63 ) 64 p = jnp.abs( 65 self.param( 66 "p", 67 lambda key, p: jnp.asarray(p, dtype=rijs.dtype), 68 p_, 69 ) 70 ) 71 cs = 0.5 * jax.nn.softmax( 72 self.param( 73 "cs", 74 lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype), 75 np.array(self.cs_logits, dtype=rijs.dtype), 76 ) 77 ) 78 alphas = jnp.abs( 79 self.param( 80 "alphas", 81 lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype), 82 alphas_, 83 ) 84 ) 85 86 if training: 87 if self.proportional_regularization: 88 reg = jnp.asarray( 89 ((1 - alphas/alphas_) ** 2).sum() 90 + ((1 - cs/cs_) ** 2).sum() 91 + (1 - p/p_) ** 2 92 + (1 - d/d_) ** 2 93 ).reshape(1) 94 else: 95 reg = jnp.asarray( 96 ((alphas_ - alphas) ** 2).sum() 97 + ((cs_ - cs) ** 2).sum() 98 + (p_ - p) ** 2 99 + (d_ - d) ** 2 100 ).reshape(1) 101 else: 102 cs = jnp.asarray(cs_) 103 alphas = jnp.asarray(alphas_) 104 d = d_ 105 p = p_ 106 107 if "alch_group" in inputs: 108 switch = graph["switch_raw"] 109 lambda_v = inputs["alch_vlambda"] 110 alch_group = inputs["alch_group"] 111 alch_alpha = inputs.get("alch_alpha", 0.5) 112 alch_m = inputs.get("alch_m", 2) 113 114 mask = alch_group[edge_src] == alch_group[edge_dst] 115 116 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5) 117 lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v)) 118 switch = jnp.where( 119 mask, 120 switch, 121 (lambda_v**alch_m) * switch , 122 ) 123 else: 124 switch = graph["switch"] 125 126 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 127 Zij = Z[edge_src]*Z[edge_dst] 128 Zp = Z**p / d 129 x = rijs * (Zp[edge_src] + Zp[edge_dst]) 130 phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1) 131 132 ereppair = Zij * phi / rijs * switch 133 134 erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 135 136 energy_unit = au.get_multiplier(self._energy_unit) 137 energy_key = self.energy_key if self.energy_key is not None else self.name 138 output = {**inputs, energy_key: erep_atomic * energy_unit} 139 if self.trainable and training: 140 output[energy_key + "_regularization"] = reg 141 142 return output 143 144class RepulsionNLH(nn.Module): 145 """ NLH pairwise repulsive potential with pair-specific coefficients up to Z=92 146 147 FID: REPULSION_NLH 148 149 ### Reference 150 K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 151 https://doi.org/10.1103/PhysRevA.111.032818 152 """ 153 154 _graphs_properties: Dict 155 graph_key: str = "graph" 156 """The key for the graph input.""" 157 energy_key: Optional[str] = None 158 """The key for the output energy.""" 159 _energy_unit: str = "Ha" 160 """The energy unit of the model. **Automatically set by FENNIX**""" 161 trainable: bool = False 162 direct_forces_key: Optional[str] = None 163 164 FID: ClassVar[str] = "REPULSION_NLH" 165 166 @nn.compact 167 def __call__(self, inputs): 168 169 path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat" 170 DATA_NLH = np.loadtxt(path,usecols=np.arange(0,8)) 171 zmax = int(np.max(DATA_NLH[:, 0])) 172 AB = np.zeros(((zmax+1)**2,6), dtype=np.float32) 173 for i in range(DATA_NLH.shape[0]): 174 z1 = int(DATA_NLH[i, 0]) 175 z2 = int(DATA_NLH[i, 1]) 176 AB[z1+zmax*z2] = DATA_NLH[i, 2:8] 177 AB[z2+zmax*z1] = DATA_NLH[i, 2:8] 178 AB = AB.reshape((zmax+1)**2, 3,2) 179 180 species = inputs["species"] 181 graph = inputs[self.graph_key] 182 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 183 rijs = graph["distances"] 184 185 # coefficients (a1,a2,a3) 186 CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype) 187 # exponents (b1,b2,b3) 188 ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype) 189 190 if self.trainable: 191 cfact = jnp.abs(self.param( 192 "c_fact", 193 lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype), 194 )) 195 CS = CS * cfact[None,:] 196 CS = CS / jnp.sum(CS, axis=1, keepdims=True) 197 alphas_fact = jnp.abs(self.param( 198 "alpha_fact", 199 lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype), 200 )) 201 ALPHAS = ALPHAS * alphas_fact[None,:] 202 203 s12 = species[edge_src] + zmax*species[edge_dst] 204 cs = CS[s12] 205 alphas = ALPHAS[s12] 206 207 if "alch_group" in inputs: 208 switch = graph["switch_raw"] 209 lambda_v = inputs["alch_vlambda"] 210 alch_group = inputs["alch_group"] 211 alch_alpha = inputs.get("alch_alpha", 0.) 212 alch_m = inputs.get("alch_m", 2) 213 214 mask = alch_group[edge_src] == alch_group[edge_dst] 215 216 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5) 217 lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v)) 218 switch = jnp.where( 219 mask, 220 switch, 221 (lambda_v**alch_m) * switch , 222 ) 223 # alphas = jnp.where( 224 # mask[:,None], 225 # alphas, 226 # lambda_v * alphas , 227 # ) 228 else: 229 switch = graph["switch"] 230 231 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 232 phi = (cs * jnp.exp(-alphas*rijs[:,None])).sum(axis=-1) 233 Zij = Z[edge_src]*Z[edge_dst]*switch 234 235 ereppair = Zij * phi / rijs 236 237 energy_unit = au.get_multiplier(self._energy_unit) 238 erep_atomic = (energy_unit*0.5*au.BOHR)*jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 239 240 energy_key = self.energy_key if self.energy_key is not None else self.name 241 output = {**inputs, energy_key: erep_atomic } 242 243 if self.direct_forces_key is not None: 244 dphidr = -(alphas*cs * jnp.exp(-alphas*rijs[:,None])).sum(axis=-1) 245 dedr = Zij * (dphidr/rijs - phi/(rijs**2)) 246 dedij = (dedr/rijs)[:,None] * graph["vec"] 247 fi = (energy_unit*au.BOHR)*jax.ops.segment_sum(dedij, edge_src, species.shape[0]) 248 output[self.direct_forces_key] = fi 249 250 return output
12class RepulsionZBL(nn.Module): 13 """Repulsion energy based on the Ziegler-Biersack-Littmark potential 14 15 FID: REPULSION_ZBL 16 17 ### Reference 18 J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter 19 20 """ 21 22 _graphs_properties: Dict 23 graph_key: str = "graph" 24 """The key for the graph input.""" 25 energy_key: Optional[str] = None 26 """The key for the output energy.""" 27 trainable: bool = True 28 """Whether the parameters are trainable.""" 29 _energy_unit: str = "Ha" 30 """The energy unit of the model. **Automatically set by FENNIX**""" 31 proportional_regularization: bool = True 32 d: float = 0.46850/au.BOHR 33 p: float = 0.23 34 alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162) 35 cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697) 36 cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514) 37 38 FID: ClassVar[str] = "REPULSION_ZBL" 39 40 @nn.compact 41 def __call__(self, inputs): 42 species = inputs["species"] 43 graph = inputs[self.graph_key] 44 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 45 46 training = "training" in inputs.get("flags", {}) 47 48 rijs = graph["distances"] / au.BOHR 49 50 d_ = self.d 51 p_ = self.p 52 assert len(self.alphas) == 4, "alphas must be a sequence of length 4" 53 alphas_ = np.array(self.alphas, dtype=rijs.dtype) 54 assert len(self.cs) == 4, "cs must be a sequence of length 4" 55 cs_ = np.array(self.cs, dtype=rijs.dtype) 56 cs_ = 0.5 * cs_ / np.sum(cs_) 57 if self.trainable: 58 d = jnp.abs( 59 self.param( 60 "d", 61 lambda key, d: jnp.asarray(d, dtype=rijs.dtype), 62 d_, 63 ) 64 ) 65 p = jnp.abs( 66 self.param( 67 "p", 68 lambda key, p: jnp.asarray(p, dtype=rijs.dtype), 69 p_, 70 ) 71 ) 72 cs = 0.5 * jax.nn.softmax( 73 self.param( 74 "cs", 75 lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype), 76 np.array(self.cs_logits, dtype=rijs.dtype), 77 ) 78 ) 79 alphas = jnp.abs( 80 self.param( 81 "alphas", 82 lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype), 83 alphas_, 84 ) 85 ) 86 87 if training: 88 if self.proportional_regularization: 89 reg = jnp.asarray( 90 ((1 - alphas/alphas_) ** 2).sum() 91 + ((1 - cs/cs_) ** 2).sum() 92 + (1 - p/p_) ** 2 93 + (1 - d/d_) ** 2 94 ).reshape(1) 95 else: 96 reg = jnp.asarray( 97 ((alphas_ - alphas) ** 2).sum() 98 + ((cs_ - cs) ** 2).sum() 99 + (p_ - p) ** 2 100 + (d_ - d) ** 2 101 ).reshape(1) 102 else: 103 cs = jnp.asarray(cs_) 104 alphas = jnp.asarray(alphas_) 105 d = d_ 106 p = p_ 107 108 if "alch_group" in inputs: 109 switch = graph["switch_raw"] 110 lambda_v = inputs["alch_vlambda"] 111 alch_group = inputs["alch_group"] 112 alch_alpha = inputs.get("alch_alpha", 0.5) 113 alch_m = inputs.get("alch_m", 2) 114 115 mask = alch_group[edge_src] == alch_group[edge_dst] 116 117 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5) 118 lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v)) 119 switch = jnp.where( 120 mask, 121 switch, 122 (lambda_v**alch_m) * switch , 123 ) 124 else: 125 switch = graph["switch"] 126 127 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 128 Zij = Z[edge_src]*Z[edge_dst] 129 Zp = Z**p / d 130 x = rijs * (Zp[edge_src] + Zp[edge_dst]) 131 phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1) 132 133 ereppair = Zij * phi / rijs * switch 134 135 erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 136 137 energy_unit = au.get_multiplier(self._energy_unit) 138 energy_key = self.energy_key if self.energy_key is not None else self.name 139 output = {**inputs, energy_key: erep_atomic * energy_unit} 140 if self.trainable and training: 141 output[energy_key + "_regularization"] = reg 142 143 return output
Repulsion energy based on the Ziegler-Biersack-Littmark potential
FID: REPULSION_ZBL
Reference
J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
145class RepulsionNLH(nn.Module): 146 """ NLH pairwise repulsive potential with pair-specific coefficients up to Z=92 147 148 FID: REPULSION_NLH 149 150 ### Reference 151 K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 152 https://doi.org/10.1103/PhysRevA.111.032818 153 """ 154 155 _graphs_properties: Dict 156 graph_key: str = "graph" 157 """The key for the graph input.""" 158 energy_key: Optional[str] = None 159 """The key for the output energy.""" 160 _energy_unit: str = "Ha" 161 """The energy unit of the model. **Automatically set by FENNIX**""" 162 trainable: bool = False 163 direct_forces_key: Optional[str] = None 164 165 FID: ClassVar[str] = "REPULSION_NLH" 166 167 @nn.compact 168 def __call__(self, inputs): 169 170 path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat" 171 DATA_NLH = np.loadtxt(path,usecols=np.arange(0,8)) 172 zmax = int(np.max(DATA_NLH[:, 0])) 173 AB = np.zeros(((zmax+1)**2,6), dtype=np.float32) 174 for i in range(DATA_NLH.shape[0]): 175 z1 = int(DATA_NLH[i, 0]) 176 z2 = int(DATA_NLH[i, 1]) 177 AB[z1+zmax*z2] = DATA_NLH[i, 2:8] 178 AB[z2+zmax*z1] = DATA_NLH[i, 2:8] 179 AB = AB.reshape((zmax+1)**2, 3,2) 180 181 species = inputs["species"] 182 graph = inputs[self.graph_key] 183 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 184 rijs = graph["distances"] 185 186 # coefficients (a1,a2,a3) 187 CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype) 188 # exponents (b1,b2,b3) 189 ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype) 190 191 if self.trainable: 192 cfact = jnp.abs(self.param( 193 "c_fact", 194 lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype), 195 )) 196 CS = CS * cfact[None,:] 197 CS = CS / jnp.sum(CS, axis=1, keepdims=True) 198 alphas_fact = jnp.abs(self.param( 199 "alpha_fact", 200 lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype), 201 )) 202 ALPHAS = ALPHAS * alphas_fact[None,:] 203 204 s12 = species[edge_src] + zmax*species[edge_dst] 205 cs = CS[s12] 206 alphas = ALPHAS[s12] 207 208 if "alch_group" in inputs: 209 switch = graph["switch_raw"] 210 lambda_v = inputs["alch_vlambda"] 211 alch_group = inputs["alch_group"] 212 alch_alpha = inputs.get("alch_alpha", 0.) 213 alch_m = inputs.get("alch_m", 2) 214 215 mask = alch_group[edge_src] == alch_group[edge_dst] 216 217 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5) 218 lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v)) 219 switch = jnp.where( 220 mask, 221 switch, 222 (lambda_v**alch_m) * switch , 223 ) 224 # alphas = jnp.where( 225 # mask[:,None], 226 # alphas, 227 # lambda_v * alphas , 228 # ) 229 else: 230 switch = graph["switch"] 231 232 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 233 phi = (cs * jnp.exp(-alphas*rijs[:,None])).sum(axis=-1) 234 Zij = Z[edge_src]*Z[edge_dst]*switch 235 236 ereppair = Zij * phi / rijs 237 238 energy_unit = au.get_multiplier(self._energy_unit) 239 erep_atomic = (energy_unit*0.5*au.BOHR)*jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 240 241 energy_key = self.energy_key if self.energy_key is not None else self.name 242 output = {**inputs, energy_key: erep_atomic } 243 244 if self.direct_forces_key is not None: 245 dphidr = -(alphas*cs * jnp.exp(-alphas*rijs[:,None])).sum(axis=-1) 246 dedr = Zij * (dphidr/rijs - phi/(rijs**2)) 247 dedij = (dedr/rijs)[:,None] * graph["vec"] 248 fi = (energy_unit*au.BOHR)*jax.ops.segment_sum(dedij, edge_src, species.shape[0]) 249 output[self.direct_forces_key] = fi 250 251 return output
NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
FID: REPULSION_NLH
Reference
K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 https://doi.org/10.1103/PhysRevA.111.032818
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.