fennol.models.embeddings.schnet

Implementation of SchNet embedding.

Done by Côme Cattin, 2024.

  1#!/usr/bin/env python3
  2"""Implementation of SchNet embedding.
  3
  4Done by Côme Cattin, 2024.
  5"""
  6
  7import dataclasses
  8from typing import Callable, Dict, Sequence, Union, ClassVar
  9
 10import flax.linen as nn
 11import jax
 12
 13from ...utils.activations import ssp
 14from ..misc.encodings import RadialBasis, SpeciesEncoding
 15from ..misc.nets import FullyConnectedNet
 16
 17
 18class SchNetEmbedding(nn.Module):
 19    """SchNet embedding.
 20
 21    Continuous filter convolutional neural network for modeling quantum
 22    interactions.
 23
 24    ### References
 25    SCHÜTT, Kristof, KINDERMANS, Pieter-Jan, SAUCEDA FELIX, Huziel Enoc, et al.
 26    Schnet: A continuous-filter convolutional neural network for
 27    modeling quantum interactions.
 28    Advances in neural information processing systems, 2017, vol. 30.
 29    https://proceedings.neurips.cc/paper_files/paper/2017/file/303ed4c69846ab36c2904d3ba8573050-Paper.pdf
 30
 31    Parameters
 32    ----------
 33    dim : int, default=64
 34        The dimension of the embedding.
 35    nlayers : int, default=3
 36        The number of interaction layers.
 37    graph_key : str, default="graph"
 38        The key for the graph input.
 39    embedding_key : str, default="embedding"
 40        The key for the embedding output.
 41    radial_basis : dict, default={}
 42        The radial basis function parameters.
 43    species_encoding : dict, default={}
 44        The species encoding parameters.
 45    activation : Union[Callable, str], default=ssp
 46        The activation function.
 47    """
 48
 49    _graphs_properties: Dict
 50    dim: int = 64
 51    """The dimension of the embedding."""
 52    nlayers: int = 3
 53    """The number of interaction layers."""
 54    conv_hidden: Sequence[int] = dataclasses.field(
 55        default_factory=lambda: [64, 64]
 56    )
 57    """The hidden layers for the edge network."""
 58    graph_key: str = "graph"
 59    """The key for the graph input."""
 60    embedding_key: str = "embedding"
 61    """The key for the embedding output."""
 62    radial_basis: dict = dataclasses.field(default_factory=dict)
 63    """The radial basis function parameters. See `fennol.models.misc.encodings.RadialBasis`."""
 64    species_encoding: dict = dataclasses.field(default_factory=dict)
 65    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 66    activation: Union[Callable, str] = "ssp"
 67    """The activation function."""
 68
 69    FID: ClassVar[str] = "SCHNET"
 70
 71    @nn.compact
 72    def __call__(self, inputs):
 73        """Forward pass."""
 74        species = inputs["species"]
 75        graph = inputs[self.graph_key]
 76        switch = graph["switch"][:, None]
 77        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 78        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 79        onehot = SpeciesEncoding(
 80            **self.species_encoding, name="SpeciesEncoding"
 81        )(species)
 82
 83        xi_prev_layer = nn.Dense(
 84            self.dim, name="species_linear", use_bias=False
 85        )(onehot)
 86
 87        distances = graph["distances"]
 88        radial_basis = RadialBasis(
 89            **{
 90                "end": cutoff,
 91                **self.radial_basis,
 92                "name": "RadialBasis",
 93            }
 94        )(distances)
 95
 96        def atom_wise(xi, i, layer):
 97            return nn.Dense(
 98                self.dim, name=f"atom_wise_{i}_{layer}", use_bias=True
 99            )(xi)
