fennol.models.embeddings.aimnet

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Sequence, Dict, Callable, ClassVar,Union
  5import numpy as np
  6from ...utils.periodic_table import PERIODIC_TABLE
  7from ..misc.encodings import SpeciesEncoding
  8from ..misc.nets import FullyConnectedNet
  9import dataclasses
 10
 11
 12class AIMNet(nn.Module):
 13    """Atom-In-Molecule Network message-passing embedding
 14
 15    FID : AIMNET
 16
 17    ### Reference
 18    Roman Zubatyuk et al. ,Accurate and transferable multitask prediction of chemical properties with an atoms-in-molecules neural network.Sci. Adv.5,eaav6490(2019).DOI:10.1126/sciadv.aav6490
 19    """
 20
 21    _graphs_properties: Dict
 22    graph_angle_key: str
 23    """ The key in the input dictionary that corresponds to the angular graph. """
 24    nlayers: int = 3
 25    """ The number of message-passing layers."""
 26    zmax: int = 86
 27    """ The maximum atomic number to allocate AFV."""
 28    radial_eta: float = 16.0
 29    """ Controls the width of the gaussian sensity functions in radial AEV."""
 30    angular_eta: float = 8.0
 31    """ Controls the width of the gaussian sensity functions in angular AEV."""
 32    radial_dist_divisions: int = 16
 33    """ Number of basis function to encode ditance in radial AEV."""
 34    angular_dist_divisions: int = 4
 35    """ Number of basis function to encode ditance in angular AEV."""
 36    zeta: float = 32.0
 37    """ The power parameter in angle embedding."""
 38    angle_sections: int = 4
 39    """ The number of angle sections."""
 40    radial_start: float = 0.8
 41    """ The starting distance in radial AEV."""
 42    angular_start: float = 0.8
 43    """ The starting distance in angular AEV."""
 44    embedding_key: str = "embedding"
 45    """ The key to use for the output embedding in the returned dictionary."""
 46    graph_key: str = "graph"
 47    """ The key in the input dictionary that corresponds to the radial graph."""
 48    keep_all_layers: bool = False
 49    """ If True, the output will contain the embeddings from all layers."""
 50
 51    activation: Union[Callable, str] = "swish"
 52    """ The activation function to use."""
 53    combination_neurons: Sequence[int] = dataclasses.field(
 54        default_factory=lambda: [256, 128, 16]
 55    )
 56    """ The number of neurons in the AFV combination network."""
 57
 58    embedding_neurons: Sequence[int] = dataclasses.field(
 59        default_factory=lambda: [512, 256, 256]
 60    )
 61    """ The number of neurons in the embedding network."""
 62    interaction_neurons: Sequence[int] = dataclasses.field(
 63        default_factory=lambda: [256, 256, 128]
 64    )
 65    """ The number of neurons in the interaction network."""
 66    afv_neurons: Sequence[int] = dataclasses.field(
 67        default_factory=lambda: [256, 256, 16]
 68    )
 69    """ The number of neurons in the AFV update network. The last number of neurons defines the size of AFV."""
 70
 71    FID: ClassVar[str] = "AIMNET"
 72
 73    @nn.compact
 74    def __call__(self, inputs):
 75        """Forward pass of the AIMNet model."""
 76        species = inputs["species"]
 77
 78        # species encoding (AFV)
 79        afv_dim = self.afv_neurons[-1]
 80        afv = SpeciesEncoding(dim=afv_dim, zmax=self.zmax, encoding="random")(species)
 81
 82        # Radial graph
 83        graph = inputs[self.graph_key]
 84        distances = graph["distances"]
 85        switch = graph["switch"]
 86        edge_src = graph["edge_src"]
 87        edge_dst = graph["edge_dst"]
 88
 89        # Radial AEV
 90        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 91        shiftR = jnp.asarray(
 92            np.linspace(self.radial_start, cutoff, self.radial_dist_divisions + 1)[
 93                None, :-1
 94            ],
 95            dtype=distances.dtype,
 96        )
 97        x2 = self.radial_eta * (distances[:, None] - shiftR) ** 2
 98        radial_terms = 0.25 * jnp.exp(-x2) * switch[:, None]
 99
