fennol.models.physics.repulsion

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import numpy as np
  5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  6from ...utils import AtomicUnits as au
  7
  8
  9class RepulsionZBL(nn.Module):
 10    """Repulsion energy based on the Ziegler-Biersack-Littmark potential
 11
 12    FID: REPULSION_ZBL
 13
 14    ### Reference
 15    J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
 16
 17    """
 18
 19    _graphs_properties: Dict
 20    graph_key: str = "graph"
 21    """The key for the graph input."""
 22    energy_key: Optional[str] = None
 23    """The key for the output energy."""
 24    trainable: bool = True
 25    """Whether the parameters are trainable."""
 26    _energy_unit: str = "Ha"
 27    """The energy unit of the model. **Automatically set by FENNIX**"""
 28
 29    FID: ClassVar[str] = "REPULSION_ZBL"
 30
 31    @nn.compact
 32    def __call__(self, inputs):
 33        species = inputs["species"]
 34        graph = inputs[self.graph_key]
 35        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 36
 37        training = "training" in inputs.get("flags", {})
 38
 39        rijs = graph["distances"] / au.BOHR
 40
 41        d_ = 0.46850 / au.BOHR
 42        p_ = 0.23
 43        alphas_ = np.array([3.19980, 0.94229, 0.40290, 0.20162])
 44        cs_ = 0.5 * np.array([0.18175273, 0.5098655, 0.28021213, 0.0281697])
 45        if self.trainable:
 46            d = jnp.abs(
 47                self.param(
 48                    "d",
 49                    lambda key, d: jnp.asarray(d, dtype=rijs.dtype),
 50                    d_,
 51                )
 52            )
 53            p = jnp.abs(
 54                self.param(
 55                    "p",
 56                    lambda key, p: jnp.asarray(p, dtype=rijs.dtype),
 57                    p_,
 58                )
 59            )
 60            cs = 0.5 * jax.nn.softmax(
 61                self.param(
 62                    "cs",
 63                    lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype),
 64                    [0.1130, 1.1445, 0.5459, -1.7514],
 65                )
 66            )
 67            alphas = jnp.abs(
 68                self.param(
 69                    "alphas",
 70                    lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype),
 71                    alphas_,
 72                )
 73            )
 74
 75            if training:
 76                reg = jnp.asarray(
 77                    ((alphas_ - alphas) ** 2).sum()
 78                    + ((cs_ - cs) ** 2).sum()
 79                    + (p_ - p) ** 2
 80                    + (d_ - d) ** 2
 81                ).reshape(1)
 82        else:
 83            cs = jnp.asarray(cs_)
 84            alphas = jnp.asarray(alphas_)
 85            d = d_
 86            p = p_
 87
 88        if "alch_group" in inputs:
 89            switch = graph["switch_raw"]
 90            lambda_v = inputs["alch_vlambda"]
 91            alch_group = inputs["alch_group"]
 92            alch_alpha = inputs.get("alch_alpha", 0.5)
 93            alch_m = inputs.get("alch_m", 2)
 94
 95            mask = alch_group[edge_src] == alch_group[edge_dst]
 96
 97            rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5)
 98            lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v))
 99            switch = jnp.where(
100                mask,
101                switch,
102                (lambda_v**alch_m) * switch ,
103            )
104        else:
105            switch = graph["switch"]
106
107        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
108        Zij = Z[edge_src]*Z[edge_dst]
109        Zp = Z**p / d
110        x = rijs * (Zp[edge_src] + Zp[edge_dst])
111        phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1)
112
113        ereppair = Zij * phi / rijs * switch
114
115        erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
116
117        energy_unit = au.get_multiplier(self._energy_unit)
118        energy_key = self.energy_key if self.energy_key is not None else self.name
119        output = {**inputs, energy_key: erep_atomic * energy_unit}
120        if self.trainable and training:
121            output[energy_key + "_regularization"] = reg
122
123        return output
class RepulsionZBL(flax.linen.module.Module):
 10class RepulsionZBL(nn.Module):
 11    """Repulsion energy based on the Ziegler-Biersack-Littmark potential
 12
 13    FID: REPULSION_ZBL
 14
 15    ### Reference
 16    J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
 17
 18    """
 19
 20    _graphs_properties: Dict
 21    graph_key: str = "graph"
 22    """The key for the graph input."""
 23    energy_key: Optional[str] = None
 24    """The key for the output energy."""
 25    trainable: bool = True
 26    """Whether the parameters are trainable."""
 27    _energy_unit: str = "Ha"
 28    """The energy unit of the model. **Automatically set by FENNIX**"""
 29
 30    FID: ClassVar[str] = "REPULSION_ZBL"
 31
 32    @nn.compact
 33    def __call__(self, inputs):
 34        species = inputs["species"]
 35        graph = inputs[self.graph_key]
 36        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 37
 38        training = "training" in inputs.get("flags", {})
 39
 40        rijs = graph["distances"] / au.BOHR
 41
 42        d_ = 0.46850 / au.BOHR
 43        p_ = 0.23
 44        alphas_ = np.array([3.19980, 0.94229, 0.40290, 0.20162])
 45        cs_ = 0.5 * np.array([0.18175273, 0.5098655, 0.28021213, 0.0281697])
 46        if self.trainable:
 47            d = jnp.abs(
 48                self.param(
 49                    "d",
 50                    lambda key, d: jnp.asarray(d, dtype=rijs.dtype),
 51                    d_,
 52                )
 53            )
 54            p = jnp.abs(
 55                self.param(
 56                    "p",
 57                    lambda key, p: jnp.asarray(p, dtype=rijs.dtype),
 58                    p_,
 59                )
 60            )
 61            cs = 0.5 * jax.nn.softmax(
 62                self.param(
 63                    "cs",
 64                    lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype),
 65                    [0.1130, 1.1445, 0.5459, -1.7514],
 66                )
 67            )
 68            alphas = jnp.abs(
 69                self.param(
 70                    "alphas",
 71                    lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype),
 72                    alphas_,
 73                )
 74            )
 75
 76            if training:
 77                reg = jnp.asarray(
 78                    ((alphas_ - alphas) ** 2).sum()
 79                    + ((cs_ - cs) ** 2).sum()
 80                    + (p_ - p) ** 2
 81                    + (d_ - d) ** 2
 82                ).reshape(1)
 83        else:
 84            cs = jnp.asarray(cs_)
 85            alphas = jnp.asarray(alphas_)
 86            d = d_
 87            p = p_
 88
 89        if "alch_group" in inputs:
 90            switch = graph["switch_raw"]
 91            lambda_v = inputs["alch_vlambda"]
 92            alch_group = inputs["alch_group"]
 93            alch_alpha = inputs.get("alch_alpha", 0.5)
 94            alch_m = inputs.get("alch_m", 2)
 95
 96            mask = alch_group[edge_src] == alch_group[edge_dst]
 97
 98            rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5)
 99            lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v))
100            switch = jnp.where(
101                mask,
102                switch,
103                (lambda_v**alch_m) * switch ,
104            )
105        else:
106            switch = graph["switch"]
107
108        Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0)
109        Zij = Z[edge_src]*Z[edge_dst]
110        Zp = Z**p / d
111        x = rijs * (Zp[edge_src] + Zp[edge_dst])
112        phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1)
113
114        ereppair = Zij * phi / rijs * switch
115
116        erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0])
117
118        energy_unit = au.get_multiplier(self._energy_unit)
119        energy_key = self.energy_key if self.energy_key is not None else self.name
120        output = {**inputs, energy_key: erep_atomic * energy_unit}
121        if self.trainable and training:
122            output[energy_key + "_regularization"] = reg
123
124        return output

Repulsion energy based on the Ziegler-Biersack-Littmark potential

FID: REPULSION_ZBL

Reference

J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter

RepulsionZBL( _graphs_properties: Dict, graph_key: str = 'graph', energy_key: Optional[str] = None, trainable: bool = True, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

energy_key: Optional[str] = None

The key for the output energy.

trainable: bool = True

Whether the parameters are trainable.

FID: ClassVar[str] = 'REPULSION_ZBL'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None