fennol.models.physics.repulsion

  1import pathlib
  2import jax
  3import jax.numpy as jnp
  4import flax.linen as nn
  5import numpy as np
  6from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  7from ...utils import AtomicUnits as au
  8from ...utils.periodic_table import D3_COV_RADII,UFF_VDW_RADII
  9
 10
 11class RepulsionZBL(nn.Module):
 12    """Repulsion energy based on the Ziegler-Biersack-Littmark potential
 13
 14    FID: REPULSION_ZBL
 15
 16    ### Reference
 17    J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
 18
 19    """
 20
 21    _graphs_properties: Dict
 22    graph_key: str = "graph"
 23    """The key for the graph input."""
 24    energy_key: Optional[str] = None
 25    """The key for the output energy."""
 26    trainable: bool = True
 27    """Whether the parameters are trainable."""
 28    _energy_unit: str = "Ha"
 29    """The energy unit of the model. **Automatically set by FENNIX**"""
 30    proportional_regularization: bool = True
 31    d: float = 0.46850/au.BOHR
 32    p: float = 0.23
 33    alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162)
 34    cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
 35    cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514)
 36
 37    FID: ClassVar[str] = "REPULSION_ZBL"
 38
 39    @nn.compact
 40    def __call__(self, inputs):
 41        species = inputs["species"]
 42        graph = inputs[self.graph_key]
 43        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 44
 45        training = "training" in inputs.get("flags", {})
 46
 47        rijs = graph["distances"] / au.BOHR
 48
 49        d_ = self.d
 50        p_ = self.p
 51        assert len(self.alphas) == 4, "alphas must be a sequence of length 4"
 52        alphas_ = np.array(self.alphas, dtype=rijs.dtype)
 53        assert len(self.cs) == 4, "cs must be a sequence of length 4"
 54        cs_ =  np.array(self.cs, dtype=rijs.dtype)
 55        cs_ = 0.5 * cs_ / np.sum(cs_)
 56        if self.trainable:
 57            d = jnp.abs(
 58                self.param(
 59                    "d",
 60                    lambda key, d: jnp.asarray(d, dtype=rijs.dtype),
 61                    d_,
 62                )
 63            )
 64            p = jnp.abs(
 65                self.param(
 66                    "p",
 67                    lambda key, p: jnp.asarray(p, dtype=rijs.dtype),
 68                    p_,
 69                )
 70            )
 71            cs = 0.5 * jax.nn.softmax(
 72                self.param(
 73                    "cs",
 74                    lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype),
 75                    np.array(self.cs_logits, dtype=rijs.dtype),
 76                )
 77            )
 78            alphas = jnp.abs(
 79                self.param(
 80                    "alphas",
 81                    lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype),
 82                    alphas_,
 83                )
 84            )
 85
 86            if training:
 87                if self.proportional_regularization:
 88                    reg = jnp.asarray(
 89                        ((1 - alphas/alphas_) ** 2).sum()
 90                        + ((1 - cs/cs_) ** 2).sum()
 91                        + (1 - p/p_) ** 2
 92                        + (1 - d/d_) ** 2
 93                    ).reshape(1)
 94                else:
 95                    reg = jnp.asarray(
 96                        ((alphas_ - alphas) ** 2).sum()
 97                        + ((cs_ - cs) ** 2).sum()
 98                        + (p_ - p) ** 2
 99                        + (d_ - d) ** 2
100                    ).reshape(1)
101        else:
102            cs = jnp.asarray(cs_)
103            alphas = jnp.asarray(alphas_)
104            d = d_
105            p = p_
106
107        if "alch_group" in inputs:
108            switch = graph["switch_raw"]
109            lambda_v = inputs["alch_vlambda"]
110            alch_group = inputs["alch_group"]
111            alch_alpha = inputs.get("alch_alpha", 0.5)
112            alch_m = inputs.get("alch_m", 2)
113
114            mask = alch_group[edge_src] == alch_group[edge_dst]
115
116            rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5)
117            lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v))
118            switch = jnp.where(
119                mask,
120                switch,
121                (lambda_v**alch_m) * switch ,
122            )
123        else:
124            switch = graph["switch"]
125
126        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
127        Zij = Z[edge_src]*Z[edge_dst]
128        Zp = Z**p / d
129        x = rijs * (Zp[edge_src] + Zp[edge_dst])
130        phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1)
131
132        ereppair = Zij * phi / rijs * switch
133
134        erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
135
136        energy_unit = au.get_multiplier(self._energy_unit)
137        energy_key = self.energy_key if self.energy_key is not None else self.name
138        output = {**inputs, energy_key: erep_atomic * energy_unit}
139        if self.trainable and training:
140            output[energy_key + "_regularization"] = reg
141
142        return output
143    
144class RepulsionNLH(nn.Module):
145    """ NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
146
147    FID: REPULSION_NLH
148
149    ### Reference
150    K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 
151    https://doi.org/10.1103/PhysRevA.111.032818
152    """
153
154    _graphs_properties: Dict
155    graph_key: str = "graph"
156    """The key for the graph input."""
157    energy_key: Optional[str] = None
158    """The key for the output energy."""
159    _energy_unit: str = "Ha"
160    """The energy unit of the model. **Automatically set by FENNIX**"""
161    trainable: bool = False
162    direct_forces_key: Optional[str] = None
163
164    FID: ClassVar[str] = "REPULSION_NLH"
165
166    @nn.compact
167    def __call__(self, inputs):
168
169        path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat"
170        DATA_NLH = np.loadtxt(path,usecols=np.arange(0,8)) 
171        zmax = int(np.max(DATA_NLH[:, 0]))
172        AB = np.zeros(((zmax+1)**2,6), dtype=np.float32)
173        for i in range(DATA_NLH.shape[0]):
174            z1 = int(DATA_NLH[i, 0])
175            z2 = int(DATA_NLH[i, 1])
176            AB[z1+zmax*z2] = DATA_NLH[i, 2:8]
177            AB[z2+zmax*z1] = DATA_NLH[i, 2:8]
178        AB = AB.reshape((zmax+1)**2, 3,2)
179
180        species = inputs["species"]
181        graph = inputs[self.graph_key]
182        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
183        rijs = graph["distances"]
184
185        # coefficients (a1,a2,a3)
186        CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype)
187        # exponents (b1,b2,b3)
188        ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype)
189
190        if self.trainable:
191            cfact = jnp.abs(self.param(
192                "c_fact",
193                lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype),
194            ))
195            CS = CS * cfact[None,:]
196            CS = CS / jnp.sum(CS, axis=1, keepdims=True)
197            alphas_fact = jnp.abs(self.param(
198                "alpha_fact",
199                lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype),
200            ))
201            ALPHAS = ALPHAS * alphas_fact[None,:]
202
203        s12 = species[edge_src] + zmax*species[edge_dst]
204        cs = CS[s12]
205        alphas = ALPHAS[s12]
206
207        if "alch_group" in inputs:
208            switch = graph["switch_raw"]
209            lambda_v = inputs["alch_vlambda"]
210            alch_group = inputs["alch_group"]
211            alch_alpha = inputs.get("alch_alpha", 0.)
212            alch_m = inputs.get("alch_m", 2)
213
214            mask = alch_group[edge_src] == alch_group[edge_dst]
215
216            rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5)
217            lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v))
218            switch = jnp.where(
219                mask,
220                switch,
221                (lambda_v**alch_m) * switch ,
222            )
223            # alphas = jnp.where(
224            #     mask[:,None],
225            #     alphas,
226            #     lambda_v * alphas ,
227            # )
228        else:
229            switch = graph["switch"]
230
231        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
232        phi = (cs * jnp.exp(-alphas*rijs[:,None])).sum(axis=-1)
233        Zij = Z[edge_src]*Z[edge_dst]*switch
234
235        ereppair = Zij * phi / rijs           
236
237        energy_unit = au.get_multiplier(self._energy_unit)
238        erep_atomic = (energy_unit*0.5*au.BOHR)*jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
239
240        energy_key = self.energy_key if self.energy_key is not None else self.name
241        output = {**inputs, energy_key: erep_atomic }
242
243        if self.direct_forces_key is not None:
244            dphidr = -(alphas*cs * jnp.exp(-alphas*rijs[:,None])).sum(axis=-1)
245            dedr = Zij * (dphidr/rijs - phi/(rijs**2))
246            dedij = (dedr/rijs)[:,None] * graph["vec"]
247            fi = (energy_unit*au.BOHR)*jax.ops.segment_sum(dedij, edge_src, species.shape[0])
248            output[self.direct_forces_key] = fi
249
250        return output
class RepulsionZBL(flax.linen.module.Module):
 12class RepulsionZBL(nn.Module):
 13    """Repulsion energy based on the Ziegler-Biersack-Littmark potential
 14
 15    FID: REPULSION_ZBL
 16
 17    ### Reference
 18    J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
 19
 20    """
 21
 22    _graphs_properties: Dict
 23    graph_key: str = "graph"
 24    """The key for the graph input."""
 25    energy_key: Optional[str] = None
 26    """The key for the output energy."""
 27    trainable: bool = True
 28    """Whether the parameters are trainable."""
 29    _energy_unit: str = "Ha"
 30    """The energy unit of the model. **Automatically set by FENNIX**"""
 31    proportional_regularization: bool = True
 32    d: float = 0.46850/au.BOHR
 33    p: float = 0.23
 34    alphas: Sequence[float] = (3.19980, 0.94229, 0.40290, 0.20162)
 35    cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
 36    cs_logits: Sequence[float] = (0.1130, 1.1445, 0.5459, -1.7514)
 37
 38    FID: ClassVar[str] = "REPULSION_ZBL"
 39
 40    @nn.compact
 41    def __call__(self, inputs):
 42        species = inputs["species"]
 43        graph = inputs[self.graph_key]
 44        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 45
 46        training = "training" in inputs.get("flags", {})
 47
 48        rijs = graph["distances"] / au.BOHR
 49
 50        d_ = self.d
 51        p_ = self.p
 52        assert len(self.alphas) == 4, "alphas must be a sequence of length 4"
 53        alphas_ = np.array(self.alphas, dtype=rijs.dtype)
 54        assert len(self.cs) == 4, "cs must be a sequence of length 4"
 55        cs_ =  np.array(self.cs, dtype=rijs.dtype)
 56        cs_ = 0.5 * cs_ / np.sum(cs_)
 57        if self.trainable:
 58            d = jnp.abs(
 59                self.param(
 60                    "d",
 61                    lambda key, d: jnp.asarray(d, dtype=rijs.dtype),
 62                    d_,
 63                )
 64            )
 65            p = jnp.abs(
 66                self.param(
 67                    "p",
 68                    lambda key, p: jnp.asarray(p, dtype=rijs.dtype),
 69                    p_,
 70                )
 71            )
 72            cs = 0.5 * jax.nn.softmax(
 73                self.param(
 74                    "cs",
 75                    lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype),
 76                    np.array(self.cs_logits, dtype=rijs.dtype),
 77                )
 78            )
 79            alphas = jnp.abs(
 80                self.param(
 81                    "alphas",
 82                    lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype),
 83                    alphas_,
 84                )
 85            )
 86
 87            if training:
 88                if self.proportional_regularization:
 89                    reg = jnp.asarray(
 90                        ((1 - alphas/alphas_) ** 2).sum()
 91                        + ((1 - cs/cs_) ** 2).sum()
 92                        + (1 - p/p_) ** 2
 93                        + (1 - d/d_) ** 2
 94                    ).reshape(1)
 95                else:
 96                    reg = jnp.asarray(
 97                        ((alphas_ - alphas) ** 2).sum()
 98                        + ((cs_ - cs) ** 2).sum()
 99                        + (p_ - p) ** 2
100                        + (d_ - d) ** 2
101                    ).reshape(1)
102        else:
103            cs = jnp.asarray(cs_)
104            alphas = jnp.asarray(alphas_)
105            d = d_
106            p = p_
107
108        if "alch_group" in inputs:
109            switch = graph["switch_raw"]
110            lambda_v = inputs["alch_vlambda"]
111            alch_group = inputs["alch_group"]
112            alch_alpha = inputs.get("alch_alpha", 0.5)
113            alch_m = inputs.get("alch_m", 2)
114
115            mask = alch_group[edge_src] == alch_group[edge_dst]
116
117            rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5)
118            lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v))
119            switch = jnp.where(
120                mask,
121                switch,
122                (lambda_v**alch_m) * switch ,
123            )
124        else:
125            switch = graph["switch"]
126
127        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
128        Zij = Z[edge_src]*Z[edge_dst]
129        Zp = Z**p / d
130        x = rijs * (Zp[edge_src] + Zp[edge_dst])
131        phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1)
132
133        ereppair = Zij * phi / rijs * switch
134
135        erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
136
137        energy_unit = au.get_multiplier(self._energy_unit)
138        energy_key = self.energy_key if self.energy_key is not None else self.name
139        output = {**inputs, energy_key: erep_atomic * energy_unit}
140        if self.trainable and training:
141            output[energy_key + "_regularization"] = reg
142
143        return output

