fennol.models.physics.repulsion

  1import pathlib
  2import jax
  3import jax.numpy as jnp
  4import flax.linen as nn
  5import numpy as np
  6from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  7from ...utils import AtomicUnits as au
  8from ...utils.periodic_table import D3_COV_RADII, UFF_VDW_RADII
  9
 10
 11class RepulsionZBL(nn.Module):
 12    """Repulsion energy based on the Ziegler-Biersack-Littmark potential
 13
 14    FID: REPULSION_ZBL
 15
 16    ### Reference
 17    J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
 18
 19    """
 20
 21    _graphs_properties: Dict
 22    graph_key: str = "graph"
 23    """The key for the graph input."""
 24    energy_key: Optional[str] = None
 25    """The key for the output energy."""
 26    trainable: bool = True
 27    """Whether the parameters are trainable."""
 28    _energy_unit: str = "Ha"
 29    """The energy unit of the model. **Automatically set by FENNIX**"""
 30    proportional_regularization: bool = True
 31    d: float = 0.46850 / au.BOHR
 32    p: float = 0.23
 33    alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162)
 34    cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
 35    cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514)
 36
 37    FID: ClassVar[str] = "REPULSION_ZBL"
 38
 39    @nn.compact
 40    def __call__(self, inputs):
 41        species = inputs["species"]
 42        graph = inputs[self.graph_key]
 43        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 44
 45        training = "training" in inputs.get("flags", {})
 46
 47        rijs = graph["distances"] / au.BOHR
 48
 49        d_ = self.d
 50        p_ = self.p
 51        assert len(self.alphas) == 4, "alphas must be a sequence of length 4"
 52        alphas_ = np.array(self.alphas, dtype=rijs.dtype)
 53        assert len(self.cs) == 4, "cs must be a sequence of length 4"
 54        cs_ = np.array(self.cs, dtype=rijs.dtype)
 55        cs_ = 0.5 * cs_ / np.sum(cs_)
 56        if self.trainable:
 57            d = jnp.abs(
 58                self.param(
 59                    "d",
 60                    lambda key, d: jnp.asarray(d, dtype=rijs.dtype),
 61                    d_,
 62                )
 63            )
 64            p = jnp.abs(
 65                self.param(
 66                    "p",
 67                    lambda key, p: jnp.asarray(p, dtype=rijs.dtype),
 68                    p_,
 69                )
 70            )
 71            cs = 0.5 * jax.nn.softmax(
 72                self.param(
 73                    "cs",
 74                    lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype),
 75                    np.array(self.cs_logits, dtype=rijs.dtype),
 76                )
 77            )
 78            alphas = jnp.abs(
 79                self.param(
 80                    "alphas",
 81                    lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype),
 82                    alphas_,
 83                )
 84            )
 85
 86            if training:
 87                if self.proportional_regularization:
 88                    reg = jnp.asarray(
 89                        ((1 - alphas / alphas_) ** 2).sum()
 90                        + ((1 - cs / cs_) ** 2).sum()
 91                        + (1 - p / p_) ** 2
 92                        + (1 - d / d_) ** 2
 93                    ).reshape(1)
 94                else:
 95                    reg = jnp.asarray(
 96                        ((alphas_ - alphas) ** 2).sum()
 97                        + ((cs_ - cs) ** 2).sum()
 98                        + (p_ - p) ** 2
 99                        + (d_ - d) ** 2
100                    ).reshape(1)
101        else:
102            cs = jnp.asarray(cs_)
103            alphas = jnp.asarray(alphas_)
104            d = d_
105            p = p_
106
107        if "alch_group" in inputs:
108            switch = graph["switch_raw"]
109            lambda_v = inputs["alch_vlambda"]
110            alch_group = inputs["alch_group"]
111            alch_m = inputs.get("alch_m", 2)
112
113            mask = alch_group[edge_src] == alch_group[edge_dst]
114
115            if "alch_softcore_rep" in inputs:
116                alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v)
117                rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5)
118            lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v))
119            switch = jnp.where(
120                mask,
121                switch,
122                (lambda_v**alch_m) * switch,
123            )
124        else:
125            switch = graph["switch"]
126
127        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
128        Zij = Z[edge_src] * Z[edge_dst]
129        Zp = Z**p / d
130        x = rijs * (Zp[edge_src] + Zp[edge_dst])
131        phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1)
132
133        ereppair = Zij * phi / rijs * switch
134
135        erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
136
137        energy_unit = au.get_multiplier(self._energy_unit)
138        energy_key = self.energy_key if self.energy_key is not None else self.name
139        output = {**inputs, energy_key: erep_atomic * energy_unit}
140        if self.trainable and training:
141            output[energy_key + "_regularization"] = reg
142
143        return output
144
145
146class RepulsionNLH(nn.Module):
147    """NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
148
149    FID: REPULSION_NLH
150
151    ### Reference
152    K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818
153    https://doi.org/10.1103/PhysRevA.111.032818
154    """
155
156    _graphs_properties: Dict
157    graph_key: str = "graph"
158    """The key for the graph input."""
159    energy_key: Optional[str] = None
160    """The key for the output energy."""
161    _energy_unit: str = "Ha"
162    """The energy unit of the model. **Automatically set by FENNIX**"""
163    trainable: bool = False
164    direct_forces_key: Optional[str] = None
165
166    FID: ClassVar[str] = "REPULSION_NLH"
167
168    @nn.compact
169    def __call__(self, inputs):
170
171        path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat"
172        DATA_NLH = np.loadtxt(path, usecols=np.arange(0, 8))
173        zmax = int(np.max(DATA_NLH[:, 0]))
174        AB = np.zeros(((zmax + 1) ** 2, 6), dtype=np.float32)
175        for i in range(DATA_NLH.shape[0]):
176            z1 = int(DATA_NLH[i, 0])
177            z2 = int(DATA_NLH[i, 1])
178            AB[z1 + zmax * z2] = DATA_NLH[i, 2:8]
179            AB[z2 + zmax * z1] = DATA_NLH[i, 2:8]
180        AB = AB.reshape((zmax + 1) ** 2, 3, 2)
181
182        species = inputs["species"]
183        graph = inputs[self.graph_key]
184        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
185        rijs = graph["distances"]
186
187        # coefficients (a1,a2,a3)
188        CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype)
189        # exponents (b1,b2,b3)
190        ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype)
191
192        if self.trainable:
193            cfact = jnp.abs(
194                self.param(
195                    "c_fact",
196                    lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype),
197                )
198            )
199            CS = CS * cfact[None, :]
200            CS = CS / jnp.sum(CS, axis=1, keepdims=True)
201            alphas_fact = jnp.abs(
202                self.param(
203                    "alpha_fact",
204                    lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype),
205                )
206            )
207            ALPHAS = ALPHAS * alphas_fact[None, :]
208
209        s12 = species[edge_src] + zmax * species[edge_dst]
210        cs = CS[s12]
211        alphas = ALPHAS[s12]
212
213        if "alch_group" in inputs:
214            switch = graph["switch_raw"]
215            lambda_v = inputs["alch_vlambda"]
216            alch_group = inputs["alch_group"]
217            alch_m = inputs.get("alch_m", 2)
218
219            mask = alch_group[edge_src] == alch_group[edge_dst]
220            if "alch_softcore_rep" in inputs:
221                alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v)
222                rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5)
223            lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v))
224            switch = jnp.where(
225                mask,
226                switch,
227                (lambda_v**alch_m) * switch,
228            )
229            # alphas = jnp.where(
230            #     mask[:,None],
231            #     alphas,
232            #     lambda_v * alphas ,
233            # )
234        else:
235            switch = graph["switch"]
236
237        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
238        phi = (cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1)
239        Zij = Z[edge_src] * Z[edge_dst] * switch
240
241        ereppair = Zij * phi / rijs
242
243        energy_unit = au.get_multiplier(self._energy_unit)
244        erep_atomic = (energy_unit * 0.5 * au.BOHR) * jax.ops.segment_sum(
245            ereppair, edge_src, species.shape[0]
246        )
247
248        energy_key = self.energy_key if self.energy_key is not None else self.name
249        output = {**inputs, energy_key: erep_atomic}
250
251        if self.direct_forces_key is not None:
252            dphidr = -(alphas * cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1)
253            dedr = Zij * (dphidr / rijs - phi / (rijs**2))
254            dedij = (dedr / rijs)[:, None] * graph["vec"]
255            fi = (energy_unit * au.BOHR) * jax.ops.segment_sum(
256                dedij, edge_src, species.shape[0]
257            )
258            output[self.direct_forces_key] = fi
259
260        return output
class RepulsionZBL(flax.linen.module.Module):
 12class RepulsionZBL(nn.Module):
 13    """Repulsion energy based on the Ziegler-Biersack-Littmark potential
 14
 15    FID: REPULSION_ZBL
 16
 17    ### Reference
 18    J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
 19
 20    """
 21
 22    _graphs_properties: Dict
 23    graph_key: str = "graph"
 24    """The key for the graph input."""
 25    energy_key: Optional[str] = None
 26    """The key for the output energy."""
 27    trainable: bool = True
 28    """Whether the parameters are trainable."""
 29    _energy_unit: str = "Ha"
 30    """The energy unit of the model. **Automatically set by FENNIX**"""
 31    proportional_regularization: bool = True
 32    d: float = 0.46850 / au.BOHR
 33    p: float = 0.23
 34    alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162)
 35    cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
 36    cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514)
 37
 38    FID: ClassVar[str] = "REPULSION_ZBL"
 39
 40    @nn.compact
 41    def __call__(self, inputs):
 42        species = inputs["species"]
 43        graph = inputs[self.graph_key]
 44        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 45
 46        training = "training" in inputs.get("flags", {})
 47
 48        rijs = graph["distances"] / au.BOHR
 49
 50        d_ = self.d
 51        p_ = self.p
 52        assert len(self.alphas) == 4, "alphas must be a sequence of length 4"
 53        alphas_ = np.array(self.alphas, dtype=rijs.dtype)
 54        assert len(self.cs) == 4, "cs must be a sequence of length 4"
 55        cs_ = np.array(self.cs, dtype=rijs.dtype)
 56        cs_ = 0.5 * cs_ / np.sum(cs_)
 57        if self.trainable:
 58            d = jnp.abs(
 59                self.param(
 60                    "d",
 61                    lambda key, d: jnp.asarray(d, dtype=rijs.dtype),
 62                    d_,
 63                )
 64            )
 65            p = jnp.abs(
 66                self.param(
 67                    "p",
 68                    lambda key, p: jnp.asarray(p, dtype=rijs.dtype),
 69                    p_,
 70                )
 71            )
 72            cs = 0.5 * jax.nn.softmax(
 73                self.param(
 74                    "cs",
 75                    lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype),
 76                    np.array(self.cs_logits, dtype=rijs.dtype),
 77                )
 78            )
 79            alphas = jnp.abs(
 80                self.param(
 81                    "alphas",
 82                    lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype),
 83                    alphas_,
 84                )
 85            )
 86
 87            if training:
 88                if self.proportional_regularization:
 89                    reg = jnp.asarray(
 90                        ((1 - alphas / alphas_) ** 2).sum()
 91                        + ((1 - cs / cs_) ** 2).sum()
 92                        + (1 - p / p_) ** 2
 93                        + (1 - d / d_) ** 2
 94                    ).reshape(1)
 95                else:
 96                    reg = jnp.asarray(
 97                        ((alphas_ - alphas) ** 2).sum()
 98                        + ((cs_ - cs) ** 2).sum()
 99                        + (p_ - p) ** 2
100                        + (d_ - d) ** 2
101                    ).reshape(1)
102        else:
103            cs = jnp.asarray(cs_)
104            alphas = jnp.asarray(alphas_)
105            d = d_
106            p = p_
107
108        if "alch_group" in inputs:
109            switch = graph["switch_raw"]
110            lambda_v = inputs["alch_vlambda"]
111            alch_group = inputs["alch_group"]
112            alch_m = inputs.get("alch_m", 2)
113
114            mask = alch_group[edge_src] == alch_group[edge_dst]
115
116            if "alch_softcore_rep" in inputs:
117                alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v)
118                rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5)
119            lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v))
120            switch = jnp.where(
121                mask,
122                switch,
123                (lambda_v**alch_m) * switch,
124            )
125        else:
126            switch = graph["switch"]
127
128        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
129        Zij = Z[edge_src] * Z[edge_dst]
130        Zp = Z**p / d
131        x = rijs * (Zp[edge_src] + Zp[edge_dst])
132        phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1)
133
134        ereppair = Zij * phi / rijs * switch
135
136        erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
137
138        energy_unit = au.get_multiplier(self._energy_unit)
139        energy_key = self.energy_key if self.energy_key is not None else self.name
140        output = {**inputs, energy_key: erep_atomic * energy_unit}
141        if self.trainable and training:
142            output[energy_key + "_regularization"] = reg
143
144        return output