100        # Angular graph
101        graph = inputs[self.graph_angle_key]
102        edge_dst_a = graph["edge_dst"]
103        angles = graph["angles"]
104        distances = graph["distances"]
105        central_atom = graph["central_atom"]
106        angle_src, angle_dst = graph["angle_src"], graph["angle_dst"]
107        angle_atom_1 = edge_dst_a[angle_src]
108        angle_atom_2 = edge_dst_a[angle_dst]
109        switch = graph["switch"]
110        d12 = 0.5 * (distances[angle_src] + distances[angle_dst])[:, None]
111
112        # Angular AEV parameters
113        angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"]
114        angle_start = np.pi / (2 * self.angle_sections)
115        shiftZ = jnp.asarray(
116            (np.linspace(0, np.pi, self.angle_sections + 1) + angle_start)[None, :-1],
117            dtype=distances.dtype,
118        )
119        shiftA = jnp.asarray(
120            np.linspace(
121                self.angular_start, angular_cutoff, self.angular_dist_divisions + 1
122            )[None, :-1],
123            dtype=distances.dtype,
124        )
125
126        # Angular AEV
127        factor1 = (0.5 + 0.5 * jnp.cos(angles[:, None] - shiftZ)) ** self.zeta
128        factor2 = jnp.exp(-self.angular_eta * (d12 - shiftA) ** 2)
129        angular_terms = (
130            (factor1[:, None, :] * factor2[:, :, None]).reshape(
131                -1, self.angle_sections * self.angular_dist_divisions
132            )
133            * 2
134            * (switch[angle_src] * switch[angle_dst])[:, None]
135        )
136
137        if self.keep_all_layers:
138            mis = []
139        for layer in range(self.nlayers):
140            # combine pair info
141            Gij = (radial_terms[:, None, :] * afv[edge_dst, :, None]).reshape(
142                radial_terms.shape[0], -1
143            )
144            Gri = jax.ops.segment_sum(Gij, edge_src, species.shape[0])
145
146            # combine triplet info
147            afv1 = afv[angle_atom_1]
148            afv2 = afv[angle_atom_2]
149            afv12 = jnp.concatenate((afv1 * afv2, afv1 + afv2), axis=-1)
150            afv_ang = FullyConnectedNet(
151                self.combination_neurons,
152                activation=self.activation,
153                name=f"combination_net_{layer}",
154            )(afv12)
155            Gijk = (angular_terms[:, None, :] * afv_ang[:, :, None]).reshape(
156                angular_terms.shape[0], -1
157            )
158            Gai = jax.ops.segment_sum(Gijk, central_atom, species.shape[0])
159
160            # environment field
161            fi = FullyConnectedNet(
162                self.embedding_neurons,
163                activation=self.activation,
164                name=f"embedding_net_{layer}",
165            )(jnp.concatenate((Gri, Gai), axis=-1))
166
167            # update AFV
168            dafv = FullyConnectedNet(
169                self.afv_neurons,
170                activation=self.activation,
171                name=f"afv_update_net_{layer}",
172            )(fi)
173            afv = afv + dafv
174
175            if self.keep_all_layers or layer == self.nlayers - 1:
176                # embedding
177                mi = FullyConnectedNet(
178                    self.interaction_neurons,
179                    activation=self.activation,
180                    name=f"interaction_net_{layer}",
181                )(fi)
182                if self.keep_all_layers:
183                    mis.append(mi)
184
185        embedding_key = (
186            self.embedding_key if self.embedding_key is not None else self.name
187        )
188
189        output = {**inputs, embedding_key: mi, embedding_key + "_afv": afv}
190        if self.keep_all_layers:
191            output[embedding_key + "_layers"] = jnp.stack(mis, axis=1)
192
193        return output
class AIMNet(flax.linen.module.Module):
 13class AIMNet(nn.Module):
 14    """Atom-In-Molecule Network message-passing embedding
 15
 16    FID : AIMNET
 17
 18    ### Reference
 19    Roman Zubatyuk et al. ,Accurate and transferable multitask prediction of chemical properties with an atoms-in-molecules neural network.Sci. Adv.5,eaav6490(2019).DOI:10.1126/sciadv.aav6490
 20    """
 21
 22    _graphs_properties: Dict
 23    graph_angle_key: str
 24    """ The key in the input dictionary that corresponds to the angular graph. """
 25    nlayers: int = 3
 26    """ The number of message-passing layers."""
 27    zmax: int = 86
 28    """ The maximum atomic number to allocate AFV."""
 29    radial_eta: float = 16.0
 30    """ Controls the width of the gaussian sensity functions in radial AEV."""
 31    angular_eta: float = 8.0
 32    """ Controls the width of the gaussian sensity functions in angular AEV."""
 33    radial_dist_divisions: int = 16
 34    """ Number of basis function to encode ditance in radial AEV."""
 35    angular_dist_divisions: int = 4
 36    """ Number of basis function to encode ditance in angular AEV."""
 37    zeta: float = 32.0
 38    """ The power parameter in angle embedding."""
 39    angle_sections: int = 4
 40    """ The number of angle sections."""
 41    radial_start: float = 0.8
 42    """ The starting distance in radial AEV."""
 43    angular_start: float = 0.8
 44    """ The starting distance in angular AEV."""
 45    embedding_key: str = "embedding"
 46    """ The key to use for the output embedding in the returned dictionary."""
 47    graph_key: str = "graph"
 48    """ The key in the input dictionary that corresponds to the radial graph."""
 49    keep_all_layers: bool = False
 50    """ If True, the output will contain the embeddings from all layers."""
 51
 52    activation: Union[Callable, str] = "swish"
 53    """ The activation function to use."""
 54    combination_neurons: Sequence[int] = dataclasses.field(
 55        default_factory=lambda: [256, 128, 16]
 56    )
 57    """ The number of neurons in the AFV combination network."""
 58
 59    embedding_neurons: Sequence[int] = dataclasses.field(
 60        default_factory=lambda: [512, 256, 256]
 61    )
 62    """ The number of neurons in the embedding network."""
 63    interaction_neurons: Sequence[int] = dataclasses.field(
 64        default_factory=lambda: [256, 256, 128]
 65    )
 66    """ The number of neurons in the interaction network."""
 67    afv_neurons: Sequence[int] = dataclasses.field(
 68        default_factory=lambda: [256, 256, 16]
 69    )
 70    """ The number of neurons in the AFV update network. The last number of neurons defines the size of AFV."""
 71
 72    FID: ClassVar[str] = "AIMNET"
 73
 74    @nn.compact
 75    def __call__(self, inputs):
 76        """Forward pass of the AIMNet model."""
 77        species = inputs["species"]
 78
 79        # species encoding (AFV)
 80        afv_dim = self.afv_neurons[-1]
 81        afv = SpeciesEncoding(dim=afv_dim, zmax=self.zmax, encoding="random")(species)
 82
 83        # Radial graph
 84        graph = inputs[self.graph_key]
 85        distances = graph["distances"]
 86        switch = graph["switch"]
 87        edge_src = graph["edge_src"]
 88        edge_dst = graph["edge_dst"]
 89
 90        # Radial AEV
 91        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 92        shiftR = jnp.asarray(
 93            np.linspace(self.radial_start, cutoff, self.radial_dist_divisions + 1)[
 94                None, :-1
 95            ],
 96            dtype=distances.dtype,
 97        )
 98        x2 = self.radial_eta * (distances[:, None] - shiftR) ** 2
 99        radial_terms = 0.25 * jnp.exp(-x2) * switch[:, None]
