fennol.models.misc.misc

  1import flax.linen as nn
  2from typing import Any, Sequence, Callable, Union, ClassVar, Optional, Dict, List
  3import jax.numpy as jnp
  4import jax
  5import numpy as np
  6from functools import partial
  7from ...utils.activations import activation_from_str
  8from ...utils.periodic_table import (
  9    CHEMICAL_PROPERTIES,
 10    PERIODIC_TABLE,
 11    PERIODIC_TABLE_REV_IDX,
 12)
 13
 14
 15def apply_switch(x: jax.Array, switch: jax.Array):
 16    """@private Multiply a switch array to an array of values."""
 17    shape = x.shape
 18    return (
 19        jnp.expand_dims(x, axis=-1).reshape(*switch.shape, -1) * switch[..., None]
 20    ).reshape(shape)
 21
 22
 23class ApplySwitch(nn.Module):
 24    """Multiply an edge array by a switch array.
 25
 26    FID: APPLY_SWITCH
 27    """
 28
 29    key: str
 30    """The key of the input array."""
 31    switch_key: Optional[str] = None
 32    """The key of the switch array."""
 33    graph_key: Optional[str] = None
 34    """The key of the graph containing the switch."""
 35    output_key: Optional[str] = None
 36    """The key of the output array. If None, the input key is used."""
 37
 38    FID: ClassVar[str] = "APPLY_SWITCH"
 39
 40    @nn.compact
 41    def __call__(self, inputs) -> Any:
 42        if self.graph_key is not None:
 43            graph = inputs[self.graph_key]
 44            switch = graph["switch"]
 45        elif self.switch_key is not None:
 46            switch = inputs[self.switch_key]
 47        else:
 48            raise ValueError("Either graph_key or switch_key must be specified")
 49
 50        x = inputs[self.key]
 51        output = apply_switch(x, switch)
 52        output_key = self.key if self.output_key is None else self.output_key
 53        return {**inputs, output_key: output}
 54
 55
 56class AtomToEdge(nn.Module):
 57    """Map atom-wise values to edge-wise values.
 58
 59    FID: ATOM_TO_EDGE
 60
 61    By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True.
 62    """
 63
 64    _graphs_properties: Dict
 65    key: str
 66    """The key of the input atom-wise array."""
 67    output_key: Optional[str] = None
 68    """The key of the output edge-wise array. If None, the input key is used."""
 69    graph_key: str = "graph"
 70    """The key of the graph containing the edges."""
 71    switch: bool = False
 72    """Whether to apply a switch to the edge values."""
 73    switch_key: Optional[str] = None
 74    """The key of the switch array. If None, the switch is taken from the graph."""
 75    use_source: bool = False
 76    """Whether to use the source atom value instead of the destination atom value."""
 77
 78    FID: ClassVar[str] = "ATOM_TO_EDGE"
 79
 80    @nn.compact
 81    def __call__(self, inputs) -> Any:
 82        graph = inputs[self.graph_key]
 83        nat = inputs["species"].shape[0]
 84        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 85
 86        x = inputs[self.key]
 87        if self.use_source:
 88            x_edge = x[edge_src]
 89        else:
 90            x_edge = x[edge_dst]
 91
 92        if self.switch:
 93            switch = (
 94                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
 95            )
 96            x_edge = apply_switch(x_edge, switch)
 97
 98        output_key = self.key if self.output_key is None else self.output_key
 99        return {**inputs, output_key: x_edge}
