fennol.models.physics.electric_field
Electric field model for FENNOL.
Created by C. Cattin 2024
1#!/usr/bin/env python3 2"""Electric field model for FENNOL. 3 4Created by C. Cattin 2024 5""" 6 7import flax.linen as nn 8import jax 9import jax.numpy as jnp 10from typing import ClassVar 11 12from fennol.utils import AtomicUnits as Au 13 14 15class ElectricField(nn.Module): 16 """Electric field from distributed point charges with short-range damping. 17 18 FID: ELECTRIC_FIELD 19 20 The short-range damping is defined as in AMOEBA+ 21 22 """ 23 24 damping_param: float = 0.7 25 """Damping parameter for the electric field.""" 26 charges_key: str = 'charges' 27 """Key of the charges in the input.""" 28 graph_key: str = 'graph' 29 """Key of the graph in the input.""" 30 polarizability_key: str = 'polarizability' 31 """Key of the polarizability in the input.""" 32 trainable: bool = False 33 34 FID: ClassVar[str] = 'ELECTRIC_FIELD' 35 36 @nn.compact 37 def __call__(self, inputs): 38 species = inputs['species'] 39 40 # Graph information 41 graph = inputs[self.graph_key] 42 edge_src, edge_dst = graph['edge_src'], graph['edge_dst'] 43 44 # Distance and vector between each pair of atoms in atomic units 45 distances = graph['distances'] 46 rij = distances / Au.BOHR 47 vec_ij = graph['vec'] / Au.BOHR 48 polarizability = ( 49 inputs[self.polarizability_key] / Au.BOHR**3 50 ) 51 pol_src = polarizability[edge_src] 52 pol_dst = polarizability[edge_dst] 53 alpha_ij = pol_dst * pol_src 54 # Effective distance 55 uij = rij / alpha_ij ** (1 / 6) 56 57 # Charges and polarizability 58 charges = inputs[self.charges_key] 59 rij = rij[:, None] 60 q_ij = charges[edge_dst, None] 61 62 if self.trainable: 63 damping_param = jnp.abs(self.param('damping_param', lambda key: jnp.array(self.damping_param))) 64 else: 65 damping_param = self.damping_param 66 # Damping term 67 damping_field = 1 - jnp.exp( 68 -damping_param * uij**1.5 69 )[:, None] 70 71 # Electric field 72 eij = -q_ij * (vec_ij / rij**3) * damping_field 73 electric_field = jax.ops.segment_sum( 74 eij, edge_src, species.shape[0] 75 ).flatten() 76 77 output = { 78 'electric_field': electric_field 79 } 80 81 return {**inputs, **output} 82 83 84if __name__ == "__main__": 85 pass
class
ElectricField(flax.linen.module.Module):
16class ElectricField(nn.Module): 17 """Electric field from distributed point charges with short-range damping. 18 19 FID: ELECTRIC_FIELD 20 21 The short-range damping is defined as in AMOEBA+ 22 23 """ 24 25 damping_param: float = 0.7 26 """Damping parameter for the electric field.""" 27 charges_key: str = 'charges' 28 """Key of the charges in the input.""" 29 graph_key: str = 'graph' 30 """Key of the graph in the input.""" 31 polarizability_key: str = 'polarizability' 32 """Key of the polarizability in the input.""" 33 trainable: bool = False 34 35 FID: ClassVar[str] = 'ELECTRIC_FIELD' 36 37 @nn.compact 38 def __call__(self, inputs): 39 species = inputs['species'] 40 41 # Graph information 42 graph = inputs[self.graph_key] 43 edge_src, edge_dst = graph['edge_src'], graph['edge_dst'] 44 45 # Distance and vector between each pair of atoms in atomic units 46 distances = graph['distances'] 47 rij = distances / Au.BOHR 48 vec_ij = graph['vec'] / Au.BOHR 49 polarizability = ( 50 inputs[self.polarizability_key] / Au.BOHR**3 51 ) 52 pol_src = polarizability[edge_src] 53 pol_dst = polarizability[edge_dst] 54 alpha_ij = pol_dst * pol_src 55 # Effective distance 56 uij = rij / alpha_ij ** (1 / 6) 57 58 # Charges and polarizability 59 charges = inputs[self.charges_key] 60 rij = rij[:, None] 61 q_ij = charges[edge_dst, None] 62 63 if self.trainable: 64 damping_param = jnp.abs(self.param('damping_param', lambda key: jnp.array(self.damping_param))) 65 else: 66 damping_param = self.damping_param 67 # Damping term 68 damping_field = 1 - jnp.exp( 69 -damping_param * uij**1.5 70 )[:, None] 71 72 # Electric field 73 eij = -q_ij * (vec_ij / rij**3) * damping_field 74 electric_field = jax.ops.segment_sum( 75 eij, edge_src, species.shape[0] 76 ).flatten() 77 78 output = { 79 'electric_field': electric_field 80 } 81 82 return {**inputs, **output}
Electric field from distributed point charges with short-range damping.
FID: ELECTRIC_FIELD
The short-range damping is defined as in AMOEBA+
ElectricField( damping_param: float = 0.7, charges_key: str = 'charges', graph_key: str = 'graph', polarizability_key: str = 'polarizability', trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.