fennol.models.physics.bond

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import numpy as np
  5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  6from ...utils import AtomicUnits as au
  7import dataclasses
  8from ...utils.periodic_table import (
  9    D3_ELECTRONEGATIVITIES,
 10    D3_HARDNESSES,
 11    D3_VDW_RADII,
 12    D3_COV_RADII,
 13    D3_KAPPA,
 14    VDW_RADII,
 15    VALENCE_ELECTRONS,
 16    PAULING_ELECTRONEGATIVITY,
 17)
 18
 19
 20class CND4(nn.Module):
 21    """ Coordination number as defined in D4 dispersion correction
 22    
 23    FID : CN_D4
 24    """
 25    graph_key: str = "graph"
 26    """ The key for the graph input."""
 27    output_key: Optional[str] = None
 28    """ The key for the output."""
 29    k0: float = 7.5
 30    k1: float = 4.1
 31    k2: float = 19.09
 32    k3: float = 254.56
 33    electronegativity_factor: bool = False
 34    """ Whether to include electronegativity factor."""
 35    trainable: bool = False
 36    """ Whether the parameters are trainable."""
 37
 38    FID: ClassVar[str]  = "CN_D4"
 39
 40    @nn.compact
 41    def __call__(self, inputs):
 42        graph = inputs[self.graph_key]
 43        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 44        species = inputs["species"]
 45
 46        if self.trainable:
 47            rc = self.param("rc", lambda key: jnp.asarray(D3_COV_RADII))[species]
 48        else:
 49            rc = jnp.asarray(D3_COV_RADII)[species]
 50        rcij = rc[edge_src] + rc[edge_dst]
 51        rij = graph["distances"] / au.BOHR
 52
 53        if self.trainable:
 54            k0 = self.k0 * jnp.abs(self.param("k0", lambda key: jnp.asarray(1.0)))
 55        else:
 56            k0 = self.k0
 57
 58        CNij = (
 59            0.5 * (1 + jax.scipy.special.erf(-k0 * (rij / rcij - 1.))) * graph["switch"]
 60        )
 61
 62        if self.electronegativity_factor:
 63            k1 = self.k1
 64            k2 = self.k2
 65            k3 = self.k3
 66            if self.trainable:
 67                k1 = k1 * jnp.abs(self.param("k1", lambda key: jnp.asarray(1.0)))
 68                k2 = self.param("k2", lambda key: jnp.asarray(1.0))
 69                k3 = jnp.abs(self.param("k3", lambda key: jnp.asarray(1.0)))
 70                en = self.param("en", lambda key: jnp.asarray(D3_ELECTRONEGATIVITIES))[
 71                    species
 72                ]
 73            else:
 74                en = jnp.asarray(PAULING_ELECTRONEGATIVITY)[species]
 75            en_ij = jnp.abs(en[edge_src] - en[edge_dst])
 76            dij = k1 * jnp.exp(-((en_ij + k2) ** 2) / k3)
 77            CNij = CNij * dij
 78        CNi = jax.ops.segment_sum(CNij, edge_src, species.shape[0])
 79
 80        output_key = self.name if self.output_key is None else self.output_key
 81        return {**inputs, output_key: CNi, output_key + "_pair": CNij}
 82
 83
 84class SumSwitch(nn.Module):
 85    """Sum (a power of) the switch values for each neighbor.
 86    
 87    FID : SUM_SWITCH
 88
 89    """
 90    graph_key: str = "graph"
 91    """ The key for the graph input."""
 92    output_key: Optional[str] = None
 93    """ The key for the output."""
 94    pow: float = 1.0
 95    """ The power to raise the switch values to."""
 96    trainable: bool = False
 97    """ Whether the pow parameter is trainable."""
 98
 99    FID: ClassVar[str]  = "SUM_SWITCH"