100
101
102class ScatterEdges(nn.Module):
103    """Reduce an edge array to atoms by summing over neighbors.
104
105    FID: SCATTER_EDGES
106    """
107
108    _graphs_properties: Dict
109    key: str
110    """The key of the input edge-wise array."""
111    output_key: Optional[str] = None
112    """The key of the output atom-wise array. If None, the input key is used."""
113    graph_key: str = "graph"
114    """The key of the graph containing the edges."""
115    switch: bool = False
116    """Whether to apply a switch to the edge values before summing."""
117    switch_key: Optional[str] = None
118    """The key of the switch array. If None, the switch is taken from the graph."""
119    antisymmetric: bool = False
120
121    FID: ClassVar[str] = "SCATTER_EDGES"
122
123    @nn.compact
124    def __call__(self, inputs) -> Any:
125        graph = inputs[self.graph_key]
126        nat = inputs["species"].shape[0]
127        x = inputs[self.key]
128
129        if self.switch:
130            switch = (
131                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
132            )
133            x = apply_switch(x, switch)
134
135        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
136        output = jax.ops.segment_sum(
137            x, edge_src, nat
138        )  # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop")
139        if self.antisymmetric:
140            output = output - jax.ops.segment_sum(x, edge_dst, nat)
141
142        output_key = self.key if self.output_key is None else self.output_key
143        return {**inputs, output_key: output}
144
145
146class EdgeConcatenate(nn.Module):
147    """Concatenate the source and destination atom values of an edge.
148
149    FID: EDGE_CONCATENATE
150    """
151
152    _graphs_properties: Dict
153    key: str
154    """The key of the input atom-wise array."""
155    output_key: Optional[str] = None
156    """The key of the output edge-wise array. If None, the input key is used."""
157    graph_key: str = "graph"
158    """The key of the graph containing the edges."""
159    switch: bool = False
160    """Whether to apply a switch to the edge values."""
161    switch_key: Optional[str] = None
162    """The key of the switch array. If None, the switch is taken from the graph."""
163    axis: int = -1
164    """The axis along which to concatenate the atom values."""
165
166    FID: ClassVar[str] = "EDGE_CONCATENATE"
167
168    @nn.compact
169    def __call__(self, inputs) -> Any:
170        graph = inputs[self.graph_key]
171        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
172        nat = inputs["species"].shape[0]
173        xi = inputs[self.key]
174
175        assert self._graphs_properties[self.graph_key][
176            "directed"
177        ], "EdgeConcatenate only works for directed graphs"
178        assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat"
179
180        xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis)
181
182        if self.switch:
183            switch = (
184                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
185            )
186            xij = apply_switch(xij, switch)
187
188        output_key = self.name if self.output_key is None else self.output_key
189        return {**inputs, output_key: xij}
190
191
192class ScatterSystem(nn.Module):
193    """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).
194
195    FID: SCATTER_SYSTEM
196    """
197
198    key: str
199    """The key of the input atom-wise array."""
200    output_key: Optional[str] = None
201    """The key of the output system-wise array. If None, the input key is used."""
202    average: bool = False
203    """Wether to divide by the number of atoms in the system."""
204
205    FID: ClassVar[str] = "SCATTER_SYSTEM"
206
207    @nn.compact
208    def __call__(self, inputs) -> Any:
209        batch_index = inputs["batch_index"]
210        x = inputs[self.key]
211        assert (
212            x.shape[0] == batch_index.shape[0]
213        ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}"
214        nsys = inputs["natoms"].shape[0]
215        if self.average:
216            shape = [batch_index.shape[0]] + (x.ndim - 1) * [1]
217            x = x / inputs["natoms"][batch_index].reshape(shape)
218
219        output = jax.ops.segment_sum(x, batch_index, nsys)
220
221        output_key = self.key if self.output_key is None else self.output_key
222        return {**inputs, output_key: output}
223
224
225class SystemToAtoms(nn.Module):
226    """Broadcast a system-wise array to an atom-wise array.
227
228    FID: SYSTEM_TO_ATOMS
229    """
230
231    key: str
232    """The key of the input system-wise array."""
233    output_key: Optional[str] = None
234    """The key of the output atom-wise array. If None, the input key is used."""
235
236    FID: ClassVar[str] = "SYSTEM_TO_ATOMS"
237
238    @nn.compact
239    def __call__(self, inputs) -> Any:
240        batch_index = inputs["batch_index"]
241        x = inputs[self.key]
242        output = x[batch_index]
243
244        output_key = self.key if self.output_key is None else self.output_key
245        return {**inputs, output_key: output}
246
247
248class SumAxis(nn.Module):
249    """Sum an array along an axis.
250
251    FID: SUM_AXIS
252    """
253
254    key: str
255    """The key of the input array."""
256    axis: Union[None, int, Sequence[int]] = None
257    """The axis along which to sum the array."""
258    output_key: Optional[str] = None
259    """The key of the output array. If None, the input key is used."""
260    norm: Optional[str] = None
261    """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'."""
262
263    FID: ClassVar[str] = "SUM_AXIS"
264
265    @nn.compact
266    def __call__(self, inputs) -> Any:
267        x = inputs[self.key]
268        output = jnp.sum(x, axis=self.axis)
269        if self.norm is not None:
270            norm = self.norm.lower()
271            if norm == "dim":
272                dim = np.prod(x.shape[self.axis])
273                output = output / dim
274            elif norm == "sqrt":
275                dim = np.prod(x.shape[self.axis])
276                output = output / dim**0.5
277            elif norm == "none":
278                pass
279            else:
280                raise ValueError(f"Unknown norm {norm}")
281        output_key = self.key if self.output_key is None else self.output_key
282        return {**inputs, output_key: output}
283
284
285class Split(nn.Module):
286    """Split an array along an axis.
287
288    FID: SPLIT
289    """
290
291    key: str
292    """The key of the input array."""
293    output_keys: Sequence[str]
294    """The keys of the output arrays."""
295    axis: int = -1
296    """The axis along which to split the array."""
297    sizes: Union[int, Sequence[int]] = 1
298    """The sizes of the splits."""
299    squeeze: bool = True
300    """Whether to remove the axis in the output if the size is 1."""
301
302    FID: ClassVar[str] = "SPLIT"
303
304    @nn.compact
305    def __call__(self, inputs) -> Any:
306        x = inputs[self.key]
307
308        if isinstance(self.sizes, int):
309            split_size = [self.sizes] * len(self.output_keys)
310        else:
311            split_size = self.sizes
312        if len(split_size) == len(self.output_keys):
313            assert (
314                sum(split_size) == x.shape[self.axis]
315            ), f"Split sizes {split_size} do not match input shape"
316            split_size = split_size[:-1]
317        assert (
318            len(split_size) == len(self.output_keys) - 1
319        ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs"
320        split_indices = np.cumsum(split_size)
321        outs = {}
322
323        for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)):
324            outs[k] = (
325                jnp.squeeze(v, axis=self.axis)
326                if self.squeeze and v.shape[self.axis] == 1
327                else v
328            )
329
330        return {**inputs, **outs}
331
332
333class Concatenate(nn.Module):
334    """Concatenate a list of arrays along an axis.
335
336    FID: CONCATENATE
337    """
338
339    keys: Sequence[str]
340    """The keys of the input arrays."""
341    axis: int = -1
342    """The axis along which to concatenate the arrays."""
343    output_key: Optional[str] = None
344    """The key of the output array. If None, the name of the module is used."""
345
346    FID: ClassVar[str] = "CONCATENATE"
347
348    @nn.compact
349    def __call__(self, inputs) -> Any:
350        output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis)
351        output_key = self.output_key if self.output_key is not None else self.name
352        return {**inputs, output_key: output}
353
354
355class Activation(nn.Module):
356    """Apply an element-wise activation function to an array.
357
358    FID: ACTIVATION
359    """
360
361    key: str
362    """The key of the input array."""
363    activation: Union[Callable, str]
364    """The activation function or its name."""
365    scale_out: float = 1.0
366    """Output scaling factor."""
367    shift_out: float = 0.0
368    """Output shift."""
369    output_key: Optional[str] = None
370    """The key of the output array. If None, the input key is used."""
371
372    FID: ClassVar[str] = "ACTIVATION"
373
374    @nn.compact
375    def __call__(self, inputs) -> Any:
376        x = inputs[self.key]
377        activation = (
378            activation_from_str(self.activation)
379            if isinstance(self.activation, str)
380            else self.activation
381        )
382        output = self.scale_out * activation(x) + self.shift_out
383        output_key = self.output_key if self.output_key is not None else self.key
384        return {**inputs, output_key: output}
385
386
387class Scale(nn.Module):
388    """Scale an array by a constant factor.
389
390    FID: SCALE
391    """
392
393    key: str
394    """The key of the input array."""
395    scale: float
396    """The (initial) scaling factor."""
397    output_key: Optional[str] = None
398    """The key of the output array. If None, the input key is used."""
399    trainable: bool = False
400    """Whether the scaling factor is trainable."""
401
402    FID: ClassVar[str] = "SCALE"
403
404    @nn.compact
405    def __call__(self, inputs) -> Any:
406        x = inputs[self.key]
407
408        if self.trainable:
409            scale = self.param("scale", lambda rng: jnp.asarray(self.scale))
410        else:
411            scale = self.scale
412
413        output = scale * x
414        output_key = self.output_key if self.output_key is not None else self.key
415        return {**inputs, output_key: output}
416
417
418class Add(nn.Module):
419    """Add together a list of arrays.
420
421    FID: ADD
422    """
423
424    keys: Sequence[str]
425    """The keys of the input arrays."""
426    output_key: Optional[str] = None
427    """The key of the output array. If None, the name of the module is used."""
428
429    FID: ClassVar[str] = "ADD"
430
431    @nn.compact
432    def __call__(self, inputs) -> Any:
433        output = 0
434        for k in self.keys:
435            output = output + inputs[k]
436
437        output_key = self.output_key if self.output_key is not None else self.name
438        return {**inputs, output_key: output}
439
440
441class Multiply(nn.Module):
442    """Element-wise-multiply together a list of arrays.
443
444    FID: MULTIPLY
445    """
446
447    keys: Sequence[str]
448    """The keys of the input arrays."""
449    output_key: Optional[str] = None
450    """The key of the output array. If None, the name of the module is used."""
451
452    FID: ClassVar[str] = "MULTIPLY"
453
454    @nn.compact
455    def __call__(self, inputs) -> Any:
456        output = 1
457        for k in self.keys:
458            output = output * inputs[k]
459
460        output_key = self.output_key if self.output_key is not None else self.name
461        return {**inputs, output_key: output}
462
463
464class Transpose(nn.Module):
465    """Transpose an array.
466
467    FID: TRANSPOSE
468    """
469
470    key: str
471    """The key of the input array."""
472    axes: Sequence[int]
473    """The permutation of the axes. See `jax.numpy.transpose` for more details."""
474    output_key: Optional[str] = None
475    """The key of the output array. If None, the input key is used."""
476
477    FID: ClassVar[str] = "TRANSPOSE"
478
479    @nn.compact
480    def __call__(self, inputs) -> Any:
481        output = jnp.transpose(inputs[self.key], axes=self.axes)
482        output_key = self.output_key if self.output_key is not None else self.key
483        return {**inputs, output_key: output}
484
485
486class Reshape(nn.Module):
487    """Reshape an array.
488
489    FID: RESHAPE
490    """
491
492    key: str
493    """The key of the input array."""
494    shape: Sequence[int]
495    """The shape of the output array."""
496    output_key: Optional[str] = None
497    """The key of the output array. If None, the input key is used."""
498
499    FID: ClassVar[str] = "RESHAPE"
500
501    @nn.compact
502    def __call__(self, inputs) -> Any:
503        output = jnp.reshape(inputs[self.key], self.shape)
504        output_key = self.output_key if self.output_key is not None else self.key
505        return {**inputs, output_key: output}
506
507
508class ChemicalConstant(nn.Module):
509    """Map atomic species to a constant value.
510
511    FID: CHEMICAL_CONSTANT
512    """
513
514    value: Union[str, List[float], float, Dict]
515    """The constant value or a dictionary of values for each element."""
516    output_key: Optional[str] = None
517    """The key of the output array. If None, the name of the module is used."""
518    trainable: bool = False
519    """Whether the constant is trainable."""
520
521    FID: ClassVar[str] = "CHEMICAL_CONSTANT"
522
523    @nn.compact
524    def __call__(self, inputs) -> Any:
525        if isinstance(self.value, str):
526            constant = CHEMICAL_PROPERTIES[self.value.upper()]
527        elif isinstance(self.value, list):
528            constant = self.value
529        elif isinstance(self.value, float):
530            constant = [self.value] * len(PERIODIC_TABLE)
531        elif hasattr(self.value, "items"):
532            constant = [0.0] * len(PERIODIC_TABLE)
533            for k, v in self.value.items():
534                constant[PERIODIC_TABLE_REV_IDX[k]] = v
535        else:
536            raise ValueError(f"Unknown constant type {type(self.value)}")
537
538        if self.trainable:
539            constant = self.param(
540                "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32)
541            )
542        else:
543            constant = jnp.asarray(constant, dtype=jnp.float32)
544        output = constant[inputs["species"]]
545        output_key = self.output_key if self.output_key is not None else self.name
546        return {**inputs, output_key: output}
547
548
549class SwitchFunction(nn.Module):
550    """Compute a switch array from an array of distances and a cutoff.
551
552    FID: SWITCH_FUNCTION
553    """
554
555    cutoff: Optional[float] = None
556    """The cutoff distance. If None, the cutoff is taken from the graph."""
557    switch_start: float = 0.0
558    """The proportion of the cutoff distance at which the switch function starts."""
559    graph_key: Optional[str] = "graph"
560    """The key of the graph containing the distances and edge mask."""
561    output_key: Optional[str] = None
562    """The key of the output switch array. If None, it is added to the graph."""
563    switch_type: str = "cosine"
564    """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'."""
565    p: Optional[float] = None
566    """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`."""
567    trainable: bool = False
568    """Whether the switch parameter is trainable."""
569
570    FID: ClassVar[str] = "SWITCH_FUNCTION"
571
572    @nn.compact
573    def __call__(self, inputs) -> Any:
574        if self.graph_key is not None:
575            graph = inputs[self.graph_key]
576            distances, edge_mask = graph["distances"], graph["edge_mask"]
577            if self.cutoff is not None:
578                edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff))
579                cutoff = self.cutoff
580            else:
581                cutoff = graph["cutoff"]
582        else:
583            # distances = inputs
584            distances, edge_mask = inputs
585            assert (
586                self.cutoff is not None
587            ), "cutoff must be specified if no graph is given"
588            # edge_mask = distances < self.cutoff
589            cutoff = self.cutoff
590
591        if self.switch_start > 1.0e-5:
592            assert (
593                self.switch_start < 1.0
594            ), "switch_start is a proportion of cutoff and must be smaller than 1."
595            cutoff_in = self.switch_start * cutoff
596            x = distances - cutoff_in
597            end = cutoff - cutoff_in
598        else:
599            x = distances
600            end = cutoff
601
602        switch_type = self.switch_type.lower()
603        if switch_type == "cosine":
604            p = self.p if self.p is not None else 1.0
605            if self.trainable:
606                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
607            switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p
608
609        elif switch_type == "polynomial":
610            p = self.p if self.p is not None else 3.0
611            if self.trainable:
612                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
613            d = x / end
614            switch = (
615                1.0
616                - 0.5 * (p + 1) * (p + 2) * d**p
617                + p * (p + 2) * d ** (p + 1)
618                - 0.5 * p * (p + 1) * d ** (p + 2)
619            )
620
621        elif switch_type == "exponential":
622            p = self.p if self.p is not None else 1.0
623            if self.trainable:
624                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
625            r2 = x**2
626            c2 = end**2
627            switch = jnp.exp(-p * r2 / (c2 - r2))
628        
629        elif switch_type == "hard":
630            switch = jnp.where(distances < cutoff, 1.0, 0.0)
631        else:
632            raise ValueError(f"Unknown switch function {switch_type}")
633
634        if self.switch_start > 1.0e-5:
635            switch = jnp.where(distances < cutoff_in, 1.0, switch)
636
637        switch = jnp.where(edge_mask, switch, 0.0)
638
639        if self.graph_key is not None:
640            if self.output_key is not None:
641                return {**inputs, self.output_key: switch}
642            else:
643                return {**inputs, self.graph_key: {**graph, "switch": switch}}
644        else:
645            return switch  # , edge_mask
class ApplySwitch(flax.linen.module.Module):
24class ApplySwitch(nn.Module):
25    """Multiply an edge array by a switch array.
26
27    FID: APPLY_SWITCH
28    """
29
30    key: str
31    """The key of the input array."""
32    switch_key: Optional[str] = None
33    """The key of the switch array."""
34    graph_key: Optional[str] = None
35    """The key of the graph containing the switch."""
36    output_key: Optional[str] = None
37    """The key of the output array. If None, the input key is used."""
38
39    FID: ClassVar[str] = "APPLY_SWITCH"
40
41    @nn.compact
42    def __call__(self, inputs) -> Any:
43        if self.graph_key is not None:
44            graph = inputs[self.graph_key]
45            switch = graph["switch"]
46        elif self.switch_key is not None:
47            switch = inputs[self.switch_key]
48        else:
49            raise ValueError("Either graph_key or switch_key must be specified")
50
51        x = inputs[self.key]
52        output = apply_switch(x, switch)
53        output_key = self.key if self.output_key is None else self.output_key
54        return {**inputs, output_key: output}

