fennol.models.physics.dispersion

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import numpy as np
  5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  6from ...utils import AtomicUnits as au
  7from ...utils.periodic_table import (
  8    D3_ELECTRONEGATIVITIES,
  9    D3_HARDNESSES,
 10    D3_VDW_RADII,
 11    D3_COV_RADII,
 12    D3_KAPPA,
 13    VDW_RADII,
 14    POLARIZABILITIES,
 15    C6_FREE,
 16    VALENCE_ELECTRONS,
 17)
 18import pathlib
 19import pickle
 20
 21class VdwOQDO(nn.Module):
 22    """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
 23    
 24    FID : VDW_OQDO
 25
 26    ### Reference
 27    A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators,
 28    J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
 29    
 30    """
 31    graph_key: str = "graph"
 32    """ The key for the graph input."""
 33    include_exchange: bool = True
 34    """ Whether to compute the exchange part."""
 35    ratiovol_key: Optional[str] = None
 36    """ The key for the ratio between AIM volume and free-atom volume. 
 37         If None, the volume ratio is assumed to be 1.0."""
 38    energy_key: Optional[str] = None
 39    """ The key for the output energy. If None, the name of the module is used."""
 40    damped: bool = True
 41    """ Whether to use short-range damping."""
 42    _energy_unit: str = "Ha"
 43    """The energy unit of the model. **Automatically set by FENNIX**"""
 44
 45    FID: ClassVar[str]  = "VDW_OQDO"
 46
 47    @nn.compact
 48    def __call__(self, inputs):
 49        energy_unit = au.get_multiplier(self._energy_unit)
 50
 51        species = inputs["species"]
 52        graph = inputs[self.graph_key]
 53        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 54        rij = graph["distances"] / au.BOHR
 55        switch = graph["switch"]
 56
 57        c6 = jnp.asarray(C6_FREE)[species]
 58        alpha = jnp.asarray(POLARIZABILITIES)[species]
 59
 60        if self.ratiovol_key is not None:
 61            ratiovol = inputs[self.ratiovol_key] + 1.0e-6
 62            if ratiovol.shape[-1] == 1:
 63                ratiovol = jnp.squeeze(ratiovol, axis=-1)
 64            c6 = c6 * ratiovol**2
 65            alpha = alpha * ratiovol
 66
 67        c6i, c6j = c6[edge_src], c6[edge_dst]
 68        alphai, alphaj = alpha[edge_src], alpha[edge_dst]
 69
 70        # combination rules
 71        alphaij = 0.5 * (alphai + alphaj)
 72        c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2)
 73
 74        # equilibrium distance
 75        Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
 76        Re2 = Re**2
 77        Re4 = Re**4
 78        # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators"
 79        if self.damped:
 80            muw = (
 81                4.83053463e-01
 82                - 3.76191669e-02 * Re
 83                + 1.27066988e-03 * Re2
 84                - 7.21940151e-07 * Re4
 85            ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2)
 86        else:
 87            muw = (
 88                3.66316787e01
 89                - 5.79579187 * Re
 90                + 3.02674813e-01 * Re2
 91                - 3.65461255e-04 * Re4
 92            ) / (-1.46169102e01 + 7.32461225 * Re)
 93
 94        c8ij = 5 * c6ij / muw
 95        c10ij = 245 * c6ij / (8 * muw**2)
 96
 97        if self.damped:
 98            z = 0.5 * muw * rij**2
 99            ez = jnp.exp(-z)
