fennol.models.embeddings.ani

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Sequence, Dict, Union, ClassVar
  5import numpy as np
  6from ...utils.periodic_table import PERIODIC_TABLE
  7
  8
  9class ANIAEV(nn.Module):
 10    """Computes the Atomic Environment Vector (AEV) for a given molecular system using the ANI model.
 11
 12    FID : ANI_AEV
 13
 14    ### Reference
 15    J. S. Smith, O. Isayev and A. E. Roitberg, ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost, Chem. Sci., 2017, 8, 3192
 16    """
 17
 18    _graphs_properties: Dict
 19    species_order: Union[str, Sequence[str]]
 20    """ The chemical species which are considered by the model."""
 21    graph_angle_key: str
 22    """ The key in the input dictionary that corresponds to the angular graph."""
 23    radial_eta: float = 16.0
 24    """ Controls the width of the gaussian sensitivity functions in radial AEV."""
 25    angular_eta: float = 8.0
 26    """ Controls the width of the gaussian sensitivity functions in angular AEV."""
 27    radial_dist_divisions: int = 16
 28    """ Number of basis function to encode distance in radial AEV."""
 29    angular_dist_divisions: int = 4
 30    """ Number of basis function to encode distance in angular AEV."""
 31    zeta: float = 32.0
 32    """ The power parameter in angle embedding."""
 33    angle_sections: int = 4
 34    """ The number of angle sections."""
 35    radial_start: float = 0.8
 36    """ The starting distance in radial AEV."""
 37    angular_start: float = 0.8
 38    """ The starting distance in angular AEV."""
 39    embedding_key: str = "embedding"
 40    """ The key to use for the output embedding in the returned dictionary."""
 41    graph_key: str = "graph"
 42    """ The key in the input dictionary that corresponds to the radial graph."""
 43
 44    FID: ClassVar[str] = "ANI_AEV"
 45
 46    @nn.compact
 47    def __call__(self, inputs):
 48        species = inputs["species"]
 49        rev_idx = {s: k for k, s in enumerate(PERIODIC_TABLE)}
 50        maxidx = max(rev_idx.values())
 51
 52        # convert species to internal indices
 53        conv_tensor = [0] * (maxidx + 2)
 54        if isinstance(self.species_order, str):
 55            species_order = [el.strip() for el in self.species_order.split(",")]
 56        else:
 57            species_order = [el for el in self.species_order]
 58        for i, s in enumerate(species_order):
 59            conv_tensor[rev_idx[s]] = i
 60        indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species]
 61        num_species = len(species_order)
 62        num_species_pair = (num_species * (num_species + 1)) // 2
 63
 64        # Radial graph
 65        graph = inputs[self.graph_key]
 66        distances = graph["distances"]
 67        switch = graph["switch"]
 68        edge_src = graph["edge_src"]
 69        edge_dst = graph["edge_dst"]
 70
 71        # Radial AEV
 72        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 73        shiftR = jnp.asarray(
 74            - np.linspace(self.radial_start, cutoff, self.radial_dist_divisions + 1)[
 75                None, :-1
 76            ],
 77            dtype=distances.dtype,
 78        )
 79        x2 = self.radial_eta * (distances[:, None] + shiftR) ** 2
 80        radial_terms = jnp.exp(-x2) * (0.25*switch)[:, None]
 81        # aggregate radial AEV
 82        radial_index = edge_src * num_species + indices[edge_dst]
 83
 84        radial_aev = jax.ops.segment_sum(
 85            radial_terms, radial_index, num_species * species.shape[0]
 86        ).reshape(species.shape[0], num_species * radial_terms.shape[-1])
 87
 88        # Angular graph
 89        # eta2 = self.angular_eta ** 0.5
 90        graph = inputs[self.graph_angle_key]
 91        angles = graph["angles"]
 92        distances = (0.5*self.angular_eta**0.5)*graph["distances"]
 93        central_atom = graph["central_atom"]
 94        angle_src, angle_dst = graph["angle_src"], graph["angle_dst"]
 95        d12 = (distances[angle_src] + distances[angle_dst])[:, None]
 96
 97        # Angular AEV parameters
 98        angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"]
 99        angle_start = np.pi / (2 * self.angle_sections)
