fennol.models.misc.misc
1import flax.linen as nn 2from typing import Any, Sequence, Callable, Union, ClassVar, Optional, Dict, List 3import jax.numpy as jnp 4import jax 5import numpy as np 6from functools import partial 7from ...utils.activations import activation_from_str 8from ...utils.periodic_table import ( 9 CHEMICAL_PROPERTIES, 10 PERIODIC_TABLE, 11 PERIODIC_TABLE_REV_IDX, 12) 13 14 15def apply_switch(x: jax.Array, switch: jax.Array): 16 """@private Multiply a switch array to an array of values.""" 17 shape = x.shape 18 return ( 19 jnp.expand_dims(x, axis=-1).reshape(*switch.shape, -1) * switch[..., None] 20 ).reshape(shape) 21 22 23class ApplySwitch(nn.Module): 24 """Multiply an edge array by a switch array. 25 26 FID: APPLY_SWITCH 27 """ 28 29 key: str 30 """The key of the input array.""" 31 switch_key: Optional[str] = None 32 """The key of the switch array.""" 33 graph_key: Optional[str] = None 34 """The key of the graph containing the switch.""" 35 output_key: Optional[str] = None 36 """The key of the output array. If None, the input key is used.""" 37 38 FID: ClassVar[str] = "APPLY_SWITCH" 39 40 @nn.compact 41 def __call__(self, inputs) -> Any: 42 if self.graph_key is not None: 43 graph = inputs[self.graph_key] 44 switch = graph["switch"] 45 elif self.switch_key is not None: 46 switch = inputs[self.switch_key] 47 else: 48 raise ValueError("Either graph_key or switch_key must be specified") 49 50 x = inputs[self.key] 51 output = apply_switch(x, switch) 52 output_key = self.key if self.output_key is None else self.output_key 53 return {**inputs, output_key: output} 54 55 56class AtomToEdge(nn.Module): 57 """Map atom-wise values to edge-wise values. 58 59 FID: ATOM_TO_EDGE 60 61 By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True. 62 """ 63 64 _graphs_properties: Dict 65 key: str 66 """The key of the input atom-wise array.""" 67 output_key: Optional[str] = None 68 """The key of the output edge-wise array. If None, the input key is used.""" 69 graph_key: str = "graph" 70 """The key of the graph containing the edges.""" 71 switch: bool = False 72 """Whether to apply a switch to the edge values.""" 73 switch_key: Optional[str] = None 74 """The key of the switch array. If None, the switch is taken from the graph.""" 75 use_source: bool = False 76 """Whether to use the source atom value instead of the destination atom value.""" 77 78 FID: ClassVar[str] = "ATOM_TO_EDGE" 79 80 @nn.compact 81 def __call__(self, inputs) -> Any: 82 graph = inputs[self.graph_key] 83 nat = inputs["species"].shape[0] 84 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 85 86 x = inputs[self.key] 87 if self.use_source: 88 x_edge = x[edge_src] 89 else: 90 x_edge = x[edge_dst] 91 92 if self.switch: 93 switch = ( 94 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 95 ) 96 x_edge = apply_switch(x_edge, switch) 97 98 output_key = self.key if self.output_key is None else self.output_key 99 return {**inputs, output_key: x_edge} 100 101 102class ScatterEdges(nn.Module): 103 """Reduce an edge array to atoms by summing over neighbors. 104 105 FID: SCATTER_EDGES 106 """ 107 108 _graphs_properties: Dict 109 key: str 110 """The key of the input edge-wise array.""" 111 output_key: Optional[str] = None 112 """The key of the output atom-wise array. If None, the input key is used.""" 113 graph_key: str = "graph" 114 """The key of the graph containing the edges.""" 115 switch: bool = False 116 """Whether to apply a switch to the edge values before summing.""" 117 switch_key: Optional[str] = None 118 """The key of the switch array. If None, the switch is taken from the graph.""" 119 antisymmetric: bool = False 120 121 FID: ClassVar[str] = "SCATTER_EDGES" 122 123 @nn.compact 124 def __call__(self, inputs) -> Any: 125 graph = inputs[self.graph_key] 126 nat = inputs["species"].shape[0] 127 x = inputs[self.key] 128 129 if self.switch: 130 switch = ( 131 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 132 ) 133 x = apply_switch(x, switch) 134 135 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 136 output = jax.ops.segment_sum( 137 x, edge_src, nat 138 ) # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop") 139 if self.antisymmetric: 140 output = output - jax.ops.segment_sum(x, edge_dst, nat) 141 142 output_key = self.key if self.output_key is None else self.output_key 143 return {**inputs, output_key: output} 144 145 146class EdgeConcatenate(nn.Module): 147 """Concatenate the source and destination atom values of an edge. 148 149 FID: EDGE_CONCATENATE 150 """ 151 152 _graphs_properties: Dict 153 key: str 154 """The key of the input atom-wise array.""" 155 output_key: Optional[str] = None 156 """The key of the output edge-wise array. If None, the input key is used.""" 157 graph_key: str = "graph" 158 """The key of the graph containing the edges.""" 159 switch: bool = False 160 """Whether to apply a switch to the edge values.""" 161 switch_key: Optional[str] = None 162 """The key of the switch array. If None, the switch is taken from the graph.""" 163 axis: int = -1 164 """The axis along which to concatenate the atom values.""" 165 166 FID: ClassVar[str] = "EDGE_CONCATENATE" 167 168 @nn.compact 169 def __call__(self, inputs) -> Any: 170 graph = inputs[self.graph_key] 171 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 172 nat = inputs["species"].shape[0] 173 xi = inputs[self.key] 174 175 assert self._graphs_properties[self.graph_key][ 176 "directed" 177 ], "EdgeConcatenate only works for directed graphs" 178 assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat" 179 180 xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis) 181 182 if self.switch: 183 switch = ( 184 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 185 ) 186 xij = apply_switch(xij, switch) 187 188 output_key = self.name if self.output_key is None else self.output_key 189 return {**inputs, output_key: xij} 190 191 192class ScatterSystem(nn.Module): 193 """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch). 194 195 FID: SCATTER_SYSTEM 196 """ 197 198 key: str 199 """The key of the input atom-wise array.""" 200 output_key: Optional[str] = None 201 """The key of the output system-wise array. If None, the input key is used.""" 202 average: bool = False 203 """Wether to divide by the number of atoms in the system.""" 204 205 FID: ClassVar[str] = "SCATTER_SYSTEM" 206 207 @nn.compact 208 def __call__(self, inputs) -> Any: 209 batch_index = inputs["batch_index"] 210 x = inputs[self.key] 211 assert ( 212 x.shape[0] == batch_index.shape[0] 213 ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}" 214 nsys = inputs["natoms"].shape[0] 215 if self.average: 216 shape = [batch_index.shape[0]] + (x.ndim - 1) * [1] 217 x = x / inputs["natoms"][batch_index].reshape(shape) 218 219 output = jax.ops.segment_sum(x, batch_index, nsys) 220 221 output_key = self.key if self.output_key is None else self.output_key 222 return {**inputs, output_key: output} 223 224 225class SystemToAtoms(nn.Module): 226 """Broadcast a system-wise array to an atom-wise array. 227 228 FID: SYSTEM_TO_ATOMS 229 """ 230 231 key: str 232 """The key of the input system-wise array.""" 233 output_key: Optional[str] = None 234 """The key of the output atom-wise array. If None, the input key is used.""" 235 236 FID: ClassVar[str] = "SYSTEM_TO_ATOMS" 237 238 @nn.compact 239 def __call__(self, inputs) -> Any: 240 batch_index = inputs["batch_index"] 241 x = inputs[self.key] 242 output = x[batch_index] 243 244 output_key = self.key if self.output_key is None else self.output_key 245 return {**inputs, output_key: output} 246 247 248class SumAxis(nn.Module): 249 """Sum an array along an axis. 250 251 FID: SUM_AXIS 252 """ 253 254 key: str 255 """The key of the input array.""" 256 axis: Union[None, int, Sequence[int]] = None 257 """The axis along which to sum the array.""" 258 output_key: Optional[str] = None 259 """The key of the output array. If None, the input key is used.""" 260 norm: Optional[str] = None 261 """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'.""" 262 263 FID: ClassVar[str] = "SUM_AXIS" 264 265 @nn.compact 266 def __call__(self, inputs) -> Any: 267 x = inputs[self.key] 268 output = jnp.sum(x, axis=self.axis) 269 if self.norm is not None: 270 norm = self.norm.lower() 271 if norm == "dim": 272 dim = np.prod(x.shape[self.axis]) 273 output = output / dim 274 elif norm == "sqrt": 275 dim = np.prod(x.shape[self.axis]) 276 output = output / dim**0.5 277 elif norm == "none": 278 pass 279 else: 280 raise ValueError(f"Unknown norm {norm}") 281 output_key = self.key if self.output_key is None else self.output_key 282 return {**inputs, output_key: output} 283 284 285class Split(nn.Module): 286 """Split an array along an axis. 287 288 FID: SPLIT 289 """ 290 291 key: str 292 """The key of the input array.""" 293 output_keys: Sequence[str] 294 """The keys of the output arrays.""" 295 axis: int = -1 296 """The axis along which to split the array.""" 297 sizes: Union[int, Sequence[int]] = 1 298 """The sizes of the splits.""" 299 squeeze: bool = True 300 """Whether to remove the axis in the output if the size is 1.""" 301 302 FID: ClassVar[str] = "SPLIT" 303 304 @nn.compact 305 def __call__(self, inputs) -> Any: 306 x = inputs[self.key] 307 308 if isinstance(self.sizes, int): 309 split_size = [self.sizes] * len(self.output_keys) 310 else: 311 split_size = self.sizes 312 if len(split_size) == len(self.output_keys): 313 assert ( 314 sum(split_size) == x.shape[self.axis] 315 ), f"Split sizes {split_size} do not match input shape" 316 split_size = split_size[:-1] 317 assert ( 318 len(split_size) == len(self.output_keys) - 1 319 ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs" 320 split_indices = np.cumsum(split_size) 321 outs = {} 322 323 for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)): 324 outs[k] = ( 325 jnp.squeeze(v, axis=self.axis) 326 if self.squeeze and v.shape[self.axis] == 1 327 else v 328 ) 329 330 return {**inputs, **outs} 331 332 333class Concatenate(nn.Module): 334 """Concatenate a list of arrays along an axis. 335 336 FID: CONCATENATE 337 """ 338 339 keys: Sequence[str] 340 """The keys of the input arrays.""" 341 axis: int = -1 342 """The axis along which to concatenate the arrays.""" 343 output_key: Optional[str] = None 344 """The key of the output array. If None, the name of the module is used.""" 345 346 FID: ClassVar[str] = "CONCATENATE" 347 348 @nn.compact 349 def __call__(self, inputs) -> Any: 350 output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis) 351 output_key = self.output_key if self.output_key is not None else self.name 352 return {**inputs, output_key: output} 353 354 355class Activation(nn.Module): 356 """Apply an element-wise activation function to an array. 357 358 FID: ACTIVATION 359 """ 360 361 key: str 362 """The key of the input array.""" 363 activation: Union[Callable, str] 364 """The activation function or its name.""" 365 scale_out: float = 1.0 366 """Output scaling factor.""" 367 shift_out: float = 0.0 368 """Output shift.""" 369 output_key: Optional[str] = None 370 """The key of the output array. If None, the input key is used.""" 371 372 FID: ClassVar[str] = "ACTIVATION" 373 374 @nn.compact 375 def __call__(self, inputs) -> Any: 376 x = inputs[self.key] 377 activation = ( 378 activation_from_str(self.activation) 379 if isinstance(self.activation, str) 380 else self.activation 381 ) 382 output = self.scale_out * activation(x) + self.shift_out 383 output_key = self.output_key if self.output_key is not None else self.key 384 return {**inputs, output_key: output} 385 386 387class Scale(nn.Module): 388 """Scale an array by a constant factor. 389 390 FID: SCALE 391 """ 392 393 key: str 394 """The key of the input array.""" 395 scale: float 396 """The (initial) scaling factor.""" 397 output_key: Optional[str] = None 398 """The key of the output array. If None, the input key is used.""" 399 trainable: bool = False 400 """Whether the scaling factor is trainable.""" 401 402 FID: ClassVar[str] = "SCALE" 403 404 @nn.compact 405 def __call__(self, inputs) -> Any: 406 x = inputs[self.key] 407 408 if self.trainable: 409 scale = self.param("scale", lambda rng: jnp.asarray(self.scale)) 410 else: 411 scale = self.scale 412 413 output = scale * x 414 output_key = self.output_key if self.output_key is not None else self.key 415 return {**inputs, output_key: output} 416 417 418class Add(nn.Module): 419 """Add together a list of arrays. 420 421 FID: ADD 422 """ 423 424 keys: Sequence[str] 425 """The keys of the input arrays.""" 426 output_key: Optional[str] = None 427 """The key of the output array. If None, the name of the module is used.""" 428 429 FID: ClassVar[str] = "ADD" 430 431 @nn.compact 432 def __call__(self, inputs) -> Any: 433 output = 0 434 for k in self.keys: 435 output = output + inputs[k] 436 437 output_key = self.output_key if self.output_key is not None else self.name 438 return {**inputs, output_key: output} 439 440 441class Multiply(nn.Module): 442 """Element-wise-multiply together a list of arrays. 443 444 FID: MULTIPLY 445 """ 446 447 keys: Sequence[str] 448 """The keys of the input arrays.""" 449 output_key: Optional[str] = None 450 """The key of the output array. If None, the name of the module is used.""" 451 452 FID: ClassVar[str] = "MULTIPLY" 453 454 @nn.compact 455 def __call__(self, inputs) -> Any: 456 output = 1 457 for k in self.keys: 458 output = output * inputs[k] 459 460 output_key = self.output_key if self.output_key is not None else self.name 461 return {**inputs, output_key: output} 462 463 464class Transpose(nn.Module): 465 """Transpose an array. 466 467 FID: TRANSPOSE 468 """ 469 470 key: str 471 """The key of the input array.""" 472 axes: Sequence[int] 473 """The permutation of the axes. See `jax.numpy.transpose` for more details.""" 474 output_key: Optional[str] = None 475 """The key of the output array. If None, the input key is used.""" 476 477 FID: ClassVar[str] = "TRANSPOSE" 478 479 @nn.compact 480 def __call__(self, inputs) -> Any: 481 output = jnp.transpose(inputs[self.key], axes=self.axes) 482 output_key = self.output_key if self.output_key is not None else self.key 483 return {**inputs, output_key: output} 484 485 486class Reshape(nn.Module): 487 """Reshape an array. 488 489 FID: RESHAPE 490 """ 491 492 key: str 493 """The key of the input array.""" 494 shape: Sequence[int] 495 """The shape of the output array.""" 496 output_key: Optional[str] = None 497 """The key of the output array. If None, the input key is used.""" 498 499 FID: ClassVar[str] = "RESHAPE" 500 501 @nn.compact 502 def __call__(self, inputs) -> Any: 503 output = jnp.reshape(inputs[self.key], self.shape) 504 output_key = self.output_key if self.output_key is not None else self.key 505 return {**inputs, output_key: output} 506 507 508class ChemicalConstant(nn.Module): 509 """Map atomic species to a constant value. 510 511 FID: CHEMICAL_CONSTANT 512 """ 513 514 value: Union[str, List[float], float, Dict] 515 """The constant value or a dictionary of values for each element.""" 516 output_key: Optional[str] = None 517 """The key of the output array. If None, the name of the module is used.""" 518 trainable: bool = False 519 """Whether the constant is trainable.""" 520 521 FID: ClassVar[str] = "CHEMICAL_CONSTANT" 522 523 @nn.compact 524 def __call__(self, inputs) -> Any: 525 if isinstance(self.value, str): 526 constant = CHEMICAL_PROPERTIES[self.value.upper()] 527 elif isinstance(self.value, list): 528 constant = self.value 529 elif isinstance(self.value, float): 530 constant = [self.value] * len(PERIODIC_TABLE) 531 elif hasattr(self.value, "items"): 532 constant = [0.0] * len(PERIODIC_TABLE) 533 for k, v in self.value.items(): 534 constant[PERIODIC_TABLE_REV_IDX[k]] = v 535 else: 536 raise ValueError(f"Unknown constant type {type(self.value)}") 537 538 if self.trainable: 539 constant = self.param( 540 "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32) 541 ) 542 else: 543 constant = jnp.asarray(constant, dtype=jnp.float32) 544 output = constant[inputs["species"]] 545 output_key = self.output_key if self.output_key is not None else self.name 546 return {**inputs, output_key: output} 547 548 549class SwitchFunction(nn.Module): 550 """Compute a switch array from an array of distances and a cutoff. 551 552 FID: SWITCH_FUNCTION 553 """ 554 555 cutoff: Optional[float] = None 556 """The cutoff distance. If None, the cutoff is taken from the graph.""" 557 switch_start: float = 0.0 558 """The proportion of the cutoff distance at which the switch function starts.""" 559 graph_key: Optional[str] = "graph" 560 """The key of the graph containing the distances and edge mask.""" 561 output_key: Optional[str] = None 562 """The key of the output switch array. If None, it is added to the graph.""" 563 switch_type: str = "cosine" 564 """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.""" 565 p: Optional[float] = None 566 """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`.""" 567 trainable: bool = False 568 """Whether the switch parameter is trainable.""" 569 570 FID: ClassVar[str] = "SWITCH_FUNCTION" 571 572 @nn.compact 573 def __call__(self, inputs) -> Any: 574 if self.graph_key is not None: 575 graph = inputs[self.graph_key] 576 distances, edge_mask = graph["distances"], graph["edge_mask"] 577 if self.cutoff is not None: 578 edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff)) 579 cutoff = self.cutoff 580 else: 581 cutoff = graph["cutoff"] 582 else: 583 # distances = inputs 584 distances, edge_mask = inputs 585 assert ( 586 self.cutoff is not None 587 ), "cutoff must be specified if no graph is given" 588 # edge_mask = distances < self.cutoff 589 cutoff = self.cutoff 590 591 if self.switch_start > 1.0e-5: 592 assert ( 593 self.switch_start < 1.0 594 ), "switch_start is a proportion of cutoff and must be smaller than 1." 595 cutoff_in = self.switch_start * cutoff 596 x = distances - cutoff_in 597 end = cutoff - cutoff_in 598 else: 599 x = distances 600 end = cutoff 601 602 switch_type = self.switch_type.lower() 603 if switch_type == "cosine": 604 p = self.p if self.p is not None else 1.0 605 if self.trainable: 606 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 607 switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p 608 609 elif switch_type == "polynomial": 610 p = self.p if self.p is not None else 3.0 611 if self.trainable: 612 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 613 d = x / end 614 switch = ( 615 1.0 616 - 0.5 * (p + 1) * (p + 2) * d**p 617 + p * (p + 2) * d ** (p + 1) 618 - 0.5 * p * (p + 1) * d ** (p + 2) 619 ) 620 621 elif switch_type == "exponential": 622 p = self.p if self.p is not None else 1.0 623 if self.trainable: 624 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 625 r2 = x**2 626 c2 = end**2 627 switch = jnp.exp(-p * r2 / (c2 - r2)) 628 629 elif switch_type == "hard": 630 switch = jnp.where(distances < cutoff, 1.0, 0.0) 631 else: 632 raise ValueError(f"Unknown switch function {switch_type}") 633 634 if self.switch_start > 1.0e-5: 635 switch = jnp.where(distances < cutoff_in, 1.0, switch) 636 637 switch = jnp.where(edge_mask, switch, 0.0) 638 639 if self.graph_key is not None: 640 if self.output_key is not None: 641 return {**inputs, self.output_key: switch} 642 else: 643 return {**inputs, self.graph_key: {**graph, "switch": switch}} 644 else: 645 return switch # , edge_mask
24class ApplySwitch(nn.Module): 25 """Multiply an edge array by a switch array. 26 27 FID: APPLY_SWITCH 28 """ 29 30 key: str 31 """The key of the input array.""" 32 switch_key: Optional[str] = None 33 """The key of the switch array.""" 34 graph_key: Optional[str] = None 35 """The key of the graph containing the switch.""" 36 output_key: Optional[str] = None 37 """The key of the output array. If None, the input key is used.""" 38 39 FID: ClassVar[str] = "APPLY_SWITCH" 40 41 @nn.compact 42 def __call__(self, inputs) -> Any: 43 if self.graph_key is not None: 44 graph = inputs[self.graph_key] 45 switch = graph["switch"] 46 elif self.switch_key is not None: 47 switch = inputs[self.switch_key] 48 else: 49 raise ValueError("Either graph_key or switch_key must be specified") 50 51 x = inputs[self.key] 52 output = apply_switch(x, switch) 53 output_key = self.key if self.output_key is None else self.output_key 54 return {**inputs, output_key: output}
Multiply an edge array by a switch array.
FID: APPLY_SWITCH
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
57class AtomToEdge(nn.Module): 58 """Map atom-wise values to edge-wise values. 59 60 FID: ATOM_TO_EDGE 61 62 By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True. 63 """ 64 65 _graphs_properties: Dict 66 key: str 67 """The key of the input atom-wise array.""" 68 output_key: Optional[str] = None 69 """The key of the output edge-wise array. If None, the input key is used.""" 70 graph_key: str = "graph" 71 """The key of the graph containing the edges.""" 72 switch: bool = False 73 """Whether to apply a switch to the edge values.""" 74 switch_key: Optional[str] = None 75 """The key of the switch array. If None, the switch is taken from the graph.""" 76 use_source: bool = False 77 """Whether to use the source atom value instead of the destination atom value.""" 78 79 FID: ClassVar[str] = "ATOM_TO_EDGE" 80 81 @nn.compact 82 def __call__(self, inputs) -> Any: 83 graph = inputs[self.graph_key] 84 nat = inputs["species"].shape[0] 85 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 86 87 x = inputs[self.key] 88 if self.use_source: 89 x_edge = x[edge_src] 90 else: 91 x_edge = x[edge_dst] 92 93 if self.switch: 94 switch = ( 95 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 96 ) 97 x_edge = apply_switch(x_edge, switch) 98 99 output_key = self.key if self.output_key is None else self.output_key 100 return {**inputs, output_key: x_edge}
Map atom-wise values to edge-wise values.
FID: ATOM_TO_EDGE
By default, we map the destination atom value to the edge. This can be changed by setting use_source
to True.
The key of the output edge-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Whether to use the source atom value instead of the destination atom value.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
103class ScatterEdges(nn.Module): 104 """Reduce an edge array to atoms by summing over neighbors. 105 106 FID: SCATTER_EDGES 107 """ 108 109 _graphs_properties: Dict 110 key: str 111 """The key of the input edge-wise array.""" 112 output_key: Optional[str] = None 113 """The key of the output atom-wise array. If None, the input key is used.""" 114 graph_key: str = "graph" 115 """The key of the graph containing the edges.""" 116 switch: bool = False 117 """Whether to apply a switch to the edge values before summing.""" 118 switch_key: Optional[str] = None 119 """The key of the switch array. If None, the switch is taken from the graph.""" 120 antisymmetric: bool = False 121 122 FID: ClassVar[str] = "SCATTER_EDGES" 123 124 @nn.compact 125 def __call__(self, inputs) -> Any: 126 graph = inputs[self.graph_key] 127 nat = inputs["species"].shape[0] 128 x = inputs[self.key] 129 130 if self.switch: 131 switch = ( 132 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 133 ) 134 x = apply_switch(x, switch) 135 136 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 137 output = jax.ops.segment_sum( 138 x, edge_src, nat 139 ) # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop") 140 if self.antisymmetric: 141 output = output - jax.ops.segment_sum(x, edge_dst, nat) 142 143 output_key = self.key if self.output_key is None else self.output_key 144 return {**inputs, output_key: output}
Reduce an edge array to atoms by summing over neighbors.
FID: SCATTER_EDGES
The key of the output atom-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
147class EdgeConcatenate(nn.Module): 148 """Concatenate the source and destination atom values of an edge. 149 150 FID: EDGE_CONCATENATE 151 """ 152 153 _graphs_properties: Dict 154 key: str 155 """The key of the input atom-wise array.""" 156 output_key: Optional[str] = None 157 """The key of the output edge-wise array. If None, the input key is used.""" 158 graph_key: str = "graph" 159 """The key of the graph containing the edges.""" 160 switch: bool = False 161 """Whether to apply a switch to the edge values.""" 162 switch_key: Optional[str] = None 163 """The key of the switch array. If None, the switch is taken from the graph.""" 164 axis: int = -1 165 """The axis along which to concatenate the atom values.""" 166 167 FID: ClassVar[str] = "EDGE_CONCATENATE" 168 169 @nn.compact 170 def __call__(self, inputs) -> Any: 171 graph = inputs[self.graph_key] 172 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 173 nat = inputs["species"].shape[0] 174 xi = inputs[self.key] 175 176 assert self._graphs_properties[self.graph_key][ 177 "directed" 178 ], "EdgeConcatenate only works for directed graphs" 179 assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat" 180 181 xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis) 182 183 if self.switch: 184 switch = ( 185 graph["switch"] if self.switch_key is None else inputs[self.switch_key] 186 ) 187 xij = apply_switch(xij, switch) 188 189 output_key = self.name if self.output_key is None else self.output_key 190 return {**inputs, output_key: xij}
Concatenate the source and destination atom values of an edge.
FID: EDGE_CONCATENATE
The key of the output edge-wise array. If None, the input key is used.
The key of the switch array. If None, the switch is taken from the graph.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
193class ScatterSystem(nn.Module): 194 """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch). 195 196 FID: SCATTER_SYSTEM 197 """ 198 199 key: str 200 """The key of the input atom-wise array.""" 201 output_key: Optional[str] = None 202 """The key of the output system-wise array. If None, the input key is used.""" 203 average: bool = False 204 """Wether to divide by the number of atoms in the system.""" 205 206 FID: ClassVar[str] = "SCATTER_SYSTEM" 207 208 @nn.compact 209 def __call__(self, inputs) -> Any: 210 batch_index = inputs["batch_index"] 211 x = inputs[self.key] 212 assert ( 213 x.shape[0] == batch_index.shape[0] 214 ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}" 215 nsys = inputs["natoms"].shape[0] 216 if self.average: 217 shape = [batch_index.shape[0]] + (x.ndim - 1) * [1] 218 x = x / inputs["natoms"][batch_index].reshape(shape) 219 220 output = jax.ops.segment_sum(x, batch_index, nsys) 221 222 output_key = self.key if self.output_key is None else self.output_key 223 return {**inputs, output_key: output}
Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).
FID: SCATTER_SYSTEM
The key of the output system-wise array. If None, the input key is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
226class SystemToAtoms(nn.Module): 227 """Broadcast a system-wise array to an atom-wise array. 228 229 FID: SYSTEM_TO_ATOMS 230 """ 231 232 key: str 233 """The key of the input system-wise array.""" 234 output_key: Optional[str] = None 235 """The key of the output atom-wise array. If None, the input key is used.""" 236 237 FID: ClassVar[str] = "SYSTEM_TO_ATOMS" 238 239 @nn.compact 240 def __call__(self, inputs) -> Any: 241 batch_index = inputs["batch_index"] 242 x = inputs[self.key] 243 output = x[batch_index] 244 245 output_key = self.key if self.output_key is None else self.output_key 246 return {**inputs, output_key: output}
Broadcast a system-wise array to an atom-wise array.
FID: SYSTEM_TO_ATOMS
The key of the output atom-wise array. If None, the input key is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
249class SumAxis(nn.Module): 250 """Sum an array along an axis. 251 252 FID: SUM_AXIS 253 """ 254 255 key: str 256 """The key of the input array.""" 257 axis: Union[None, int, Sequence[int]] = None 258 """The axis along which to sum the array.""" 259 output_key: Optional[str] = None 260 """The key of the output array. If None, the input key is used.""" 261 norm: Optional[str] = None 262 """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'.""" 263 264 FID: ClassVar[str] = "SUM_AXIS" 265 266 @nn.compact 267 def __call__(self, inputs) -> Any: 268 x = inputs[self.key] 269 output = jnp.sum(x, axis=self.axis) 270 if self.norm is not None: 271 norm = self.norm.lower() 272 if norm == "dim": 273 dim = np.prod(x.shape[self.axis]) 274 output = output / dim 275 elif norm == "sqrt": 276 dim = np.prod(x.shape[self.axis]) 277 output = output / dim**0.5 278 elif norm == "none": 279 pass 280 else: 281 raise ValueError(f"Unknown norm {norm}") 282 output_key = self.key if self.output_key is None else self.output_key 283 return {**inputs, output_key: output}
Sum an array along an axis.
FID: SUM_AXIS
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
286class Split(nn.Module): 287 """Split an array along an axis. 288 289 FID: SPLIT 290 """ 291 292 key: str 293 """The key of the input array.""" 294 output_keys: Sequence[str] 295 """The keys of the output arrays.""" 296 axis: int = -1 297 """The axis along which to split the array.""" 298 sizes: Union[int, Sequence[int]] = 1 299 """The sizes of the splits.""" 300 squeeze: bool = True 301 """Whether to remove the axis in the output if the size is 1.""" 302 303 FID: ClassVar[str] = "SPLIT" 304 305 @nn.compact 306 def __call__(self, inputs) -> Any: 307 x = inputs[self.key] 308 309 if isinstance(self.sizes, int): 310 split_size = [self.sizes] * len(self.output_keys) 311 else: 312 split_size = self.sizes 313 if len(split_size) == len(self.output_keys): 314 assert ( 315 sum(split_size) == x.shape[self.axis] 316 ), f"Split sizes {split_size} do not match input shape" 317 split_size = split_size[:-1] 318 assert ( 319 len(split_size) == len(self.output_keys) - 1 320 ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs" 321 split_indices = np.cumsum(split_size) 322 outs = {} 323 324 for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)): 325 outs[k] = ( 326 jnp.squeeze(v, axis=self.axis) 327 if self.squeeze and v.shape[self.axis] == 1 328 else v 329 ) 330 331 return {**inputs, **outs}
Split an array along an axis.
FID: SPLIT
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
334class Concatenate(nn.Module): 335 """Concatenate a list of arrays along an axis. 336 337 FID: CONCATENATE 338 """ 339 340 keys: Sequence[str] 341 """The keys of the input arrays.""" 342 axis: int = -1 343 """The axis along which to concatenate the arrays.""" 344 output_key: Optional[str] = None 345 """The key of the output array. If None, the name of the module is used.""" 346 347 FID: ClassVar[str] = "CONCATENATE" 348 349 @nn.compact 350 def __call__(self, inputs) -> Any: 351 output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis) 352 output_key = self.output_key if self.output_key is not None else self.name 353 return {**inputs, output_key: output}
Concatenate a list of arrays along an axis.
FID: CONCATENATE
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
356class Activation(nn.Module): 357 """Apply an element-wise activation function to an array. 358 359 FID: ACTIVATION 360 """ 361 362 key: str 363 """The key of the input array.""" 364 activation: Union[Callable, str] 365 """The activation function or its name.""" 366 scale_out: float = 1.0 367 """Output scaling factor.""" 368 shift_out: float = 0.0 369 """Output shift.""" 370 output_key: Optional[str] = None 371 """The key of the output array. If None, the input key is used.""" 372 373 FID: ClassVar[str] = "ACTIVATION" 374 375 @nn.compact 376 def __call__(self, inputs) -> Any: 377 x = inputs[self.key] 378 activation = ( 379 activation_from_str(self.activation) 380 if isinstance(self.activation, str) 381 else self.activation 382 ) 383 output = self.scale_out * activation(x) + self.shift_out 384 output_key = self.output_key if self.output_key is not None else self.key 385 return {**inputs, output_key: output}
Apply an element-wise activation function to an array.
FID: ACTIVATION
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
388class Scale(nn.Module): 389 """Scale an array by a constant factor. 390 391 FID: SCALE 392 """ 393 394 key: str 395 """The key of the input array.""" 396 scale: float 397 """The (initial) scaling factor.""" 398 output_key: Optional[str] = None 399 """The key of the output array. If None, the input key is used.""" 400 trainable: bool = False 401 """Whether the scaling factor is trainable.""" 402 403 FID: ClassVar[str] = "SCALE" 404 405 @nn.compact 406 def __call__(self, inputs) -> Any: 407 x = inputs[self.key] 408 409 if self.trainable: 410 scale = self.param("scale", lambda rng: jnp.asarray(self.scale)) 411 else: 412 scale = self.scale 413 414 output = scale * x 415 output_key = self.output_key if self.output_key is not None else self.key 416 return {**inputs, output_key: output}
Scale an array by a constant factor.
FID: SCALE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
419class Add(nn.Module): 420 """Add together a list of arrays. 421 422 FID: ADD 423 """ 424 425 keys: Sequence[str] 426 """The keys of the input arrays.""" 427 output_key: Optional[str] = None 428 """The key of the output array. If None, the name of the module is used.""" 429 430 FID: ClassVar[str] = "ADD" 431 432 @nn.compact 433 def __call__(self, inputs) -> Any: 434 output = 0 435 for k in self.keys: 436 output = output + inputs[k] 437 438 output_key = self.output_key if self.output_key is not None else self.name 439 return {**inputs, output_key: output}
Add together a list of arrays.
FID: ADD
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
442class Multiply(nn.Module): 443 """Element-wise-multiply together a list of arrays. 444 445 FID: MULTIPLY 446 """ 447 448 keys: Sequence[str] 449 """The keys of the input arrays.""" 450 output_key: Optional[str] = None 451 """The key of the output array. If None, the name of the module is used.""" 452 453 FID: ClassVar[str] = "MULTIPLY" 454 455 @nn.compact 456 def __call__(self, inputs) -> Any: 457 output = 1 458 for k in self.keys: 459 output = output * inputs[k] 460 461 output_key = self.output_key if self.output_key is not None else self.name 462 return {**inputs, output_key: output}
Element-wise-multiply together a list of arrays.
FID: MULTIPLY
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
465class Transpose(nn.Module): 466 """Transpose an array. 467 468 FID: TRANSPOSE 469 """ 470 471 key: str 472 """The key of the input array.""" 473 axes: Sequence[int] 474 """The permutation of the axes. See `jax.numpy.transpose` for more details.""" 475 output_key: Optional[str] = None 476 """The key of the output array. If None, the input key is used.""" 477 478 FID: ClassVar[str] = "TRANSPOSE" 479 480 @nn.compact 481 def __call__(self, inputs) -> Any: 482 output = jnp.transpose(inputs[self.key], axes=self.axes) 483 output_key = self.output_key if self.output_key is not None else self.key 484 return {**inputs, output_key: output}
Transpose an array.
FID: TRANSPOSE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
487class Reshape(nn.Module): 488 """Reshape an array. 489 490 FID: RESHAPE 491 """ 492 493 key: str 494 """The key of the input array.""" 495 shape: Sequence[int] 496 """The shape of the output array.""" 497 output_key: Optional[str] = None 498 """The key of the output array. If None, the input key is used.""" 499 500 FID: ClassVar[str] = "RESHAPE" 501 502 @nn.compact 503 def __call__(self, inputs) -> Any: 504 output = jnp.reshape(inputs[self.key], self.shape) 505 output_key = self.output_key if self.output_key is not None else self.key 506 return {**inputs, output_key: output}
Reshape an array.
FID: RESHAPE
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
509class ChemicalConstant(nn.Module): 510 """Map atomic species to a constant value. 511 512 FID: CHEMICAL_CONSTANT 513 """ 514 515 value: Union[str, List[float], float, Dict] 516 """The constant value or a dictionary of values for each element.""" 517 output_key: Optional[str] = None 518 """The key of the output array. If None, the name of the module is used.""" 519 trainable: bool = False 520 """Whether the constant is trainable.""" 521 522 FID: ClassVar[str] = "CHEMICAL_CONSTANT" 523 524 @nn.compact 525 def __call__(self, inputs) -> Any: 526 if isinstance(self.value, str): 527 constant = CHEMICAL_PROPERTIES[self.value.upper()] 528 elif isinstance(self.value, list): 529 constant = self.value 530 elif isinstance(self.value, float): 531 constant = [self.value] * len(PERIODIC_TABLE) 532 elif hasattr(self.value, "items"): 533 constant = [0.0] * len(PERIODIC_TABLE) 534 for k, v in self.value.items(): 535 constant[PERIODIC_TABLE_REV_IDX[k]] = v 536 else: 537 raise ValueError(f"Unknown constant type {type(self.value)}") 538 539 if self.trainable: 540 constant = self.param( 541 "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32) 542 ) 543 else: 544 constant = jnp.asarray(constant, dtype=jnp.float32) 545 output = constant[inputs["species"]] 546 output_key = self.output_key if self.output_key is not None else self.name 547 return {**inputs, output_key: output}
Map atomic species to a constant value.
FID: CHEMICAL_CONSTANT
The constant value or a dictionary of values for each element.
The key of the output array. If None, the name of the module is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
550class SwitchFunction(nn.Module): 551 """Compute a switch array from an array of distances and a cutoff. 552 553 FID: SWITCH_FUNCTION 554 """ 555 556 cutoff: Optional[float] = None 557 """The cutoff distance. If None, the cutoff is taken from the graph.""" 558 switch_start: float = 0.0 559 """The proportion of the cutoff distance at which the switch function starts.""" 560 graph_key: Optional[str] = "graph" 561 """The key of the graph containing the distances and edge mask.""" 562 output_key: Optional[str] = None 563 """The key of the output switch array. If None, it is added to the graph.""" 564 switch_type: str = "cosine" 565 """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.""" 566 p: Optional[float] = None 567 """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`.""" 568 trainable: bool = False 569 """Whether the switch parameter is trainable.""" 570 571 FID: ClassVar[str] = "SWITCH_FUNCTION" 572 573 @nn.compact 574 def __call__(self, inputs) -> Any: 575 if self.graph_key is not None: 576 graph = inputs[self.graph_key] 577 distances, edge_mask = graph["distances"], graph["edge_mask"] 578 if self.cutoff is not None: 579 edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff)) 580 cutoff = self.cutoff 581 else: 582 cutoff = graph["cutoff"] 583 else: 584 # distances = inputs 585 distances, edge_mask = inputs 586 assert ( 587 self.cutoff is not None 588 ), "cutoff must be specified if no graph is given" 589 # edge_mask = distances < self.cutoff 590 cutoff = self.cutoff 591 592 if self.switch_start > 1.0e-5: 593 assert ( 594 self.switch_start < 1.0 595 ), "switch_start is a proportion of cutoff and must be smaller than 1." 596 cutoff_in = self.switch_start * cutoff 597 x = distances - cutoff_in 598 end = cutoff - cutoff_in 599 else: 600 x = distances 601 end = cutoff 602 603 switch_type = self.switch_type.lower() 604 if switch_type == "cosine": 605 p = self.p if self.p is not None else 1.0 606 if self.trainable: 607 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 608 switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p 609 610 elif switch_type == "polynomial": 611 p = self.p if self.p is not None else 3.0 612 if self.trainable: 613 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 614 d = x / end 615 switch = ( 616 1.0 617 - 0.5 * (p + 1) * (p + 2) * d**p 618 + p * (p + 2) * d ** (p + 1) 619 - 0.5 * p * (p + 1) * d ** (p + 2) 620 ) 621 622 elif switch_type == "exponential": 623 p = self.p if self.p is not None else 1.0 624 if self.trainable: 625 p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32)) 626 r2 = x**2 627 c2 = end**2 628 switch = jnp.exp(-p * r2 / (c2 - r2)) 629 630 elif switch_type == "hard": 631 switch = jnp.where(distances < cutoff, 1.0, 0.0) 632 else: 633 raise ValueError(f"Unknown switch function {switch_type}") 634 635 if self.switch_start > 1.0e-5: 636 switch = jnp.where(distances < cutoff_in, 1.0, switch) 637 638 switch = jnp.where(edge_mask, switch, 0.0) 639 640 if self.graph_key is not None: 641 if self.output_key is not None: 642 return {**inputs, self.output_key: switch} 643 else: 644 return {**inputs, self.graph_key: {**graph, "switch": switch}} 645 else: 646 return switch # , edge_mask
Compute a switch array from an array of distances and a cutoff.
FID: SWITCH_FUNCTION
The proportion of the cutoff distance at which the switch function starts.
The key of the output switch array. If None, it is added to the graph.
The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.
The parameter of the switch function. If None, it is fixed to the default for each switch_type
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.