100            f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3)
101            f8 = f6 - (1.0 / 24.0) * ez * z**4
102            f10 = f8 - (1.0 / 120.0) * ez * z**5
103            epair = (
104                f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10
105            )
106        else:
107            epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10
108
109        edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0])
110
111        output_key = self.name if self.energy_key is None else self.energy_key
112
113        if not self.include_exchange:
114            return {**inputs, output_key: edisp}
115
116        ### exchange
117        w = 4 * c6ij / (3 * alphaij**2)
118        # q = (alphaij * mu*w**2)**0.5
119        q2 = alphaij * muw * w
120        # undamped case
121        if self.damped:
122            ze = 0.5 * muw * Re2
123            eze = jnp.exp(-ze)
124
125            s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3)
126            f6e = 1.0 - s6
127            muwRe = muw * Re
128            df6e = muwRe * s6 - eze * (
129                muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3
130            )
131
132            s8 = (1.0 / 24.0) * eze * ze**4
133            f8e = f6e - s8
134            df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4
135
136            s10 = (1.0 / 120.0) * eze * ze**5
137            f10e = f8e - s10
138            df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5
139
140            den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e)
141            A = (
142                0.5
143                + c8ij * (8 * f8e - Re * df8e) / den
144                + c10ij * (10 * f10e - Re * df10e) / (den * Re2)
145            )
146        else:
147            A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4)
148            ez = jnp.exp(-0.5 * muw * rij**2)
149
150        exij = A * q2 * ez / rij
151        ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0])
152
153        return {
154            **inputs,
155            output_key + "_dispersion": edisp,
156            output_key + "_exchange": ex,
157            output_key: edisp + ex,
158        }
159
160
161class DispersionD3(nn.Module):
162    """
163
164    FID : DISPERSION_D3
165
166    """
167
168    _graphs_properties: Dict
169    graph_key: str = "graph"
170    output_key: Optional[str] = None
171    s6: float = 1.0
172    s8: float = 1.0
173    a1: float = 0.4
174    a2: float = 5.0
175    _energy_unit: str = "Ha"
176    """The energy unit of the model. **Automatically set by FENNIX**"""
177
178    FID: ClassVar[str] = "DISPERSION_D3"
179
180    @nn.compact
181    def __call__(self, inputs):
182
183        path = str(pathlib.Path(__file__).parent.resolve()) + "/ref_data_d3.pkl"
184        with open(path, "rb") as f:
185            DATA_D3 = pickle.load(f)
186
187        graph = inputs[self.graph_key]
188        edge_src = graph["edge_src"]
189        edge_dst = graph["edge_dst"]
190        switch = graph["switch"]
191        species = inputs["species"]
192
193        rij = jnp.clip(graph["distances"] / au.BOHR, 1e-6, None)
194
195        ## RADII (in BOHR)
196        rcov = jnp.array(DATA_D3["COV_D3"])[species]
197        # rvdw = jnp.array(DATA_D3["VDW_D3"])[species]
198        r4r2 = jnp.array(DATA_D3["R4R2"])[species]
199
200        rcij = rcov[edge_src] + rcov[edge_dst]
201
202        ## COORDINATION NUMBER
203        KCN = 16.0
204        cnij = jax.nn.sigmoid(KCN * (rcij / rij - 1.0))
205        cn = jax.ops.segment_sum(cnij, edge_src, species.shape[0])
206
207        ## WEIGHTS
208        refcn = jnp.array(DATA_D3["REF_CN"])[species]
209        mask = refcn >= 0
210        dcn = refcn - cn[:, None]
211        KW = 4.0
212        weights = jnp.where(mask, jnp.exp(-KW * dcn**2), 0.0)
213        norm = weights.sum(axis=1, keepdims=True)
214        weights = jnp.where(mask, weights / jnp.clip(norm, 1e-6, None), 0.0)
215
216        ## correct for all null weights
217        imaxcn = np.argmax(DATA_D3["REF_CN"], axis=1)
218        exweight = np.zeros_like(DATA_D3["REF_CN"])
219        for i, imax in enumerate(imaxcn):
220            exweight[i, imax] = 1.0
221        exweight = jnp.array(exweight)[species]
222
223        exceptional = norm < 1.0e-6
224        weights = jnp.where(exceptional, exweight, weights)
225
226        ## C6 coefficients
227        REF_C6 = DATA_D3["REF_C6"]
228        nz = REF_C6.shape[0]
229        nref = REF_C6.shape[-1]
230        REF_C6 = jnp.array(REF_C6.reshape((nz * nz, nref, nref)))
231        pair_num = species[edge_src] * nz + species[edge_dst]
232        rc6 = REF_C6[pair_num]
233        c6 = jnp.einsum("iab,ia,ib->i", rc6, weights[edge_src], weights[edge_dst])
234
235        ## DISPERSION
236
237        qq = 3 * r4r2[edge_src] * r4r2[edge_dst]
238        c8 = c6 * qq
239
240        r0 = self.a1 * jnp.sqrt(qq) + self.a2
241
242        t6 = self.s6 / (rij**6 + r0**6)
243        t8 = self.s8 / (rij**8 + r0**8)
244
245        energy_unit = au.get_multiplier(self._energy_unit)
246        energy = (-0.5*energy_unit) * jax.ops.segment_sum(
247            (c6 * t6 + c8 * t8) * switch, edge_src, species.shape[0]
248        )
249
250        output_key = self.output_key if self.output_key is not None else self.name
251        return {**inputs, output_key: energy}
class VdwOQDO(flax.linen.module.Module):
 22class VdwOQDO(nn.Module):
 23    """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
 24    
 25    FID : VDW_OQDO
 26
 27    ### Reference
 28    A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators,
 29    J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
 30    
 31    """
 32    graph_key: str = "graph"
 33    """ The key for the graph input."""
 34    include_exchange: bool = True
 35    """ Whether to compute the exchange part."""
 36    ratiovol_key: Optional[str] = None
 37    """ The key for the ratio between AIM volume and free-atom volume. 
 38         If None, the volume ratio is assumed to be 1.0."""
 39    energy_key: Optional[str] = None
 40    """ The key for the output energy. If None, the name of the module is used."""
 41    damped: bool = True
 42    """ Whether to use short-range damping."""
 43    _energy_unit: str = "Ha"
 44    """The energy unit of the model. **Automatically set by FENNIX**"""
 45
 46    FID: ClassVar[str]  = "VDW_OQDO"
 47
 48    @nn.compact
 49    def __call__(self, inputs):
 50        energy_unit = au.get_multiplier(self._energy_unit)
 51
 52        species = inputs["species"]
 53        graph = inputs[self.graph_key]
 54        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 55        rij = graph["distances"] / au.BOHR
 56        switch = graph["switch"]
 57
 58        c6 = jnp.asarray(C6_FREE)[species]
 59        alpha = jnp.asarray(POLARIZABILITIES)[species]
 60
 61        if self.ratiovol_key is not None:
 62            ratiovol = inputs[self.ratiovol_key] + 1.0e-6
 63            if ratiovol.shape[-1] == 1:
 64                ratiovol = jnp.squeeze(ratiovol, axis=-1)
 65            c6 = c6 * ratiovol**2
 66            alpha = alpha * ratiovol
 67
 68        c6i, c6j = c6[edge_src], c6[edge_dst]
 69        alphai, alphaj = alpha[edge_src], alpha[edge_dst]
 70
 71        # combination rules
 72        alphaij = 0.5 * (alphai + alphaj)
 73        c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2)
 74
 75        # equilibrium distance
 76        Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
 77        Re2 = Re**2
 78        Re4 = Re**4
 79        # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators"
 80        if self.damped:
 81            muw = (
 82                4.83053463e-01
 83                - 3.76191669e-02 * Re
 84                + 1.27066988e-03 * Re2
 85                - 7.21940151e-07 * Re4
 86            ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2)
 87        else:
 88            muw = (
 89                3.66316787e01
 90                - 5.79579187 * Re
 91                + 3.02674813e-01 * Re2
 92                - 3.65461255e-04 * Re4
 93            ) / (-1.46169102e01 + 7.32461225 * Re)
 94
 95        c8ij = 5 * c6ij / muw
 96        c10ij = 245 * c6ij / (8 * muw**2)
 97
 98        if self.damped:
 99            z = 0.5 * muw * rij**2
