fennol.models.physics.dispersion

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import numpy as np
  5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  6from ...utils import AtomicUnits as au
  7from ...utils.periodic_table import (
  8    D3_ELECTRONEGATIVITIES,
  9    D3_HARDNESSES,
 10    D3_VDW_RADII,
 11    D3_COV_RADII,
 12    D3_KAPPA,
 13    VDW_RADII,
 14    POLARIZABILITIES,
 15    C6_FREE,
 16    VALENCE_ELECTRONS,
 17)
 18
 19
 20class VdwOQDO(nn.Module):
 21    """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
 22    
 23    FID : VDW_OQDO
 24
 25    ### Reference
 26    A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators,
 27    J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
 28    
 29    """
 30    graph_key: str = "graph"
 31    """ The key for the graph input."""
 32    include_exchange: bool = True
 33    """ Whether to compute the exchange part."""
 34    ratiovol_key: Optional[str] = None
 35    """ The key for the ratio between AIM volume and free-atom volume. 
 36         If None, the volume ratio is assumed to be 1.0."""
 37    energy_key: Optional[str] = None
 38    """ The key for the output energy. If None, the name of the module is used."""
 39    damped: bool = True
 40    """ Whether to use short-range damping."""
 41    _energy_unit: str = "Ha"
 42    """The energy unit of the model. **Automatically set by FENNIX**"""
 43
 44    FID: ClassVar[str]  = "VDW_OQDO"
 45
 46    @nn.compact
 47    def __call__(self, inputs):
 48        energy_unit = au.get_multiplier(self._energy_unit)
 49
 50        species = inputs["species"]
 51        graph = inputs[self.graph_key]
 52        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 53        rij = graph["distances"] / au.BOHR
 54        switch = graph["switch"]
 55
 56        c6 = jnp.asarray(C6_FREE)[species]
 57        alpha = jnp.asarray(POLARIZABILITIES)[species]
 58
 59        if self.ratiovol_key is not None:
 60            ratiovol = inputs[self.ratiovol_key] + 1.0e-6
 61            if ratiovol.shape[-1] == 1:
 62                ratiovol = jnp.squeeze(ratiovol, axis=-1)
 63            c6 = c6 * ratiovol**2
 64            alpha = alpha * ratiovol
 65
 66        c6i, c6j = c6[edge_src], c6[edge_dst]
 67        alphai, alphaj = alpha[edge_src], alpha[edge_dst]
 68
 69        # combination rules
 70        alphaij = 0.5 * (alphai + alphaj)
 71        c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2)
 72
 73        # equilibrium distance
 74        Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
 75        Re2 = Re**2
 76        Re4 = Re**4
 77        # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators"
 78        if self.damped:
 79            muw = (
 80                4.83053463e-01
 81                - 3.76191669e-02 * Re
 82                + 1.27066988e-03 * Re2
 83                - 7.21940151e-07 * Re4
 84            ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2)
 85        else:
 86            muw = (
 87                3.66316787e01
 88                - 5.79579187 * Re
 89                + 3.02674813e-01 * Re2
 90                - 3.65461255e-04 * Re4
 91            ) / (-1.46169102e01 + 7.32461225 * Re)
 92
 93        c8ij = 5 * c6ij / muw
 94        c10ij = 245 * c6ij / (8 * muw**2)
 95
 96        if self.damped:
 97            z = 0.5 * muw * rij**2
 98            ez = jnp.exp(-z)
 99            f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3)