Repulsion energy based on the Ziegler-Biersack-Littmark potential

FID: REPULSION_ZBL

Reference

J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter

RepulsionZBL( _graphs_properties: Dict, graph_key: str = 'graph', energy_key: Optional[str] = None, trainable: bool = True, _energy_unit: str = 'Ha', proportional_regularization: bool = True, d: float = 0.885336690897932, p: float = 0.23, alphas: Sequence[float] = (3.1998, 0.94229, 0.4029, 0.20162), cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697), cs_logits: Sequence[float] = (0.113, 1.1445, 0.5459, -1.7514), parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

energy_key: Optional[str] = None

The key for the output energy.

trainable: bool = True

Whether the parameters are trainable.

proportional_regularization: bool = True
d: float = 0.885336690897932
p: float = 0.23
alphas: Sequence[float] = (3.1998, 0.94229, 0.4029, 0.20162)
cs: Sequence[float] = (0.18175273, 0.5098655, 0.28021213, 0.0281697)
cs_logits: Sequence[float] = (0.113, 1.1445, 0.5459, -1.7514)
FID: ClassVar[str] = 'REPULSION_ZBL'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class RepulsionNLH(flax.linen.module.Module):
145class RepulsionNLH(nn.Module):
146    """ NLH pairwise repulsive potential with pair-specific coefficients up to Z=92
147
148    FID: REPULSION_NLH
149
150    ### Reference
151    K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 
152    https://doi.org/10.1103/PhysRevA.111.032818
153    """
154
155    _graphs_properties: Dict
156    graph_key: str = "graph"
157    """The key for the graph input."""
158    energy_key: Optional[str] = None
159    """The key for the output energy."""
160    _energy_unit: str = "Ha"
161    """The energy unit of the model. **Automatically set by FENNIX**"""
162    trainable: bool = False
163    direct_forces_key: Optional[str] = None
164
165    FID: ClassVar[str] = "REPULSION_NLH"
166
167    @nn.compact
168    def __call__(self, inputs):
169
170        path = str(pathlib.Path(__file__).parent.resolve()) + "/nlh_coeffs.dat"
171        DATA_NLH = np.loadtxt(path,usecols=np.arange(0,8)) 
172        zmax = int(np.max(DATA_NLH[:, 0]))
173        AB = np.zeros(((zmax+1)**2,6), dtype=np.float32)
174        for i in range(DATA_NLH.shape[0]):
175            z1 = int(DATA_NLH[i, 0])
176            z2 = int(DATA_NLH[i, 1])
177            AB[z1+zmax*z2] = DATA_NLH[i, 2:8]
178            AB[z2+zmax*z1] = DATA_NLH[i, 2:8]
179        AB = AB.reshape((zmax+1)**2, 3,2)
180
181        species = inputs["species"]
182        graph = inputs[self.graph_key]
183        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
184        rijs = graph["distances"]
185
186        # coefficients (a1,a2,a3)
187        CS = jnp.array(AB[:, :, 0], dtype=rijs.dtype)
188        # exponents (b1,b2,b3)
189        ALPHAS = jnp.array(AB[:, :, 1], dtype=rijs.dtype)
190
191        if self.trainable:
192            cfact = jnp.abs(self.param(
193                "c_fact",
194                lambda key: jnp.ones(CS.shape[1], dtype=CS.dtype),
195            ))
196            CS = CS * cfact[None,:]
197            CS = CS / jnp.sum(CS, axis=1, keepdims=True)
198            alphas_fact = jnp.abs(self.param(
199                "alpha_fact",
200                lambda key: jnp.ones(ALPHAS.shape[1], dtype=ALPHAS.dtype),
201            ))
202            ALPHAS = ALPHAS * alphas_fact[None,:]
203
204        s12 = species[edge_src] + zmax*species[edge_dst]
205        cs = CS[s12]
206        alphas = ALPHAS[s12]
207
208        if "alch_group" in inputs:
209            switch = graph["switch_raw"]
210            lambda_v = inputs["alch_vlambda"]
211            alch_group = inputs["alch_group"]
212            alch_alpha = inputs.get("alch_alpha", 0.)
213            alch_m = inputs.get("alch_m", 2)
214
215            mask = alch_group[edge_src] == alch_group[edge_dst]
216
217            rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5)
218            lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v))
219            switch = jnp.where(
220                mask,
221                switch,
222                (lambda_v**alch_m) * switch ,
223            )
224            # alphas = jnp.where(
225            #     mask[:,None],
226            #     alphas,
227            #     lambda_v * alphas ,
228            # )
229        else:
230            switch = graph["switch"]
231
232        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
233        phi = (cs * jnp.exp(-alphas*rijs[:,None])).sum(axis=-1)
234        Zij = Z[edge_src]*Z[edge_dst]*switch
235
236        ereppair = Zij * phi / rijs           
237
238        energy_unit = au.get_multiplier(self._energy_unit)
239        erep_atomic = (energy_unit*0.5*au.BOHR)*jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
240
241        energy_key = self.energy_key if self.energy_key is not None else self.name
242        output = {**inputs, energy_key: erep_atomic }
243
244        if self.direct_forces_key is not None:
245            dphidr = -(alphas*cs * jnp.exp(-alphas*rijs[:,None])).sum(axis=-1)
246            dedr = Zij * (dphidr/rijs - phi/(rijs**2))
247            dedij = (dedr/rijs)[:,None] * graph["vec"]
248            fi = (energy_unit*au.BOHR)*jax.ops.segment_sum(dedij, edge_src, species.shape[0])
249            output[self.direct_forces_key] = fi
250
251        return output

NLH pairwise repulsive potential with pair-specific coefficients up to Z=92

FID: REPULSION_NLH

Reference

K. Nordlund, S. Lehtola, G. Hobler, Repulsive interatomic potentials calculated at three levels of theory, Phys. Rev. A 111, 032818 https://doi.org/10.1103/PhysRevA.111.032818

RepulsionNLH( _graphs_properties: Dict, graph_key: str = 'graph', energy_key: Optional[str] = None, _energy_unit: str = 'Ha', trainable: bool = False, direct_forces_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

energy_key: Optional[str] = None

The key for the output energy.

trainable: bool = False
direct_forces_key: Optional[str] = None
FID: ClassVar[str] = 'REPULSION_NLH'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None