fennol.models.embeddings.eeacsf

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Sequence, Dict, ClassVar
  5import numpy as np
  6import dataclasses
  7
  8from ...utils.periodic_table import PERIODIC_TABLE, VALENCE_STRUCTURE
  9from ..misc.encodings import SpeciesEncoding, RadialBasis
 10
 11
 12class EEACSF(nn.Module):
 13    """element-embracing Atom-Centered Symmetry Functions 
 14
 15    FID : EEACSF
 16
 17    This is an embedding similar to ANI that include simple chemical information in the AEVs
 18    without trainable parameters (fixed embedding).
 19    The angle embedding is computed using a low-order Fourier expansion.
 20
 21    ### Reference
 22    Loosely inspired from M. Eckhoff and M. Reiher, Lifelong Machine Learning Potentials,
 23    J. Chem. Theory Comput. 2023, 19, 12, 3509–3525, https://doi.org/10.1021/acs.jctc.3c00279
 24
 25    """
 26
 27    _graphs_properties: Dict
 28    graph_angle_key: str
 29    """ The key in the input dictionary that corresponds to the angular graph."""
 30    nmax_angle: int = 4
 31    """ The maximum fourier order for the angle representation."""
 32    embedding_key: str = "embedding"
 33    """ The key to use for the output embedding in the returned dictionary."""
 34    graph_key: str = "graph"
 35    """ The key in the input dictionary that corresponds to the radial graph."""
 36    species_encoding: dict = dataclasses.field(default_factory=dict)
 37    """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 38    radial_basis: dict = dataclasses.field(default_factory=dict)
 39    """ The radial basis parameters the radial AEV. See `fennol.models.misc.encodings.RadialBasis`"""
 40    radial_basis_angle: dict = dataclasses.field(default_factory=dict)
 41    """ The radial basis parameters for the angular AEV. See `fennol.models.misc.encodings.RadialBasis`"""
 42    angle_combine_pairs: bool = False
 43    """ If True, the angular AEV is computed by combining pairs of radial AEV."""
 44
 45    FID: ClassVar[str] = "EEACSF"
 46
 47    @nn.compact
 48    def __call__(self, inputs):
 49        species = inputs["species"]
 50
 51        # species encoding
 52        
 53        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 54            species
 55        )
 56
 57        # Radial graph
 58        graph = inputs[self.graph_key]
 59        distances = graph["distances"]
 60        switch = graph["switch"][:,None]
 61        edge_src = graph["edge_src"]
 62        edge_dst = graph["edge_dst"]
 63
 64        # Radial BASIS
 65        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 66        radial_terms = (
 67            RadialBasis(
 68                **{
 69                    **self.radial_basis,
 70                    "end": cutoff,
 71                    "name": f"RadialBasis",
 72                }
 73            )(distances)*switch
 74        )
 75        # aggregate radial AEV
 76        radial_aev = jax.ops.segment_sum(
 77            radial_terms[:, :, None] * onehot[edge_dst, None, :],
 78            edge_src,
 79            species.shape[0],
 80        ).reshape(species.shape[0], -1)
 81
 82        # Angular graph
 83        graph_angle = inputs[self.graph_angle_key]     
 84        angles = graph_angle["angles"]
 85        dang = graph_angle["distances"]
 86        central_atom = graph_angle["central_atom"]
 87        angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"]
 88        switch_angles = graph_angle["switch"][:, None]
 89        angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"]
 90        edge_dst_ang = graph_angle["edge_dst"]
 91
 92
 93        radial_basis_angle = (
 94            self.radial_basis_angle
 95            if self.radial_basis_angle is not None
 96            else self.radial_basis
 97        )
 98
 99        # Angular AEV parameters
100        if self.angle_combine_pairs:
101            factor2 = RadialBasis(
102                **{
103                    **radial_basis_angle,
104                    "end": angular_cutoff,
105                    "name": f"RadialBasisAng",
106                }
107            )(dang)*switch_angles
108            radial_ang = (factor2[:, :, None] * onehot[edge_dst_ang, None, :]).reshape(-1, onehot.shape[1]*factor2.shape[1])
109
110            radial_aev_ang = radial_ang[angle_src]*radial_ang[angle_dst]
111
112            # Angular AEV
113            nangles = jnp.asarray(
114                np.arange(self.nmax_angle + 1)[None, :], dtype=angles.dtype
115            )
116            factor1 = jnp.cos(nangles * angles[:, None])
117
118            angular_aev = jax.ops.segment_sum(
119                factor1[:, None, :] * radial_aev_ang[:, :, None],
120                central_atom,
121                species.shape[0],
122            ).reshape(species.shape[0], -1)
123
124        else:
125            valence = SpeciesEncoding(
126                encoding="sjs_coordinates", name="SJSEncoding", trainable=False
127            )(species)
128            d12 = 0.5 * (dang[angle_src] + dang[angle_dst])
129            switch12 = switch_angles[angle_src] * switch_angles[angle_dst]
130
131            
132            factor2 = RadialBasis(
133                **{
134                    **radial_basis_angle,
135                    "end": angular_cutoff,
136                    "name": f"RadialBasisAng",
137                }
138            )(d12)
139
140            # Angular AEV
141            nangles = jnp.asarray(
142                np.arange(self.nmax_angle + 1)[None, :], dtype=angles.dtype
143            )
144            factor1 = jnp.cos(nangles * angles[:, None]) * switch12
145
146            angular_terms = (factor1[:, None, :] * factor2[:, :, None]).reshape(
147                -1, factor1.shape[1] * factor2.shape[1]
148            )
149
150            valence_dst = valence[edge_dst_ang]
151            vangsrc = valence_dst[angle_src]
152            vangdst = valence_dst[angle_dst]
153            valence_ang_p = vangsrc + vangdst
154            valence_ang_m = vangsrc * vangdst
155            valence_ang = (valence_ang_p[:, :, None] * valence_ang_m[:, None, :]).reshape(
156                -1, valence_ang_p.shape[1] * valence_ang_m.shape[1]
157            )
158
159            angular_aev = jax.ops.segment_sum(
160                angular_terms[:, :, None] * valence_ang[:, None, :],
161                central_atom,
162                species.shape[0],
163            ).reshape(species.shape[0], -1)
164
165        
166        embedding = jnp.concatenate((onehot, radial_aev, angular_aev), axis=-1)
167        if self.embedding_key is None:
168            return embedding
169        return {**inputs, self.embedding_key: embedding}
class EEACSF(flax.linen.module.Module):
 13class EEACSF(nn.Module):
 14    """element-embracing Atom-Centered Symmetry Functions 
 15
 16    FID : EEACSF
 17
 18    This is an embedding similar to ANI that include simple chemical information in the AEVs
 19    without trainable parameters (fixed embedding).
 20    The angle embedding is computed using a low-order Fourier expansion.
 21
 22    ### Reference
 23    Loosely inspired from M. Eckhoff and M. Reiher, Lifelong Machine Learning Potentials,
 24    J. Chem. Theory Comput. 2023, 19, 12, 3509–3525, https://doi.org/10.1021/acs.jctc.3c00279
 25
 26    """
 27
 28    _graphs_properties: Dict
 29    graph_angle_key: str
 30    """ The key in the input dictionary that corresponds to the angular graph."""
 31    nmax_angle: int = 4
 32    """ The maximum fourier order for the angle representation."""
 33    embedding_key: str = "embedding"
 34    """ The key to use for the output embedding in the returned dictionary."""
 35    graph_key: str = "graph"
 36    """ The key in the input dictionary that corresponds to the radial graph."""
 37    species_encoding: dict = dataclasses.field(default_factory=dict)
 38    """ The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 39    radial_basis: dict = dataclasses.field(default_factory=dict)
 40    """ The radial basis parameters the radial AEV. See `fennol.models.misc.encodings.RadialBasis`"""
 41    radial_basis_angle: dict = dataclasses.field(default_factory=dict)
 42    """ The radial basis parameters for the angular AEV. See `fennol.models.misc.encodings.RadialBasis`"""
 43    angle_combine_pairs: bool = False
 44    """ If True, the angular AEV is computed by combining pairs of radial AEV."""
 45
 46    FID: ClassVar[str] = "EEACSF"
 47
 48    @nn.compact
 49    def __call__(self, inputs):
 50        species = inputs["species"]
 51
 52        # species encoding
 53        
 54        onehot = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(
 55            species
 56        )
 57
 58        # Radial graph
 59        graph = inputs[self.graph_key]
 60        distances = graph["distances"]
 61        switch = graph["switch"][:,None]
 62        edge_src = graph["edge_src"]
 63        edge_dst = graph["edge_dst"]
 64
 65        # Radial BASIS
 66        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 67        radial_terms = (
 68            RadialBasis(
 69                **{
 70                    **self.radial_basis,
 71                    "end": cutoff,
 72                    "name": f"RadialBasis",
 73                }
 74            )(distances)*switch
 75        )
 76        # aggregate radial AEV
 77        radial_aev = jax.ops.segment_sum(
 78            radial_terms[:, :, None] * onehot[edge_dst, None, :],
 79            edge_src,
 80            species.shape[0],
 81        ).reshape(species.shape[0], -1)
 82
 83        # Angular graph
 84        graph_angle = inputs[self.graph_angle_key]     
 85        angles = graph_angle["angles"]
 86        dang = graph_angle["distances"]
 87        central_atom = graph_angle["central_atom"]
 88        angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"]
 89        switch_angles = graph_angle["switch"][:, None]
 90        angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"]
 91        edge_dst_ang = graph_angle["edge_dst"]
 92
 93
 94        radial_basis_angle = (
 95            self.radial_basis_angle
 96            if self.radial_basis_angle is not None
 97            else self.radial_basis
 98        )
 99
