fennol.models.physics.dispersion
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import numpy as np 5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 6from ...utils import AtomicUnits as au 7from ...utils.periodic_table import ( 8 D3_ELECTRONEGATIVITIES, 9 D3_HARDNESSES, 10 D3_VDW_RADII, 11 D3_COV_RADII, 12 D3_KAPPA, 13 VDW_RADII, 14 POLARIZABILITIES, 15 C6_FREE, 16 VALENCE_ELECTRONS, 17) 18 19 20class VdwOQDO(nn.Module): 21 """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model. 22 23 FID : VDW_OQDO 24 25 ### Reference 26 A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, 27 J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797) 28 29 """ 30 graph_key: str = "graph" 31 """ The key for the graph input.""" 32 include_exchange: bool = True 33 """ Whether to compute the exchange part.""" 34 ratiovol_key: Optional[str] = None 35 """ The key for the ratio between AIM volume and free-atom volume. 36 If None, the volume ratio is assumed to be 1.0.""" 37 energy_key: Optional[str] = None 38 """ The key for the output energy. If None, the name of the module is used.""" 39 damped: bool = True 40 """ Whether to use short-range damping.""" 41 _energy_unit: str = "Ha" 42 """The energy unit of the model. **Automatically set by FENNIX**""" 43 44 FID: ClassVar[str] = "VDW_OQDO" 45 46 @nn.compact 47 def __call__(self, inputs): 48 energy_unit = au.get_multiplier(self._energy_unit) 49 50 species = inputs["species"] 51 graph = inputs[self.graph_key] 52 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 53 rij = graph["distances"] / au.BOHR 54 switch = graph["switch"] 55 56 c6 = jnp.asarray(C6_FREE)[species] 57 alpha = jnp.asarray(POLARIZABILITIES)[species] 58 59 if self.ratiovol_key is not None: 60 ratiovol = inputs[self.ratiovol_key] + 1.0e-6 61 if ratiovol.shape[-1] == 1: 62 ratiovol = jnp.squeeze(ratiovol, axis=-1) 63 c6 = c6 * ratiovol**2 64 alpha = alpha * ratiovol 65 66 c6i, c6j = c6[edge_src], c6[edge_dst] 67 alphai, alphaj = alpha[edge_src], alpha[edge_dst] 68 69 # combination rules 70 alphaij = 0.5 * (alphai + alphaj) 71 c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2) 72 73 # equilibrium distance 74 Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 75 Re2 = Re**2 76 Re4 = Re**4 77 # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators" 78 if self.damped: 79 muw = ( 80 4.83053463e-01 81 - 3.76191669e-02 * Re 82 + 1.27066988e-03 * Re2 83 - 7.21940151e-07 * Re4 84 ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2) 85 else: 86 muw = ( 87 3.66316787e01 88 - 5.79579187 * Re 89 + 3.02674813e-01 * Re2 90 - 3.65461255e-04 * Re4 91 ) / (-1.46169102e01 + 7.32461225 * Re) 92 93 c8ij = 5 * c6ij / muw 94 c10ij = 245 * c6ij / (8 * muw**2) 95 96 if self.damped: 97 z = 0.5 * muw * rij**2 98 ez = jnp.exp(-z) 99 f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3) 100 f8 = f6 - (1.0 / 24.0) * ez * z**4 101 f10 = f8 - (1.0 / 120.0) * ez * z**5 102 epair = ( 103 f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10 104 ) 105 else: 106 epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10 107 108 edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0]) 109 110 output_key = self.name if self.energy_key is None else self.energy_key 111 112 if not self.include_exchange: 113 return {**inputs, output_key: edisp} 114 115 ### exchange 116 w = 4 * c6ij / (3 * alphaij**2) 117 # q = (alphaij * mu*w**2)**0.5 118 q2 = alphaij * muw * w 119 # undamped case 120 if self.damped: 121 ze = 0.5 * muw * Re2 122 eze = jnp.exp(-ze) 123 124 s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3) 125 f6e = 1.0 - s6 126 muwRe = muw * Re 127 df6e = muwRe * s6 - eze * ( 128 muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3 129 ) 130 131 s8 = (1.0 / 24.0) * eze * ze**4 132 f8e = f6e - s8 133 df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4 134 135 s10 = (1.0 / 120.0) * eze * ze**5 136 f10e = f8e - s10 137 df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5 138 139 den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e) 140 A = ( 141 0.5 142 + c8ij * (8 * f8e - Re * df8e) / den 143 + c10ij * (10 * f10e - Re * df10e) / (den * Re2) 144 ) 145 else: 146 A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4) 147 ez = jnp.exp(-0.5 * muw * rij**2) 148 149 exij = A * q2 * ez / rij 150 ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0]) 151 152 return { 153 **inputs, 154 output_key + "_dispersion": edisp, 155 output_key + "_exchange": ex, 156 output_key: edisp + ex, 157 }
21class VdwOQDO(nn.Module): 22 """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model. 23 24 FID : VDW_OQDO 25 26 ### Reference 27 A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, 28 J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797) 29 30 """ 31 graph_key: str = "graph" 32 """ The key for the graph input.""" 33 include_exchange: bool = True 34 """ Whether to compute the exchange part.""" 35 ratiovol_key: Optional[str] = None 36 """ The key for the ratio between AIM volume and free-atom volume. 37 If None, the volume ratio is assumed to be 1.0.""" 38 energy_key: Optional[str] = None 39 """ The key for the output energy. If None, the name of the module is used.""" 40 damped: bool = True 41 """ Whether to use short-range damping.""" 42 _energy_unit: str = "Ha" 43 """The energy unit of the model. **Automatically set by FENNIX**""" 44 45 FID: ClassVar[str] = "VDW_OQDO" 46 47 @nn.compact 48 def __call__(self, inputs): 49 energy_unit = au.get_multiplier(self._energy_unit) 50 51 species = inputs["species"] 52 graph = inputs[self.graph_key] 53 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 54 rij = graph["distances"] / au.BOHR 55 switch = graph["switch"] 56 57 c6 = jnp.asarray(C6_FREE)[species] 58 alpha = jnp.asarray(POLARIZABILITIES)[species] 59 60 if self.ratiovol_key is not None: 61 ratiovol = inputs[self.ratiovol_key] + 1.0e-6 62 if ratiovol.shape[-1] == 1: 63 ratiovol = jnp.squeeze(ratiovol, axis=-1) 64 c6 = c6 * ratiovol**2 65 alpha = alpha * ratiovol 66 67 c6i, c6j = c6[edge_src], c6[edge_dst] 68 alphai, alphaj = alpha[edge_src], alpha[edge_dst] 69 70 # combination rules 71 alphaij = 0.5 * (alphai + alphaj) 72 c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2) 73 74 # equilibrium distance 75 Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0) 76 Re2 = Re**2 77 Re4 = Re**4 78 # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators" 79 if self.damped: 80 muw = ( 81 4.83053463e-01 82 - 3.76191669e-02 * Re 83 + 1.27066988e-03 * Re2 84 - 7.21940151e-07 * Re4 85 ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2) 86 else: 87 muw = ( 88 3.66316787e01 89 - 5.79579187 * Re 90 + 3.02674813e-01 * Re2 91 - 3.65461255e-04 * Re4 92 ) / (-1.46169102e01 + 7.32461225 * Re) 93 94 c8ij = 5 * c6ij / muw 95 c10ij = 245 * c6ij / (8 * muw**2) 96 97 if self.damped: 98 z = 0.5 * muw * rij**2 99 ez = jnp.exp(-z) 100 f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3) 101 f8 = f6 - (1.0 / 24.0) * ez * z**4 102 f10 = f8 - (1.0 / 120.0) * ez * z**5 103 epair = ( 104 f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10 105 ) 106 else: 107 epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10 108 109 edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0]) 110 111 output_key = self.name if self.energy_key is None else self.energy_key 112 113 if not self.include_exchange: 114 return {**inputs, output_key: edisp} 115 116 ### exchange 117 w = 4 * c6ij / (3 * alphaij**2) 118 # q = (alphaij * mu*w**2)**0.5 119 q2 = alphaij * muw * w 120 # undamped case 121 if self.damped: 122 ze = 0.5 * muw * Re2 123 eze = jnp.exp(-ze) 124 125 s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3) 126 f6e = 1.0 - s6 127 muwRe = muw * Re 128 df6e = muwRe * s6 - eze * ( 129 muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3 130 ) 131 132 s8 = (1.0 / 24.0) * eze * ze**4 133 f8e = f6e - s8 134 df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4 135 136 s10 = (1.0 / 120.0) * eze * ze**5 137 f10e = f8e - s10 138 df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5 139 140 den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e) 141 A = ( 142 0.5 143 + c8ij * (8 * f8e - Re * df8e) / den 144 + c10ij * (10 * f10e - Re * df10e) / (den * Re2) 145 ) 146 else: 147 A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4) 148 ez = jnp.exp(-0.5 * muw * rij**2) 149 150 exij = A * q2 * ez / rij 151 ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0]) 152 153 return { 154 **inputs, 155 output_key + "_dispersion": edisp, 156 output_key + "_exchange": ex, 157 output_key: edisp + ex, 158 }
Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
FID : VDW_OQDO
Reference
A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
The key for the ratio between AIM volume and free-atom volume. If None, the volume ratio is assumed to be 1.0.
The key for the output energy. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.