100            ez = jnp.exp(-z)
101            f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3)
102            f8 = f6 - (1.0 / 24.0) * ez * z**4
103            f10 = f8 - (1.0 / 120.0) * ez * z**5
104            epair = (
105                f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10
106            )
107        else:
108            epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10
109
110        edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0])
111
112        output_key = self.name if self.energy_key is None else self.energy_key
113
114        if not self.include_exchange:
115            return {**inputs, output_key: edisp}
116
117        ### exchange
118        w = 4 * c6ij / (3 * alphaij**2)
119        # q = (alphaij * mu*w**2)**0.5
120        q2 = alphaij * muw * w
121        # undamped case
122        if self.damped:
123            ze = 0.5 * muw * Re2
124            eze = jnp.exp(-ze)
125
126            s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3)
127            f6e = 1.0 - s6
128            muwRe = muw * Re
129            df6e = muwRe * s6 - eze * (
130                muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3
131            )
132
133            s8 = (1.0 / 24.0) * eze * ze**4
134            f8e = f6e - s8
135            df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4
136
137            s10 = (1.0 / 120.0) * eze * ze**5
138            f10e = f8e - s10
139            df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5
140
141            den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e)
142            A = (
143                0.5
144                + c8ij * (8 * f8e - Re * df8e) / den
145                + c10ij * (10 * f10e - Re * df10e) / (den * Re2)
146            )
147        else:
148            A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4)
149            ez = jnp.exp(-0.5 * muw * rij**2)
150
151        exij = A * q2 * ez / rij
152        ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0])
153
154        return {
155            **inputs,
156            output_key + "_dispersion": edisp,
157            output_key + "_exchange": ex,
158            output_key: edisp + ex,
159        }

Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.

FID : VDW_OQDO

Reference

A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)