100        # Angular AEV parameters
101        if self.angle_combine_pairs:
102            factor2 = RadialBasis(
103                **{
104                    **radial_basis_angle,
105                    "end": angular_cutoff,
106                    "name": f"RadialBasisAng",
107                }
108            )(dang)*switch_angles
109            radial_ang = (factor2[:, :, None] * onehot[edge_dst_ang, None, :]).reshape(-1, onehot.shape[1]*factor2.shape[1])
110
111            radial_aev_ang = radial_ang[angle_src]*radial_ang[angle_dst]
112
113            # Angular AEV
114            nangles = jnp.asarray(
115                np.arange(self.nmax_angle + 1)[None, :], dtype=angles.dtype
116            )
117            factor1 = jnp.cos(nangles * angles[:, None])
118
119            angular_aev = jax.ops.segment_sum(
120                factor1[:, None, :] * radial_aev_ang[:, :, None],
121                central_atom,
122                species.shape[0],
123            ).reshape(species.shape[0], -1)
124
125        else:
126            valence = SpeciesEncoding(
127                encoding="sjs_coordinates", name="SJSEncoding", trainable=False
128            )(species)
129            d12 = 0.5 * (dang[angle_src] + dang[angle_dst])
130            switch12 = switch_angles[angle_src] * switch_angles[angle_dst]
131
132            
133            factor2 = RadialBasis(
134                **{
135                    **radial_basis_angle,
136                    "end": angular_cutoff,
137                    "name": f"RadialBasisAng",
138                }
139            )(d12)
140
141            # Angular AEV
142            nangles = jnp.asarray(
143                np.arange(self.nmax_angle + 1)[None, :], dtype=angles.dtype
144            )
145            factor1 = jnp.cos(nangles * angles[:, None]) * switch12
146
147            angular_terms = (factor1[:, None, :] * factor2[:, :, None]).reshape(
148                -1, factor1.shape[1] * factor2.shape[1]
149            )
150
151            valence_dst = valence[edge_dst_ang]
152            vangsrc = valence_dst[angle_src]
153            vangdst = valence_dst[angle_dst]
154            valence_ang_p = vangsrc + vangdst
155            valence_ang_m = vangsrc * vangdst
156            valence_ang = (valence_ang_p[:, :, None] * valence_ang_m[:, None, :]).reshape(
157                -1, valence_ang_p.shape[1] * valence_ang_m.shape[1]
158            )
159
160            angular_aev = jax.ops.segment_sum(
161                angular_terms[:, :, None] * valence_ang[:, None, :],
162                central_atom,
163                species.shape[0],
164            ).reshape(species.shape[0], -1)
165
166        
167        embedding = jnp.concatenate((onehot, radial_aev, angular_aev), axis=-1)
168        if self.embedding_key is None:
169            return embedding
170        return {**inputs, self.embedding_key: embedding}

