fennol.models.embeddings.crate
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import dataclasses 5import numpy as np 6from typing import Dict, Union, Callable, Sequence, Optional, Tuple, ClassVar 7 8from ..misc.encodings import SpeciesEncoding, RadialBasis, positional_encoding 9from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3 10from ...utils.activations import activation_from_str 11from ...utils.initializers import initializer_from_str 12from ..misc.nets import FullyConnectedNet, BlockIndexNet 13from ..misc.e3 import ChannelMixing, ChannelMixingE3, FilteredTensorProduct 14from ...utils.periodic_table import D3_COV_RADII, D3_VDW_RADII, VALENCE_ELECTRONS 15from ...utils import AtomicUnits as au 16 17class CRATEmbedding(nn.Module): 18 """Configurable Resources ATomic Environment 19 20 FID : CRATE 21 22 This class represents the CRATE (Configurable Resources ATomic Environment) embedding model. 23 It is used to encode atomic environments using multiple sources of information 24 (radial, angular, E(3), message-passing, LODE, etc...) 25 """ 26 27 _graphs_properties: Dict 28 29 dim: int = 256 30 """The size of the embedding vectors.""" 31 nlayers: int = 2 32 """The number of interaction layers in the model.""" 33 keep_all_layers: bool = False 34 """Whether to output all layers.""" 35 36 dim_src: int = 64 37 """The size of the source embedding vectors.""" 38 dim_dst: int = 32 39 """The size of the destination embedding vectors.""" 40 41 angle_style: str = "fourier" 42 """The style of angle representation.""" 43 dim_angle: int = 8 44 """The size of the pairwise vectors use for triplet combinations.""" 45 nmax_angle: int = 4 46 """The dimension of the angle representation (minus one).""" 47 zeta: float = 14.1 48 """The zeta parameter for the model ANI angular representation.""" 49 angle_combine_pairs: bool = True 50 """Whether to combine angle pairs instead of average distance embedding like in ANI.""" 51 52 message_passing: bool = True 53 """Whether to use message passing in the model.""" 54 att_dim: int = 1 55 """The hidden size for the attention mechanism (only used when message-passing is disabled).""" 56 57 lmax: int = 0 58 """The maximum order of spherical tensors.""" 59 nchannels_l: int = 16 60 """The number of channels for spherical tensors.""" 61 n_tp: int = 1 62 """The number of tensor products performed at each layer.""" 63 ignore_irreps_parity: bool = False 64 """Whether to ignore the parity of the irreps in the tensor product.""" 65 edge_tp: bool = False 66 """Whether to perform a tensor product on edges before sending messages.""" 67 resolve_wij_l: bool = False 68 """Equivariant message weights are l-dependent.""" 69 70 species_init: bool = False 71 """Whether to initialize the embedding using the species encoding.""" 72 mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 73 """The hidden layer sizes for the mixing network.""" 74 pair_mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 75 """The hidden layer sizes for the pair mixing network.""" 76 activation: Union[Callable, str] = "silu" 77 """The activation function for the mixing network.""" 78 kernel_init: Union[str, Callable] = "lecun_normal()" 79 """The kernel initialization function for Dense operations.""" 80 activation_mixing: Union[Callable, str] = "tssr3" 81 """The activation function applied after mixing.""" 82 layer_normalization: bool = False 83 """Whether to apply layer normalization after each layer.""" 84 use_bias: bool = True 85 """Whether to use bias in the Dense operations.""" 86 87 graph_key: str = "graph" 88 """The key for the graph data in the inputs dictionary.""" 89 graph_angle_key: Optional[str] = None 90 """The key for the angle graph data in the inputs dictionary.""" 91 embedding_key: Optional[str] = None 92 """The key for the embedding data in the output dictionary.""" 93 pair_embedding_key: Optional[str] = None 94 """The key for the pair embedding data in the output dictionary.""" 95 96 species_encoding: Union[dict, str] = dataclasses.field(default_factory=dict) 97 """If `str`, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 98 radial_basis: dict = dataclasses.field(default_factory=dict) 99 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 100 radial_basis_angle: Optional[dict] = None 101 """The dictionary of parameters for radial basis functions for angle embedding. 102 If None, the radial basis for angles is the same as the radial basis for distances.""" 103 104 graph_lode: Optional[str] = None 105 """The key for the lode graph data in the inputs dictionary.""" 106 lode_channels: Union[int, Sequence[int]] = 8 107 """The number of channels for lode.""" 108 lmax_lode: int = 0 109 """The maximum order of spherical tensors for lode.""" 110 a_lode: float = -1. 111 """The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode.""" 112 lode_resolve_l: bool = True 113 """Whether to resolve the lode channels by l.""" 114 lode_multipole_interaction: bool = True 115 """Whether to interact with the multipole moments of the lode graph.""" 116 lode_direct_multipoles: bool = True 117 """Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction.""" 118 lode_equi_full_combine: bool = False 119 lode_normalize_l: bool = False 120 lode_use_field_norm: bool = True 121 lode_rshort: Optional[float] = None 122 lode_dshort: float = 0.5 123 lode_extra_powers: Sequence[int] = () 124 125 126 charge_embedding: bool = False 127 """Whether to include charge embedding.""" 128 total_charge_key: str = "total_charge" 129 """The key for the total charge data in the inputs dictionary.""" 130 131 block_index_key: Optional[str] = None 132 """The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network.""" 133 134 FID: ClassVar[str] = "CRATE" 135 136 @nn.compact 137 def __call__(self, inputs): 138 species = inputs["species"] 139 assert ( 140 len(species.shape) == 1 141 ), "Species must be a 1D array (batches must be flattened)" 142 reduce_memory = "reduce_memory" in inputs.get("flags", {}) 143 144 kernel_init = ( 145 initializer_from_str(self.kernel_init) 146 if isinstance(self.kernel_init, str) 147 else self.kernel_init 148 ) 149 150 actmix = activation_from_str(self.activation_mixing) 151 152 ################################################## 153 graph = inputs[self.graph_key] 154 use_angles = self.graph_angle_key is not None 155 if use_angles: 156 graph_angle = inputs[self.graph_angle_key] 157 158 # Check that the graph_angle is a subgraph of graph 159 correct_graph = ( 160 self.graph_angle_key == self.graph_key 161 or self._graphs_properties[self.graph_angle_key]["parent_graph"] 162 == self.graph_key 163 ) 164 assert ( 165 correct_graph 166 ), f"graph_angle_key={self.graph_angle_key} must be a subgraph of graph_key={self.graph_key}" 167 assert ( 168 "angles" in graph_angle 169 ), f"Graph {self.graph_angle_key} must contain angles" 170 # check if graph_angle is a filtered graph 171 filtered = "parent_graph" in self._graphs_properties[self.graph_angle_key] 172 if filtered: 173 filter_indices = graph_angle["filter_indices"] 174 175 ################################################## 176 ### SPECIES ENCODING ### 177 if isinstance(self.species_encoding, str): 178 zi = inputs[self.species_encoding] 179 else: 180 zi = SpeciesEncoding( 181 **self.species_encoding, name="SpeciesEncoding" 182 )(species) 183 184 185 if self.layer_normalization: 186 def layer_norm(x): 187 mu = jnp.mean(x,axis=-1,keepdims=True) 188 dx = x-mu 189 var = jnp.mean(dx**2,axis=-1,keepdims=True) 190 sig = (1.e-6 + var)**(-0.5) 191 return dx*sig 192 else: 193 layer_norm = lambda x:x 194 195 196 if self.charge_embedding: 197 xi, qi = jnp.split( 198 nn.Dense(self.dim + 1, use_bias=False, name="ChargeEncoding")(zi), 199 [self.dim], 200 axis=-1, 201 ) 202 batch_index = inputs["batch_index"] 203 natoms = inputs["natoms"] 204 nsys = natoms.shape[0] 205 Zi = jnp.asarray(VALENCE_ELECTRONS)[species] 206 Ntot = jax.ops.segment_sum(Zi, batch_index, nsys) - inputs.get( 207 self.total_charge_key, jnp.zeros(nsys) 208 ) 209 ai = jax.nn.softplus(qi.squeeze(-1)) 210 A = jax.ops.segment_sum(ai, batch_index, nsys) 211 Ni = ai * (Ntot / A)[batch_index] 212 charge_embedding = positional_encoding(Ni, self.dim) 213 xi = layer_norm(xi + charge_embedding) 214 elif self.species_init: 215 xi = layer_norm(nn.Dense(self.dim, use_bias=False, name="SpeciesInit")(zi)) 216 else: 217 xi = zi 218 219 ################################################## 220 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 221 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 222 distances = graph["distances"] 223 switch = graph["switch"][:, None] 224 225 ### COMPUTE RADIAL BASIS ### 226 radial_basis = RadialBasis( 227 **{ 228 **self.radial_basis, 229 "end": cutoff, 230 "name": f"RadialBasis", 231 } 232 )(distances) 233 234 do_lode = self.graph_lode is not None 235 if do_lode: 236 graph_lode = inputs[self.graph_lode] 237 switch_lode = graph_lode["switch"][:, None] 238 239 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 240 r = graph_lode["distances"][:, None] 241 rc = self._graphs_properties[self.graph_lode]["cutoff"] 242 243 lmax_lr = self.lmax_lode 244 equivariant_lode = lmax_lr > 0 245 assert lmax_lr >=0, f"lmax_lode must be >= 0, got {lmax_lr}" 246 if self.lode_multipole_interaction: 247 assert lmax_lr <= self.lmax, f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 248 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 249 if self.lode_resolve_l and equivariant_lode: 250 ls_lr = np.arange(lmax_lr + 1) 251 else: 252 ls_lr = np.array([0]) 253 254 nextra_powers = len(self.lode_extra_powers) 255 if nextra_powers > 0: 256 ls_lr = np.concatenate([self.lode_extra_powers,ls_lr]) 257 258 if self.a_lode > 0: 259 a = self.a_lode**2 260 else: 261 a = ( 262 self.param( 263 "a_lr", 264 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[None, :], 265 ) 266 ** 2 267 ) 268 rc2a = rc**2 + a 269 ls_lr = 0.5 * (ls_lr[None, :] + 1) 270 ### minimal radial basis for long range (damped coulomb) 271 eij_lr = ( 272 1.0 / (r**2 + a) ** ls_lr 273 - 1.0 / rc2a**ls_lr 274 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 275 ) * switch_lode 276 277 if self.lode_rshort is not None: 278 rs = self.lode_rshort 279 d = self.lode_dshort 280 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 281 r < rs + d 282 ) + (r >= rs + d) 283 eij_lr = eij_lr * switch_short 284 285 if nextra_powers>0: 286 eij_lr_extra = eij_lr[:,:nextra_powers] 287 eij_lr = eij_lr[:,nextra_powers:] 288 289 290 # dim_lr = self.nchannels_lode 291 nchannels_lode = ( 292 [self.lode_channels] * self.nlayers 293 if isinstance(self.lode_channels, int) 294 else self.lode_channels 295 ) 296 dim_lr = nchannels_lode 297 298 if equivariant_lode: 299 if self.lode_resolve_l: 300 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 301 Yij = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 302 graph_lode["vec"] / r 303 ) 304 eij_lr = (eij_lr * Yij)[:, None, :] 305 dim_lr = [d * (lmax_lr + 1) for d in dim_lr] 306 307 if nextra_powers > 0: 308 eij_lr_extra = eij_lr_extra[:,None,:] 309 extra_dims = [nextra_powers*d for d in nchannels_lode] 310 dim_lr = [d + ed for d,ed in zip(dim_lr,extra_dims)] 311 312 313 ################################################## 314 ### GET ANGLES ### 315 if use_angles: 316 angles = graph_angle["angles"][:, None] 317 angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"] 318 switch_angles = graph_angle["switch"][:, None] 319 central_atom = graph_angle["central_atom"] 320 321 if not self.angle_combine_pairs: 322 assert ( 323 self.radial_basis_angle is not None 324 ), "radial_basis_angle must be specified if angle_combine_pairs=False" 325 326 ### COMPUTE RADIAL BASIS FOR ANGLES ### 327 if self.radial_basis_angle is not None: 328 dangles = graph_angle["distances"] 329 swang = switch_angles 330 if not self.angle_combine_pairs: 331 dangles = 0.5 * (dangles[angle_src] + dangles[angle_dst]) 332 swang = switch_angles[angle_src] * switch_angles[angle_dst] 333 radial_basis_angle = ( 334 RadialBasis( 335 **{ 336 **self.radial_basis_angle, 337 "end": self._graphs_properties[self.graph_angle_key][ 338 "cutoff" 339 ], 340 "name": f"RadialBasisAngle", 341 } 342 )(dangles) 343 * swang 344 ) 345 346 else: 347 if filtered: 348 radial_basis_angle = radial_basis[filter_indices] * switch_angles 349 else: 350 radial_basis_angle = radial_basis * switch 351 352 radial_basis = radial_basis * switch 353 354 # # add covalent indicator 355 # rc = jnp.asarray([d/au.BOHR for d in D3_COV_RADII])[species] 356 # rcij = rc[edge_src] + rc[edge_dst] 357 # fact = graph["switch"]*(2*distances/rcij)*jnp.exp(-0.5 * ((distances - rcij)/(0.1*rcij)) ** 2) 358 # radial_basis = jnp.concatenate([radial_basis,fact[:,None]],axis=-1) 359 # if use_angles: 360 # rcij = rc[graph_angle["edge_src"]] + rc[graph_angle["edge_dst"]] 361 # dangles = graph_angle["distances"] 362 # fact = graph_angle["switch"]*((2*dangles/rcij))*jnp.exp(-0.5 * ((dangles - rcij)/(0.1*rcij))**2) 363 # radial_basis_angle = jnp.concatenate([radial_basis_angle,fact[:,None]],axis=-1) 364 365 366 ################################################## 367 if self.lmax > 0: 368 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 369 graph["vec"] / graph["distances"][:, None] 370 )[:, None, :] 371 Yij = jnp.broadcast_to(Yij, (Yij.shape[0], self.nchannels_l, Yij.shape[2])) 372 nrep_l = np.array([2 * l + 1 for l in range(self.lmax + 1)], dtype=np.int32) 373 # ls = [0] 374 # for l in range(1, self.lmax + 1): 375 # ls = ls + [l] * (2 * l + 1) 376 #ls = jnp.asarray(np.array(ls)[None, :], dtype=distances.dtype) 377 #lcut = (0.5 + 0.5 * jnp.cos((np.pi / cutoff) * distances[:, #None])) ** ( 378 # ls + 1 379 #) 380 # lcut = jnp.where(graph["edge_mask"][:, None], lcut, 0.0) 381 # rijl1 = (lcut * distances[:, None] ** ls)[:, None, :] 382 383 ################################################## 384 if use_angles: 385 ### ANGULAR BASIS ### 386 if self.angle_style == "fourier": 387 # build fourier series for angles 388 nangles = self.param( 389 f"nangles", 390 lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :], 391 self.nmax_angle + 1, 392 ) 393 394 phi = self.param( 395 f"phi", 396 lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype), 397 self.nmax_angle + 1, 398 ) 399 xa = jnp.cos(nangles * angles + phi) 400 elif self.angle_style == "fourier_full": 401 # build fourier series for angles including sin terms 402 nangles = self.param( 403 f"nangles", 404 lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :], 405 self.nmax_angle + 1, 406 ) 407 408 phi = self.param( 409 f"phi", 410 lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype), 411 2 * self.nmax_angle + 1, 412 ) 413 xac = jnp.cos(nangles * angles + phi[:, : self.nmax_angle + 1]) 414 xas = jnp.sin(nangles[:, 1:] * angles + phi[:, self.nmax_angle + 1 :]) 415 xa = jnp.concatenate([xac, xas], axis=-1) 416 elif self.angle_style == "ani": 417 # ANI-style angle embedding 418 angle_start = np.pi / (2 * (self.nmax_angle + 1)) 419 shiftZ = self.param( 420 f"shiftZ", 421 lambda key, dim: jnp.asarray( 422 (np.linspace(0, np.pi, dim + 1) + angle_start)[None, :-1], 423 dtype=distances.dtype, 424 ), 425 self.nmax_angle + 1, 426 ) 427 zeta = self.param( 428 f"zeta", 429 lambda key: jnp.asarray(self.zeta, dtype=distances.dtype), 430 ) 431 xa = (0.5 + 0.5 * jnp.cos(angles - shiftZ)) ** zeta 432 else: 433 raise ValueError(f"Unknown angle style {self.angle_style}") 434 xa = xa[:, None, :] 435 if not self.angle_combine_pairs: 436 if reduce_memory: raise NotImplementedError("Angle embedding not implemented with reduce_memory") 437 xa = (xa * radial_basis_angle[:, :, None]).reshape( 438 -1, 1, xa.shape[1] * radial_basis_angle.shape[1] 439 ) 440 441 if self.pair_embedding_key is not None: 442 if filtered: 443 ang_pair_src = filter_indices[angle_src] 444 ang_pair_dst = filter_indices[angle_dst] 445 else: 446 ang_pair_src = angle_src 447 ang_pair_dst = angle_dst 448 ang_pairs = jnp.concatenate((ang_pair_src, ang_pair_dst)) 449 450 ################################################## 451 ### DIMENSIONS ### 452 dim_src = ( 453 [self.dim_src] * self.nlayers 454 if isinstance(self.dim_src, int) 455 else self.dim_src 456 ) 457 assert ( 458 len(dim_src) == self.nlayers 459 ), f"dim_src must be an integer or a list of length {self.nlayers}" 460 dim_dst = self.dim_dst 461 # dim_dst = ( 462 # [self.dim_dst] * self.nlayers 463 # if isinstance(self.dim_dst, int) 464 # else self.dim_dst 465 # ) 466 # assert ( 467 # len(dim_dst) == self.nlayers 468 # ), f"dim_dst must be an integer or a list of length {self.nlayers}" 469 470 if use_angles: 471 dim_angle = ( 472 [self.dim_angle] * self.nlayers 473 if isinstance(self.dim_angle, int) 474 else self.dim_angle 475 ) 476 assert ( 477 len(dim_angle) == self.nlayers 478 ), f"dim_angle must be an integer or a list of length {self.nlayers}" 479 # nmax_angle = [self.nmax_angle]*self.nlayers if isinstance(self.nmax_angle, int) else self.nmax_angle 480 # assert len(nmax_angle) == self.nlayers, f"nmax_angle must be an integer or a list of length {self.nlayers}" 481 482 initialize_e3 = True 483 if self.lmax > 0: 484 n_tp = ( 485 [self.n_tp] * self.nlayers 486 if isinstance(self.n_tp, int) 487 else self.n_tp 488 ) 489 assert ( 490 len(n_tp) == self.nlayers 491 ), f"n_tp must be an integer or a list of length {self.nlayers}" 492 493 494 message_passing = ( 495 [self.message_passing] * self.nlayers 496 if isinstance(self.message_passing, bool) 497 else self.message_passing 498 ) 499 assert ( 500 len(message_passing) == self.nlayers 501 ), f"message_passing must be a boolean or a list of length {self.nlayers}" 502 503 ################################################## 504 ### INITIALIZE PAIR EMBEDDING ### 505 if self.pair_embedding_key is not None: 506 xij_s,xij_d = jnp.split(nn.Dense(2*dim_dst, name="pair_init_linear")(zi), [dim_dst], axis=-1) 507 xij = layer_norm(xij_s[edge_src]*xij_d[edge_dst]) 508 509 ################################################## 510 if self.keep_all_layers: 511 xis = [] 512 513 ### LOOP OVER LAYERS ### 514 for layer in range(self.nlayers): 515 ################################################## 516 ### COMPACT DESCRIPTORS ### 517 si, si_dst = jnp.split( 518 nn.Dense( 519 dim_src[layer] + dim_dst, 520 name=f"species_linear_{layer}", 521 use_bias=self.use_bias, 522 )(xi), 523 [ 524 dim_src[layer], 525 ], 526 axis=-1, 527 ) 528 529 ################################################## 530 if message_passing[layer] or layer == 0: 531 ### MESSAGE PASSING ### 532 si_mp = si_dst[edge_dst] 533 else: 534 # if layer == 0: 535 # si_mp = si_dst[edge_dst] 536 ### ATTENTION TO SIMULATE MP ### 537 Q = nn.Dense( 538 dim_dst * self.att_dim, name=f"queries_{layer}", use_bias=False 539 )(si_dst).reshape(-1, dim_dst, self.att_dim)[edge_src] 540 K = nn.Dense( 541 dim_dst * self.att_dim, name=f"keys_{layer}", use_bias=False 542 )(zi).reshape(-1, dim_dst, self.att_dim)[edge_dst] 543 544 si_mp = (K * Q).sum(axis=-1) / self.att_dim**0.5 545 # Vmp = jax.ops.segment_sum( 546 # (KQ * switch)[:, :, None] * Yij, edge_src, species.shape[0] 547 # ) 548 # si_mp = (Vmp[edge_src] * Yij).sum(axis=-1) 549 # Q = nn.Dense( 550 # dim_dst * dim_dst, name=f"queries_{layer}", use_bias=False 551 # )(si_dst).reshape(-1, dim_dst, dim_dst) 552 # si_mp = ( 553 # si_mp + jax.vmap(jnp.dot)(Q[edge_src], si_mp) / self.dim_dst**0.5 554 # ) 555 556 if self.pair_embedding_key is not None: 557 si_mp = si_mp + xij 558 559 ################################################## 560 ### PAIR EMBEDDING ### 561 if reduce_memory: 562 Li = jnp.zeros((species.shape[0]* radial_basis.shape[1],si_mp.shape[1]),dtype=si_mp.dtype) 563 for i in range(radial_basis.shape[1]): 564 indices = i + edge_src*radial_basis.shape[1] 565 Li = Li.at[indices].add(si_mp*radial_basis[:,i,None]) 566 Li = Li.reshape(species.shape[0], radial_basis.shape[1]*si_mp.shape[1]) 567 else: 568 Lij = (si_mp[:, None, :] * radial_basis[:, :, None]).reshape( 569 radial_basis.shape[0], si_mp.shape[1] * radial_basis.shape[1] 570 ) 571 ### AGGREGATE PAIR EMBEDDING ### 572 Li = jax.ops.segment_sum(Lij, edge_src, species.shape[0]) 573 574 ### CONCATENATE EMBEDDING COMPONENTS ### 575 components = [si, Li] 576 if self.pair_embedding_key is not None: 577 if reduce_memory: raise NotImplementedError("Pair embedding not implemented with reduce_memory") 578 components_pair = [si[edge_src], xij, Lij] 579 580 581 ################################################## 582 ### ANGLE EMBEDDING ### 583 if use_angles and dim_angle[layer]>0: 584 si_mp_ang = si_mp[filter_indices] if filtered else si_mp 585 if self.angle_combine_pairs: 586 Wa = self.param( 587 f"Wa_{layer}", 588 nn.initializers.normal( 589 stddev=1.0 590 / (si_mp.shape[1] * radial_basis_angle.shape[1]) ** 0.5 591 ), 592 (si_mp.shape[1], radial_basis_angle.shape[1], dim_angle[layer]), 593 ) 594 Da = jnp.einsum( 595 "...i,...j,ijk->...k", 596 si_mp_ang, 597 radial_basis_angle, 598 Wa, 599 ) 600 601 else: 602 if message_passing[layer] or layer == 0: 603 Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")( 604 xi 605 )[graph_angle["edge_dst"]] 606 else: 607 Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")( 608 si_mp_ang 609 ) 610 611 Da = Da[angle_dst] * Da[angle_src] 612 ## combine pair and angle info 613 if reduce_memory: 614 ang_embedding = jnp.zeros((species.shape[0]* Da.shape[-1],xa.shape[-1]),dtype=Da.dtype) 615 for i in range(Da.shape[-1]): 616 indices = i + central_atom*Da.shape[-1] 617 ang_embedding = ang_embedding.at[indices].add(Da[:,i,None]*xa[:,0,:]) 618 ang_embedding = ang_embedding.reshape(species.shape[0], xa.shape[-1]*Da.shape[-1]) 619 else: 620 radang = (xa * Da[:, :, None]).reshape( 621 (-1, Da.shape[1] * xa.shape[2]) 622 ) 623 ### AGGREGATE ANGLE EMBEDDING ### 624 ang_embedding = jax.ops.segment_sum( 625 radang, central_atom, species.shape[0] 626 ) 627 628 629 components.append(ang_embedding) 630 631 if self.pair_embedding_key is not None: 632 ang_ij = jax.ops.segment_sum( 633 jnp.concatenate((radang, radang)), 634 ang_pairs, 635 edge_src.shape[0], 636 ) 637 components_pair.append(ang_ij) 638 639 ################################################## 640 ### EQUIVARIANT EMBEDDING ### 641 if self.lmax > 0 and n_tp[layer] >= 0: 642 if initialize_e3 or not message_passing[layer]: 643 Vij = Yij 644 elif self.edge_tp: 645 Vij = FilteredTensorProduct( 646 self.lmax, self.lmax, name=f"edge_tp_{layer}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False 647 )(Vi[edge_dst], Yij) 648 else: 649 Vij = Vi[edge_dst] 650 651 ### compute channel weights 652 dim_wij = self.nchannels_l 653 if self.resolve_wij_l: 654 dim_wij=self.nchannels_l*(self.lmax+1) 655 656 eij = Lij if self.pair_embedding_key is None else jnp.concatenate([Lij,xij*switch],axis=-1) 657 wij = nn.Dense( 658 dim_wij, name=f"e3_channel_{layer}", use_bias=False 659 )(eij) 660 if self.resolve_wij_l: 661 wij = jnp.repeat(wij.reshape(-1,self.nchannels_l,self.lmax+1),nrep_l,axis=-1) 662 else: 663 wij = wij[:,:,None] 664 665 ### aggregate equivariant messages 666 drhoi = jax.ops.segment_sum( 667 wij * Vij, 668 edge_src, 669 species.shape[0], 670 ) 671 672 Vi0 = [] 673 if initialize_e3: 674 rhoi = drhoi 675 Vi = ChannelMixingE3( 676 self.lmax, 677 self.nchannels_l, 678 self.nchannels_l, 679 name=f"e3_initial_mixing_{layer}", 680 )(rhoi) 681 # assert n_tp[layer] > 0, "n_tp must be > 0 for the first equivariant layer." 682 else: 683 rhoi = rhoi + drhoi 684 # if message_passing[layer]: 685 # Vi0.append(drhoi[:, :, 0]) 686 initialize_e3 = False 687 if n_tp[layer] > 0: 688 for itp in range(n_tp[layer]): 689 dVi = FilteredTensorProduct( 690 self.lmax, self.lmax, name=f"tensor_product_{layer}_{itp}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False 691 )(rhoi, Vi) 692 Vi = ChannelMixing( 693 self.lmax, 694 self.nchannels_l, 695 self.nchannels_l, 696 name=f"tp_mixing_{layer}_{itp}", 697 )(Vi + dVi) 698 Vi0.append(dVi[:, :, 0]) 699 Vi0 = jnp.concatenate(Vi0, axis=-1) 700 components.append(Vi0) 701 702 if self.pair_embedding_key is not None: 703 Vij = Vi[edge_src]*Yij 704 Vij0 = [Vij[...,0]] 705 for l in range(1,self.lmax+1): 706 Vij0.append(Vij[...,l**2:(l+1)**2].sum(axis=-1)) 707 Vij0 = jnp.concatenate(Vij0,axis=-1) 708 components_pair.append(Vij0) 709 710 ################################################## 711 ### CONCATENATE EMBEDDING COMPONENTS ### 712 if do_lode and nchannels_lode[layer] > 0: 713 zj = nn.Dense(dim_lr[layer], use_bias=False, name=f"LODE_{layer}")(xi) 714 if nextra_powers > 0: 715 zj_extra = zj[:,:nextra_powers*nchannels_lode[layer]].reshape( 716 (species.shape[0],nchannels_lode[layer], nextra_powers) 717 ) 718 zj = zj[:,nextra_powers*nchannels_lode[layer]:] 719 xi_lr_extra = jax.ops.segment_sum( 720 eij_lr_extra * zj_extra[edge_dst_lr], edge_src_lr, species.shape[0] 721 ) 722 components.append(xi_lr_extra.reshape(species.shape[0],-1)) 723 724 if equivariant_lode: 725 zj = zj.reshape( 726 (species.shape[0], nchannels_lode[layer], lmax_lr + 1) 727 ).repeat(nrep_lr, axis=-1) 728 xi_lr = jax.ops.segment_sum( 729 eij_lr * zj[edge_dst_lr], edge_src_lr, species.shape[0] 730 ) 731 if equivariant_lode: 732 assert self.lode_use_field_norm or self.lode_multipole_interaction, "equivariant LODE requires field norm or multipole interaction" 733 if self.lode_multipole_interaction: 734 if initialize_e3: 735 raise ValueError("equivariant LODE used before local equivariants initialized") 736 size_l_lr = (lmax_lr+1)**2 737 if self.lode_direct_multipoles: 738 assert nchannels_lode[layer] <= self.nchannels_l 739 Mi = Vi[:, : nchannels_lode[layer], :size_l_lr] 740 else: 741 Mi = ChannelMixingE3( 742 lmax_lr, 743 self.nchannels_l, 744 nchannels_lode[layer], 745 name=f"e3_LODE_{layer}", 746 )(Vi[...,:size_l_lr]) 747 Mi_lr = Mi * xi_lr 748 components.append(xi_lr[:, :, 0]) 749 if self.lode_use_field_norm and self.lode_equi_full_combine: 750 xi_lr1 = ChannelMixing( 751 lmax_lr, 752 nchannels_lode[layer], 753 nchannels_lode[layer], 754 name=f"LODE_mixing_{layer}", 755 )(xi_lr) 756 norm = 1. 757 for l in range(1, lmax_lr + 1): 758 if self.lode_normalize_l: 759 norm = 1. / (2 * l + 1) 760 if self.lode_multipole_interaction: 761 components.append(Mi_lr[:, :, l**2 : (l + 1) ** 2].sum(axis=-1)*norm) 762 763 if self.lode_use_field_norm: 764 if self.lode_equi_full_combine: 765 components.append((xi_lr[:,:,l**2 : (l + 1) ** 2]*xi_lr1[:,:,l**2 : (l + 1) ** 2]).sum(axis=-1)*norm) 766 else: 767 components.append( 768 ((xi_lr[:, :, l**2 : (l + 1) ** 2]) ** 2).sum(axis=-1)*norm 769 ) 770 else: 771 components.append(xi_lr) 772 773 dxi = jnp.concatenate(components, axis=-1) 774 775 ################################################## 776 ### CONCATENATE PAIR EMBEDDING COMPONENTS ### 777 if self.pair_embedding_key is not None: 778 dxij = jnp.concatenate(components_pair, axis=-1) 779 780 ################################################## 781 ### MIX AND APPLY NONLINEARITY ### 782 if self.block_index_key is not None: 783 block_index = inputs[self.block_index_key] 784 dxi = actmix(BlockIndexNet( 785 output_dim=self.dim, 786 hidden_neurons=self.mixing_hidden, 787 activation=self.activation, 788 name=f"dxi_{layer}", 789 use_bias=self.use_bias, 790 kernel_init=kernel_init, 791 )((species,dxi, block_index)) 792 ) 793 else: 794 dxi = actmix( 795 FullyConnectedNet( 796 [*self.mixing_hidden, self.dim], 797 activation=self.activation, 798 name=f"dxi_{layer}", 799 use_bias=self.use_bias, 800 kernel_init=kernel_init, 801 )(dxi) 802 ) 803 804 if self.pair_embedding_key is not None: 805 ### UPDATE PAIR EMBEDDING ### 806 # dxij = tssr3(nn.Dense(dim_dst, name=f"dxij_{layer}",use_bias=False)(dxij)) 807 dxij = actmix( 808 FullyConnectedNet( 809 [*self.pair_mixing_hidden, dim_dst], 810 activation=self.activation, 811 name=f"dxij_{layer}", 812 use_bias=False, 813 kernel_init=kernel_init, 814 )(dxij) 815 ) 816 xij = layer_norm(xij + dxij) 817 818 ################################################## 819 ### UPDATE EMBEDDING ### 820 if layer == 0 and not (self.species_init or self.charge_embedding): 821 xi = layer_norm(dxi) 822 else: 823 ### FORGET GATE ### 824 R = jax.nn.sigmoid( 825 self.param( 826 f"retention_{layer}", 827 nn.initializers.normal(), 828 (xi.shape[-1],), 829 ) 830 ) 831 xi = layer_norm(R[None, :] * xi + dxi) 832 833 if self.keep_all_layers: 834 xis.append(xi) 835 836 embedding_key = ( 837 self.embedding_key if self.embedding_key is not None else self.name 838 ) 839 output = { 840 **inputs, 841 embedding_key: xi, 842 } 843 if self.lmax > 0: 844 output[embedding_key + "_tensor"] = Vi 845 if self.keep_all_layers: 846 output[embedding_key + "_layers"] = jnp.stack(xis, axis=1) 847 if self.charge_embedding: 848 output[embedding_key + "_charge"] = charge_embedding 849 if self.pair_embedding_key is not None: 850 output[self.pair_embedding_key] = xij 851 return output
18class CRATEmbedding(nn.Module): 19 """Configurable Resources ATomic Environment 20 21 FID : CRATE 22 23 This class represents the CRATE (Configurable Resources ATomic Environment) embedding model. 24 It is used to encode atomic environments using multiple sources of information 25 (radial, angular, E(3), message-passing, LODE, etc...) 26 """ 27 28 _graphs_properties: Dict 29 30 dim: int = 256 31 """The size of the embedding vectors.""" 32 nlayers: int = 2 33 """The number of interaction layers in the model.""" 34 keep_all_layers: bool = False 35 """Whether to output all layers.""" 36 37 dim_src: int = 64 38 """The size of the source embedding vectors.""" 39 dim_dst: int = 32 40 """The size of the destination embedding vectors.""" 41 42 angle_style: str = "fourier" 43 """The style of angle representation.""" 44 dim_angle: int = 8 45 """The size of the pairwise vectors use for triplet combinations.""" 46 nmax_angle: int = 4 47 """The dimension of the angle representation (minus one).""" 48 zeta: float = 14.1 49 """The zeta parameter for the model ANI angular representation.""" 50 angle_combine_pairs: bool = True 51 """Whether to combine angle pairs instead of average distance embedding like in ANI.""" 52 53 message_passing: bool = True 54 """Whether to use message passing in the model.""" 55 att_dim: int = 1 56 """The hidden size for the attention mechanism (only used when message-passing is disabled).""" 57 58 lmax: int = 0 59 """The maximum order of spherical tensors.""" 60 nchannels_l: int = 16 61 """The number of channels for spherical tensors.""" 62 n_tp: int = 1 63 """The number of tensor products performed at each layer.""" 64 ignore_irreps_parity: bool = False 65 """Whether to ignore the parity of the irreps in the tensor product.""" 66 edge_tp: bool = False 67 """Whether to perform a tensor product on edges before sending messages.""" 68 resolve_wij_l: bool = False 69 """Equivariant message weights are l-dependent.""" 70 71 species_init: bool = False 72 """Whether to initialize the embedding using the species encoding.""" 73 mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 74 """The hidden layer sizes for the mixing network.""" 75 pair_mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 76 """The hidden layer sizes for the pair mixing network.""" 77 activation: Union[Callable, str] = "silu" 78 """The activation function for the mixing network.""" 79 kernel_init: Union[str, Callable] = "lecun_normal()" 80 """The kernel initialization function for Dense operations.""" 81 activation_mixing: Union[Callable, str] = "tssr3" 82 """The activation function applied after mixing.""" 83 layer_normalization: bool = False 84 """Whether to apply layer normalization after each layer.""" 85 use_bias: bool = True 86 """Whether to use bias in the Dense operations.""" 87 88 graph_key: str = "graph" 89 """The key for the graph data in the inputs dictionary.""" 90 graph_angle_key: Optional[str] = None 91 """The key for the angle graph data in the inputs dictionary.""" 92 embedding_key: Optional[str] = None 93 """The key for the embedding data in the output dictionary.""" 94 pair_embedding_key: Optional[str] = None 95 """The key for the pair embedding data in the output dictionary.""" 96 97 species_encoding: Union[dict, str] = dataclasses.field(default_factory=dict) 98 """If `str`, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 99 radial_basis: dict = dataclasses.field(default_factory=dict) 100 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 101 radial_basis_angle: Optional[dict] = None 102 """The dictionary of parameters for radial basis functions for angle embedding. 103 If None, the radial basis for angles is the same as the radial basis for distances.""" 104 105 graph_lode: Optional[str] = None 106 """The key for the lode graph data in the inputs dictionary.""" 107 lode_channels: Union[int, Sequence[int]] = 8 108 """The number of channels for lode.""" 109 lmax_lode: int = 0 110 """The maximum order of spherical tensors for lode.""" 111 a_lode: float = -1. 112 """The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode.""" 113 lode_resolve_l: bool = True 114 """Whether to resolve the lode channels by l.""" 115 lode_multipole_interaction: bool = True 116 """Whether to interact with the multipole moments of the lode graph.""" 117 lode_direct_multipoles: bool = True 118 """Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction.""" 119 lode_equi_full_combine: bool = False 120 lode_normalize_l: bool = False 121 lode_use_field_norm: bool = True 122 lode_rshort: Optional[float] = None 123 lode_dshort: float = 0.5 124 lode_extra_powers: Sequence[int] = () 125 126 127 charge_embedding: bool = False 128 """Whether to include charge embedding.""" 129 total_charge_key: str = "total_charge" 130 """The key for the total charge data in the inputs dictionary.""" 131 132 block_index_key: Optional[str] = None 133 """The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network.""" 134 135 FID: ClassVar[str] = "CRATE" 136 137 @nn.compact 138 def __call__(self, inputs): 139 species = inputs["species"] 140 assert ( 141 len(species.shape) == 1 142 ), "Species must be a 1D array (batches must be flattened)" 143 reduce_memory = "reduce_memory" in inputs.get("flags", {}) 144 145 kernel_init = ( 146 initializer_from_str(self.kernel_init) 147 if isinstance(self.kernel_init, str) 148 else self.kernel_init 149 ) 150 151 actmix = activation_from_str(self.activation_mixing) 152 153 ################################################## 154 graph = inputs[self.graph_key] 155 use_angles = self.graph_angle_key is not None 156 if use_angles: 157 graph_angle = inputs[self.graph_angle_key] 158 159 # Check that the graph_angle is a subgraph of graph 160 correct_graph = ( 161 self.graph_angle_key == self.graph_key 162 or self._graphs_properties[self.graph_angle_key]["parent_graph"] 163 == self.graph_key 164 ) 165 assert ( 166 correct_graph 167 ), f"graph_angle_key={self.graph_angle_key} must be a subgraph of graph_key={self.graph_key}" 168 assert ( 169 "angles" in graph_angle 170 ), f"Graph {self.graph_angle_key} must contain angles" 171 # check if graph_angle is a filtered graph 172 filtered = "parent_graph" in self._graphs_properties[self.graph_angle_key] 173 if filtered: 174 filter_indices = graph_angle["filter_indices"] 175 176 ################################################## 177 ### SPECIES ENCODING ### 178 if isinstance(self.species_encoding, str): 179 zi = inputs[self.species_encoding] 180 else: 181 zi = SpeciesEncoding( 182 **self.species_encoding, name="SpeciesEncoding" 183 )(species) 184 185 186 if self.layer_normalization: 187 def layer_norm(x): 188 mu = jnp.mean(x,axis=-1,keepdims=True) 189 dx = x-mu 190 var = jnp.mean(dx**2,axis=-1,keepdims=True) 191 sig = (1.e-6 + var)**(-0.5) 192 return dx*sig 193 else: 194 layer_norm = lambda x:x 195 196 197 if self.charge_embedding: 198 xi, qi = jnp.split( 199 nn.Dense(self.dim + 1, use_bias=False, name="ChargeEncoding")(zi), 200 [self.dim], 201 axis=-1, 202 ) 203 batch_index = inputs["batch_index"] 204 natoms = inputs["natoms"] 205 nsys = natoms.shape[0] 206 Zi = jnp.asarray(VALENCE_ELECTRONS)[species] 207 Ntot = jax.ops.segment_sum(Zi, batch_index, nsys) - inputs.get( 208 self.total_charge_key, jnp.zeros(nsys) 209 ) 210 ai = jax.nn.softplus(qi.squeeze(-1)) 211 A = jax.ops.segment_sum(ai, batch_index, nsys) 212 Ni = ai * (Ntot / A)[batch_index] 213 charge_embedding = positional_encoding(Ni, self.dim) 214 xi = layer_norm(xi + charge_embedding) 215 elif self.species_init: 216 xi = layer_norm(nn.Dense(self.dim, use_bias=False, name="SpeciesInit")(zi)) 217 else: 218 xi = zi 219 220 ################################################## 221 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 222 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 223 distances = graph["distances"] 224 switch = graph["switch"][:, None] 225 226 ### COMPUTE RADIAL BASIS ### 227 radial_basis = RadialBasis( 228 **{ 229 **self.radial_basis, 230 "end": cutoff, 231 "name": f"RadialBasis", 232 } 233 )(distances) 234 235 do_lode = self.graph_lode is not None 236 if do_lode: 237 graph_lode = inputs[self.graph_lode] 238 switch_lode = graph_lode["switch"][:, None] 239 240 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 241 r = graph_lode["distances"][:, None] 242 rc = self._graphs_properties[self.graph_lode]["cutoff"] 243 244 lmax_lr = self.lmax_lode 245 equivariant_lode = lmax_lr > 0 246 assert lmax_lr >=0, f"lmax_lode must be >= 0, got {lmax_lr}" 247 if self.lode_multipole_interaction: 248 assert lmax_lr <= self.lmax, f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 249 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 250 if self.lode_resolve_l and equivariant_lode: 251 ls_lr = np.arange(lmax_lr + 1) 252 else: 253 ls_lr = np.array([0]) 254 255 nextra_powers = len(self.lode_extra_powers) 256 if nextra_powers > 0: 257 ls_lr = np.concatenate([self.lode_extra_powers,ls_lr]) 258 259 if self.a_lode > 0: 260 a = self.a_lode**2 261 else: 262 a = ( 263 self.param( 264 "a_lr", 265 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[None, :], 266 ) 267 ** 2 268 ) 269 rc2a = rc**2 + a 270 ls_lr = 0.5 * (ls_lr[None, :] + 1) 271 ### minimal radial basis for long range (damped coulomb) 272 eij_lr = ( 273 1.0 / (r**2 + a) ** ls_lr 274 - 1.0 / rc2a**ls_lr 275 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 276 ) * switch_lode 277 278 if self.lode_rshort is not None: 279 rs = self.lode_rshort 280 d = self.lode_dshort 281 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 282 r < rs + d 283 ) + (r >= rs + d) 284 eij_lr = eij_lr * switch_short 285 286 if nextra_powers>0: 287 eij_lr_extra = eij_lr[:,:nextra_powers] 288 eij_lr = eij_lr[:,nextra_powers:] 289 290 291 # dim_lr = self.nchannels_lode 292 nchannels_lode = ( 293 [self.lode_channels] * self.nlayers 294 if isinstance(self.lode_channels, int) 295 else self.lode_channels 296 ) 297 dim_lr = nchannels_lode 298 299 if equivariant_lode: 300 if self.lode_resolve_l: 301 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 302 Yij = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 303 graph_lode["vec"] / r 304 ) 305 eij_lr = (eij_lr * Yij)[:, None, :] 306 dim_lr = [d * (lmax_lr + 1) for d in dim_lr] 307 308 if nextra_powers > 0: 309 eij_lr_extra = eij_lr_extra[:,None,:] 310 extra_dims = [nextra_powers*d for d in nchannels_lode] 311 dim_lr = [d + ed for d,ed in zip(dim_lr,extra_dims)] 312 313 314 ################################################## 315 ### GET ANGLES ### 316 if use_angles: 317 angles = graph_angle["angles"][:, None] 318 angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"] 319 switch_angles = graph_angle["switch"][:, None] 320 central_atom = graph_angle["central_atom"] 321 322 if not self.angle_combine_pairs: 323 assert ( 324 self.radial_basis_angle is not None 325 ), "radial_basis_angle must be specified if angle_combine_pairs=False" 326 327 ### COMPUTE RADIAL BASIS FOR ANGLES ### 328 if self.radial_basis_angle is not None: 329 dangles = graph_angle["distances"] 330 swang = switch_angles 331 if not self.angle_combine_pairs: 332 dangles = 0.5 * (dangles[angle_src] + dangles[angle_dst]) 333 swang = switch_angles[angle_src] * switch_angles[angle_dst] 334 radial_basis_angle = ( 335 RadialBasis( 336 **{ 337 **self.radial_basis_angle, 338 "end": self._graphs_properties[self.graph_angle_key][ 339 "cutoff" 340 ], 341 "name": f"RadialBasisAngle", 342 } 343 )(dangles) 344 * swang 345 ) 346 347 else: 348 if filtered: 349 radial_basis_angle = radial_basis[filter_indices] * switch_angles 350 else: 351 radial_basis_angle = radial_basis * switch 352 353 radial_basis = radial_basis * switch 354 355 # # add covalent indicator 356 # rc = jnp.asarray([d/au.BOHR for d in D3_COV_RADII])[species] 357 # rcij = rc[edge_src] + rc[edge_dst] 358 # fact = graph["switch"]*(2*distances/rcij)*jnp.exp(-0.5 * ((distances - rcij)/(0.1*rcij)) ** 2) 359 # radial_basis = jnp.concatenate([radial_basis,fact[:,None]],axis=-1) 360 # if use_angles: 361 # rcij = rc[graph_angle["edge_src"]] + rc[graph_angle["edge_dst"]] 362 # dangles = graph_angle["distances"] 363 # fact = graph_angle["switch"]*((2*dangles/rcij))*jnp.exp(-0.5 * ((dangles - rcij)/(0.1*rcij))**2) 364 # radial_basis_angle = jnp.concatenate([radial_basis_angle,fact[:,None]],axis=-1) 365 366 367 ################################################## 368 if self.lmax > 0: 369 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 370 graph["vec"] / graph["distances"][:, None] 371 )[:, None, :] 372 Yij = jnp.broadcast_to(Yij, (Yij.shape[0], self.nchannels_l, Yij.shape[2])) 373 nrep_l = np.array([2 * l + 1 for l in range(self.lmax + 1)], dtype=np.int32) 374 # ls = [0] 375 # for l in range(1, self.lmax + 1): 376 # ls = ls + [l] * (2 * l + 1) 377 #ls = jnp.asarray(np.array(ls)[None, :], dtype=distances.dtype) 378 #lcut = (0.5 + 0.5 * jnp.cos((np.pi / cutoff) * distances[:, #None])) ** ( 379 # ls + 1 380 #) 381 # lcut = jnp.where(graph["edge_mask"][:, None], lcut, 0.0) 382 # rijl1 = (lcut * distances[:, None] ** ls)[:, None, :] 383 384 ################################################## 385 if use_angles: 386 ### ANGULAR BASIS ### 387 if self.angle_style == "fourier": 388 # build fourier series for angles 389 nangles = self.param( 390 f"nangles", 391 lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :], 392 self.nmax_angle + 1, 393 ) 394 395 phi = self.param( 396 f"phi", 397 lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype), 398 self.nmax_angle + 1, 399 ) 400 xa = jnp.cos(nangles * angles + phi) 401 elif self.angle_style == "fourier_full": 402 # build fourier series for angles including sin terms 403 nangles = self.param( 404 f"nangles", 405 lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :], 406 self.nmax_angle + 1, 407 ) 408 409 phi = self.param( 410 f"phi", 411 lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype), 412 2 * self.nmax_angle + 1, 413 ) 414 xac = jnp.cos(nangles * angles + phi[:, : self.nmax_angle + 1]) 415 xas = jnp.sin(nangles[:, 1:] * angles + phi[:, self.nmax_angle + 1 :]) 416 xa = jnp.concatenate([xac, xas], axis=-1) 417 elif self.angle_style == "ani": 418 # ANI-style angle embedding 419 angle_start = np.pi / (2 * (self.nmax_angle + 1)) 420 shiftZ = self.param( 421 f"shiftZ", 422 lambda key, dim: jnp.asarray( 423 (np.linspace(0, np.pi, dim + 1) + angle_start)[None, :-1], 424 dtype=distances.dtype, 425 ), 426 self.nmax_angle + 1, 427 ) 428 zeta = self.param( 429 f"zeta", 430 lambda key: jnp.asarray(self.zeta, dtype=distances.dtype), 431 ) 432 xa = (0.5 + 0.5 * jnp.cos(angles - shiftZ)) ** zeta 433 else: 434 raise ValueError(f"Unknown angle style {self.angle_style}") 435 xa = xa[:, None, :] 436 if not self.angle_combine_pairs: 437 if reduce_memory: raise NotImplementedError("Angle embedding not implemented with reduce_memory") 438 xa = (xa * radial_basis_angle[:, :, None]).reshape( 439 -1, 1, xa.shape[1] * radial_basis_angle.shape[1] 440 ) 441 442 if self.pair_embedding_key is not None: 443 if filtered: 444 ang_pair_src = filter_indices[angle_src] 445 ang_pair_dst = filter_indices[angle_dst] 446 else: 447 ang_pair_src = angle_src 448 ang_pair_dst = angle_dst 449 ang_pairs = jnp.concatenate((ang_pair_src, ang_pair_dst)) 450 451 ################################################## 452 ### DIMENSIONS ### 453 dim_src = ( 454 [self.dim_src] * self.nlayers 455 if isinstance(self.dim_src, int) 456 else self.dim_src 457 ) 458 assert ( 459 len(dim_src) == self.nlayers 460 ), f"dim_src must be an integer or a list of length {self.nlayers}" 461 dim_dst = self.dim_dst 462 # dim_dst = ( 463 # [self.dim_dst] * self.nlayers 464 # if isinstance(self.dim_dst, int) 465 # else self.dim_dst 466 # ) 467 # assert ( 468 # len(dim_dst) == self.nlayers 469 # ), f"dim_dst must be an integer or a list of length {self.nlayers}" 470 471 if use_angles: 472 dim_angle = ( 473 [self.dim_angle] * self.nlayers 474 if isinstance(self.dim_angle, int) 475 else self.dim_angle 476 ) 477 assert ( 478 len(dim_angle) == self.nlayers 479 ), f"dim_angle must be an integer or a list of length {self.nlayers}" 480 # nmax_angle = [self.nmax_angle]*self.nlayers if isinstance(self.nmax_angle, int) else self.nmax_angle 481 # assert len(nmax_angle) == self.nlayers, f"nmax_angle must be an integer or a list of length {self.nlayers}" 482 483 initialize_e3 = True 484 if self.lmax > 0: 485 n_tp = ( 486 [self.n_tp] * self.nlayers 487 if isinstance(self.n_tp, int) 488 else self.n_tp 489 ) 490 assert ( 491 len(n_tp) == self.nlayers 492 ), f"n_tp must be an integer or a list of length {self.nlayers}" 493 494 495 message_passing = ( 496 [self.message_passing] * self.nlayers 497 if isinstance(self.message_passing, bool) 498 else self.message_passing 499 ) 500 assert ( 501 len(message_passing) == self.nlayers 502 ), f"message_passing must be a boolean or a list of length {self.nlayers}" 503 504 ################################################## 505 ### INITIALIZE PAIR EMBEDDING ### 506 if self.pair_embedding_key is not None: 507 xij_s,xij_d = jnp.split(nn.Dense(2*dim_dst, name="pair_init_linear")(zi), [dim_dst], axis=-1) 508 xij = layer_norm(xij_s[edge_src]*xij_d[edge_dst]) 509 510 ################################################## 511 if self.keep_all_layers: 512 xis = [] 513 514 ### LOOP OVER LAYERS ### 515 for layer in range(self.nlayers): 516 ################################################## 517 ### COMPACT DESCRIPTORS ### 518 si, si_dst = jnp.split( 519 nn.Dense( 520 dim_src[layer] + dim_dst, 521 name=f"species_linear_{layer}", 522 use_bias=self.use_bias, 523 )(xi), 524 [ 525 dim_src[layer], 526 ], 527 axis=-1, 528 ) 529 530 ################################################## 531 if message_passing[layer] or layer == 0: 532 ### MESSAGE PASSING ### 533 si_mp = si_dst[edge_dst] 534 else: 535 # if layer == 0: 536 # si_mp = si_dst[edge_dst] 537 ### ATTENTION TO SIMULATE MP ### 538 Q = nn.Dense( 539 dim_dst * self.att_dim, name=f"queries_{layer}", use_bias=False 540 )(si_dst).reshape(-1, dim_dst, self.att_dim)[edge_src] 541 K = nn.Dense( 542 dim_dst * self.att_dim, name=f"keys_{layer}", use_bias=False 543 )(zi).reshape(-1, dim_dst, self.att_dim)[edge_dst] 544 545 si_mp = (K * Q).sum(axis=-1) / self.att_dim**0.5 546 # Vmp = jax.ops.segment_sum( 547 # (KQ * switch)[:, :, None] * Yij, edge_src, species.shape[0] 548 # ) 549 # si_mp = (Vmp[edge_src] * Yij).sum(axis=-1) 550 # Q = nn.Dense( 551 # dim_dst * dim_dst, name=f"queries_{layer}", use_bias=False 552 # )(si_dst).reshape(-1, dim_dst, dim_dst) 553 # si_mp = ( 554 # si_mp + jax.vmap(jnp.dot)(Q[edge_src], si_mp) / self.dim_dst**0.5 555 # ) 556 557 if self.pair_embedding_key is not None: 558 si_mp = si_mp + xij 559 560 ################################################## 561 ### PAIR EMBEDDING ### 562 if reduce_memory: 563 Li = jnp.zeros((species.shape[0]* radial_basis.shape[1],si_mp.shape[1]),dtype=si_mp.dtype) 564 for i in range(radial_basis.shape[1]): 565 indices = i + edge_src*radial_basis.shape[1] 566 Li = Li.at[indices].add(si_mp*radial_basis[:,i,None]) 567 Li = Li.reshape(species.shape[0], radial_basis.shape[1]*si_mp.shape[1]) 568 else: 569 Lij = (si_mp[:, None, :] * radial_basis[:, :, None]).reshape( 570 radial_basis.shape[0], si_mp.shape[1] * radial_basis.shape[1] 571 ) 572 ### AGGREGATE PAIR EMBEDDING ### 573 Li = jax.ops.segment_sum(Lij, edge_src, species.shape[0]) 574 575 ### CONCATENATE EMBEDDING COMPONENTS ### 576 components = [si, Li] 577 if self.pair_embedding_key is not None: 578 if reduce_memory: raise NotImplementedError("Pair embedding not implemented with reduce_memory") 579 components_pair = [si[edge_src], xij, Lij] 580 581 582 ################################################## 583 ### ANGLE EMBEDDING ### 584 if use_angles and dim_angle[layer]>0: 585 si_mp_ang = si_mp[filter_indices] if filtered else si_mp 586 if self.angle_combine_pairs: 587 Wa = self.param( 588 f"Wa_{layer}", 589 nn.initializers.normal( 590 stddev=1.0 591 / (si_mp.shape[1] * radial_basis_angle.shape[1]) ** 0.5 592 ), 593 (si_mp.shape[1], radial_basis_angle.shape[1], dim_angle[layer]), 594 ) 595 Da = jnp.einsum( 596 "...i,...j,ijk->...k", 597 si_mp_ang, 598 radial_basis_angle, 599 Wa, 600 ) 601 602 else: 603 if message_passing[layer] or layer == 0: 604 Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")( 605 xi 606 )[graph_angle["edge_dst"]] 607 else: 608 Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")( 609 si_mp_ang 610 ) 611 612 Da = Da[angle_dst] * Da[angle_src] 613 ## combine pair and angle info 614 if reduce_memory: 615 ang_embedding = jnp.zeros((species.shape[0]* Da.shape[-1],xa.shape[-1]),dtype=Da.dtype) 616 for i in range(Da.shape[-1]): 617 indices = i + central_atom*Da.shape[-1] 618 ang_embedding = ang_embedding.at[indices].add(Da[:,i,None]*xa[:,0,:]) 619 ang_embedding = ang_embedding.reshape(species.shape[0], xa.shape[-1]*Da.shape[-1]) 620 else: 621 radang = (xa * Da[:, :, None]).reshape( 622 (-1, Da.shape[1] * xa.shape[2]) 623 ) 624 ### AGGREGATE ANGLE EMBEDDING ### 625 ang_embedding = jax.ops.segment_sum( 626 radang, central_atom, species.shape[0] 627 ) 628 629 630 components.append(ang_embedding) 631 632 if self.pair_embedding_key is not None: 633 ang_ij = jax.ops.segment_sum( 634 jnp.concatenate((radang, radang)), 635 ang_pairs, 636 edge_src.shape[0], 637 ) 638 components_pair.append(ang_ij) 639 640 ################################################## 641 ### EQUIVARIANT EMBEDDING ### 642 if self.lmax > 0 and n_tp[layer] >= 0: 643 if initialize_e3 or not message_passing[layer]: 644 Vij = Yij 645 elif self.edge_tp: 646 Vij = FilteredTensorProduct( 647 self.lmax, self.lmax, name=f"edge_tp_{layer}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False 648 )(Vi[edge_dst], Yij) 649 else: 650 Vij = Vi[edge_dst] 651 652 ### compute channel weights 653 dim_wij = self.nchannels_l 654 if self.resolve_wij_l: 655 dim_wij=self.nchannels_l*(self.lmax+1) 656 657 eij = Lij if self.pair_embedding_key is None else jnp.concatenate([Lij,xij*switch],axis=-1) 658 wij = nn.Dense( 659 dim_wij, name=f"e3_channel_{layer}", use_bias=False 660 )(eij) 661 if self.resolve_wij_l: 662 wij = jnp.repeat(wij.reshape(-1,self.nchannels_l,self.lmax+1),nrep_l,axis=-1) 663 else: 664 wij = wij[:,:,None] 665 666 ### aggregate equivariant messages 667 drhoi = jax.ops.segment_sum( 668 wij * Vij, 669 edge_src, 670 species.shape[0], 671 ) 672 673 Vi0 = [] 674 if initialize_e3: 675 rhoi = drhoi 676 Vi = ChannelMixingE3( 677 self.lmax, 678 self.nchannels_l, 679 self.nchannels_l, 680 name=f"e3_initial_mixing_{layer}", 681 )(rhoi) 682 # assert n_tp[layer] > 0, "n_tp must be > 0 for the first equivariant layer." 683 else: 684 rhoi = rhoi + drhoi 685 # if message_passing[layer]: 686 # Vi0.append(drhoi[:, :, 0]) 687 initialize_e3 = False 688 if n_tp[layer] > 0: 689 for itp in range(n_tp[layer]): 690 dVi = FilteredTensorProduct( 691 self.lmax, self.lmax, name=f"tensor_product_{layer}_{itp}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False 692 )(rhoi, Vi) 693 Vi = ChannelMixing( 694 self.lmax, 695 self.nchannels_l, 696 self.nchannels_l, 697 name=f"tp_mixing_{layer}_{itp}", 698 )(Vi + dVi) 699 Vi0.append(dVi[:, :, 0]) 700 Vi0 = jnp.concatenate(Vi0, axis=-1) 701 components.append(Vi0) 702 703 if self.pair_embedding_key is not None: 704 Vij = Vi[edge_src]*Yij 705 Vij0 = [Vij[...,0]] 706 for l in range(1,self.lmax+1): 707 Vij0.append(Vij[...,l**2:(l+1)**2].sum(axis=-1)) 708 Vij0 = jnp.concatenate(Vij0,axis=-1) 709 components_pair.append(Vij0) 710 711 ################################################## 712 ### CONCATENATE EMBEDDING COMPONENTS ### 713 if do_lode and nchannels_lode[layer] > 0: 714 zj = nn.Dense(dim_lr[layer], use_bias=False, name=f"LODE_{layer}")(xi) 715 if nextra_powers > 0: 716 zj_extra = zj[:,:nextra_powers*nchannels_lode[layer]].reshape( 717 (species.shape[0],nchannels_lode[layer], nextra_powers) 718 ) 719 zj = zj[:,nextra_powers*nchannels_lode[layer]:] 720 xi_lr_extra = jax.ops.segment_sum( 721 eij_lr_extra * zj_extra[edge_dst_lr], edge_src_lr, species.shape[0] 722 ) 723 components.append(xi_lr_extra.reshape(species.shape[0],-1)) 724 725 if equivariant_lode: 726 zj = zj.reshape( 727 (species.shape[0], nchannels_lode[layer], lmax_lr + 1) 728 ).repeat(nrep_lr, axis=-1) 729 xi_lr = jax.ops.segment_sum( 730 eij_lr * zj[edge_dst_lr], edge_src_lr, species.shape[0] 731 ) 732 if equivariant_lode: 733 assert self.lode_use_field_norm or self.lode_multipole_interaction, "equivariant LODE requires field norm or multipole interaction" 734 if self.lode_multipole_interaction: 735 if initialize_e3: 736 raise ValueError("equivariant LODE used before local equivariants initialized") 737 size_l_lr = (lmax_lr+1)**2 738 if self.lode_direct_multipoles: 739 assert nchannels_lode[layer] <= self.nchannels_l 740 Mi = Vi[:, : nchannels_lode[layer], :size_l_lr] 741 else: 742 Mi = ChannelMixingE3( 743 lmax_lr, 744 self.nchannels_l, 745 nchannels_lode[layer], 746 name=f"e3_LODE_{layer}", 747 )(Vi[...,:size_l_lr]) 748 Mi_lr = Mi * xi_lr 749 components.append(xi_lr[:, :, 0]) 750 if self.lode_use_field_norm and self.lode_equi_full_combine: 751 xi_lr1 = ChannelMixing( 752 lmax_lr, 753 nchannels_lode[layer], 754 nchannels_lode[layer], 755 name=f"LODE_mixing_{layer}", 756 )(xi_lr) 757 norm = 1. 758 for l in range(1, lmax_lr + 1): 759 if self.lode_normalize_l: 760 norm = 1. / (2 * l + 1) 761 if self.lode_multipole_interaction: 762 components.append(Mi_lr[:, :, l**2 : (l + 1) ** 2].sum(axis=-1)*norm) 763 764 if self.lode_use_field_norm: 765 if self.lode_equi_full_combine: 766 components.append((xi_lr[:,:,l**2 : (l + 1) ** 2]*xi_lr1[:,:,l**2 : (l + 1) ** 2]).sum(axis=-1)*norm) 767 else: 768 components.append( 769 ((xi_lr[:, :, l**2 : (l + 1) ** 2]) ** 2).sum(axis=-1)*norm 770 ) 771 else: 772 components.append(xi_lr) 773 774 dxi = jnp.concatenate(components, axis=-1) 775 776 ################################################## 777 ### CONCATENATE PAIR EMBEDDING COMPONENTS ### 778 if self.pair_embedding_key is not None: 779 dxij = jnp.concatenate(components_pair, axis=-1) 780 781 ################################################## 782 ### MIX AND APPLY NONLINEARITY ### 783 if self.block_index_key is not None: 784 block_index = inputs[self.block_index_key] 785 dxi = actmix(BlockIndexNet( 786 output_dim=self.dim, 787 hidden_neurons=self.mixing_hidden, 788 activation=self.activation, 789 name=f"dxi_{layer}", 790 use_bias=self.use_bias, 791 kernel_init=kernel_init, 792 )((species,dxi, block_index)) 793 ) 794 else: 795 dxi = actmix( 796 FullyConnectedNet( 797 [*self.mixing_hidden, self.dim], 798 activation=self.activation, 799 name=f"dxi_{layer}", 800 use_bias=self.use_bias, 801 kernel_init=kernel_init, 802 )(dxi) 803 ) 804 805 if self.pair_embedding_key is not None: 806 ### UPDATE PAIR EMBEDDING ### 807 # dxij = tssr3(nn.Dense(dim_dst, name=f"dxij_{layer}",use_bias=False)(dxij)) 808 dxij = actmix( 809 FullyConnectedNet( 810 [*self.pair_mixing_hidden, dim_dst], 811 activation=self.activation, 812 name=f"dxij_{layer}", 813 use_bias=False, 814 kernel_init=kernel_init, 815 )(dxij) 816 ) 817 xij = layer_norm(xij + dxij) 818 819 ################################################## 820 ### UPDATE EMBEDDING ### 821 if layer == 0 and not (self.species_init or self.charge_embedding): 822 xi = layer_norm(dxi) 823 else: 824 ### FORGET GATE ### 825 R = jax.nn.sigmoid( 826 self.param( 827 f"retention_{layer}", 828 nn.initializers.normal(), 829 (xi.shape[-1],), 830 ) 831 ) 832 xi = layer_norm(R[None, :] * xi + dxi) 833 834 if self.keep_all_layers: 835 xis.append(xi) 836 837 embedding_key = ( 838 self.embedding_key if self.embedding_key is not None else self.name 839 ) 840 output = { 841 **inputs, 842 embedding_key: xi, 843 } 844 if self.lmax > 0: 845 output[embedding_key + "_tensor"] = Vi 846 if self.keep_all_layers: 847 output[embedding_key + "_layers"] = jnp.stack(xis, axis=1) 848 if self.charge_embedding: 849 output[embedding_key + "_charge"] = charge_embedding 850 if self.pair_embedding_key is not None: 851 output[self.pair_embedding_key] = xij 852 return output
Configurable Resources ATomic Environment
FID : CRATE
This class represents the CRATE (Configurable Resources ATomic Environment) embedding model. It is used to encode atomic environments using multiple sources of information (radial, angular, E(3), message-passing, LODE, etc...)
Whether to combine angle pairs instead of average distance embedding like in ANI.
The hidden size for the attention mechanism (only used when message-passing is disabled).
Whether to ignore the parity of the irreps in the tensor product.
The kernel initialization function for Dense operations.
The key for the pair embedding data in the output dictionary.
If str
, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding
.
The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis
.
The dictionary of parameters for radial basis functions for angle embedding. If None, the radial basis for angles is the same as the radial basis for distances.
The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode.
Whether to interact with the multipole moments of the lode graph.
Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction.
The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.