fennol.models.misc.misc
1import flax.linen as nn 2from typing import Any, Sequence, Callable, Union, ClassVar, Optional, Dict, List 3import jax.numpy as jnp 4import jax 5import numpy as np 6from functools import partial 7from ...utils.activations import activation_from_str 8from ...utils.periodic_table import ( 9 CHEMICAL_PROPERTIES, 10 PERIODIC_TABLE, 11 PERIODIC_TABLE_REV_IDX, 12) 13 14 15def apply_switch(x: jax.Array, switch: jax.Array): 16 """@private Multiply a switch array to an array of values.""" 17 shape = x.shape 18 return ( 19 jnp.expand_dims(x, axis=-1).reshape(*switch.shape, -1) * switch[..., None] 20 ).reshape(shape) 21 22 23class ApplySwitch(nn.Module): 24 """Multiply an edge array by a switch array. 25 26 FID: APPLY_SWITCH 27 """ 28 29 key: str 30 """The key of the input array.""" 31 switch_key: Optional[str] = None 32 """The key of the switch array.""" 33 graph_key: Optional[str] = None 34 """The key of the graph containing the switch.""" 35 output_key: Optional[str] = None 36 """The key of the output array. If None, the input key is used.""" 37 38 FID: ClassVar[str] = "APPLY_SWITCH" 39 40 @nn.compact 41 def __call__(self, inputs) -> Any: 42 if self.graph_key is not None: 43 graph = inputs[self.graph_key] 44 switch = graph["switch"] 45 elif self.switch_key is not None: 46 switch = inputs[self.switch_key] 47 else: 48 raise ValueError("Either graph_key or switch_key must be specified") 49 50 x = inputs[self.key] 51 output = apply_switch(x, switch) 52 output_key = self.key if self.output_key is None else self.output_key 53 return {**inputs, output_key: output} 54 55 56class AtomToEdge(nn.Module): 57 """Map atom-wise values to edge-wise values. 58 59 FID: ATOM_TO_EDGE 60 61 By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True. 62 """ 63 64 _graphs_properties: Dict 65 key: str 66 """The key of the input atom-wise array.""" 67 output_key: Optional[str] = None 68 """The key of the output edge-wise array. If None, the input key is used.""" 69 graph_key: str = "graph" 70 """The key of the graph containing the edges.""" 71 switch: bool = False 72 """Whether to apply a switch to the edge values.""" 73 switch_key: Optional[str] = None 74 """The key of the switch array. If None, the switch is taken from the graph.""" 75 use_source: bool = False 76 """Whether to use the source atom value instead of the destination atom value.""" 77 78 FID: ClassVar[str] = "ATOM_TO_EDGE" 79 80 @nn.compact 81 def __call__(self, inputs) -> Any: 82 graph = inputs[self.graph_key] 83 nat = inputs["species"].shape[0] 84 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 85 86 x = inputs[self.key] 87 if self.use_source: 88 x_edge = x[edge_src] 89 else: 90 x_edge = x[edge_dst] 91 92 if self.switch: 93 switch = ( 94 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 95 ) 96 x_edge = apply_switch(x_edge, switch) 97 98 output_key = self.key if self.output_key is None else self.output_key 99 return {**inputs, output_key: x_edge} 100 101 102class ScatterEdges(nn.Module): 103 """Reduce an edge array to atoms by summing over neighbors. 104 105 FID: SCATTER_EDGES 106 """ 107 108 _graphs_properties: Dict 109 key: str 110 """The key of the input edge-wise array.""" 111 output_key: Optional[str] = None 112 """The key of the output atom-wise array. If None, the input key is used.""" 113 graph_key: str = "graph" 114 """The key of the graph containing the edges.""" 115 switch: bool = False 116 """Whether to apply a switch to the edge values before summing.""" 117 switch_key: Optional[str] = None 118 """The key of the switch array. If None, the switch is taken from the graph.""" 119 antisymmetric: bool = False 120 121 FID: ClassVar[str] = "SCATTER_EDGES" 122 123 @nn.compact 124 def __call__(self, inputs) -> Any: 125 graph = inputs[self.graph_key] 126 nat = inputs["species"].shape[0] 127 x = inputs[self.key] 128 129 if self.switch: 130 switch = ( 131 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 132 ) 133 x = apply_switch(x, switch) 134 135 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 136 output = jax.ops.segment_sum( 137 x, edge_src, nat 138 ) # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop") 139 if self.antisymmetric: 140 output = output - jax.ops.segment_sum(x, edge_dst, nat) 141 142 output_key = self.key if self.output_key is None else self.output_key 143 return {**inputs, output_key: output} 144 145 146class EdgeConcatenate(nn.Module): 147 """Concatenate the source and destination atom values of an edge. 148 149 FID: EDGE_CONCATENATE 150 """ 151 152 _graphs_properties: Dict 153 key: str 154 """The key of the input atom-wise array.""" 155 output_key: Optional[str] = None 156 """The key of the output edge-wise array. If None, the input key is used.""" 157 graph_key: str = "graph" 158 """The key of the graph containing the edges.""" 159 switch: bool = False 160 """Whether to apply a switch to the edge values.""" 161 switch_key: Optional[str] = None 162 """The key of the switch array. If None, the switch is taken from the graph.""" 163 axis: int = -1 164 """The axis along which to concatenate the atom values.""" 165 166 FID: ClassVar[str] = "EDGE_CONCATENATE" 167 168 @nn.compact 169 def __call__(self, inputs) -> Any: 170 graph = inputs[self.graph_key] 171 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 172 nat = inputs["species"].shape[0] 173 xi = inputs[self.key] 174 175 assert self._graphs_properties[self.graph_key][ 176 "directed" 177 ], "EdgeConcatenate only works for directed graphs" 178 assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat" 179 180 xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis) 181 182 if self.switch: 183 switch = ( 184 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 185 ) 186 xij = apply_switch(xij, switch) 187 188 output_key = self.name if self.output_key is None else self.output_key 189 return {**inputs, output_key: xij} 190 191 192class ScatterSystem(nn.Module): 193 """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch). 194 195 FID: SCATTER_SYSTEM 196 """ 197 198 key: str 199 """The key of the input atom-wise array.""" 200 output_key: Optional[str] = None 201 """The key of the output system-wise array. If None, the input key is used.""" 202 average: bool = False 203 """Wether to divide by the number of atoms in the system.""" 204 205 FID: ClassVar[str] = "SCATTER_SYSTEM" 206 207 @nn.compact 208 def __call__(self, inputs) -> Any: 209 batch_index = inputs["batch_index"] 210 x = inputs[self.key] 211 assert ( 212 x.shape[0] == batch_index.shape[0] 213 ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}" 214 nsys = inputs["natoms"].shape[0] 215 if self.average: 216 shape = [batch_index.shape[0]] + (x.ndim - 1) * [1] 217 x = x / inputs["natoms"][batch_index].reshape(shape) 218 219 output = jax.ops.segment_sum(x, batch_index, nsys) 220 221 output_key = self.key if self.output_key is None else self.output_key 222 return {**inputs, output_key: output} 223 224 225class SystemToAtoms(nn.Module): 226 """Broadcast a system-wise array to an atom-wise array. 227 228 FID: SYSTEM_TO_ATOMS 229 """ 230 231 key: str 232 """The key of the input system-wise array.""" 233 output_key: Optional[str] = None 234 """The key of the output atom-wise array. If None, the input key is used.""" 235 236 FID: ClassVar[str] = "SYSTEM_TO_ATOMS" 237 238 @nn.compact 239 def __call__(self, inputs) -> Any: 240 batch_index = inputs["batch_index"] 241 x = inputs[self.key] 242 output = x[batch_index] 243 244 output_key = self.key if self.output_key is None else self.output_key 245 return {**inputs, output_key: output} 246 247 248class SumAxis(nn.Module): 249 """Sum an array along an axis. 250 251 FID: SUM_AXIS 252 """ 253 254 key: str 255 """The key of the input array.""" 256 axis: Union[None, int, Sequence[int]] = None 257 """The axis along which to sum the array.""" 258 output_key: Optional[str] = None 259 """The key of the output array. If None, the input key is used.""" 260 norm: Optional[str] = None 261 """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'.""" 262 263 FID: ClassVar[str] = "SUM_AXIS" 264 265 @nn.compact 266 def __call__(self, inputs) -> Any: 267 x = inputs[self.key] 268 output = jnp.sum(x, axis=self.axis) 269 if self.norm is not None: 270 norm = self.norm.lower() 271 if norm == "dim": 272 dim = np.prod(x.shape[self.axis]) 273 output = output / dim 274 elif norm == "sqrt": 275 dim = np.prod(x.shape[self.axis]) 276 output = output / dim**0.5 277 elif norm == "none": 278 pass 279 else: 280 raise ValueError(f"Unknown norm {norm}") 281 output_key = self.key if self.output_key is None else self.output_key 282 return {**inputs, output_key: output} 283 284 285class Split(nn.Module): 286 """Split an array along an axis. 287 288 FID: SPLIT 289 """ 290 291 key: str 292 """The key of the input array.""" 293 output_keys: Sequence[str] 294 """The keys of the output arrays.""" 295 axis: int = -1 296 """The axis along which to split the array.""" 297 sizes: Union[int, Sequence[int]] = 1 298 """The sizes of the splits.""" 299 squeeze: bool = True 300 """Whether to remove the axis in the output if the size is 1.""" 301 302 FID: ClassVar[str] = "SPLIT" 303 304 @nn.compact 305 def __call__(self, inputs) -> Any: 306 x = inputs[self.key] 307 308 if isinstance(self.sizes, int): 309 split_size = [self.sizes] * len(self.output_keys) 310 else: 311 split_size = self.sizes 312 if len(split_size) == len(self.output_keys): 313 assert ( 314 sum(split_size) == x.shape[self.axis] 315 ), f"Split sizes {split_size} do not match input shape" 316 split_size = split_size[:-1] 317 assert ( 318 len(split_size) == len(self.output_keys) - 1 319 ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs" 320 split_indices = np.cumsum(split_size) 321 outs = {} 322 323 for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)): 324 outs[k] = ( 325 jnp.squeeze(v, axis=self.axis) 326 if self.squeeze and v.shape[self.axis] == 1 327 else v 328 ) 329 330 return {**inputs, **outs} 331 332 333class Concatenate(nn.Module): 334 """Concatenate a list of arrays along an axis. 335 336 FID: CONCATENATE 337 """ 338 339 keys: Sequence[str] 340 """The keys of the input arrays.""" 341 axis: int = -1 342 """The axis along which to concatenate the arrays.""" 343 output_key: Optional[str] = None 344 """The key of the output array. If None, the name of the module is used.""" 345 346 FID: ClassVar[str] = "CONCATENATE" 347 348 @nn.compact 349 def __call__(self, inputs) -> Any: 350 output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis) 351 output_key = self.output_key if self.output_key is not None else self.name 352 return {**inputs, output_key: output} 353 354 355class Activation(nn.Module): 356 """Apply an element-wise activation function to an array. 357 358 FID: ACTIVATION 359 """ 360 361 key: str 362 """The key of the input array.""" 363 activation: Union[Callable, str] 364 """The activation function or its name.""" 365 scale_out: float = 1.0 366 """Output scaling factor.""" 367 shift_out: float = 0.0 368 """Output shift.""" 369 output_key: Optional[str] = None 370 """The key of the output array. If None, the input key is used.""" 371 372 FID: ClassVar[str] = "ACTIVATION" 373 374 @nn.compact 375 def __call__(self, inputs) -> Any: 376 x = inputs[self.key] 377 activation = ( 378 activation_from_str(self.activation) 379 if isinstance(self.activation, str) 380 else self.activation 381 ) 382 output = self.scale_out * activation(x) + self.shift_out 383 output_key = self.output_key if self.output_key is not None else self.key 384 return {**inputs, output_key: output} 385 386 387class Scale(nn.Module): 388 """Scale an array by a constant factor. 389 390 FID: SCALE 391 """ 392 393 key: str 394 """The key of the input array.""" 395 scale: float 396 """The (initial) scaling factor.""" 397 output_key: Optional[str] = None 398 """The key of the output array. If None, the input key is used.""" 399 trainable: bool = False 400 """Whether the scaling factor is trainable.""" 401 402 FID: ClassVar[str] = "SCALE" 403 404 @nn.compact 405 def __call__(self, inputs) -> Any: 406 x = inputs[self.key] 407 408 if self.trainable: 409 scale = self.param("scale", lambda rng: jnp.asarray(self.scale)) 410 else: 411 scale = self.scale 412 413 output = scale * x 414 output_key = self.output_key if self.output_key is not None else self.key 415 return {**inputs, output_key: output} 416 417 418class Add(nn.Module): 419 """Add together a list of arrays. 420 421 FID: ADD 422 """ 423 424 keys: Sequence[str] 425 """The keys of the input arrays.""" 426 output_key: Optional[str] = None 427 """The key of the output array. If None, the name of the module is used.""" 428 429 FID: ClassVar[str] = "ADD" 430 431 @nn.compact 432 def __call__(self, inputs) -> Any: 433 output = 0 434 for k in self.keys: 435 output = output + inputs[k] 436 437 output_key = self.output_key if self.output_key is not None else self.name 438 return {**inputs, output_key: output} 439 440 441class Multiply(nn.Module): 442 """Element-wise-multiply together a list of arrays. 443 444 FID: MULTIPLY 445 """ 446 447 keys: Sequence[str] 448 """The keys of the input arrays.""" 449 output_key: Optional[str] = None 450 """The key of the output array. If None, the name of the module is used.""" 451 452 FID: ClassVar[str] = "MULTIPLY" 453 454 @nn.compact 455 def __call__(self, inputs) -> Any: 456 output = 1 457 for k in self.keys: 458 output = output * inputs[k] 459 460 output_key = self.output_key if self.output_key is not None else self.name 461 return {**inputs, output_key: output} 462 463 464class Transpose(nn.Module): 465 """Transpose an array. 466 467 FID: TRANSPOSE 468 """ 469 470 key: str 471 """The key of the input array.""" 472 axes: Sequence[int] 473 """The permutation of the axes. See `jax.numpy.transpose` for more details.""" 474 output_key: Optional[str] = None 475 """The key of the output array. If None, the input key is used.""" 476 477 FID: ClassVar[str] = "TRANSPOSE" 478 479 @nn.compact 480 def __call__(self, inputs) -> Any: 481 output = jnp.transpose(inputs[self.key], axes=self.axes) 482 output_key = self.output_key if self.output_key is not None else self.key 483 return {**inputs, output_key: output} 484 485 486class Reshape(nn.Module): 487 """Reshape an array. 488 489 FID: RESHAPE 490 """ 491 492 key: str 493 """The key of the input array.""" 494 shape: Sequence[Union[int,str]] 495 """The shape of the output array.""" 496 output_key: Optional[str] = None 497 """The key of the output array. If None, the input key is used.""" 498 499 FID: ClassVar[str] = "RESHAPE" 500 501 @nn.compact 502 def __call__(self, inputs) -> Any: 503 shape = [] 504 for s in self.shape: 505 if isinstance(s,int): 506 shape.append(s) 507 continue 508 509 if isinstance(s,str): 510 s_=s.lower().strip() 511 if s_ in ["natoms" ,"nat","natom","n_atoms","atoms"]: 512 shape.append(inputs["species"].shape[0]) 513 continue 514 515 if s_ in ["nsys","nbatch","nsystems","n_sys","n_systems","n_batch"]: 516 shape.append(inputs["natoms"].shape[0]) 517 continue 518 519 s_ = s.strip().split("[") 520 key = s_[0] 521 if key in inputs: 522 axis = int(s_[1].split("]")[0]) 523 shape.append(inputs[key].shape[axis]) 524 continue 525 526 raise ValueError(f"Error parsing shape component {s}") 527 528 output = jnp.reshape(inputs[self.key], shape) 529 output_key = self.output_key if self.output_key is not None else self.key 530 return {**inputs, output_key: output} 531 532 533class ChemicalConstant(nn.Module): 534 """Map atomic species to a constant value. 535 536 FID: CHEMICAL_CONSTANT 537 """ 538 539 value: Union[str, List[float], float, Dict] 540 """The constant value or a dictionary of values for each element.""" 541 output_key: Optional[str] = None 542 """The key of the output array. If None, the name of the module is used.""" 543 trainable: bool = False 544 """Whether the constant is trainable.""" 545 546 FID: ClassVar[str] = "CHEMICAL_CONSTANT" 547 548 @nn.compact 549 def __call__(self, inputs) -> Any: 550 if isinstance(self.value, str): 551 constant = CHEMICAL_PROPERTIES[self.value.upper()] 552 elif isinstance(self.value, list) or isinstance(self.value, tuple): 553 constant = list(self.value) 554 elif isinstance(self.value, float): 555 constant = [self.value] * len(PERIODIC_TABLE) 556 elif hasattr(self.value, "items"): 557 constant = [0.0] * len(PERIODIC_TABLE) 558 for k, v in self.value.items(): 559 constant[PERIODIC_TABLE_REV_IDX[k]] = v 560 else: 561 raise ValueError(f"Unknown constant type {type(self.value)}") 562 563 if self.trainable: 564 constant = self.param( 565 "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32) 566 ) 567 else: 568 constant = jnp.asarray(constant, dtype=jnp.float32) 569 output = constant[inputs["species"]] 570 output_key = self.output_key if self.output_key is not None else self.name 571 return {**inputs, output_key: output} 572 573 574class SwitchFunction(nn.Module): 575 """Compute a switch array from an array of distances and a cutoff. 576 577 FID: SWITCH_FUNCTION 578 """ 579 580 cutoff: Optional[float] = None 581 """The cutoff distance. If None, the cutoff is taken from the graph.""" 582 switch_start: float = 0.0 583 """The proportion of the cutoff distance at which the switch function starts.""" 584 graph_key: Optional[str] = "graph" 585 """The key of the graph containing the distances and edge mask.""" 586 output_key: Optional[str] = None 587 """The key of the output switch array. If None, it is added to the graph.""" 588 switch_type: str = "cosine" 589 """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.""" 590 p: Optional[float] = None 591 """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`.""" 592 trainable: bool = False 593 """Whether the switch parameter is trainable.""" 594 595 FID: ClassVar[str] = "SWITCH_FUNCTION" 596 597 @nn.compact 598 def __call__(self, inputs) -> Any: 599 if self.graph_key is not None: 600 graph = inputs[self.graph_key] 601 distances, edge_mask = graph["distances"], graph["edge_mask"] 602 if self.cutoff is not None: 603 edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff)) 604 cutoff = self.cutoff 605 else: 606 cutoff = graph["cutoff"] 607 else: 608 # distances = inputs 609 if len(inputs) == 3: 610 distances, edge_mask, cutoff = inputs 611 else: 612 distances, edge_mask = inputs 613 assert ( 614 self.cutoff is not None 615 ), "cutoff must be specified if no graph is given" 616 # edge_mask = distances < self.cutoff 617 cutoff = self.cutoff 618 619 if self.switch_start > 1.0e-5: 620 assert ( 621 self.switch_start < 1.0 622 ), "switch_start is a proportion of cutoff and must be smaller than 1." 623 cutoff_in = self.switch_start * cutoff 624 x = distances - cutoff_in 625 end = cutoff - cutoff_in 626 else: 627 x = distances 628 end = cutoff 629 630 switch_type = self.switch_type.lower() 631 if switch_type == "cosine": 632 p = self.p if self.p is not None else 1.0 633 if self.trainable: 634 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 635 switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p 636 637 elif switch_type == "polynomial": 638 p = self.p if self.p is not None else 3.0 639 if self.trainable: 640 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 641 d = x / end 642 switch = ( 643 1.0 644 - 0.5 * (p + 1) * (p + 2) * d**p 645 + p * (p + 2) * d ** (p + 1) 646 - 0.5 * p * (p + 1) * d ** (p + 2) 647 ) 648 649 elif switch_type == "exponential": 650 p = self.p if self.p is not None else 1.0 651 if self.trainable: 652 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 653 r2 = x**2 654 c2 = end**2 655 switch = jnp.exp(-p * r2 / (c2 - r2)) 656 657 elif switch_type == "hard": 658 switch = jnp.where(distances < cutoff, 1.0, 0.0) 659 else: 660 raise ValueError(f"Unknown switch function {switch_type}") 661 662 if self.switch_start > 1.0e-5: 663 switch = jnp.where(distances < cutoff_in, 1.0, switch) 664 665 switch = jnp.where(edge_mask, switch, 0.0) 666 667 if self.graph_key is not None: 668 if self.output_key is not None: 669 return {**inputs, self.output_key: switch} 670 else: 671 return {**inputs, self.graph_key: {**graph, "switch": switch}} 672 else: 673 return switch # , edge_mask
24class ApplySwitch(nn.Module): 25 """Multiply an edge array by a switch array. 26 27 FID: APPLY_SWITCH 28 """ 29 30 key: str 31 """The key of the input array.""" 32 switch_key: Optional[str] = None 33 """The key of the switch array.""" 34 graph_key: Optional[str] = None 35 """The key of the graph containing the switch.""" 36 output_key: Optional[str] = None 37 """The key of the output array. If None, the input key is used.""" 38 39 FID: ClassVar[str] = "APPLY_SWITCH" 40 41 @nn.compact 42 def __call__(self, inputs) -> Any: 43 if self.graph_key is not None: 44 graph = inputs[self.graph_key] 45 switch = graph["switch"] 46 elif self.switch_key is not None: 47 switch = inputs[self.switch_key] 48 else: 49 raise ValueError("Either graph_key or switch_key must be specified") 50 51 x = inputs[self.key] 52 output = apply_switch(x, switch) 53 output_key = self.key if self.output_key is None else self.output_key 54 return {**inputs, output_key: output}
Multiply an edge array by a switch array.
FID: APPLY_SWITCH
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
57class AtomToEdge(nn.Module): 58 """Map atom-wise values to edge-wise values. 59 60 FID: ATOM_TO_EDGE 61 62 By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True. 63 """ 64 65 _graphs_properties: Dict 66 key: str 67 """The key of the input atom-wise array.""" 68 output_key: Optional[str] = None 69 """The key of the output edge-wise array. If None, the input key is used.""" 70 graph_key: str = "graph" 71 """The key of the graph containing the edges.""" 72 switch: bool = False 73 """Whether to apply a switch to the edge values.""" 74 switch_key: Optional[str] = None 75 """The key of the switch array. If None, the switch is taken from the graph.""" 76 use_source: bool = False 77 """Whether to use the source atom value instead of the destination atom value.""" 78 79 FID: ClassVar[str] = "ATOM_TO_EDGE" 80 81 @nn.compact 82 def __call__(self, inputs) -> Any: 83 graph = inputs[self.graph_key] 84 nat = inputs["species"].shape[0] 85 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 86 87 x = inputs[self.key] 88 if self.use_source: 89 x_edge = x[edge_src] 90 else: 91 x_edge = x[edge_dst] 92 93 if self.switch: 94 switch = ( 95 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 96 ) 97 x_edge = apply_switch(x_edge, switch) 98 99 output_key = self.key if self.output_key is None else self.output_key 100 return {**inputs, output_key: x_edge}
Map atom-wise values to edge-wise values.
FID: ATOM_TO_EDGE
By default, we map the destination atom value to the edge. This can be changed by setting use_source
to True.
The key of the output edge-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Whether to use the source atom value instead of the destination atom value.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
103class ScatterEdges(nn.Module): 104 """Reduce an edge array to atoms by summing over neighbors. 105 106 FID: SCATTER_EDGES 107 """ 108 109 _graphs_properties: Dict 110 key: str 111 """The key of the input edge-wise array.""" 112 output_key: Optional[str] = None 113 """The key of the output atom-wise array. If None, the input key is used.""" 114 graph_key: str = "graph" 115 """The key of the graph containing the edges.""" 116 switch: bool = False 117 """Whether to apply a switch to the edge values before summing.""" 118 switch_key: Optional[str] = None 119 """The key of the switch array. If None, the switch is taken from the graph.""" 120 antisymmetric: bool = False 121 122 FID: ClassVar[str] = "SCATTER_EDGES" 123 124 @nn.compact 125 def __call__(self, inputs) -> Any: 126 graph = inputs[self.graph_key] 127 nat = inputs["species"].shape[0] 128 x = inputs[self.key] 129 130 if self.switch: 131 switch = ( 132 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 133 ) 134 x = apply_switch(x, switch) 135 136 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 137 output = jax.ops.segment_sum( 138 x, edge_src, nat 139 ) # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop") 140 if self.antisymmetric: 141 output = output - jax.ops.segment_sum(x, edge_dst, nat) 142 143 output_key = self.key if self.output_key is None else self.output_key 144 return {**inputs, output_key: output}
Reduce an edge array to atoms by summing over neighbors.
FID: SCATTER_EDGES
The key of the output atom-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
147class EdgeConcatenate(nn.Module): 148 """Concatenate the source and destination atom values of an edge. 149 150 FID: EDGE_CONCATENATE 151 """ 152 153 _graphs_properties: Dict 154 key: str 155 """The key of the input atom-wise array.""" 156 output_key: Optional[str] = None 157 """The key of the output edge-wise array. If None, the input key is used.""" 158 graph_key: str = "graph" 159 """The key of the graph containing the edges.""" 160 switch: bool = False 161 """Whether to apply a switch to the edge values.""" 162 switch_key: Optional[str] = None 163 """The key of the switch array. If None, the switch is taken from the graph.""" 164 axis: int = -1 165 """The axis along which to concatenate the atom values.""" 166 167 FID: ClassVar[str] = "EDGE_CONCATENATE" 168 169 @nn.compact 170 def __call__(self, inputs) -> Any: 171 graph = inputs[self.graph_key] 172 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 173 nat = inputs["species"].shape[0] 174 xi = inputs[self.key] 175 176 assert self._graphs_properties[self.graph_key][ 177 "directed" 178 ], "EdgeConcatenate only works for directed graphs" 179 assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat" 180 181 xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis) 182 183 if self.switch: 184 switch = ( 185 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 186 ) 187 xij = apply_switch(xij, switch) 188 189 output_key = self.name if self.output_key is None else self.output_key 190 return {**inputs, output_key: xij}
Concatenate the source and destination atom values of an edge.
FID: EDGE_CONCATENATE
The key of the output edge-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
193class ScatterSystem(nn.Module): 194 """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch). 195 196 FID: SCATTER_SYSTEM 197 """ 198 199 key: str 200 """The key of the input atom-wise array.""" 201 output_key: Optional[str] = None 202 """The key of the output system-wise array. If None, the input key is used.""" 203 average: bool = False 204 """Wether to divide by the number of atoms in the system.""" 205 206 FID: ClassVar[str] = "SCATTER_SYSTEM" 207 208 @nn.compact 209 def __call__(self, inputs) -> Any: 210 batch_index = inputs["batch_index"] 211 x = inputs[self.key] 212 assert ( 213 x.shape[0] == batch_index.shape[0] 214 ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}" 215 nsys = inputs["natoms"].shape[0] 216 if self.average: 217 shape = [batch_index.shape[0]] + (x.ndim - 1) * [1] 218 x = x / inputs["natoms"][batch_index].reshape(shape) 219 220 output = jax.ops.segment_sum(x, batch_index, nsys) 221 222 output_key = self.key if self.output_key is None else self.output_key 223 return {**inputs, output_key: output}
Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).
FID: SCATTER_SYSTEM
The key of the output system-wise array. If None, the input key is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
226class SystemToAtoms(nn.Module): 227 """Broadcast a system-wise array to an atom-wise array. 228 229 FID: SYSTEM_TO_ATOMS 230 """ 231 232 key: str 233 """The key of the input system-wise array.""" 234 output_key: Optional[str] = None 235 """The key of the output atom-wise array. If None, the input key is used.""" 236 237 FID: ClassVar[str] = "SYSTEM_TO_ATOMS" 238 239 @nn.compact 240 def __call__(self, inputs) -> Any: 241 batch_index = inputs["batch_index"] 242 x = inputs[self.key] 243 output = x[batch_index] 244 245 output_key = self.key if self.output_key is None else self.output_key 246 return {**inputs, output_key: output}
Broadcast a system-wise array to an atom-wise array.
FID: SYSTEM_TO_ATOMS
The key of the output atom-wise array. If None, the input key is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
249class SumAxis(nn.Module): 250 """Sum an array along an axis. 251 252 FID: SUM_AXIS 253 """ 254 255 key: str 256 """The key of the input array.""" 257 axis: Union[None, int, Sequence[int]] = None 258 """The axis along which to sum the array.""" 259 output_key: Optional[str] = None 260 """The key of the output array. If None, the input key is used.""" 261 norm: Optional[str] = None 262 """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'.""" 263 264 FID: ClassVar[str] = "SUM_AXIS" 265 266 @nn.compact 267 def __call__(self, inputs) -> Any: 268 x = inputs[self.key] 269 output = jnp.sum(x, axis=self.axis) 270 if self.norm is not None: 271 norm = self.norm.lower() 272 if norm == "dim": 273 dim = np.prod(x.shape[self.axis]) 274 output = output / dim 275 elif norm == "sqrt": 276 dim = np.prod(x.shape[self.axis]) 277 output = output / dim**0.5 278 elif norm == "none": 279 pass 280 else: 281 raise ValueError(f"Unknown norm {norm}") 282 output_key = self.key if self.output_key is None else self.output_key 283 return {**inputs, output_key: output}
Sum an array along an axis.
FID: SUM_AXIS
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
286class Split(nn.Module): 287 """Split an array along an axis. 288 289 FID: SPLIT 290 """ 291 292 key: str 293 """The key of the input array.""" 294 output_keys: Sequence[str] 295 """The keys of the output arrays.""" 296 axis: int = -1 297 """The axis along which to split the array.""" 298 sizes: Union[int, Sequence[int]] = 1 299 """The sizes of the splits.""" 300 squeeze: bool = True 301 """Whether to remove the axis in the output if the size is 1.""" 302 303 FID: ClassVar[str] = "SPLIT" 304 305 @nn.compact 306 def __call__(self, inputs) -> Any: 307 x = inputs[self.key] 308 309 if isinstance(self.sizes, int): 310 split_size = [self.sizes] * len(self.output_keys) 311 else: 312 split_size = self.sizes 313 if len(split_size) == len(self.output_keys): 314 assert ( 315 sum(split_size) == x.shape[self.axis] 316 ), f"Split sizes {split_size} do not match input shape" 317 split_size = split_size[:-1] 318 assert ( 319 len(split_size) == len(self.output_keys) - 1 320 ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs" 321 split_indices = np.cumsum(split_size) 322 outs = {} 323 324 for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)): 325 outs[k] = ( 326 jnp.squeeze(v, axis=self.axis) 327 if self.squeeze and v.shape[self.axis] == 1 328 else v 329 ) 330 331 return {**inputs, **outs}
Split an array along an axis.
FID: SPLIT
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
334class Concatenate(nn.Module): 335 """Concatenate a list of arrays along an axis. 336 337 FID: CONCATENATE 338 """ 339 340 keys: Sequence[str] 341 """The keys of the input arrays.""" 342 axis: int = -1 343 """The axis along which to concatenate the arrays.""" 344 output_key: Optional[str] = None 345 """The key of the output array. If None, the name of the module is used.""" 346 347 FID: ClassVar[str] = "CONCATENATE" 348 349 @nn.compact 350 def __call__(self, inputs) -> Any: 351 output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis) 352 output_key = self.output_key if self.output_key is not None else self.name 353 return {**inputs, output_key: output}
Concatenate a list of arrays along an axis.
FID: CONCATENATE
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
356class Activation(nn.Module): 357 """Apply an element-wise activation function to an array. 358 359 FID: ACTIVATION 360 """ 361 362 key: str 363 """The key of the input array.""" 364 activation: Union[Callable, str] 365 """The activation function or its name.""" 366 scale_out: float = 1.0 367 """Output scaling factor.""" 368 shift_out: float = 0.0 369 """Output shift.""" 370 output_key: Optional[str] = None 371 """The key of the output array. If None, the input key is used.""" 372 373 FID: ClassVar[str] = "ACTIVATION" 374 375 @nn.compact 376 def __call__(self, inputs) -> Any: 377 x = inputs[self.key] 378 activation = ( 379 activation_from_str(self.activation) 380 if isinstance(self.activation, str) 381 else self.activation 382 ) 383 output = self.scale_out * activation(x) + self.shift_out 384 output_key = self.output_key if self.output_key is not None else self.key 385 return {**inputs, output_key: output}
Apply an element-wise activation function to an array.
FID: ACTIVATION
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
388class Scale(nn.Module): 389 """Scale an array by a constant factor. 390 391 FID: SCALE 392 """ 393 394 key: str 395 """The key of the input array.""" 396 scale: float 397 """The (initial) scaling factor.""" 398 output_key: Optional[str] = None 399 """The key of the output array. If None, the input key is used.""" 400 trainable: bool = False 401 """Whether the scaling factor is trainable.""" 402 403 FID: ClassVar[str] = "SCALE" 404 405 @nn.compact 406 def __call__(self, inputs) -> Any: 407 x = inputs[self.key] 408 409 if self.trainable: 410 scale = self.param("scale", lambda rng: jnp.asarray(self.scale)) 411 else: 412 scale = self.scale 413 414 output = scale * x 415 output_key = self.output_key if self.output_key is not None else self.key 416 return {**inputs, output_key: output}
Scale an array by a constant factor.
FID: SCALE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
419class Add(nn.Module): 420 """Add together a list of arrays. 421 422 FID: ADD 423 """ 424 425 keys: Sequence[str] 426 """The keys of the input arrays.""" 427 output_key: Optional[str] = None 428 """The key of the output array. If None, the name of the module is used.""" 429 430 FID: ClassVar[str] = "ADD" 431 432 @nn.compact 433 def __call__(self, inputs) -> Any: 434 output = 0 435 for k in self.keys: 436 output = output + inputs[k] 437 438 output_key = self.output_key if self.output_key is not None else self.name 439 return {**inputs, output_key: output}
Add together a list of arrays.
FID: ADD
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
442class Multiply(nn.Module): 443 """Element-wise-multiply together a list of arrays. 444 445 FID: MULTIPLY 446 """ 447 448 keys: Sequence[str] 449 """The keys of the input arrays.""" 450 output_key: Optional[str] = None 451 """The key of the output array. If None, the name of the module is used.""" 452 453 FID: ClassVar[str] = "MULTIPLY" 454 455 @nn.compact 456 def __call__(self, inputs) -> Any: 457 output = 1 458 for k in self.keys: 459 output = output * inputs[k] 460 461 output_key = self.output_key if self.output_key is not None else self.name 462 return {**inputs, output_key: output}
Element-wise-multiply together a list of arrays.
FID: MULTIPLY
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
465class Transpose(nn.Module): 466 """Transpose an array. 467 468 FID: TRANSPOSE 469 """ 470 471 key: str 472 """The key of the input array.""" 473 axes: Sequence[int] 474 """The permutation of the axes. See `jax.numpy.transpose` for more details.""" 475 output_key: Optional[str] = None 476 """The key of the output array. If None, the input key is used.""" 477 478 FID: ClassVar[str] = "TRANSPOSE" 479 480 @nn.compact 481 def __call__(self, inputs) -> Any: 482 output = jnp.transpose(inputs[self.key], axes=self.axes) 483 output_key = self.output_key if self.output_key is not None else self.key 484 return {**inputs, output_key: output}
Transpose an array.
FID: TRANSPOSE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
487class Reshape(nn.Module): 488 """Reshape an array. 489 490 FID: RESHAPE 491 """ 492 493 key: str 494 """The key of the input array.""" 495 shape: Sequence[Union[int,str]] 496 """The shape of the output array.""" 497 output_key: Optional[str] = None 498 """The key of the output array. If None, the input key is used.""" 499 500 FID: ClassVar[str] = "RESHAPE" 501 502 @nn.compact 503 def __call__(self, inputs) -> Any: 504 shape = [] 505 for s in self.shape: 506 if isinstance(s,int): 507 shape.append(s) 508 continue 509 510 if isinstance(s,str): 511 s_=s.lower().strip() 512 if s_ in ["natoms" ,"nat","natom","n_atoms","atoms"]: 513 shape.append(inputs["species"].shape[0]) 514 continue 515 516 if s_ in ["nsys","nbatch","nsystems","n_sys","n_systems","n_batch"]: 517 shape.append(inputs["natoms"].shape[0]) 518 continue 519 520 s_ = s.strip().split("[") 521 key = s_[0] 522 if key in inputs: 523 axis = int(s_[1].split("]")[0]) 524 shape.append(inputs[key].shape[axis]) 525 continue 526 527 raise ValueError(f"Error parsing shape component {s}") 528 529 output = jnp.reshape(inputs[self.key], shape) 530 output_key = self.output_key if self.output_key is not None else self.key 531 return {**inputs, output_key: output}
Reshape an array.
FID: RESHAPE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
534class ChemicalConstant(nn.Module): 535 """Map atomic species to a constant value. 536 537 FID: CHEMICAL_CONSTANT 538 """ 539 540 value: Union[str, List[float], float, Dict] 541 """The constant value or a dictionary of values for each element.""" 542 output_key: Optional[str] = None 543 """The key of the output array. If None, the name of the module is used.""" 544 trainable: bool = False 545 """Whether the constant is trainable.""" 546 547 FID: ClassVar[str] = "CHEMICAL_CONSTANT" 548 549 @nn.compact 550 def __call__(self, inputs) -> Any: 551 if isinstance(self.value, str): 552 constant = CHEMICAL_PROPERTIES[self.value.upper()] 553 elif isinstance(self.value, list) or isinstance(self.value, tuple): 554 constant = list(self.value) 555 elif isinstance(self.value, float): 556 constant = [self.value] * len(PERIODIC_TABLE) 557 elif hasattr(self.value, "items"): 558 constant = [0.0] * len(PERIODIC_TABLE) 559 for k, v in self.value.items(): 560 constant[PERIODIC_TABLE_REV_IDX[k]] = v 561 else: 562 raise ValueError(f"Unknown constant type {type(self.value)}") 563 564 if self.trainable: 565 constant = self.param( 566 "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32) 567 ) 568 else: 569 constant = jnp.asarray(constant, dtype=jnp.float32) 570 output = constant[inputs["species"]] 571 output_key = self.output_key if self.output_key is not None else self.name 572 return {**inputs, output_key: output}
Map atomic species to a constant value.
FID: CHEMICAL_CONSTANT
The constant value or a dictionary of values for each element.
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
575class SwitchFunction(nn.Module): 576 """Compute a switch array from an array of distances and a cutoff. 577 578 FID: SWITCH_FUNCTION 579 """ 580 581 cutoff: Optional[float] = None 582 """The cutoff distance. If None, the cutoff is taken from the graph.""" 583 switch_start: float = 0.0 584 """The proportion of the cutoff distance at which the switch function starts.""" 585 graph_key: Optional[str] = "graph" 586 """The key of the graph containing the distances and edge mask.""" 587 output_key: Optional[str] = None 588 """The key of the output switch array. If None, it is added to the graph.""" 589 switch_type: str = "cosine" 590 """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.""" 591 p: Optional[float] = None 592 """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`.""" 593 trainable: bool = False 594 """Whether the switch parameter is trainable.""" 595 596 FID: ClassVar[str] = "SWITCH_FUNCTION" 597 598 @nn.compact 599 def __call__(self, inputs) -> Any: 600 if self.graph_key is not None: 601 graph = inputs[self.graph_key] 602 distances, edge_mask = graph["distances"], graph["edge_mask"] 603 if self.cutoff is not None: 604 edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff)) 605 cutoff = self.cutoff 606 else: 607 cutoff = graph["cutoff"] 608 else: 609 # distances = inputs 610 if len(inputs) == 3: 611 distances, edge_mask, cutoff = inputs 612 else: 613 distances, edge_mask = inputs 614 assert ( 615 self.cutoff is not None 616 ), "cutoff must be specified if no graph is given" 617 # edge_mask = distances < self.cutoff 618 cutoff = self.cutoff 619 620 if self.switch_start > 1.0e-5: 621 assert ( 622 self.switch_start < 1.0 623 ), "switch_start is a proportion of cutoff and must be smaller than 1." 624 cutoff_in = self.switch_start * cutoff 625 x = distances - cutoff_in 626 end = cutoff - cutoff_in 627 else: 628 x = distances 629 end = cutoff 630 631 switch_type = self.switch_type.lower() 632 if switch_type == "cosine": 633 p = self.p if self.p is not None else 1.0 634 if self.trainable: 635 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 636 switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p 637 638 elif switch_type == "polynomial": 639 p = self.p if self.p is not None else 3.0 640 if self.trainable: 641 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 642 d = x / end 643 switch = ( 644 1.0 645 - 0.5 * (p + 1) * (p + 2) * d**p 646 + p * (p + 2) * d ** (p + 1) 647 - 0.5 * p * (p + 1) * d ** (p + 2) 648 ) 649 650 elif switch_type == "exponential": 651 p = self.p if self.p is not None else 1.0 652 if self.trainable: 653 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 654 r2 = x**2 655 c2 = end**2 656 switch = jnp.exp(-p * r2 / (c2 - r2)) 657 658 elif switch_type == "hard": 659 switch = jnp.where(distances < cutoff, 1.0, 0.0) 660 else: 661 raise ValueError(f"Unknown switch function {switch_type}") 662 663 if self.switch_start > 1.0e-5: 664 switch = jnp.where(distances < cutoff_in, 1.0, switch) 665 666 switch = jnp.where(edge_mask, switch, 0.0) 667 668 if self.graph_key is not None: 669 if self.output_key is not None: 670 return {**inputs, self.output_key: switch} 671 else: 672 return {**inputs, self.graph_key: {**graph, "switch": switch}} 673 else: 674 return switch # , edge_mask
Compute a switch array from an array of distances and a cutoff.
FID: SWITCH_FUNCTION
The proportion of the cutoff distance at which the switch function starts.
The key of the output switch array. If None, it is added to the graph.
The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.
The parameter of the switch function. If None, it is fixed to the default for each switch_type
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.