fennol.models.physics.polarisation

Polarization model for FENNOL.

Created by C. Cattin 2024

  1#!/usr/bin/env python3
  2"""Polarization model for FENNOL.
  3
  4Created by C. Cattin 2024
  5"""
  6
  7import flax.linen as nn
  8import jax
  9import jax.numpy as jnp
 10from typing import Optional, ClassVar
 11
 12from fennol.utils import AtomicUnits as Au
 13
 14
 15class Polarization(nn.Module):
 16    """Polarization model with Thole damping scheme.
 17    
 18    FID: POLARIZATION
 19    """
 20
 21    energy_key: Optional[str] = None
 22    """Key of the energy in the outputs."""
 23    graph_key: str = 'graph'
 24    """Key of the graph in the inputs."""
 25    polarizability_key: str = 'polarizability'
 26    """Key of the polarizability in the inputs."""
 27    electric_field_key: str = 'electric_field'
 28    """Key of the electric field in the inputs."""
 29    induced_dipoles_key: str = 'induced_dipoles'
 30    """Key of the induced dipole in the outputs."""
 31    damping_param_mutual: float = 0.39
 32    """Damping parameter for mutual polarization."""
 33    neglect_mutual: bool = False
 34    """Neglect the mutual polarization term (like in iAMOEBA)."""
 35    _energy_unit: str = 'Ha'
 36    """The energy unit of the model. **Automatically set by FENNIX**"""
 37
 38    FID: ClassVar[str] = 'POLARIZATION'
 39
 40    @nn.compact
 41    def __call__(self, inputs):
 42        """Forward pass of the polarization model.
 43
 44        Parameters
 45        ----------
 46        inputs : dict
 47            Input dictionary containing all the info about the system.
 48            This dictionary is given from the FENNIX class.
 49        """
 50        # Species
 51        species = inputs['species']
 52        graph = inputs[self.graph_key]
 53        # Graph
 54        edge_src, edge_dst = graph['edge_src'], graph['edge_dst']
 55        # Distances and vector between each pair of atoms in atomic units
 56        distances = graph['distances']
 57        rij = distances / Au.BOHR
 58        vec_ij = graph['vec'] / Au.BOHR
 59        # Polarizability
 60        polarizability = (
 61            inputs[self.polarizability_key] / Au.BOHR**3
 62        )
 63
 64        pol_src = polarizability[edge_src]
 65        pol_dst = polarizability[edge_dst]
 66        alpha_ij = pol_dst * pol_src
 67
 68        # The output is a dictionary with the polarization energy
 69        output = {}
 70
 71        ######################
 72        # Interaction matrix #
 73        ######################
 74        # Diagonal terms
 75        tii = 1 / polarizability[:, None]
 76        
 77
 78        ##################
 79        # Electric field #
 80        ##################
 81        electric_field = inputs[self.electric_field_key]
 82
 83        ###############################
 84        # Electric point dipole moment#
 85        ###############################
 86        if self.neglect_mutual:
 87            electric_field = electric_field.reshape(-1, 3)
 88            mu = polarizability[:,None]*electric_field
 89            mu_ = jax.lax.stop_gradient(mu)
 90            tmu = tii*mu_
 91            
 92        else:
 93            # Effective distance
 94            uij = rij / alpha_ij ** (1 / 6)
 95            # Damping terms
 96            exp = jnp.exp(-self.damping_param_mutual * uij**3)
 97            lambda_3 = 1 - exp
 98            lambda_5 = 1 - (1 + self.damping_param_mutual * uij**3) * exp
 99            # Off-diagonal terms