Multiply an edge array by a switch array.

FID: APPLY_SWITCH

ApplySwitch( key: str, switch_key: Optional[str] = None, graph_key: Optional[str] = None, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

switch_key: Optional[str] = None

The key of the switch array.

graph_key: Optional[str] = None

The key of the graph containing the switch.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

FID: ClassVar[str] = 'APPLY_SWITCH'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class AtomToEdge(flax.linen.module.Module):
 57class AtomToEdge(nn.Module):
 58    """Map atom-wise values to edge-wise values.
 59
 60    FID: ATOM_TO_EDGE
 61
 62    By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True.
 63    """
 64
 65    _graphs_properties: Dict
 66    key: str
 67    """The key of the input atom-wise array."""
 68    output_key: Optional[str] = None
 69    """The key of the output edge-wise array. If None, the input key is used."""
 70    graph_key: str = "graph"
 71    """The key of the graph containing the edges."""
 72    switch: bool = False
 73    """Whether to apply a switch to the edge values."""
 74    switch_key: Optional[str] = None
 75    """The key of the switch array. If None, the switch is taken from the graph."""
 76    use_source: bool = False
 77    """Whether to use the source atom value instead of the destination atom value."""
 78
 79    FID: ClassVar[str] = "ATOM_TO_EDGE"
 80
 81    @nn.compact
 82    def __call__(self, inputs) -> Any:
 83        graph = inputs[self.graph_key]
 84        nat = inputs["species"].shape[0]
 85        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 86
 87        x = inputs[self.key]
 88        if self.use_source:
 89            x_edge = x[edge_src]
 90        else:
 91            x_edge = x[edge_dst]
 92
 93        if self.switch:
 94            switch = (
 95                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
 96            )
 97            x_edge = apply_switch(x_edge, switch)
 98
 99        output_key = self.key if self.output_key is None else self.output_key
100        return {**inputs, output_key: x_edge}

Map atom-wise values to edge-wise values.

FID: ATOM_TO_EDGE

By default, we map the destination atom value to the edge. This can be changed by setting use_source to True.

AtomToEdge( _graphs_properties: Dict, key: str, output_key: Optional[str] = None, graph_key: str = 'graph', switch: bool = False, switch_key: Optional[str] = None, use_source: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input atom-wise array.

output_key: Optional[str] = None

The key of the output edge-wise array. If None, the input key is used.

graph_key: str = 'graph'

The key of the graph containing the edges.

switch: bool = False

Whether to apply a switch to the edge values.

switch_key: Optional[str] = None

The key of the switch array. If None, the switch is taken from the graph.

use_source: bool = False

Whether to use the source atom value instead of the destination atom value.

FID: ClassVar[str] = 'ATOM_TO_EDGE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ScatterEdges(flax.linen.module.Module):
103class ScatterEdges(nn.Module):
104    """Reduce an edge array to atoms by summing over neighbors.
105
106    FID: SCATTER_EDGES
107    """
108
109    _graphs_properties: Dict
110    key: str
111    """The key of the input edge-wise array."""
112    output_key: Optional[str] = None
113    """The key of the output atom-wise array. If None, the input key is used."""
114    graph_key: str = "graph"
115    """The key of the graph containing the edges."""
116    switch: bool = False
117    """Whether to apply a switch to the edge values before summing."""
118    switch_key: Optional[str] = None
119    """The key of the switch array. If None, the switch is taken from the graph."""
120    antisymmetric: bool = False
121
122    FID: ClassVar[str] = "SCATTER_EDGES"
123
124    @nn.compact
125    def __call__(self, inputs) -> Any:
126        graph = inputs[self.graph_key]
127        nat = inputs["species"].shape[0]
128        x = inputs[self.key]
129
130        if self.switch:
131            switch = (
132                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
133            )
134            x = apply_switch(x, switch)
135
136        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
137        output = jax.ops.segment_sum(
138            x, edge_src, nat
139        )  # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop")
140        if self.antisymmetric:
141            output = output - jax.ops.segment_sum(x, edge_dst, nat)
142
143        output_key = self.key if self.output_key is None else self.output_key
144        return {**inputs, output_key: output}

Reduce an edge array to atoms by summing over neighbors.

FID: SCATTER_EDGES

ScatterEdges( _graphs_properties: Dict, key: str, output_key: Optional[str] = None, graph_key: str = 'graph', switch: bool = False, switch_key: Optional[str] = None, antisymmetric: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input edge-wise array.

output_key: Optional[str] = None

The key of the output atom-wise array. If None, the input key is used.

graph_key: str = 'graph'

The key of the graph containing the edges.

switch: bool = False

Whether to apply a switch to the edge values before summing.

switch_key: Optional[str] = None

The key of the switch array. If None, the switch is taken from the graph.

antisymmetric: bool = False
FID: ClassVar[str] = 'SCATTER_EDGES'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class EdgeConcatenate(flax.linen.module.Module):
147class EdgeConcatenate(nn.Module):
148    """Concatenate the source and destination atom values of an edge.
149
150    FID: EDGE_CONCATENATE
151    """
152
153    _graphs_properties: Dict
154    key: str
155    """The key of the input atom-wise array."""
156    output_key: Optional[str] = None
157    """The key of the output edge-wise array. If None, the input key is used."""
158    graph_key: str = "graph"
159    """The key of the graph containing the edges."""
160    switch: bool = False
161    """Whether to apply a switch to the edge values."""
162    switch_key: Optional[str] = None
163    """The key of the switch array. If None, the switch is taken from the graph."""
164    axis: int = -1
165    """The axis along which to concatenate the atom values."""
166
167    FID: ClassVar[str] = "EDGE_CONCATENATE"
168
169    @nn.compact
170    def __call__(self, inputs) -> Any:
171        graph = inputs[self.graph_key]
172        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
173        nat = inputs["species"].shape[0]
174        xi = inputs[self.key]
175
176        assert self._graphs_properties[self.graph_key][
177            "directed"
178        ], "EdgeConcatenate only works for directed graphs"
179        assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat"
180
181        xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis)
182
183        if self.switch:
184            switch = (
185                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
186            )
187            xij = apply_switch(xij, switch)
188
189        output_key = self.name if self.output_key is None else self.output_key
190        return {**inputs, output_key: xij}

Concatenate the source and destination atom values of an edge.

FID: EDGE_CONCATENATE

EdgeConcatenate( _graphs_properties: Dict, key: str, output_key: Optional[str] = None, graph_key: str = 'graph', switch: bool = False, switch_key: Optional[str] = None, axis: int = -1, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input atom-wise array.

output_key: Optional[str] = None

The key of the output edge-wise array. If None, the input key is used.

graph_key: str = 'graph'

The key of the graph containing the edges.

switch: bool = False

Whether to apply a switch to the edge values.

switch_key: Optional[str] = None

The key of the switch array. If None, the switch is taken from the graph.

axis: int = -1

The axis along which to concatenate the atom values.

FID: ClassVar[str] = 'EDGE_CONCATENATE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ScatterSystem(flax.linen.module.Module):
193class ScatterSystem(nn.Module):
194    """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).
195
196    FID: SCATTER_SYSTEM
197    """
198
199    key: str
200    """The key of the input atom-wise array."""
201    output_key: Optional[str] = None
202    """The key of the output system-wise array. If None, the input key is used."""
203    average: bool = False
204    """Wether to divide by the number of atoms in the system."""
205
206    FID: ClassVar[str] = "SCATTER_SYSTEM"
207
208    @nn.compact
209    def __call__(self, inputs) -> Any:
210        batch_index = inputs["batch_index"]
211        x = inputs[self.key]
212        assert (
213            x.shape[0] == batch_index.shape[0]
214        ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}"
215        nsys = inputs["natoms"].shape[0]
216        if self.average:
217            shape = [batch_index.shape[0]] + (x.ndim - 1) * [1]
218            x = x / inputs["natoms"][batch_index].reshape(shape)
219
220        output = jax.ops.segment_sum(x, batch_index, nsys)
221
222        output_key = self.key if self.output_key is None else self.output_key
223        return {**inputs, output_key: output}

Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).

FID: SCATTER_SYSTEM

ScatterSystem( key: str, output_key: Optional[str] = None, average: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input atom-wise array.

output_key: Optional[str] = None

The key of the output system-wise array. If None, the input key is used.

average: bool = False

Wether to divide by the number of atoms in the system.

FID: ClassVar[str] = 'SCATTER_SYSTEM'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SystemToAtoms(flax.linen.module.Module):
226class SystemToAtoms(nn.Module):
227    """Broadcast a system-wise array to an atom-wise array.
228
229    FID: SYSTEM_TO_ATOMS
230    """
231
232    key: str
233    """The key of the input system-wise array."""
234    output_key: Optional[str] = None
235    """The key of the output atom-wise array. If None, the input key is used."""
236
237    FID: ClassVar[str] = "SYSTEM_TO_ATOMS"
238
239    @nn.compact
240    def __call__(self, inputs) -> Any:
241        batch_index = inputs["batch_index"]
242        x = inputs[self.key]
243        output = x[batch_index]
244
245        output_key = self.key if self.output_key is None else self.output_key
246        return {**inputs, output_key: output}

Broadcast a system-wise array to an atom-wise array.

FID: SYSTEM_TO_ATOMS

SystemToAtoms( key: str, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input system-wise array.

output_key: Optional[str] = None

The key of the output atom-wise array. If None, the input key is used.

FID: ClassVar[str] = 'SYSTEM_TO_ATOMS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SumAxis(flax.linen.module.Module):
249class SumAxis(nn.Module):
250    """Sum an array along an axis.
251
252    FID: SUM_AXIS
253    """
254
255    key: str
256    """The key of the input array."""
257    axis: Union[None, int, Sequence[int]] = None
258    """The axis along which to sum the array."""
259    output_key: Optional[str] = None
260    """The key of the output array. If None, the input key is used."""
261    norm: Optional[str] = None
262    """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'."""
263
264    FID: ClassVar[str] = "SUM_AXIS"
265
266    @nn.compact
267    def __call__(self, inputs) -> Any:
268        x = inputs[self.key]
269        output = jnp.sum(x, axis=self.axis)
270        if self.norm is not None:
271            norm = self.norm.lower()
272            if norm == "dim":
273                dim = np.prod(x.shape[self.axis])
274                output = output / dim
275            elif norm == "sqrt":
276                dim = np.prod(x.shape[self.axis])
277                output = output / dim**0.5
278            elif norm == "none":
279                pass
280            else:
281                raise ValueError(f"Unknown norm {norm}")
282        output_key = self.key if self.output_key is None else self.output_key
283        return {**inputs, output_key: output}

Sum an array along an axis.

FID: SUM_AXIS

SumAxis( key: str, axis: Union[NoneType, int, Sequence[int]] = None, output_key: Optional[str] = None, norm: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

axis: Union[NoneType, int, Sequence[int]] = None

The axis along which to sum the array.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

norm: Optional[str] = None

Normalization of the sum. Can be 'dim', 'sqrt', or 'none'.

FID: ClassVar[str] = 'SUM_AXIS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Split(flax.linen.module.Module):
286class Split(nn.Module):
287    """Split an array along an axis.
288
289    FID: SPLIT
290    """
291
292    key: str
293    """The key of the input array."""
294    output_keys: Sequence[str]
295    """The keys of the output arrays."""
296    axis: int = -1
297    """The axis along which to split the array."""
298    sizes: Union[int, Sequence[int]] = 1
299    """The sizes of the splits."""
300    squeeze: bool = True
301    """Whether to remove the axis in the output if the size is 1."""
302
303    FID: ClassVar[str] = "SPLIT"
304
305    @nn.compact
306    def __call__(self, inputs) -> Any:
307        x = inputs[self.key]
308
309        if isinstance(self.sizes, int):
310            split_size = [self.sizes] * len(self.output_keys)
311        else:
312            split_size = self.sizes
313        if len(split_size) == len(self.output_keys):
314            assert (
315                sum(split_size) == x.shape[self.axis]
316            ), f"Split sizes {split_size} do not match input shape"
317            split_size = split_size[:-1]
318        assert (
319            len(split_size) == len(self.output_keys) - 1
320        ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs"
321        split_indices = np.cumsum(split_size)
322        outs = {}
323
324        for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)):
325            outs[k] = (
326                jnp.squeeze(v, axis=self.axis)
327                if self.squeeze and v.shape[self.axis] == 1
328                else v
329            )
330
331        return {**inputs, **outs}

Split an array along an axis.

FID: SPLIT

Split( key: str, output_keys: Sequence[str], axis: int = -1, sizes: Union[int, Sequence[int]] = 1, squeeze: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

output_keys: Sequence[str]

The keys of the output arrays.

axis: int = -1

The axis along which to split the array.

sizes: Union[int, Sequence[int]] = 1

The sizes of the splits.

squeeze: bool = True

Whether to remove the axis in the output if the size is 1.

FID: ClassVar[str] = 'SPLIT'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Concatenate(flax.linen.module.Module):
334class Concatenate(nn.Module):
335    """Concatenate a list of arrays along an axis.
336
337    FID: CONCATENATE
338    """
339
340    keys: Sequence[str]
341    """The keys of the input arrays."""
342    axis: int = -1
343    """The axis along which to concatenate the arrays."""
344    output_key: Optional[str] = None
345    """The key of the output array. If None, the name of the module is used."""
346
347    FID: ClassVar[str] = "CONCATENATE"
348
349    @nn.compact
350    def __call__(self, inputs) -> Any:
351        output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis)
352        output_key = self.output_key if self.output_key is not None else self.name
353        return {**inputs, output_key: output}

Concatenate a list of arrays along an axis.

FID: CONCATENATE

Concatenate( keys: Sequence[str], axis: int = -1, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
keys: Sequence[str]

The keys of the input arrays.

axis: int = -1

The axis along which to concatenate the arrays.

output_key: Optional[str] = None

The key of the output array. If None, the name of the module is used.

FID: ClassVar[str] = 'CONCATENATE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Activation(flax.linen.module.Module):
356class Activation(nn.Module):
357    """Apply an element-wise activation function to an array.
358
359    FID: ACTIVATION
360    """
361
362    key: str
363    """The key of the input array."""
364    activation: Union[Callable, str]
365    """The activation function or its name."""
366    scale_out: float = 1.0
367    """Output scaling factor."""
368    shift_out: float = 0.0
369    """Output shift."""
370    output_key: Optional[str] = None
371    """The key of the output array. If None, the input key is used."""
372
373    FID: ClassVar[str] = "ACTIVATION"
374
375    @nn.compact
376    def __call__(self, inputs) -> Any:
377        x = inputs[self.key]
378        activation = (
379            activation_from_str(self.activation)
380            if isinstance(self.activation, str)
381            else self.activation
382        )
383        output = self.scale_out * activation(x) + self.shift_out
384        output_key = self.output_key if self.output_key is not None else self.key
385        return {**inputs, output_key: output}

Apply an element-wise activation function to an array.

FID: ACTIVATION

Activation( key: str, activation: Union[Callable, str], scale_out: float = 1.0, shift_out: float = 0.0, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

activation: Union[Callable, str]

The activation function or its name.

scale_out: float = 1.0

Output scaling factor.

shift_out: float = 0.0

Output shift.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

FID: ClassVar[str] = 'ACTIVATION'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Scale(flax.linen.module.Module):
388class Scale(nn.Module):
389    """Scale an array by a constant factor.
390
391    FID: SCALE
392    """
393
394    key: str
395    """The key of the input array."""
396    scale: float
397    """The (initial) scaling factor."""
398    output_key: Optional[str] = None
399    """The key of the output array. If None, the input key is used."""
400    trainable: bool = False
401    """Whether the scaling factor is trainable."""
402
403    FID: ClassVar[str] = "SCALE"
404
405    @nn.compact
406    def __call__(self, inputs) -> Any:
407        x = inputs[self.key]
408
409        if self.trainable:
410            scale = self.param("scale", lambda rng: jnp.asarray(self.scale))
411        else:
412            scale = self.scale
413
414        output = scale * x
415        output_key = self.output_key if self.output_key is not None else self.key
416        return {**inputs, output_key: output}

Scale an array by a constant factor.

FID: SCALE

Scale( key: str, scale: float, output_key: Optional[str] = None, trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

scale: float

The (initial) scaling factor.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

trainable: bool = False

Whether the scaling factor is trainable.

FID: ClassVar[str] = 'SCALE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Add(flax.linen.module.Module):
419class Add(nn.Module):
420    """Add together a list of arrays.
421
422    FID: ADD
423    """
424
425    keys: Sequence[str]
426    """The keys of the input arrays."""
427    output_key: Optional[str] = None
428    """The key of the output array. If None, the name of the module is used."""
429
430    FID: ClassVar[str] = "ADD"
431
432    @nn.compact
433    def __call__(self, inputs) -> Any:
434        output = 0
435        for k in self.keys:
436            output = output + inputs[k]
437
438        output_key = self.output_key if self.output_key is not None else self.name
439        return {**inputs, output_key: output}

Add together a list of arrays.

FID: ADD

Add( keys: Sequence[str], output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
keys: Sequence[str]

The keys of the input arrays.

output_key: Optional[str] = None

The key of the output array. If None, the name of the module is used.

FID: ClassVar[str] = 'ADD'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Multiply(flax.linen.module.Module):
442class Multiply(nn.Module):
443    """Element-wise-multiply together a list of arrays.
444
445    FID: MULTIPLY
446    """
447
448    keys: Sequence[str]
449    """The keys of the input arrays."""
450    output_key: Optional[str] = None
451    """The key of the output array. If None, the name of the module is used."""
452
453    FID: ClassVar[str] = "MULTIPLY"
454
455    @nn.compact
456    def __call__(self, inputs) -> Any:
457        output = 1
458        for k in self.keys:
459            output = output * inputs[k]
460
461        output_key = self.output_key if self.output_key is not None else self.name
462        return {**inputs, output_key: output}

Element-wise-multiply together a list of arrays.

FID: MULTIPLY

Multiply( keys: Sequence[str], output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
keys: Sequence[str]

The keys of the input arrays.

output_key: Optional[str] = None

The key of the output array. If None, the name of the module is used.

FID: ClassVar[str] = 'MULTIPLY'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Transpose(flax.linen.module.Module):
465class Transpose(nn.Module):
466    """Transpose an array.
467
468    FID: TRANSPOSE
469    """
470
471    key: str
472    """The key of the input array."""
473    axes: Sequence[int]
474    """The permutation of the axes. See `jax.numpy.transpose` for more details."""
475    output_key: Optional[str] = None
476    """The key of the output array. If None, the input key is used."""
477
478    FID: ClassVar[str] = "TRANSPOSE"
479
480    @nn.compact
481    def __call__(self, inputs) -> Any:
482        output = jnp.transpose(inputs[self.key], axes=self.axes)
483        output_key = self.output_key if self.output_key is not None else self.key
484        return {**inputs, output_key: output}

Transpose an array.

FID: TRANSPOSE

Transpose( key: str, axes: Sequence[int], output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

axes: Sequence[int]

The permutation of the axes. See jax.numpy.transpose for more details.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

FID: ClassVar[str] = 'TRANSPOSE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Reshape(flax.linen.module.Module):
487class Reshape(nn.Module):
488    """Reshape an array.
489
490    FID: RESHAPE
491    """
492
493    key: str
494    """The key of the input array."""
495    shape: Sequence[int]
496    """The shape of the output array."""
497    output_key: Optional[str] = None
498    """The key of the output array. If None, the input key is used."""
499
500    FID: ClassVar[str] = "RESHAPE"
501
502    @nn.compact
503    def __call__(self, inputs) -> Any:
504        output = jnp.reshape(inputs[self.key], self.shape)
505        output_key = self.output_key if self.output_key is not None else self.key
506        return {**inputs, output_key: output}

Reshape an array.

FID: RESHAPE

Reshape( key: str, shape: Sequence[int], output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

shape: Sequence[int]

The shape of the output array.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

FID: ClassVar[str] = 'RESHAPE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChemicalConstant(flax.linen.module.Module):
509class ChemicalConstant(nn.Module):
510    """Map atomic species to a constant value.
511
512    FID: CHEMICAL_CONSTANT
513    """
514
515    value: Union[str, List[float], float, Dict]
516    """The constant value or a dictionary of values for each element."""
517    output_key: Optional[str] = None
518    """The key of the output array. If None, the name of the module is used."""
519    trainable: bool = False
520    """Whether the constant is trainable."""
521
522    FID: ClassVar[str] = "CHEMICAL_CONSTANT"
523
524    @nn.compact
525    def __call__(self, inputs) -> Any:
526        if isinstance(self.value, str):
527            constant = CHEMICAL_PROPERTIES[self.value.upper()]
528        elif isinstance(self.value, list):
529            constant = self.value
530        elif isinstance(self.value, float):
531            constant = [self.value] * len(PERIODIC_TABLE)
532        elif hasattr(self.value, "items"):
533            constant = [0.0] * len(PERIODIC_TABLE)
534            for k, v in self.value.items():
535                constant[PERIODIC_TABLE_REV_IDX[k]] = v
536        else:
537            raise ValueError(f"Unknown constant type {type(self.value)}")
538
539        if self.trainable:
540            constant = self.param(
541                "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32)
542            )
543        else:
544            constant = jnp.asarray(constant, dtype=jnp.float32)
545        output = constant[inputs["species"]]
546        output_key = self.output_key if self.output_key is not None else self.name
547        return {**inputs, output_key: output}

Map atomic species to a constant value.

FID: CHEMICAL_CONSTANT

ChemicalConstant( value: Union[str, List[float], float, Dict], output_key: Optional[str] = None, trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
value: Union[str, List[float], float, Dict]

The constant value or a dictionary of values for each element.

output_key: Optional[str] = None

The key of the output array. If None, the name of the module is used.

trainable: bool = False

Whether the constant is trainable.

FID: ClassVar[str] = 'CHEMICAL_CONSTANT'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SwitchFunction(flax.linen.module.Module):
550class SwitchFunction(nn.Module):
551    """Compute a switch array from an array of distances and a cutoff.
552
553    FID: SWITCH_FUNCTION
554    """
555
556    cutoff: Optional[float] = None
557    """The cutoff distance. If None, the cutoff is taken from the graph."""
558    switch_start: float = 0.0
559    """The proportion of the cutoff distance at which the switch function starts."""
560    graph_key: Optional[str] = "graph"
561    """The key of the graph containing the distances and edge mask."""
562    output_key: Optional[str] = None
563    """The key of the output switch array. If None, it is added to the graph."""
564    switch_type: str = "cosine"
565    """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'."""
566    p: Optional[float] = None
567    """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`."""
568    trainable: bool = False
569    """Whether the switch parameter is trainable."""
570
571    FID: ClassVar[str] = "SWITCH_FUNCTION"
572
573    @nn.compact
574    def __call__(self, inputs) -> Any:
575        if self.graph_key is not None:
576            graph = inputs[self.graph_key]
577            distances, edge_mask = graph["distances"], graph["edge_mask"]
578            if self.cutoff is not None:
579                edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff))
580                cutoff = self.cutoff
581            else:
582                cutoff = graph["cutoff"]
583        else:
584            # distances = inputs
585            distances, edge_mask = inputs
586            assert (
587                self.cutoff is not None
588            ), "cutoff must be specified if no graph is given"
589            # edge_mask = distances < self.cutoff
590            cutoff = self.cutoff
591
592        if self.switch_start > 1.0e-5:
593            assert (
594                self.switch_start < 1.0
595            ), "switch_start is a proportion of cutoff and must be smaller than 1."
596            cutoff_in = self.switch_start * cutoff
597            x = distances - cutoff_in
598            end = cutoff - cutoff_in
599        else:
600            x = distances
601            end = cutoff
602
603        switch_type = self.switch_type.lower()
604        if switch_type == "cosine":
605            p = self.p if self.p is not None else 1.0
606            if self.trainable:
607                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
608            switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p
609
610        elif switch_type == "polynomial":
611            p = self.p if self.p is not None else 3.0
612            if self.trainable:
613                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
614            d = x / end
615            switch = (
616                1.0
617                - 0.5 * (p + 1) * (p + 2) * d**p
618                + p * (p + 2) * d ** (p + 1)
619                - 0.5 * p * (p + 1) * d ** (p + 2)
620            )
621
622        elif switch_type == "exponential":
623            p = self.p if self.p is not None else 1.0
624            if self.trainable:
625                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
626            r2 = x**2
627            c2 = end**2
628            switch = jnp.exp(-p * r2 / (c2 - r2))
629        
630        elif switch_type == "hard":
631            switch = jnp.where(distances < cutoff, 1.0, 0.0)
632        else:
633            raise ValueError(f"Unknown switch function {switch_type}")
634
635        if self.switch_start > 1.0e-5:
636            switch = jnp.where(distances < cutoff_in, 1.0, switch)
637
638        switch = jnp.where(edge_mask, switch, 0.0)
639
640        if self.graph_key is not None:
641            if self.output_key is not None:
642                return {**inputs, self.output_key: switch}
643            else:
644                return {**inputs, self.graph_key: {**graph, "switch": switch}}
645        else:
646            return switch  # , edge_mask

Compute a switch array from an array of distances and a cutoff.

FID: SWITCH_FUNCTION

SwitchFunction( cutoff: Optional[float] = None, switch_start: float = 0.0, graph_key: Optional[str] = 'graph', output_key: Optional[str] = None, switch_type: str = 'cosine', p: Optional[float] = None, trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: Optional[float] = None

The cutoff distance. If None, the cutoff is taken from the graph.

switch_start: float = 0.0

The proportion of the cutoff distance at which the switch function starts.

graph_key: Optional[str] = 'graph'

The key of the graph containing the distances and edge mask.

output_key: Optional[str] = None

The key of the output switch array. If None, it is added to the graph.

switch_type: str = 'cosine'

The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.

p: Optional[float] = None

The parameter of the switch function. If None, it is fixed to the default for each switch_type.

trainable: bool = False

Whether the switch parameter is trainable.

FID: ClassVar[str] = 'SWITCH_FUNCTION'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None