100            f8 = f6 - (1.0 / 24.0) * ez * z**4
101            f10 = f8 - (1.0 / 120.0) * ez * z**5
102            epair = (
103                f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10
104            )
105        else:
106            epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10
107
108        edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0])
109
110        output_key = self.name if self.energy_key is None else self.energy_key
111
112        if not self.include_exchange:
113            return {**inputs, output_key: edisp}
114
115        ### exchange
116        w = 4 * c6ij / (3 * alphaij**2)
117        # q = (alphaij * mu*w**2)**0.5
118        q2 = alphaij * muw * w
119        # undamped case
120        if self.damped:
121            ze = 0.5 * muw * Re2
122            eze = jnp.exp(-ze)
123
124            s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3)
125            f6e = 1.0 - s6
126            muwRe = muw * Re
127            df6e = muwRe * s6 - eze * (
128                muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3
129            )
130
131            s8 = (1.0 / 24.0) * eze * ze**4
132            f8e = f6e - s8
133            df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4
134
135            s10 = (1.0 / 120.0) * eze * ze**5
136            f10e = f8e - s10
137            df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5
138
139            den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e)
140            A = (
141                0.5
142                + c8ij * (8 * f8e - Re * df8e) / den
143                + c10ij * (10 * f10e - Re * df10e) / (den * Re2)
144            )
145        else:
146            A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4)
147            ez = jnp.exp(-0.5 * muw * rij**2)
148
149        exij = A * q2 * ez / rij
150        ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0])
151
152        return {
153            **inputs,
154            output_key + "_dispersion": edisp,
155            output_key + "_exchange": ex,
156            output_key: edisp + ex,
157        }
class VdwOQDO(flax.linen.module.Module):
 21class VdwOQDO(nn.Module):
 22    """ Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.
 23    
 24    FID : VDW_OQDO
 25
 26    ### Reference
 27    A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators,
 28    J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)
 29    
 30    """
 31    graph_key: str = "graph"
 32    """ The key for the graph input."""
 33    include_exchange: bool = True
 34    """ Whether to compute the exchange part."""
 35    ratiovol_key: Optional[str] = None
 36    """ The key for the ratio between AIM volume and free-atom volume. 
 37         If None, the volume ratio is assumed to be 1.0."""
 38    energy_key: Optional[str] = None
 39    """ The key for the output energy. If None, the name of the module is used."""
 40    damped: bool = True
 41    """ Whether to use short-range damping."""
 42    _energy_unit: str = "Ha"
 43    """The energy unit of the model. **Automatically set by FENNIX**"""
 44
 45    FID: ClassVar[str]  = "VDW_OQDO"
 46
 47    @nn.compact
 48    def __call__(self, inputs):
 49        energy_unit = au.get_multiplier(self._energy_unit)
 50
 51        species = inputs["species"]
 52        graph = inputs[self.graph_key]
 53        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 54        rij = graph["distances"] / au.BOHR
 55        switch = graph["switch"]
 56
 57        c6 = jnp.asarray(C6_FREE)[species]
 58        alpha = jnp.asarray(POLARIZABILITIES)[species]
 59
 60        if self.ratiovol_key is not None:
 61            ratiovol = inputs[self.ratiovol_key] + 1.0e-6
 62            if ratiovol.shape[-1] == 1:
 63                ratiovol = jnp.squeeze(ratiovol, axis=-1)
 64            c6 = c6 * ratiovol**2
 65            alpha = alpha * ratiovol
 66
 67        c6i, c6j = c6[edge_src], c6[edge_dst]
 68        alphai, alphaj = alpha[edge_src], alpha[edge_dst]
 69
 70        # combination rules
 71        alphaij = 0.5 * (alphai + alphaj)
 72        c6ij = 2 * alphai * alphaj * c6i * c6j / (c6i * alphaj**2 + c6j * alphai**2)
 73
 74        # equilibrium distance
 75        Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
 76        Re2 = Re**2
 77        Re4 = Re**4
 78        # fit to largest root of eq (S33) of "Universal Pairwise Interatomic van der Waals Potentials Based On Quantum Drude Oscillators"
 79        if self.damped:
 80            muw = (
 81                4.83053463e-01
 82                - 3.76191669e-02 * Re
 83                + 1.27066988e-03 * Re2
 84                - 7.21940151e-07 * Re4
 85            ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2)
 86        else:
 87            muw = (
 88                3.66316787e01
 89                - 5.79579187 * Re
 90                + 3.02674813e-01 * Re2
 91                - 3.65461255e-04 * Re4
 92            ) / (-1.46169102e01 + 7.32461225 * Re)
 93
 94        c8ij = 5 * c6ij / muw
 95        c10ij = 245 * c6ij / (8 * muw**2)
 96
 97        if self.damped:
 98            z = 0.5 * muw * rij**2
 99            ez = jnp.exp(-z)
