fennol.models.embeddings.aimnet
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, Callable, ClassVar,Union 5import numpy as np 6from ...utils.periodic_table import PERIODIC_TABLE 7from ..misc.encodings import SpeciesEncoding 8from ..misc.nets import FullyConnectedNet 9import dataclasses 10 11 12class AIMNet(nn.Module): 13 """Atom-In-Molecule Network message-passing embedding 14 15 FID : AIMNET 16 17 ### Reference 18 Roman Zubatyuk et al. ,Accurate and transferable multitask prediction of chemical properties with an atoms-in-molecules neural network.Sci. Adv.5,eaav6490(2019).DOI:10.1126/sciadv.aav6490 19 """ 20 21 _graphs_properties: Dict 22 graph_angle_key: str 23 """ The key in the input dictionary that corresponds to the angular graph. """ 24 nlayers: int = 3 25 """ The number of message-passing layers.""" 26 zmax: int = 86 27 """ The maximum atomic number to allocate AFV.""" 28 radial_eta: float = 16.0 29 """ Controls the width of the gaussian sensity functions in radial AEV.""" 30 angular_eta: float = 8.0 31 """ Controls the width of the gaussian sensity functions in angular AEV.""" 32 radial_dist_divisions: int = 16 33 """ Number of basis function to encode ditance in radial AEV.""" 34 angular_dist_divisions: int = 4 35 """ Number of basis function to encode ditance in angular AEV.""" 36 zeta: float = 32.0 37 """ The power parameter in angle embedding.""" 38 angle_sections: int = 4 39 """ The number of angle sections.""" 40 radial_start: float = 0.8 41 """ The starting distance in radial AEV.""" 42 angular_start: float = 0.8 43 """ The starting distance in angular AEV.""" 44 embedding_key: str = "embedding" 45 """ The key to use for the output embedding in the returned dictionary.""" 46 graph_key: str = "graph" 47 """ The key in the input dictionary that corresponds to the radial graph.""" 48 keep_all_layers: bool = False 49 """ If True, the output will contain the embeddings from all layers.""" 50 51 activation: Union[Callable, str] = "swish" 52 """ The activation function to use.""" 53 combination_neurons: Sequence[int] = dataclasses.field( 54 default_factory=lambda: [256, 128, 16] 55 ) 56 """ The number of neurons in the AFV combination network.""" 57 58 embedding_neurons: Sequence[int] = dataclasses.field( 59 default_factory=lambda: [512, 256, 256] 60 ) 61 """ The number of neurons in the embedding network.""" 62 interaction_neurons: Sequence[int] = dataclasses.field( 63 default_factory=lambda: [256, 256, 128] 64 ) 65 """ The number of neurons in the interaction network.""" 66 afv_neurons: Sequence[int] = dataclasses.field( 67 default_factory=lambda: [256, 256, 16] 68 ) 69 """ The number of neurons in the AFV update network. The last number of neurons defines the size of AFV.""" 70 71 FID: ClassVar[str] = "AIMNET" 72 73 @nn.compact 74 def __call__(self, inputs): 75 """Forward pass of the AIMNet model.""" 76 species = inputs["species"] 77 78 # species encoding (AFV) 79 afv_dim = self.afv_neurons[-1] 80 afv = SpeciesEncoding(dim=afv_dim, zmax=self.zmax, encoding="random")(species) 81 82 # Radial graph 83 graph = inputs[self.graph_key] 84 distances = graph["distances"] 85 switch = graph["switch"] 86 edge_src = graph["edge_src"] 87 edge_dst = graph["edge_dst"] 88 89 # Radial AEV 90 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 91 shiftR = jnp.asarray( 92 np.linspace(self.radial_start, cutoff, self.radial_dist_divisions + 1)[ 93 None, :-1 94 ], 95 dtype=distances.dtype, 96 ) 97 x2 = self.radial_eta * (distances[:, None] - shiftR) ** 2 98 radial_terms = 0.25 * jnp.exp(-x2) * switch[:, None] 99 100 # Angular graph 101 graph = inputs[self.graph_angle_key] 102 edge_dst_a = graph["edge_dst"] 103 angles = graph["angles"] 104 distances = graph["distances"] 105 central_atom = graph["central_atom"] 106 angle_src, angle_dst = graph["angle_src"], graph["angle_dst"] 107 angle_atom_1 = edge_dst_a[angle_src] 108 angle_atom_2 = edge_dst_a[angle_dst] 109 switch = graph["switch"] 110 d12 = 0.5 * (distances[angle_src] + distances[angle_dst])[:, None] 111 112 # Angular AEV parameters 113 angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"] 114 angle_start = np.pi / (2 * self.angle_sections) 115 shiftZ = jnp.asarray( 116 (np.linspace(0, np.pi, self.angle_sections + 1) + angle_start)[None, :-1], 117 dtype=distances.dtype, 118 ) 119 shiftA = jnp.asarray( 120 np.linspace( 121 self.angular_start, angular_cutoff, self.angular_dist_divisions + 1 122 )[None, :-1], 123 dtype=distances.dtype, 124 ) 125 126 # Angular AEV 127 factor1 = (0.5 + 0.5 * jnp.cos(angles[:, None] - shiftZ)) ** self.zeta 128 factor2 = jnp.exp(-self.angular_eta * (d12 - shiftA) ** 2) 129 angular_terms = ( 130 (factor1[:, None, :] * factor2[:, :, None]).reshape( 131 -1, self.angle_sections * self.angular_dist_divisions 132 ) 133 * 2 134 * (switch[angle_src] * switch[angle_dst])[:, None] 135 ) 136 137 if self.keep_all_layers: 138 mis = [] 139 for layer in range(self.nlayers): 140 # combine pair info 141 Gij = (radial_terms[:, None, :] * afv[edge_dst, :, None]).reshape( 142 radial_terms.shape[0], -1 143 ) 144 Gri = jax.ops.segment_sum(Gij, edge_src, species.shape[0]) 145 146 # combine triplet info 147 afv1 = afv[angle_atom_1] 148 afv2 = afv[angle_atom_2] 149 afv12 = jnp.concatenate((afv1 * afv2, afv1 + afv2), axis=-1) 150 afv_ang = FullyConnectedNet( 151 self.combination_neurons, 152 activation=self.activation, 153 name=f"combination_net_{layer}", 154 )(afv12) 155 Gijk = (angular_terms[:, None, :] * afv_ang[:, :, None]).reshape( 156 angular_terms.shape[0], -1 157 ) 158 Gai = jax.ops.segment_sum(Gijk, central_atom, species.shape[0]) 159 160 # environment field 161 fi = FullyConnectedNet( 162 self.embedding_neurons, 163 activation=self.activation, 164 name=f"embedding_net_{layer}", 165 )(jnp.concatenate((Gri, Gai), axis=-1)) 166 167 # update AFV 168 dafv = FullyConnectedNet( 169 self.afv_neurons, 170 activation=self.activation, 171 name=f"afv_update_net_{layer}", 172 )(fi) 173 afv = afv + dafv 174 175 if self.keep_all_layers or layer == self.nlayers - 1: 176 # embedding 177 mi = FullyConnectedNet( 178 self.interaction_neurons, 179 activation=self.activation, 180 name=f"interaction_net_{layer}", 181 )(fi) 182 if self.keep_all_layers: 183 mis.append(mi) 184 185 embedding_key = ( 186 self.embedding_key if self.embedding_key is not None else self.name 187 ) 188 189 output = {**inputs, embedding_key: mi, embedding_key + "_afv": afv} 190 if self.keep_all_layers: 191 output[embedding_key + "_layers"] = jnp.stack(mis, axis=1) 192 193 return output
13class AIMNet(nn.Module): 14 """Atom-In-Molecule Network message-passing embedding 15 16 FID : AIMNET 17 18 ### Reference 19 Roman Zubatyuk et al. ,Accurate and transferable multitask prediction of chemical properties with an atoms-in-molecules neural network.Sci. Adv.5,eaav6490(2019).DOI:10.1126/sciadv.aav6490 20 """ 21 22 _graphs_properties: Dict 23 graph_angle_key: str 24 """ The key in the input dictionary that corresponds to the angular graph. """ 25 nlayers: int = 3 26 """ The number of message-passing layers.""" 27 zmax: int = 86 28 """ The maximum atomic number to allocate AFV.""" 29 radial_eta: float = 16.0 30 """ Controls the width of the gaussian sensity functions in radial AEV.""" 31 angular_eta: float = 8.0 32 """ Controls the width of the gaussian sensity functions in angular AEV.""" 33 radial_dist_divisions: int = 16 34 """ Number of basis function to encode ditance in radial AEV.""" 35 angular_dist_divisions: int = 4 36 """ Number of basis function to encode ditance in angular AEV.""" 37 zeta: float = 32.0 38 """ The power parameter in angle embedding.""" 39 angle_sections: int = 4 40 """ The number of angle sections.""" 41 radial_start: float = 0.8 42 """ The starting distance in radial AEV.""" 43 angular_start: float = 0.8 44 """ The starting distance in angular AEV.""" 45 embedding_key: str = "embedding" 46 """ The key to use for the output embedding in the returned dictionary.""" 47 graph_key: str = "graph" 48 """ The key in the input dictionary that corresponds to the radial graph.""" 49 keep_all_layers: bool = False 50 """ If True, the output will contain the embeddings from all layers.""" 51 52 activation: Union[Callable, str] = "swish" 53 """ The activation function to use.""" 54 combination_neurons: Sequence[int] = dataclasses.field( 55 default_factory=lambda: [256, 128, 16] 56 ) 57 """ The number of neurons in the AFV combination network.""" 58 59 embedding_neurons: Sequence[int] = dataclasses.field( 60 default_factory=lambda: [512, 256, 256] 61 ) 62 """ The number of neurons in the embedding network.""" 63 interaction_neurons: Sequence[int] = dataclasses.field( 64 default_factory=lambda: [256, 256, 128] 65 ) 66 """ The number of neurons in the interaction network.""" 67 afv_neurons: Sequence[int] = dataclasses.field( 68 default_factory=lambda: [256, 256, 16] 69 ) 70 """ The number of neurons in the AFV update network. The last number of neurons defines the size of AFV.""" 71 72 FID: ClassVar[str] = "AIMNET" 73 74 @nn.compact 75 def __call__(self, inputs): 76 """Forward pass of the AIMNet model.""" 77 species = inputs["species"] 78 79 # species encoding (AFV) 80 afv_dim = self.afv_neurons[-1] 81 afv = SpeciesEncoding(dim=afv_dim, zmax=self.zmax, encoding="random")(species) 82 83 # Radial graph 84 graph = inputs[self.graph_key] 85 distances = graph["distances"] 86 switch = graph["switch"] 87 edge_src = graph["edge_src"] 88 edge_dst = graph["edge_dst"] 89 90 # Radial AEV 91 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 92 shiftR = jnp.asarray( 93 np.linspace(self.radial_start, cutoff, self.radial_dist_divisions + 1)[ 94 None, :-1 95 ], 96 dtype=distances.dtype, 97 ) 98 x2 = self.radial_eta * (distances[:, None] - shiftR) ** 2 99 radial_terms = 0.25 * jnp.exp(-x2) * switch[:, None] 100 101 # Angular graph 102 graph = inputs[self.graph_angle_key] 103 edge_dst_a = graph["edge_dst"] 104 angles = graph["angles"] 105 distances = graph["distances"] 106 central_atom = graph["central_atom"] 107 angle_src, angle_dst = graph["angle_src"], graph["angle_dst"] 108 angle_atom_1 = edge_dst_a[angle_src] 109 angle_atom_2 = edge_dst_a[angle_dst] 110 switch = graph["switch"] 111 d12 = 0.5 * (distances[angle_src] + distances[angle_dst])[:, None] 112 113 # Angular AEV parameters 114 angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"] 115 angle_start = np.pi / (2 * self.angle_sections) 116 shiftZ = jnp.asarray( 117 (np.linspace(0, np.pi, self.angle_sections + 1) + angle_start)[None, :-1], 118 dtype=distances.dtype, 119 ) 120 shiftA = jnp.asarray( 121 np.linspace( 122 self.angular_start, angular_cutoff, self.angular_dist_divisions + 1 123 )[None, :-1], 124 dtype=distances.dtype, 125 ) 126 127 # Angular AEV 128 factor1 = (0.5 + 0.5 * jnp.cos(angles[:, None] - shiftZ)) ** self.zeta 129 factor2 = jnp.exp(-self.angular_eta * (d12 - shiftA) ** 2) 130 angular_terms = ( 131 (factor1[:, None, :] * factor2[:, :, None]).reshape( 132 -1, self.angle_sections * self.angular_dist_divisions 133 ) 134 * 2 135 * (switch[angle_src] * switch[angle_dst])[:, None] 136 ) 137 138 if self.keep_all_layers: 139 mis = [] 140 for layer in range(self.nlayers): 141 # combine pair info 142 Gij = (radial_terms[:, None, :] * afv[edge_dst, :, None]).reshape( 143 radial_terms.shape[0], -1 144 ) 145 Gri = jax.ops.segment_sum(Gij, edge_src, species.shape[0]) 146 147 # combine triplet info 148 afv1 = afv[angle_atom_1] 149 afv2 = afv[angle_atom_2] 150 afv12 = jnp.concatenate((afv1 * afv2, afv1 + afv2), axis=-1) 151 afv_ang = FullyConnectedNet( 152 self.combination_neurons, 153 activation=self.activation, 154 name=f"combination_net_{layer}", 155 )(afv12) 156 Gijk = (angular_terms[:, None, :] * afv_ang[:, :, None]).reshape( 157 angular_terms.shape[0], -1 158 ) 159 Gai = jax.ops.segment_sum(Gijk, central_atom, species.shape[0]) 160 161 # environment field 162 fi = FullyConnectedNet( 163 self.embedding_neurons, 164 activation=self.activation, 165 name=f"embedding_net_{layer}", 166 )(jnp.concatenate((Gri, Gai), axis=-1)) 167 168 # update AFV 169 dafv = FullyConnectedNet( 170 self.afv_neurons, 171 activation=self.activation, 172 name=f"afv_update_net_{layer}", 173 )(fi) 174 afv = afv + dafv 175 176 if self.keep_all_layers or layer == self.nlayers - 1: 177 # embedding 178 mi = FullyConnectedNet( 179 self.interaction_neurons, 180 activation=self.activation, 181 name=f"interaction_net_{layer}", 182 )(fi) 183 if self.keep_all_layers: 184 mis.append(mi) 185 186 embedding_key = ( 187 self.embedding_key if self.embedding_key is not None else self.name 188 ) 189 190 output = {**inputs, embedding_key: mi, embedding_key + "_afv": afv} 191 if self.keep_all_layers: 192 output[embedding_key + "_layers"] = jnp.stack(mis, axis=1) 193 194 return output
Atom-In-Molecule Network message-passing embedding
FID : AIMNET
Reference
Roman Zubatyuk et al. ,Accurate and transferable multitask prediction of chemical properties with an atoms-in-molecules neural network.Sci. Adv.5,eaav6490(2019).DOI:10.1126/sciadv.aav6490
The key to use for the output embedding in the returned dictionary.
The number of neurons in the AFV update network. The last number of neurons defines the size of AFV.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.