fennol.models.physics.dispersion
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import numpy as np 5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 6from ...utils import AtomicUnits as au 7from ...utils.periodic_table import ( 8 D3_ELECTRONEGATIVITIES, 9 D3_HARDNESSES, 10 D3_VDW_RADII, 11 D3_COV_RADII, 12 D3_KAPPA, 13 VDW_RADII, 14 POLARIZABILITIES, 15 C6_FREE, 16 VALENCE_ELECTRONS, 17) 18import pathlib 19import pickle 20 21class VdwOQDO(nn.Module): 22 """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model. 23 24 FID : VDW_OQDO 25 26 ### Reference 27 A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, 28 J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797) 29 30 """ 31 graph_key: str = "graph" 32 """ The key for the graph input.""" 33 include_exchange: bool = True 34 """ Whether to compute the exchange part.""" 35 ratiovol_key: Optional[str] = None 36 """ The key for the ratio between AIM volume and free-atom volume. 37 If None, the volume ratio is assumed to be 1.0.""" 38 energy_key: Optional[str] = None 39 """ The key for the output energy. If None, the name of the module is used.""" 40 damped: bool = True 41 """ Whether to use short-range damping.""" 42 _energy_unit: str = "Ha" 43 """The energy unit of the model. **Automatically set by FENNIX**""" 44 45 FID: ClassVar[str] = "VDW_OQDO" 46 47 @nn.compact 48 def __call__(self, inputs): 49 energy_unit = au.get_multiplier(self._energy_unit) 50 51 species = inputs["species"] 52 graph = inputs[self.graph_key] 53 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 54 rij = graph["distances"] / au.BOHR 55 switch = graph["switch"] 56 57 c6 = jnp.asarray(C6_FREE)[species] 58 alpha = jnp.asarray(POLARIZABILITIES)[species] 59 60 if self.ratiovol_key is not None: 61 ratiovol = inputs[self.ratiovol_key] + 1.0e-6 62 if ratiovol.shape[-1] == 1: 63 ratiovol = jnp.squeeze(ratiovol, axis=-1) 64 c6 = c6 * ratiovol**2 65 alpha = alpha * ratiovol 66 67 c6i, c6j = c6[edge_src], c6[edge_dst] 68 alphai, alphaj = alpha[edge_src], alpha[edge_dst] 69 70 # combination rules 71 alphaij = 0.5 * (alphai + alphaj) 72 c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2) 73 74 # equilibrium distance 75 Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 76 Re2 = Re**2 77 Re4 = Re**4 78 # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators" 79 if self.damped: 80 muw = ( 81 4.83053463e-01 82 - 3.76191669e-02 * Re 83 + 1.27066988e-03 * Re2 84 - 7.21940151e-07 * Re4 85 ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2) 86 else: 87 muw = ( 88 3.66316787e01 89 - 5.79579187 * Re 90 + 3.02674813e-01 * Re2 91 - 3.65461255e-04 * Re4 92 ) / (-1.46169102e01 + 7.32461225 * Re) 93 94 c8ij = 5 * c6ij / muw 95 c10ij = 245 * c6ij / (8 * muw**2) 96 97 if self.damped: 98 z = 0.5 * muw * rij**2 99 ez = jnp.exp(-z) 100 f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3) 101 f8 = f6 - (1.0 / 24.0) * ez * z**4 102 f10 = f8 - (1.0 / 120.0) * ez * z**5 103 epair = ( 104 f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10 105 ) 106 else: 107 epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10 108 109 edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0]) 110 111 output_key = self.name if self.energy_key is None else self.energy_key 112 113 if not self.include_exchange: 114 return {**inputs, output_key: edisp} 115 116 ### exchange 117 w = 4 * c6ij / (3 * alphaij**2) 118 # q = (alphaij * mu*w**2)**0.5 119 q2 = alphaij * muw * w 120 # undamped case 121 if self.damped: 122 ze = 0.5 * muw * Re2 123 eze = jnp.exp(-ze) 124 125 s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3) 126 f6e = 1.0 - s6 127 muwRe = muw * Re 128 df6e = muwRe * s6 - eze * ( 129 muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3 130 ) 131 132 s8 = (1.0 / 24.0) * eze * ze**4 133 f8e = f6e - s8 134 df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4 135 136 s10 = (1.0 / 120.0) * eze * ze**5 137 f10e = f8e - s10 138 df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5 139 140 den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e) 141 A = ( 142 0.5 143 + c8ij * (8 * f8e - Re * df8e) / den 144 + c10ij * (10 * f10e - Re * df10e) / (den * Re2) 145 ) 146 else: 147 A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4) 148 ez = jnp.exp(-0.5 * muw * rij**2) 149 150 exij = A * q2 * ez / rij 151 ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0]) 152 153 return { 154 **inputs, 155 output_key + "_dispersion": edisp, 156 output_key + "_exchange": ex, 157 output_key: edisp + ex, 158 } 159 160 161class DispersionD3(nn.Module): 162 """ 163 164 FID : DISPERSION_D3 165 166 """ 167 168 _graphs_properties: Dict 169 graph_key: str = "graph" 170 output_key: Optional[str] = None 171 s6: float = 1.0 172 s8: float = 1.0 173 a1: float = 0.4 174 a2: float = 5.0 175 _energy_unit: str = "Ha" 176 """The energy unit of the model. **Automatically set by FENNIX**""" 177 178 FID: ClassVar[str] = "DISPERSION_D3" 179 180 @nn.compact 181 def __call__(self, inputs): 182 183 path = str(pathlib.Path(__file__).parent.resolve()) + "/ref_data_d3.pkl" 184 with open(path, "rb") as f: 185 DATA_D3 = pickle.load(f) 186 187 graph = inputs[self.graph_key] 188 edge_src = graph["edge_src"] 189 edge_dst = graph["edge_dst"] 190 switch = graph["switch"] 191 species = inputs["species"] 192 193 rij = jnp.clip(graph["distances"] / au.BOHR, 1e-6, None) 194 195 ## RADII (in BOHR) 196 rcov = jnp.array(DATA_D3["COV_D3"])[species] 197 # rvdw = jnp.array(DATA_D3["VDW_D3"])[species] 198 r4r2 = jnp.array(DATA_D3["R4R2"])[species] 199 200 rcij = rcov[edge_src] + rcov[edge_dst] 201 202 ## COORDINATION NUMBER 203 KCN = 16.0 204 cnij = jax.nn.sigmoid(KCN * (rcij / rij - 1.0)) 205 cn = jax.ops.segment_sum(cnij, edge_src, species.shape[0]) 206 207 ## WEIGHTS 208 refcn = jnp.array(DATA_D3["REF_CN"])[species] 209 mask = refcn >= 0 210 dcn = refcn - cn[:, None] 211 KW = 4.0 212 weights = jnp.where(mask, jnp.exp(-KW * dcn**2), 0.0) 213 norm = weights.sum(axis=1, keepdims=True) 214 weights = jnp.where(mask, weights / jnp.clip(norm, 1e-6, None), 0.0) 215 216 ## correct for all null weights 217 imaxcn = np.argmax(DATA_D3["REF_CN"], axis=1) 218 exweight = np.zeros_like(DATA_D3["REF_CN"]) 219 for i, imax in enumerate(imaxcn): 220 exweight[i, imax] = 1.0 221 exweight = jnp.array(exweight)[species] 222 223 exceptional = norm < 1.0e-6 224 weights = jnp.where(exceptional, exweight, weights) 225 226 ## C6 coefficients 227 REF_C6 = DATA_D3["REF_C6"] 228 nz = REF_C6.shape[0] 229 nref = REF_C6.shape[-1] 230 REF_C6 = jnp.array(REF_C6.reshape((nz * nz, nref, nref))) 231 pair_num = species[edge_src] * nz + species[edge_dst] 232 rc6 = REF_C6[pair_num] 233 c6 = jnp.einsum("iab,ia,ib->i", rc6, weights[edge_src], weights[edge_dst]) 234 235 ## DISPERSION 236 237 qq = 3 * r4r2[edge_src] * r4r2[edge_dst] 238 c8 = c6 * qq 239 240 r0 = self.a1 * jnp.sqrt(qq) + self.a2 241 242 t6 = self.s6 / (rij**6 + r0**6) 243 t8 = self.s8 / (rij**8 + r0**8) 244 245 energy_unit = au.get_multiplier(self._energy_unit) 246 energy = (-0.5*energy_unit) * jax.ops.segment_sum( 247 (c6 * t6 + c8 * t8) * switch, edge_src, species.shape[0] 248 ) 249 250 output_key = self.output_key if self.output_key is not None else self.name 251 return {**inputs, output_key: energy}
22class VdwOQDO(nn.Module): 23 """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model. 24 25 FID : VDW_OQDO 26 27 ### Reference 28 A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, 29 J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797) 30 31 """ 32 graph_key: str = "graph" 33 """ The key for the graph input.""" 34 include_exchange: bool = True 35 """ Whether to compute the exchange part.""" 36 ratiovol_key: Optional[str] = None 37 """ The key for the ratio between AIM volume and free-atom volume. 38 If None, the volume ratio is assumed to be 1.0.""" 39 energy_key: Optional[str] = None 40 """ The key for the output energy. If None, the name of the module is used.""" 41 damped: bool = True 42 """ Whether to use short-range damping.""" 43 _energy_unit: str = "Ha" 44 """The energy unit of the model. **Automatically set by FENNIX**""" 45 46 FID: ClassVar[str] = "VDW_OQDO" 47 48 @nn.compact 49 def __call__(self, inputs): 50 energy_unit = au.get_multiplier(self._energy_unit) 51 52 species = inputs["species"] 53 graph = inputs[self.graph_key] 54 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 55 rij = graph["distances"] / au.BOHR 56 switch = graph["switch"] 57 58 c6 = jnp.asarray(C6_FREE)[species] 59 alpha = jnp.asarray(POLARIZABILITIES)[species] 60 61 if self.ratiovol_key is not None: 62 ratiovol = inputs[self.ratiovol_key] + 1.0e-6 63 if ratiovol.shape[-1] == 1: 64 ratiovol = jnp.squeeze(ratiovol, axis=-1) 65 c6 = c6 * ratiovol**2 66 alpha = alpha * ratiovol 67 68 c6i, c6j = c6[edge_src], c6[edge_dst] 69 alphai, alphaj = alpha[edge_src], alpha[edge_dst] 70 71 # combination rules 72 alphaij = 0.5 * (alphai + alphaj) 73 c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2) 74 75 # equilibrium distance 76 Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 77 Re2 = Re**2 78 Re4 = Re**4 79 # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators" 80 if self.damped: 81 muw = ( 82 4.83053463e-01 83 - 3.76191669e-02 * Re 84 + 1.27066988e-03 * Re2 85 - 7.21940151e-07 * Re4 86 ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2) 87 else: 88 muw = ( 89 3.66316787e01 90 - 5.79579187 * Re 91 + 3.02674813e-01 * Re2 92 - 3.65461255e-04 * Re4 93 ) / (-1.46169102e01 + 7.32461225 * Re) 94 95 c8ij = 5 * c6ij / muw 96 c10ij = 245 * c6ij / (8 * muw**2) 97 98 if self.damped: 99 z = 0.5 * muw * rij**2 100 ez = jnp.exp(-z) 101 f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3) 102 f8 = f6 - (1.0 / 24.0) * ez * z**4 103 f10 = f8 - (1.0 / 120.0) * ez * z**5 104 epair = ( 105 f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10 106 ) 107 else: 108 epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10 109 110 edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0]) 111 112 output_key = self.name if self.energy_key is None else self.energy_key 113 114 if not self.include_exchange: 115 return {**inputs, output_key: edisp} 116 117 ### exchange 118 w = 4 * c6ij / (3 * alphaij**2) 119 # q = (alphaij * mu*w**2)**0.5 120 q2 = alphaij * muw * w 121 # undamped case 122 if self.damped: 123 ze = 0.5 * muw * Re2 124 eze = jnp.exp(-ze) 125 126 s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3) 127 f6e = 1.0 - s6 128 muwRe = muw * Re 129 df6e = muwRe * s6 - eze * ( 130 muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3 131 ) 132 133 s8 = (1.0 / 24.0) * eze * ze**4 134 f8e = f6e - s8 135 df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4 136 137 s10 = (1.0 / 120.0) * eze * ze**5 138 f10e = f8e - s10 139 df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5 140 141 den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e) 142 A = ( 143 0.5 144 + c8ij * (8 * f8e - Re * df8e) / den 145 + c10ij * (10 * f10e - Re * df10e) / (den * Re2) 146 ) 147 else: 148 A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4) 149 ez = jnp.exp(-0.5 * muw * rij**2) 150 151 exij = A * q2 * ez / rij 152 ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0]) 153 154 return { 155 **inputs, 156 output_key + "_dispersion": edisp, 157 output_key + "_exchange": ex, 158 output_key: edisp + ex, 159 }
Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
FID : VDW_OQDO
Reference
A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
The key for the ratio between AIM volume and free-atom volume. If None, the volume ratio is assumed to be 1.0.
The key for the output energy. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
162class DispersionD3(nn.Module): 163 """ 164 165 FID : DISPERSION_D3 166 167 """ 168 169 _graphs_properties: Dict 170 graph_key: str = "graph" 171 output_key: Optional[str] = None 172 s6: float = 1.0 173 s8: float = 1.0 174 a1: float = 0.4 175 a2: float = 5.0 176 _energy_unit: str = "Ha" 177 """The energy unit of the model. **Automatically set by FENNIX**""" 178 179 FID: ClassVar[str] = "DISPERSION_D3" 180 181 @nn.compact 182 def __call__(self, inputs): 183 184 path = str(pathlib.Path(__file__).parent.resolve()) + "/ref_data_d3.pkl" 185 with open(path, "rb") as f: 186 DATA_D3 = pickle.load(f) 187 188 graph = inputs[self.graph_key] 189 edge_src = graph["edge_src"] 190 edge_dst = graph["edge_dst"] 191 switch = graph["switch"] 192 species = inputs["species"] 193 194 rij = jnp.clip(graph["distances"] / au.BOHR, 1e-6, None) 195 196 ## RADII (in BOHR) 197 rcov = jnp.array(DATA_D3["COV_D3"])[species] 198 # rvdw = jnp.array(DATA_D3["VDW_D3"])[species] 199 r4r2 = jnp.array(DATA_D3["R4R2"])[species] 200 201 rcij = rcov[edge_src] + rcov[edge_dst] 202 203 ## COORDINATION NUMBER 204 KCN = 16.0 205 cnij = jax.nn.sigmoid(KCN * (rcij / rij - 1.0)) 206 cn = jax.ops.segment_sum(cnij, edge_src, species.shape[0]) 207 208 ## WEIGHTS 209 refcn = jnp.array(DATA_D3["REF_CN"])[species] 210 mask = refcn >= 0 211 dcn = refcn - cn[:, None] 212 KW = 4.0 213 weights = jnp.where(mask, jnp.exp(-KW * dcn**2), 0.0) 214 norm = weights.sum(axis=1, keepdims=True) 215 weights = jnp.where(mask, weights / jnp.clip(norm, 1e-6, None), 0.0) 216 217 ## correct for all null weights 218 imaxcn = np.argmax(DATA_D3["REF_CN"], axis=1) 219 exweight = np.zeros_like(DATA_D3["REF_CN"]) 220 for i, imax in enumerate(imaxcn): 221 exweight[i, imax] = 1.0 222 exweight = jnp.array(exweight)[species] 223 224 exceptional = norm < 1.0e-6 225 weights = jnp.where(exceptional, exweight, weights) 226 227 ## C6 coefficients 228 REF_C6 = DATA_D3["REF_C6"] 229 nz = REF_C6.shape[0] 230 nref = REF_C6.shape[-1] 231 REF_C6 = jnp.array(REF_C6.reshape((nz * nz, nref, nref))) 232 pair_num = species[edge_src] * nz + species[edge_dst] 233 rc6 = REF_C6[pair_num] 234 c6 = jnp.einsum("iab,ia,ib->i", rc6, weights[edge_src], weights[edge_dst]) 235 236 ## DISPERSION 237 238 qq = 3 * r4r2[edge_src] * r4r2[edge_dst] 239 c8 = c6 * qq 240 241 r0 = self.a1 * jnp.sqrt(qq) + self.a2 242 243 t6 = self.s6 / (rij**6 + r0**6) 244 t8 = self.s8 / (rij**8 + r0**8) 245 246 energy_unit = au.get_multiplier(self._energy_unit) 247 energy = (-0.5*energy_unit) * jax.ops.segment_sum( 248 (c6 * t6 + c8 * t8) * switch, edge_src, species.shape[0] 249 ) 250 251 output_key = self.output_key if self.output_key is not None else self.name 252 return {**inputs, output_key: energy}
FID : DISPERSION_D3
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.