element-embracing Atom-Centered Symmetry Functions

FID : EEACSF

This is an embedding similar to ANI that include simple chemical information in the AEVs without trainable parameters (fixed embedding). The angle embedding is computed using a low-order Fourier expansion.

Reference

Loosely inspired from M. Eckhoff and M. Reiher, Lifelong Machine Learning Potentials, J. Chem. Theory Comput. 2023, 19, 12, 3509–3525, https://doi.org/10.1021/acs.jctc.3c00279

EEACSF( _graphs_properties: Dict, graph_angle_key: str, nmax_angle: int = 4, embedding_key: str = 'embedding', graph_key: str = 'graph', species_encoding: dict = <factory>, radial_basis: dict = <factory>, radial_basis_angle: dict = <factory>, angle_combine_pairs: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_angle_key: str

The key in the input dictionary that corresponds to the angular graph.

nmax_angle: int = 4

The maximum fourier order for the angle representation.

embedding_key: str = 'embedding'

The key to use for the output embedding in the returned dictionary.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the radial graph.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding

radial_basis: dict

The radial basis parameters the radial AEV. See fennol.models.misc.encodings.RadialBasis

radial_basis_angle: dict

The radial basis parameters for the angular AEV. See fennol.models.misc.encodings.RadialBasis

angle_combine_pairs: bool = False

If True, the angular AEV is computed by combining pairs of radial AEV.

FID: ClassVar[str] = 'EEACSF'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None