fennol.models.physics.repulsion
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import numpy as np 5from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar 6from ...utils import AtomicUnits as au 7 8 9class RepulsionZBL(nn.Module): 10 """Repulsion energy based on the Ziegler-Biersack-Littmark potential 11 12 FID: REPULSION_ZBL 13 14 ### Reference 15 J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter 16 17 """ 18 19 _graphs_properties: Dict 20 graph_key: str = "graph" 21 """The key for the graph input.""" 22 energy_key: Optional[str] = None 23 """The key for the output energy.""" 24 trainable: bool = True 25 """Whether the parameters are trainable.""" 26 _energy_unit: str = "Ha" 27 """The energy unit of the model. **Automatically set by FENNIX**""" 28 29 FID: ClassVar[str] = "REPULSION_ZBL" 30 31 @nn.compact 32 def __call__(self, inputs): 33 species = inputs["species"] 34 graph = inputs[self.graph_key] 35 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 36 37 training = "training" in inputs.get("flags", {}) 38 39 rijs = graph["distances"] / au.BOHR 40 41 d_ = 0.46850 / au.BOHR 42 p_ = 0.23 43 alphas_ = np.array([3.19980, 0.94229, 0.40290, 0.20162]) 44 cs_ = 0.5 * np.array([0.18175273, 0.5098655, 0.28021213, 0.0281697]) 45 if self.trainable: 46 d = jnp.abs( 47 self.param( 48 "d", 49 lambda key, d: jnp.asarray(d, dtype=rijs.dtype), 50 d_, 51 ) 52 ) 53 p = jnp.abs( 54 self.param( 55 "p", 56 lambda key, p: jnp.asarray(p, dtype=rijs.dtype), 57 p_, 58 ) 59 ) 60 cs = 0.5 * jax.nn.softmax( 61 self.param( 62 "cs", 63 lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype), 64 [0.1130, 1.1445, 0.5459, -1.7514], 65 ) 66 ) 67 alphas = jnp.abs( 68 self.param( 69 "alphas", 70 lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype), 71 alphas_, 72 ) 73 ) 74 75 if training: 76 reg = jnp.asarray( 77 ((alphas_ - alphas) ** 2).sum() 78 + ((cs_ - cs) ** 2).sum() 79 + (p_ - p) ** 2 80 + (d_ - d) ** 2 81 ).reshape(1) 82 else: 83 cs = jnp.asarray(cs_) 84 alphas = jnp.asarray(alphas_) 85 d = d_ 86 p = p_ 87 88 if "alch_group" in inputs: 89 switch = graph["switch_raw"] 90 lambda_v = inputs["alch_vlambda"] 91 alch_group = inputs["alch_group"] 92 alch_alpha = inputs.get("alch_alpha", 0.5) 93 alch_m = inputs.get("alch_m", 2) 94 95 mask = alch_group[edge_src] == alch_group[edge_dst] 96 97 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5) 98 lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v)) 99 switch = jnp.where( 100 mask, 101 switch, 102 (lambda_v**alch_m) * switch , 103 ) 104 else: 105 switch = graph["switch"] 106 107 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 108 Zij = Z[edge_src]*Z[edge_dst] 109 Zp = Z**p / d 110 x = rijs * (Zp[edge_src] + Zp[edge_dst]) 111 phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1) 112 113 ereppair = Zij * phi / rijs * switch 114 115 erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 116 117 energy_unit = au.get_multiplier(self._energy_unit) 118 energy_key = self.energy_key if self.energy_key is not None else self.name 119 output = {**inputs, energy_key: erep_atomic * energy_unit} 120 if self.trainable and training: 121 output[energy_key + "_regularization"] = reg 122 123 return output
class
RepulsionZBL(flax.linen.module.Module):
10class RepulsionZBL(nn.Module): 11 """Repulsion energy based on the Ziegler-Biersack-Littmark potential 12 13 FID: REPULSION_ZBL 14 15 ### Reference 16 J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter 17 18 """ 19 20 _graphs_properties: Dict 21 graph_key: str = "graph" 22 """The key for the graph input.""" 23 energy_key: Optional[str] = None 24 """The key for the output energy.""" 25 trainable: bool = True 26 """Whether the parameters are trainable.""" 27 _energy_unit: str = "Ha" 28 """The energy unit of the model. **Automatically set by FENNIX**""" 29 30 FID: ClassVar[str] = "REPULSION_ZBL" 31 32 @nn.compact 33 def __call__(self, inputs): 34 species = inputs["species"] 35 graph = inputs[self.graph_key] 36 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 37 38 training = "training" in inputs.get("flags", {}) 39 40 rijs = graph["distances"] / au.BOHR 41 42 d_ = 0.46850 / au.BOHR 43 p_ = 0.23 44 alphas_ = np.array([3.19980, 0.94229, 0.40290, 0.20162]) 45 cs_ = 0.5 * np.array([0.18175273, 0.5098655, 0.28021213, 0.0281697]) 46 if self.trainable: 47 d = jnp.abs( 48 self.param( 49 "d", 50 lambda key, d: jnp.asarray(d, dtype=rijs.dtype), 51 d_, 52 ) 53 ) 54 p = jnp.abs( 55 self.param( 56 "p", 57 lambda key, p: jnp.asarray(p, dtype=rijs.dtype), 58 p_, 59 ) 60 ) 61 cs = 0.5 * jax.nn.softmax( 62 self.param( 63 "cs", 64 lambda key, cs: jnp.asarray(cs, dtype=rijs.dtype), 65 [0.1130, 1.1445, 0.5459, -1.7514], 66 ) 67 ) 68 alphas = jnp.abs( 69 self.param( 70 "alphas", 71 lambda key, alphas: jnp.asarray(alphas, dtype=rijs.dtype), 72 alphas_, 73 ) 74 ) 75 76 if training: 77 reg = jnp.asarray( 78 ((alphas_ - alphas) ** 2).sum() 79 + ((cs_ - cs) ** 2).sum() 80 + (p_ - p) ** 2 81 + (d_ - d) ** 2 82 ).reshape(1) 83 else: 84 cs = jnp.asarray(cs_) 85 alphas = jnp.asarray(alphas_) 86 d = d_ 87 p = p_ 88 89 if "alch_group" in inputs: 90 switch = graph["switch_raw"] 91 lambda_v = inputs["alch_vlambda"] 92 alch_group = inputs["alch_group"] 93 alch_alpha = inputs.get("alch_alpha", 0.5) 94 alch_m = inputs.get("alch_m", 2) 95 96 mask = alch_group[edge_src] == alch_group[edge_dst] 97 98 rijs = jnp.where(mask, rijs, (rijs**2 + alch_alpha**2 * (1 - lambda_v))**0.5) 99 lambda_v = 0.5*(1-jnp.cos(jnp.pi*lambda_v)) 100 switch = jnp.where( 101 mask, 102 switch, 103 (lambda_v**alch_m) * switch , 104 ) 105 else: 106 switch = graph["switch"] 107 108 Z = jnp.where(species > 0, species.astype(rijs.dtype), 0.0) 109 Zij = Z[edge_src]*Z[edge_dst] 110 Zp = Z**p / d 111 x = rijs * (Zp[edge_src] + Zp[edge_dst]) 112 phi = (cs[None, :] * jnp.exp(-alphas[None, :] * x[:, None])).sum(axis=-1) 113 114 ereppair = Zij * phi / rijs * switch 115 116 erep_atomic = jax.ops.segment_sum(ereppair, edge_src, species.shape[0]) 117 118 energy_unit = au.get_multiplier(self._energy_unit) 119 energy_key = self.energy_key if self.energy_key is not None else self.name 120 output = {**inputs, energy_key: erep_atomic * energy_unit} 121 if self.trainable and training: 122 output[energy_key + "_regularization"] = reg 123 124 return output
Repulsion energy based on the Ziegler-Biersack-Littmark potential
FID: REPULSION_ZBL
Reference
J. F. Ziegler & J. P. Biersack , The Stopping and Range of Ions in Matter
RepulsionZBL( _graphs_properties: Dict, graph_key: str = 'graph', energy_key: Optional[str] = None, trainable: bool = True, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.