fennol.models.physics.polarisation
Polarization model for FENNOL.
Created by C. Cattin 2024
1#!/usr/bin/env python3 2"""Polarization model for FENNOL. 3 4Created by C. Cattin 2024 5""" 6 7import flax.linen as nn 8import jax 9import jax.numpy as jnp 10from typing import Optional, ClassVar 11 12from fennol.utils import AtomicUnits as Au 13 14 15class Polarization(nn.Module): 16 """Polarization model with Thole damping scheme. 17 18 FID: POLARIZATION 19 """ 20 21 energy_key: Optional[str] = None 22 """Key of the energy in the outputs.""" 23 graph_key: str = 'graph' 24 """Key of the graph in the inputs.""" 25 polarizability_key: str = 'polarizability' 26 """Key of the polarizability in the inputs.""" 27 electric_field_key: str = 'electric_field' 28 """Key of the electric field in the inputs.""" 29 induced_dipoles_key: str = 'induced_dipoles' 30 """Key of the induced dipole in the outputs.""" 31 damping_param_mutual: float = 0.39 32 """Damping parameter for mutual polarization.""" 33 neglect_mutual: bool = False 34 """Neglect the mutual polarization term (like in iAMOEBA).""" 35 _energy_unit: str = 'Ha' 36 """The energy unit of the model. **Automatically set by FENNIX**""" 37 38 FID: ClassVar[str] = 'POLARIZATION' 39 40 @nn.compact 41 def __call__(self, inputs): 42 """Forward pass of the polarization model. 43 44 Parameters 45 ---------- 46 inputs : dict 47 Input dictionary containing all the info about the system. 48 This dictionary is given from the FENNIX class. 49 """ 50 # Species 51 species = inputs['species'] 52 graph = inputs[self.graph_key] 53 # Graph 54 edge_src, edge_dst = graph['edge_src'], graph['edge_dst'] 55 # Distances and vector between each pair of atoms in atomic units 56 distances = graph['distances'] 57 rij = distances / Au.BOHR 58 vec_ij = graph['vec'] / Au.BOHR 59 # Polarizability 60 polarizability = ( 61 inputs[self.polarizability_key] / Au.BOHR**3 62 ) 63 64 pol_src = polarizability[edge_src] 65 pol_dst = polarizability[edge_dst] 66 alpha_ij = pol_dst * pol_src 67 68 # The output is a dictionary with the polarization energy 69 output = {} 70 71 ###################### 72 # Interaction matrix # 73 ###################### 74 # Diagonal terms 75 tii = 1 / polarizability[:, None] 76 77 78 ################## 79 # Electric field # 80 ################## 81 electric_field = inputs[self.electric_field_key] 82 83 ############################### 84 # Electric point dipole moment# 85 ############################### 86 if self.neglect_mutual: 87 electric_field = electric_field.reshape(-1, 3) 88 mu = polarizability[:,None]*electric_field 89 mu_ = jax.lax.stop_gradient(mu) 90 tmu = tii*mu_ 91 92 else: 93 # Effective distance 94 uij = rij / alpha_ij ** (1 / 6) 95 # Damping terms 96 exp = jnp.exp(-self.damping_param_mutual * uij**3) 97 lambda_3 = 1 - exp 98 lambda_5 = 1 - (1 + self.damping_param_mutual * uij**3) * exp 99 # Off-diagonal terms 100 tij = ( 101 3 * lambda_5[:, None, None] 102 * vec_ij[:, :, None] * vec_ij[:, None, :] 103 / rij[:, None, None]**5 104 - jnp.eye(3)[None, :, :] * lambda_3[:, None, None] 105 / rij[:, None, None]**3 106 ) 107 108 def matvec(mui): 109 """Compute the matrix vector product of T and mu.""" 110 mui = mui.reshape(-1, 3) 111 tmu_self = tii * mui 112 tmu_pair = jnp.einsum("jab,jb->ja", tij, mui[edge_dst]) 113 tmu = ( 114 jax.ops.segment_sum(tmu_pair, edge_src, species.shape[0]) 115 + tmu_self 116 ) 117 return tmu.flatten() 118 mu = jax.scipy.sparse.linalg.cg(matvec, electric_field)[0] 119 mu_ = jax.lax.stop_gradient(mu) 120 121 # Matrix vector product 122 tmu = matvec(mu_) 123 124 # Polarization energy 125 pol_energy = ( 126 (0.5 * tmu - electric_field) * mu_ 127 ).reshape(-1, 3).sum(axis=1) 128 129 # Output 130 output[self.electric_field_key] = electric_field.reshape(-1, 3) 131 output[self.induced_dipoles_key] = mu.reshape(-1, 3) * Au.BOHR 132 energy_key = ( 133 self.energy_key if self.energy_key is not None else self.name 134 ) 135 energy_unit = Au.get_multiplier(self._energy_unit) 136 output[energy_key] = pol_energy*energy_unit 137 output['tmu'] = tmu.reshape(-1, 3) 138 139 return {**inputs, **output} 140 141 142if __name__ == "__main__": 143 pass
class
Polarization(flax.linen.module.Module):
16class Polarization(nn.Module): 17 """Polarization model with Thole damping scheme. 18 19 FID: POLARIZATION 20 """ 21 22 energy_key: Optional[str] = None 23 """Key of the energy in the outputs.""" 24 graph_key: str = 'graph' 25 """Key of the graph in the inputs.""" 26 polarizability_key: str = 'polarizability' 27 """Key of the polarizability in the inputs.""" 28 electric_field_key: str = 'electric_field' 29 """Key of the electric field in the inputs.""" 30 induced_dipoles_key: str = 'induced_dipoles' 31 """Key of the induced dipole in the outputs.""" 32 damping_param_mutual: float = 0.39 33 """Damping parameter for mutual polarization.""" 34 neglect_mutual: bool = False 35 """Neglect the mutual polarization term (like in iAMOEBA).""" 36 _energy_unit: str = 'Ha' 37 """The energy unit of the model. **Automatically set by FENNIX**""" 38 39 FID: ClassVar[str] = 'POLARIZATION' 40 41 @nn.compact 42 def __call__(self, inputs): 43 """Forward pass of the polarization model. 44 45 Parameters 46 ---------- 47 inputs : dict 48 Input dictionary containing all the info about the system. 49 This dictionary is given from the FENNIX class. 50 """ 51 # Species 52 species = inputs['species'] 53 graph = inputs[self.graph_key] 54 # Graph 55 edge_src, edge_dst = graph['edge_src'], graph['edge_dst'] 56 # Distances and vector between each pair of atoms in atomic units 57 distances = graph['distances'] 58 rij = distances / Au.BOHR 59 vec_ij = graph['vec'] / Au.BOHR 60 # Polarizability 61 polarizability = ( 62 inputs[self.polarizability_key] / Au.BOHR**3 63 ) 64 65 pol_src = polarizability[edge_src] 66 pol_dst = polarizability[edge_dst] 67 alpha_ij = pol_dst * pol_src 68 69 # The output is a dictionary with the polarization energy 70 output = {} 71 72 ###################### 73 # Interaction matrix # 74 ###################### 75 # Diagonal terms 76 tii = 1 / polarizability[:, None] 77 78 79 ################## 80 # Electric field # 81 ################## 82 electric_field = inputs[self.electric_field_key] 83 84 ############################### 85 # Electric point dipole moment# 86 ############################### 87 if self.neglect_mutual: 88 electric_field = electric_field.reshape(-1, 3) 89 mu = polarizability[:,None]*electric_field 90 mu_ = jax.lax.stop_gradient(mu) 91 tmu = tii*mu_ 92 93 else: 94 # Effective distance 95 uij = rij / alpha_ij ** (1 / 6) 96 # Damping terms 97 exp = jnp.exp(-self.damping_param_mutual * uij**3) 98 lambda_3 = 1 - exp 99 lambda_5 = 1 - (1 + self.damping_param_mutual * uij**3) * exp 100 # Off-diagonal terms 101 tij = ( 102 3 * lambda_5[:, None, None] 103 * vec_ij[:, :, None] * vec_ij[:, None, :] 104 / rij[:, None, None]**5 105 - jnp.eye(3)[None, :, :] * lambda_3[:, None, None] 106 / rij[:, None, None]**3 107 ) 108 109 def matvec(mui): 110 """Compute the matrix vector product of T and mu.""" 111 mui = mui.reshape(-1, 3) 112 tmu_self = tii * mui 113 tmu_pair = jnp.einsum("jab,jb->ja", tij, mui[edge_dst]) 114 tmu = ( 115 jax.ops.segment_sum(tmu_pair, edge_src, species.shape[0]) 116 + tmu_self 117 ) 118 return tmu.flatten() 119 mu = jax.scipy.sparse.linalg.cg(matvec, electric_field)[0] 120 mu_ = jax.lax.stop_gradient(mu) 121 122 # Matrix vector product 123 tmu = matvec(mu_) 124 125 # Polarization energy 126 pol_energy = ( 127 (0.5 * tmu - electric_field) * mu_ 128 ).reshape(-1, 3).sum(axis=1) 129 130 # Output 131 output[self.electric_field_key] = electric_field.reshape(-1, 3) 132 output[self.induced_dipoles_key] = mu.reshape(-1, 3) * Au.BOHR 133 energy_key = ( 134 self.energy_key if self.energy_key is not None else self.name 135 ) 136 energy_unit = Au.get_multiplier(self._energy_unit) 137 output[energy_key] = pol_energy*energy_unit 138 output['tmu'] = tmu.reshape(-1, 3) 139 140 return {**inputs, **output}
Polarization model with Thole damping scheme.
FID: POLARIZATION
Polarization( energy_key: Optional[str] = None, graph_key: str = 'graph', polarizability_key: str = 'polarizability', electric_field_key: str = 'electric_field', induced_dipoles_key: str = 'induced_dipoles', damping_param_mutual: float = 0.39, neglect_mutual: bool = False, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.