fennol.models.embeddings.raster
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, Union, ClassVar, Optional 5import numpy as np 6from ...utils.periodic_table import D3_COV_RADII 7import dataclasses 8from ..misc.encodings import RadialBasis, SpeciesEncoding 9from ...utils.spherical_harmonics import generate_spherical_harmonics 10from ..misc.e3 import ChannelMixing 11from ..misc.nets import FullyConnectedNet,BlockIndexNet 12from ...utils.activations import activation_from_str 13from ...utils.atomic_units import au 14from ...utils.initializers import initializer_from_str 15 16class RaSTER(nn.Module): 17 """ Range-Separated Transformer with Equivariant Representations 18 19 FID : RASTER 20 21 """ 22 23 _graphs_properties: Dict 24 dim: int = 176 25 """The dimension of the output embedding.""" 26 nlayers: int = 2 27 """The number of message-passing layers.""" 28 att_dim: int = 16 29 """The dimension of the attention heads.""" 30 scal_heads: int = 16 31 """The number of scalar attention heads.""" 32 tens_heads: int = 4 33 """The number of tensor attention heads.""" 34 lmax: int = 3 35 """The maximum angular momentum to consider.""" 36 normalize_vec: bool = True 37 """Whether to normalize the vector features before computing spherical harmonics.""" 38 att_activation: str = "identity" 39 """The activation function to use for the attention coefficients.""" 40 activation: str = "swish" 41 """The activation function to use for the update network.""" 42 update_hidden: Sequence[int] = () 43 """The hidden layers for the update network.""" 44 update_bias: bool = True 45 """Whether to use bias in the update network.""" 46 positional_activation: str = "swish" 47 """The activation function to use for the positional embedding network.""" 48 positional_bias: bool = True 49 """Whether to use bias in the positional embedding network.""" 50 switch_before_net: bool = False 51 """Whether to apply the switch function to the radial basis before the edge neural network.""" 52 ignore_parity: bool = False 53 """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.""" 54 additive_positional: bool = False 55 """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.""" 56 edge_value: bool = False 57 """Whether to use edge values in the attention mechanism.""" 58 layer_normalization: bool = True 59 """Whether to use layer normalization of atomic embeddings.""" 60 graph_key: str = "graph" 61 """ The key in the input dictionary that corresponds to the radial graph.""" 62 embedding_key: str = "embedding" 63 """ The key in the output dictionary that corresponds to the embedding.""" 64 radial_basis: dict = dataclasses.field( 65 default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16} 66 ) 67 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 68 species_encoding: str | dict = dataclasses.field( 69 default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"} 70 ) 71 """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 72 graph_lode: Optional[str] = None 73 """The key in the input dictionary that corresponds to the long-range graph.""" 74 lmax_lode: int = 0 75 """The maximum angular momentum for the long-range features.""" 76 lode_rshort: Optional[float] = None 77 """The short-range cutoff for the long-range features.""" 78 lode_dshort: float = 2.0 79 """The width of the short-range cutoff for the long-range features.""" 80 lode_extra_powers: Sequence[int] = () 81 """The extra powers to include in the long-range features.""" 82 a_lode: float = -1.0 83 """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).""" 84 block_index_key: Optional[str] = None 85 """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.""" 86 lode_channels: int = 1 87 """The number of channels for the long-range features.""" 88 switch_cov_start: float = 0.5 89 """The start of close-range covalent switch (in units of covalent radii).""" 90 switch_cov_end: float = 0.6 91 """The end of close-range covalent switch (in units of covalent radii).""" 92 normalize_keys: bool = False 93 """Whether to normalize queries and keys in the attention mechanism.""" 94 normalize_components: bool = False 95 """Whether to normalize the components before the update network.""" 96 keep_all_layers: bool = False 97 """Whether to return the stacked scalar embeddings from all message-passing layers.""" 98 kernel_init: Optional[str] = None 99 100 FID: ClassVar[str] = "RASTER" 101 102 @nn.compact 103 def __call__(self, inputs): 104 species = inputs["species"] 105 106 ## SETUP LAYER NORMALIZATION 107 def _layer_norm(x): 108 mu = jnp.mean(x, axis=-1, keepdims=True) 109 dx = x - mu 110 var = jnp.mean(dx**2, axis=-1, keepdims=True) 111 sig = (1.0e-6 + var) ** (-0.5) 112 return dx * sig 113 114 if self.layer_normalization: 115 layer_norm = _layer_norm 116 else: 117 layer_norm = lambda x: x 118 119 if self.normalize_keys: 120 ln_qk = _layer_norm 121 else: 122 ln_qk = lambda x: x 123 124 kernel_init = initializer_from_str(self.kernel_init) 125 126 ## SPECIES ENCODING 127 if isinstance(self.species_encoding, str): 128 Zi = inputs[self.species_encoding] 129 else: 130 Zi = SpeciesEncoding(**self.species_encoding)(species) 131 132 ## INITIALIZE SCALAR FEATURES 133 xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear",kernel_init=kernel_init)(Zi)) 134 135 # RADIAL GRAPH 136 graph = inputs[self.graph_key] 137 distances = graph["distances"] 138 switch = graph["switch"] 139 edge_src = graph["edge_src"] 140 edge_dst = graph["edge_dst"] 141 vec = ( 142 graph["vec"] / graph["distances"][:, None] 143 if self.normalize_vec 144 else graph["vec"] 145 ) 146 ## CLOSE-RANGE SWITCH 147 use_switch_cov = False 148 if self.switch_cov_end > 0 and self.switch_cov_start > 0: 149 use_switch_cov = True 150 assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}" 151 assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1" 152 rc = jnp.array(D3_COV_RADII*au.ANG)[species] 153 rcij = rc[edge_src] + rc[edge_dst] 154 rstart = rcij * self.switch_cov_start 155 rend = rcij * self.switch_cov_end 156 switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend) 157 switch = switch * switch_short 158 159 ## COMPUTE SPHERICAL HARMONICS ON EDGES 160 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:] 161 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 162 ls = np.arange(self.lmax + 1).repeat(nrep) 163 164 parity = jnp.array((-1) ** ls[None,None,:]) 165 if self.ignore_parity: 166 parity = -jnp.ones_like(parity) 167 168 ## INITIALIZE TENSOR FEATURES 169 Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1])) 170 171 # RADIAL BASIS 172 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 173 radial_terms = RadialBasis( 174 **{ 175 **self.radial_basis, 176 "end": cutoff, 177 "name": f"RadialBasis", 178 } 179 )(distances) 180 if self.switch_before_net: 181 radial_terms = radial_terms * switch[:, None] 182 elif use_switch_cov: 183 radial_terms = radial_terms * switch_short[:, None] 184 185 ## INITIALIZE LODE 186 do_lode = self.graph_lode is not None 187 if do_lode: 188 ## LONG-RANGE GRAPH 189 graph_lode = inputs[self.graph_lode] 190 switch_lode = graph_lode["switch"][:, None] 191 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 192 r = graph_lode["distances"][:, None] 193 rc = self._graphs_properties[self.graph_lode]["cutoff"] 194 195 lmax_lr = self.lmax_lode 196 equivariant_lode = lmax_lr > 0 197 assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}" 198 assert ( 199 lmax_lr <= self.lmax 200 ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 201 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 202 if equivariant_lode: 203 ls_lr = np.arange(lmax_lr + 1) 204 else: 205 ls_lr = np.array([0]) 206 207 ## PARAMETERS FOR THE LR RADIAL BASIS 208 nextra_powers = len(self.lode_extra_powers) 209 if nextra_powers > 0: 210 ls_lr = np.concatenate([self.lode_extra_powers, ls_lr]) 211 212 if self.a_lode > 0: 213 a = self.a_lode**2 214 else: 215 a = ( 216 self.param( 217 "a_lr", 218 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[ 219 None, : 220 ], 221 ) 222 ** 2 223 ) 224 rc2a = rc**2 + a 225 ls_lr = 0.5 * (ls_lr[None, :] + 1) 226 ### minimal radial basis for long range (damped coulomb) 227 eij_lr = ( 228 1.0 / (r**2 + a) ** ls_lr 229 - 1.0 / rc2a**ls_lr 230 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 231 ) * switch_lode 232 233 if self.lode_rshort is not None: 234 rs = self.lode_rshort 235 d = self.lode_dshort 236 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 237 r < rs + d 238 ) + (r >= rs + d) 239 eij_lr = eij_lr * switch_short 240 241 dim_lr = 1 242 if nextra_powers > 0: 243 eij_lr_extra = eij_lr[:, :nextra_powers] 244 eij_lr = eij_lr[:, nextra_powers:] 245 dim_lr += nextra_powers 246 247 if equivariant_lode: 248 ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH 249 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 250 Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 251 graph_lode["vec"] / r 252 ) 253 dim_lr += lmax_lr 254 eij_lr = eij_lr * Yij_lr 255 del Yij_lr 256 257 258 if self.keep_all_layers: 259 xis = [] 260 261 ### START MESSAGE PASSING ITERATIONS 262 for layer in range(self.nlayers): 263 ## GATHER SCALAR EDGE FEATURES 264 u = [radial_terms] 265 if layer > 0: 266 ## edge-tensor contraction 267 xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij 268 for l in range(self.lmax + 1): 269 u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1)) 270 ur = jnp.concatenate(u, axis=-1) 271 272 ## BUILD RELATIVE POSITIONAL ENCODING 273 if self.edge_value: 274 nout = 2 275 else: 276 nout = 1 277 w = FullyConnectedNet( 278 [2 * self.att_dim, nout*self.att_dim], 279 activation=self.positional_activation, 280 use_bias=self.positional_bias, 281 name=f"positional_encoding_{layer}", 282 )(ur).reshape(radial_terms.shape[0],nout, self.att_dim) 283 if self.edge_value: 284 w,vij = jnp.split(w, 2, axis=1) 285 286 nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1) 287 288 289 ## QUERY, KEY, VALUE 290 q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}",kernel_init=kernel_init)( 291 xi 292 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)) 293 k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}",kernel_init=kernel_init)( 294 xi 295 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim) 296 297 v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}",kernel_init=kernel_init)(xi).reshape( 298 xi.shape[0], self.scal_heads, self.att_dim 299 ) 300 301 ## ATTENTION COEFFICIENTS 302 if self.additive_positional: 303 wk = ln_qk(w + k[edge_dst]) 304 else: 305 wk = ln_qk(w * k[edge_dst]) 306 307 act = activation_from_str(self.att_activation) 308 aij = ( 309 act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5)) 310 * switch[:, None] 311 ) 312 313 aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 314 if layer > 0: 315 aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 316 aij = aij[:, self.tens_heads*nls:, None] 317 318 if self.edge_value: 319 ## EDGE VALUES 320 if self.additive_positional: 321 vij = vij + v[edge_dst] 322 else: 323 vij = vij * v[edge_dst] 324 else: 325 ## MOVE DEST VALUES TO EDGE 326 vij = v[edge_dst] 327 328 ## SCALAR ATTENDED FEATURES 329 vai = jax.ops.segment_sum( 330 aij * vij, 331 edge_src, 332 num_segments=xi.shape[0], 333 ) 334 vai = vai.reshape(xi.shape[0], -1) 335 336 ### TENSOR ATTENDED FEATURES 337 uij = aijl * Yij 338 if layer > 0: 339 uij = uij + aijl1 * Vi[edge_dst] 340 Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0]) 341 342 ## SELF SCALAR FEATURES 343 si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}",kernel_init=kernel_init)(xi) 344 345 components = [si, vai] 346 347 ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS 348 if self.tens_heads == 1: 349 Vi2 = Vi**2 350 else: 351 Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi) 352 for l in range(self.lmax + 1): 353 norm = 1.0 / (2 * l + 1) 354 components.append( 355 (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm 356 ) 357 358 ### LODE (~ LONG-RANGE ATTENTION) 359 if do_lode and layer == self.nlayers - 1: 360 assert self.lode_channels <= self.tens_heads 361 zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}",kernel_init=kernel_init)(xi).reshape( 362 xi.shape[0], self.lode_channels, dim_lr 363 ) 364 if nextra_powers > 0: 365 zj_extra = zj[:,:, :nextra_powers] 366 zj = zj[:, :, nextra_powers:] 367 xi_lr_extra = jax.ops.segment_sum( 368 eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr], 369 edge_src_lr, 370 species.shape[0], 371 ).reshape(species.shape[0],-1) 372 components.append(xi_lr_extra) 373 if equivariant_lode: 374 zj = zj.repeat(nrep_lr, axis=-1) 375 Vi_lr = jax.ops.segment_sum( 376 eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0] 377 ) 378 components.append(Vi_lr[:,: , 0]) 379 if equivariant_lode: 380 Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr 381 for l in range(1, lmax_lr + 1): 382 norm = 1.0 / (2 * l + 1) 383 components.append( 384 Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1) 385 * norm 386 ) 387 388 ### CONCATENATE UPDATE COMPONENTS 389 components = jnp.concatenate(components, axis=-1) 390 if self.normalize_components: 391 components = _layer_norm(components) 392 ### COMPUTE UPDATE 393 if self.block_index_key is not None: 394 ## MoE neural network from block index 395 block_index = inputs[self.block_index_key] 396 updi = BlockIndexNet( 397 output_dim=self.dim + self.tens_heads*(self.lmax + 1), 398 hidden_neurons=self.update_hidden, 399 activation=self.activation, 400 use_bias=self.update_bias, 401 name=f"update_net_{layer}", 402 kernel_init=kernel_init, 403 )((species,components, block_index)) 404 else: 405 updi = FullyConnectedNet( 406 [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)], 407 activation=self.activation, 408 use_bias=self.update_bias, 409 name=f"update_net_{layer}", 410 kernel_init=kernel_init, 411 )(components) 412 413 ## UPDATE ATOM FEATURES 414 xi = layer_norm(xi + updi[:,:self.dim]) 415 Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 416 if self.tens_heads > 1: 417 Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi) 418 419 if self.keep_all_layers: 420 ## STORE ALL LAYERS 421 xis.append(xi) 422 423 424 output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi} 425 if self.keep_all_layers: 426 output[self.embedding_key+'_layers'] = jnp.stack(xis,axis=1) 427 return output
17class RaSTER(nn.Module): 18 """ Range-Separated Transformer with Equivariant Representations 19 20 FID : RASTER 21 22 """ 23 24 _graphs_properties: Dict 25 dim: int = 176 26 """The dimension of the output embedding.""" 27 nlayers: int = 2 28 """The number of message-passing layers.""" 29 att_dim: int = 16 30 """The dimension of the attention heads.""" 31 scal_heads: int = 16 32 """The number of scalar attention heads.""" 33 tens_heads: int = 4 34 """The number of tensor attention heads.""" 35 lmax: int = 3 36 """The maximum angular momentum to consider.""" 37 normalize_vec: bool = True 38 """Whether to normalize the vector features before computing spherical harmonics.""" 39 att_activation: str = "identity" 40 """The activation function to use for the attention coefficients.""" 41 activation: str = "swish" 42 """The activation function to use for the update network.""" 43 update_hidden: Sequence[int] = () 44 """The hidden layers for the update network.""" 45 update_bias: bool = True 46 """Whether to use bias in the update network.""" 47 positional_activation: str = "swish" 48 """The activation function to use for the positional embedding network.""" 49 positional_bias: bool = True 50 """Whether to use bias in the positional embedding network.""" 51 switch_before_net: bool = False 52 """Whether to apply the switch function to the radial basis before the edge neural network.""" 53 ignore_parity: bool = False 54 """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.""" 55 additive_positional: bool = False 56 """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.""" 57 edge_value: bool = False 58 """Whether to use edge values in the attention mechanism.""" 59 layer_normalization: bool = True 60 """Whether to use layer normalization of atomic embeddings.""" 61 graph_key: str = "graph" 62 """ The key in the input dictionary that corresponds to the radial graph.""" 63 embedding_key: str = "embedding" 64 """ The key in the output dictionary that corresponds to the embedding.""" 65 radial_basis: dict = dataclasses.field( 66 default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16} 67 ) 68 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 69 species_encoding: str | dict = dataclasses.field( 70 default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"} 71 ) 72 """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 73 graph_lode: Optional[str] = None 74 """The key in the input dictionary that corresponds to the long-range graph.""" 75 lmax_lode: int = 0 76 """The maximum angular momentum for the long-range features.""" 77 lode_rshort: Optional[float] = None 78 """The short-range cutoff for the long-range features.""" 79 lode_dshort: float = 2.0 80 """The width of the short-range cutoff for the long-range features.""" 81 lode_extra_powers: Sequence[int] = () 82 """The extra powers to include in the long-range features.""" 83 a_lode: float = -1.0 84 """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).""" 85 block_index_key: Optional[str] = None 86 """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.""" 87 lode_channels: int = 1 88 """The number of channels for the long-range features.""" 89 switch_cov_start: float = 0.5 90 """The start of close-range covalent switch (in units of covalent radii).""" 91 switch_cov_end: float = 0.6 92 """The end of close-range covalent switch (in units of covalent radii).""" 93 normalize_keys: bool = False 94 """Whether to normalize queries and keys in the attention mechanism.""" 95 normalize_components: bool = False 96 """Whether to normalize the components before the update network.""" 97 keep_all_layers: bool = False 98 """Whether to return the stacked scalar embeddings from all message-passing layers.""" 99 kernel_init: Optional[str] = None 100 101 FID: ClassVar[str] = "RASTER" 102 103 @nn.compact 104 def __call__(self, inputs): 105 species = inputs["species"] 106 107 ## SETUP LAYER NORMALIZATION 108 def _layer_norm(x): 109 mu = jnp.mean(x, axis=-1, keepdims=True) 110 dx = x - mu 111 var = jnp.mean(dx**2, axis=-1, keepdims=True) 112 sig = (1.0e-6 + var) ** (-0.5) 113 return dx * sig 114 115 if self.layer_normalization: 116 layer_norm = _layer_norm 117 else: 118 layer_norm = lambda x: x 119 120 if self.normalize_keys: 121 ln_qk = _layer_norm 122 else: 123 ln_qk = lambda x: x 124 125 kernel_init = initializer_from_str(self.kernel_init) 126 127 ## SPECIES ENCODING 128 if isinstance(self.species_encoding, str): 129 Zi = inputs[self.species_encoding] 130 else: 131 Zi = SpeciesEncoding(**self.species_encoding)(species) 132 133 ## INITIALIZE SCALAR FEATURES 134 xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear",kernel_init=kernel_init)(Zi)) 135 136 # RADIAL GRAPH 137 graph = inputs[self.graph_key] 138 distances = graph["distances"] 139 switch = graph["switch"] 140 edge_src = graph["edge_src"] 141 edge_dst = graph["edge_dst"] 142 vec = ( 143 graph["vec"] / graph["distances"][:, None] 144 if self.normalize_vec 145 else graph["vec"] 146 ) 147 ## CLOSE-RANGE SWITCH 148 use_switch_cov = False 149 if self.switch_cov_end > 0 and self.switch_cov_start > 0: 150 use_switch_cov = True 151 assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}" 152 assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1" 153 rc = jnp.array(D3_COV_RADII*au.ANG)[species] 154 rcij = rc[edge_src] + rc[edge_dst] 155 rstart = rcij * self.switch_cov_start 156 rend = rcij * self.switch_cov_end 157 switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend) 158 switch = switch * switch_short 159 160 ## COMPUTE SPHERICAL HARMONICS ON EDGES 161 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:] 162 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 163 ls = np.arange(self.lmax + 1).repeat(nrep) 164 165 parity = jnp.array((-1) ** ls[None,None,:]) 166 if self.ignore_parity: 167 parity = -jnp.ones_like(parity) 168 169 ## INITIALIZE TENSOR FEATURES 170 Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1])) 171 172 # RADIAL BASIS 173 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 174 radial_terms = RadialBasis( 175 **{ 176 **self.radial_basis, 177 "end": cutoff, 178 "name": f"RadialBasis", 179 } 180 )(distances) 181 if self.switch_before_net: 182 radial_terms = radial_terms * switch[:, None] 183 elif use_switch_cov: 184 radial_terms = radial_terms * switch_short[:, None] 185 186 ## INITIALIZE LODE 187 do_lode = self.graph_lode is not None 188 if do_lode: 189 ## LONG-RANGE GRAPH 190 graph_lode = inputs[self.graph_lode] 191 switch_lode = graph_lode["switch"][:, None] 192 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 193 r = graph_lode["distances"][:, None] 194 rc = self._graphs_properties[self.graph_lode]["cutoff"] 195 196 lmax_lr = self.lmax_lode 197 equivariant_lode = lmax_lr > 0 198 assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}" 199 assert ( 200 lmax_lr <= self.lmax 201 ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 202 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 203 if equivariant_lode: 204 ls_lr = np.arange(lmax_lr + 1) 205 else: 206 ls_lr = np.array([0]) 207 208 ## PARAMETERS FOR THE LR RADIAL BASIS 209 nextra_powers = len(self.lode_extra_powers) 210 if nextra_powers > 0: 211 ls_lr = np.concatenate([self.lode_extra_powers, ls_lr]) 212 213 if self.a_lode > 0: 214 a = self.a_lode**2 215 else: 216 a = ( 217 self.param( 218 "a_lr", 219 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[ 220 None, : 221 ], 222 ) 223 ** 2 224 ) 225 rc2a = rc**2 + a 226 ls_lr = 0.5 * (ls_lr[None, :] + 1) 227 ### minimal radial basis for long range (damped coulomb) 228 eij_lr = ( 229 1.0 / (r**2 + a) ** ls_lr 230 - 1.0 / rc2a**ls_lr 231 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 232 ) * switch_lode 233 234 if self.lode_rshort is not None: 235 rs = self.lode_rshort 236 d = self.lode_dshort 237 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 238 r < rs + d 239 ) + (r >= rs + d) 240 eij_lr = eij_lr * switch_short 241 242 dim_lr = 1 243 if nextra_powers > 0: 244 eij_lr_extra = eij_lr[:, :nextra_powers] 245 eij_lr = eij_lr[:, nextra_powers:] 246 dim_lr += nextra_powers 247 248 if equivariant_lode: 249 ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH 250 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 251 Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 252 graph_lode["vec"] / r 253 ) 254 dim_lr += lmax_lr 255 eij_lr = eij_lr * Yij_lr 256 del Yij_lr 257 258 259 if self.keep_all_layers: 260 xis = [] 261 262 ### START MESSAGE PASSING ITERATIONS 263 for layer in range(self.nlayers): 264 ## GATHER SCALAR EDGE FEATURES 265 u = [radial_terms] 266 if layer > 0: 267 ## edge-tensor contraction 268 xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij 269 for l in range(self.lmax + 1): 270 u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1)) 271 ur = jnp.concatenate(u, axis=-1) 272 273 ## BUILD RELATIVE POSITIONAL ENCODING 274 if self.edge_value: 275 nout = 2 276 else: 277 nout = 1 278 w = FullyConnectedNet( 279 [2 * self.att_dim, nout*self.att_dim], 280 activation=self.positional_activation, 281 use_bias=self.positional_bias, 282 name=f"positional_encoding_{layer}", 283 )(ur).reshape(radial_terms.shape[0],nout, self.att_dim) 284 if self.edge_value: 285 w,vij = jnp.split(w, 2, axis=1) 286 287 nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1) 288 289 290 ## QUERY, KEY, VALUE 291 q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}",kernel_init=kernel_init)( 292 xi 293 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)) 294 k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}",kernel_init=kernel_init)( 295 xi 296 ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim) 297 298 v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}",kernel_init=kernel_init)(xi).reshape( 299 xi.shape[0], self.scal_heads, self.att_dim 300 ) 301 302 ## ATTENTION COEFFICIENTS 303 if self.additive_positional: 304 wk = ln_qk(w + k[edge_dst]) 305 else: 306 wk = ln_qk(w * k[edge_dst]) 307 308 act = activation_from_str(self.att_activation) 309 aij = ( 310 act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5)) 311 * switch[:, None] 312 ) 313 314 aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 315 if layer > 0: 316 aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 317 aij = aij[:, self.tens_heads*nls:, None] 318 319 if self.edge_value: 320 ## EDGE VALUES 321 if self.additive_positional: 322 vij = vij + v[edge_dst] 323 else: 324 vij = vij * v[edge_dst] 325 else: 326 ## MOVE DEST VALUES TO EDGE 327 vij = v[edge_dst] 328 329 ## SCALAR ATTENDED FEATURES 330 vai = jax.ops.segment_sum( 331 aij * vij, 332 edge_src, 333 num_segments=xi.shape[0], 334 ) 335 vai = vai.reshape(xi.shape[0], -1) 336 337 ### TENSOR ATTENDED FEATURES 338 uij = aijl * Yij 339 if layer > 0: 340 uij = uij + aijl1 * Vi[edge_dst] 341 Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0]) 342 343 ## SELF SCALAR FEATURES 344 si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}",kernel_init=kernel_init)(xi) 345 346 components = [si, vai] 347 348 ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS 349 if self.tens_heads == 1: 350 Vi2 = Vi**2 351 else: 352 Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi) 353 for l in range(self.lmax + 1): 354 norm = 1.0 / (2 * l + 1) 355 components.append( 356 (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm 357 ) 358 359 ### LODE (~ LONG-RANGE ATTENTION) 360 if do_lode and layer == self.nlayers - 1: 361 assert self.lode_channels <= self.tens_heads 362 zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}",kernel_init=kernel_init)(xi).reshape( 363 xi.shape[0], self.lode_channels, dim_lr 364 ) 365 if nextra_powers > 0: 366 zj_extra = zj[:,:, :nextra_powers] 367 zj = zj[:, :, nextra_powers:] 368 xi_lr_extra = jax.ops.segment_sum( 369 eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr], 370 edge_src_lr, 371 species.shape[0], 372 ).reshape(species.shape[0],-1) 373 components.append(xi_lr_extra) 374 if equivariant_lode: 375 zj = zj.repeat(nrep_lr, axis=-1) 376 Vi_lr = jax.ops.segment_sum( 377 eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0] 378 ) 379 components.append(Vi_lr[:,: , 0]) 380 if equivariant_lode: 381 Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr 382 for l in range(1, lmax_lr + 1): 383 norm = 1.0 / (2 * l + 1) 384 components.append( 385 Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1) 386 * norm 387 ) 388 389 ### CONCATENATE UPDATE COMPONENTS 390 components = jnp.concatenate(components, axis=-1) 391 if self.normalize_components: 392 components = _layer_norm(components) 393 ### COMPUTE UPDATE 394 if self.block_index_key is not None: 395 ## MoE neural network from block index 396 block_index = inputs[self.block_index_key] 397 updi = BlockIndexNet( 398 output_dim=self.dim + self.tens_heads*(self.lmax + 1), 399 hidden_neurons=self.update_hidden, 400 activation=self.activation, 401 use_bias=self.update_bias, 402 name=f"update_net_{layer}", 403 kernel_init=kernel_init, 404 )((species,components, block_index)) 405 else: 406 updi = FullyConnectedNet( 407 [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)], 408 activation=self.activation, 409 use_bias=self.update_bias, 410 name=f"update_net_{layer}", 411 kernel_init=kernel_init, 412 )(components) 413 414 ## UPDATE ATOM FEATURES 415 xi = layer_norm(xi + updi[:,:self.dim]) 416 Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1) 417 if self.tens_heads > 1: 418 Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi) 419 420 if self.keep_all_layers: 421 ## STORE ALL LAYERS 422 xis.append(xi) 423 424 425 output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi} 426 if self.keep_all_layers: 427 output[self.embedding_key+'_layers'] = jnp.stack(xis,axis=1) 428 return output
Range-Separated Transformer with Equivariant Representations
FID : RASTER
Whether to normalize the vector features before computing spherical harmonics.
The activation function to use for the positional embedding network.
Whether to apply the switch function to the radial basis before the edge neural network.
Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.
Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.
The key in the output dictionary that corresponds to the embedding.
The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis
.
The dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding
.
The key in the input dictionary that corresponds to the long-range graph.
The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).
The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.
The start of close-range covalent switch (in units of covalent radii).
Whether to return the stacked scalar embeddings from all message-passing layers.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.