100
101        # Interaction layer
102        for layer in range(self.nlayers):
103            # Atom-wise
104            xi = atom_wise(xi_prev_layer, 1, layer)
105
106            # cfconv
107            w_l = FullyConnectedNet(
108                [*self.conv_hidden, self.dim],
109                activation=self.activation,
110                name=f"filter_weight_{layer}",
111                use_bias=True,
112            )(radial_basis)
113            xi_j = xi[edge_dst]
114            xi = jax.ops.segment_sum(
115                self.activation(w_l) * xi_j * switch, edge_src, species.shape[0]
116            )
117
118            # Atom-wise
119            xi = atom_wise(xi, 2, layer)
120
121            # Activation
122            xi = self.activation(xi)
123
124            # Atom-wise
125            xi = atom_wise(xi, 3, layer)
126
127            # Residual connection
128            xi = xi + xi_prev_layer
129            xi_prev_layer = xi
130
131        output = {
132            **inputs,
133            self.embedding_key: xi,
134        }
135        return output
136
137
138if __name__ == "__main__":
139    pass
class SchNetEmbedding(flax.linen.module.Module):
 19class SchNetEmbedding(nn.Module):
 20    """SchNet embedding.
 21
 22    Continuous filter convolutional neural network for modeling quantum
 23    interactions.
 24
 25    ### References
 26    SCHÜTT, Kristof, KINDERMANS, Pieter-Jan, SAUCEDA FELIX, Huziel Enoc, et al.
 27    Schnet: A continuous-filter convolutional neural network for
 28    modeling quantum interactions.
 29    Advances in neural information processing systems, 2017, vol. 30.
 30    https://proceedings.neurips.cc/paper_files/paper/2017/file/303ed4c69846ab36c2904d3ba8573050-Paper.pdf
 31
 32    Parameters
 33    ----------
 34    dim : int, default=64
 35        The dimension of the embedding.
 36    nlayers : int, default=3
 37        The number of interaction layers.
 38    graph_key : str, default="graph"
 39        The key for the graph input.
 40    embedding_key : str, default="embedding"
 41        The key for the embedding output.
 42    radial_basis : dict, default={}
 43        The radial basis function parameters.
 44    species_encoding : dict, default={}
 45        The species encoding parameters.
 46    activation : Union[Callable, str], default=ssp
 47        The activation function.
 48    """
 49
 50    _graphs_properties: Dict
 51    dim: int = 64
 52    """The dimension of the embedding."""
 53    nlayers: int = 3
 54    """The number of interaction layers."""
 55    conv_hidden: Sequence[int] = dataclasses.field(
 56        default_factory=lambda: [64, 64]
 57    )
 58    """The hidden layers for the edge network."""
 59    graph_key: str = "graph"
 60    """The key for the graph input."""
 61    embedding_key: str = "embedding"
 62    """The key for the embedding output."""
 63    radial_basis: dict = dataclasses.field(default_factory=dict)
 64    """The radial basis function parameters. See `fennol.models.misc.encodings.RadialBasis`."""
 65    species_encoding: dict = dataclasses.field(default_factory=dict)
 66    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 67    activation: Union[Callable, str] = "ssp"
 68    """The activation function."""
 69
 70    FID: ClassVar[str] = "SCHNET"
 71
 72    @nn.compact
 73    def __call__(self, inputs):
 74        """Forward pass."""
 75        species = inputs["species"]
 76        graph = inputs[self.graph_key]
 77        switch = graph["switch"][:, None]
 78        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 79        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 80        onehot = SpeciesEncoding(
 81            **self.species_encoding, name="SpeciesEncoding"
 82        )(species)
 83
 84        xi_prev_layer = nn.Dense(
 85            self.dim, name="species_linear", use_bias=False
 86        )(onehot)
 87
 88        distances = graph["distances"]
 89        radial_basis = RadialBasis(
 90            **{
 91                "end": cutoff,
 92                **self.radial_basis,
 93                "name": "RadialBasis",
 94            }
 95        )(distances)
 96
 97        def atom_wise(xi, i, layer):
 98            return nn.Dense(
 99                self.dim, name=f"atom_wise_{i}_{layer}", use_bias=True
100            )(xi)
101
102        # Interaction layer
103        for layer in range(self.nlayers):
104            # Atom-wise
105            xi = atom_wise(xi_prev_layer, 1, layer)
106
107            # cfconv
108            w_l = FullyConnectedNet(
109                [*self.conv_hidden, self.dim],
110                activation=self.activation,
111                name=f"filter_weight_{layer}",
112                use_bias=True,
113            )(radial_basis)
114            xi_j = xi[edge_dst]
115            xi = jax.ops.segment_sum(
116                self.activation(w_l) * xi_j * switch, edge_src, species.shape[0]
117            )
118
119            # Atom-wise
120            xi = atom_wise(xi, 2, layer)
121
122            # Activation
123            xi = self.activation(xi)
124
125            # Atom-wise
126            xi = atom_wise(xi, 3, layer)
127
128            # Residual connection
129            xi = xi + xi_prev_layer
130            xi_prev_layer = xi
131
132        output = {
133            **inputs,
134            self.embedding_key: xi,
135        }
136        return output

SchNet embedding.

Continuous filter convolutional neural network for modeling quantum interactions.

References

SCHÜTT, Kristof, KINDERMANS, Pieter-Jan, SAUCEDA FELIX, Huziel Enoc, et al. Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. Advances in neural information processing systems, 2017, vol. 30. https://proceedings.neurips.cc/paper_files/paper/2017/file/303ed4c69846ab36c2904d3ba8573050-Paper.pdf

Parameters
  • dim (int, default=64): The dimension of the embedding.
  • nlayers (int, default=3): The number of interaction layers.
  • graph_key (str, default="graph"): The key for the graph input.
  • embedding_key (str, default="embedding"): The key for the embedding output.
  • radial_basis (dict, default={}): The radial basis function parameters.
  • species_encoding (dict, default={}): The species encoding parameters.
  • activation (Union[Callable, str], default=ssp): The activation function.
SchNetEmbedding( _graphs_properties: Dict, dim: int = 64, nlayers: int = 3, conv_hidden: Sequence[int] = <factory>, graph_key: str = 'graph', embedding_key: str = 'embedding', radial_basis: dict = <factory>, species_encoding: dict = <factory>, activation: Union[Callable, str] = 'ssp', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 64

The dimension of the embedding.

nlayers: int = 3

The number of interaction layers.

conv_hidden: Sequence[int]

The hidden layers for the edge network.

graph_key: str = 'graph'

The key for the graph input.

embedding_key: str = 'embedding'

The key for the embedding output.

radial_basis: dict

The radial basis function parameters. See fennol.models.misc.encodings.RadialBasis.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding.

activation: Union[Callable, str] = 'ssp'

The activation function.

FID: ClassVar[str] = 'SCHNET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None