fennol.models.physics.electric_field

Electric field model for FENNOL.

Created by C. Cattin 2024

 1#!/usr/bin/env python3
 2"""Electric field model for FENNOL.
 3
 4Created by C. Cattin 2024
 5"""
 6
 7import flax.linen as nn
 8import jax
 9import jax.numpy as jnp
10from typing import ClassVar
11
12from fennol.utils import AtomicUnits as Au
13
14
15class ElectricField(nn.Module):
16    """Electric field from distributed point charges with short-range damping.
17
18    FID: ELECTRIC_FIELD
19
20    The short-range damping is defined as in AMOEBA+
21
22    """
23
24    damping_param: float = 0.7
25    """Damping parameter for the electric field."""
26    charges_key: str = 'charges'
27    """Key of the charges in the input."""
28    graph_key: str = 'graph'
29    """Key of the graph in the input."""
30    polarizability_key: str = 'polarizability'
31    """Key of the polarizability in the input."""
32    trainable: bool = False
33
34    FID: ClassVar[str] = 'ELECTRIC_FIELD'
35
36    @nn.compact
37    def __call__(self, inputs):
38        species = inputs['species']
39
40        # Graph information
41        graph = inputs[self.graph_key]
42        edge_src, edge_dst = graph['edge_src'], graph['edge_dst']
43
44        # Distance and vector between each pair of atoms in atomic units
45        distances = graph['distances']
46        rij = distances / Au.BOHR
47        vec_ij = graph['vec'] / Au.BOHR
48        polarizability = (
49            inputs[self.polarizability_key] / Au.BOHR**3
50        )
51        pol_src = polarizability[edge_src]
52        pol_dst = polarizability[edge_dst]
53        alpha_ij = pol_dst * pol_src
54        # Effective distance
55        uij = rij / alpha_ij ** (1 / 6)
56
57        # Charges and polarizability
58        charges = inputs[self.charges_key]
59        rij = rij[:, None]
60        q_ij = charges[edge_dst, None]
61
62        if self.trainable:
63            damping_param = jnp.abs(self.param('damping_param', lambda key: jnp.array(self.damping_param)))
64        else:
65            damping_param = self.damping_param
66        # Damping term
67        damping_field = 1 - jnp.exp(
68            -damping_param * uij**1.5
69        )[:, None]
70
71        # Electric field
72        eij = -q_ij * (vec_ij / rij**3) * damping_field
73        electric_field = jax.ops.segment_sum(
74            eij, edge_src, species.shape[0]
75        ).flatten()
76
77        output = {
78            'electric_field': electric_field
79        }
80
81        return {**inputs, **output}
82
83
84if __name__ == "__main__":
85    pass
class ElectricField(flax.linen.module.Module):
16class ElectricField(nn.Module):
17    """Electric field from distributed point charges with short-range damping.
18
19    FID: ELECTRIC_FIELD
20
21    The short-range damping is defined as in AMOEBA+
22
23    """
24
25    damping_param: float = 0.7
26    """Damping parameter for the electric field."""
27    charges_key: str = 'charges'
28    """Key of the charges in the input."""
29    graph_key: str = 'graph'
30    """Key of the graph in the input."""
31    polarizability_key: str = 'polarizability'
32    """Key of the polarizability in the input."""
33    trainable: bool = False
34
35    FID: ClassVar[str] = 'ELECTRIC_FIELD'
36
37    @nn.compact
38    def __call__(self, inputs):
39        species = inputs['species']
40
41        # Graph information
42        graph = inputs[self.graph_key]
43        edge_src, edge_dst = graph['edge_src'], graph['edge_dst']
44
45        # Distance and vector between each pair of atoms in atomic units
46        distances = graph['distances']
47        rij = distances / Au.BOHR
48        vec_ij = graph['vec'] / Au.BOHR
49        polarizability = (
50            inputs[self.polarizability_key] / Au.BOHR**3
51        )
52        pol_src = polarizability[edge_src]
53        pol_dst = polarizability[edge_dst]
54        alpha_ij = pol_dst * pol_src
55        # Effective distance
56        uij = rij / alpha_ij ** (1 / 6)
57
58        # Charges and polarizability
59        charges = inputs[self.charges_key]
60        rij = rij[:, None]
61        q_ij = charges[edge_dst, None]
62
63        if self.trainable:
64            damping_param = jnp.abs(self.param('damping_param', lambda key: jnp.array(self.damping_param)))
65        else:
66            damping_param = self.damping_param
67        # Damping term
68        damping_field = 1 - jnp.exp(
69            -damping_param * uij**1.5
70        )[:, None]
71
72        # Electric field
73        eij = -q_ij * (vec_ij / rij**3) * damping_field
74        electric_field = jax.ops.segment_sum(
75            eij, edge_src, species.shape[0]
76        ).flatten()
77
78        output = {
79            'electric_field': electric_field
80        }
81
82        return {**inputs, **output}

Electric field from distributed point charges with short-range damping.

FID: ELECTRIC_FIELD

The short-range damping is defined as in AMOEBA+

ElectricField( damping_param: float = 0.7, charges_key: str = 'charges', graph_key: str = 'graph', polarizability_key: str = 'polarizability', trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
damping_param: float = 0.7

Damping parameter for the electric field.

charges_key: str = 'charges'

Key of the charges in the input.

graph_key: str = 'graph'

Key of the graph in the input.

polarizability_key: str = 'polarizability'

Key of the polarizability in the input.

trainable: bool = False
FID: ClassVar[str] = 'ELECTRIC_FIELD'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None