100
101        # Angular graph
102        graph = inputs[self.graph_angle_key]
103        edge_dst_a = graph["edge_dst"]
104        angles = graph["angles"]
105        distances = graph["distances"]
106        central_atom = graph["central_atom"]
107        angle_src, angle_dst = graph["angle_src"], graph["angle_dst"]
108        angle_atom_1 = edge_dst_a[angle_src]
109        angle_atom_2 = edge_dst_a[angle_dst]
110        switch = graph["switch"]
111        d12 = 0.5 * (distances[angle_src] + distances[angle_dst])[:, None]
112
113        # Angular AEV parameters
114        angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"]
115        angle_start = np.pi / (2 * self.angle_sections)
116        shiftZ = jnp.asarray(
117            (np.linspace(0, np.pi, self.angle_sections + 1) + angle_start)[None, :-1],
118            dtype=distances.dtype,
119        )
120        shiftA = jnp.asarray(
121            np.linspace(
122                self.angular_start, angular_cutoff, self.angular_dist_divisions + 1
123            )[None, :-1],
124            dtype=distances.dtype,
125        )
126
127        # Angular AEV
128        factor1 = (0.5 + 0.5 * jnp.cos(angles[:, None] - shiftZ)) ** self.zeta
129        factor2 = jnp.exp(-self.angular_eta * (d12 - shiftA) ** 2)
130        angular_terms = (
131            (factor1[:, None, :] * factor2[:, :, None]).reshape(
132                -1, self.angle_sections * self.angular_dist_divisions
133            )
134            * 2
135            * (switch[angle_src] * switch[angle_dst])[:, None]
136        )
137
138        if self.keep_all_layers:
139            mis = []
140        for layer in range(self.nlayers):
141            # combine pair info
142            Gij = (radial_terms[:, None, :] * afv[edge_dst, :, None]).reshape(
143                radial_terms.shape[0], -1
144            )
145            Gri = jax.ops.segment_sum(Gij, edge_src, species.shape[0])
146
147            # combine triplet info
148            afv1 = afv[angle_atom_1]
149            afv2 = afv[angle_atom_2]
150            afv12 = jnp.concatenate((afv1 * afv2, afv1 + afv2), axis=-1)
151            afv_ang = FullyConnectedNet(
152                self.combination_neurons,
153                activation=self.activation,
154                name=f"combination_net_{layer}",
155            )(afv12)
156            Gijk = (angular_terms[:, None, :] * afv_ang[:, :, None]).reshape(
157                angular_terms.shape[0], -1
158            )
159            Gai = jax.ops.segment_sum(Gijk, central_atom, species.shape[0])
160
161            # environment field
162            fi = FullyConnectedNet(
163                self.embedding_neurons,
164                activation=self.activation,
165                name=f"embedding_net_{layer}",
166            )(jnp.concatenate((Gri, Gai), axis=-1))
167
168            # update AFV
169            dafv = FullyConnectedNet(
170                self.afv_neurons,
171                activation=self.activation,
172                name=f"afv_update_net_{layer}",
173            )(fi)
174            afv = afv + dafv
175
176            if self.keep_all_layers or layer == self.nlayers - 1:
177                # embedding
178                mi = FullyConnectedNet(
179                    self.interaction_neurons,
180                    activation=self.activation,
181                    name=f"interaction_net_{layer}",
182                )(fi)
183                if self.keep_all_layers:
184                    mis.append(mi)
185
186        embedding_key = (
187            self.embedding_key if self.embedding_key is not None else self.name
188        )
189
190        output = {**inputs, embedding_key: mi, embedding_key + "_afv": afv}
191        if self.keep_all_layers:
192            output[embedding_key + "_layers"] = jnp.stack(mis, axis=1)
193
194        return output

