fennol.models.embeddings.ani
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, Union, ClassVar 5import numpy as np 6from ...utils.periodic_table import PERIODIC_TABLE 7 8 9class ANIAEV(nn.Module): 10 """Computes the Atomic Environment Vector (AEV) for a given molecular system using the ANI model. 11 12 FID : ANI_AEV 13 14 ### Reference 15 J. S. Smith, O. Isayev and A. E. Roitberg, ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost, Chem. Sci., 2017, 8, 3192 16 """ 17 18 _graphs_properties: Dict 19 species_order: Union[str, Sequence[str]] 20 """ The chemical species which are considered by the model.""" 21 graph_angle_key: str 22 """ The key in the input dictionary that corresponds to the angular graph.""" 23 radial_eta: float = 16.0 24 """ Controls the width of the gaussian sensitivity functions in radial AEV.""" 25 angular_eta: float = 8.0 26 """ Controls the width of the gaussian sensitivity functions in angular AEV.""" 27 radial_dist_divisions: int = 16 28 """ Number of basis function to encode distance in radial AEV.""" 29 angular_dist_divisions: int = 4 30 """ Number of basis function to encode distance in angular AEV.""" 31 zeta: float = 32.0 32 """ The power parameter in angle embedding.""" 33 angle_sections: int = 4 34 """ The number of angle sections.""" 35 radial_start: float = 0.8 36 """ The starting distance in radial AEV.""" 37 angular_start: float = 0.8 38 """ The starting distance in angular AEV.""" 39 embedding_key: str = "embedding" 40 """ The key to use for the output embedding in the returned dictionary.""" 41 graph_key: str = "graph" 42 """ The key in the input dictionary that corresponds to the radial graph.""" 43 44 FID: ClassVar[str] = "ANI_AEV" 45 46 @nn.compact 47 def __call__(self, inputs): 48 species = inputs["species"] 49 rev_idx = {s: k for k, s in enumerate(PERIODIC_TABLE)} 50 maxidx = max(rev_idx.values()) 51 52 # convert species to internal indices 53 conv_tensor = [0] * (maxidx + 2) 54 if isinstance(self.species_order, str): 55 species_order = [el.strip() for el in self.species_order.split(",")] 56 else: 57 species_order = [el for el in self.species_order] 58 for i, s in enumerate(species_order): 59 conv_tensor[rev_idx[s]] = i 60 indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species] 61 num_species = len(species_order) 62 num_species_pair = (num_species * (num_species + 1)) // 2 63 64 # Radial graph 65 graph = inputs[self.graph_key] 66 distances = graph["distances"] 67 switch = graph["switch"] 68 edge_src = graph["edge_src"] 69 edge_dst = graph["edge_dst"] 70 71 # Radial AEV 72 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 73 shiftR = jnp.asarray( 74 - np.linspace(self.radial_start, cutoff, self.radial_dist_divisions + 1)[ 75 None, :-1 76 ], 77 dtype=distances.dtype, 78 ) 79 x2 = self.radial_eta * (distances[:, None] + shiftR) ** 2 80 radial_terms = jnp.exp(-x2) * (0.25*switch)[:, None] 81 # aggregate radial AEV 82 radial_index = edge_src * num_species + indices[edge_dst] 83 84 radial_aev = jax.ops.segment_sum( 85 radial_terms, radial_index, num_species * species.shape[0] 86 ).reshape(species.shape[0], num_species * radial_terms.shape[-1]) 87 88 # Angular graph 89 # eta2 = self.angular_eta ** 0.5 90 graph = inputs[self.graph_angle_key] 91 angles = graph["angles"] 92 distances = (0.5*self.angular_eta**0.5)*graph["distances"] 93 central_atom = graph["central_atom"] 94 angle_src, angle_dst = graph["angle_src"], graph["angle_dst"] 95 d12 = (distances[angle_src] + distances[angle_dst])[:, None] 96 97 # Angular AEV parameters 98 angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"] 99 angle_start = np.pi / (2 * self.angle_sections) 100 shiftZ = jnp.asarray( 101 -(np.linspace(0, np.pi, self.angle_sections + 1) + angle_start)[None, :-1], 102 dtype=distances.dtype, 103 ) 104 shiftA = jnp.asarray( 105 (-self.angular_eta**0.5)*np.linspace( 106 self.angular_start, angular_cutoff, self.angular_dist_divisions + 1 107 )[None, :-1], 108 dtype=distances.dtype, 109 ) 110 111 # Angular AEV 112 switch = graph["switch"] *(2.0 *(0.5**self.zeta))**0.5 113 factor1 = switch[angle_src] * switch[angle_dst] 114 factor1 = ( 115 factor1[:, None] * (1 + jnp.cos(angles[:, None] + shiftZ)) ** self.zeta 116 ) 117 factor2 = jnp.exp(-(d12 + shiftA) ** 2) 118 angular_terms = (factor1[:, None, :] * factor2[:, :, None]).reshape( 119 -1, self.angle_sections * self.angular_dist_divisions 120 ) 121 122 # aggregate angular AEV 123 index_dest = indices[graph["edge_dst"]] 124 species1, species2 = np.triu_indices(num_species, 0) 125 pair_index = np.arange(species1.shape[0], dtype=np.int32) 126 triu_index = np.zeros((num_species, num_species), dtype=np.int32) 127 triu_index[species1, species2] = pair_index 128 triu_index[species2, species1] = pair_index 129 triu_index = jnp.asarray(triu_index, dtype=jnp.int32) 130 angular_index = ( 131 central_atom * num_species_pair 132 + triu_index[index_dest[angle_src], index_dest[angle_dst]] 133 ) 134 135 angular_aev = jax.ops.segment_sum( 136 angular_terms, angular_index, num_species_pair * species.shape[0] 137 ).reshape(species.shape[0], num_species_pair * angular_terms.shape[-1]) 138 139 embedding = jnp.concatenate((radial_aev, angular_aev), axis=-1) 140 if self.embedding_key is None: 141 return embedding 142 return {**inputs, self.embedding_key: embedding}
10class ANIAEV(nn.Module): 11 """Computes the Atomic Environment Vector (AEV) for a given molecular system using the ANI model. 12 13 FID : ANI_AEV 14 15 ### Reference 16 J. S. Smith, O. Isayev and A. E. Roitberg, ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost, Chem. Sci., 2017, 8, 3192 17 """ 18 19 _graphs_properties: Dict 20 species_order: Union[str, Sequence[str]] 21 """ The chemical species which are considered by the model.""" 22 graph_angle_key: str 23 """ The key in the input dictionary that corresponds to the angular graph.""" 24 radial_eta: float = 16.0 25 """ Controls the width of the gaussian sensitivity functions in radial AEV.""" 26 angular_eta: float = 8.0 27 """ Controls the width of the gaussian sensitivity functions in angular AEV.""" 28 radial_dist_divisions: int = 16 29 """ Number of basis function to encode distance in radial AEV.""" 30 angular_dist_divisions: int = 4 31 """ Number of basis function to encode distance in angular AEV.""" 32 zeta: float = 32.0 33 """ The power parameter in angle embedding.""" 34 angle_sections: int = 4 35 """ The number of angle sections.""" 36 radial_start: float = 0.8 37 """ The starting distance in radial AEV.""" 38 angular_start: float = 0.8 39 """ The starting distance in angular AEV.""" 40 embedding_key: str = "embedding" 41 """ The key to use for the output embedding in the returned dictionary.""" 42 graph_key: str = "graph" 43 """ The key in the input dictionary that corresponds to the radial graph.""" 44 45 FID: ClassVar[str] = "ANI_AEV" 46 47 @nn.compact 48 def __call__(self, inputs): 49 species = inputs["species"] 50 rev_idx = {s: k for k, s in enumerate(PERIODIC_TABLE)} 51 maxidx = max(rev_idx.values()) 52 53 # convert species to internal indices 54 conv_tensor = [0] * (maxidx + 2) 55 if isinstance(self.species_order, str): 56 species_order = [el.strip() for el in self.species_order.split(",")] 57 else: 58 species_order = [el for el in self.species_order] 59 for i, s in enumerate(species_order): 60 conv_tensor[rev_idx[s]] = i 61 indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species] 62 num_species = len(species_order) 63 num_species_pair = (num_species * (num_species + 1)) // 2 64 65 # Radial graph 66 graph = inputs[self.graph_key] 67 distances = graph["distances"] 68 switch = graph["switch"] 69 edge_src = graph["edge_src"] 70 edge_dst = graph["edge_dst"] 71 72 # Radial AEV 73 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 74 shiftR = jnp.asarray( 75 - np.linspace(self.radial_start, cutoff, self.radial_dist_divisions + 1)[ 76 None, :-1 77 ], 78 dtype=distances.dtype, 79 ) 80 x2 = self.radial_eta * (distances[:, None] + shiftR) ** 2 81 radial_terms = jnp.exp(-x2) * (0.25*switch)[:, None] 82 # aggregate radial AEV 83 radial_index = edge_src * num_species + indices[edge_dst] 84 85 radial_aev = jax.ops.segment_sum( 86 radial_terms, radial_index, num_species * species.shape[0] 87 ).reshape(species.shape[0], num_species * radial_terms.shape[-1]) 88 89 # Angular graph 90 # eta2 = self.angular_eta ** 0.5 91 graph = inputs[self.graph_angle_key] 92 angles = graph["angles"] 93 distances = (0.5*self.angular_eta**0.5)*graph["distances"] 94 central_atom = graph["central_atom"] 95 angle_src, angle_dst = graph["angle_src"], graph["angle_dst"] 96 d12 = (distances[angle_src] + distances[angle_dst])[:, None] 97 98 # Angular AEV parameters 99 angular_cutoff = self._graphs_properties[self.graph_angle_key]["cutoff"] 100 angle_start = np.pi / (2 * self.angle_sections) 101 shiftZ = jnp.asarray( 102 -(np.linspace(0, np.pi, self.angle_sections + 1) + angle_start)[None, :-1], 103 dtype=distances.dtype, 104 ) 105 shiftA = jnp.asarray( 106 (-self.angular_eta**0.5)*np.linspace( 107 self.angular_start, angular_cutoff, self.angular_dist_divisions + 1 108 )[None, :-1], 109 dtype=distances.dtype, 110 ) 111 112 # Angular AEV 113 switch = graph["switch"] *(2.0 *(0.5**self.zeta))**0.5 114 factor1 = switch[angle_src] * switch[angle_dst] 115 factor1 = ( 116 factor1[:, None] * (1 + jnp.cos(angles[:, None] + shiftZ)) ** self.zeta 117 ) 118 factor2 = jnp.exp(-(d12 + shiftA) ** 2) 119 angular_terms = (factor1[:, None, :] * factor2[:, :, None]).reshape( 120 -1, self.angle_sections * self.angular_dist_divisions 121 ) 122 123 # aggregate angular AEV 124 index_dest = indices[graph["edge_dst"]] 125 species1, species2 = np.triu_indices(num_species, 0) 126 pair_index = np.arange(species1.shape[0], dtype=np.int32) 127 triu_index = np.zeros((num_species, num_species), dtype=np.int32) 128 triu_index[species1, species2] = pair_index 129 triu_index[species2, species1] = pair_index 130 triu_index = jnp.asarray(triu_index, dtype=jnp.int32) 131 angular_index = ( 132 central_atom * num_species_pair 133 + triu_index[index_dest[angle_src], index_dest[angle_dst]] 134 ) 135 136 angular_aev = jax.ops.segment_sum( 137 angular_terms, angular_index, num_species_pair * species.shape[0] 138 ).reshape(species.shape[0], num_species_pair * angular_terms.shape[-1]) 139 140 embedding = jnp.concatenate((radial_aev, angular_aev), axis=-1) 141 if self.embedding_key is None: 142 return embedding 143 return {**inputs, self.embedding_key: embedding}
Computes the Atomic Environment Vector (AEV) for a given molecular system using the ANI model.
FID : ANI_AEV
Reference
J. S. Smith, O. Isayev and A. E. Roitberg, ANI-1: an extensible neural network potential with DFT accuracy at force field computational cost, Chem. Sci., 2017, 8, 3192
The key to use for the output embedding in the returned dictionary.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.