fennol.models.misc.misc
1import flax.linen as nn 2from typing import Any, Sequence, Callable, Union, ClassVar, Optional, Dict, List 3import jax.numpy as jnp 4import jax 5import numpy as np 6from functools import partial 7from ...utils.activations import activation_from_str 8from ...utils.periodic_table import ( 9 CHEMICAL_PROPERTIES, 10 PERIODIC_TABLE, 11 PERIODIC_TABLE_REV_IDX, 12) 13 14 15def apply_switch(x: jax.Array, switch: jax.Array): 16 """@private Multiply a switch array to an array of values.""" 17 shape = x.shape 18 return ( 19 jnp.expand_dims(x, axis=-1).reshape(*switch.shape, -1) * switch[..., None] 20 ).reshape(shape) 21 22 23class ApplySwitch(nn.Module): 24 """Multiply an edge array by a switch array. 25 26 FID: APPLY_SWITCH 27 """ 28 29 key: str 30 """The key of the input array.""" 31 switch_key: Optional[str] = None 32 """The key of the switch array.""" 33 graph_key: Optional[str] = None 34 """The key of the graph containing the switch.""" 35 output_key: Optional[str] = None 36 """The key of the output array. If None, the input key is used.""" 37 38 FID: ClassVar[str] = "APPLY_SWITCH" 39 40 @nn.compact 41 def __call__(self, inputs) -> Any: 42 if self.graph_key is not None: 43 graph = inputs[self.graph_key] 44 switch = graph["switch"] 45 elif self.switch_key is not None: 46 switch = inputs[self.switch_key] 47 else: 48 raise ValueError("Either graph_key or switch_key must be specified") 49 50 x = inputs[self.key] 51 output = apply_switch(x, switch) 52 output_key = self.key if self.output_key is None else self.output_key 53 return {**inputs, output_key: output} 54 55 56class AtomToEdge(nn.Module): 57 """Map atom-wise values to edge-wise values. 58 59 FID: ATOM_TO_EDGE 60 61 By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True. 62 """ 63 64 _graphs_properties: Dict 65 key: str 66 """The key of the input atom-wise array.""" 67 output_key: Optional[str] = None 68 """The key of the output edge-wise array. If None, the input key is used.""" 69 graph_key: str = "graph" 70 """The key of the graph containing the edges.""" 71 switch: bool = False 72 """Whether to apply a switch to the edge values.""" 73 switch_key: Optional[str] = None 74 """The key of the switch array. If None, the switch is taken from the graph.""" 75 use_source: bool = False 76 """Whether to use the source atom value instead of the destination atom value.""" 77 78 FID: ClassVar[str] = "ATOM_TO_EDGE" 79 80 @nn.compact 81 def __call__(self, inputs) -> Any: 82 graph = inputs[self.graph_key] 83 nat = inputs["species"].shape[0] 84 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 85 86 x = inputs[self.key] 87 if self.use_source: 88 x_edge = x[edge_src] 89 else: 90 x_edge = x[edge_dst] 91 92 if self.switch: 93 switch = ( 94 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 95 ) 96 x_edge = apply_switch(x_edge, switch) 97 98 output_key = self.key if self.output_key is None else self.output_key 99 return {**inputs, output_key: x_edge} 100 101 102class ScatterEdges(nn.Module): 103 """Reduce an edge array to atoms by summing over neighbors. 104 105 FID: SCATTER_EDGES 106 """ 107 108 _graphs_properties: Dict 109 key: str 110 """The key of the input edge-wise array.""" 111 output_key: Optional[str] = None 112 """The key of the output atom-wise array. If None, the input key is used.""" 113 graph_key: str = "graph" 114 """The key of the graph containing the edges.""" 115 switch: bool = False 116 """Whether to apply a switch to the edge values before summing.""" 117 switch_key: Optional[str] = None 118 """The key of the switch array. If None, the switch is taken from the graph.""" 119 antisymmetric: bool = False 120 121 FID: ClassVar[str] = "SCATTER_EDGES" 122 123 @nn.compact 124 def __call__(self, inputs) -> Any: 125 graph = inputs[self.graph_key] 126 nat = inputs["species"].shape[0] 127 x = inputs[self.key] 128 129 if self.switch: 130 switch = ( 131 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 132 ) 133 x = apply_switch(x, switch) 134 135 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 136 output = jax.ops.segment_sum( 137 x, edge_src, nat 138 ) # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop") 139 if self.antisymmetric: 140 output = output - jax.ops.segment_sum(x, edge_dst, nat) 141 142 output_key = self.key if self.output_key is None else self.output_key 143 return {**inputs, output_key: output} 144 145 146class EdgeConcatenate(nn.Module): 147 """Concatenate the source and destination atom values of an edge. 148 149 FID: EDGE_CONCATENATE 150 """ 151 152 _graphs_properties: Dict 153 key: str 154 """The key of the input atom-wise array.""" 155 output_key: Optional[str] = None 156 """The key of the output edge-wise array. If None, the input key is used.""" 157 graph_key: str = "graph" 158 """The key of the graph containing the edges.""" 159 switch: bool = False 160 """Whether to apply a switch to the edge values.""" 161 switch_key: Optional[str] = None 162 """The key of the switch array. If None, the switch is taken from the graph.""" 163 axis: int = -1 164 """The axis along which to concatenate the atom values.""" 165 166 FID: ClassVar[str] = "EDGE_CONCATENATE" 167 168 @nn.compact 169 def __call__(self, inputs) -> Any: 170 graph = inputs[self.graph_key] 171 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 172 nat = inputs["species"].shape[0] 173 xi = inputs[self.key] 174 175 assert self._graphs_properties[self.graph_key][ 176 "directed" 177 ], "EdgeConcatenate only works for directed graphs" 178 assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat" 179 180 xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis) 181 182 if self.switch: 183 switch = ( 184 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 185 ) 186 xij = apply_switch(xij, switch) 187 188 output_key = self.name if self.output_key is None else self.output_key 189 return {**inputs, output_key: xij} 190 191 192class ScatterSystem(nn.Module): 193 """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch). 194 195 FID: SCATTER_SYSTEM 196 """ 197 198 key: str 199 """The key of the input atom-wise array.""" 200 output_key: Optional[str] = None 201 """The key of the output system-wise array. If None, the input key is used.""" 202 average: bool = False 203 """Wether to divide by the number of atoms in the system.""" 204 205 FID: ClassVar[str] = "SCATTER_SYSTEM" 206 207 @nn.compact 208 def __call__(self, inputs) -> Any: 209 batch_index = inputs["batch_index"] 210 x = inputs[self.key] 211 assert ( 212 x.shape[0] == batch_index.shape[0] 213 ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}" 214 nsys = inputs["natoms"].shape[0] 215 if self.average: 216 shape = [batch_index.shape[0]] + (x.ndim - 1) * [1] 217 x = x / inputs["natoms"][batch_index].reshape(shape) 218 219 output = jax.ops.segment_sum(x, batch_index, nsys) 220 221 output_key = self.key if self.output_key is None else self.output_key 222 return {**inputs, output_key: output} 223 224 225class SystemToAtoms(nn.Module): 226 """Broadcast a system-wise array to an atom-wise array. 227 228 FID: SYSTEM_TO_ATOMS 229 """ 230 231 key: str 232 """The key of the input system-wise array.""" 233 output_key: Optional[str] = None 234 """The key of the output atom-wise array. If None, the input key is used.""" 235 236 FID: ClassVar[str] = "SYSTEM_TO_ATOMS" 237 238 @nn.compact 239 def __call__(self, inputs) -> Any: 240 batch_index = inputs["batch_index"] 241 x = inputs[self.key] 242 output = x[batch_index] 243 244 output_key = self.key if self.output_key is None else self.output_key 245 return {**inputs, output_key: output} 246 247 248class SumAxis(nn.Module): 249 """Sum an array along an axis. 250 251 FID: SUM_AXIS 252 """ 253 254 key: str 255 """The key of the input array.""" 256 axis: Union[None, int, Sequence[int]] = None 257 """The axis along which to sum the array.""" 258 output_key: Optional[str] = None 259 """The key of the output array. If None, the input key is used.""" 260 norm: Optional[str] = None 261 """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'.""" 262 263 FID: ClassVar[str] = "SUM_AXIS" 264 265 @nn.compact 266 def __call__(self, inputs) -> Any: 267 x = inputs[self.key] 268 output = jnp.sum(x, axis=self.axis) 269 if self.norm is not None: 270 norm = self.norm.lower() 271 if norm == "dim": 272 dim = np.prod(x.shape[self.axis]) 273 output = output / dim 274 elif norm == "sqrt": 275 dim = np.prod(x.shape[self.axis]) 276 output = output / dim**0.5 277 elif norm == "none": 278 pass 279 else: 280 raise ValueError(f"Unknown norm {norm}") 281 output_key = self.key if self.output_key is None else self.output_key 282 return {**inputs, output_key: output} 283 284 285class Split(nn.Module): 286 """Split an array along an axis. 287 288 FID: SPLIT 289 """ 290 291 key: str 292 """The key of the input array.""" 293 output_keys: Sequence[str] 294 """The keys of the output arrays.""" 295 axis: int = -1 296 """The axis along which to split the array.""" 297 sizes: Union[int, Sequence[int]] = 1 298 """The sizes of the splits.""" 299 squeeze: bool = True 300 """Whether to remove the axis in the output if the size is 1.""" 301 302 FID: ClassVar[str] = "SPLIT" 303 304 @nn.compact 305 def __call__(self, inputs) -> Any: 306 x = inputs[self.key] 307 308 if isinstance(self.sizes, int): 309 split_size = [self.sizes] * len(self.output_keys) 310 else: 311 split_size = self.sizes 312 if len(split_size) == len(self.output_keys): 313 assert ( 314 sum(split_size) == x.shape[self.axis] 315 ), f"Split sizes {split_size} do not match input shape" 316 split_size = split_size[:-1] 317 assert ( 318 len(split_size) == len(self.output_keys) - 1 319 ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs" 320 split_indices = np.cumsum(split_size) 321 outs = {} 322 323 for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)): 324 outs[k] = ( 325 jnp.squeeze(v, axis=self.axis) 326 if self.squeeze and v.shape[self.axis] == 1 327 else v 328 ) 329 330 return {**inputs, **outs} 331 332 333class Concatenate(nn.Module): 334 """Concatenate a list of arrays along an axis. 335 336 FID: CONCATENATE 337 """ 338 339 keys: Sequence[str] 340 """The keys of the input arrays.""" 341 axis: int = -1 342 """The axis along which to concatenate the arrays.""" 343 output_key: Optional[str] = None 344 """The key of the output array. If None, the name of the module is used.""" 345 346 FID: ClassVar[str] = "CONCATENATE" 347 348 @nn.compact 349 def __call__(self, inputs) -> Any: 350 output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis) 351 output_key = self.output_key if self.output_key is not None else self.name 352 return {**inputs, output_key: output} 353 354 355class Activation(nn.Module): 356 """Apply an element-wise activation function to an array. 357 358 FID: ACTIVATION 359 """ 360 361 key: str 362 """The key of the input array.""" 363 activation: Union[Callable, str] 364 """The activation function or its name.""" 365 scale_out: float = 1.0 366 """Output scaling factor.""" 367 shift_out: float = 0.0 368 """Output shift.""" 369 output_key: Optional[str] = None 370 """The key of the output array. If None, the input key is used.""" 371 372 FID: ClassVar[str] = "ACTIVATION" 373 374 @nn.compact 375 def __call__(self, inputs) -> Any: 376 x = inputs[self.key] 377 activation = ( 378 activation_from_str(self.activation) 379 if isinstance(self.activation, str) 380 else self.activation 381 ) 382 output = self.scale_out * activation(x) + self.shift_out 383 output_key = self.output_key if self.output_key is not None else self.key 384 return {**inputs, output_key: output} 385 386class LayerNorm(nn.Module): 387 """Layer normalization module. 388 389 FID: LAYER_NORM 390 """ 391 392 key: Optional[str] = None 393 """The key of the input.""" 394 output_key: Optional[str] = None 395 """The key of the output. If None, it is the same as the key.""" 396 axis: Union[int, Sequence[int]] = -1 397 """The axis to normalize.""" 398 epsilon: float = 1e-6 399 """The epsilon for numerical stability.""" 400 scale: float = 1.0 401 """The scale for the normalization.""" 402 shift: float = 0.0 403 """The shift for the normalization.""" 404 405 FID: ClassVar[str] = "LAYER_NORM" 406 407 @nn.compact 408 def __call__(self, inputs): 409 if isinstance(inputs, dict): 410 if self.key is None: 411 raise ValueError("Key must be specified for LayerNorm") 412 x = inputs[self.key] 413 else: 414 x = inputs 415 mu = jnp.mean(x, axis=self.axis, keepdims=True) 416 dx = x-mu 417 var = jnp.mean(dx ** 2, axis=self.axis, keepdims=True) 418 sig = (self.epsilon + var) ** (-0.5) 419 out = self.scale * (sig * dx) + self.shift 420 421 output_key = self.output_key if self.output_key is not None else self.key 422 423 if isinstance(inputs, dict): 424 return {**inputs, output_key: out} 425 426 return out 427 428 429 430class Scale(nn.Module): 431 """Scale an array by a constant factor. 432 433 FID: SCALE 434 """ 435 436 key: str 437 """The key of the input array.""" 438 scale: float 439 """The (initial) scaling factor.""" 440 output_key: Optional[str] = None 441 """The key of the output array. If None, the input key is used.""" 442 trainable: bool = False 443 """Whether the scaling factor is trainable.""" 444 445 FID: ClassVar[str] = "SCALE" 446 447 @nn.compact 448 def __call__(self, inputs) -> Any: 449 x = inputs[self.key] 450 451 if self.trainable: 452 scale = self.param("scale", lambda rng: jnp.asarray(self.scale)) 453 else: 454 scale = self.scale 455 456 output = scale * x 457 output_key = self.output_key if self.output_key is not None else self.key 458 return {**inputs, output_key: output} 459 460 461class Add(nn.Module): 462 """Add together a list of arrays. 463 464 FID: ADD 465 """ 466 467 keys: Sequence[str] 468 """The keys of the input arrays.""" 469 output_key: Optional[str] = None 470 """The key of the output array. If None, the name of the module is used.""" 471 472 FID: ClassVar[str] = "ADD" 473 474 @nn.compact 475 def __call__(self, inputs) -> Any: 476 output = 0 477 for k in self.keys: 478 output = output + inputs[k] 479 480 output_key = self.output_key if self.output_key is not None else self.name 481 return {**inputs, output_key: output} 482 483 484class Multiply(nn.Module): 485 """Element-wise-multiply together a list of arrays. 486 487 FID: MULTIPLY 488 """ 489 490 keys: Sequence[str] 491 """The keys of the input arrays.""" 492 output_key: Optional[str] = None 493 """The key of the output array. If None, the name of the module is used.""" 494 495 FID: ClassVar[str] = "MULTIPLY" 496 497 @nn.compact 498 def __call__(self, inputs) -> Any: 499 output = 1 500 for k in self.keys: 501 output = output * inputs[k] 502 503 output_key = self.output_key if self.output_key is not None else self.name 504 return {**inputs, output_key: output} 505 506 507class Transpose(nn.Module): 508 """Transpose an array. 509 510 FID: TRANSPOSE 511 """ 512 513 key: str 514 """The key of the input array.""" 515 axes: Sequence[int] 516 """The permutation of the axes. See `jax.numpy.transpose` for more details.""" 517 output_key: Optional[str] = None 518 """The key of the output array. If None, the input key is used.""" 519 520 FID: ClassVar[str] = "TRANSPOSE" 521 522 @nn.compact 523 def __call__(self, inputs) -> Any: 524 output = jnp.transpose(inputs[self.key], axes=self.axes) 525 output_key = self.output_key if self.output_key is not None else self.key 526 return {**inputs, output_key: output} 527 528 529class Reshape(nn.Module): 530 """Reshape an array. 531 532 FID: RESHAPE 533 """ 534 535 key: str 536 """The key of the input array.""" 537 shape: Sequence[Union[int,str]] 538 """The shape of the output array.""" 539 output_key: Optional[str] = None 540 """The key of the output array. If None, the input key is used.""" 541 542 FID: ClassVar[str] = "RESHAPE" 543 544 @nn.compact 545 def __call__(self, inputs) -> Any: 546 shape = [] 547 for s in self.shape: 548 if isinstance(s,int): 549 shape.append(s) 550 continue 551 552 if isinstance(s,str): 553 s_=s.lower().strip() 554 if s_ in ["natoms" ,"nat","natom","n_atoms","atoms"]: 555 shape.append(inputs["species"].shape[0]) 556 continue 557 558 if s_ in ["nsys","nbatch","nsystems","n_sys","n_systems","n_batch"]: 559 shape.append(inputs["natoms"].shape[0]) 560 continue 561 562 s_ = s.strip().split("[") 563 key = s_[0] 564 if key in inputs: 565 axis = int(s_[1].split("]")[0]) 566 shape.append(inputs[key].shape[axis]) 567 continue 568 569 raise ValueError(f"Error parsing shape component {s}") 570 571 output = jnp.reshape(inputs[self.key], shape) 572 output_key = self.output_key if self.output_key is not None else self.key 573 return {**inputs, output_key: output} 574 575 576class ChemicalConstant(nn.Module): 577 """Map atomic species to a constant value. 578 579 FID: CHEMICAL_CONSTANT 580 """ 581 582 value: Union[str, List[float], float, Dict] 583 """The constant value or a dictionary of values for each element.""" 584 output_key: Optional[str] = None 585 """The key of the output array. If None, the name of the module is used.""" 586 trainable: bool = False 587 """Whether the constant is trainable.""" 588 589 FID: ClassVar[str] = "CHEMICAL_CONSTANT" 590 591 @nn.compact 592 def __call__(self, inputs) -> Any: 593 if isinstance(self.value, str): 594 constant = CHEMICAL_PROPERTIES[self.value.upper()] 595 elif isinstance(self.value, list) or isinstance(self.value, tuple): 596 constant = list(self.value) 597 elif isinstance(self.value, float): 598 constant = [self.value] * len(PERIODIC_TABLE) 599 elif hasattr(self.value, "items"): 600 constant = [0.0] * len(PERIODIC_TABLE) 601 for k, v in self.value.items(): 602 constant[PERIODIC_TABLE_REV_IDX[k]] = v 603 else: 604 raise ValueError(f"Unknown constant type {type(self.value)}") 605 606 if self.trainable: 607 constant = self.param( 608 "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32) 609 ) 610 else: 611 constant = jnp.asarray(constant, dtype=jnp.float32) 612 output = constant[inputs["species"]] 613 output_key = self.output_key if self.output_key is not None else self.name 614 return {**inputs, output_key: output} 615 616 617class SwitchFunction(nn.Module): 618 """Compute a switch array from an array of distances and a cutoff. 619 620 FID: SWITCH_FUNCTION 621 """ 622 623 cutoff: Optional[float] = None 624 """The cutoff distance. If None, the cutoff is taken from the graph.""" 625 switch_start: float = 0.0 626 """The proportion of the cutoff distance at which the switch function starts.""" 627 graph_key: Optional[str] = "graph" 628 """The key of the graph containing the distances and edge mask.""" 629 output_key: Optional[str] = None 630 """The key of the output switch array. If None, it is added to the graph.""" 631 switch_type: str = "cosine" 632 """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.""" 633 p: Optional[float] = None 634 """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`.""" 635 trainable: bool = False 636 """Whether the switch parameter is trainable.""" 637 638 FID: ClassVar[str] = "SWITCH_FUNCTION" 639 640 @nn.compact 641 def __call__(self, inputs) -> Any: 642 if self.graph_key is not None: 643 graph = inputs[self.graph_key] 644 distances, edge_mask = graph["distances"], graph["edge_mask"] 645 if self.cutoff is not None: 646 edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff)) 647 cutoff = self.cutoff 648 else: 649 cutoff = graph["cutoff"] 650 else: 651 # distances = inputs 652 if len(inputs) == 3: 653 distances, edge_mask, cutoff = inputs 654 else: 655 distances, edge_mask = inputs 656 assert ( 657 self.cutoff is not None 658 ), "cutoff must be specified if no graph is given" 659 # edge_mask = distances < self.cutoff 660 cutoff = self.cutoff 661 662 if self.switch_start > 1.0e-5: 663 assert ( 664 self.switch_start < 1.0 665 ), "switch_start is a proportion of cutoff and must be smaller than 1." 666 cutoff_in = self.switch_start * cutoff 667 x = distances - cutoff_in 668 end = cutoff - cutoff_in 669 else: 670 x = distances 671 end = cutoff 672 673 switch_type = self.switch_type.lower() 674 if switch_type == "cosine": 675 p = self.p if self.p is not None else 1.0 676 if self.trainable: 677 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 678 switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p 679 680 elif switch_type == "polynomial": 681 p = self.p if self.p is not None else 3.0 682 if self.trainable: 683 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 684 d = x / end 685 switch = ( 686 1.0 687 - 0.5 * (p + 1) * (p + 2) * d**p 688 + p * (p + 2) * d ** (p + 1) 689 - 0.5 * p * (p + 1) * d ** (p + 2) 690 ) 691 692 elif switch_type == "exponential": 693 p = self.p if self.p is not None else 1.0 694 if self.trainable: 695 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 696 r2 = x**2 697 c2 = end**2 698 switch = jnp.exp(-p * r2 / (c2 - r2)) 699 700 elif switch_type == "hard": 701 switch = jnp.where(distances < cutoff, 1.0, 0.0) 702 else: 703 raise ValueError(f"Unknown switch function {switch_type}") 704 705 if self.switch_start > 1.0e-5: 706 switch = jnp.where(distances < cutoff_in, 1.0, switch) 707 708 switch = jnp.where(edge_mask, switch, 0.0) 709 710 if self.graph_key is not None: 711 if self.output_key is not None: 712 return {**inputs, self.output_key: switch} 713 else: 714 return {**inputs, self.graph_key: {**graph, "switch": switch}} 715 else: 716 return switch # , edge_mask
24class ApplySwitch(nn.Module): 25 """Multiply an edge array by a switch array. 26 27 FID: APPLY_SWITCH 28 """ 29 30 key: str 31 """The key of the input array.""" 32 switch_key: Optional[str] = None 33 """The key of the switch array.""" 34 graph_key: Optional[str] = None 35 """The key of the graph containing the switch.""" 36 output_key: Optional[str] = None 37 """The key of the output array. If None, the input key is used.""" 38 39 FID: ClassVar[str] = "APPLY_SWITCH" 40 41 @nn.compact 42 def __call__(self, inputs) -> Any: 43 if self.graph_key is not None: 44 graph = inputs[self.graph_key] 45 switch = graph["switch"] 46 elif self.switch_key is not None: 47 switch = inputs[self.switch_key] 48 else: 49 raise ValueError("Either graph_key or switch_key must be specified") 50 51 x = inputs[self.key] 52 output = apply_switch(x, switch) 53 output_key = self.key if self.output_key is None else self.output_key 54 return {**inputs, output_key: output}
Multiply an edge array by a switch array.
FID: APPLY_SWITCH
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
57class AtomToEdge(nn.Module): 58 """Map atom-wise values to edge-wise values. 59 60 FID: ATOM_TO_EDGE 61 62 By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True. 63 """ 64 65 _graphs_properties: Dict 66 key: str 67 """The key of the input atom-wise array.""" 68 output_key: Optional[str] = None 69 """The key of the output edge-wise array. If None, the input key is used.""" 70 graph_key: str = "graph" 71 """The key of the graph containing the edges.""" 72 switch: bool = False 73 """Whether to apply a switch to the edge values.""" 74 switch_key: Optional[str] = None 75 """The key of the switch array. If None, the switch is taken from the graph.""" 76 use_source: bool = False 77 """Whether to use the source atom value instead of the destination atom value.""" 78 79 FID: ClassVar[str] = "ATOM_TO_EDGE" 80 81 @nn.compact 82 def __call__(self, inputs) -> Any: 83 graph = inputs[self.graph_key] 84 nat = inputs["species"].shape[0] 85 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 86 87 x = inputs[self.key] 88 if self.use_source: 89 x_edge = x[edge_src] 90 else: 91 x_edge = x[edge_dst] 92 93 if self.switch: 94 switch = ( 95 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 96 ) 97 x_edge = apply_switch(x_edge, switch) 98 99 output_key = self.key if self.output_key is None else self.output_key 100 return {**inputs, output_key: x_edge}
Map atom-wise values to edge-wise values.
FID: ATOM_TO_EDGE
By default, we map the destination atom value to the edge. This can be changed by setting use_source
to True.
The key of the output edge-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Whether to use the source atom value instead of the destination atom value.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
103class ScatterEdges(nn.Module): 104 """Reduce an edge array to atoms by summing over neighbors. 105 106 FID: SCATTER_EDGES 107 """ 108 109 _graphs_properties: Dict 110 key: str 111 """The key of the input edge-wise array.""" 112 output_key: Optional[str] = None 113 """The key of the output atom-wise array. If None, the input key is used.""" 114 graph_key: str = "graph" 115 """The key of the graph containing the edges.""" 116 switch: bool = False 117 """Whether to apply a switch to the edge values before summing.""" 118 switch_key: Optional[str] = None 119 """The key of the switch array. If None, the switch is taken from the graph.""" 120 antisymmetric: bool = False 121 122 FID: ClassVar[str] = "SCATTER_EDGES" 123 124 @nn.compact 125 def __call__(self, inputs) -> Any: 126 graph = inputs[self.graph_key] 127 nat = inputs["species"].shape[0] 128 x = inputs[self.key] 129 130 if self.switch: 131 switch = ( 132 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 133 ) 134 x = apply_switch(x, switch) 135 136 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 137 output = jax.ops.segment_sum( 138 x, edge_src, nat 139 ) # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop") 140 if self.antisymmetric: 141 output = output - jax.ops.segment_sum(x, edge_dst, nat) 142 143 output_key = self.key if self.output_key is None else self.output_key 144 return {**inputs, output_key: output}
Reduce an edge array to atoms by summing over neighbors.
FID: SCATTER_EDGES
The key of the output atom-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
147class EdgeConcatenate(nn.Module): 148 """Concatenate the source and destination atom values of an edge. 149 150 FID: EDGE_CONCATENATE 151 """ 152 153 _graphs_properties: Dict 154 key: str 155 """The key of the input atom-wise array.""" 156 output_key: Optional[str] = None 157 """The key of the output edge-wise array. If None, the input key is used.""" 158 graph_key: str = "graph" 159 """The key of the graph containing the edges.""" 160 switch: bool = False 161 """Whether to apply a switch to the edge values.""" 162 switch_key: Optional[str] = None 163 """The key of the switch array. If None, the switch is taken from the graph.""" 164 axis: int = -1 165 """The axis along which to concatenate the atom values.""" 166 167 FID: ClassVar[str] = "EDGE_CONCATENATE" 168 169 @nn.compact 170 def __call__(self, inputs) -> Any: 171 graph = inputs[self.graph_key] 172 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 173 nat = inputs["species"].shape[0] 174 xi = inputs[self.key] 175 176 assert self._graphs_properties[self.graph_key][ 177 "directed" 178 ], "EdgeConcatenate only works for directed graphs" 179 assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat" 180 181 xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis) 182 183 if self.switch: 184 switch = ( 185 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 186 ) 187 xij = apply_switch(xij, switch) 188 189 output_key = self.name if self.output_key is None else self.output_key 190 return {**inputs, output_key: xij}
Concatenate the source and destination atom values of an edge.
FID: EDGE_CONCATENATE
The key of the output edge-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
193class ScatterSystem(nn.Module): 194 """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch). 195 196 FID: SCATTER_SYSTEM 197 """ 198 199 key: str 200 """The key of the input atom-wise array.""" 201 output_key: Optional[str] = None 202 """The key of the output system-wise array. If None, the input key is used.""" 203 average: bool = False 204 """Wether to divide by the number of atoms in the system.""" 205 206 FID: ClassVar[str] = "SCATTER_SYSTEM" 207 208 @nn.compact 209 def __call__(self, inputs) -> Any: 210 batch_index = inputs["batch_index"] 211 x = inputs[self.key] 212 assert ( 213 x.shape[0] == batch_index.shape[0] 214 ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}" 215 nsys = inputs["natoms"].shape[0] 216 if self.average: 217 shape = [batch_index.shape[0]] + (x.ndim - 1) * [1] 218 x = x / inputs["natoms"][batch_index].reshape(shape) 219 220 output = jax.ops.segment_sum(x, batch_index, nsys) 221 222 output_key = self.key if self.output_key is None else self.output_key 223 return {**inputs, output_key: output}
Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).
FID: SCATTER_SYSTEM
The key of the output system-wise array. If None, the input key is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
226class SystemToAtoms(nn.Module): 227 """Broadcast a system-wise array to an atom-wise array. 228 229 FID: SYSTEM_TO_ATOMS 230 """ 231 232 key: str 233 """The key of the input system-wise array.""" 234 output_key: Optional[str] = None 235 """The key of the output atom-wise array. If None, the input key is used.""" 236 237 FID: ClassVar[str] = "SYSTEM_TO_ATOMS" 238 239 @nn.compact 240 def __call__(self, inputs) -> Any: 241 batch_index = inputs["batch_index"] 242 x = inputs[self.key] 243 output = x[batch_index] 244 245 output_key = self.key if self.output_key is None else self.output_key 246 return {**inputs, output_key: output}
Broadcast a system-wise array to an atom-wise array.
FID: SYSTEM_TO_ATOMS
The key of the output atom-wise array. If None, the input key is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
249class SumAxis(nn.Module): 250 """Sum an array along an axis. 251 252 FID: SUM_AXIS 253 """ 254 255 key: str 256 """The key of the input array.""" 257 axis: Union[None, int, Sequence[int]] = None 258 """The axis along which to sum the array.""" 259 output_key: Optional[str] = None 260 """The key of the output array. If None, the input key is used.""" 261 norm: Optional[str] = None 262 """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'.""" 263 264 FID: ClassVar[str] = "SUM_AXIS" 265 266 @nn.compact 267 def __call__(self, inputs) -> Any: 268 x = inputs[self.key] 269 output = jnp.sum(x, axis=self.axis) 270 if self.norm is not None: 271 norm = self.norm.lower() 272 if norm == "dim": 273 dim = np.prod(x.shape[self.axis]) 274 output = output / dim 275 elif norm == "sqrt": 276 dim = np.prod(x.shape[self.axis]) 277 output = output / dim**0.5 278 elif norm == "none": 279 pass 280 else: 281 raise ValueError(f"Unknown norm {norm}") 282 output_key = self.key if self.output_key is None else self.output_key 283 return {**inputs, output_key: output}
Sum an array along an axis.
FID: SUM_AXIS
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
286class Split(nn.Module): 287 """Split an array along an axis. 288 289 FID: SPLIT 290 """ 291 292 key: str 293 """The key of the input array.""" 294 output_keys: Sequence[str] 295 """The keys of the output arrays.""" 296 axis: int = -1 297 """The axis along which to split the array.""" 298 sizes: Union[int, Sequence[int]] = 1 299 """The sizes of the splits.""" 300 squeeze: bool = True 301 """Whether to remove the axis in the output if the size is 1.""" 302 303 FID: ClassVar[str] = "SPLIT" 304 305 @nn.compact 306 def __call__(self, inputs) -> Any: 307 x = inputs[self.key] 308 309 if isinstance(self.sizes, int): 310 split_size = [self.sizes] * len(self.output_keys) 311 else: 312 split_size = self.sizes 313 if len(split_size) == len(self.output_keys): 314 assert ( 315 sum(split_size) == x.shape[self.axis] 316 ), f"Split sizes {split_size} do not match input shape" 317 split_size = split_size[:-1] 318 assert ( 319 len(split_size) == len(self.output_keys) - 1 320 ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs" 321 split_indices = np.cumsum(split_size) 322 outs = {} 323 324 for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)): 325 outs[k] = ( 326 jnp.squeeze(v, axis=self.axis) 327 if self.squeeze and v.shape[self.axis] == 1 328 else v 329 ) 330 331 return {**inputs, **outs}
Split an array along an axis.
FID: SPLIT
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
334class Concatenate(nn.Module): 335 """Concatenate a list of arrays along an axis. 336 337 FID: CONCATENATE 338 """ 339 340 keys: Sequence[str] 341 """The keys of the input arrays.""" 342 axis: int = -1 343 """The axis along which to concatenate the arrays.""" 344 output_key: Optional[str] = None 345 """The key of the output array. If None, the name of the module is used.""" 346 347 FID: ClassVar[str] = "CONCATENATE" 348 349 @nn.compact 350 def __call__(self, inputs) -> Any: 351 output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis) 352 output_key = self.output_key if self.output_key is not None else self.name 353 return {**inputs, output_key: output}
Concatenate a list of arrays along an axis.
FID: CONCATENATE
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
356class Activation(nn.Module): 357 """Apply an element-wise activation function to an array. 358 359 FID: ACTIVATION 360 """ 361 362 key: str 363 """The key of the input array.""" 364 activation: Union[Callable, str] 365 """The activation function or its name.""" 366 scale_out: float = 1.0 367 """Output scaling factor.""" 368 shift_out: float = 0.0 369 """Output shift.""" 370 output_key: Optional[str] = None 371 """The key of the output array. If None, the input key is used.""" 372 373 FID: ClassVar[str] = "ACTIVATION" 374 375 @nn.compact 376 def __call__(self, inputs) -> Any: 377 x = inputs[self.key] 378 activation = ( 379 activation_from_str(self.activation) 380 if isinstance(self.activation, str) 381 else self.activation 382 ) 383 output = self.scale_out * activation(x) + self.shift_out 384 output_key = self.output_key if self.output_key is not None else self.key 385 return {**inputs, output_key: output}
Apply an element-wise activation function to an array.
FID: ACTIVATION
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
387class LayerNorm(nn.Module): 388 """Layer normalization module. 389 390 FID: LAYER_NORM 391 """ 392 393 key: Optional[str] = None 394 """The key of the input.""" 395 output_key: Optional[str] = None 396 """The key of the output. If None, it is the same as the key.""" 397 axis: Union[int, Sequence[int]] = -1 398 """The axis to normalize.""" 399 epsilon: float = 1e-6 400 """The epsilon for numerical stability.""" 401 scale: float = 1.0 402 """The scale for the normalization.""" 403 shift: float = 0.0 404 """The shift for the normalization.""" 405 406 FID: ClassVar[str] = "LAYER_NORM" 407 408 @nn.compact 409 def __call__(self, inputs): 410 if isinstance(inputs, dict): 411 if self.key is None: 412 raise ValueError("Key must be specified for LayerNorm") 413 x = inputs[self.key] 414 else: 415 x = inputs 416 mu = jnp.mean(x, axis=self.axis, keepdims=True) 417 dx = x-mu 418 var = jnp.mean(dx ** 2, axis=self.axis, keepdims=True) 419 sig = (self.epsilon + var) ** (-0.5) 420 out = self.scale * (sig * dx) + self.shift 421 422 output_key = self.output_key if self.output_key is not None else self.key 423 424 if isinstance(inputs, dict): 425 return {**inputs, output_key: out} 426 427 return out
Layer normalization module.
FID: LAYER_NORM
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
431class Scale(nn.Module): 432 """Scale an array by a constant factor. 433 434 FID: SCALE 435 """ 436 437 key: str 438 """The key of the input array.""" 439 scale: float 440 """The (initial) scaling factor.""" 441 output_key: Optional[str] = None 442 """The key of the output array. If None, the input key is used.""" 443 trainable: bool = False 444 """Whether the scaling factor is trainable.""" 445 446 FID: ClassVar[str] = "SCALE" 447 448 @nn.compact 449 def __call__(self, inputs) -> Any: 450 x = inputs[self.key] 451 452 if self.trainable: 453 scale = self.param("scale", lambda rng: jnp.asarray(self.scale)) 454 else: 455 scale = self.scale 456 457 output = scale * x 458 output_key = self.output_key if self.output_key is not None else self.key 459 return {**inputs, output_key: output}
Scale an array by a constant factor.
FID: SCALE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
462class Add(nn.Module): 463 """Add together a list of arrays. 464 465 FID: ADD 466 """ 467 468 keys: Sequence[str] 469 """The keys of the input arrays.""" 470 output_key: Optional[str] = None 471 """The key of the output array. If None, the name of the module is used.""" 472 473 FID: ClassVar[str] = "ADD" 474 475 @nn.compact 476 def __call__(self, inputs) -> Any: 477 output = 0 478 for k in self.keys: 479 output = output + inputs[k] 480 481 output_key = self.output_key if self.output_key is not None else self.name 482 return {**inputs, output_key: output}
Add together a list of arrays.
FID: ADD
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
485class Multiply(nn.Module): 486 """Element-wise-multiply together a list of arrays. 487 488 FID: MULTIPLY 489 """ 490 491 keys: Sequence[str] 492 """The keys of the input arrays.""" 493 output_key: Optional[str] = None 494 """The key of the output array. If None, the name of the module is used.""" 495 496 FID: ClassVar[str] = "MULTIPLY" 497 498 @nn.compact 499 def __call__(self, inputs) -> Any: 500 output = 1 501 for k in self.keys: 502 output = output * inputs[k] 503 504 output_key = self.output_key if self.output_key is not None else self.name 505 return {**inputs, output_key: output}
Element-wise-multiply together a list of arrays.
FID: MULTIPLY
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
508class Transpose(nn.Module): 509 """Transpose an array. 510 511 FID: TRANSPOSE 512 """ 513 514 key: str 515 """The key of the input array.""" 516 axes: Sequence[int] 517 """The permutation of the axes. See `jax.numpy.transpose` for more details.""" 518 output_key: Optional[str] = None 519 """The key of the output array. If None, the input key is used.""" 520 521 FID: ClassVar[str] = "TRANSPOSE" 522 523 @nn.compact 524 def __call__(self, inputs) -> Any: 525 output = jnp.transpose(inputs[self.key], axes=self.axes) 526 output_key = self.output_key if self.output_key is not None else self.key 527 return {**inputs, output_key: output}
Transpose an array.
FID: TRANSPOSE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
530class Reshape(nn.Module): 531 """Reshape an array. 532 533 FID: RESHAPE 534 """ 535 536 key: str 537 """The key of the input array.""" 538 shape: Sequence[Union[int,str]] 539 """The shape of the output array.""" 540 output_key: Optional[str] = None 541 """The key of the output array. If None, the input key is used.""" 542 543 FID: ClassVar[str] = "RESHAPE" 544 545 @nn.compact 546 def __call__(self, inputs) -> Any: 547 shape = [] 548 for s in self.shape: 549 if isinstance(s,int): 550 shape.append(s) 551 continue 552 553 if isinstance(s,str): 554 s_=s.lower().strip() 555 if s_ in ["natoms" ,"nat","natom","n_atoms","atoms"]: 556 shape.append(inputs["species"].shape[0]) 557 continue 558 559 if s_ in ["nsys","nbatch","nsystems","n_sys","n_systems","n_batch"]: 560 shape.append(inputs["natoms"].shape[0]) 561 continue 562 563 s_ = s.strip().split("[") 564 key = s_[0] 565 if key in inputs: 566 axis = int(s_[1].split("]")[0]) 567 shape.append(inputs[key].shape[axis]) 568 continue 569 570 raise ValueError(f"Error parsing shape component {s}") 571 572 output = jnp.reshape(inputs[self.key], shape) 573 output_key = self.output_key if self.output_key is not None else self.key 574 return {**inputs, output_key: output}
Reshape an array.
FID: RESHAPE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
577class ChemicalConstant(nn.Module): 578 """Map atomic species to a constant value. 579 580 FID: CHEMICAL_CONSTANT 581 """ 582 583 value: Union[str, List[float], float, Dict] 584 """The constant value or a dictionary of values for each element.""" 585 output_key: Optional[str] = None 586 """The key of the output array. If None, the name of the module is used.""" 587 trainable: bool = False 588 """Whether the constant is trainable.""" 589 590 FID: ClassVar[str] = "CHEMICAL_CONSTANT" 591 592 @nn.compact 593 def __call__(self, inputs) -> Any: 594 if isinstance(self.value, str): 595 constant = CHEMICAL_PROPERTIES[self.value.upper()] 596 elif isinstance(self.value, list) or isinstance(self.value, tuple): 597 constant = list(self.value) 598 elif isinstance(self.value, float): 599 constant = [self.value] * len(PERIODIC_TABLE) 600 elif hasattr(self.value, "items"): 601 constant = [0.0] * len(PERIODIC_TABLE) 602 for k, v in self.value.items(): 603 constant[PERIODIC_TABLE_REV_IDX[k]] = v 604 else: 605 raise ValueError(f"Unknown constant type {type(self.value)}") 606 607 if self.trainable: 608 constant = self.param( 609 "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32) 610 ) 611 else: 612 constant = jnp.asarray(constant, dtype=jnp.float32) 613 output = constant[inputs["species"]] 614 output_key = self.output_key if self.output_key is not None else self.name 615 return {**inputs, output_key: output}
Map atomic species to a constant value.
FID: CHEMICAL_CONSTANT
The constant value or a dictionary of values for each element.
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
618class SwitchFunction(nn.Module): 619 """Compute a switch array from an array of distances and a cutoff. 620 621 FID: SWITCH_FUNCTION 622 """ 623 624 cutoff: Optional[float] = None 625 """The cutoff distance. If None, the cutoff is taken from the graph.""" 626 switch_start: float = 0.0 627 """The proportion of the cutoff distance at which the switch function starts.""" 628 graph_key: Optional[str] = "graph" 629 """The key of the graph containing the distances and edge mask.""" 630 output_key: Optional[str] = None 631 """The key of the output switch array. If None, it is added to the graph.""" 632 switch_type: str = "cosine" 633 """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.""" 634 p: Optional[float] = None 635 """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`.""" 636 trainable: bool = False 637 """Whether the switch parameter is trainable.""" 638 639 FID: ClassVar[str] = "SWITCH_FUNCTION" 640 641 @nn.compact 642 def __call__(self, inputs) -> Any: 643 if self.graph_key is not None: 644 graph = inputs[self.graph_key] 645 distances, edge_mask = graph["distances"], graph["edge_mask"] 646 if self.cutoff is not None: 647 edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff)) 648 cutoff = self.cutoff 649 else: 650 cutoff = graph["cutoff"] 651 else: 652 # distances = inputs 653 if len(inputs) == 3: 654 distances, edge_mask, cutoff = inputs 655 else: 656 distances, edge_mask = inputs 657 assert ( 658 self.cutoff is not None 659 ), "cutoff must be specified if no graph is given" 660 # edge_mask = distances < self.cutoff 661 cutoff = self.cutoff 662 663 if self.switch_start > 1.0e-5: 664 assert ( 665 self.switch_start < 1.0 666 ), "switch_start is a proportion of cutoff and must be smaller than 1." 667 cutoff_in = self.switch_start * cutoff 668 x = distances - cutoff_in 669 end = cutoff - cutoff_in 670 else: 671 x = distances 672 end = cutoff 673 674 switch_type = self.switch_type.lower() 675 if switch_type == "cosine": 676 p = self.p if self.p is not None else 1.0 677 if self.trainable: 678 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 679 switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p 680 681 elif switch_type == "polynomial": 682 p = self.p if self.p is not None else 3.0 683 if self.trainable: 684 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 685 d = x / end 686 switch = ( 687 1.0 688 - 0.5 * (p + 1) * (p + 2) * d**p 689 + p * (p + 2) * d ** (p + 1) 690 - 0.5 * p * (p + 1) * d ** (p + 2) 691 ) 692 693 elif switch_type == "exponential": 694 p = self.p if self.p is not None else 1.0 695 if self.trainable: 696 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 697 r2 = x**2 698 c2 = end**2 699 switch = jnp.exp(-p * r2 / (c2 - r2)) 700 701 elif switch_type == "hard": 702 switch = jnp.where(distances < cutoff, 1.0, 0.0) 703 else: 704 raise ValueError(f"Unknown switch function {switch_type}") 705 706 if self.switch_start > 1.0e-5: 707 switch = jnp.where(distances < cutoff_in, 1.0, switch) 708 709 switch = jnp.where(edge_mask, switch, 0.0) 710 711 if self.graph_key is not None: 712 if self.output_key is not None: 713 return {**inputs, self.output_key: switch} 714 else: 715 return {**inputs, self.graph_key: {**graph, "switch": switch}} 716 else: 717 return switch # , edge_mask
Compute a switch array from an array of distances and a cutoff.
FID: SWITCH_FUNCTION
The proportion of the cutoff distance at which the switch function starts.
The key of the output switch array. If None, it is added to the graph.
The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.
The parameter of the switch function. If None, it is fixed to the default for each switch_type
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.