fennol.models.embeddings.raster
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, Union, ClassVar, Optional 5import numpy as np 6from ...utils.periodic_table import D3_COV_RADII 7import dataclasses 8from ..misc.encodings import RadialBasis, SpeciesEncoding 9from ...utils.spherical_harmonics import generate_spherical_harmonics 10from ..misc.e3 import ChannelMixing 11from ..misc.nets import FullyConnectedNet,BlockIndexNet 12from ...utils.activations import activation_from_str 13from ...utils import AtomicUnits as au 14 15class RaSTER(nn.Module): 16 """ Range-Separated Transformer with Equivariant Representations 17 18 FID : RASTER 19 20 """ 21 22 _graphs_properties: Dict 23 dim: int = 176 24 """The dimension of the output embedding.""" 25 nlayers: int = 2 26 """The number of message-passing layers.""" 27 att_dim: int = 16 28 """The dimension of the attention heads.""" 29 scal_heads: int = 16 30 """The number of scalar attention heads.""" 31 tens_heads: int = 4 32 """The number of tensor attention heads.""" 33 lmax: int = 3 34 """The maximum angular momentum to consider.""" 35 normalize_vec: bool = True 36 """Whether to normalize the vector features before computing spherical harmonics.""" 37 att_activation: str = "identity" 38 """The activation function to use for the attention coefficients.""" 39 activation: str = "swish" 40 """The activation function to use for the update network.""" 41 update_hidden: Sequence[int] = () 42 """The hidden layers for the update network.""" 43 update_bias: bool = True 44 """Whether to use bias in the update network.""" 45 positional_activation: str = "swish" 46 """The activation function to use for the positional embedding network.""" 47 positional_bias: bool = True 48 """Whether to use bias in the positional embedding network.""" 49 switch_before_net: bool = False 50 """Whether to apply the switch function to the radial basis before the edge neural network.""" 51 ignore_parity: bool = False 52 """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.""" 53 additive_positional: bool = False 54 """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.""" 55 edge_value: bool = False 56 """Whether to use edge values in the attention mechanism.""" 57 layer_normalization: bool = True 58 """Whether to use layer normalization of atomic embeddings.""" 59 graph_key: str = "graph" 60 """ The key in the input dictionary that corresponds to the radial graph.""" 61 embedding_key: str = "embedding" 62 """ The key in the output dictionary that corresponds to the embedding.""" 63 radial_basis: dict = dataclasses.field( 64 default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16} 65 ) 66 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 67 species_encoding: str | dict = dataclasses.field( 68 default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"} 69 ) 70 """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 71 graph_lode: Optional[str] = None 72 """The key in the input dictionary that corresponds to the long-range graph.""" 73 lmax_lode: int = 0 74 """The maximum angular momentum for the long-range features.""" 75 lode_rshort: Optional[float] = None 76 """The short-range cutoff for the long-range features.""" 77 lode_dshort: float = 2.0 78 """The width of the short-range cutoff for the long-range features.""" 79 lode_extra_powers: Sequence[int] = () 80 """The extra powers to include in the long-range features.""" 81 a_lode: float = -1.0 82 """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).""" 83 block_index_key: Optional[str] = None 84 """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.""" 85 lode_channels: int = 1 86 """The number of channels for the long-range features.""" 87 switch_cov_start: float = 0.5 88 """The start of close-range covalent switch (in units of covalent radii).""" 89 switch_cov_end: float = 0.6 90 """The end of close-range covalent switch (in units of covalent radii).""" 91 normalize_keys: bool = False 92 """Whether to normalize queries and keys in the attention mechanism.""" 93 keep_all_layers: bool = False 94 """Whether to return the stacked scalar embeddings from all message-passing layers.""" 95 96 FID: ClassVar[str] = "RASTER" 97 98 @nn.compact 99 def __call__(self, inputs): 100 species = inputs["species"] 101 102 ## SETUP LAYER NORMALIZATION 103 def _layer_norm(x): 104 mu = jnp.mean(x, axis=-1, keepdims=True) 105 dx = x - mu 106 var = jnp.mean(dx**2, axis=-1, keepdims=True) 107 sig = (1.0e-6 + var) ** (-0.5) 108 return dx * sig 109 110 if self.layer_normalization: 111 layer_norm = _layer_norm 112 else: 113 layer_norm = lambda x: x 114 115 if self.normalize_keys: 116 ln_qk = _layer_norm 117 else: 118 ln_qk = lambda x: x 119 120 ## SPECIES ENCODING 121 if isinstance(self.species_encoding, str): 122 Zi = inputs[self.species_encoding] 123 else: 124 Zi = SpeciesEncoding(**self.species_encoding)(species) 125 126 ## INITIALIZE SCALAR FEATURES 127 xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear")(Zi)) 128 129 # RADIAL GRAPH 130 graph = inputs[self.graph_key] 131 distances = graph["distances"] 132 switch = graph["switch"] 133 edge_src = graph["edge_src"] 134 edge_dst = graph["edge_dst"] 135 vec = ( 136 graph["vec"] / graph["distances"][:, None] 137 if self.normalize_vec 138 else graph["vec"] 139 ) 140 ## CLOSE-RANGE SWITCH 141 use_switch_cov = False 142 if self.switch_cov_end > 0 and self.switch_cov_start > 0: 143 use_switch_cov = True 144 assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}" 145 assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1" 146 rc = jnp.array(D3_COV_RADII*au.BOHR)[species] 147 rcij = rc[edge_src] + rc[edge_dst] 148 rstart = rcij * self.switch_cov_start 149 rend = rcij * self.switch_cov_end 150 switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend) 151 switch = switch * switch_short 152 153 ## COMPUTE SPHERICAL HARMONICS ON EDGES 154 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:] 155 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 156 ls = np.arange(self.lmax + 1).repeat(nrep) 157 158 parity = jnp.array((-1) ** ls[None,None,:]) 159 if self.ignore_parity: 160 parity = -jnp.ones_like(parity) 161 162 ## INITIALIZE TENSOR FEATURES 163 Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1])) 164 165 # RADIAL BASIS 166 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 167 radial_terms = RadialBasis( 168 **{ 169 **self.radial_basis, 170 "end": cutoff, 171 "name": f"RadialBasis", 172 } 173 )(distances) 174 if self.switch_before_net: 175 radial_terms = radial_terms * switch[:, None] 176 elif use_switch_cov: 177 radial_terms = radial_terms * switch_short[:, None] 178 179 ## INITIALIZE LODE 180 do_lode = self.graph_lode is not None 181 if do_lode: 182 ## LONG-RANGE GRAPH 183 graph_lode = inputs[self.graph_lode] 184 switch_lode = graph_lode["switch"][:, None] 185 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 186 r = graph_lode["distances"][:, None] 187 rc = self._graphs_properties[self.graph_lode]["cutoff"] 188 189 lmax_lr = self.lmax_lode 190 equivariant_lode = lmax_lr > 0 191 assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}" 192 assert ( 193 lmax_lr <= self.lmax 194 ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 195 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 196 if equivariant_lode: 197 ls_lr = np.arange(lmax_lr + 1) 198 else: 199 ls_lr = np.array([0]) 200 201 ## PARAMETERS FOR THE LR RADIAL BASIS 202 nextra_powers = len(self.lode_extra_powers) 203 if nextra_powers > 0: 204 ls_lr = np.concatenate([self.lode_extra_powers, ls_lr]) 205 206 if self.a_lode > 0: 207 a = self.a_lode**2 208 else: 209 a = ( 210 self.param( 211 "a_lr", 212 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[ 213 None, : 214 ], 215 ) 216 ** 2 217 ) 218 rc2a = rc**2 + a 219 ls_lr = 0.5 * (ls_lr[None, :] + 1) 220 ### minimal radial basis for long range (damped coulomb) 221 eij_lr = ( 222 1.0 / (r**2 + a) ** ls_lr 223 - 1.0 / rc2a**ls_lr 224 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 225 ) * switch_lode 226 227 if self.lode_rshort is not None: 228 rs = self.lode_rshort 229 d = self.lode_dshort 230 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 231 r < rs + d 232 ) + (r >= rs + d) 233 eij_lr = eij_lr * switch_short 234 235 dim_lr = 1 236 if nextra_powers > 0: 237 eij_lr_extra = eij_lr[:, :nextra_powers] 238 eij_lr = eij_lr[:, nextra_powers:] 239 dim_lr += nextra_powers 240 241 if equivariant_lode: 242 ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH 243 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 244 Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 245 graph_lode["vec"] / r 246 ) 247 dim_lr += lmax_lr 248 eij_lr = eij_lr * Yij_lr 249 del Yij_lr 250 251 252 if self.keep_all_layers: 253 fis = [] 254 255 ### START MESSAGE PASSING ITERATIONS 256 for layer in range(self.nlayers): 257 ## GATHER SCALAR EDGE FEATURES 258 u = [radial_terms] 259 if layer > 0: 260 ## edge-tensor contraction 261 xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij 262 for l in range(self.lmax + 1): 263 u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1)) 264 ur = jnp.concatenate(u, axis=-1) 265 266 ## BUILD RELATIVE POSITIONAL ENCODING 267 if self.edge_value: 268 nout = 2 269 else: 270 nout = 1 271 w = FullyConnectedNet( 272 [2 * self.att_dim, nout*self.att_dim], 273 activation=self.positional_activation, 274 use_bias=self.positional_bias, 275 name=f"positional_encoding_{layer}", 276 )(ur).reshape(radial_terms.shape[0],nout, self.att_dim) 277 if self.edge_value: 278 w,vij = jnp.split(w, 2, axis=1) 279 280 nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1) 281 282 283 ## QUERY, KEY, VALUE 284 q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}")( 285 xi 286 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)) 287 k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}")( 288 xi 289 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim) 290 291 v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}")(xi).reshape( 292 xi.shape[0], self.scal_heads, self.att_dim 293 ) 294 295 ## ATTENTION COEFFICIENTS 296 if self.additive_positional: 297 wk = ln_qk(w + k[edge_dst]) 298 else: 299 wk = ln_qk(w * k[edge_dst]) 300 301 act = activation_from_str(self.att_activation) 302 aij = ( 303 act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5)) 304 * switch[:, None] 305 ) 306 307 aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 308 if layer > 0: 309 aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 310 aij = aij[:, self.tens_heads*nls:, None] 311 312 if self.edge_value: 313 ## EDGE VALUES 314 if self.additive_positional: 315 vij = vij + v[edge_dst] 316 else: 317 vij = vij * v[edge_dst] 318 else: 319 ## MOVE DEST VALUES TO EDGE 320 vij = v[edge_dst] 321 322 ## SCALAR ATTENDED FEATURES 323 vai = jax.ops.segment_sum( 324 aij * vij, 325 edge_src, 326 num_segments=xi.shape[0], 327 ) 328 vai = vai.reshape(xi.shape[0], -1) 329 330 ### TENSOR ATTENDED FEATURES 331 uij = aijl * Yij 332 if layer > 0: 333 uij = uij + aijl1 * Vi[edge_dst] 334 Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0]) 335 336 ## SELF SCALAR FEATURES 337 si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}")(xi) 338 339 components = [si, vai] 340 341 ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS 342 if self.tens_heads == 1: 343 Vi2 = Vi**2 344 else: 345 Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi) 346 for l in range(self.lmax + 1): 347 norm = 1.0 / (2 * l + 1) 348 components.append( 349 (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm 350 ) 351 352 ### LODE (~ LONG-RANGE ATTENTION) 353 if do_lode and layer == self.nlayers - 1: 354 assert self.lode_channels <= self.tens_heads 355 zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}")(xi).reshape( 356 xi.shape[0], self.lode_channels, dim_lr 357 ) 358 if nextra_powers > 0: 359 zj_extra = zj[:,:, :nextra_powers] 360 zj = zj[:, :, nextra_powers:] 361 xi_lr_extra = jax.ops.segment_sum( 362 eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr], 363 edge_src_lr, 364 species.shape[0], 365 ).reshape(species.shape[0],-1) 366 components.append(xi_lr_extra) 367 if equivariant_lode: 368 zj = zj.repeat(nrep_lr, axis=-1) 369 Vi_lr = jax.ops.segment_sum( 370 eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0] 371 ) 372 components.append(Vi_lr[:,: , 0]) 373 if equivariant_lode: 374 Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr 375 for l in range(1, lmax_lr + 1): 376 norm = 1.0 / (2 * l + 1) 377 components.append( 378 Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1) 379 * norm 380 ) 381 382 ### CONCATENATE UPDATE COMPONENTS 383 components = jnp.concatenate(components, axis=-1) 384 ### COMPUTE UPDATE 385 if self.block_index_key is not None: 386 ## MoE neural network from block index 387 block_index = inputs[self.block_index_key] 388 updi = BlockIndexNet( 389 output_dim=self.dim + self.tens_heads*(self.lmax + 1), 390 hidden_neurons=self.update_hidden, 391 activation=self.activation, 392 use_bias=self.update_bias, 393 name=f"update_net_{layer}", 394 )((species,components, block_index)) 395 else: 396 updi = FullyConnectedNet( 397 [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)], 398 activation=self.activation, 399 use_bias=self.update_bias, 400 name=f"update_net_{layer}", 401 )(components) 402 403 ## UPDATE ATOM FEATURES 404 xi = layer_norm(xi + updi[:,:self.dim]) 405 Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 406 if self.tens_heads > 1: 407 Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi) 408 409 if self.keep_all_layers: 410 ## STORE ALL LAYERS 411 fis.append(xi) 412 413 414 output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi} 415 if self.keep_all_layers: 416 output[self.embedding_key+'_layers'] = jnp.stack(fis,axis=1) 417 return output
16class RaSTER(nn.Module): 17 """ Range-Separated Transformer with Equivariant Representations 18 19 FID : RASTER 20 21 """ 22 23 _graphs_properties: Dict 24 dim: int = 176 25 """The dimension of the output embedding.""" 26 nlayers: int = 2 27 """The number of message-passing layers.""" 28 att_dim: int = 16 29 """The dimension of the attention heads.""" 30 scal_heads: int = 16 31 """The number of scalar attention heads.""" 32 tens_heads: int = 4 33 """The number of tensor attention heads.""" 34 lmax: int = 3 35 """The maximum angular momentum to consider.""" 36 normalize_vec: bool = True 37 """Whether to normalize the vector features before computing spherical harmonics.""" 38 att_activation: str = "identity" 39 """The activation function to use for the attention coefficients.""" 40 activation: str = "swish" 41 """The activation function to use for the update network.""" 42 update_hidden: Sequence[int] = () 43 """The hidden layers for the update network.""" 44 update_bias: bool = True 45 """Whether to use bias in the update network.""" 46 positional_activation: str = "swish" 47 """The activation function to use for the positional embedding network.""" 48 positional_bias: bool = True 49 """Whether to use bias in the positional embedding network.""" 50 switch_before_net: bool = False 51 """Whether to apply the switch function to the radial basis before the edge neural network.""" 52 ignore_parity: bool = False 53 """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.""" 54 additive_positional: bool = False 55 """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.""" 56 edge_value: bool = False 57 """Whether to use edge values in the attention mechanism.""" 58 layer_normalization: bool = True 59 """Whether to use layer normalization of atomic embeddings.""" 60 graph_key: str = "graph" 61 """ The key in the input dictionary that corresponds to the radial graph.""" 62 embedding_key: str = "embedding" 63 """ The key in the output dictionary that corresponds to the embedding.""" 64 radial_basis: dict = dataclasses.field( 65 default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16} 66 ) 67 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 68 species_encoding: str | dict = dataclasses.field( 69 default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"} 70 ) 71 """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 72 graph_lode: Optional[str] = None 73 """The key in the input dictionary that corresponds to the long-range graph.""" 74 lmax_lode: int = 0 75 """The maximum angular momentum for the long-range features.""" 76 lode_rshort: Optional[float] = None 77 """The short-range cutoff for the long-range features.""" 78 lode_dshort: float = 2.0 79 """The width of the short-range cutoff for the long-range features.""" 80 lode_extra_powers: Sequence[int] = () 81 """The extra powers to include in the long-range features.""" 82 a_lode: float = -1.0 83 """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).""" 84 block_index_key: Optional[str] = None 85 """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.""" 86 lode_channels: int = 1 87 """The number of channels for the long-range features.""" 88 switch_cov_start: float = 0.5 89 """The start of close-range covalent switch (in units of covalent radii).""" 90 switch_cov_end: float = 0.6 91 """The end of close-range covalent switch (in units of covalent radii).""" 92 normalize_keys: bool = False 93 """Whether to normalize queries and keys in the attention mechanism.""" 94 keep_all_layers: bool = False 95 """Whether to return the stacked scalar embeddings from all message-passing layers.""" 96 97 FID: ClassVar[str] = "RASTER" 98 99 @nn.compact 100 def __call__(self, inputs): 101 species = inputs["species"] 102 103 ## SETUP LAYER NORMALIZATION 104 def _layer_norm(x): 105 mu = jnp.mean(x, axis=-1, keepdims=True) 106 dx = x - mu 107 var = jnp.mean(dx**2, axis=-1, keepdims=True) 108 sig = (1.0e-6 + var) ** (-0.5) 109 return dx * sig 110 111 if self.layer_normalization: 112 layer_norm = _layer_norm 113 else: 114 layer_norm = lambda x: x 115 116 if self.normalize_keys: 117 ln_qk = _layer_norm 118 else: 119 ln_qk = lambda x: x 120 121 ## SPECIES ENCODING 122 if isinstance(self.species_encoding, str): 123 Zi = inputs[self.species_encoding] 124 else: 125 Zi = SpeciesEncoding(**self.species_encoding)(species) 126 127 ## INITIALIZE SCALAR FEATURES 128 xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear")(Zi)) 129 130 # RADIAL GRAPH 131 graph = inputs[self.graph_key] 132 distances = graph["distances"] 133 switch = graph["switch"] 134 edge_src = graph["edge_src"] 135 edge_dst = graph["edge_dst"] 136 vec = ( 137 graph["vec"] / graph["distances"][:, None] 138 if self.normalize_vec 139 else graph["vec"] 140 ) 141 ## CLOSE-RANGE SWITCH 142 use_switch_cov = False 143 if self.switch_cov_end > 0 and self.switch_cov_start > 0: 144 use_switch_cov = True 145 assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}" 146 assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1" 147 rc = jnp.array(D3_COV_RADII*au.BOHR)[species] 148 rcij = rc[edge_src] + rc[edge_dst] 149 rstart = rcij * self.switch_cov_start 150 rend = rcij * self.switch_cov_end 151 switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend) 152 switch = switch * switch_short 153 154 ## COMPUTE SPHERICAL HARMONICS ON EDGES 155 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:] 156 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 157 ls = np.arange(self.lmax + 1).repeat(nrep) 158 159 parity = jnp.array((-1) ** ls[None,None,:]) 160 if self.ignore_parity: 161 parity = -jnp.ones_like(parity) 162 163 ## INITIALIZE TENSOR FEATURES 164 Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1])) 165 166 # RADIAL BASIS 167 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 168 radial_terms = RadialBasis( 169 **{ 170 **self.radial_basis, 171 "end": cutoff, 172 "name": f"RadialBasis", 173 } 174 )(distances) 175 if self.switch_before_net: 176 radial_terms = radial_terms * switch[:, None] 177 elif use_switch_cov: 178 radial_terms = radial_terms * switch_short[:, None] 179 180 ## INITIALIZE LODE 181 do_lode = self.graph_lode is not None 182 if do_lode: 183 ## LONG-RANGE GRAPH 184 graph_lode = inputs[self.graph_lode] 185 switch_lode = graph_lode["switch"][:, None] 186 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 187 r = graph_lode["distances"][:, None] 188 rc = self._graphs_properties[self.graph_lode]["cutoff"] 189 190 lmax_lr = self.lmax_lode 191 equivariant_lode = lmax_lr > 0 192 assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}" 193 assert ( 194 lmax_lr <= self.lmax 195 ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 196 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 197 if equivariant_lode: 198 ls_lr = np.arange(lmax_lr + 1) 199 else: 200 ls_lr = np.array([0]) 201 202 ## PARAMETERS FOR THE LR RADIAL BASIS 203 nextra_powers = len(self.lode_extra_powers) 204 if nextra_powers > 0: 205 ls_lr = np.concatenate([self.lode_extra_powers, ls_lr]) 206 207 if self.a_lode > 0: 208 a = self.a_lode**2 209 else: 210 a = ( 211 self.param( 212 "a_lr", 213 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[ 214 None, : 215 ], 216 ) 217 ** 2 218 ) 219 rc2a = rc**2 + a 220 ls_lr = 0.5 * (ls_lr[None, :] + 1) 221 ### minimal radial basis for long range (damped coulomb) 222 eij_lr = ( 223 1.0 / (r**2 + a) ** ls_lr 224 - 1.0 / rc2a**ls_lr 225 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 226 ) * switch_lode 227 228 if self.lode_rshort is not None: 229 rs = self.lode_rshort 230 d = self.lode_dshort 231 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 232 r < rs + d 233 ) + (r >= rs + d) 234 eij_lr = eij_lr * switch_short 235 236 dim_lr = 1 237 if nextra_powers > 0: 238 eij_lr_extra = eij_lr[:, :nextra_powers] 239 eij_lr = eij_lr[:, nextra_powers:] 240 dim_lr += nextra_powers 241 242 if equivariant_lode: 243 ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH 244 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 245 Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 246 graph_lode["vec"] / r 247 ) 248 dim_lr += lmax_lr 249 eij_lr = eij_lr * Yij_lr 250 del Yij_lr 251 252 253 if self.keep_all_layers: 254 fis = [] 255 256 ### START MESSAGE PASSING ITERATIONS 257 for layer in range(self.nlayers): 258 ## GATHER SCALAR EDGE FEATURES 259 u = [radial_terms] 260 if layer > 0: 261 ## edge-tensor contraction 262 xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij 263 for l in range(self.lmax + 1): 264 u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1)) 265 ur = jnp.concatenate(u, axis=-1) 266 267 ## BUILD RELATIVE POSITIONAL ENCODING 268 if self.edge_value: 269 nout = 2 270 else: 271 nout = 1 272 w = FullyConnectedNet( 273 [2 * self.att_dim, nout*self.att_dim], 274 activation=self.positional_activation, 275 use_bias=self.positional_bias, 276 name=f"positional_encoding_{layer}", 277 )(ur).reshape(radial_terms.shape[0],nout, self.att_dim) 278 if self.edge_value: 279 w,vij = jnp.split(w, 2, axis=1) 280 281 nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1) 282 283 284 ## QUERY, KEY, VALUE 285 q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}")( 286 xi 287 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)) 288 k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}")( 289 xi 290 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim) 291 292 v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}")(xi).reshape( 293 xi.shape[0], self.scal_heads, self.att_dim 294 ) 295 296 ## ATTENTION COEFFICIENTS 297 if self.additive_positional: 298 wk = ln_qk(w + k[edge_dst]) 299 else: 300 wk = ln_qk(w * k[edge_dst]) 301 302 act = activation_from_str(self.att_activation) 303 aij = ( 304 act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5)) 305 * switch[:, None] 306 ) 307 308 aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 309 if layer > 0: 310 aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 311 aij = aij[:, self.tens_heads*nls:, None] 312 313 if self.edge_value: 314 ## EDGE VALUES 315 if self.additive_positional: 316 vij = vij + v[edge_dst] 317 else: 318 vij = vij * v[edge_dst] 319 else: 320 ## MOVE DEST VALUES TO EDGE 321 vij = v[edge_dst] 322 323 ## SCALAR ATTENDED FEATURES 324 vai = jax.ops.segment_sum( 325 aij * vij, 326 edge_src, 327 num_segments=xi.shape[0], 328 ) 329 vai = vai.reshape(xi.shape[0], -1) 330 331 ### TENSOR ATTENDED FEATURES 332 uij = aijl * Yij 333 if layer > 0: 334 uij = uij + aijl1 * Vi[edge_dst] 335 Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0]) 336 337 ## SELF SCALAR FEATURES 338 si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}")(xi) 339 340 components = [si, vai] 341 342 ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS 343 if self.tens_heads == 1: 344 Vi2 = Vi**2 345 else: 346 Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi) 347 for l in range(self.lmax + 1): 348 norm = 1.0 / (2 * l + 1) 349 components.append( 350 (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm 351 ) 352 353 ### LODE (~ LONG-RANGE ATTENTION) 354 if do_lode and layer == self.nlayers - 1: 355 assert self.lode_channels <= self.tens_heads 356 zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}")(xi).reshape( 357 xi.shape[0], self.lode_channels, dim_lr 358 ) 359 if nextra_powers > 0: 360 zj_extra = zj[:,:, :nextra_powers] 361 zj = zj[:, :, nextra_powers:] 362 xi_lr_extra = jax.ops.segment_sum( 363 eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr], 364 edge_src_lr, 365 species.shape[0], 366 ).reshape(species.shape[0],-1) 367 components.append(xi_lr_extra) 368 if equivariant_lode: 369 zj = zj.repeat(nrep_lr, axis=-1) 370 Vi_lr = jax.ops.segment_sum( 371 eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0] 372 ) 373 components.append(Vi_lr[:,: , 0]) 374 if equivariant_lode: 375 Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr 376 for l in range(1, lmax_lr + 1): 377 norm = 1.0 / (2 * l + 1) 378 components.append( 379 Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1) 380 * norm 381 ) 382 383 ### CONCATENATE UPDATE COMPONENTS 384 components = jnp.concatenate(components, axis=-1) 385 ### COMPUTE UPDATE 386 if self.block_index_key is not None: 387 ## MoE neural network from block index 388 block_index = inputs[self.block_index_key] 389 updi = BlockIndexNet( 390 output_dim=self.dim + self.tens_heads*(self.lmax + 1), 391 hidden_neurons=self.update_hidden, 392 activation=self.activation, 393 use_bias=self.update_bias, 394 name=f"update_net_{layer}", 395 )((species,components, block_index)) 396 else: 397 updi = FullyConnectedNet( 398 [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)], 399 activation=self.activation, 400 use_bias=self.update_bias, 401 name=f"update_net_{layer}", 402 )(components) 403 404 ## UPDATE ATOM FEATURES 405 xi = layer_norm(xi + updi[:,:self.dim]) 406 Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 407 if self.tens_heads > 1: 408 Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi) 409 410 if self.keep_all_layers: 411 ## STORE ALL LAYERS 412 fis.append(xi) 413 414 415 output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi} 416 if self.keep_all_layers: 417 output[self.embedding_key+'_layers'] = jnp.stack(fis,axis=1) 418 return output
Range-Separated Transformer with Equivariant Representations
FID : RASTER
Whether to normalize the vector features before computing spherical harmonics.
The activation function to use for the positional embedding network.
Whether to apply the switch function to the radial basis before the edge neural network.
Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.
Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.
The key in the output dictionary that corresponds to the embedding.
The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis
.
The dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding
.
The key in the input dictionary that corresponds to the long-range graph.
The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).
The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.
The start of close-range covalent switch (in units of covalent radii).
Whether to return the stacked scalar embeddings from all message-passing layers.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.