Atom-In-Molecule Network message-passing embedding

FID : AIMNET

Reference

Roman Zubatyuk et al. ,Accurate and transferable multitask prediction of chemical properties with an atoms-in-molecules neural network.Sci. Adv.5,eaav6490(2019).DOI:10.1126/sciadv.aav6490

AIMNet( _graphs_properties: Dict, graph_angle_key: str, nlayers: int = 3, zmax: int = 86, radial_eta: float = 16.0, angular_eta: float = 8.0, radial_dist_divisions: int = 16, angular_dist_divisions: int = 4, zeta: float = 32.0, angle_sections: int = 4, radial_start: float = 0.8, angular_start: float = 0.8, embedding_key: str = 'embedding', graph_key: str = 'graph', keep_all_layers: bool = False, activation: Union[Callable, str] = 'swish', combination_neurons: Sequence[int] = <factory>, embedding_neurons: Sequence[int] = <factory>, interaction_neurons: Sequence[int] = <factory>, afv_neurons: Sequence[int] = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_angle_key: str

The key in the input dictionary that corresponds to the angular graph.

nlayers: int = 3

The number of message-passing layers.

zmax: int = 86

The maximum atomic number to allocate AFV.

radial_eta: float = 16.0

Controls the width of the gaussian sensity functions in radial AEV.

angular_eta: float = 8.0

Controls the width of the gaussian sensity functions in angular AEV.

radial_dist_divisions: int = 16

Number of basis function to encode ditance in radial AEV.

angular_dist_divisions: int = 4

Number of basis function to encode ditance in angular AEV.

zeta: float = 32.0

The power parameter in angle embedding.

angle_sections: int = 4

The number of angle sections.

radial_start: float = 0.8

The starting distance in radial AEV.

angular_start: float = 0.8

The starting distance in angular AEV.

embedding_key: str = 'embedding'

The key to use for the output embedding in the returned dictionary.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the radial graph.

keep_all_layers: bool = False

If True, the output will contain the embeddings from all layers.

activation: Union[Callable, str] = 'swish'

The activation function to use.

combination_neurons: Sequence[int]

The number of neurons in the AFV combination network.

embedding_neurons: Sequence[int]

The number of neurons in the embedding network.

interaction_neurons: Sequence[int]

The number of neurons in the interaction network.

afv_neurons: Sequence[int]

The number of neurons in the AFV update network. The last number of neurons defines the size of AFV.

FID: ClassVar[str] = 'AIMNET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None