fennol.models.physics.electrostatics
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import numpy as np 5 6# from jaxopt.linear_solve import solve_cg, solve_iterative_refinement, solve_gmres 7from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 8from ...utils import AtomicUnits as au 9import dataclasses 10from ...utils.periodic_table import ( 11 D3_ELECTRONEGATIVITIES, 12 D3_HARDNESSES, 13 D3_VDW_RADII, 14 D3_COV_RADII, 15 D3_KAPPA, 16 VDW_RADII, 17 POLARIZABILITIES, 18 VALENCE_ELECTRONS, 19) 20import math 21 22 23def prepare_reciprocal_space( 24 cells, reciprocal_cells, coordinates, batch_index, k_points, bewald 25): 26 """Prepare variables for Ewald summation in reciprocal space""" 27 A = reciprocal_cells 28 if A.shape[0] == 1: 29 s = coordinates @ A[0] 30 ks = 2j * jnp.pi * jnp.einsum("ai,ki-> ak", s, k_points[0]) # nat x nk 31 else: 32 s = jnp.einsum("aj,aji->ai", coordinates,A[batch_index]) 33 ks = ( 34 2j * jnp.pi * jnp.einsum("ai,aki-> ak", s, k_points[batch_index]) 35 ) # nat x nk 36 37 m2 = jnp.sum( 38 jnp.einsum("ski,sji->skj", k_points, A) ** 2, 39 axis=-1, 40 ) # nsys x nk 41 a2 = (jnp.pi / bewald) ** 2 42 expfac = jnp.exp(-a2 * m2) / m2 # nsys x nk 43 44 volume = jnp.abs(jnp.linalg.det(cells)) # nsys 45 phiscale = (au.BOHR / jnp.pi) / volume 46 selfscale = bewald * (2 * au.BOHR / jnp.pi**0.5) 47 return batch_index, k_points, phiscale, selfscale, expfac, ks 48 49 50def ewald_reciprocal(q, batch_index, k_points, phiscale, selfscale, expfac, ks): 51 """Compute Coulomb interactions in reciprocal space using Ewald summation""" 52 if phiscale.shape[0] == 1: 53 Sm = (q[:, None] * jnp.exp(ks)).sum(axis=0)[None, :] # nys x nk 54 else: 55 Sm = jax.ops.segment_sum( 56 q[:, None] * jnp.exp(ks), batch_index, k_points.shape[0] 57 ) # nsys x nk 58 59 ### compute reciprocal Coulomb potential (https://arxiv.org/abs/1805.10363) 60 phi = ( 61 jnp.real(((Sm * expfac)[batch_index] * jnp.exp(-ks)).sum(axis=-1)) 62 * phiscale[batch_index] 63 - q * selfscale 64 ) 65 66 return 0.5 * q * phi, phi 67 68 69# def ewald_reciprocal( 70# q, cells, reciprocal_cells, coordinates, batch_index, k_points, bewald 71# ): 72# A = reciprocal_cells 73# ### Ewald reciprocal space 74 75# if A.shape[0] == 1: 76# s = jnp.einsum("ij,aj->ai", A[0], coordinates) 77# ks = 2j * jnp.pi * jnp.einsum("ai,ki-> ak", s, k_points[0]) # nat x nk 78# Sm = (q[:, None] * jnp.exp(ks)).sum(axis=0)[None, :] # nys x nk 79# else: 80# s = jnp.einsum("aij,aj->ai", A[batch_index], coordinates) 81# ks = ( 82# 2j * jnp.pi * jnp.einsum("ai,aki-> ak", s, k_points[batch_index]) 83# ) # nat x nk 84# Sm = jax.ops.segment_sum( 85# q[:, None] * jnp.exp(ks), batch_index, k_points.shape[0] 86# ) # nsys x nk 87 88# m2 = jnp.sum( 89# jnp.einsum("sij,ski->skj", A, k_points) ** 2, 90# axis=-1, 91# ) # nsys x nk 92# a2 = (jnp.pi / bewald) ** 2 93# expfac = Sm * jnp.exp(-a2 * m2) / m2 # nsys x nk 94# volume = jnp.linalg.det(cells) # nsys 95 96# ### compute reciprocal Coulomb potential (https://arxiv.org/abs/1805.10363) 97# phi = jnp.real((expfac[batch_index] * jnp.exp(-ks)).sum(axis=-1)) * ( 98# (au.BOHR / jnp.pi) / volume[batch_index] 99# ) - q * (bewald * (2 * au.BOHR / jnp.pi**0.5)) 100 101# return 0.5 * q * phi, phi 102 103 104class Coulomb(nn.Module): 105 """Coulomb interaction between distributed point charges 106 107 FID: COULOMB 108 109 """ 110 _graphs_properties: Dict 111 graph_key: str = "graph" 112 """Key for the graph in the inputs""" 113 charges_key: str = "charges" 114 """Key for the charges in the inputs""" 115 energy_key: Optional[str] = None 116 """Key for the energy in the outputs""" 117 # switch_fraction: float = 0.9 118 scale: Optional[float] = None 119 """Scaling factor for the energy""" 120 charge_scale: Optional[float] = None 121 """Scaling factor for the charges""" 122 damp_style: Optional[str] = None 123 """Damping style. Available options are: None, 'TS', 'OQDO', 'D3', 'SPOOKY', 'CP', 'KEY'""" 124 damp_params: Dict = dataclasses.field(default_factory=dict) 125 """Damping parameters""" 126 bscreen: float = -1.0 127 """Screening parameter. If >0, the Coulomb potential becomes a Yukawa potential and the reciprocal space is not computed""" 128 trainable: bool = True 129 """Whether the parameters are trainable""" 130 _energy_unit: str = "Ha" 131 """The energy unit of the model. **Automatically set by FENNIX**""" 132 133 FID: ClassVar[str] = "COULOMB" 134 135 @nn.compact 136 def __call__(self, inputs): 137 species = inputs["species"] 138 graph = inputs[self.graph_key] 139 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 140 distances = graph["distances"] 141 switch = graph["switch"] 142 143 rij = distances / au.BOHR 144 q = inputs[self.charges_key] 145 if q.shape[-1] == 1: 146 q = jnp.squeeze(q, axis=-1) 147 if self.charge_scale is not None: 148 if self.trainable: 149 charge_scale = jnp.abs( 150 self.param( 151 "charge_scale", lambda key: jnp.asarray(self.charge_scale) 152 ) 153 ) 154 else: 155 charge_scale = self.charge_scale 156 q = q * charge_scale 157 158 damp_style = self.damp_style.upper() if self.damp_style is not None else None 159 160 do_recip = self.bscreen <= 0.0 and "k_points" in graph 161 162 if self.bscreen > 0.0: 163 # dirfact = jax.scipy.special.erfc(self.bscreen * distances) 164 dirfact = jnp.exp(-self.bscreen * distances) 165 elif do_recip: 166 k_points = graph["k_points"] 167 bewald = graph["b_ewald"] 168 cells = inputs["cells"] 169 reciprocal_cells = inputs["reciprocal_cells"] 170 batch_index = inputs["batch_index"] 171 erec, _ = ewald_reciprocal( 172 q, 173 *prepare_reciprocal_space( 174 cells, 175 reciprocal_cells, 176 inputs["coordinates"], 177 batch_index, 178 k_points, 179 bewald, 180 ), 181 ) 182 dirfact = jax.scipy.special.erfc(bewald * distances) 183 else: 184 dirfact = 1.0 185 186 if damp_style is None: 187 Aij = switch * dirfact / rij 188 eat = ( 189 0.5 190 * q 191 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 192 ) 193 194 elif damp_style == "TS": 195 cpB = self.damp_params.get("cpB", 3.5) 196 s = self.damp_params.get("s", 2.4) 197 198 if self.trainable: 199 cpB = jnp.abs(self.param("cpB", lambda key: jnp.asarray(cpB))) 200 s = jnp.abs(self.param("s", lambda key: jnp.asarray(s))) 201 202 ratiovol_key = self.damp_params.get("ratiovol_key", None) 203 if ratiovol_key is not None: 204 ratiovol = inputs[ratiovol_key] 205 if ratiovol.shape[-1] == 1: 206 ratiovol = jnp.squeeze(ratiovol, axis=-1) 207 rvdw = jnp.asarray(VDW_RADII)[species] * ratiovol ** (1.0 / 3.0) 208 else: 209 rvdw = jnp.asarray(VDW_RADII)[species] 210 Rij = rvdw[edge_src] + rvdw[edge_dst] 211 Bij = cpB * (rij / Rij) ** s 212 213 eBij = jnp.where(Bij < 20.0, jnp.exp(-Bij), 0.0) 214 215 Aij = (dirfact - eBij) / rij * switch 216 eat = ( 217 0.5 218 * q 219 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 220 ) 221 222 elif damp_style == "OQDO": 223 ratiovol_key = self.damp_params.get("ratiovol_key", None) 224 alpha = jnp.asarray(POLARIZABILITIES)[species] 225 if ratiovol_key is not None: 226 ratiovol = inputs[ratiovol_key] 227 if ratiovol.shape[-1] == 1: 228 ratiovol = jnp.squeeze(ratiovol, axis=-1) 229 alpha = alpha * ratiovol 230 231 alphaij = 0.5 * (alpha[edge_src] + alpha[edge_dst]) 232 Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 233 Re2 = Re**2 234 Re4 = Re**4 235 muw = ( 236 3.66316787e01 237 - 5.79579187 * Re 238 + 3.02674813e-01 * Re2 239 - 3.65461255e-04 * Re4 240 ) / (-1.46169102e01 + 7.32461225 * Re) 241 # muw = ( 242 # 4.83053463e-01 243 # - 3.76191669e-02 * Re 244 # + 1.27066988e-03 * Re2 245 # - 7.21940151e-07 * Re4 246 # ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2) 247 Bij = 0.5 * muw * rij**2 248 249 eBij = jnp.where(Bij < 20.0, jnp.exp(-Bij), 0.0) 250 251 Aij = (dirfact - eBij) / rij * switch 252 eat = ( 253 0.5 254 * q 255 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 256 ) 257 258 elif damp_style == "D3": 259 ratiovol_key = self.damp_params.get("ratiovol_key", None) 260 if ratiovol_key is not None: 261 ratiovol = inputs[ratiovol_key] 262 if ratiovol.shape[-1] == 1: 263 ratiovol = jnp.squeeze(ratiovol, axis=-1) 264 else: 265 ratiovol = 1.0 266 267 gamma_scheme = self.damp_params.get("gamma_scheme", "D3") 268 if gamma_scheme == "D3": 269 if self.trainable: 270 rvdw = jnp.abs( 271 self.param("rvdw", lambda key: jnp.asarray(VDW_RADII)) 272 )[species] 273 else: 274 rvdw = jnp.asarray(VDW_RADII)[species] 275 rvdw = rvdw * ratiovol ** (1.0 / 3.0) 276 ai2 = rvdw**2 277 gamma_ij = (ai2[edge_src] + ai2[edge_dst] + 1.0e-3) ** (-0.5) 278 279 elif gamma_scheme == "QDO": 280 gscale = self.damp_params.get("gamma_scale", 2.0) 281 if self.trainable: 282 gscale = jnp.abs( 283 self.param("gamma_scale", lambda key: jnp.asarray(gscale)) 284 ) 285 alpha = jnp.abs( 286 self.param("alpha", lambda key: jnp.asarray(POLARIZABILITIES)) 287 )[species] 288 else: 289 alpha = jnp.asarray(POLARIZABILITIES)[species] 290 alpha = alpha * ratiovol 291 alphaij = 0.5 * (alpha[edge_src] + alpha[edge_dst]) 292 rvdwij = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 293 gamma_ij = gscale / rvdwij 294 else: 295 raise NotImplementedError( 296 f"gamma_scheme {gamma_scheme} not implemented" 297 ) 298 299 Aij = (dirfact - jax.scipy.special.erfc(gamma_ij * rij)) / rij * switch 300 301 eat = ( 302 0.5 303 * q 304 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 305 ) 306 307 elif damp_style == "SPOOKY": 308 shortrange_cutoff = self.damp_params.get("shortrange_cutoff", 5.0) 309 r_on = self.damp_params.get("r_on", 0.25) * shortrange_cutoff 310 r_off = self.damp_params.get("r_off", 0.75) * shortrange_cutoff 311 x1 = (distances - r_on) / (r_off - r_on) 312 x2 = 1.0 - x1 313 mask1 = x1 <= 1.0e-6 314 mask2 = x2 <= 1.0e-6 315 x1 = jnp.where(mask1, 1.0, x1) 316 x2 = jnp.where(mask2, 1.0, x2) 317 s1 = jnp.where(mask1, 0.0, jnp.exp(-1.0 / x1)) 318 s2 = jnp.where(mask2, 0.0, jnp.exp(-1.0 / x2)) 319 Bij = s2 / (s1 + s2) 320 321 Aij = Bij / (rij**2 + 1) ** 0.5 + (dirfact - Bij) / rij * switch 322 eat = ( 323 0.5 324 * q 325 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 326 ) 327 328 elif damp_style == "CP": 329 cpA = self.damp_params.get("cpA", 4.42) 330 cpB = self.damp_params.get("cpB", 4.12) 331 gamma = self.damp_params.get("gamma", 0.5) 332 333 if self.trainable: 334 cpA = jnp.abs(self.param("cpA", lambda key: jnp.asarray(cpA))) 335 cpB = jnp.abs(self.param("cpB", lambda key: jnp.asarray(cpB))) 336 gamma = jnp.abs(self.param("gamma", lambda key: jnp.asarray(gamma))) 337 338 rvdw = jnp.asarray(VDW_RADII)[species] 339 ratiovol_key = self.damp_params.get("ratiovol_key", None) 340 if ratiovol_key is not None: 341 ratiovol = inputs[ratiovol_key] 342 if ratiovol.shape[-1] == 1: 343 ratiovol = jnp.squeeze(ratiovol, axis=-1) 344 rvdw = rvdw * ratiovol ** (1.0 / 3.0) 345 346 Zv = jnp.asarray(VALENCE_ELECTRONS)[species] 347 Zi, Zj = Zv[edge_src], Zv[edge_dst] 348 qi, qj = q[edge_src], q[edge_dst] 349 rvdwi, rvdwj = rvdw[edge_src], rvdw[edge_dst] 350 351 eAi = jnp.exp(-cpA * rij / rvdwi) 352 eAj = jnp.exp(-cpA * rij / rvdwj) 353 eBi = jnp.exp(-cpB * rij / rvdwi) 354 eBj = jnp.exp(-cpB * rij / rvdwj) 355 eBij = eBi * eBj - eBi - eBj 356 357 Bshort = jnp.exp(-gamma * distances**4) 358 Dshort = 1.0 - Bshort 359 ecp = Dshort * ( 360 Zi * Zj * (eAi + eAj + eBij) 361 - qi * Zj * (eAi + eBij) 362 - qj * Zi * (eAj + eBij) 363 ) 364 365 # eq = qi * qj * (1 + eBij) * (1 - Bshort) 366 eqq = qi * qj * (dirfact - Bshort + eBij * Dshort) 367 368 epair = (ecp + eqq) * switch / rij 369 370 # epair = ( 371 # (1 - Bshort) 372 # * ( 373 # Zi * Zj * (eAi + eAj + eBij) 374 # - qi * Zj * (eAi + eBij) 375 # - qj * Zi * (eAj + eBij) 376 # + qi * qj * (1 + eBij) 377 # ) 378 # * switch 379 # / rij 380 # ) 381 eat = 0.5 * jax.ops.segment_sum(epair, edge_src, species.shape[0]) 382 383 elif damp_style == "KEY": 384 damp_key = self.damp_params["key"] 385 damp = inputs[damp_key] 386 epair = (dirfact - damp) * switch / rij 387 eat = 0.5 * jax.ops.segment_sum(epair, edge_src, species.shape[0]) 388 else: 389 raise NotImplementedError(f"damp_style {self.damp_style} not implemented") 390 391 if do_recip: 392 eat = eat + erec 393 394 if self.scale is not None: 395 if self.trainable: 396 scale = jnp.abs( 397 self.param("scale", lambda key: jnp.asarray(self.scale)) 398 ) 399 else: 400 scale = self.scale 401 eat = eat * scale 402 403 energy_key = self.energy_key if self.energy_key is not None else self.name 404 energy_unit = au.get_multiplier(self._energy_unit) 405 out = {**inputs, energy_key: eat*energy_unit} 406 if do_recip: 407 out[energy_key + "_reciprocal"] = erec*energy_unit 408 return out 409 410 411class QeqD4(nn.Module): 412 """ QEq-D4 charge equilibration scheme 413 414 FID: QEQ_D4 415 416 ### Reference 417 E. Caldeweyher et al.,A generally applicable atomic-charge dependent London dispersion correction, 418 J Chem Phys. 2019 Apr 21;150(15):154122. (https://doi.org/10.1063/1.5090222) 419 """ 420 graph_key: str = "graph" 421 """Key for the graph in the inputs""" 422 trainable: bool = False 423 """Whether the parameters are trainable""" 424 charges_key: str = "charges" 425 """Key for the charges in the outputs. 426 If charges are provided in the inputs, they are not re-optimized and we only compute the energy""" 427 energy_key: Optional[str] = None 428 """Key for the energy in the outputs""" 429 chi_key: Optional[str] = None 430 """Key for additional electronegativity in the inputs""" 431 c3_key: Optional[str] = None 432 """Key for additional c3 in the inputs. Only used if charges are provided in the inputs""" 433 c4_key: Optional[str] = None 434 """Key for additional c4 in the inputs. Only used if charges are provided in the inputs""" 435 total_charge_key: str = "total_charge" 436 """Key for the total charge in the inputs""" 437 non_interacting_guess: bool = False 438 """Whether to use the non-interacting limit as an initial guess.""" 439 solver: str = "gmres" 440 _energy_unit: str = "Ha" 441 """The energy unit of the model. **Automatically set by FENNIX**""" 442 443 FID: ClassVar[str] = "QEQ_D4" 444 445 @nn.compact 446 def __call__(self, inputs): 447 species = inputs["species"] 448 graph = inputs[self.graph_key] 449 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 450 switch = graph["switch"] 451 452 rij = graph["distances"] / au.BOHR 453 454 do_recip = "k_points" in graph 455 if do_recip: 456 k_points = graph["k_points"] 457 bewald = graph["b_ewald"] 458 cells = inputs["cells"] 459 reciprocal_cells = inputs["reciprocal_cells"] 460 batch_index = inputs["batch_index"] 461 dirfact = jax.scipy.special.erfc(bewald * graph["distances"]) 462 ewald_params = prepare_reciprocal_space( 463 cells, 464 reciprocal_cells, 465 inputs["coordinates"], 466 batch_index, 467 k_points, 468 bewald, 469 ) 470 else: 471 dirfact = 1.0 472 473 Jii = D3_HARDNESSES 474 ai = D3_VDW_RADII 475 ETA = [jii + (2.0 / np.pi) ** 0.5 / aii for jii, aii in zip(Jii, ai)] 476 477 # D3 parameters 478 if self.trainable: 479 ENi = self.param("EN", lambda key: jnp.asarray(D3_ELECTRONEGATIVITIES))[ 480 species 481 ] 482 # Jii = self.param("J", lambda key: jnp.asarray(D3_HARDNESSES))[species] 483 eta = jnp.abs(self.param("eta", lambda key: jnp.asarray(ETA)))[species] 484 ai = jnp.abs(self.param("a", lambda key: jnp.asarray(D3_VDW_RADII)))[ 485 species 486 ] 487 rci = jnp.abs(self.param("rc", lambda key: jnp.asarray(D3_COV_RADII)))[ 488 species 489 ] 490 c3 = self.param("c3", lambda key: jnp.zeros(len(D3_ELECTRONEGATIVITIES)))[ 491 species 492 ] 493 c4 = jnp.abs( 494 self.param("c4", lambda key: jnp.zeros(len(D3_ELECTRONEGATIVITIES)))[ 495 species 496 ] 497 ) 498 kappai = self.param("kappa", lambda key: jnp.asarray(D3_KAPPA))[species] 499 k1 = jnp.abs(self.param("k1", lambda key: jnp.asarray(7.5))) 500 training = "training" in inputs.get("flags", {}) 501 if training: 502 regularization = ( 503 (ENi - jnp.asarray(D3_ELECTRONEGATIVITIES)[species]) ** 2 504 + (eta - jnp.asarray(ETA)[species]) ** 2 505 + (ai - jnp.asarray(D3_VDW_RADII)[species]) ** 2 506 + (rci - jnp.asarray(D3_COV_RADII)[species]) ** 2 507 + (kappai - jnp.asarray(D3_KAPPA)[species]) ** 2 508 + (k1 - 7.5) ** 2 509 ) 510 else: 511 c3 = jnp.zeros_like(species, dtype=jnp.float32) 512 c4 = jnp.zeros_like(species, dtype=jnp.float32) 513 ENi = jnp.asarray(D3_ELECTRONEGATIVITIES)[species] 514 # Jii = jnp.asarray(D3_HARDNESSES)[species] 515 eta = jnp.asarray(ETA)[species] 516 ai = jnp.asarray(D3_VDW_RADII)[species] 517 rci = jnp.asarray(D3_COV_RADII)[species] 518 kappai = jnp.asarray(D3_KAPPA)[species] 519 k1 = 7.5 520 521 ai2 = ai**2 522 rcij = ( 523 rci.at[edge_src].get(mode="fill", fill_value=1.0) 524 + rci.at[edge_dst].get(mode="fill", fill_value=1.0) 525 + 1.0e-3 526 ) 527 mCNij = 1.0 + jax.scipy.special.erf(-k1 * (rij / rcij - 1)) 528 mCNi = 0.5 * jax.ops.segment_sum(mCNij * switch, edge_src, species.shape[0]) 529 chi = ENi - kappai * (mCNi + 1.0e-3) ** 0.5 530 if self.chi_key is not None: 531 chi = chi + inputs[self.chi_key] 532 533 gamma_ij = ( 534 ai2.at[edge_src].get(mode="fill", fill_value=1.0) 535 + ai2.at[edge_dst].get(mode="fill", fill_value=1.0) 536 + 1.0e-3 537 ) ** (-0.5) 538 539 Aii = eta # Jii + ((2.0 / np.pi) ** 0.5) / ai 540 # Aij = jax.scipy.special.erf(gamma_ij * rij) / rij * switch 541 Aij = (dirfact - jax.scipy.special.erfc(gamma_ij * rij)) / rij * switch 542 543 if self.charges_key in inputs: 544 q = inputs[self.charges_key] 545 q_ = q 546 else: 547 nsys = inputs["natoms"].shape[0] 548 batch_index = inputs["batch_index"] 549 550 def matvec(x): 551 l, q = jnp.split(x, (nsys,)) 552 Aq_self = Aii * q 553 qdest = q.at[edge_dst].get(mode="fill", fill_value=0.0) 554 Aq_pair = jax.ops.segment_sum(Aij * qdest, edge_src, species.shape[0]) 555 Aq = ( 556 Aq_self 557 + Aq_pair 558 + l.at[batch_index].get(mode="fill", fill_value=0.0) 559 ) 560 if do_recip: 561 _, phirec = ewald_reciprocal(q, *ewald_params) 562 Aq = Aq + phirec 563 Al = jax.ops.segment_sum(q, batch_index, nsys) 564 return jnp.concatenate((Al, Aq)) 565 566 Qtot = ( 567 inputs[self.total_charge_key].astype(chi.dtype).reshape(nsys) 568 if self.total_charge_key in inputs 569 else jnp.zeros(nsys, dtype=chi.dtype) 570 ) 571 b = jnp.concatenate([Qtot, -chi]) 572 573 if self.non_interacting_guess: 574 # build initial guess 575 si = 1./Aii 576 q0 = -chi*si 577 qtot = jax.ops.segment_sum(q0,batch_index,nsys) 578 sisum = jax.ops.segment_sum(si,batch_index,nsys) 579 l0 = sisum*(qtot - Qtot) 580 q0 = q0 - si*l0[batch_index] 581 x0 = jnp.concatenate((l0,q0)) 582 else: 583 x0 = None 584 585 solver = self.solver.lower() 586 if solver == "bicg": 587 x = jax.scipy.sparse.linalg.bicgstab(matvec, b,x0=x0)[0] 588 elif solver == "gmres": 589 x = jax.scipy.sparse.linalg.gmres(matvec, b,x0=x0)[0] 590 elif solver == "cg": 591 print("Warning: Use of cg solver for Qeq is not recommended") 592 x = jax.scipy.sparse.linalg.cg(matvec, b,x0=x0)[0] 593 else: 594 raise NotImplementedError(f"solver '{solver}' is not implemented. Choose one of [bicg, gmres]") 595 596 597 q = x[nsys:] 598 q_ = jax.lax.stop_gradient(q) 599 600 eself = 0.5 * Aii * q_**2 + chi * q_ 601 602 phi = jax.ops.segment_sum(Aij * q_[edge_dst], edge_src, species.shape[0]) 603 if do_recip: 604 erec, _ = ewald_reciprocal(q_, *ewald_params) 605 epair = 0.5 * q_ * phi 606 607 if self.charges_key in inputs: 608 if self.c3_key is not None: 609 c3 = c3 + inputs[self.c3_key] 610 if self.c4_key is not None: 611 c4 = c4 + inputs[self.c4_key] 612 eself = eself + c3 * q_**3 + c4 * q_**4 613 training = "training" in inputs.get("flags", {}) 614 if self.trainable and training: 615 Aii_ = jax.lax.stop_gradient(Aii) 616 chi_ = jax.lax.stop_gradient(chi) 617 phi_ = jax.lax.stop_gradient(phi) 618 c3_ = jax.lax.stop_gradient(c3) 619 c4_ = jax.lax.stop_gradient(c4) 620 switch_ = jax.lax.stop_gradient(switch) 621 Aij_ = jax.lax.stop_gradient(Aij) 622 phi_ = jax.ops.segment_sum( 623 Aij_ * q[edge_dst], edge_src, species.shape[0] 624 ) 625 626 dedq = Aii_ * q + chi_ + phi_ + 3 * c3_ * q**2 + 4 * c4_ * q**3 627 dedq = jax.ops.segment_sum( 628 switch_ * (dedq[edge_src] - dedq[edge_dst]) ** 2, 629 edge_src, 630 species.shape[0], 631 ) 632 etrain = ( 633 0.5 * Aii_ * q**2 634 + chi_ * q 635 + 0.5 * q * phi_ 636 + c3_ * q**3 637 + c4_ * q**4 638 ) 639 if do_recip: 640 etrain = etrain + erec 641 642 energy = eself + epair 643 if do_recip: 644 energy = energy + erec 645 646 energy_key = self.energy_key if self.energy_key is not None else self.name 647 energy_unit = au.get_multiplier(self._energy_unit) 648 output = { 649 **inputs, 650 self.charges_key: q, 651 energy_key: energy*energy_unit, 652 } 653 if do_recip: 654 output[energy_key + "_reciprocal"] = erec*energy_unit 655 656 training = "training" in inputs.get("flags", {}) 657 if self.charges_key in inputs and self.trainable and training: 658 output[energy_key + "_regularization"] = regularization 659 output[energy_key + "_dedq"] = dedq*energy_unit 660 output[energy_key + "_etrain"] = etrain*energy_unit 661 return output 662 663 664class ChargeCorrection(nn.Module): 665 """Charge correction scheme 666 667 FID: CHARGE_CORRECTION 668 669 Used to correct the provided charges to sum to the total charge of the system. 670 """ 671 key: str = "charges" 672 """Key for the charges in the inputs""" 673 output_key: str = None 674 """Key for the corrected charges in the outputs. If None, it is the same as the input key""" 675 dq_key: str = "delta_qtot" 676 """Key for the deviation of the raw charge sum in the outputs""" 677 ratioeta_key: str = None 678 """Key for the ratio of hardness between AIM and free atom in the inputs. Used to adjust charge redistribution.""" 679 trainable: bool = False 680 """Whether the parameters are trainable""" 681 cn_key: str = None 682 """Key for the coordination number in the inputs. Used to adjust charge redistribution.""" 683 total_charge_key: str = "total_charge" 684 """Key for the total charge in the inputs""" 685 _energy_unit: str = "Ha" 686 """The energy unit of the model. **Automatically set by FENNIX**""" 687 688 FID: ClassVar[str] = "CHARGE_CORRECTION" 689 690 @nn.compact 691 def __call__(self, inputs) -> Any: 692 species = inputs["species"] 693 batch_index = inputs["batch_index"] 694 nsys = inputs["natoms"].shape[0] 695 q = inputs[self.key] 696 if q.shape[-1] == 1: 697 q = jnp.squeeze(q, axis=-1) 698 qtot = jax.ops.segment_sum(q, batch_index, nsys) 699 Qtot = ( 700 inputs[self.total_charge_key].astype(q.dtype) 701 if self.total_charge_key in inputs 702 else jnp.zeros(qtot.shape[0], dtype=q.dtype) 703 ) 704 dq = Qtot - qtot 705 706 Jii = D3_HARDNESSES 707 ai = D3_VDW_RADII 708 eta = [jii + (2.0 / np.pi) ** 0.5 / aii for jii, aii in zip(Jii, ai)] 709 if self.trainable: 710 eta = jnp.abs(self.param("eta", lambda key: jnp.asarray(eta)))[species] 711 else: 712 eta = jnp.asarray(eta)[species] 713 714 if self.ratioeta_key is not None: 715 ratioeta = inputs[self.ratioeta_key] 716 if ratioeta.shape[-1] == 1: 717 ratioeta = jnp.squeeze(ratioeta, axis=-1) 718 eta = eta * ratioeta 719 720 s = (1.0e-6 + 2 * jnp.abs(eta)) ** (-1) 721 if self.cn_key is not None: 722 cn = inputs[self.cn_key] 723 if cn.shape[-1] == 1: 724 cn = jnp.squeeze(cn, axis=-1) 725 s = s * cn 726 727 f = dq / jax.ops.segment_sum(s, batch_index, nsys) 728 729 qf = q + s * f[batch_index] 730 731 energy_unit = au.get_multiplier(self._energy_unit) 732 ecorr = (0.5*energy_unit) * eta * (qf - q) ** 2 733 output_key = self.output_key if self.output_key is not None else self.key 734 return { 735 **inputs, 736 output_key: qf, 737 self.dq_key: dq, 738 "charge_correction_energy": ecorr, 739 } 740 741class DistributeElectrons(nn.Module): 742 """Distribute valence electrons between the atoms 743 744 FID: DISTRIBUTE_ELECTRONS 745 746 Used to predict charges that sum to the total charge of the system. 747 """ 748 embedding_key: str 749 """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight""" 750 output_key: Union[str,None] = None 751 """Key for the charges in the outputs""" 752 total_charge_key: str = "total_charge" 753 """Key for the total charge in the inputs""" 754 755 FID: ClassVar[str] = "DISTRIBUTE_ELECTRONS" 756 757 @nn.compact 758 def __call__(self, inputs) -> Any: 759 species = inputs["species"] 760 Nel = jnp.asarray(VALENCE_ELECTRONS)[species] 761 762 ei = nn.Dense(1, use_bias=True, name="wi")(inputs[self.embedding_key]).squeeze(-1) 763 wi = jax.nn.softplus(ei) 764 765 batch_index = inputs["batch_index"] 766 nsys = inputs["natoms"].shape[0] 767 wtot = jax.ops.segment_sum(wi, inputs["batch_index"], inputs["natoms"].shape[0]) 768 769 Qtot = ( 770 inputs[self.total_charge_key].astype(ei.dtype) 771 if self.total_charge_key in inputs 772 else jnp.zeros(nsys, dtype=ei.dtype) 773 ) 774 Neltot = jax.ops.segment_sum(Nel, batch_index, nsys) - Qtot 775 776 f = Neltot / wtot 777 Ni = wi* f[batch_index] 778 q = Nel-Ni 779 780 781 output_key = self.output_key if self.output_key is not None else self.name 782 return { 783 **inputs, 784 output_key: q, 785 }
24def prepare_reciprocal_space( 25 cells, reciprocal_cells, coordinates, batch_index, k_points, bewald 26): 27 """Prepare variables for Ewald summation in reciprocal space""" 28 A = reciprocal_cells 29 if A.shape[0] == 1: 30 s = coordinates @ A[0] 31 ks = 2j * jnp.pi * jnp.einsum("ai,ki-> ak", s, k_points[0]) # nat x nk 32 else: 33 s = jnp.einsum("aj,aji->ai", coordinates,A[batch_index]) 34 ks = ( 35 2j * jnp.pi * jnp.einsum("ai,aki-> ak", s, k_points[batch_index]) 36 ) # nat x nk 37 38 m2 = jnp.sum( 39 jnp.einsum("ski,sji->skj", k_points, A) ** 2, 40 axis=-1, 41 ) # nsys x nk 42 a2 = (jnp.pi / bewald) ** 2 43 expfac = jnp.exp(-a2 * m2) / m2 # nsys x nk 44 45 volume = jnp.abs(jnp.linalg.det(cells)) # nsys 46 phiscale = (au.BOHR / jnp.pi) / volume 47 selfscale = bewald * (2 * au.BOHR / jnp.pi**0.5) 48 return batch_index, k_points, phiscale, selfscale, expfac, ks
Prepare variables for Ewald summation in reciprocal space
51def ewald_reciprocal(q, batch_index, k_points, phiscale, selfscale, expfac, ks): 52 """Compute Coulomb interactions in reciprocal space using Ewald summation""" 53 if phiscale.shape[0] == 1: 54 Sm = (q[:, None] * jnp.exp(ks)).sum(axis=0)[None, :] # nys x nk 55 else: 56 Sm = jax.ops.segment_sum( 57 q[:, None] * jnp.exp(ks), batch_index, k_points.shape[0] 58 ) # nsys x nk 59 60 ### compute reciprocal Coulomb potential (https://arxiv.org/abs/1805.10363) 61 phi = ( 62 jnp.real(((Sm * expfac)[batch_index] * jnp.exp(-ks)).sum(axis=-1)) 63 * phiscale[batch_index] 64 - q * selfscale 65 ) 66 67 return 0.5 * q * phi, phi
Compute Coulomb interactions in reciprocal space using Ewald summation
105class Coulomb(nn.Module): 106 """Coulomb interaction between distributed point charges 107 108 FID: COULOMB 109 110 """ 111 _graphs_properties: Dict 112 graph_key: str = "graph" 113 """Key for the graph in the inputs""" 114 charges_key: str = "charges" 115 """Key for the charges in the inputs""" 116 energy_key: Optional[str] = None 117 """Key for the energy in the outputs""" 118 # switch_fraction: float = 0.9 119 scale: Optional[float] = None 120 """Scaling factor for the energy""" 121 charge_scale: Optional[float] = None 122 """Scaling factor for the charges""" 123 damp_style: Optional[str] = None 124 """Damping style. Available options are: None, 'TS', 'OQDO', 'D3', 'SPOOKY', 'CP', 'KEY'""" 125 damp_params: Dict = dataclasses.field(default_factory=dict) 126 """Damping parameters""" 127 bscreen: float = -1.0 128 """Screening parameter. If >0, the Coulomb potential becomes a Yukawa potential and the reciprocal space is not computed""" 129 trainable: bool = True 130 """Whether the parameters are trainable""" 131 _energy_unit: str = "Ha" 132 """The energy unit of the model. **Automatically set by FENNIX**""" 133 134 FID: ClassVar[str] = "COULOMB" 135 136 @nn.compact 137 def __call__(self, inputs): 138 species = inputs["species"] 139 graph = inputs[self.graph_key] 140 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 141 distances = graph["distances"] 142 switch = graph["switch"] 143 144 rij = distances / au.BOHR 145 q = inputs[self.charges_key] 146 if q.shape[-1] == 1: 147 q = jnp.squeeze(q, axis=-1) 148 if self.charge_scale is not None: 149 if self.trainable: 150 charge_scale = jnp.abs( 151 self.param( 152 "charge_scale", lambda key: jnp.asarray(self.charge_scale) 153 ) 154 ) 155 else: 156 charge_scale = self.charge_scale 157 q = q * charge_scale 158 159 damp_style = self.damp_style.upper() if self.damp_style is not None else None 160 161 do_recip = self.bscreen <= 0.0 and "k_points" in graph 162 163 if self.bscreen > 0.0: 164 # dirfact = jax.scipy.special.erfc(self.bscreen * distances) 165 dirfact = jnp.exp(-self.bscreen * distances) 166 elif do_recip: 167 k_points = graph["k_points"] 168 bewald = graph["b_ewald"] 169 cells = inputs["cells"] 170 reciprocal_cells = inputs["reciprocal_cells"] 171 batch_index = inputs["batch_index"] 172 erec, _ = ewald_reciprocal( 173 q, 174 *prepare_reciprocal_space( 175 cells, 176 reciprocal_cells, 177 inputs["coordinates"], 178 batch_index, 179 k_points, 180 bewald, 181 ), 182 ) 183 dirfact = jax.scipy.special.erfc(bewald * distances) 184 else: 185 dirfact = 1.0 186 187 if damp_style is None: 188 Aij = switch * dirfact / rij 189 eat = ( 190 0.5 191 * q 192 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 193 ) 194 195 elif damp_style == "TS": 196 cpB = self.damp_params.get("cpB", 3.5) 197 s = self.damp_params.get("s", 2.4) 198 199 if self.trainable: 200 cpB = jnp.abs(self.param("cpB", lambda key: jnp.asarray(cpB))) 201 s = jnp.abs(self.param("s", lambda key: jnp.asarray(s))) 202 203 ratiovol_key = self.damp_params.get("ratiovol_key", None) 204 if ratiovol_key is not None: 205 ratiovol = inputs[ratiovol_key] 206 if ratiovol.shape[-1] == 1: 207 ratiovol = jnp.squeeze(ratiovol, axis=-1) 208 rvdw = jnp.asarray(VDW_RADII)[species] * ratiovol ** (1.0 / 3.0) 209 else: 210 rvdw = jnp.asarray(VDW_RADII)[species] 211 Rij = rvdw[edge_src] + rvdw[edge_dst] 212 Bij = cpB * (rij / Rij) ** s 213 214 eBij = jnp.where(Bij < 20.0, jnp.exp(-Bij), 0.0) 215 216 Aij = (dirfact - eBij) / rij * switch 217 eat = ( 218 0.5 219 * q 220 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 221 ) 222 223 elif damp_style == "OQDO": 224 ratiovol_key = self.damp_params.get("ratiovol_key", None) 225 alpha = jnp.asarray(POLARIZABILITIES)[species] 226 if ratiovol_key is not None: 227 ratiovol = inputs[ratiovol_key] 228 if ratiovol.shape[-1] == 1: 229 ratiovol = jnp.squeeze(ratiovol, axis=-1) 230 alpha = alpha * ratiovol 231 232 alphaij = 0.5 * (alpha[edge_src] + alpha[edge_dst]) 233 Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 234 Re2 = Re**2 235 Re4 = Re**4 236 muw = ( 237 3.66316787e01 238 - 5.79579187 * Re 239 + 3.02674813e-01 * Re2 240 - 3.65461255e-04 * Re4 241 ) / (-1.46169102e01 + 7.32461225 * Re) 242 # muw = ( 243 # 4.83053463e-01 244 # - 3.76191669e-02 * Re 245 # + 1.27066988e-03 * Re2 246 # - 7.21940151e-07 * Re4 247 # ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2) 248 Bij = 0.5 * muw * rij**2 249 250 eBij = jnp.where(Bij < 20.0, jnp.exp(-Bij), 0.0) 251 252 Aij = (dirfact - eBij) / rij * switch 253 eat = ( 254 0.5 255 * q 256 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 257 ) 258 259 elif damp_style == "D3": 260 ratiovol_key = self.damp_params.get("ratiovol_key", None) 261 if ratiovol_key is not None: 262 ratiovol = inputs[ratiovol_key] 263 if ratiovol.shape[-1] == 1: 264 ratiovol = jnp.squeeze(ratiovol, axis=-1) 265 else: 266 ratiovol = 1.0 267 268 gamma_scheme = self.damp_params.get("gamma_scheme", "D3") 269 if gamma_scheme == "D3": 270 if self.trainable: 271 rvdw = jnp.abs( 272 self.param("rvdw", lambda key: jnp.asarray(VDW_RADII)) 273 )[species] 274 else: 275 rvdw = jnp.asarray(VDW_RADII)[species] 276 rvdw = rvdw * ratiovol ** (1.0 / 3.0) 277 ai2 = rvdw**2 278 gamma_ij = (ai2[edge_src] + ai2[edge_dst] + 1.0e-3) ** (-0.5) 279 280 elif gamma_scheme == "QDO": 281 gscale = self.damp_params.get("gamma_scale", 2.0) 282 if self.trainable: 283 gscale = jnp.abs( 284 self.param("gamma_scale", lambda key: jnp.asarray(gscale)) 285 ) 286 alpha = jnp.abs( 287 self.param("alpha", lambda key: jnp.asarray(POLARIZABILITIES)) 288 )[species] 289 else: 290 alpha = jnp.asarray(POLARIZABILITIES)[species] 291 alpha = alpha * ratiovol 292 alphaij = 0.5 * (alpha[edge_src] + alpha[edge_dst]) 293 rvdwij = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 294 gamma_ij = gscale / rvdwij 295 else: 296 raise NotImplementedError( 297 f"gamma_scheme {gamma_scheme} not implemented" 298 ) 299 300 Aij = (dirfact - jax.scipy.special.erfc(gamma_ij * rij)) / rij * switch 301 302 eat = ( 303 0.5 304 * q 305 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 306 ) 307 308 elif damp_style == "SPOOKY": 309 shortrange_cutoff = self.damp_params.get("shortrange_cutoff", 5.0) 310 r_on = self.damp_params.get("r_on", 0.25) * shortrange_cutoff 311 r_off = self.damp_params.get("r_off", 0.75) * shortrange_cutoff 312 x1 = (distances - r_on) / (r_off - r_on) 313 x2 = 1.0 - x1 314 mask1 = x1 <= 1.0e-6 315 mask2 = x2 <= 1.0e-6 316 x1 = jnp.where(mask1, 1.0, x1) 317 x2 = jnp.where(mask2, 1.0, x2) 318 s1 = jnp.where(mask1, 0.0, jnp.exp(-1.0 / x1)) 319 s2 = jnp.where(mask2, 0.0, jnp.exp(-1.0 / x2)) 320 Bij = s2 / (s1 + s2) 321 322 Aij = Bij / (rij**2 + 1) ** 0.5 + (dirfact - Bij) / rij * switch 323 eat = ( 324 0.5 325 * q 326 * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0]) 327 ) 328 329 elif damp_style == "CP": 330 cpA = self.damp_params.get("cpA", 4.42) 331 cpB = self.damp_params.get("cpB", 4.12) 332 gamma = self.damp_params.get("gamma", 0.5) 333 334 if self.trainable: 335 cpA = jnp.abs(self.param("cpA", lambda key: jnp.asarray(cpA))) 336 cpB = jnp.abs(self.param("cpB", lambda key: jnp.asarray(cpB))) 337 gamma = jnp.abs(self.param("gamma", lambda key: jnp.asarray(gamma))) 338 339 rvdw = jnp.asarray(VDW_RADII)[species] 340 ratiovol_key = self.damp_params.get("ratiovol_key", None) 341 if ratiovol_key is not None: 342 ratiovol = inputs[ratiovol_key] 343 if ratiovol.shape[-1] == 1: 344 ratiovol = jnp.squeeze(ratiovol, axis=-1) 345 rvdw = rvdw * ratiovol ** (1.0 / 3.0) 346 347 Zv = jnp.asarray(VALENCE_ELECTRONS)[species] 348 Zi, Zj = Zv[edge_src], Zv[edge_dst] 349 qi, qj = q[edge_src], q[edge_dst] 350 rvdwi, rvdwj = rvdw[edge_src], rvdw[edge_dst] 351 352 eAi = jnp.exp(-cpA * rij / rvdwi) 353 eAj = jnp.exp(-cpA * rij / rvdwj) 354 eBi = jnp.exp(-cpB * rij / rvdwi) 355 eBj = jnp.exp(-cpB * rij / rvdwj) 356 eBij = eBi * eBj - eBi - eBj 357 358 Bshort = jnp.exp(-gamma * distances**4) 359 Dshort = 1.0 - Bshort 360 ecp = Dshort * ( 361 Zi * Zj * (eAi + eAj + eBij) 362 - qi * Zj * (eAi + eBij) 363 - qj * Zi * (eAj + eBij) 364 ) 365 366 # eq = qi * qj * (1 + eBij) * (1 - Bshort) 367 eqq = qi * qj * (dirfact - Bshort + eBij * Dshort) 368 369 epair = (ecp + eqq) * switch / rij 370 371 # epair = ( 372 # (1 - Bshort) 373 # * ( 374 # Zi * Zj * (eAi + eAj + eBij) 375 # - qi * Zj * (eAi + eBij) 376 # - qj * Zi * (eAj + eBij) 377 # + qi * qj * (1 + eBij) 378 # ) 379 # * switch 380 # / rij 381 # ) 382 eat = 0.5 * jax.ops.segment_sum(epair, edge_src, species.shape[0]) 383 384 elif damp_style == "KEY": 385 damp_key = self.damp_params["key"] 386 damp = inputs[damp_key] 387 epair = (dirfact - damp) * switch / rij 388 eat = 0.5 * jax.ops.segment_sum(epair, edge_src, species.shape[0]) 389 else: 390 raise NotImplementedError(f"damp_style {self.damp_style} not implemented") 391 392 if do_recip: 393 eat = eat + erec 394 395 if self.scale is not None: 396 if self.trainable: 397 scale = jnp.abs( 398 self.param("scale", lambda key: jnp.asarray(self.scale)) 399 ) 400 else: 401 scale = self.scale 402 eat = eat * scale 403 404 energy_key = self.energy_key if self.energy_key is not None else self.name 405 energy_unit = au.get_multiplier(self._energy_unit) 406 out = {**inputs, energy_key: eat*energy_unit} 407 if do_recip: 408 out[energy_key + "_reciprocal"] = erec*energy_unit 409 return out
Coulomb interaction between distributed point charges
FID: COULOMB
Damping style. Available options are: None, 'TS', 'OQDO', 'D3', 'SPOOKY', 'CP', 'KEY'
Screening parameter. If >0, the Coulomb potential becomes a Yukawa potential and the reciprocal space is not computed
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
412class QeqD4(nn.Module): 413 """ QEq-D4 charge equilibration scheme 414 415 FID: QEQ_D4 416 417 ### Reference 418 E. Caldeweyher et al.,A generally applicable atomic-charge dependent London dispersion correction, 419 J Chem Phys. 2019 Apr 21;150(15):154122. (https://doi.org/10.1063/1.5090222) 420 """ 421 graph_key: str = "graph" 422 """Key for the graph in the inputs""" 423 trainable: bool = False 424 """Whether the parameters are trainable""" 425 charges_key: str = "charges" 426 """Key for the charges in the outputs. 427 If charges are provided in the inputs, they are not re-optimized and we only compute the energy""" 428 energy_key: Optional[str] = None 429 """Key for the energy in the outputs""" 430 chi_key: Optional[str] = None 431 """Key for additional electronegativity in the inputs""" 432 c3_key: Optional[str] = None 433 """Key for additional c3 in the inputs. Only used if charges are provided in the inputs""" 434 c4_key: Optional[str] = None 435 """Key for additional c4 in the inputs. Only used if charges are provided in the inputs""" 436 total_charge_key: str = "total_charge" 437 """Key for the total charge in the inputs""" 438 non_interacting_guess: bool = False 439 """Whether to use the non-interacting limit as an initial guess.""" 440 solver: str = "gmres" 441 _energy_unit: str = "Ha" 442 """The energy unit of the model. **Automatically set by FENNIX**""" 443 444 FID: ClassVar[str] = "QEQ_D4" 445 446 @nn.compact 447 def __call__(self, inputs): 448 species = inputs["species"] 449 graph = inputs[self.graph_key] 450 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 451 switch = graph["switch"] 452 453 rij = graph["distances"] / au.BOHR 454 455 do_recip = "k_points" in graph 456 if do_recip: 457 k_points = graph["k_points"] 458 bewald = graph["b_ewald"] 459 cells = inputs["cells"] 460 reciprocal_cells = inputs["reciprocal_cells"] 461 batch_index = inputs["batch_index"] 462 dirfact = jax.scipy.special.erfc(bewald * graph["distances"]) 463 ewald_params = prepare_reciprocal_space( 464 cells, 465 reciprocal_cells, 466 inputs["coordinates"], 467 batch_index, 468 k_points, 469 bewald, 470 ) 471 else: 472 dirfact = 1.0 473 474 Jii = D3_HARDNESSES 475 ai = D3_VDW_RADII 476 ETA = [jii + (2.0 / np.pi) ** 0.5 / aii for jii, aii in zip(Jii, ai)] 477 478 # D3 parameters 479 if self.trainable: 480 ENi = self.param("EN", lambda key: jnp.asarray(D3_ELECTRONEGATIVITIES))[ 481 species 482 ] 483 # Jii = self.param("J", lambda key: jnp.asarray(D3_HARDNESSES))[species] 484 eta = jnp.abs(self.param("eta", lambda key: jnp.asarray(ETA)))[species] 485 ai = jnp.abs(self.param("a", lambda key: jnp.asarray(D3_VDW_RADII)))[ 486 species 487 ] 488 rci = jnp.abs(self.param("rc", lambda key: jnp.asarray(D3_COV_RADII)))[ 489 species 490 ] 491 c3 = self.param("c3", lambda key: jnp.zeros(len(D3_ELECTRONEGATIVITIES)))[ 492 species 493 ] 494 c4 = jnp.abs( 495 self.param("c4", lambda key: jnp.zeros(len(D3_ELECTRONEGATIVITIES)))[ 496 species 497 ] 498 ) 499 kappai = self.param("kappa", lambda key: jnp.asarray(D3_KAPPA))[species] 500 k1 = jnp.abs(self.param("k1", lambda key: jnp.asarray(7.5))) 501 training = "training" in inputs.get("flags", {}) 502 if training: 503 regularization = ( 504 (ENi - jnp.asarray(D3_ELECTRONEGATIVITIES)[species]) ** 2 505 + (eta - jnp.asarray(ETA)[species]) ** 2 506 + (ai - jnp.asarray(D3_VDW_RADII)[species]) ** 2 507 + (rci - jnp.asarray(D3_COV_RADII)[species]) ** 2 508 + (kappai - jnp.asarray(D3_KAPPA)[species]) ** 2 509 + (k1 - 7.5) ** 2 510 ) 511 else: 512 c3 = jnp.zeros_like(species, dtype=jnp.float32) 513 c4 = jnp.zeros_like(species, dtype=jnp.float32) 514 ENi = jnp.asarray(D3_ELECTRONEGATIVITIES)[species] 515 # Jii = jnp.asarray(D3_HARDNESSES)[species] 516 eta = jnp.asarray(ETA)[species] 517 ai = jnp.asarray(D3_VDW_RADII)[species] 518 rci = jnp.asarray(D3_COV_RADII)[species] 519 kappai = jnp.asarray(D3_KAPPA)[species] 520 k1 = 7.5 521 522 ai2 = ai**2 523 rcij = ( 524 rci.at[edge_src].get(mode="fill", fill_value=1.0) 525 + rci.at[edge_dst].get(mode="fill", fill_value=1.0) 526 + 1.0e-3 527 ) 528 mCNij = 1.0 + jax.scipy.special.erf(-k1 * (rij / rcij - 1)) 529 mCNi = 0.5 * jax.ops.segment_sum(mCNij * switch, edge_src, species.shape[0]) 530 chi = ENi - kappai * (mCNi + 1.0e-3) ** 0.5 531 if self.chi_key is not None: 532 chi = chi + inputs[self.chi_key] 533 534 gamma_ij = ( 535 ai2.at[edge_src].get(mode="fill", fill_value=1.0) 536 + ai2.at[edge_dst].get(mode="fill", fill_value=1.0) 537 + 1.0e-3 538 ) ** (-0.5) 539 540 Aii = eta # Jii + ((2.0 / np.pi) ** 0.5) / ai 541 # Aij = jax.scipy.special.erf(gamma_ij * rij) / rij * switch 542 Aij = (dirfact - jax.scipy.special.erfc(gamma_ij * rij)) / rij * switch 543 544 if self.charges_key in inputs: 545 q = inputs[self.charges_key] 546 q_ = q 547 else: 548 nsys = inputs["natoms"].shape[0] 549 batch_index = inputs["batch_index"] 550 551 def matvec(x): 552 l, q = jnp.split(x, (nsys,)) 553 Aq_self = Aii * q 554 qdest = q.at[edge_dst].get(mode="fill", fill_value=0.0) 555 Aq_pair = jax.ops.segment_sum(Aij * qdest, edge_src, species.shape[0]) 556 Aq = ( 557 Aq_self 558 + Aq_pair 559 + l.at[batch_index].get(mode="fill", fill_value=0.0) 560 ) 561 if do_recip: 562 _, phirec = ewald_reciprocal(q, *ewald_params) 563 Aq = Aq + phirec 564 Al = jax.ops.segment_sum(q, batch_index, nsys) 565 return jnp.concatenate((Al, Aq)) 566 567 Qtot = ( 568 inputs[self.total_charge_key].astype(chi.dtype).reshape(nsys) 569 if self.total_charge_key in inputs 570 else jnp.zeros(nsys, dtype=chi.dtype) 571 ) 572 b = jnp.concatenate([Qtot, -chi]) 573 574 if self.non_interacting_guess: 575 # build initial guess 576 si = 1./Aii 577 q0 = -chi*si 578 qtot = jax.ops.segment_sum(q0,batch_index,nsys) 579 sisum = jax.ops.segment_sum(si,batch_index,nsys) 580 l0 = sisum*(qtot - Qtot) 581 q0 = q0 - si*l0[batch_index] 582 x0 = jnp.concatenate((l0,q0)) 583 else: 584 x0 = None 585 586 solver = self.solver.lower() 587 if solver == "bicg": 588 x = jax.scipy.sparse.linalg.bicgstab(matvec, b,x0=x0)[0] 589 elif solver == "gmres": 590 x = jax.scipy.sparse.linalg.gmres(matvec, b,x0=x0)[0] 591 elif solver == "cg": 592 print("Warning: Use of cg solver for Qeq is not recommended") 593 x = jax.scipy.sparse.linalg.cg(matvec, b,x0=x0)[0] 594 else: 595 raise NotImplementedError(f"solver '{solver}' is not implemented. Choose one of [bicg, gmres]") 596 597 598 q = x[nsys:] 599 q_ = jax.lax.stop_gradient(q) 600 601 eself = 0.5 * Aii * q_**2 + chi * q_ 602 603 phi = jax.ops.segment_sum(Aij * q_[edge_dst], edge_src, species.shape[0]) 604 if do_recip: 605 erec, _ = ewald_reciprocal(q_, *ewald_params) 606 epair = 0.5 * q_ * phi 607 608 if self.charges_key in inputs: 609 if self.c3_key is not None: 610 c3 = c3 + inputs[self.c3_key] 611 if self.c4_key is not None: 612 c4 = c4 + inputs[self.c4_key] 613 eself = eself + c3 * q_**3 + c4 * q_**4 614 training = "training" in inputs.get("flags", {}) 615 if self.trainable and training: 616 Aii_ = jax.lax.stop_gradient(Aii) 617 chi_ = jax.lax.stop_gradient(chi) 618 phi_ = jax.lax.stop_gradient(phi) 619 c3_ = jax.lax.stop_gradient(c3) 620 c4_ = jax.lax.stop_gradient(c4) 621 switch_ = jax.lax.stop_gradient(switch) 622 Aij_ = jax.lax.stop_gradient(Aij) 623 phi_ = jax.ops.segment_sum( 624 Aij_ * q[edge_dst], edge_src, species.shape[0] 625 ) 626 627 dedq = Aii_ * q + chi_ + phi_ + 3 * c3_ * q**2 + 4 * c4_ * q**3 628 dedq = jax.ops.segment_sum( 629 switch_ * (dedq[edge_src] - dedq[edge_dst]) ** 2, 630 edge_src, 631 species.shape[0], 632 ) 633 etrain = ( 634 0.5 * Aii_ * q**2 635 + chi_ * q 636 + 0.5 * q * phi_ 637 + c3_ * q**3 638 + c4_ * q**4 639 ) 640 if do_recip: 641 etrain = etrain + erec 642 643 energy = eself + epair 644 if do_recip: 645 energy = energy + erec 646 647 energy_key = self.energy_key if self.energy_key is not None else self.name 648 energy_unit = au.get_multiplier(self._energy_unit) 649 output = { 650 **inputs, 651 self.charges_key: q, 652 energy_key: energy*energy_unit, 653 } 654 if do_recip: 655 output[energy_key + "_reciprocal"] = erec*energy_unit 656 657 training = "training" in inputs.get("flags", {}) 658 if self.charges_key in inputs and self.trainable and training: 659 output[energy_key + "_regularization"] = regularization 660 output[energy_key + "_dedq"] = dedq*energy_unit 661 output[energy_key + "_etrain"] = etrain*energy_unit 662 return output
QEq-D4 charge equilibration scheme
FID: QEQ_D4
Reference
E. Caldeweyher et al.,A generally applicable atomic-charge dependent London dispersion correction, J Chem Phys. 2019 Apr 21;150(15):154122. (https://doi.org/10.1063/1.5090222)
Key for the charges in the outputs. If charges are provided in the inputs, they are not re-optimized and we only compute the energy
Key for additional c3 in the inputs. Only used if charges are provided in the inputs
Key for additional c4 in the inputs. Only used if charges are provided in the inputs
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
665class ChargeCorrection(nn.Module): 666 """Charge correction scheme 667 668 FID: CHARGE_CORRECTION 669 670 Used to correct the provided charges to sum to the total charge of the system. 671 """ 672 key: str = "charges" 673 """Key for the charges in the inputs""" 674 output_key: str = None 675 """Key for the corrected charges in the outputs. If None, it is the same as the input key""" 676 dq_key: str = "delta_qtot" 677 """Key for the deviation of the raw charge sum in the outputs""" 678 ratioeta_key: str = None 679 """Key for the ratio of hardness between AIM and free atom in the inputs. Used to adjust charge redistribution.""" 680 trainable: bool = False 681 """Whether the parameters are trainable""" 682 cn_key: str = None 683 """Key for the coordination number in the inputs. Used to adjust charge redistribution.""" 684 total_charge_key: str = "total_charge" 685 """Key for the total charge in the inputs""" 686 _energy_unit: str = "Ha" 687 """The energy unit of the model. **Automatically set by FENNIX**""" 688 689 FID: ClassVar[str] = "CHARGE_CORRECTION" 690 691 @nn.compact 692 def __call__(self, inputs) -> Any: 693 species = inputs["species"] 694 batch_index = inputs["batch_index"] 695 nsys = inputs["natoms"].shape[0] 696 q = inputs[self.key] 697 if q.shape[-1] == 1: 698 q = jnp.squeeze(q, axis=-1) 699 qtot = jax.ops.segment_sum(q, batch_index, nsys) 700 Qtot = ( 701 inputs[self.total_charge_key].astype(q.dtype) 702 if self.total_charge_key in inputs 703 else jnp.zeros(qtot.shape[0], dtype=q.dtype) 704 ) 705 dq = Qtot - qtot 706 707 Jii = D3_HARDNESSES 708 ai = D3_VDW_RADII 709 eta = [jii + (2.0 / np.pi) ** 0.5 / aii for jii, aii in zip(Jii, ai)] 710 if self.trainable: 711 eta = jnp.abs(self.param("eta", lambda key: jnp.asarray(eta)))[species] 712 else: 713 eta = jnp.asarray(eta)[species] 714 715 if self.ratioeta_key is not None: 716 ratioeta = inputs[self.ratioeta_key] 717 if ratioeta.shape[-1] == 1: 718 ratioeta = jnp.squeeze(ratioeta, axis=-1) 719 eta = eta * ratioeta 720 721 s = (1.0e-6 + 2 * jnp.abs(eta)) ** (-1) 722 if self.cn_key is not None: 723 cn = inputs[self.cn_key] 724 if cn.shape[-1] == 1: 725 cn = jnp.squeeze(cn, axis=-1) 726 s = s * cn 727 728 f = dq / jax.ops.segment_sum(s, batch_index, nsys) 729 730 qf = q + s * f[batch_index] 731 732 energy_unit = au.get_multiplier(self._energy_unit) 733 ecorr = (0.5*energy_unit) * eta * (qf - q) ** 2 734 output_key = self.output_key if self.output_key is not None else self.key 735 return { 736 **inputs, 737 output_key: qf, 738 self.dq_key: dq, 739 "charge_correction_energy": ecorr, 740 }
Charge correction scheme
FID: CHARGE_CORRECTION
Used to correct the provided charges to sum to the total charge of the system.
Key for the corrected charges in the outputs. If None, it is the same as the input key
Key for the ratio of hardness between AIM and free atom in the inputs. Used to adjust charge redistribution.
Key for the coordination number in the inputs. Used to adjust charge redistribution.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
742class DistributeElectrons(nn.Module): 743 """Distribute valence electrons between the atoms 744 745 FID: DISTRIBUTE_ELECTRONS 746 747 Used to predict charges that sum to the total charge of the system. 748 """ 749 embedding_key: str 750 """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight""" 751 output_key: Union[str,None] = None 752 """Key for the charges in the outputs""" 753 total_charge_key: str = "total_charge" 754 """Key for the total charge in the inputs""" 755 756 FID: ClassVar[str] = "DISTRIBUTE_ELECTRONS" 757 758 @nn.compact 759 def __call__(self, inputs) -> Any: 760 species = inputs["species"] 761 Nel = jnp.asarray(VALENCE_ELECTRONS)[species] 762 763 ei = nn.Dense(1, use_bias=True, name="wi")(inputs[self.embedding_key]).squeeze(-1) 764 wi = jax.nn.softplus(ei) 765 766 batch_index = inputs["batch_index"] 767 nsys = inputs["natoms"].shape[0] 768 wtot = jax.ops.segment_sum(wi, inputs["batch_index"], inputs["natoms"].shape[0]) 769 770 Qtot = ( 771 inputs[self.total_charge_key].astype(ei.dtype) 772 if self.total_charge_key in inputs 773 else jnp.zeros(nsys, dtype=ei.dtype) 774 ) 775 Neltot = jax.ops.segment_sum(Nel, batch_index, nsys) - Qtot 776 777 f = Neltot / wtot 778 Ni = wi* f[batch_index] 779 q = Nel-Ni 780 781 782 output_key = self.output_key if self.output_key is not None else self.name 783 return { 784 **inputs, 785 output_key: q, 786 }
Distribute valence electrons between the atoms
FID: DISTRIBUTE_ELECTRONS
Used to predict charges that sum to the total charge of the system.
Key for the embedding in the inputs that is used to predict an 'electron affinity' weight
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.