fennol.models.embeddings.crate
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4import dataclasses 5import numpy as np 6from typing import Dict, Union, Callable, Sequence, Optional, Tuple, ClassVar 7 8from ..misc.encodings import SpeciesEncoding, RadialBasis, positional_encoding 9from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3 10from ...utils.activations import activation_from_str 11from ...utils.initializers import initializer_from_str 12from ..misc.nets import FullyConnectedNet, BlockIndexNet 13from ..misc.e3 import ChannelMixing, ChannelMixingE3, FilteredTensorProduct 14from ...utils.periodic_table import D3_COV_RADII, D3_VDW_RADII, VALENCE_ELECTRONS 15from ...utils import AtomicUnits as au 16 17class CRATEmbedding(nn.Module): 18 """Configurable Resources ATomic Environment 19 20 FID : CRATE 21 22 This class represents the CRATE (Configurable Resources ATomic Environment) embedding model. 23 It is used to encode atomic environments using multiple sources of information 24 (radial, angular, E(3), message-passing, LODE, etc...) 25 """ 26 27 _graphs_properties: Dict 28 29 dim: int = 256 30 """The size of the embedding vectors.""" 31 nlayers: int = 2 32 """The number of interaction layers in the model.""" 33 keep_all_layers: bool = False 34 """Whether to output all layers.""" 35 36 dim_src: int = 64 37 """The size of the source embedding vectors.""" 38 dim_dst: int = 32 39 """The size of the destination embedding vectors.""" 40 41 angle_style: str = "fourier" 42 """The style of angle representation.""" 43 dim_angle: int = 8 44 """The size of the pairwise vectors use for triplet combinations.""" 45 nmax_angle: int = 4 46 """The dimension of the angle representation (minus one).""" 47 zeta: float = 14.1 48 """The zeta parameter for the model ANI angular representation.""" 49 angle_combine_pairs: bool = True 50 """Whether to combine angle pairs instead of average distance embedding like in ANI.""" 51 52 message_passing: bool = True 53 """Whether to use message passing in the model.""" 54 att_dim: int = 1 55 """The hidden size for the attention mechanism (only used when message-passing is disabled).""" 56 57 lmax: int = 0 58 """The maximum order of spherical tensors.""" 59 nchannels_l: int = 16 60 """The number of channels for spherical tensors.""" 61 n_tp: int = 1 62 """The number of tensor products performed at each layer.""" 63 ignore_irreps_parity: bool = False 64 """Whether to ignore the parity of the irreps in the tensor product.""" 65 edge_tp: bool = False 66 """Whether to perform a tensor product on edges before sending messages.""" 67 resolve_wij_l: bool = False 68 """Equivariant message weights are l-dependent.""" 69 70 species_init: bool = False 71 """Whether to initialize the embedding using the species encoding.""" 72 mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 73 """The hidden layer sizes for the mixing network.""" 74 pair_mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 75 """The hidden layer sizes for the pair mixing network.""" 76 activation: Union[Callable, str] = "silu" 77 """The activation function for the mixing network.""" 78 kernel_init: Union[str, Callable] = "lecun_normal()" 79 """The kernel initialization function for Dense operations.""" 80 activation_mixing: Union[Callable, str] = "tssr3" 81 """The activation function applied after mixing.""" 82 layer_normalization: bool = False 83 """Whether to apply layer normalization after each layer.""" 84 use_bias: bool = True 85 """Whether to use bias in the Dense operations.""" 86 87 graph_key: str = "graph" 88 """The key for the graph data in the inputs dictionary.""" 89 graph_angle_key: Optional[str] = None 90 """The key for the angle graph data in the inputs dictionary.""" 91 embedding_key: Optional[str] = None 92 """The key for the embedding data in the output dictionary.""" 93 pair_embedding_key: Optional[str] = None 94 """The key for the pair embedding data in the output dictionary.""" 95 96 species_encoding: Union[dict, str] = dataclasses.field(default_factory=dict) 97 """If `str`, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 98 radial_basis: dict = dataclasses.field(default_factory=dict) 99 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 100 radial_basis_angle: Optional[dict] = None 101 """The dictionary of parameters for radial basis functions for angle embedding. 102 If None, the radial basis for angles is the same as the radial basis for distances.""" 103 104 graph_lode: Optional[str] = None 105 """The key for the lode graph data in the inputs dictionary.""" 106 lode_channels: Union[int, Sequence[int]] = 8 107 """The number of channels for lode.""" 108 lmax_lode: int = 0 109 """The maximum order of spherical tensors for lode.""" 110 a_lode: float = -1. 111 """The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode.""" 112 lode_resolve_l: bool = True 113 """Whether to resolve the lode channels by l.""" 114 lode_multipole_interaction: bool = True 115 """Whether to interact with the multipole moments of the lode graph.""" 116 lode_direct_multipoles: bool = True 117 """Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction.""" 118 lode_equi_full_combine: bool = False 119 lode_normalize_l: bool = False 120 lode_use_field_norm: bool = True 121 lode_rshort: Optional[float] = None 122 lode_dshort: float = 0.5 123 124 125 charge_embedding: bool = False 126 """Whether to include charge embedding.""" 127 total_charge_key: str = "total_charge" 128 """The key for the total charge data in the inputs dictionary.""" 129 130 block_index_key: Optional[str] = None 131 """The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network.""" 132 133 FID: ClassVar[str] = "CRATE" 134 135 @nn.compact 136 def __call__(self, inputs): 137 species = inputs["species"] 138 assert ( 139 len(species.shape) == 1 140 ), "Species must be a 1D array (batches must be flattened)" 141 reduce_memory = "reduce_memory" in inputs.get("flags", {}) 142 143 kernel_init = ( 144 initializer_from_str(self.kernel_init) 145 if isinstance(self.kernel_init, str) 146 else self.kernel_init 147 ) 148 149 actmix = activation_from_str(self.activation_mixing) 150 151 ################################################## 152 graph = inputs[self.graph_key] 153 use_angles = self.graph_angle_key is not None 154 if use_angles: 155 graph_angle = inputs[self.graph_angle_key] 156 157 # Check that the graph_angle is a subgraph of graph 158 correct_graph = ( 159 self.graph_angle_key == self.graph_key 160 or self._graphs_properties[self.graph_angle_key]["parent_graph"] 161 == self.graph_key 162 ) 163 assert ( 164 correct_graph 165 ), f"graph_angle_key={self.graph_angle_key} must be a subgraph of graph_key={self.graph_key}" 166 assert ( 167 "angles" in graph_angle 168 ), f"Graph {self.graph_angle_key} must contain angles" 169 # check if graph_angle is a filtered graph 170 filtered = "parent_graph" in self._graphs_properties[self.graph_angle_key] 171 if filtered: 172 filter_indices = graph_angle["filter_indices"] 173 174 ################################################## 175 ### SPECIES ENCODING ### 176 if isinstance(self.species_encoding, str): 177 zi = inputs[self.species_encoding] 178 else: 179 zi = SpeciesEncoding( 180 **self.species_encoding, name="SpeciesEncoding" 181 )(species) 182 183 184 if self.layer_normalization: 185 def layer_norm(x): 186 mu = jnp.mean(x,axis=-1,keepdims=True) 187 dx = x-mu 188 var = jnp.mean(dx**2,axis=-1,keepdims=True) 189 sig = (1.e-6 + var)**(-0.5) 190 return dx*sig 191 else: 192 layer_norm = lambda x:x 193 194 195 if self.charge_embedding: 196 xi, qi = jnp.split( 197 nn.Dense(self.dim + 1, use_bias=False, name="ChargeEncoding")(zi), 198 [self.dim], 199 axis=-1, 200 ) 201 batch_index = inputs["batch_index"] 202 natoms = inputs["natoms"] 203 nsys = natoms.shape[0] 204 Zi = jnp.asarray(VALENCE_ELECTRONS)[species] 205 Ntot = jax.ops.segment_sum(Zi, batch_index, nsys) - inputs.get( 206 self.total_charge_key, jnp.zeros(nsys) 207 ) 208 ai = jax.nn.softplus(qi.squeeze(-1)) 209 A = jax.ops.segment_sum(ai, batch_index, nsys) 210 Ni = ai * (Ntot / A)[batch_index] 211 charge_embedding = positional_encoding(Ni, self.dim) 212 xi = layer_norm(xi + charge_embedding) 213 elif self.species_init: 214 xi = layer_norm(nn.Dense(self.dim, use_bias=False, name="SpeciesInit")(zi)) 215 else: 216 xi = zi 217 218 ################################################## 219 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 220 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 221 distances = graph["distances"] 222 switch = graph["switch"][:, None] 223 224 ### COMPUTE RADIAL BASIS ### 225 radial_basis = RadialBasis( 226 **{ 227 **self.radial_basis, 228 "end": cutoff, 229 "name": f"RadialBasis", 230 } 231 )(distances) 232 233 do_lode = self.graph_lode is not None 234 if do_lode: 235 graph_lode = inputs[self.graph_lode] 236 switch_lode = graph_lode["switch"][:, None] 237 238 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 239 r = graph_lode["distances"][:, None] 240 rc = self._graphs_properties[self.graph_lode]["cutoff"] 241 242 lmax_lr = self.lmax_lode 243 equivariant_lode = lmax_lr > 0 244 assert lmax_lr >=0, f"lmax_lode must be >= 0, got {lmax_lr}" 245 if self.lode_multipole_interaction: 246 assert lmax_lr <= self.lmax, f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 247 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 248 if self.lode_resolve_l and equivariant_lode: 249 ls_lr = np.arange(lmax_lr + 1) 250 else: 251 ls_lr = np.array([0]) 252 253 if self.a_lode > 0: 254 a = self.a_lode**2 255 else: 256 a = ( 257 self.param( 258 "a_lr", 259 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[None, :], 260 ) 261 ** 2 262 ) 263 rc2a = rc**2 + a 264 ls_lr = 0.5 * (ls_lr[None, :] + 1) 265 ### minimal radial basis for long range (damped coulomb) 266 eij_lr = ( 267 1.0 / (r**2 + a) ** ls_lr 268 - 1.0 / rc2a**ls_lr 269 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 270 ) * switch_lode 271 272 if self.lode_rshort is not None: 273 rs = self.lode_rshort 274 d = self.lode_dshort 275 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 276 r < rs + d 277 ) + (r > rs + d) 278 eij_lr = eij_lr * switch_short 279 280 # dim_lr = self.nchannels_lode 281 nchannels_lode = ( 282 [self.lode_channels] * self.nlayers 283 if isinstance(self.lode_channels, int) 284 else self.lode_channels 285 ) 286 dim_lr = nchannels_lode 287 288 if equivariant_lode: 289 if self.lode_resolve_l: 290 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 291 Yij = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 292 graph_lode["vec"] / r 293 ) 294 eij_lr = (eij_lr * Yij)[:, None, :] 295 dim_lr = [d * (lmax_lr + 1) for d in dim_lr] 296 297 298 ################################################## 299 ### GET ANGLES ### 300 if use_angles: 301 angles = graph_angle["angles"][:, None] 302 angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"] 303 switch_angles = graph_angle["switch"][:, None] 304 central_atom = graph_angle["central_atom"] 305 306 if not self.angle_combine_pairs: 307 assert ( 308 self.radial_basis_angle is not None 309 ), "radial_basis_angle must be specified if angle_combine_pairs=False" 310 311 ### COMPUTE RADIAL BASIS FOR ANGLES ### 312 if self.radial_basis_angle is not None: 313 dangles = graph_angle["distances"] 314 swang = switch_angles 315 if not self.angle_combine_pairs: 316 dangles = 0.5 * (dangles[angle_src] + dangles[angle_dst]) 317 swang = switch_angles[angle_src] * switch_angles[angle_dst] 318 radial_basis_angle = ( 319 RadialBasis( 320 **{ 321 **self.radial_basis_angle, 322 "end": self._graphs_properties[self.graph_angle_key][ 323 "cutoff" 324 ], 325 "name": f"RadialBasisAngle", 326 } 327 )(dangles) 328 * swang 329 ) 330 331 else: 332 if filtered: 333 radial_basis_angle = radial_basis[filter_indices] * switch_angles 334 else: 335 radial_basis_angle = radial_basis * switch 336 337 radial_basis = radial_basis * switch 338 339 # # add covalent indicator 340 # rc = jnp.asarray([d/au.BOHR for d in D3_COV_RADII])[species] 341 # rcij = rc[edge_src] + rc[edge_dst] 342 # fact = graph["switch"]*(2*distances/rcij)*jnp.exp(-0.5 * ((distances - rcij)/(0.1*rcij)) ** 2) 343 # radial_basis = jnp.concatenate([radial_basis,fact[:,None]],axis=-1) 344 # if use_angles: 345 # rcij = rc[graph_angle["edge_src"]] + rc[graph_angle["edge_dst"]] 346 # dangles = graph_angle["distances"] 347 # fact = graph_angle["switch"]*((2*dangles/rcij))*jnp.exp(-0.5 * ((dangles - rcij)/(0.1*rcij))**2) 348 # radial_basis_angle = jnp.concatenate([radial_basis_angle,fact[:,None]],axis=-1) 349 350 351 ################################################## 352 if self.lmax > 0: 353 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 354 graph["vec"] / graph["distances"][:, None] 355 )[:, None, :] 356 Yij = jnp.broadcast_to(Yij, (Yij.shape[0], self.nchannels_l, Yij.shape[2])) 357 nrep_l = np.array([2 * l + 1 for l in range(self.lmax + 1)], dtype=np.int32) 358 # ls = [0] 359 # for l in range(1, self.lmax + 1): 360 # ls = ls + [l] * (2 * l + 1) 361 #ls = jnp.asarray(np.array(ls)[None, :], dtype=distances.dtype) 362 #lcut = (0.5 + 0.5 * jnp.cos((np.pi / cutoff) * distances[:, #None])) ** ( 363 # ls + 1 364 #) 365 # lcut = jnp.where(graph["edge_mask"][:, None], lcut, 0.0) 366 # rijl1 = (lcut * distances[:, None] ** ls)[:, None, :] 367 368 ################################################## 369 if use_angles: 370 ### ANGULAR BASIS ### 371 if self.angle_style == "fourier": 372 # build fourier series for angles 373 nangles = self.param( 374 f"nangles", 375 lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :], 376 self.nmax_angle + 1, 377 ) 378 379 phi = self.param( 380 f"phi", 381 lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype), 382 self.nmax_angle + 1, 383 ) 384 xa = jnp.cos(nangles * angles + phi) 385 elif self.angle_style == "fourier_full": 386 # build fourier series for angles including sin terms 387 nangles = self.param( 388 f"nangles", 389 lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :], 390 self.nmax_angle + 1, 391 ) 392 393 phi = self.param( 394 f"phi", 395 lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype), 396 2 * self.nmax_angle + 1, 397 ) 398 xac = jnp.cos(nangles * angles + phi[:, : self.nmax_angle + 1]) 399 xas = jnp.sin(nangles[:, 1:] * angles + phi[:, self.nmax_angle + 1 :]) 400 xa = jnp.concatenate([xac, xas], axis=-1) 401 elif self.angle_style == "ani": 402 # ANI-style angle embedding 403 angle_start = np.pi / (2 * (self.nmax_angle + 1)) 404 shiftZ = self.param( 405 f"shiftZ", 406 lambda key, dim: jnp.asarray( 407 (np.linspace(0, np.pi, dim + 1) + angle_start)[None, :-1], 408 dtype=distances.dtype, 409 ), 410 self.nmax_angle + 1, 411 ) 412 zeta = self.param( 413 f"zeta", 414 lambda key: jnp.asarray(self.zeta, dtype=distances.dtype), 415 ) 416 xa = (0.5 + 0.5 * jnp.cos(angles - shiftZ)) ** zeta 417 else: 418 raise ValueError(f"Unknown angle style {self.angle_style}") 419 xa = xa[:, None, :] 420 if not self.angle_combine_pairs: 421 if reduce_memory: raise NotImplementedError("Angle embedding not implemented with reduce_memory") 422 xa = (xa * radial_basis_angle[:, :, None]).reshape( 423 -1, 1, xa.shape[1] * radial_basis_angle.shape[1] 424 ) 425 426 if self.pair_embedding_key is not None: 427 if filtered: 428 ang_pair_src = filter_indices[angle_src] 429 ang_pair_dst = filter_indices[angle_dst] 430 else: 431 ang_pair_src = angle_src 432 ang_pair_dst = angle_dst 433 ang_pairs = jnp.concatenate((ang_pair_src, ang_pair_dst)) 434 435 ################################################## 436 ### DIMENSIONS ### 437 dim_src = ( 438 [self.dim_src] * self.nlayers 439 if isinstance(self.dim_src, int) 440 else self.dim_src 441 ) 442 assert ( 443 len(dim_src) == self.nlayers 444 ), f"dim_src must be an integer or a list of length {self.nlayers}" 445 dim_dst = self.dim_dst 446 # dim_dst = ( 447 # [self.dim_dst] * self.nlayers 448 # if isinstance(self.dim_dst, int) 449 # else self.dim_dst 450 # ) 451 # assert ( 452 # len(dim_dst) == self.nlayers 453 # ), f"dim_dst must be an integer or a list of length {self.nlayers}" 454 455 if use_angles: 456 dim_angle = ( 457 [self.dim_angle] * self.nlayers 458 if isinstance(self.dim_angle, int) 459 else self.dim_angle 460 ) 461 assert ( 462 len(dim_angle) == self.nlayers 463 ), f"dim_angle must be an integer or a list of length {self.nlayers}" 464 # nmax_angle = [self.nmax_angle]*self.nlayers if isinstance(self.nmax_angle, int) else self.nmax_angle 465 # assert len(nmax_angle) == self.nlayers, f"nmax_angle must be an integer or a list of length {self.nlayers}" 466 467 initialize_e3 = True 468 if self.lmax > 0: 469 n_tp = ( 470 [self.n_tp] * self.nlayers 471 if isinstance(self.n_tp, int) 472 else self.n_tp 473 ) 474 assert ( 475 len(n_tp) == self.nlayers 476 ), f"n_tp must be an integer or a list of length {self.nlayers}" 477 478 479 message_passing = ( 480 [self.message_passing] * self.nlayers 481 if isinstance(self.message_passing, bool) 482 else self.message_passing 483 ) 484 assert ( 485 len(message_passing) == self.nlayers 486 ), f"message_passing must be a boolean or a list of length {self.nlayers}" 487 488 ################################################## 489 ### INITIALIZE PAIR EMBEDDING ### 490 if self.pair_embedding_key is not None: 491 xij_s,xij_d = jnp.split(nn.Dense(2*dim_dst, name="pair_init_linear")(zi), [dim_dst], axis=-1) 492 xij = layer_norm(xij_s[edge_src]*xij_d[edge_dst]) 493 494 ################################################## 495 if self.keep_all_layers: 496 xis = [] 497 498 ### LOOP OVER LAYERS ### 499 for layer in range(self.nlayers): 500 ################################################## 501 ### COMPACT DESCRIPTORS ### 502 si, si_dst = jnp.split( 503 nn.Dense( 504 dim_src[layer] + dim_dst, 505 name=f"species_linear_{layer}", 506 use_bias=self.use_bias, 507 )(xi), 508 [ 509 dim_src[layer], 510 ], 511 axis=-1, 512 ) 513 514 ################################################## 515 if message_passing[layer] or layer == 0: 516 ### MESSAGE PASSING ### 517 si_mp = si_dst[edge_dst] 518 else: 519 # if layer == 0: 520 # si_mp = si_dst[edge_dst] 521 ### ATTENTION TO SIMULATE MP ### 522 Q = nn.Dense( 523 dim_dst * self.att_dim, name=f"queries_{layer}", use_bias=False 524 )(si_dst).reshape(-1, dim_dst, self.att_dim)[edge_src] 525 K = nn.Dense( 526 dim_dst * self.att_dim, name=f"keys_{layer}", use_bias=False 527 )(zi).reshape(-1, dim_dst, self.att_dim)[edge_dst] 528 529 si_mp = (K * Q).sum(axis=-1) / self.att_dim**0.5 530 # Vmp = jax.ops.segment_sum( 531 # (KQ * switch)[:, :, None] * Yij, edge_src, species.shape[0] 532 # ) 533 # si_mp = (Vmp[edge_src] * Yij).sum(axis=-1) 534 # Q = nn.Dense( 535 # dim_dst * dim_dst, name=f"queries_{layer}", use_bias=False 536 # )(si_dst).reshape(-1, dim_dst, dim_dst) 537 # si_mp = ( 538 # si_mp + jax.vmap(jnp.dot)(Q[edge_src], si_mp) / self.dim_dst**0.5 539 # ) 540 541 if self.pair_embedding_key is not None: 542 si_mp = si_mp + xij 543 544 ################################################## 545 ### PAIR EMBEDDING ### 546 if reduce_memory: 547 Li = jnp.zeros((species.shape[0]* radial_basis.shape[1],si_mp.shape[1]),dtype=si_mp.dtype) 548 for i in range(radial_basis.shape[1]): 549 indices = i + edge_src*radial_basis.shape[1] 550 Li = Li.at[indices].add(si_mp*radial_basis[:,i,None]) 551 Li = Li.reshape(species.shape[0], radial_basis.shape[1]*si_mp.shape[1]) 552 else: 553 Lij = (si_mp[:, None, :] * radial_basis[:, :, None]).reshape( 554 radial_basis.shape[0], si_mp.shape[1] * radial_basis.shape[1] 555 ) 556 ### AGGREGATE PAIR EMBEDDING ### 557 Li = jax.ops.segment_sum(Lij, edge_src, species.shape[0]) 558 559 ### CONCATENATE EMBEDDING COMPONENTS ### 560 components = [si, Li] 561 if self.pair_embedding_key is not None: 562 if reduce_memory: raise NotImplementedError("Pair embedding not implemented with reduce_memory") 563 components_pair = [si[edge_src], xij, Lij] 564 565 566 ################################################## 567 ### ANGLE EMBEDDING ### 568 if use_angles and dim_angle[layer]>0: 569 si_mp_ang = si_mp[filter_indices] if filtered else si_mp 570 if self.angle_combine_pairs: 571 Wa = self.param( 572 f"Wa_{layer}", 573 nn.initializers.normal( 574 stddev=1.0 575 / (si_mp.shape[1] * radial_basis_angle.shape[1]) ** 0.5 576 ), 577 (si_mp.shape[1], radial_basis_angle.shape[1], dim_angle[layer]), 578 ) 579 Da = jnp.einsum( 580 "...i,...j,ijk->...k", 581 si_mp_ang, 582 radial_basis_angle, 583 Wa, 584 ) 585 586 else: 587 if message_passing[layer] or layer == 0: 588 Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")( 589 xi 590 )[graph_angle["edge_dst"]] 591 else: 592 Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")( 593 si_mp_ang 594 ) 595 596 Da = Da[angle_dst] * Da[angle_src] 597 ## combine pair and angle info 598 if reduce_memory: 599 ang_embedding = jnp.zeros((species.shape[0]* Da.shape[-1],xa.shape[-1]),dtype=Da.dtype) 600 for i in range(Da.shape[-1]): 601 indices = i + central_atom*Da.shape[-1] 602 ang_embedding = ang_embedding.at[indices].add(Da[:,i,None]*xa[:,0,:]) 603 ang_embedding = ang_embedding.reshape(species.shape[0], xa.shape[-1]*Da.shape[-1]) 604 else: 605 radang = (xa * Da[:, :, None]).reshape( 606 (-1, Da.shape[1] * xa.shape[2]) 607 ) 608 ### AGGREGATE ANGLE EMBEDDING ### 609 ang_embedding = jax.ops.segment_sum( 610 radang, central_atom, species.shape[0] 611 ) 612 613 614 components.append(ang_embedding) 615 616 if self.pair_embedding_key is not None: 617 ang_ij = jax.ops.segment_sum( 618 jnp.concatenate((radang, radang)), 619 ang_pairs, 620 edge_src.shape[0], 621 ) 622 components_pair.append(ang_ij) 623 624 ################################################## 625 ### EQUIVARIANT EMBEDDING ### 626 if self.lmax > 0 and n_tp[layer] >= 0: 627 if initialize_e3 or not message_passing[layer]: 628 Vij = Yij 629 elif self.edge_tp: 630 Vij = FilteredTensorProduct( 631 self.lmax, self.lmax, name=f"edge_tp_{layer}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False 632 )(Vi[edge_dst], Yij) 633 else: 634 Vij = Vi[edge_dst] 635 636 ### compute channel weights 637 dim_wij = self.nchannels_l 638 if self.resolve_wij_l: 639 dim_wij=self.nchannels_l*(self.lmax+1) 640 641 eij = Lij if self.pair_embedding_key is None else jnp.concatenate([Lij,xij*switch],axis=-1) 642 wij = nn.Dense( 643 dim_wij, name=f"e3_channel_{layer}", use_bias=False 644 )(eij) 645 if self.resolve_wij_l: 646 wij = jnp.repeat(wij.reshape(-1,self.nchannels_l,self.lmax+1),nrep_l,axis=-1) 647 else: 648 wij = wij[:,:,None] 649 650 ### aggregate equivariant messages 651 drhoi = jax.ops.segment_sum( 652 wij * Vij, 653 edge_src, 654 species.shape[0], 655 ) 656 657 Vi0 = [] 658 if initialize_e3: 659 rhoi = drhoi 660 Vi = ChannelMixingE3( 661 self.lmax, 662 self.nchannels_l, 663 self.nchannels_l, 664 name=f"e3_initial_mixing_{layer}", 665 )(rhoi) 666 # assert n_tp[layer] > 0, "n_tp must be > 0 for the first equivariant layer." 667 else: 668 rhoi = rhoi + drhoi 669 # if message_passing[layer]: 670 # Vi0.append(drhoi[:, :, 0]) 671 initialize_e3 = False 672 if n_tp[layer] > 0: 673 for itp in range(n_tp[layer]): 674 dVi = FilteredTensorProduct( 675 self.lmax, self.lmax, name=f"tensor_product_{layer}_{itp}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False 676 )(rhoi, Vi) 677 Vi = ChannelMixing( 678 self.lmax, 679 self.nchannels_l, 680 self.nchannels_l, 681 name=f"tp_mixing_{layer}_{itp}", 682 )(Vi + dVi) 683 Vi0.append(dVi[:, :, 0]) 684 Vi0 = jnp.concatenate(Vi0, axis=-1) 685 components.append(Vi0) 686 687 if self.pair_embedding_key is not None: 688 Vij = Vi[edge_src]*Yij 689 Vij0 = [Vij[...,0]] 690 for l in range(1,self.lmax+1): 691 Vij0.append(Vij[...,l**2:(l+1)**2].sum(axis=-1)) 692 Vij0 = jnp.concatenate(Vij0,axis=-1) 693 components_pair.append(Vij0) 694 695 ################################################## 696 ### CONCATENATE EMBEDDING COMPONENTS ### 697 if do_lode and nchannels_lode[layer] > 0: 698 zj = nn.Dense(dim_lr[layer], use_bias=False, name=f"LODE_{layer}")(xi) 699 if equivariant_lode: 700 zj = zj.reshape( 701 (species.shape[0], nchannels_lode[layer], lmax_lr + 1) 702 ).repeat(nrep_lr, axis=-1) 703 xi_lr = jax.ops.segment_sum( 704 eij_lr * zj[edge_dst_lr], edge_src_lr, species.shape[0] 705 ) 706 if equivariant_lode: 707 assert self.lode_use_field_norm or self.lode_multipole_interaction, "equivariant LODE requires field norm or multipole interaction" 708 if self.lode_multipole_interaction: 709 if initialize_e3: 710 raise ValueError("equivariant LODE used before local equivariants initialized") 711 size_l_lr = (lmax_lr+1)**2 712 if self.lode_direct_multipoles: 713 assert nchannels_lode[layer] <= self.nchannels_l 714 Mi = Vi[:, : nchannels_lode[layer], :size_l_lr] 715 else: 716 Mi = ChannelMixingE3( 717 lmax_lr, 718 self.nchannels_l, 719 nchannels_lode[layer], 720 name=f"e3_LODE_{layer}", 721 )(Vi[...,:size_l_lr]) 722 Mi_lr = Mi * xi_lr 723 components.append(xi_lr[:, :, 0]) 724 if self.lode_use_field_norm and self.lode_equi_full_combine: 725 xi_lr1 = ChannelMixing( 726 lmax_lr, 727 nchannels_lode[layer], 728 nchannels_lode[layer], 729 name=f"LODE_mixing_{layer}", 730 )(xi_lr) 731 norm = 1. 732 for l in range(1, lmax_lr + 1): 733 if self.lode_normalize_l: 734 norm = 1. / (2 * l + 1) 735 if self.lode_multipole_interaction: 736 components.append(Mi_lr[:, :, l**2 : (l + 1) ** 2].sum(axis=-1)*norm) 737 738 if self.lode_use_field_norm: 739 if self.lode_equi_full_combine: 740 components.append((xi_lr[:,:,l**2 : (l + 1) ** 2]*xi_lr1[:,:,l**2 : (l + 1) ** 2]).sum(axis=-1)*norm) 741 else: 742 components.append( 743 ((xi_lr[:, :, l**2 : (l + 1) ** 2]) ** 2).sum(axis=-1)*norm 744 ) 745 else: 746 components.append(xi_lr) 747 748 dxi = jnp.concatenate(components, axis=-1) 749 750 ################################################## 751 ### CONCATENATE PAIR EMBEDDING COMPONENTS ### 752 if self.pair_embedding_key is not None: 753 dxij = jnp.concatenate(components_pair, axis=-1) 754 755 ################################################## 756 ### MIX AND APPLY NONLINEARITY ### 757 if self.block_index_key is not None: 758 block_index = inputs[self.block_index_key] 759 dxi = actmix(BlockIndexNet( 760 output_dim=self.dim, 761 hidden_neurons=self.mixing_hidden, 762 activation=self.activation, 763 name=f"dxi_{layer}", 764 use_bias=self.use_bias, 765 kernel_init=kernel_init, 766 )((species,dxi, block_index)) 767 ) 768 else: 769 dxi = actmix( 770 FullyConnectedNet( 771 [*self.mixing_hidden, self.dim], 772 activation=self.activation, 773 name=f"dxi_{layer}", 774 use_bias=self.use_bias, 775 kernel_init=kernel_init, 776 )(dxi) 777 ) 778 779 if self.pair_embedding_key is not None: 780 ### UPDATE PAIR EMBEDDING ### 781 # dxij = tssr3(nn.Dense(dim_dst, name=f"dxij_{layer}",use_bias=False)(dxij)) 782 dxij = actmix( 783 FullyConnectedNet( 784 [*self.pair_mixing_hidden, dim_dst], 785 activation=self.activation, 786 name=f"dxij_{layer}", 787 use_bias=False, 788 kernel_init=kernel_init, 789 )(dxij) 790 ) 791 xij = layer_norm(xij + dxij) 792 793 ################################################## 794 ### UPDATE EMBEDDING ### 795 if layer == 0 and not (self.species_init or self.charge_embedding): 796 xi = layer_norm(dxi) 797 else: 798 ### FORGET GATE ### 799 R = jax.nn.sigmoid( 800 self.param( 801 f"retention_{layer}", 802 nn.initializers.normal(), 803 (xi.shape[-1],), 804 ) 805 ) 806 xi = layer_norm(R[None, :] * xi + dxi) 807 808 if self.keep_all_layers: 809 xis.append(xi) 810 811 embedding_key = ( 812 self.embedding_key if self.embedding_key is not None else self.name 813 ) 814 output = { 815 **inputs, 816 embedding_key: xi, 817 } 818 if self.lmax > 0: 819 output[embedding_key + "_tensor"] = Vi 820 if self.keep_all_layers: 821 output[embedding_key + "_layers"] = jnp.stack(xis, axis=1) 822 if self.charge_embedding: 823 output[embedding_key + "_charge"] = charge_embedding 824 if self.pair_embedding_key is not None: 825 output[self.pair_embedding_key] = xij 826 return output
18class CRATEmbedding(nn.Module): 19 """Configurable Resources ATomic Environment 20 21 FID : CRATE 22 23 This class represents the CRATE (Configurable Resources ATomic Environment) embedding model. 24 It is used to encode atomic environments using multiple sources of information 25 (radial, angular, E(3), message-passing, LODE, etc...) 26 """ 27 28 _graphs_properties: Dict 29 30 dim: int = 256 31 """The size of the embedding vectors.""" 32 nlayers: int = 2 33 """The number of interaction layers in the model.""" 34 keep_all_layers: bool = False 35 """Whether to output all layers.""" 36 37 dim_src: int = 64 38 """The size of the source embedding vectors.""" 39 dim_dst: int = 32 40 """The size of the destination embedding vectors.""" 41 42 angle_style: str = "fourier" 43 """The style of angle representation.""" 44 dim_angle: int = 8 45 """The size of the pairwise vectors use for triplet combinations.""" 46 nmax_angle: int = 4 47 """The dimension of the angle representation (minus one).""" 48 zeta: float = 14.1 49 """The zeta parameter for the model ANI angular representation.""" 50 angle_combine_pairs: bool = True 51 """Whether to combine angle pairs instead of average distance embedding like in ANI.""" 52 53 message_passing: bool = True 54 """Whether to use message passing in the model.""" 55 att_dim: int = 1 56 """The hidden size for the attention mechanism (only used when message-passing is disabled).""" 57 58 lmax: int = 0 59 """The maximum order of spherical tensors.""" 60 nchannels_l: int = 16 61 """The number of channels for spherical tensors.""" 62 n_tp: int = 1 63 """The number of tensor products performed at each layer.""" 64 ignore_irreps_parity: bool = False 65 """Whether to ignore the parity of the irreps in the tensor product.""" 66 edge_tp: bool = False 67 """Whether to perform a tensor product on edges before sending messages.""" 68 resolve_wij_l: bool = False 69 """Equivariant message weights are l-dependent.""" 70 71 species_init: bool = False 72 """Whether to initialize the embedding using the species encoding.""" 73 mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 74 """The hidden layer sizes for the mixing network.""" 75 pair_mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: []) 76 """The hidden layer sizes for the pair mixing network.""" 77 activation: Union[Callable, str] = "silu" 78 """The activation function for the mixing network.""" 79 kernel_init: Union[str, Callable] = "lecun_normal()" 80 """The kernel initialization function for Dense operations.""" 81 activation_mixing: Union[Callable, str] = "tssr3" 82 """The activation function applied after mixing.""" 83 layer_normalization: bool = False 84 """Whether to apply layer normalization after each layer.""" 85 use_bias: bool = True 86 """Whether to use bias in the Dense operations.""" 87 88 graph_key: str = "graph" 89 """The key for the graph data in the inputs dictionary.""" 90 graph_angle_key: Optional[str] = None 91 """The key for the angle graph data in the inputs dictionary.""" 92 embedding_key: Optional[str] = None 93 """The key for the embedding data in the output dictionary.""" 94 pair_embedding_key: Optional[str] = None 95 """The key for the pair embedding data in the output dictionary.""" 96 97 species_encoding: Union[dict, str] = dataclasses.field(default_factory=dict) 98 """If `str`, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`.""" 99 radial_basis: dict = dataclasses.field(default_factory=dict) 100 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 101 radial_basis_angle: Optional[dict] = None 102 """The dictionary of parameters for radial basis functions for angle embedding. 103 If None, the radial basis for angles is the same as the radial basis for distances.""" 104 105 graph_lode: Optional[str] = None 106 """The key for the lode graph data in the inputs dictionary.""" 107 lode_channels: Union[int, Sequence[int]] = 8 108 """The number of channels for lode.""" 109 lmax_lode: int = 0 110 """The maximum order of spherical tensors for lode.""" 111 a_lode: float = -1. 112 """The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode.""" 113 lode_resolve_l: bool = True 114 """Whether to resolve the lode channels by l.""" 115 lode_multipole_interaction: bool = True 116 """Whether to interact with the multipole moments of the lode graph.""" 117 lode_direct_multipoles: bool = True 118 """Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction.""" 119 lode_equi_full_combine: bool = False 120 lode_normalize_l: bool = False 121 lode_use_field_norm: bool = True 122 lode_rshort: Optional[float] = None 123 lode_dshort: float = 0.5 124 125 126 charge_embedding: bool = False 127 """Whether to include charge embedding.""" 128 total_charge_key: str = "total_charge" 129 """The key for the total charge data in the inputs dictionary.""" 130 131 block_index_key: Optional[str] = None 132 """The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network.""" 133 134 FID: ClassVar[str] = "CRATE" 135 136 @nn.compact 137 def __call__(self, inputs): 138 species = inputs["species"] 139 assert ( 140 len(species.shape) == 1 141 ), "Species must be a 1D array (batches must be flattened)" 142 reduce_memory = "reduce_memory" in inputs.get("flags", {}) 143 144 kernel_init = ( 145 initializer_from_str(self.kernel_init) 146 if isinstance(self.kernel_init, str) 147 else self.kernel_init 148 ) 149 150 actmix = activation_from_str(self.activation_mixing) 151 152 ################################################## 153 graph = inputs[self.graph_key] 154 use_angles = self.graph_angle_key is not None 155 if use_angles: 156 graph_angle = inputs[self.graph_angle_key] 157 158 # Check that the graph_angle is a subgraph of graph 159 correct_graph = ( 160 self.graph_angle_key == self.graph_key 161 or self._graphs_properties[self.graph_angle_key]["parent_graph"] 162 == self.graph_key 163 ) 164 assert ( 165 correct_graph 166 ), f"graph_angle_key={self.graph_angle_key} must be a subgraph of graph_key={self.graph_key}" 167 assert ( 168 "angles" in graph_angle 169 ), f"Graph {self.graph_angle_key} must contain angles" 170 # check if graph_angle is a filtered graph 171 filtered = "parent_graph" in self._graphs_properties[self.graph_angle_key] 172 if filtered: 173 filter_indices = graph_angle["filter_indices"] 174 175 ################################################## 176 ### SPECIES ENCODING ### 177 if isinstance(self.species_encoding, str): 178 zi = inputs[self.species_encoding] 179 else: 180 zi = SpeciesEncoding( 181 **self.species_encoding, name="SpeciesEncoding" 182 )(species) 183 184 185 if self.layer_normalization: 186 def layer_norm(x): 187 mu = jnp.mean(x,axis=-1,keepdims=True) 188 dx = x-mu 189 var = jnp.mean(dx**2,axis=-1,keepdims=True) 190 sig = (1.e-6 + var)**(-0.5) 191 return dx*sig 192 else: 193 layer_norm = lambda x:x 194 195 196 if self.charge_embedding: 197 xi, qi = jnp.split( 198 nn.Dense(self.dim + 1, use_bias=False, name="ChargeEncoding")(zi), 199 [self.dim], 200 axis=-1, 201 ) 202 batch_index = inputs["batch_index"] 203 natoms = inputs["natoms"] 204 nsys = natoms.shape[0] 205 Zi = jnp.asarray(VALENCE_ELECTRONS)[species] 206 Ntot = jax.ops.segment_sum(Zi, batch_index, nsys) - inputs.get( 207 self.total_charge_key, jnp.zeros(nsys) 208 ) 209 ai = jax.nn.softplus(qi.squeeze(-1)) 210 A = jax.ops.segment_sum(ai, batch_index, nsys) 211 Ni = ai * (Ntot / A)[batch_index] 212 charge_embedding = positional_encoding(Ni, self.dim) 213 xi = layer_norm(xi + charge_embedding) 214 elif self.species_init: 215 xi = layer_norm(nn.Dense(self.dim, use_bias=False, name="SpeciesInit")(zi)) 216 else: 217 xi = zi 218 219 ################################################## 220 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 221 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 222 distances = graph["distances"] 223 switch = graph["switch"][:, None] 224 225 ### COMPUTE RADIAL BASIS ### 226 radial_basis = RadialBasis( 227 **{ 228 **self.radial_basis, 229 "end": cutoff, 230 "name": f"RadialBasis", 231 } 232 )(distances) 233 234 do_lode = self.graph_lode is not None 235 if do_lode: 236 graph_lode = inputs[self.graph_lode] 237 switch_lode = graph_lode["switch"][:, None] 238 239 edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"] 240 r = graph_lode["distances"][:, None] 241 rc = self._graphs_properties[self.graph_lode]["cutoff"] 242 243 lmax_lr = self.lmax_lode 244 equivariant_lode = lmax_lr > 0 245 assert lmax_lr >=0, f"lmax_lode must be >= 0, got {lmax_lr}" 246 if self.lode_multipole_interaction: 247 assert lmax_lr <= self.lmax, f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}" 248 nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32) 249 if self.lode_resolve_l and equivariant_lode: 250 ls_lr = np.arange(lmax_lr + 1) 251 else: 252 ls_lr = np.array([0]) 253 254 if self.a_lode > 0: 255 a = self.a_lode**2 256 else: 257 a = ( 258 self.param( 259 "a_lr", 260 lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[None, :], 261 ) 262 ** 2 263 ) 264 rc2a = rc**2 + a 265 ls_lr = 0.5 * (ls_lr[None, :] + 1) 266 ### minimal radial basis for long range (damped coulomb) 267 eij_lr = ( 268 1.0 / (r**2 + a) ** ls_lr 269 - 1.0 / rc2a**ls_lr 270 + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1) 271 ) * switch_lode 272 273 if self.lode_rshort is not None: 274 rs = self.lode_rshort 275 d = self.lode_dshort 276 switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * ( 277 r < rs + d 278 ) + (r > rs + d) 279 eij_lr = eij_lr * switch_short 280 281 # dim_lr = self.nchannels_lode 282 nchannels_lode = ( 283 [self.lode_channels] * self.nlayers 284 if isinstance(self.lode_channels, int) 285 else self.lode_channels 286 ) 287 dim_lr = nchannels_lode 288 289 if equivariant_lode: 290 if self.lode_resolve_l: 291 eij_lr = eij_lr.repeat(nrep_lr, axis=-1) 292 Yij = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)( 293 graph_lode["vec"] / r 294 ) 295 eij_lr = (eij_lr * Yij)[:, None, :] 296 dim_lr = [d * (lmax_lr + 1) for d in dim_lr] 297 298 299 ################################################## 300 ### GET ANGLES ### 301 if use_angles: 302 angles = graph_angle["angles"][:, None] 303 angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"] 304 switch_angles = graph_angle["switch"][:, None] 305 central_atom = graph_angle["central_atom"] 306 307 if not self.angle_combine_pairs: 308 assert ( 309 self.radial_basis_angle is not None 310 ), "radial_basis_angle must be specified if angle_combine_pairs=False" 311 312 ### COMPUTE RADIAL BASIS FOR ANGLES ### 313 if self.radial_basis_angle is not None: 314 dangles = graph_angle["distances"] 315 swang = switch_angles 316 if not self.angle_combine_pairs: 317 dangles = 0.5 * (dangles[angle_src] + dangles[angle_dst]) 318 swang = switch_angles[angle_src] * switch_angles[angle_dst] 319 radial_basis_angle = ( 320 RadialBasis( 321 **{ 322 **self.radial_basis_angle, 323 "end": self._graphs_properties[self.graph_angle_key][ 324 "cutoff" 325 ], 326 "name": f"RadialBasisAngle", 327 } 328 )(dangles) 329 * swang 330 ) 331 332 else: 333 if filtered: 334 radial_basis_angle = radial_basis[filter_indices] * switch_angles 335 else: 336 radial_basis_angle = radial_basis * switch 337 338 radial_basis = radial_basis * switch 339 340 # # add covalent indicator 341 # rc = jnp.asarray([d/au.BOHR for d in D3_COV_RADII])[species] 342 # rcij = rc[edge_src] + rc[edge_dst] 343 # fact = graph["switch"]*(2*distances/rcij)*jnp.exp(-0.5 * ((distances - rcij)/(0.1*rcij)) ** 2) 344 # radial_basis = jnp.concatenate([radial_basis,fact[:,None]],axis=-1) 345 # if use_angles: 346 # rcij = rc[graph_angle["edge_src"]] + rc[graph_angle["edge_dst"]] 347 # dangles = graph_angle["distances"] 348 # fact = graph_angle["switch"]*((2*dangles/rcij))*jnp.exp(-0.5 * ((dangles - rcij)/(0.1*rcij))**2) 349 # radial_basis_angle = jnp.concatenate([radial_basis_angle,fact[:,None]],axis=-1) 350 351 352 ################################################## 353 if self.lmax > 0: 354 Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)( 355 graph["vec"] / graph["distances"][:, None] 356 )[:, None, :] 357 Yij = jnp.broadcast_to(Yij, (Yij.shape[0], self.nchannels_l, Yij.shape[2])) 358 nrep_l = np.array([2 * l + 1 for l in range(self.lmax + 1)], dtype=np.int32) 359 # ls = [0] 360 # for l in range(1, self.lmax + 1): 361 # ls = ls + [l] * (2 * l + 1) 362 #ls = jnp.asarray(np.array(ls)[None, :], dtype=distances.dtype) 363 #lcut = (0.5 + 0.5 * jnp.cos((np.pi / cutoff) * distances[:, #None])) ** ( 364 # ls + 1 365 #) 366 # lcut = jnp.where(graph["edge_mask"][:, None], lcut, 0.0) 367 # rijl1 = (lcut * distances[:, None] ** ls)[:, None, :] 368 369 ################################################## 370 if use_angles: 371 ### ANGULAR BASIS ### 372 if self.angle_style == "fourier": 373 # build fourier series for angles 374 nangles = self.param( 375 f"nangles", 376 lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :], 377 self.nmax_angle + 1, 378 ) 379 380 phi = self.param( 381 f"phi", 382 lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype), 383 self.nmax_angle + 1, 384 ) 385 xa = jnp.cos(nangles * angles + phi) 386 elif self.angle_style == "fourier_full": 387 # build fourier series for angles including sin terms 388 nangles = self.param( 389 f"nangles", 390 lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :], 391 self.nmax_angle + 1, 392 ) 393 394 phi = self.param( 395 f"phi", 396 lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype), 397 2 * self.nmax_angle + 1, 398 ) 399 xac = jnp.cos(nangles * angles + phi[:, : self.nmax_angle + 1]) 400 xas = jnp.sin(nangles[:, 1:] * angles + phi[:, self.nmax_angle + 1 :]) 401 xa = jnp.concatenate([xac, xas], axis=-1) 402 elif self.angle_style == "ani": 403 # ANI-style angle embedding 404 angle_start = np.pi / (2 * (self.nmax_angle + 1)) 405 shiftZ = self.param( 406 f"shiftZ", 407 lambda key, dim: jnp.asarray( 408 (np.linspace(0, np.pi, dim + 1) + angle_start)[None, :-1], 409 dtype=distances.dtype, 410 ), 411 self.nmax_angle + 1, 412 ) 413 zeta = self.param( 414 f"zeta", 415 lambda key: jnp.asarray(self.zeta, dtype=distances.dtype), 416 ) 417 xa = (0.5 + 0.5 * jnp.cos(angles - shiftZ)) ** zeta 418 else: 419 raise ValueError(f"Unknown angle style {self.angle_style}") 420 xa = xa[:, None, :] 421 if not self.angle_combine_pairs: 422 if reduce_memory: raise NotImplementedError("Angle embedding not implemented with reduce_memory") 423 xa = (xa * radial_basis_angle[:, :, None]).reshape( 424 -1, 1, xa.shape[1] * radial_basis_angle.shape[1] 425 ) 426 427 if self.pair_embedding_key is not None: 428 if filtered: 429 ang_pair_src = filter_indices[angle_src] 430 ang_pair_dst = filter_indices[angle_dst] 431 else: 432 ang_pair_src = angle_src 433 ang_pair_dst = angle_dst 434 ang_pairs = jnp.concatenate((ang_pair_src, ang_pair_dst)) 435 436 ################################################## 437 ### DIMENSIONS ### 438 dim_src = ( 439 [self.dim_src] * self.nlayers 440 if isinstance(self.dim_src, int) 441 else self.dim_src 442 ) 443 assert ( 444 len(dim_src) == self.nlayers 445 ), f"dim_src must be an integer or a list of length {self.nlayers}" 446 dim_dst = self.dim_dst 447 # dim_dst = ( 448 # [self.dim_dst] * self.nlayers 449 # if isinstance(self.dim_dst, int) 450 # else self.dim_dst 451 # ) 452 # assert ( 453 # len(dim_dst) == self.nlayers 454 # ), f"dim_dst must be an integer or a list of length {self.nlayers}" 455 456 if use_angles: 457 dim_angle = ( 458 [self.dim_angle] * self.nlayers 459 if isinstance(self.dim_angle, int) 460 else self.dim_angle 461 ) 462 assert ( 463 len(dim_angle) == self.nlayers 464 ), f"dim_angle must be an integer or a list of length {self.nlayers}" 465 # nmax_angle = [self.nmax_angle]*self.nlayers if isinstance(self.nmax_angle, int) else self.nmax_angle 466 # assert len(nmax_angle) == self.nlayers, f"nmax_angle must be an integer or a list of length {self.nlayers}" 467 468 initialize_e3 = True 469 if self.lmax > 0: 470 n_tp = ( 471 [self.n_tp] * self.nlayers 472 if isinstance(self.n_tp, int) 473 else self.n_tp 474 ) 475 assert ( 476 len(n_tp) == self.nlayers 477 ), f"n_tp must be an integer or a list of length {self.nlayers}" 478 479 480 message_passing = ( 481 [self.message_passing] * self.nlayers 482 if isinstance(self.message_passing, bool) 483 else self.message_passing 484 ) 485 assert ( 486 len(message_passing) == self.nlayers 487 ), f"message_passing must be a boolean or a list of length {self.nlayers}" 488 489 ################################################## 490 ### INITIALIZE PAIR EMBEDDING ### 491 if self.pair_embedding_key is not None: 492 xij_s,xij_d = jnp.split(nn.Dense(2*dim_dst, name="pair_init_linear")(zi), [dim_dst], axis=-1) 493 xij = layer_norm(xij_s[edge_src]*xij_d[edge_dst]) 494 495 ################################################## 496 if self.keep_all_layers: 497 xis = [] 498 499 ### LOOP OVER LAYERS ### 500 for layer in range(self.nlayers): 501 ################################################## 502 ### COMPACT DESCRIPTORS ### 503 si, si_dst = jnp.split( 504 nn.Dense( 505 dim_src[layer] + dim_dst, 506 name=f"species_linear_{layer}", 507 use_bias=self.use_bias, 508 )(xi), 509 [ 510 dim_src[layer], 511 ], 512 axis=-1, 513 ) 514 515 ################################################## 516 if message_passing[layer] or layer == 0: 517 ### MESSAGE PASSING ### 518 si_mp = si_dst[edge_dst] 519 else: 520 # if layer == 0: 521 # si_mp = si_dst[edge_dst] 522 ### ATTENTION TO SIMULATE MP ### 523 Q = nn.Dense( 524 dim_dst * self.att_dim, name=f"queries_{layer}", use_bias=False 525 )(si_dst).reshape(-1, dim_dst, self.att_dim)[edge_src] 526 K = nn.Dense( 527 dim_dst * self.att_dim, name=f"keys_{layer}", use_bias=False 528 )(zi).reshape(-1, dim_dst, self.att_dim)[edge_dst] 529 530 si_mp = (K * Q).sum(axis=-1) / self.att_dim**0.5 531 # Vmp = jax.ops.segment_sum( 532 # (KQ * switch)[:, :, None] * Yij, edge_src, species.shape[0] 533 # ) 534 # si_mp = (Vmp[edge_src] * Yij).sum(axis=-1) 535 # Q = nn.Dense( 536 # dim_dst * dim_dst, name=f"queries_{layer}", use_bias=False 537 # )(si_dst).reshape(-1, dim_dst, dim_dst) 538 # si_mp = ( 539 # si_mp + jax.vmap(jnp.dot)(Q[edge_src], si_mp) / self.dim_dst**0.5 540 # ) 541 542 if self.pair_embedding_key is not None: 543 si_mp = si_mp + xij 544 545 ################################################## 546 ### PAIR EMBEDDING ### 547 if reduce_memory: 548 Li = jnp.zeros((species.shape[0]* radial_basis.shape[1],si_mp.shape[1]),dtype=si_mp.dtype) 549 for i in range(radial_basis.shape[1]): 550 indices = i + edge_src*radial_basis.shape[1] 551 Li = Li.at[indices].add(si_mp*radial_basis[:,i,None]) 552 Li = Li.reshape(species.shape[0], radial_basis.shape[1]*si_mp.shape[1]) 553 else: 554 Lij = (si_mp[:, None, :] * radial_basis[:, :, None]).reshape( 555 radial_basis.shape[0], si_mp.shape[1] * radial_basis.shape[1] 556 ) 557 ### AGGREGATE PAIR EMBEDDING ### 558 Li = jax.ops.segment_sum(Lij, edge_src, species.shape[0]) 559 560 ### CONCATENATE EMBEDDING COMPONENTS ### 561 components = [si, Li] 562 if self.pair_embedding_key is not None: 563 if reduce_memory: raise NotImplementedError("Pair embedding not implemented with reduce_memory") 564 components_pair = [si[edge_src], xij, Lij] 565 566 567 ################################################## 568 ### ANGLE EMBEDDING ### 569 if use_angles and dim_angle[layer]>0: 570 si_mp_ang = si_mp[filter_indices] if filtered else si_mp 571 if self.angle_combine_pairs: 572 Wa = self.param( 573 f"Wa_{layer}", 574 nn.initializers.normal( 575 stddev=1.0 576 / (si_mp.shape[1] * radial_basis_angle.shape[1]) ** 0.5 577 ), 578 (si_mp.shape[1], radial_basis_angle.shape[1], dim_angle[layer]), 579 ) 580 Da = jnp.einsum( 581 "...i,...j,ijk->...k", 582 si_mp_ang, 583 radial_basis_angle, 584 Wa, 585 ) 586 587 else: 588 if message_passing[layer] or layer == 0: 589 Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")( 590 xi 591 )[graph_angle["edge_dst"]] 592 else: 593 Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")( 594 si_mp_ang 595 ) 596 597 Da = Da[angle_dst] * Da[angle_src] 598 ## combine pair and angle info 599 if reduce_memory: 600 ang_embedding = jnp.zeros((species.shape[0]* Da.shape[-1],xa.shape[-1]),dtype=Da.dtype) 601 for i in range(Da.shape[-1]): 602 indices = i + central_atom*Da.shape[-1] 603 ang_embedding = ang_embedding.at[indices].add(Da[:,i,None]*xa[:,0,:]) 604 ang_embedding = ang_embedding.reshape(species.shape[0], xa.shape[-1]*Da.shape[-1]) 605 else: 606 radang = (xa * Da[:, :, None]).reshape( 607 (-1, Da.shape[1] * xa.shape[2]) 608 ) 609 ### AGGREGATE ANGLE EMBEDDING ### 610 ang_embedding = jax.ops.segment_sum( 611 radang, central_atom, species.shape[0] 612 ) 613 614 615 components.append(ang_embedding) 616 617 if self.pair_embedding_key is not None: 618 ang_ij = jax.ops.segment_sum( 619 jnp.concatenate((radang, radang)), 620 ang_pairs, 621 edge_src.shape[0], 622 ) 623 components_pair.append(ang_ij) 624 625 ################################################## 626 ### EQUIVARIANT EMBEDDING ### 627 if self.lmax > 0 and n_tp[layer] >= 0: 628 if initialize_e3 or not message_passing[layer]: 629 Vij = Yij 630 elif self.edge_tp: 631 Vij = FilteredTensorProduct( 632 self.lmax, self.lmax, name=f"edge_tp_{layer}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False 633 )(Vi[edge_dst], Yij) 634 else: 635 Vij = Vi[edge_dst] 636 637 ### compute channel weights 638 dim_wij = self.nchannels_l 639 if self.resolve_wij_l: 640 dim_wij=self.nchannels_l*(self.lmax+1) 641 642 eij = Lij if self.pair_embedding_key is None else jnp.concatenate([Lij,xij*switch],axis=-1) 643 wij = nn.Dense( 644 dim_wij, name=f"e3_channel_{layer}", use_bias=False 645 )(eij) 646 if self.resolve_wij_l: 647 wij = jnp.repeat(wij.reshape(-1,self.nchannels_l,self.lmax+1),nrep_l,axis=-1) 648 else: 649 wij = wij[:,:,None] 650 651 ### aggregate equivariant messages 652 drhoi = jax.ops.segment_sum( 653 wij * Vij, 654 edge_src, 655 species.shape[0], 656 ) 657 658 Vi0 = [] 659 if initialize_e3: 660 rhoi = drhoi 661 Vi = ChannelMixingE3( 662 self.lmax, 663 self.nchannels_l, 664 self.nchannels_l, 665 name=f"e3_initial_mixing_{layer}", 666 )(rhoi) 667 # assert n_tp[layer] > 0, "n_tp must be > 0 for the first equivariant layer." 668 else: 669 rhoi = rhoi + drhoi 670 # if message_passing[layer]: 671 # Vi0.append(drhoi[:, :, 0]) 672 initialize_e3 = False 673 if n_tp[layer] > 0: 674 for itp in range(n_tp[layer]): 675 dVi = FilteredTensorProduct( 676 self.lmax, self.lmax, name=f"tensor_product_{layer}_{itp}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False 677 )(rhoi, Vi) 678 Vi = ChannelMixing( 679 self.lmax, 680 self.nchannels_l, 681 self.nchannels_l, 682 name=f"tp_mixing_{layer}_{itp}", 683 )(Vi + dVi) 684 Vi0.append(dVi[:, :, 0]) 685 Vi0 = jnp.concatenate(Vi0, axis=-1) 686 components.append(Vi0) 687 688 if self.pair_embedding_key is not None: 689 Vij = Vi[edge_src]*Yij 690 Vij0 = [Vij[...,0]] 691 for l in range(1,self.lmax+1): 692 Vij0.append(Vij[...,l**2:(l+1)**2].sum(axis=-1)) 693 Vij0 = jnp.concatenate(Vij0,axis=-1) 694 components_pair.append(Vij0) 695 696 ################################################## 697 ### CONCATENATE EMBEDDING COMPONENTS ### 698 if do_lode and nchannels_lode[layer] > 0: 699 zj = nn.Dense(dim_lr[layer], use_bias=False, name=f"LODE_{layer}")(xi) 700 if equivariant_lode: 701 zj = zj.reshape( 702 (species.shape[0], nchannels_lode[layer], lmax_lr + 1) 703 ).repeat(nrep_lr, axis=-1) 704 xi_lr = jax.ops.segment_sum( 705 eij_lr * zj[edge_dst_lr], edge_src_lr, species.shape[0] 706 ) 707 if equivariant_lode: 708 assert self.lode_use_field_norm or self.lode_multipole_interaction, "equivariant LODE requires field norm or multipole interaction" 709 if self.lode_multipole_interaction: 710 if initialize_e3: 711 raise ValueError("equivariant LODE used before local equivariants initialized") 712 size_l_lr = (lmax_lr+1)**2 713 if self.lode_direct_multipoles: 714 assert nchannels_lode[layer] <= self.nchannels_l 715 Mi = Vi[:, : nchannels_lode[layer], :size_l_lr] 716 else: 717 Mi = ChannelMixingE3( 718 lmax_lr, 719 self.nchannels_l, 720 nchannels_lode[layer], 721 name=f"e3_LODE_{layer}", 722 )(Vi[...,:size_l_lr]) 723 Mi_lr = Mi * xi_lr 724 components.append(xi_lr[:, :, 0]) 725 if self.lode_use_field_norm and self.lode_equi_full_combine: 726 xi_lr1 = ChannelMixing( 727 lmax_lr, 728 nchannels_lode[layer], 729 nchannels_lode[layer], 730 name=f"LODE_mixing_{layer}", 731 )(xi_lr) 732 norm = 1. 733 for l in range(1, lmax_lr + 1): 734 if self.lode_normalize_l: 735 norm = 1. / (2 * l + 1) 736 if self.lode_multipole_interaction: 737 components.append(Mi_lr[:, :, l**2 : (l + 1) ** 2].sum(axis=-1)*norm) 738 739 if self.lode_use_field_norm: 740 if self.lode_equi_full_combine: 741 components.append((xi_lr[:,:,l**2 : (l + 1) ** 2]*xi_lr1[:,:,l**2 : (l + 1) ** 2]).sum(axis=-1)*norm) 742 else: 743 components.append( 744 ((xi_lr[:, :, l**2 : (l + 1) ** 2]) ** 2).sum(axis=-1)*norm 745 ) 746 else: 747 components.append(xi_lr) 748 749 dxi = jnp.concatenate(components, axis=-1) 750 751 ################################################## 752 ### CONCATENATE PAIR EMBEDDING COMPONENTS ### 753 if self.pair_embedding_key is not None: 754 dxij = jnp.concatenate(components_pair, axis=-1) 755 756 ################################################## 757 ### MIX AND APPLY NONLINEARITY ### 758 if self.block_index_key is not None: 759 block_index = inputs[self.block_index_key] 760 dxi = actmix(BlockIndexNet( 761 output_dim=self.dim, 762 hidden_neurons=self.mixing_hidden, 763 activation=self.activation, 764 name=f"dxi_{layer}", 765 use_bias=self.use_bias, 766 kernel_init=kernel_init, 767 )((species,dxi, block_index)) 768 ) 769 else: 770 dxi = actmix( 771 FullyConnectedNet( 772 [*self.mixing_hidden, self.dim], 773 activation=self.activation, 774 name=f"dxi_{layer}", 775 use_bias=self.use_bias, 776 kernel_init=kernel_init, 777 )(dxi) 778 ) 779 780 if self.pair_embedding_key is not None: 781 ### UPDATE PAIR EMBEDDING ### 782 # dxij = tssr3(nn.Dense(dim_dst, name=f"dxij_{layer}",use_bias=False)(dxij)) 783 dxij = actmix( 784 FullyConnectedNet( 785 [*self.pair_mixing_hidden, dim_dst], 786 activation=self.activation, 787 name=f"dxij_{layer}", 788 use_bias=False, 789 kernel_init=kernel_init, 790 )(dxij) 791 ) 792 xij = layer_norm(xij + dxij) 793 794 ################################################## 795 ### UPDATE EMBEDDING ### 796 if layer == 0 and not (self.species_init or self.charge_embedding): 797 xi = layer_norm(dxi) 798 else: 799 ### FORGET GATE ### 800 R = jax.nn.sigmoid( 801 self.param( 802 f"retention_{layer}", 803 nn.initializers.normal(), 804 (xi.shape[-1],), 805 ) 806 ) 807 xi = layer_norm(R[None, :] * xi + dxi) 808 809 if self.keep_all_layers: 810 xis.append(xi) 811 812 embedding_key = ( 813 self.embedding_key if self.embedding_key is not None else self.name 814 ) 815 output = { 816 **inputs, 817 embedding_key: xi, 818 } 819 if self.lmax > 0: 820 output[embedding_key + "_tensor"] = Vi 821 if self.keep_all_layers: 822 output[embedding_key + "_layers"] = jnp.stack(xis, axis=1) 823 if self.charge_embedding: 824 output[embedding_key + "_charge"] = charge_embedding 825 if self.pair_embedding_key is not None: 826 output[self.pair_embedding_key] = xij 827 return output
Configurable Resources ATomic Environment
FID : CRATE
This class represents the CRATE (Configurable Resources ATomic Environment) embedding model. It is used to encode atomic environments using multiple sources of information (radial, angular, E(3), message-passing, LODE, etc...)
Whether to combine angle pairs instead of average distance embedding like in ANI.
The hidden size for the attention mechanism (only used when message-passing is disabled).
Whether to ignore the parity of the irreps in the tensor product.
The kernel initialization function for Dense operations.
The key for the pair embedding data in the output dictionary.
If str
, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding
.
The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis
.
The dictionary of parameters for radial basis functions for angle embedding. If None, the radial basis for angles is the same as the radial basis for distances.
The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode.
Whether to interact with the multipole moments of the lode graph.
Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction.
The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.