100
101    @nn.compact
102    def __call__(self, inputs):
103        graph = inputs[self.graph_key]
104        edge_src = graph["edge_src"]
105        switch = graph["switch"]
106
107        if self.trainable:
108            p = jnp.abs(
109                self.param("pow", lambda key: jnp.asarray(self.pow))
110            )
111        else:
112            p = self.pow
113        shift=(1.e-3)**p
114
115        cn = jax.ops.segment_sum((1.e-3+switch)**p-shift, edge_src, inputs["species"].shape[0])
116
117        output_key = self.name if self.output_key is None else self.output_key
118        return {**inputs, output_key: cn}
119
120
121class CNShift(nn.Module):
122    
123    cn_key: str
124    output_key: Optional[str] = None
125    kappa_key: Optional[str] = None
126    sqrt_shift: float = 1.0e-6
127    ref_value: Union[str, float] = 1.0
128    enforce_positive: bool = False
129    cn_pow: float = 0.5
130
131    FID: ClassVar[str]  = "CN_SHIFT"
132
133
134    @nn.compact
135    def __call__(self, inputs):
136        CNi = inputs[self.cn_key]
137        if self.kappa_key is not None:
138            kappai = inputs[self.kappa_key]
139            assert kappai.shape == CNi.shape
140        else:
141            species = inputs["species"]
142            kappai = self.param("kappa", nn.initializers.zeros, (len(D3_COV_RADII),))[
143                species
144            ]
145        shift = kappai * (CNi + self.sqrt_shift) ** self.cn_pow
146
147        if isinstance(self.ref_value, str):
148            ref_value = inputs[self.ref_value]
149            assert ref_value.shape == shift.shape
150        else:
151            ref_value = self.ref_value
152
153        if self.enforce_positive:
154            shift = jax.nn.celu(shift, alpha=ref_value)
155
156        out = ref_value + shift
157
158        output_key = self.name if self.output_key is None else self.output_key
159        return {**inputs, output_key: out}
160
161
162class CNStore(nn.Module):
163    cn_key: str
164    output_key: Optional[str] = None
165    store_size: int = 10
166    n_gaussians: int = 4
167    isolated_value: float = 0.0
168    init_scale_cn: float = 5.0
169    init_scale_values: float = 1.0
170    beta: float = 6.0
171    trainable: bool = True
172    output_dim: int = 1
173    squeeze: bool = True
174
175    FID: ClassVar[str]  = "CN_STORE"
176
177    @nn.compact
178    def __call__(self, inputs):
179        cn = inputs[self.cn_key]
180        species = inputs["species"]
181
182        cn_refs = self.param(
183            "cn_refs",
184            nn.initializers.uniform(self.init_scale_cn),
185            (len(D3_COV_RADII), self.store_size),
186        )[species]
187        values_refs = self.param(
188            "values_refs",
189            nn.initializers.uniform(self.init_scale_values),
190            (len(D3_COV_RADII), self.store_size, self.output_dim),
191        )[species]
192
193        beta = self.beta
194        if self.trainable:
195            beta = self.param("beta", lambda key: jnp.asarray(self.beta))
196        j = jnp.asarray(np.arange(self.n_gaussians)[None, None, :], dtype=jnp.float32)
197        delta_cns = jnp.log(
198            jnp.sum(
199                jnp.exp(-beta * j * ((cn[:, None] - cn_refs) ** 2)[:, :, None]), axis=-1
200            )
201        )
202        w = jax.nn.softmax(delta_cns, axis=-1)
203
204        values = jnp.sum(w[:, :, None] * values_refs, axis=1)
205        if self.output_dim == 1 and self.squeeze:
206            values = jnp.squeeze(values, axis=-1)
207
208        output_key = self.name if self.output_key is None else self.output_key
209        return {**inputs, output_key: values}
210
211
212class FlatBottom(nn.Module):
213    """Flat bottom potential energy surface.
214    
215    Realized by Côme Cattin, 2024.
216
217    Flat bottom potential energy:
218    E = alpha * (r - req) ** 2 if r >= req
219    E = 0 if r < req
220
221    FID: FLAT_BOTTOM
222    """
223
224    energy_key: Optional[str] = None
225    """Key of the energy in the outputs."""
226    graph_key: str = "graph"
227    """Key of the graph in the inputs."""
228    alpha: float = 400.0
229    """Force constant of the flat bottom potential (in kcal/mol/A^2)."""
230    r_eq_factor: float = 1.3
231    """Factor to multiply the sum of the VDW radii of the two atoms."""
232    _energy_unit: str = "Ha"
233    """The energy unit of the model. **Automatically set by FENNIX**"""
234
235    FID: ClassVar[str] = "FLAT_BOTTOM"
236
237    @nn.compact
238    def __call__(self, inputs):
239
240        species = inputs["species"]
241        graph = inputs[self.graph_key]
242        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
243        distances = graph["distances"]
244        rij = distances / au.BOHR
245        training = "training" in inputs.get("flags", {})
246
247        output = {}
248        energy_key = self.energy_key if self.energy_key is not None else self.name
249
250        if training:
251            output[energy_key] =  jnp.zeros(species.shape[0],dtype=distances.dtype)
252            return {**inputs, **output}
253
254        # req is the sum of the covalent radii of the two atoms
255        rcov = jnp.asarray(D3_COV_RADII)[species]
256        req = self.r_eq_factor * (rcov[edge_src] + rcov[edge_dst])
257
258        alpha = inputs.get("alpha", self.alpha)/ au.KCALPERMOL*au.BOHR**2
259
260        flat_bottom_energy = jnp.where(
261            rij > req, alpha  * (rij - req) ** 2, 0.
262        )
263
264        flat_bottom_energy = jax.ops.segment_sum(flat_bottom_energy, edge_src, num_segments=species.shape[0])
265
266        energy_unit = au.get_multiplier(self._energy_unit)
267        output[energy_key] = flat_bottom_energy * energy_unit
268
269        return {**inputs, **output}
class CND4(flax.linen.module.Module):
21class CND4(nn.Module):
22    """ Coordination number as defined in D4 dispersion correction
23    
24    FID : CN_D4
25    """
26    graph_key: str = "graph"
27    """ The key for the graph input."""
28    output_key: Optional[str] = None
29    """ The key for the output."""
30    k0: float = 7.5
31    k1: float = 4.1
32    k2: float = 19.09
33    k3: float = 254.56
34    electronegativity_factor: bool = False
35    """ Whether to include electronegativity factor."""
36    trainable: bool = False
37    """ Whether the parameters are trainable."""
38
39    FID: ClassVar[str]  = "CN_D4"
40
41    @nn.compact
42    def __call__(self, inputs):
43        graph = inputs[self.graph_key]
44        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
45        species = inputs["species"]
46
47        if self.trainable:
48            rc = self.param("rc", lambda key: jnp.asarray(D3_COV_RADII))[species]
49        else:
50            rc = jnp.asarray(D3_COV_RADII)[species]
51        rcij = rc[edge_src] + rc[edge_dst]
52        rij = graph["distances"] / au.BOHR
53
54        if self.trainable:
55            k0 = self.k0 * jnp.abs(self.param("k0", lambda key: jnp.asarray(1.0)))
56        else:
57            k0 = self.k0
58
59        CNij = (
60            0.5 * (1 + jax.scipy.special.erf(-k0 * (rij / rcij - 1.))) * graph["switch"]
61        )
62
63        if self.electronegativity_factor:
64            k1 = self.k1
65            k2 = self.k2
66            k3 = self.k3
67            if self.trainable:
68                k1 = k1 * jnp.abs(self.param("k1", lambda key: jnp.asarray(1.0)))
69                k2 = self.param("k2", lambda key: jnp.asarray(1.0))
70                k3 = jnp.abs(self.param("k3", lambda key: jnp.asarray(1.0)))
71                en = self.param("en", lambda key: jnp.asarray(D3_ELECTRONEGATIVITIES))[
72                    species
73                ]
74            else:
75                en = jnp.asarray(PAULING_ELECTRONEGATIVITY)[species]
76            en_ij = jnp.abs(en[edge_src] - en[edge_dst])
77            dij = k1 * jnp.exp(-((en_ij + k2) ** 2) / k3)
78            CNij = CNij * dij
79        CNi = jax.ops.segment_sum(CNij, edge_src, species.shape[0])
80
81        output_key = self.name if self.output_key is None else self.output_key
82        return {**inputs, output_key: CNi, output_key + "_pair": CNij}