Repulsion energy based on the Ziegler-Biersack-Littmark potential

FID: REPULSION_ZBL

Reference

J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter

RepulsionZBL( _graphs_properties: Dict, graph_key: str = 'graph', energy_key: Optional[str] = None, trainable: bool = True, _energy_unit: str = 'Ha', proportional_regularization: bool = True, d: float = 0.885336690897932, p: float = 0.23, alphas: Sequence[float] = (3.1998, 0.94229, 0.4029, 0.20162), cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697), cs_logits: Sequence[float] = (0.113, 1.1445, 0.5459, -1.7514), parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

energy_key: Optional[str] = None

The key for the output energy.

trainable: bool = True

Whether the parameters are trainable.

proportional_regularization: bool = True
d: float = 0.885336690897932
p: float = 0.23
alphas: Sequence[float] = (3.1998, 0.94229, 0.4029, 0.20162)
cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
cs_logits: Sequence[float] = (0.113, 1.1445, 0.5459, -1.7514)
FID: ClassVar[str] = 'REPULSION_ZBL'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class RepulsionNLH(flax.linen.module.Module):
147class RepulsionNLH(nn.Module):
148    """NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
149
150    FID: REPULSION_NLH
151
152    ### Reference
153    K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818
154    https://doi.org/10.1103/PhysRevA.111.032818
155    """
156
157    _graphs_properties: Dict
158    graph_key: str = "graph"
159    """The key for the graph input."""
160    energy_key: Optional[str] = None
161    """The key for the output energy."""
162    _energy_unit: str = "Ha"
163    """The energy unit of the model. **Automatically set by FENNIX**"""
164    trainable: bool = False
165    direct_forces_key: Optional[str] = None
166
167    FID: ClassVar[str] = "REPULSION_NLH"
168
169    @nn.compact
170    def __call__(self, inputs):
171
172        path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat"
173        DATA_NLH = np.loadtxt(path, usecols=np.arange(0, 8))
174        zmax = int(np.max(DATA_NLH[:, 0]))
175        AB = np.zeros(((zmax + 1) ** 2, 6), dtype=np.float32)
176        for i in range(DATA_NLH.shape[0]):
177            z1 = int(DATA_NLH[i, 0])
178            z2 = int(DATA_NLH[i, 1])
179            AB[z1 + zmax * z2] = DATA_NLH[i, 2:8]
180            AB[z2 + zmax * z1] = DATA_NLH[i, 2:8]
181        AB = AB.reshape((zmax + 1) ** 2, 3, 2)
182
183        species = inputs["species"]
184        graph = inputs[self.graph_key]
185        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
186        rijs = graph["distances"]
187
188        # coefficients (a1,a2,a3)
189        CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype)
190        # exponents (b1,b2,b3)
191        ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype)
192
193        if self.trainable:
194            cfact = jnp.abs(
195                self.param(
196                    "c_fact",
197                    lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype),
198                )
199            )
200            CS = CS * cfact[None, :]
201            CS = CS / jnp.sum(CS, axis=1, keepdims=True)
202            alphas_fact = jnp.abs(
203                self.param(
204                    "alpha_fact",
205                    lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype),
206                )
207            )
208            ALPHAS = ALPHAS * alphas_fact[None, :]
209
210        s12 = species[edge_src] + zmax * species[edge_dst]
211        cs = CS[s12]
212        alphas = ALPHAS[s12]
213
214        if "alch_group" in inputs:
215            switch = graph["switch_raw"]
216            lambda_v = inputs["alch_vlambda"]
217            alch_group = inputs["alch_group"]
218            alch_m = inputs.get("alch_m", 2)
219
220            mask = alch_group[edge_src] == alch_group[edge_dst]
221            if "alch_softcore_rep" in inputs:
222                alch_alpha = inputs["alch_softcore_rep"] ** 2 * (1 - lambda_v)
223                rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha) ** 0.5)
224            lambda_v = 0.5 * (1 - jnp.cos(jnp.pi * lambda_v))
225            switch = jnp.where(
226                mask,
227                switch,
228                (lambda_v**alch_m) * switch,
229            )
230            # alphas = jnp.where(
231            #     mask[:,None],
232            #     alphas,
233            #     lambda_v * alphas ,
234            # )
235        else:
236            switch = graph["switch"]
237
238        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
239        phi = (cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1)
240        Zij = Z[edge_src] * Z[edge_dst] * switch
241
242        ereppair = Zij * phi / rijs
243
244        energy_unit = au.get_multiplier(self._energy_unit)
245        erep_atomic = (energy_unit * 0.5 * au.BOHR) * jax.ops.segment_sum(
246            ereppair, edge_src, species.shape[0]
247        )
248
249        energy_key = self.energy_key if self.energy_key is not None else self.name
250        output = {**inputs, energy_key: erep_atomic}
251
252        if self.direct_forces_key is not None:
253            dphidr = -(alphas * cs * jnp.exp(-alphas * rijs[:, None])).sum(axis=-1)
254            dedr = Zij * (dphidr / rijs - phi / (rijs**2))
255            dedij = (dedr / rijs)[:, None] * graph["vec"]
256            fi = (energy_unit * au.BOHR) * jax.ops.segment_sum(
257                dedij, edge_src, species.shape[0]
258            )
259            output[self.direct_forces_key] = fi
260
261        return output

NLH pairwise repulsive potential with pair-specific coefficients up to Z=92

FID: REPULSION_NLH

Reference

K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 https://doi.org/10.1103/PhysRevA.111.032818

RepulsionNLH( _graphs_properties: Dict, graph_key: str = 'graph', energy_key: Optional[str] = None, _energy_unit: str = 'Ha', trainable: bool = False, direct_forces_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

energy_key: Optional[str] = None

The key for the output energy.

trainable: bool = False
direct_forces_key: Optional[str] = None
FID: ClassVar[str] = 'REPULSION_NLH'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None