VdwOQDO( graph_key: str = 'graph', include_exchange: bool = True, ratiovol_key: Optional[str] = None, energy_key: Optional[str] = None, damped: bool = True, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

include_exchange: bool = True

Whether to compute the exchange part.

ratiovol_key: Optional[str] = None

The key for the ratio between AIM volume and free-atom volume. If None, the volume ratio is assumed to be 1.0.

energy_key: Optional[str] = None

The key for the output energy. If None, the name of the module is used.

damped: bool = True

Whether to use short-range damping.

FID: ClassVar[str] = 'VDW_OQDO'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class DispersionD3(flax.linen.module.Module):
162class DispersionD3(nn.Module):
163    """
164
165    FID : DISPERSION_D3
166
167    """
168
169    _graphs_properties: Dict
170    graph_key: str = "graph"
171    output_key: Optional[str] = None
172    s6: float = 1.0
173    s8: float = 1.0
174    a1: float = 0.4
175    a2: float = 5.0
176    _energy_unit: str = "Ha"
177    """The energy unit of the model. **Automatically set by FENNIX**"""
178
179    FID: ClassVar[str] = "DISPERSION_D3"
180
181    @nn.compact
182    def __call__(self, inputs):
183
184        path = str(pathlib.Path(__file__).parent.resolve()) + "/ref_data_d3.pkl"
185        with open(path, "rb") as f:
186            DATA_D3 = pickle.load(f)
187
188        graph = inputs[self.graph_key]
189        edge_src = graph["edge_src"]
190        edge_dst = graph["edge_dst"]
191        switch = graph["switch"]
192        species = inputs["species"]
193
194        rij = jnp.clip(graph["distances"] / au.BOHR, 1e-6, None)
195
196        ## RADII (in BOHR)
197        rcov = jnp.array(DATA_D3["COV_D3"])[species]
198        # rvdw = jnp.array(DATA_D3["VDW_D3"])[species]
199        r4r2 = jnp.array(DATA_D3["R4R2"])[species]
200
201        rcij = rcov[edge_src] + rcov[edge_dst]
202
203        ## COORDINATION NUMBER
204        KCN = 16.0
205        cnij = jax.nn.sigmoid(KCN * (rcij / rij - 1.0))
206        cn = jax.ops.segment_sum(cnij, edge_src, species.shape[0])
207
208        ## WEIGHTS
209        refcn = jnp.array(DATA_D3["REF_CN"])[species]
210        mask = refcn >= 0
211        dcn = refcn - cn[:, None]
212        KW = 4.0
213        weights = jnp.where(mask, jnp.exp(-KW * dcn**2), 0.0)
214        norm = weights.sum(axis=1, keepdims=True)
215        weights = jnp.where(mask, weights / jnp.clip(norm, 1e-6, None), 0.0)
216
217        ## correct for all null weights
218        imaxcn = np.argmax(DATA_D3["REF_CN"], axis=1)
219        exweight = np.zeros_like(DATA_D3["REF_CN"])
220        for i, imax in enumerate(imaxcn):
221            exweight[i, imax] = 1.0
222        exweight = jnp.array(exweight)[species]
223
224        exceptional = norm < 1.0e-6
225        weights = jnp.where(exceptional, exweight, weights)
226
227        ## C6 coefficients
228        REF_C6 = DATA_D3["REF_C6"]
229        nz = REF_C6.shape[0]
230        nref = REF_C6.shape[-1]
231        REF_C6 = jnp.array(REF_C6.reshape((nz * nz, nref, nref)))
232        pair_num = species[edge_src] * nz + species[edge_dst]
233        rc6 = REF_C6[pair_num]
234        c6 = jnp.einsum("iab,ia,ib->i", rc6, weights[edge_src], weights[edge_dst])
235
236        ## DISPERSION
237
238        qq = 3 * r4r2[edge_src] * r4r2[edge_dst]
239        c8 = c6 * qq
240
241        r0 = self.a1 * jnp.sqrt(qq) + self.a2
242
243        t6 = self.s6 / (rij**6 + r0**6)
244        t8 = self.s8 / (rij**8 + r0**8)
245
246        energy_unit = au.get_multiplier(self._energy_unit)
247        energy = (-0.5*energy_unit) * jax.ops.segment_sum(
248            (c6 * t6 + c8 * t8) * switch, edge_src, species.shape[0]
249        )
250
251        output_key = self.output_key if self.output_key is not None else self.name
252        return {**inputs, output_key: energy}

FID : DISPERSION_D3

DispersionD3( _graphs_properties: Dict, graph_key: str = 'graph', output_key: Optional[str] = None, s6: float = 1.0, s8: float = 1.0, a1: float = 0.4, a2: float = 5.0, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'
output_key: Optional[str] = None
s6: float = 1.0
s8: float = 1.0
a1: float = 0.4
a2: float = 5.0
FID: ClassVar[str] = 'DISPERSION_D3'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None