fennol.models.embeddings.eeacsf
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, ClassVar 5import numpy as np 6import dataclasses 7 8from ...utils.periodic_table import PERIODIC_TABLE, VALENCE_STRUCTURE 9from ..misc.encodings import SpeciesEncoding, RadialBasis 10 11 12class EEACSF(nn.Module): 13 """element-embracing Atom-Centered Symmetry Functions 14 15 FID : EEACSF 16 17 This is an embedding similar to ANI that include simple chemical information in the AEVs 18 without trainable parameters (fixed embedding). 19 The angle embedding is computed using a low-order Fourier expansion. 20 21 ### Reference 22 Loosely inspired from M. Eckhoff and M. Reiher, Lifelong Machine Learning Potentials, 23 J. Chem. Theory Comput. 2023, 19, 12, 3509–3525, https://doi.org/10.1021/acs.jctc.3c00279 24 25 """ 26 27 _graphs_properties: Dict 28 graph_angle_key: str 29 """ The key in the input dictionary that corresponds to the angular graph.""" 30 nmax_angle: int = 4 31 """ The maximum fourier order for the angle representation.""" 32 embedding_key: str = "embedding" 33 """ The key to use for the output embedding in the returned dictionary.""" 34 graph_key: str = "graph" 35 """ The key in the input dictionary that corresponds to the radial graph.""" 36 species_encoding: dict = dataclasses.field(default_factory=dict) 37 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 38 radial_basis: dict = dataclasses.field(default_factory=dict) 39 """ The radial basis parameters the radial AEV. See `fennol.models.misc.encodings.RadialBasis`""" 40 radial_basis_angle: dict = dataclasses.field(default_factory=dict) 41 """ The radial basis parameters for the angular AEV. See `fennol.models.misc.encodings.RadialBasis`""" 42 angle_combine_pairs: bool = False 43 """ If True, the angular AEV is computed by combining pairs of radial AEV.""" 44 45 FID: ClassVar[str] = "EEACSF" 46 47 @nn.compact 48 def __call__(self, inputs): 49 species = inputs["species"] 50 51 # species encoding 52 53 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 54 species 55 ) 56 57 # Radial graph 58 graph = inputs[self.graph_key] 59 distances = graph["distances"] 60 switch = graph["switch"][:,None] 61 edge_src = graph["edge_src"] 62 edge_dst = graph["edge_dst"] 63 64 # Radial BASIS 65 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 66 radial_terms = ( 67 RadialBasis( 68 **{ 69 **self.radial_basis, 70 "end": cutoff, 71 "name": f"RadialBasis", 72 } 73 )(distances)*switch 74 ) 75 # aggregate radial AEV 76 radial_aev = jax.ops.segment_sum( 77 radial_terms[:, :, None] * onehot[edge_dst, None, :], 78 edge_src, 79 species.shape[0], 80 ).reshape(species.shape[0], -1) 81 82 # Angular graph 83 graph_angle = inputs[self.graph_angle_key] 84 angles = graph_angle["angles"] 85 dang = graph_angle["distances"] 86 central_atom = graph_angle["central_atom"] 87 angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"] 88 switch_angles = graph_angle["switch"][:, None] 89 angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"] 90 edge_dst_ang = graph_angle["edge_dst"] 91 92 93 radial_basis_angle = ( 94 self.radial_basis_angle 95 if self.radial_basis_angle is not None 96 else self.radial_basis 97 ) 98 99 # Angular AEV parameters 100 if self.angle_combine_pairs: 101 factor2 = RadialBasis( 102 **{ 103 **radial_basis_angle, 104 "end": angular_cutoff, 105 "name": f"RadialBasisAng", 106 } 107 )(dang)*switch_angles 108 radial_ang = (factor2[:, :, None] * onehot[edge_dst_ang, None, :]).reshape(-1, onehot.shape[1]*factor2.shape[1]) 109 110 radial_aev_ang = radial_ang[angle_src]*radial_ang[angle_dst] 111 112 # Angular AEV 113 nangles = jnp.asarray( 114 np.arange(self.nmax_angle + 1)[None, :], dtype=angles.dtype 115 ) 116 factor1 = jnp.cos(nangles * angles[:, None]) 117 118 angular_aev = jax.ops.segment_sum( 119 factor1[:, None, :] * radial_aev_ang[:, :, None], 120 central_atom, 121 species.shape[0], 122 ).reshape(species.shape[0], -1) 123 124 else: 125 valence = SpeciesEncoding( 126 encoding="sjs_coordinates", name="SJSEncoding", trainable=False 127 )(species) 128 d12 = 0.5 * (dang[angle_src] + dang[angle_dst]) 129 switch12 = switch_angles[angle_src] * switch_angles[angle_dst] 130 131 132 factor2 = RadialBasis( 133 **{ 134 **radial_basis_angle, 135 "end": angular_cutoff, 136 "name": f"RadialBasisAng", 137 } 138 )(d12) 139 140 # Angular AEV 141 nangles = jnp.asarray( 142 np.arange(self.nmax_angle + 1)[None, :], dtype=angles.dtype 143 ) 144 factor1 = jnp.cos(nangles * angles[:, None]) * switch12 145 146 angular_terms = (factor1[:, None, :] * factor2[:, :, None]).reshape( 147 -1, factor1.shape[1] * factor2.shape[1] 148 ) 149 150 valence_dst = valence[edge_dst_ang] 151 vangsrc = valence_dst[angle_src] 152 vangdst = valence_dst[angle_dst] 153 valence_ang_p = vangsrc + vangdst 154 valence_ang_m = vangsrc * vangdst 155 valence_ang = (valence_ang_p[:, :, None] * valence_ang_m[:, None, :]).reshape( 156 -1, valence_ang_p.shape[1] * valence_ang_m.shape[1] 157 ) 158 159 angular_aev = jax.ops.segment_sum( 160 angular_terms[:, :, None] * valence_ang[:, None, :], 161 central_atom, 162 species.shape[0], 163 ).reshape(species.shape[0], -1) 164 165 166 embedding = jnp.concatenate((onehot, radial_aev, angular_aev), axis=-1) 167 if self.embedding_key is None: 168 return embedding 169 return {**inputs, self.embedding_key: embedding}
13class EEACSF(nn.Module): 14 """element-embracing Atom-Centered Symmetry Functions 15 16 FID : EEACSF 17 18 This is an embedding similar to ANI that include simple chemical information in the AEVs 19 without trainable parameters (fixed embedding). 20 The angle embedding is computed using a low-order Fourier expansion. 21 22 ### Reference 23 Loosely inspired from M. Eckhoff and M. Reiher, Lifelong Machine Learning Potentials, 24 J. Chem. Theory Comput. 2023, 19, 12, 3509–3525, https://doi.org/10.1021/acs.jctc.3c00279 25 26 """ 27 28 _graphs_properties: Dict 29 graph_angle_key: str 30 """ The key in the input dictionary that corresponds to the angular graph.""" 31 nmax_angle: int = 4 32 """ The maximum fourier order for the angle representation.""" 33 embedding_key: str = "embedding" 34 """ The key to use for the output embedding in the returned dictionary.""" 35 graph_key: str = "graph" 36 """ The key in the input dictionary that corresponds to the radial graph.""" 37 species_encoding: dict = dataclasses.field(default_factory=dict) 38 """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`""" 39 radial_basis: dict = dataclasses.field(default_factory=dict) 40 """ The radial basis parameters the radial AEV. See `fennol.models.misc.encodings.RadialBasis`""" 41 radial_basis_angle: dict = dataclasses.field(default_factory=dict) 42 """ The radial basis parameters for the angular AEV. See `fennol.models.misc.encodings.RadialBasis`""" 43 angle_combine_pairs: bool = False 44 """ If True, the angular AEV is computed by combining pairs of radial AEV.""" 45 46 FID: ClassVar[str] = "EEACSF" 47 48 @nn.compact 49 def __call__(self, inputs): 50 species = inputs["species"] 51 52 # species encoding 53 54 onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")( 55 species 56 ) 57 58 # Radial graph 59 graph = inputs[self.graph_key] 60 distances = graph["distances"] 61 switch = graph["switch"][:,None] 62 edge_src = graph["edge_src"] 63 edge_dst = graph["edge_dst"] 64 65 # Radial BASIS 66 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 67 radial_terms = ( 68 RadialBasis( 69 **{ 70 **self.radial_basis, 71 "end": cutoff, 72 "name": f"RadialBasis", 73 } 74 )(distances)*switch 75 ) 76 # aggregate radial AEV 77 radial_aev = jax.ops.segment_sum( 78 radial_terms[:, :, None] * onehot[edge_dst, None, :], 79 edge_src, 80 species.shape[0], 81 ).reshape(species.shape[0], -1) 82 83 # Angular graph 84 graph_angle = inputs[self.graph_angle_key] 85 angles = graph_angle["angles"] 86 dang = graph_angle["distances"] 87 central_atom = graph_angle["central_atom"] 88 angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"] 89 switch_angles = graph_angle["switch"][:, None] 90 angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"] 91 edge_dst_ang = graph_angle["edge_dst"] 92 93 94 radial_basis_angle = ( 95 self.radial_basis_angle 96 if self.radial_basis_angle is not None 97 else self.radial_basis 98 ) 99 100 # Angular AEV parameters 101 if self.angle_combine_pairs: 102 factor2 = RadialBasis( 103 **{ 104 **radial_basis_angle, 105 "end": angular_cutoff, 106 "name": f"RadialBasisAng", 107 } 108 )(dang)*switch_angles 109 radial_ang = (factor2[:, :, None] * onehot[edge_dst_ang, None, :]).reshape(-1, onehot.shape[1]*factor2.shape[1]) 110 111 radial_aev_ang = radial_ang[angle_src]*radial_ang[angle_dst] 112 113 # Angular AEV 114 nangles = jnp.asarray( 115 np.arange(self.nmax_angle + 1)[None, :], dtype=angles.dtype 116 ) 117 factor1 = jnp.cos(nangles * angles[:, None]) 118 119 angular_aev = jax.ops.segment_sum( 120 factor1[:, None, :] * radial_aev_ang[:, :, None], 121 central_atom, 122 species.shape[0], 123 ).reshape(species.shape[0], -1) 124 125 else: 126 valence = SpeciesEncoding( 127 encoding="sjs_coordinates", name="SJSEncoding", trainable=False 128 )(species) 129 d12 = 0.5 * (dang[angle_src] + dang[angle_dst]) 130 switch12 = switch_angles[angle_src] * switch_angles[angle_dst] 131 132 133 factor2 = RadialBasis( 134 **{ 135 **radial_basis_angle, 136 "end": angular_cutoff, 137 "name": f"RadialBasisAng", 138 } 139 )(d12) 140 141 # Angular AEV 142 nangles = jnp.asarray( 143 np.arange(self.nmax_angle + 1)[None, :], dtype=angles.dtype 144 ) 145 factor1 = jnp.cos(nangles * angles[:, None]) * switch12 146 147 angular_terms = (factor1[:, None, :] * factor2[:, :, None]).reshape( 148 -1, factor1.shape[1] * factor2.shape[1] 149 ) 150 151 valence_dst = valence[edge_dst_ang] 152 vangsrc = valence_dst[angle_src] 153 vangdst = valence_dst[angle_dst] 154 valence_ang_p = vangsrc + vangdst 155 valence_ang_m = vangsrc * vangdst 156 valence_ang = (valence_ang_p[:, :, None] * valence_ang_m[:, None, :]).reshape( 157 -1, valence_ang_p.shape[1] * valence_ang_m.shape[1] 158 ) 159 160 angular_aev = jax.ops.segment_sum( 161 angular_terms[:, :, None] * valence_ang[:, None, :], 162 central_atom, 163 species.shape[0], 164 ).reshape(species.shape[0], -1) 165 166 167 embedding = jnp.concatenate((onehot, radial_aev, angular_aev), axis=-1) 168 if self.embedding_key is None: 169 return embedding 170 return {**inputs, self.embedding_key: embedding}
element-embracing Atom-Centered Symmetry Functions
FID : EEACSF
This is an embedding similar to ANI that include simple chemical information in the AEVs without trainable parameters (fixed embedding). The angle embedding is computed using a low-order Fourier expansion.
Reference
Loosely inspired from M. Eckhoff and M. Reiher, Lifelong Machine Learning Potentials, J. Chem. Theory Comput. 2023, 19, 12, 3509–3525, https://doi.org/10.1021/acs.jctc.3c00279
The key to use for the output embedding in the returned dictionary.
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
The radial basis parameters the radial AEV. See fennol.models.misc.encodings.RadialBasis
The radial basis parameters for the angular AEV. See fennol.models.misc.encodings.RadialBasis
If True, the angular AEV is computed by combining pairs of radial AEV.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.