100        shiftZ = jnp.asarray(
101            -(np.linspace(0, np.pi, self.angle_sections + 1) + angle_start)[None, :-1],
102            dtype=distances.dtype,
103        )
104        shiftA = jnp.asarray(
105            (-self.angular_eta**0.5)*np.linspace(
106                self.angular_start, angular_cutoff, self.angular_dist_divisions + 1
107            )[None, :-1],
108            dtype=distances.dtype,
109        )
110
111        # Angular AEV
112        switch = graph["switch"] *(2.0 *(0.5**self.zeta))**0.5
113        factor1 = switch[angle_src] * switch[angle_dst]
114        factor1 = (
115            factor1[:, None] * (1 + jnp.cos(angles[:, None] + shiftZ)) ** self.zeta
116        )
117        factor2 = jnp.exp(-(d12 + shiftA) ** 2)
118        angular_terms = (factor1[:, None, :] * factor2[:, :, None]).reshape(
119            -1, self.angle_sections * self.angular_dist_divisions
120        )
121
122        # aggregate angular AEV
123        index_dest = indices[graph["edge_dst"]]
124        species1, species2 = np.triu_indices(num_species, 0)
125        pair_index = np.arange(species1.shape[0], dtype=np.int32)
126        triu_index = np.zeros((num_species, num_species), dtype=np.int32)
127        triu_index[species1, species2] = pair_index
128        triu_index[species2, species1] = pair_index
129        triu_index = jnp.asarray(triu_index, dtype=jnp.int32)
130        angular_index = (
131            central_atom * num_species_pair
132            + triu_index[index_dest[angle_src], index_dest[angle_dst]]
133        )
134
135        angular_aev = jax.ops.segment_sum(
136            angular_terms, angular_index, num_species_pair * species.shape[0]
137        ).reshape(species.shape[0], num_species_pair * angular_terms.shape[-1])
138
139        embedding = jnp.concatenate((radial_aev, angular_aev), axis=-1)
140        if self.embedding_key is None:
141            return embedding
142        return {**inputs, self.embedding_key: embedding}
class ANIAEV(flax.linen.module.Module):
 10class ANIAEV(nn.Module):
 11    """Computes the Atomic Environment Vector (AEV) for a given molecular system using the ANI model.
 12
 13    FID : ANI_AEV
 14
 15    ### Reference
 16    J. S. Smith, O. Isayev and A. E. Roitberg, ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost, Chem. Sci., 2017, 8, 3192
 17    """
 18
 19    _graphs_properties: Dict
 20    species_order: Union[str, Sequence[str]]
 21    """ The chemical species which are considered by the model."""
 22    graph_angle_key: str
 23    """ The key in the input dictionary that corresponds to the angular graph."""
 24    radial_eta: float = 16.0
 25    """ Controls the width of the gaussian sensitivity functions in radial AEV."""
 26    angular_eta: float = 8.0
 27    """ Controls the width of the gaussian sensitivity functions in angular AEV."""
 28    radial_dist_divisions: int = 16
 29    """ Number of basis function to encode distance in radial AEV."""
 30    angular_dist_divisions: int = 4
 31    """ Number of basis function to encode distance in angular AEV."""
 32    zeta: float = 32.0
 33    """ The power parameter in angle embedding."""
 34    angle_sections: int = 4
 35    """ The number of angle sections."""
 36    radial_start: float = 0.8
 37    """ The starting distance in radial AEV."""
 38    angular_start: float = 0.8
 39    """ The starting distance in angular AEV."""
 40    embedding_key: str = "embedding"
 41    """ The key to use for the output embedding in the returned dictionary."""
 42    graph_key: str = "graph"
 43    """ The key in the input dictionary that corresponds to the radial graph."""
 44
 45    FID: ClassVar[str] = "ANI_AEV"
 46
 47    @nn.compact
 48    def __call__(self, inputs):
 49        species = inputs["species"]
 50        rev_idx = {s: k for k, s in enumerate(PERIODIC_TABLE)}
 51        maxidx = max(rev_idx.values())
 52
 53        # convert species to internal indices
 54        conv_tensor = [0] * (maxidx + 2)
 55        if isinstance(self.species_order, str):
 56            species_order = [el.strip() for el in self.species_order.split(",")]
 57        else:
 58            species_order = [el for el in self.species_order]
 59        for i, s in enumerate(species_order):
 60            conv_tensor[rev_idx[s]] = i
 61        indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species]
 62        num_species = len(species_order)
 63        num_species_pair = (num_species * (num_species + 1)) // 2
 64
 65        # Radial graph
 66        graph = inputs[self.graph_key]
 67        distances = graph["distances"]
 68        switch = graph["switch"]
 69        edge_src = graph["edge_src"]
 70        edge_dst = graph["edge_dst"]
 71
 72        # Radial AEV
 73        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 74        shiftR = jnp.asarray(
 75            - np.linspace(self.radial_start, cutoff, self.radial_dist_divisions + 1)[
 76                None, :-1
 77            ],
 78            dtype=distances.dtype,
 79        )
 80        x2 = self.radial_eta * (distances[:, None] + shiftR) ** 2
 81        radial_terms = jnp.exp(-x2) * (0.25*switch)[:, None]
 82        # aggregate radial AEV
 83        radial_index = edge_src * num_species + indices[edge_dst]
 84
 85        radial_aev = jax.ops.segment_sum(
 86            radial_terms, radial_index, num_species * species.shape[0]
 87        ).reshape(species.shape[0], num_species * radial_terms.shape[-1])
 88
 89        # Angular graph
 90        # eta2 = self.angular_eta ** 0.5
 91        graph = inputs[self.graph_angle_key]
 92        angles = graph["angles"]
 93        distances = (0.5*self.angular_eta**0.5)*graph["distances"]
 94        central_atom = graph["central_atom"]
 95        angle_src, angle_dst = graph["angle_src"], graph["angle_dst"]
 96        d12 = (distances[angle_src] + distances[angle_dst])[:, None]
 97
 98        # Angular AEV parameters
 99        angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"]