100            tij = (
101                3 * lambda_5[:, None, None]
102                    * vec_ij[:, :, None] * vec_ij[:, None, :]
103                    / rij[:, None, None]**5
104                - jnp.eye(3)[None, :, :] * lambda_3[:, None, None]
105                    / rij[:, None, None]**3
106            )
107
108            def matvec(mui):
109                """Compute the matrix vector product of T and mu."""
110                mui = mui.reshape(-1, 3)
111                tmu_self = tii * mui
112                tmu_pair = jnp.einsum("jab,jb->ja", tij, mui[edge_dst])
113                tmu = (
114                    jax.ops.segment_sum(tmu_pair, edge_src, species.shape[0])
115                    + tmu_self
116                )
117                return tmu.flatten()
118            mu = jax.scipy.sparse.linalg.cg(matvec, electric_field)[0]
119            mu_ = jax.lax.stop_gradient(mu)
120
121            # Matrix vector product
122            tmu = matvec(mu_)
123
124        # Polarization energy
125        pol_energy = (
126            (0.5 * tmu - electric_field) * mu_
127        ).reshape(-1, 3).sum(axis=1)
128
129        # Output
130        output[self.electric_field_key] = electric_field.reshape(-1, 3)
131        output[self.induced_dipoles_key] = mu.reshape(-1, 3) * Au.BOHR
132        energy_key = (
133            self.energy_key if self.energy_key is not None else self.name
134        )
135        energy_unit = Au.get_multiplier(self._energy_unit)
136        output[energy_key] = pol_energy*energy_unit
137        output['tmu'] = tmu.reshape(-1, 3)
138
139        return {**inputs, **output}
140
141
142if __name__ == "__main__":
143    pass
class Polarization(flax.linen.module.Module):
 16class Polarization(nn.Module):
 17    """Polarization model with Thole damping scheme.
 18    
 19    FID: POLARIZATION
 20    """
 21
 22    energy_key: Optional[str] = None
 23    """Key of the energy in the outputs."""
 24    graph_key: str = 'graph'
 25    """Key of the graph in the inputs."""
 26    polarizability_key: str = 'polarizability'
 27    """Key of the polarizability in the inputs."""
 28    electric_field_key: str = 'electric_field'
 29    """Key of the electric field in the inputs."""
 30    induced_dipoles_key: str = 'induced_dipoles'
 31    """Key of the induced dipole in the outputs."""
 32    damping_param_mutual: float = 0.39
 33    """Damping parameter for mutual polarization."""
 34    neglect_mutual: bool = False
 35    """Neglect the mutual polarization term (like in iAMOEBA)."""
 36    _energy_unit: str = 'Ha'
 37    """The energy unit of the model. **Automatically set by FENNIX**"""
 38
 39    FID: ClassVar[str] = 'POLARIZATION'
 40
 41    @nn.compact
 42    def __call__(self, inputs):
 43        """Forward pass of the polarization model.
 44
 45        Parameters
 46        ----------
 47        inputs : dict
 48            Input dictionary containing all the info about the system.
 49            This dictionary is given from the FENNIX class.
 50        """
 51        # Species
 52        species = inputs['species']
 53        graph = inputs[self.graph_key]
 54        # Graph
 55        edge_src, edge_dst = graph['edge_src'], graph['edge_dst']
 56        # Distances and vector between each pair of atoms in atomic units
 57        distances = graph['distances']
 58        rij = distances / Au.BOHR
 59        vec_ij = graph['vec'] / Au.BOHR
 60        # Polarizability
 61        polarizability = (
 62            inputs[self.polarizability_key] / Au.BOHR**3
 63        )
 64
 65        pol_src = polarizability[edge_src]
 66        pol_dst = polarizability[edge_dst]
 67        alpha_ij = pol_dst * pol_src
 68
 69        # The output is a dictionary with the polarization energy
 70        output = {}
 71
 72        ######################
 73        # Interaction matrix #
 74        ######################
 75        # Diagonal terms
 76        tii = 1 / polarizability[:, None]
 77        
 78
 79        ##################
 80        # Electric field #
 81        ##################
 82        electric_field = inputs[self.electric_field_key]
 83
 84        ###############################
 85        # Electric point dipole moment#
 86        ###############################
 87        if self.neglect_mutual:
 88            electric_field = electric_field.reshape(-1, 3)
 89            mu = polarizability[:,None]*electric_field
 90            mu_ = jax.lax.stop_gradient(mu)
 91            tmu = tii*mu_
 92            
 93        else:
 94            # Effective distance
 95            uij = rij / alpha_ij ** (1 / 6)
 96            # Damping terms
 97            exp = jnp.exp(-self.damping_param_mutual * uij**3)
 98            lambda_3 = 1 - exp
 99            lambda_5 = 1 - (1 + self.damping_param_mutual * uij**3) * exp
100            # Off-diagonal terms
101            tij = (
102                3 * lambda_5[:, None, None]
103                    * vec_ij[:, :, None] * vec_ij[:, None, :]
104                    / rij[:, None, None]**5
105                - jnp.eye(3)[None, :, :] * lambda_3[:, None, None]
106                    / rij[:, None, None]**3
107            )
108
109            def matvec(mui):
110                """Compute the matrix vector product of T and mu."""
111                mui = mui.reshape(-1, 3)
112                tmu_self = tii * mui
113                tmu_pair = jnp.einsum("jab,jb->ja", tij, mui[edge_dst])
114                tmu = (
115                    jax.ops.segment_sum(tmu_pair, edge_src, species.shape[0])
116                    + tmu_self
117                )
118                return tmu.flatten()
119            mu = jax.scipy.sparse.linalg.cg(matvec, electric_field)[0]
120            mu_ = jax.lax.stop_gradient(mu)
121
122            # Matrix vector product
123            tmu = matvec(mu_)
124
125        # Polarization energy
126        pol_energy = (
127            (0.5 * tmu - electric_field) * mu_
128        ).reshape(-1, 3).sum(axis=1)
129
130        # Output
131        output[self.electric_field_key] = electric_field.reshape(-1, 3)
132        output[self.induced_dipoles_key] = mu.reshape(-1, 3) * Au.BOHR
133        energy_key = (
134            self.energy_key if self.energy_key is not None else self.name
135        )
136        energy_unit = Au.get_multiplier(self._energy_unit)
137        output[energy_key] = pol_energy*energy_unit
138        output['tmu'] = tmu.reshape(-1, 3)
139
140        return {**inputs, **output}

Polarization model with Thole damping scheme.

FID: POLARIZATION

Polarization( energy_key: Optional[str] = None, graph_key: str = 'graph', polarizability_key: str = 'polarizability', electric_field_key: str = 'electric_field', induced_dipoles_key: str = 'induced_dipoles', damping_param_mutual: float = 0.39, neglect_mutual: bool = False, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
energy_key: Optional[str] = None

Key of the energy in the outputs.

graph_key: str = 'graph'

Key of the graph in the inputs.

polarizability_key: str = 'polarizability'

Key of the polarizability in the inputs.

electric_field_key: str = 'electric_field'

Key of the electric field in the inputs.

induced_dipoles_key: str = 'induced_dipoles'

Key of the induced dipole in the outputs.

damping_param_mutual: float = 0.39

Damping parameter for mutual polarization.

neglect_mutual: bool = False

Neglect the mutual polarization term (like in iAMOEBA).

FID: ClassVar[str] = 'POLARIZATION'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None