fennol.models.embeddings.schnet
Implementation of SchNet embedding.
Done by Côme Cattin, 2024.
1#!/usr/bin/env python3 2"""Implementation of SchNet embedding. 3 4Done by Côme Cattin, 2024. 5""" 6 7import dataclasses 8from typing import Callable, Dict, Sequence, Union, ClassVar 9 10import flax.linen as nn 11import jax 12 13from ...utils.activations import ssp 14from ..misc.encodings import RadialBasis, SpeciesEncoding 15from ..misc.nets import FullyConnectedNet 16 17 18class SchNetEmbedding(nn.Module): 19 """SchNet embedding. 20 21 Continuous filter convolutional neural network for modeling quantum 22 interactions. 23 24 ### References 25 SCHÜTT, Kristof, KINDERMANS, Pieter-Jan, SAUCEDA FELIX, Huziel Enoc, et al. 26 Schnet: A continuous-filter convolutional neural network for 27 modeling quantum interactions. 28 Advances in neural information processing systems, 2017, vol. 30. 29 https://proceedings.neurips.cc/paper_files/paper/2017/file/303ed4c69846ab36c2904d3ba8573050-Paper.pdf 30 31 Parameters 32 ---------- 33 dim : int, default=64 34 The dimension of the embedding. 35 nlayers : int, default=3 36 The number of interaction layers. 37 graph_key : str, default="graph" 38 The key for the graph input. 39 embedding_key : str, default="embedding" 40 The key for the embedding output. 41 radial_basis : dict, default={} 42 The radial basis function parameters. 43 species_encoding : dict, default={} 44 The species encoding parameters. 45 activation : Union[Callable, str], default=ssp 46 The activation function. 47 """ 48 49 _graphs_properties: Dict 50 dim: int = 64 51 """The dimension of the embedding.""" 52 nlayers: int = 3 53 """The number of interaction layers.""" 54 conv_hidden: Sequence[int] = dataclasses.field( 55 default_factory=lambda: [64, 64] 56 ) 57 """The hidden layers for the edge network.""" 58 graph_key: str = "graph" 59 """The key for the graph input.""" 60 embedding_key: str = "embedding" 61 """The key for the embedding output.""" 62 radial_basis: dict = dataclasses.field(default_factory=dict) 63 """The radial basis function parameters. See `fennol.models.misc.encodings.RadialBasis`.""" 64 species_encoding: dict = dataclasses.field(default_factory=dict) 65 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 66 activation: Union[Callable, str] = "ssp" 67 """The activation function.""" 68 69 FID: ClassVar[str] = "SCHNET" 70 71 @nn.compact 72 def __call__(self, inputs): 73 """Forward pass.""" 74 species = inputs["species"] 75 graph = inputs[self.graph_key] 76 switch = graph["switch"][:, None] 77 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 78 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 79 onehot = SpeciesEncoding( 80 **self.species_encoding, name="SpeciesEncoding" 81 )(species) 82 83 xi_prev_layer = nn.Dense( 84 self.dim, name="species_linear", use_bias=False 85 )(onehot) 86 87 distances = graph["distances"] 88 radial_basis = RadialBasis( 89 **{ 90 "end": cutoff, 91 **self.radial_basis, 92 "name": "RadialBasis", 93 } 94 )(distances) 95 96 def atom_wise(xi, i, layer): 97 return nn.Dense( 98 self.dim, name=f"atom_wise_{i}_{layer}", use_bias=True 99 )(xi) 100 101 # Interaction layer 102 for layer in range(self.nlayers): 103 # Atom-wise 104 xi = atom_wise(xi_prev_layer, 1, layer) 105 106 # cfconv 107 w_l = FullyConnectedNet( 108 [*self.conv_hidden, self.dim], 109 activation=self.activation, 110 name=f"filter_weight_{layer}", 111 use_bias=True, 112 )(radial_basis) 113 xi_j = xi[edge_dst] 114 xi = jax.ops.segment_sum( 115 self.activation(w_l) * xi_j * switch, edge_src, species.shape[0] 116 ) 117 118 # Atom-wise 119 xi = atom_wise(xi, 2, layer) 120 121 # Activation 122 xi = self.activation(xi) 123 124 # Atom-wise 125 xi = atom_wise(xi, 3, layer) 126 127 # Residual connection 128 xi = xi + xi_prev_layer 129 xi_prev_layer = xi 130 131 output = { 132 **inputs, 133 self.embedding_key: xi, 134 } 135 return output 136 137 138if __name__ == "__main__": 139 pass
19class SchNetEmbedding(nn.Module): 20 """SchNet embedding. 21 22 Continuous filter convolutional neural network for modeling quantum 23 interactions. 24 25 ### References 26 SCHÜTT, Kristof, KINDERMANS, Pieter-Jan, SAUCEDA FELIX, Huziel Enoc, et al. 27 Schnet: A continuous-filter convolutional neural network for 28 modeling quantum interactions. 29 Advances in neural information processing systems, 2017, vol. 30. 30 https://proceedings.neurips.cc/paper_files/paper/2017/file/303ed4c69846ab36c2904d3ba8573050-Paper.pdf 31 32 Parameters 33 ---------- 34 dim : int, default=64 35 The dimension of the embedding. 36 nlayers : int, default=3 37 The number of interaction layers. 38 graph_key : str, default="graph" 39 The key for the graph input. 40 embedding_key : str, default="embedding" 41 The key for the embedding output. 42 radial_basis : dict, default={} 43 The radial basis function parameters. 44 species_encoding : dict, default={} 45 The species encoding parameters. 46 activation : Union[Callable, str], default=ssp 47 The activation function. 48 """ 49 50 _graphs_properties: Dict 51 dim: int = 64 52 """The dimension of the embedding.""" 53 nlayers: int = 3 54 """The number of interaction layers.""" 55 conv_hidden: Sequence[int] = dataclasses.field( 56 default_factory=lambda: [64, 64] 57 ) 58 """The hidden layers for the edge network.""" 59 graph_key: str = "graph" 60 """The key for the graph input.""" 61 embedding_key: str = "embedding" 62 """The key for the embedding output.""" 63 radial_basis: dict = dataclasses.field(default_factory=dict) 64 """The radial basis function parameters. See `fennol.models.misc.encodings.RadialBasis`.""" 65 species_encoding: dict = dataclasses.field(default_factory=dict) 66 """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 67 activation: Union[Callable, str] = "ssp" 68 """The activation function.""" 69 70 FID: ClassVar[str] = "SCHNET" 71 72 @nn.compact 73 def __call__(self, inputs): 74 """Forward pass.""" 75 species = inputs["species"] 76 graph = inputs[self.graph_key] 77 switch = graph["switch"][:, None] 78 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 79 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 80 onehot = SpeciesEncoding( 81 **self.species_encoding, name="SpeciesEncoding" 82 )(species) 83 84 xi_prev_layer = nn.Dense( 85 self.dim, name="species_linear", use_bias=False 86 )(onehot) 87 88 distances = graph["distances"] 89 radial_basis = RadialBasis( 90 **{ 91 "end": cutoff, 92 **self.radial_basis, 93 "name": "RadialBasis", 94 } 95 )(distances) 96 97 def atom_wise(xi, i, layer): 98 return nn.Dense( 99 self.dim, name=f"atom_wise_{i}_{layer}", use_bias=True 100 )(xi) 101 102 # Interaction layer 103 for layer in range(self.nlayers): 104 # Atom-wise 105 xi = atom_wise(xi_prev_layer, 1, layer) 106 107 # cfconv 108 w_l = FullyConnectedNet( 109 [*self.conv_hidden, self.dim], 110 activation=self.activation, 111 name=f"filter_weight_{layer}", 112 use_bias=True, 113 )(radial_basis) 114 xi_j = xi[edge_dst] 115 xi = jax.ops.segment_sum( 116 self.activation(w_l) * xi_j * switch, edge_src, species.shape[0] 117 ) 118 119 # Atom-wise 120 xi = atom_wise(xi, 2, layer) 121 122 # Activation 123 xi = self.activation(xi) 124 125 # Atom-wise 126 xi = atom_wise(xi, 3, layer) 127 128 # Residual connection 129 xi = xi + xi_prev_layer 130 xi_prev_layer = xi 131 132 output = { 133 **inputs, 134 self.embedding_key: xi, 135 } 136 return output
SchNet embedding.
Continuous filter convolutional neural network for modeling quantum interactions.
References
SCHÜTT, Kristof, KINDERMANS, Pieter-Jan, SAUCEDA FELIX, Huziel Enoc, et al. Schnet: A continuous-filter convolutional neural network for modeling quantum interactions. Advances in neural information processing systems, 2017, vol. 30. https://proceedings.neurips.cc/paper_files/paper/2017/file/303ed4c69846ab36c2904d3ba8573050-Paper.pdf
Parameters
- dim (int, default=64): The dimension of the embedding.
- nlayers (int, default=3): The number of interaction layers.
- graph_key (str, default="graph"): The key for the graph input.
- embedding_key (str, default="embedding"): The key for the embedding output.
- radial_basis (dict, default={}): The radial basis function parameters.
- species_encoding (dict, default={}): The species encoding parameters.
- activation (Union[Callable, str], default=ssp): The activation function.
The radial basis function parameters. See fennol.models.misc.encodings.RadialBasis
.
The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.