100            f6 = 1.0 - ez * (1.0 + z + 0.5 * z**2 + (1.0 / 6.0) * z**3)
101            f8 = f6 - (1.0 / 24.0) * ez * z**4
102            f10 = f8 - (1.0 / 120.0) * ez * z**5
103            epair = (
104                f6 * c6ij / rij**6 + f8 * c8ij / rij**8 + f10 * c10ij / rij**10
105            )
106        else:
107            epair = c6ij / rij**6 + c8ij / rij**8 + c10ij / rij**10
108
109        edisp = (-0.5*energy_unit) * jax.ops.segment_sum(epair * switch, edge_src, species.shape[0])
110
111        output_key = self.name if self.energy_key is None else self.energy_key
112
113        if not self.include_exchange:
114            return {**inputs, output_key: edisp}
115
116        ### exchange
117        w = 4 * c6ij / (3 * alphaij**2)
118        # q = (alphaij * mu*w**2)**0.5
119        q2 = alphaij * muw * w
120        # undamped case
121        if self.damped:
122            ze = 0.5 * muw * Re2
123            eze = jnp.exp(-ze)
124
125            s6 = eze * (1.0 + ze + 0.5 * ze**2 + (1.0 / 6.0) * ze**3)
126            f6e = 1.0 - s6
127            muwRe = muw * Re
128            df6e = muwRe * s6 - eze * (
129                muwRe + 0.5 * Re * muwRe**2 + (1.0 / 8.0) * Re2 * muwRe**3
130            )
131
132            s8 = (1.0 / 24.0) * eze * ze**4
133            f8e = f6e - s8
134            df8e = df6e + muwRe * s8 - (1.0 / 48.0) * eze * Re2 * Re * muwRe**4
135
136            s10 = (1.0 / 120.0) * eze * ze**5
137            f10e = f8e - s10
138            df10e = df8e + muwRe * s10 - (1.0 / 384.0) * eze * Re2 * Re2 * muwRe**5
139
140            den = 2 * c6ij * Re2 * (6 * f6e - Re * df6e)
141            A = (
142                0.5
143                + c8ij * (8 * f8e - Re * df8e) / den
144                + c10ij * (10 * f10e - Re * df10e) / (den * Re2)
145            )
146        else:
147            A = 0.5 + 2 * c8ij / (3 * c6ij * Re2) + 5 * c10ij / (6 * c6ij * Re4)
148            ez = jnp.exp(-0.5 * muw * rij**2)
149
150        exij = A * q2 * ez / rij
151        ex = (0.5*energy_unit) * jax.ops.segment_sum(exij * switch, edge_src, species.shape[0])
152
153        return {
154            **inputs,
155            output_key + "_dispersion": edisp,
156            output_key + "_exchange": ex,
157            output_key: edisp + ex,
158        }

Dispersion and exchange based on the Optimized Quantum Drude Oscillator model.

FID : VDW_OQDO

Reference

A. Khabibrakhmanov, D. V. Fedorov, and A. Tkatchenko, Universal Pairwise Interatomic van der Waals Potentials Based on Quantum Drude Oscillators, J. Chem. Theory Comput. 2023, 19, 21, 7895–7907 (https://doi.org/10.1021/acs.jctc.3c00797)

VdwOQDO( graph_key: str = 'graph', include_exchange: bool = True, ratiovol_key: Optional[str] = None, energy_key: Optional[str] = None, damped: bool = True, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

include_exchange: bool = True

Whether to compute the exchange part.

ratiovol_key: Optional[str] = None

The key for the ratio between AIM volume and free-atom volume. If None, the volume ratio is assumed to be 1.0.

energy_key: Optional[str] = None

The key for the output energy. If None, the name of the module is used.

damped: bool = True

Whether to use short-range damping.

FID: ClassVar[str] = 'VDW_OQDO'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None