100        angle_start = np.pi / (2 * self.angle_sections)
101        shiftZ = jnp.asarray(
102            -(np.linspace(0, np.pi, self.angle_sections + 1) + angle_start)[None, :-1],
103            dtype=distances.dtype,
104        )
105        shiftA = jnp.asarray(
106            (-self.angular_eta**0.5)*np.linspace(
107                self.angular_start, angular_cutoff, self.angular_dist_divisions + 1
108            )[None, :-1],
109            dtype=distances.dtype,
110        )
111
112        # Angular AEV
113        switch = graph["switch"] *(2.0 *(0.5**self.zeta))**0.5
114        factor1 = switch[angle_src] * switch[angle_dst]
115        factor1 = (
116            factor1[:, None] * (1 + jnp.cos(angles[:, None] + shiftZ)) ** self.zeta
117        )
118        factor2 = jnp.exp(-(d12 + shiftA) ** 2)
119        angular_terms = (factor1[:, None, :] * factor2[:, :, None]).reshape(
120            -1, self.angle_sections * self.angular_dist_divisions
121        )
122
123        # aggregate angular AEV
124        index_dest = indices[graph["edge_dst"]]
125        species1, species2 = np.triu_indices(num_species, 0)
126        pair_index = np.arange(species1.shape[0], dtype=np.int32)
127        triu_index = np.zeros((num_species, num_species), dtype=np.int32)
128        triu_index[species1, species2] = pair_index
129        triu_index[species2, species1] = pair_index
130        triu_index = jnp.asarray(triu_index, dtype=jnp.int32)
131        angular_index = (
132            central_atom * num_species_pair
133            + triu_index[index_dest[angle_src], index_dest[angle_dst]]
134        )
135
136        angular_aev = jax.ops.segment_sum(
137            angular_terms, angular_index, num_species_pair * species.shape[0]
138        ).reshape(species.shape[0], num_species_pair * angular_terms.shape[-1])
139
140        embedding = jnp.concatenate((radial_aev, angular_aev), axis=-1)
141        if self.embedding_key is None:
142            return embedding
143        return {**inputs, self.embedding_key: embedding}

Computes the Atomic Environment Vector (AEV) for a given molecular system using the ANI model.

FID : ANI_AEV

Reference

J. S. Smith, O. Isayev and A. E. Roitberg, ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost, Chem. Sci., 2017, 8, 3192

ANIAEV( _graphs_properties: Dict, species_order: Union[str, Sequence[str]], graph_angle_key: str, radial_eta: float = 16.0, angular_eta: float = 8.0, radial_dist_divisions: int = 16, angular_dist_divisions: int = 4, zeta: float = 32.0, angle_sections: int = 4, radial_start: float = 0.8, angular_start: float = 0.8, embedding_key: str = 'embedding', graph_key: str = 'graph', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
species_order: Union[str, Sequence[str]]

The chemical species which are considered by the model.

graph_angle_key: str

The key in the input dictionary that corresponds to the angular graph.

radial_eta: float = 16.0

Controls the width of the gaussian sensitivity functions in radial AEV.

angular_eta: float = 8.0

Controls the width of the gaussian sensitivity functions in angular AEV.

radial_dist_divisions: int = 16

Number of basis function to encode distance in radial AEV.

angular_dist_divisions: int = 4

Number of basis function to encode distance in angular AEV.

zeta: float = 32.0

The power parameter in angle embedding.

angle_sections: int = 4

The number of angle sections.

radial_start: float = 0.8

The starting distance in radial AEV.

angular_start: float = 0.8

The starting distance in angular AEV.

embedding_key: str = 'embedding'

The key to use for the output embedding in the returned dictionary.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the radial graph.

FID: ClassVar[str] = 'ANI_AEV'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None