Coordination number as defined in D4 dispersion correction

FID : CN_D4

CND4( graph_key: str = 'graph', output_key: Optional[str] = None, k0: float = 7.5, k1: float = 4.1, k2: float = 19.09, k3: float = 254.56, electronegativity_factor: bool = False, trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

output_key: Optional[str] = None

The key for the output.

k0: float = 7.5
k1: float = 4.1
k2: float = 19.09
k3: float = 254.56
electronegativity_factor: bool = False

Whether to include electronegativity factor.

trainable: bool = False

Whether the parameters are trainable.

FID: ClassVar[str] = 'CN_D4'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SumSwitch(flax.linen.module.Module):
 85class SumSwitch(nn.Module):
 86    """Sum (a power of) the switch values for each neighbor.
 87    
 88    FID : SUM_SWITCH
 89
 90    """
 91    graph_key: str = "graph"
 92    """ The key for the graph input."""
 93    output_key: Optional[str] = None
 94    """ The key for the output."""
 95    pow: float = 1.0
 96    """ The power to raise the switch values to."""
 97    trainable: bool = False
 98    """ Whether the pow parameter is trainable."""
 99
100    FID: ClassVar[str]  = "SUM_SWITCH"
101
102    @nn.compact
103    def __call__(self, inputs):
104        graph = inputs[self.graph_key]
105        edge_src = graph["edge_src"]
106        switch = graph["switch"]
107
108        if self.trainable:
109            p = jnp.abs(
110                self.param("pow", lambda key: jnp.asarray(self.pow))
111            )
112        else:
113            p = self.pow
114        shift=(1.e-3)**p
115
116        cn = jax.ops.segment_sum((1.e-3+switch)**p-shift, edge_src, inputs["species"].shape[0])
117
118        output_key = self.name if self.output_key is None else self.output_key
119        return {**inputs, output_key: cn}

Sum (a power of) the switch values for each neighbor.

FID : SUM_SWITCH

SumSwitch( graph_key: str = 'graph', output_key: Optional[str] = None, pow: float = 1.0, trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

The key for the graph input.

output_key: Optional[str] = None

The key for the output.

pow: float = 1.0

The power to raise the switch values to.

trainable: bool = False

Whether the pow parameter is trainable.

FID: ClassVar[str] = 'SUM_SWITCH'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class CNShift(flax.linen.module.Module):
122class CNShift(nn.Module):
123    
124    cn_key: str
125    output_key: Optional[str] = None
126    kappa_key: Optional[str] = None
127    sqrt_shift: float = 1.0e-6
128    ref_value: Union[str, float] = 1.0
129    enforce_positive: bool = False
130    cn_pow: float = 0.5
131
132    FID: ClassVar[str]  = "CN_SHIFT"
133
134
135    @nn.compact
136    def __call__(self, inputs):
137        CNi = inputs[self.cn_key]
138        if self.kappa_key is not None:
139            kappai = inputs[self.kappa_key]
140            assert kappai.shape == CNi.shape
141        else:
142            species = inputs["species"]
143            kappai = self.param("kappa", nn.initializers.zeros, (len(D3_COV_RADII),))[
144                species
145            ]
146        shift = kappai * (CNi + self.sqrt_shift) ** self.cn_pow
147
148        if isinstance(self.ref_value, str):
149            ref_value = inputs[self.ref_value]
150            assert ref_value.shape == shift.shape
151        else:
152            ref_value = self.ref_value
153
154        if self.enforce_positive:
155            shift = jax.nn.celu(shift, alpha=ref_value)
156
157        out = ref_value + shift
158
159        output_key = self.name if self.output_key is None else self.output_key
160        return {**inputs, output_key: out}
CNShift( cn_key: str, output_key: Optional[str] = None, kappa_key: Optional[str] = None, sqrt_shift: float = 1e-06, ref_value: Union[str, float] = 1.0, enforce_positive: bool = False, cn_pow: float = 0.5, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cn_key: str
output_key: Optional[str] = None
kappa_key: Optional[str] = None
sqrt_shift: float = 1e-06
ref_value: Union[str, float] = 1.0
enforce_positive: bool = False
cn_pow: float = 0.5
FID: ClassVar[str] = 'CN_SHIFT'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class CNStore(flax.linen.module.Module):
163class CNStore(nn.Module):
164    cn_key: str
165    output_key: Optional[str] = None
166    store_size: int = 10
167    n_gaussians: int = 4
168    isolated_value: float = 0.0
169    init_scale_cn: float = 5.0
170    init_scale_values: float = 1.0
171    beta: float = 6.0
172    trainable: bool = True
173    output_dim: int = 1
174    squeeze: bool = True
175
176    FID: ClassVar[str]  = "CN_STORE"
177
178    @nn.compact
179    def __call__(self, inputs):
180        cn = inputs[self.cn_key]
181        species = inputs["species"]
182
183        cn_refs = self.param(
184            "cn_refs",
185            nn.initializers.uniform(self.init_scale_cn),
186            (len(D3_COV_RADII), self.store_size),
187        )[species]
188        values_refs = self.param(
189            "values_refs",
190            nn.initializers.uniform(self.init_scale_values),
191            (len(D3_COV_RADII), self.store_size, self.output_dim),
192        )[species]
193
194        beta = self.beta
195        if self.trainable:
196            beta = self.param("beta", lambda key: jnp.asarray(self.beta))
197        j = jnp.asarray(np.arange(self.n_gaussians)[None, None, :], dtype=jnp.float32)
198        delta_cns = jnp.log(
199            jnp.sum(
200                jnp.exp(-beta * j * ((cn[:, None] - cn_refs) ** 2)[:, :, None]), axis=-1
201            )
202        )
203        w = jax.nn.softmax(delta_cns, axis=-1)
204
205        values = jnp.sum(w[:, :, None] * values_refs, axis=1)
206        if self.output_dim == 1 and self.squeeze:
207            values = jnp.squeeze(values, axis=-1)
208
209        output_key = self.name if self.output_key is None else self.output_key
210        return {**inputs, output_key: values}
CNStore( cn_key: str, output_key: Optional[str] = None, store_size: int = 10, n_gaussians: int = 4, isolated_value: float = 0.0, init_scale_cn: float = 5.0, init_scale_values: float = 1.0, beta: float = 6.0, trainable: bool = True, output_dim: int = 1, squeeze: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cn_key: str
output_key: Optional[str] = None
store_size: int = 10
n_gaussians: int = 4
isolated_value: float = 0.0
init_scale_cn: float = 5.0
init_scale_values: float = 1.0
beta: float = 6.0
trainable: bool = True
output_dim: int = 1
squeeze: bool = True
FID: ClassVar[str] = 'CN_STORE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class FlatBottom(flax.linen.module.Module):
213class FlatBottom(nn.Module):
214    """Flat bottom potential energy surface.
215    
216    Realized by Côme Cattin, 2024.
217
218    Flat bottom potential energy:
219    E = alpha * (r - req) ** 2 if r >= req
220    E = 0 if r < req
221
222    FID: FLAT_BOTTOM
223    """
224
225    energy_key: Optional[str] = None
226    """Key of the energy in the outputs."""
227    graph_key: str = "graph"
228    """Key of the graph in the inputs."""
229    alpha: float = 400.0
230    """Force constant of the flat bottom potential (in kcal/mol/A^2)."""
231    r_eq_factor: float = 1.3
232    """Factor to multiply the sum of the VDW radii of the two atoms."""
233    _energy_unit: str = "Ha"
234    """The energy unit of the model. **Automatically set by FENNIX**"""
235
236    FID: ClassVar[str] = "FLAT_BOTTOM"
237
238    @nn.compact
239    def __call__(self, inputs):
240
241        species = inputs["species"]
242        graph = inputs[self.graph_key]
243        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
244        distances = graph["distances"]
245        rij = distances / au.BOHR
246        training = "training" in inputs.get("flags", {})
247
248        output = {}
249        energy_key = self.energy_key if self.energy_key is not None else self.name
250
251        if training:
252            output[energy_key] =  jnp.zeros(species.shape[0],dtype=distances.dtype)
253            return {**inputs, **output}
254
255        # req is the sum of the covalent radii of the two atoms
256        rcov = jnp.asarray(D3_COV_RADII)[species]
257        req = self.r_eq_factor * (rcov[edge_src] + rcov[edge_dst])
258
259        alpha = inputs.get("alpha", self.alpha)/ au.KCALPERMOL*au.BOHR**2
260
261        flat_bottom_energy = jnp.where(
262            rij > req, alpha  * (rij - req) ** 2, 0.
263        )
264
265        flat_bottom_energy = jax.ops.segment_sum(flat_bottom_energy, edge_src, num_segments=species.shape[0])
266
267        energy_unit = au.get_multiplier(self._energy_unit)
268        output[energy_key] = flat_bottom_energy * energy_unit
269
270        return {**inputs, **output}

Flat bottom potential energy surface.

Realized by Côme Cattin, 2024.

Flat bottom potential energy: E = alpha * (r - req) ** 2 if r >= req E = 0 if r < req

FID: FLAT_BOTTOM

FlatBottom( energy_key: Optional[str] = None, graph_key: str = 'graph', alpha: float = 400.0, r_eq_factor: float = 1.3, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
energy_key: Optional[str] = None

Key of the energy in the outputs.

graph_key: str = 'graph'

Key of the graph in the inputs.

alpha: float = 400.0

Force constant of the flat bottom potential (in kcal/mol/A^2).

r_eq_factor: float = 1.3

Factor to multiply the sum of the VDW radii of the two atoms.

FID: ClassVar[str] = 'FLAT_BOTTOM'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None