fennol.models.misc.misc

  1import flax.linen as nn
  2from typing import Any, Sequence, Callable, Union, ClassVar, Optional, Dict, List
  3import jax.numpy as jnp
  4import jax
  5import numpy as np
  6from functools import partial
  7from ...utils.activations import activation_from_str
  8from ...utils.periodic_table import (
  9    CHEMICAL_PROPERTIES,
 10    PERIODIC_TABLE,
 11    PERIODIC_TABLE_REV_IDX,
 12)
 13
 14
 15def apply_switch(x: jax.Array, switch: jax.Array):
 16    """@private Multiply a switch array to an array of values."""
 17    shape = x.shape
 18    return (
 19        jnp.expand_dims(x, axis=-1).reshape(*switch.shape, -1) * switch[..., None]
 20    ).reshape(shape)
 21
 22
 23class ApplySwitch(nn.Module):
 24    """Multiply an edge array by a switch array.
 25
 26    FID: APPLY_SWITCH
 27    """
 28
 29    key: str
 30    """The key of the input array."""
 31    switch_key: Optional[str] = None
 32    """The key of the switch array."""
 33    graph_key: Optional[str] = None
 34    """The key of the graph containing the switch."""
 35    output_key: Optional[str] = None
 36    """The key of the output array. If None, the input key is used."""
 37
 38    FID: ClassVar[str] = "APPLY_SWITCH"
 39
 40    @nn.compact
 41    def __call__(self, inputs) -> Any:
 42        if self.graph_key is not None:
 43            graph = inputs[self.graph_key]
 44            switch = graph["switch"]
 45        elif self.switch_key is not None:
 46            switch = inputs[self.switch_key]
 47        else:
 48            raise ValueError("Either graph_key or switch_key must be specified")
 49
 50        x = inputs[self.key]
 51        output = apply_switch(x, switch)
 52        output_key = self.key if self.output_key is None else self.output_key
 53        return {**inputs, output_key: output}
 54
 55
 56class AtomToEdge(nn.Module):
 57    """Map atom-wise values to edge-wise values.
 58
 59    FID: ATOM_TO_EDGE
 60
 61    By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True.
 62    """
 63
 64    _graphs_properties: Dict
 65    key: str
 66    """The key of the input atom-wise array."""
 67    output_key: Optional[str] = None
 68    """The key of the output edge-wise array. If None, the input key is used."""
 69    graph_key: str = "graph"
 70    """The key of the graph containing the edges."""
 71    switch: bool = False
 72    """Whether to apply a switch to the edge values."""
 73    switch_key: Optional[str] = None
 74    """The key of the switch array. If None, the switch is taken from the graph."""
 75    use_source: bool = False
 76    """Whether to use the source atom value instead of the destination atom value."""
 77
 78    FID: ClassVar[str] = "ATOM_TO_EDGE"
 79
 80    @nn.compact
 81    def __call__(self, inputs) -> Any:
 82        graph = inputs[self.graph_key]
 83        nat = inputs["species"].shape[0]
 84        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 85
 86        x = inputs[self.key]
 87        if self.use_source:
 88            x_edge = x[edge_src]
 89        else:
 90            x_edge = x[edge_dst]
 91
 92        if self.switch:
 93            switch = (
 94                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
 95            )
 96            x_edge = apply_switch(x_edge, switch)
 97
 98        output_key = self.key if self.output_key is None else self.output_key
 99        return {**inputs, output_key: x_edge}
100
101
102class ScatterEdges(nn.Module):
103    """Reduce an edge array to atoms by summing over neighbors.
104
105    FID: SCATTER_EDGES
106    """
107
108    _graphs_properties: Dict
109    key: str
110    """The key of the input edge-wise array."""
111    output_key: Optional[str] = None
112    """The key of the output atom-wise array. If None, the input key is used."""
113    graph_key: str = "graph"
114    """The key of the graph containing the edges."""
115    switch: bool = False
116    """Whether to apply a switch to the edge values before summing."""
117    switch_key: Optional[str] = None
118    """The key of the switch array. If None, the switch is taken from the graph."""
119    antisymmetric: bool = False
120
121    FID: ClassVar[str] = "SCATTER_EDGES"
122
123    @nn.compact
124    def __call__(self, inputs) -> Any:
125        graph = inputs[self.graph_key]
126        nat = inputs["species"].shape[0]
127        x = inputs[self.key]
128
129        if self.switch:
130            switch = (
131                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
132            )
133            x = apply_switch(x, switch)
134
135        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
136        output = jax.ops.segment_sum(
137            x, edge_src, nat
138        )  # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop")
139        if self.antisymmetric:
140            output = output - jax.ops.segment_sum(x, edge_dst, nat)
141
142        output_key = self.key if self.output_key is None else self.output_key
143        return {**inputs, output_key: output}
144
145
146class EdgeConcatenate(nn.Module):
147    """Concatenate the source and destination atom values of an edge.
148
149    FID: EDGE_CONCATENATE
150    """
151
152    _graphs_properties: Dict
153    key: str
154    """The key of the input atom-wise array."""
155    output_key: Optional[str] = None
156    """The key of the output edge-wise array. If None, the input key is used."""
157    graph_key: str = "graph"
158    """The key of the graph containing the edges."""
159    switch: bool = False
160    """Whether to apply a switch to the edge values."""
161    switch_key: Optional[str] = None
162    """The key of the switch array. If None, the switch is taken from the graph."""
163    axis: int = -1
164    """The axis along which to concatenate the atom values."""
165
166    FID: ClassVar[str] = "EDGE_CONCATENATE"
167
168    @nn.compact
169    def __call__(self, inputs) -> Any:
170        graph = inputs[self.graph_key]
171        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
172        nat = inputs["species"].shape[0]
173        xi = inputs[self.key]
174
175        assert self._graphs_properties[self.graph_key][
176            "directed"
177        ], "EdgeConcatenate only works for directed graphs"
178        assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat"
179
180        xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis)
181
182        if self.switch:
183            switch = (
184                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
185            )
186            xij = apply_switch(xij, switch)
187
188        output_key = self.name if self.output_key is None else self.output_key
189        return {**inputs, output_key: xij}
190
191
192class ScatterSystem(nn.Module):
193    """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).
194
195    FID: SCATTER_SYSTEM
196    """
197
198    key: str
199    """The key of the input atom-wise array."""
200    output_key: Optional[str] = None
201    """The key of the output system-wise array. If None, the input key is used."""
202    average: bool = False
203    """Wether to divide by the number of atoms in the system."""
204
205    FID: ClassVar[str] = "SCATTER_SYSTEM"
206
207    @nn.compact
208    def __call__(self, inputs) -> Any:
209        batch_index = inputs["batch_index"]
210        x = inputs[self.key]
211        assert (
212            x.shape[0] == batch_index.shape[0]
213        ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}"
214        nsys = inputs["natoms"].shape[0]
215        if self.average:
216            shape = [batch_index.shape[0]] + (x.ndim - 1) * [1]
217            x = x / inputs["natoms"][batch_index].reshape(shape)
218
219        output = jax.ops.segment_sum(x, batch_index, nsys)
220
221        output_key = self.key if self.output_key is None else self.output_key
222        return {**inputs, output_key: output}
223
224
225class SystemToAtoms(nn.Module):
226    """Broadcast a system-wise array to an atom-wise array.
227
228    FID: SYSTEM_TO_ATOMS
229    """
230
231    key: str
232    """The key of the input system-wise array."""
233    output_key: Optional[str] = None
234    """The key of the output atom-wise array. If None, the input key is used."""
235
236    FID: ClassVar[str] = "SYSTEM_TO_ATOMS"
237
238    @nn.compact
239    def __call__(self, inputs) -> Any:
240        batch_index = inputs["batch_index"]
241        x = inputs[self.key]
242        output = x[batch_index]
243
244        output_key = self.key if self.output_key is None else self.output_key
245        return {**inputs, output_key: output}
246
247
248class SumAxis(nn.Module):
249    """Sum an array along an axis.
250
251    FID: SUM_AXIS
252    """
253
254    key: str
255    """The key of the input array."""
256    axis: Union[None, int, Sequence[int]] = None
257    """The axis along which to sum the array."""
258    output_key: Optional[str] = None
259    """The key of the output array. If None, the input key is used."""
260    norm: Optional[str] = None
261    """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'."""
262
263    FID: ClassVar[str] = "SUM_AXIS"
264
265    @nn.compact
266    def __call__(self, inputs) -> Any:
267        x = inputs[self.key]
268        output = jnp.sum(x, axis=self.axis)
269        if self.norm is not None:
270            norm = self.norm.lower()
271            if norm == "dim":
272                dim = np.prod(x.shape[self.axis])
273                output = output / dim
274            elif norm == "sqrt":
275                dim = np.prod(x.shape[self.axis])
276                output = output / dim**0.5
277            elif norm == "none":
278                pass
279            else:
280                raise ValueError(f"Unknown norm {norm}")
281        output_key = self.key if self.output_key is None else self.output_key
282        return {**inputs, output_key: output}
283
284
285class Split(nn.Module):
286    """Split an array along an axis.
287
288    FID: SPLIT
289    """
290
291    key: str
292    """The key of the input array."""
293    output_keys: Sequence[str]
294    """The keys of the output arrays."""
295    axis: int = -1
296    """The axis along which to split the array."""
297    sizes: Union[int, Sequence[int]] = 1
298    """The sizes of the splits."""
299    squeeze: bool = True
300    """Whether to remove the axis in the output if the size is 1."""
301
302    FID: ClassVar[str] = "SPLIT"
303
304    @nn.compact
305    def __call__(self, inputs) -> Any:
306        x = inputs[self.key]
307
308        if isinstance(self.sizes, int):
309            split_size = [self.sizes] * len(self.output_keys)
310        else:
311            split_size = self.sizes
312        if len(split_size) == len(self.output_keys):
313            assert (
314                sum(split_size) == x.shape[self.axis]
315            ), f"Split sizes {split_size} do not match input shape"
316            split_size = split_size[:-1]
317        assert (
318            len(split_size) == len(self.output_keys) - 1
319        ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs"
320        split_indices = np.cumsum(split_size)
321        outs = {}
322
323        for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)):
324            outs[k] = (
325                jnp.squeeze(v, axis=self.axis)
326                if self.squeeze and v.shape[self.axis] == 1
327                else v
328            )
329
330        return {**inputs, **outs}
331
332
333class Concatenate(nn.Module):
334    """Concatenate a list of arrays along an axis.
335
336    FID: CONCATENATE
337    """
338
339    keys: Sequence[str]
340    """The keys of the input arrays."""
341    axis: int = -1
342    """The axis along which to concatenate the arrays."""
343    output_key: Optional[str] = None
344    """The key of the output array. If None, the name of the module is used."""
345
346    FID: ClassVar[str] = "CONCATENATE"
347
348    @nn.compact
349    def __call__(self, inputs) -> Any:
350        output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis)
351        output_key = self.output_key if self.output_key is not None else self.name
352        return {**inputs, output_key: output}
353
354
355class Activation(nn.Module):
356    """Apply an element-wise activation function to an array.
357
358    FID: ACTIVATION
359    """
360
361    key: str
362    """The key of the input array."""
363    activation: Union[Callable, str]
364    """The activation function or its name."""
365    scale_out: float = 1.0
366    """Output scaling factor."""
367    shift_out: float = 0.0
368    """Output shift."""
369    output_key: Optional[str] = None
370    """The key of the output array. If None, the input key is used."""
371
372    FID: ClassVar[str] = "ACTIVATION"
373
374    @nn.compact
375    def __call__(self, inputs) -> Any:
376        x = inputs[self.key]
377        activation = (
378            activation_from_str(self.activation)
379            if isinstance(self.activation, str)
380            else self.activation
381        )
382        output = self.scale_out * activation(x) + self.shift_out
383        output_key = self.output_key if self.output_key is not None else self.key
384        return {**inputs, output_key: output}
385
386class LayerNorm(nn.Module):
387    """Layer normalization module.
388
389    FID: LAYER_NORM
390    """
391
392    key: Optional[str] = None
393    """The key of the input."""
394    output_key: Optional[str] = None
395    """The key of the output. If None, it is the same as the key."""
396    axis: Union[int, Sequence[int]] = -1
397    """The axis to normalize."""
398    epsilon: float = 1e-6
399    """The epsilon for numerical stability."""
400    scale: float = 1.0
401    """The scale for the normalization."""
402    shift: float = 0.0
403    """The shift for the normalization."""
404
405    FID: ClassVar[str] = "LAYER_NORM"
406
407    @nn.compact
408    def __call__(self, inputs):
409        if isinstance(inputs, dict):
410            if self.key is None:
411                raise ValueError("Key must be specified for LayerNorm")
412            x = inputs[self.key]
413        else:
414            x = inputs
415        mu = jnp.mean(x, axis=self.axis, keepdims=True)
416        dx = x-mu
417        var = jnp.mean(dx ** 2, axis=self.axis, keepdims=True)
418        sig = (self.epsilon + var) ** (-0.5)
419        out = self.scale * (sig * dx) + self.shift
420
421        output_key = self.output_key if self.output_key is not None else self.key
422        
423        if isinstance(inputs, dict):
424            return {**inputs, output_key: out}
425        
426        return out
427        
428
429
430class Scale(nn.Module):
431    """Scale an array by a constant factor.
432
433    FID: SCALE
434    """
435
436    key: str
437    """The key of the input array."""
438    scale: float
439    """The (initial) scaling factor."""
440    output_key: Optional[str] = None
441    """The key of the output array. If None, the input key is used."""
442    trainable: bool = False
443    """Whether the scaling factor is trainable."""
444
445    FID: ClassVar[str] = "SCALE"
446
447    @nn.compact
448    def __call__(self, inputs) -> Any:
449        x = inputs[self.key]
450
451        if self.trainable:
452            scale = self.param("scale", lambda rng: jnp.asarray(self.scale))
453        else:
454            scale = self.scale
455
456        output = scale * x
457        output_key = self.output_key if self.output_key is not None else self.key
458        return {**inputs, output_key: output}
459
460
461class Add(nn.Module):
462    """Add together a list of arrays.
463
464    FID: ADD
465    """
466
467    keys: Sequence[str]
468    """The keys of the input arrays."""
469    output_key: Optional[str] = None
470    """The key of the output array. If None, the name of the module is used."""
471
472    FID: ClassVar[str] = "ADD"
473
474    @nn.compact
475    def __call__(self, inputs) -> Any:
476        output = 0
477        for k in self.keys:
478            output = output + inputs[k]
479
480        output_key = self.output_key if self.output_key is not None else self.name
481        return {**inputs, output_key: output}
482
483
484class Multiply(nn.Module):
485    """Element-wise-multiply together a list of arrays.
486
487    FID: MULTIPLY
488    """
489
490    keys: Sequence[str]
491    """The keys of the input arrays."""
492    output_key: Optional[str] = None
493    """The key of the output array. If None, the name of the module is used."""
494
495    FID: ClassVar[str] = "MULTIPLY"
496
497    @nn.compact
498    def __call__(self, inputs) -> Any:
499        output = 1
500        for k in self.keys:
501            output = output * inputs[k]
502
503        output_key = self.output_key if self.output_key is not None else self.name
504        return {**inputs, output_key: output}
505
506
507class Transpose(nn.Module):
508    """Transpose an array.
509
510    FID: TRANSPOSE
511    """
512
513    key: str
514    """The key of the input array."""
515    axes: Sequence[int]
516    """The permutation of the axes. See `jax.numpy.transpose` for more details."""
517    output_key: Optional[str] = None
518    """The key of the output array. If None, the input key is used."""
519
520    FID: ClassVar[str] = "TRANSPOSE"
521
522    @nn.compact
523    def __call__(self, inputs) -> Any:
524        output = jnp.transpose(inputs[self.key], axes=self.axes)
525        output_key = self.output_key if self.output_key is not None else self.key
526        return {**inputs, output_key: output}
527
528
529class Reshape(nn.Module):
530    """Reshape an array.
531
532    FID: RESHAPE
533    """
534
535    key: str
536    """The key of the input array."""
537    shape: Sequence[Union[int,str]]
538    """The shape of the output array."""
539    output_key: Optional[str] = None
540    """The key of the output array. If None, the input key is used."""
541
542    FID: ClassVar[str] = "RESHAPE"
543
544    @nn.compact
545    def __call__(self, inputs) -> Any:
546        shape = []
547        for s in self.shape:
548            if isinstance(s,int):
549                shape.append(s)
550                continue
551
552            if isinstance(s,str):
553                s_=s.lower().strip()
554                if s_ in ["natoms" ,"nat","natom","n_atoms","atoms"]:
555                    shape.append(inputs["species"].shape[0])
556                    continue
557                
558                if s_ in ["nsys","nbatch","nsystems","n_sys","n_systems","n_batch"]:
559                    shape.append(inputs["natoms"].shape[0])
560                    continue
561                
562                s_ = s.strip().split("[")
563                key = s_[0]
564                if key in inputs:
565                    axis = int(s_[1].split("]")[0])
566                    shape.append(inputs[key].shape[axis])
567                    continue
568
569            raise ValueError(f"Error parsing shape component {s}")
570
571        output = jnp.reshape(inputs[self.key], shape)
572        output_key = self.output_key if self.output_key is not None else self.key
573        return {**inputs, output_key: output}
574
575
576class ChemicalConstant(nn.Module):
577    """Map atomic species to a constant value.
578
579    FID: CHEMICAL_CONSTANT
580    """
581
582    value: Union[str, List[float], float, Dict]
583    """The constant value or a dictionary of values for each element."""
584    output_key: Optional[str] = None
585    """The key of the output array. If None, the name of the module is used."""
586    trainable: bool = False
587    """Whether the constant is trainable."""
588
589    FID: ClassVar[str] = "CHEMICAL_CONSTANT"
590
591    @nn.compact
592    def __call__(self, inputs) -> Any:
593        if isinstance(self.value, str):
594            constant = CHEMICAL_PROPERTIES[self.value.upper()]
595        elif isinstance(self.value, list) or isinstance(self.value, tuple):
596            constant = list(self.value)
597        elif isinstance(self.value, float):
598            constant = [self.value] * len(PERIODIC_TABLE)
599        elif hasattr(self.value, "items"):
600            constant = [0.0] * len(PERIODIC_TABLE)
601            for k, v in self.value.items():
602                constant[PERIODIC_TABLE_REV_IDX[k]] = v
603        else:
604            raise ValueError(f"Unknown constant type {type(self.value)}")
605
606        if self.trainable:
607            constant = self.param(
608                "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32)
609            )
610        else:
611            constant = jnp.asarray(constant, dtype=jnp.float32)
612        output = constant[inputs["species"]]
613        output_key = self.output_key if self.output_key is not None else self.name
614        return {**inputs, output_key: output}
615
616
617class SwitchFunction(nn.Module):
618    """Compute a switch array from an array of distances and a cutoff.
619
620    FID: SWITCH_FUNCTION
621    """
622
623    cutoff: Optional[float] = None
624    """The cutoff distance. If None, the cutoff is taken from the graph."""
625    switch_start: float = 0.0
626    """The proportion of the cutoff distance at which the switch function starts."""
627    graph_key: Optional[str] = "graph"
628    """The key of the graph containing the distances and edge mask."""
629    output_key: Optional[str] = None
630    """The key of the output switch array. If None, it is added to the graph."""
631    switch_type: str = "cosine"
632    """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'."""
633    p: Optional[float] = None
634    """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`."""
635    trainable: bool = False
636    """Whether the switch parameter is trainable."""
637
638    FID: ClassVar[str] = "SWITCH_FUNCTION"
639
640    @nn.compact
641    def __call__(self, inputs) -> Any:
642        if self.graph_key is not None:
643            graph = inputs[self.graph_key]
644            distances, edge_mask = graph["distances"], graph["edge_mask"]
645            if self.cutoff is not None:
646                edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff))
647                cutoff = self.cutoff
648            else:
649                cutoff = graph["cutoff"]
650        else:
651            # distances = inputs
652            if len(inputs) == 3:
653                distances, edge_mask, cutoff = inputs
654            else:
655                distances, edge_mask = inputs
656                assert (
657                    self.cutoff is not None
658                ), "cutoff must be specified if no graph is given"
659                # edge_mask = distances < self.cutoff
660                cutoff = self.cutoff
661
662        if self.switch_start > 1.0e-5:
663            assert (
664                self.switch_start < 1.0
665            ), "switch_start is a proportion of cutoff and must be smaller than 1."
666            cutoff_in = self.switch_start * cutoff
667            x = distances - cutoff_in
668            end = cutoff - cutoff_in
669        else:
670            x = distances
671            end = cutoff
672
673        switch_type = self.switch_type.lower()
674        if switch_type == "cosine":
675            p = self.p if self.p is not None else 1.0
676            if self.trainable:
677                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
678            switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p
679
680        elif switch_type == "polynomial":
681            p = self.p if self.p is not None else 3.0
682            if self.trainable:
683                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
684            d = x / end
685            switch = (
686                1.0
687                - 0.5 * (p + 1) * (p + 2) * d**p
688                + p * (p + 2) * d ** (p + 1)
689                - 0.5 * p * (p + 1) * d ** (p + 2)
690            )
691
692        elif switch_type == "exponential":
693            p = self.p if self.p is not None else 1.0
694            if self.trainable:
695                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
696            r2 = x**2
697            c2 = end**2
698            switch = jnp.exp(-p * r2 / (c2 - r2))
699        
700        elif switch_type == "hard":
701            switch = jnp.where(distances < cutoff, 1.0, 0.0)
702        else:
703            raise ValueError(f"Unknown switch function {switch_type}")
704
705        if self.switch_start > 1.0e-5:
706            switch = jnp.where(distances < cutoff_in, 1.0, switch)
707
708        switch = jnp.where(edge_mask, switch, 0.0)
709
710        if self.graph_key is not None:
711            if self.output_key is not None:
712                return {**inputs, self.output_key: switch}
713            else:
714                return {**inputs, self.graph_key: {**graph, "switch": switch}}
715        else:
716            return switch  # , edge_mask
class ApplySwitch(flax.linen.module.Module):
24class ApplySwitch(nn.Module):
25    """Multiply an edge array by a switch array.
26
27    FID: APPLY_SWITCH
28    """
29
30    key: str
31    """The key of the input array."""
32    switch_key: Optional[str] = None
33    """The key of the switch array."""
34    graph_key: Optional[str] = None
35    """The key of the graph containing the switch."""
36    output_key: Optional[str] = None
37    """The key of the output array. If None, the input key is used."""
38
39    FID: ClassVar[str] = "APPLY_SWITCH"
40
41    @nn.compact
42    def __call__(self, inputs) -> Any:
43        if self.graph_key is not None:
44            graph = inputs[self.graph_key]
45            switch = graph["switch"]
46        elif self.switch_key is not None:
47            switch = inputs[self.switch_key]
48        else:
49            raise ValueError("Either graph_key or switch_key must be specified")
50
51        x = inputs[self.key]
52        output = apply_switch(x, switch)
53        output_key = self.key if self.output_key is None else self.output_key
54        return {**inputs, output_key: output}

Multiply an edge array by a switch array.

FID: APPLY_SWITCH

ApplySwitch( key: str, switch_key: Optional[str] = None, graph_key: Optional[str] = None, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

switch_key: Optional[str] = None

The key of the switch array.

graph_key: Optional[str] = None

The key of the graph containing the switch.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

FID: ClassVar[str] = 'APPLY_SWITCH'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class AtomToEdge(flax.linen.module.Module):
 57class AtomToEdge(nn.Module):
 58    """Map atom-wise values to edge-wise values.
 59
 60    FID: ATOM_TO_EDGE
 61
 62    By default, we map the destination atom value to the edge. This can be changed by setting `use_source` to True.
 63    """
 64
 65    _graphs_properties: Dict
 66    key: str
 67    """The key of the input atom-wise array."""
 68    output_key: Optional[str] = None
 69    """The key of the output edge-wise array. If None, the input key is used."""
 70    graph_key: str = "graph"
 71    """The key of the graph containing the edges."""
 72    switch: bool = False
 73    """Whether to apply a switch to the edge values."""
 74    switch_key: Optional[str] = None
 75    """The key of the switch array. If None, the switch is taken from the graph."""
 76    use_source: bool = False
 77    """Whether to use the source atom value instead of the destination atom value."""
 78
 79    FID: ClassVar[str] = "ATOM_TO_EDGE"
 80
 81    @nn.compact
 82    def __call__(self, inputs) -> Any:
 83        graph = inputs[self.graph_key]
 84        nat = inputs["species"].shape[0]
 85        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 86
 87        x = inputs[self.key]
 88        if self.use_source:
 89            x_edge = x[edge_src]
 90        else:
 91            x_edge = x[edge_dst]
 92
 93        if self.switch:
 94            switch = (
 95                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
 96            )
 97            x_edge = apply_switch(x_edge, switch)
 98
 99        output_key = self.key if self.output_key is None else self.output_key
100        return {**inputs, output_key: x_edge}

Map atom-wise values to edge-wise values.

FID: ATOM_TO_EDGE

By default, we map the destination atom value to the edge. This can be changed by setting use_source to True.

AtomToEdge( _graphs_properties: Dict, key: str, output_key: Optional[str] = None, graph_key: str = 'graph', switch: bool = False, switch_key: Optional[str] = None, use_source: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input atom-wise array.

output_key: Optional[str] = None

The key of the output edge-wise array. If None, the input key is used.

graph_key: str = 'graph'

The key of the graph containing the edges.

switch: bool = False

Whether to apply a switch to the edge values.

switch_key: Optional[str] = None

The key of the switch array. If None, the switch is taken from the graph.

use_source: bool = False

Whether to use the source atom value instead of the destination atom value.

FID: ClassVar[str] = 'ATOM_TO_EDGE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ScatterEdges(flax.linen.module.Module):
103class ScatterEdges(nn.Module):
104    """Reduce an edge array to atoms by summing over neighbors.
105
106    FID: SCATTER_EDGES
107    """
108
109    _graphs_properties: Dict
110    key: str
111    """The key of the input edge-wise array."""
112    output_key: Optional[str] = None
113    """The key of the output atom-wise array. If None, the input key is used."""
114    graph_key: str = "graph"
115    """The key of the graph containing the edges."""
116    switch: bool = False
117    """Whether to apply a switch to the edge values before summing."""
118    switch_key: Optional[str] = None
119    """The key of the switch array. If None, the switch is taken from the graph."""
120    antisymmetric: bool = False
121
122    FID: ClassVar[str] = "SCATTER_EDGES"
123
124    @nn.compact
125    def __call__(self, inputs) -> Any:
126        graph = inputs[self.graph_key]
127        nat = inputs["species"].shape[0]
128        x = inputs[self.key]
129
130        if self.switch:
131            switch = (
132                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
133            )
134            x = apply_switch(x, switch)
135
136        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
137        output = jax.ops.segment_sum(
138            x, edge_src, nat
139        )  # jnp.zeros((nat, *x.shape[1:])).at[edge_src].add(x,mode="drop")
140        if self.antisymmetric:
141            output = output - jax.ops.segment_sum(x, edge_dst, nat)
142
143        output_key = self.key if self.output_key is None else self.output_key
144        return {**inputs, output_key: output}

Reduce an edge array to atoms by summing over neighbors.

FID: SCATTER_EDGES

ScatterEdges( _graphs_properties: Dict, key: str, output_key: Optional[str] = None, graph_key: str = 'graph', switch: bool = False, switch_key: Optional[str] = None, antisymmetric: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input edge-wise array.

output_key: Optional[str] = None

The key of the output atom-wise array. If None, the input key is used.

graph_key: str = 'graph'

The key of the graph containing the edges.

switch: bool = False

Whether to apply a switch to the edge values before summing.

switch_key: Optional[str] = None

The key of the switch array. If None, the switch is taken from the graph.

antisymmetric: bool = False
FID: ClassVar[str] = 'SCATTER_EDGES'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class EdgeConcatenate(flax.linen.module.Module):
147class EdgeConcatenate(nn.Module):
148    """Concatenate the source and destination atom values of an edge.
149
150    FID: EDGE_CONCATENATE
151    """
152
153    _graphs_properties: Dict
154    key: str
155    """The key of the input atom-wise array."""
156    output_key: Optional[str] = None
157    """The key of the output edge-wise array. If None, the input key is used."""
158    graph_key: str = "graph"
159    """The key of the graph containing the edges."""
160    switch: bool = False
161    """Whether to apply a switch to the edge values."""
162    switch_key: Optional[str] = None
163    """The key of the switch array. If None, the switch is taken from the graph."""
164    axis: int = -1
165    """The axis along which to concatenate the atom values."""
166
167    FID: ClassVar[str] = "EDGE_CONCATENATE"
168
169    @nn.compact
170    def __call__(self, inputs) -> Any:
171        graph = inputs[self.graph_key]
172        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
173        nat = inputs["species"].shape[0]
174        xi = inputs[self.key]
175
176        assert self._graphs_properties[self.graph_key][
177            "directed"
178        ], "EdgeConcatenate only works for directed graphs"
179        assert xi.shape[0] == nat, "Shape mismatch, xi.shape[0] != nat"
180
181        xij = jnp.concatenate([xi[edge_src], xi[edge_dst]], axis=self.axis)
182
183        if self.switch:
184            switch = (
185                graph["switch"] if self.switch_key is None else inputs[self.switch_key]
186            )
187            xij = apply_switch(xij, switch)
188
189        output_key = self.name if self.output_key is None else self.output_key
190        return {**inputs, output_key: xij}

Concatenate the source and destination atom values of an edge.

FID: EDGE_CONCATENATE

EdgeConcatenate( _graphs_properties: Dict, key: str, output_key: Optional[str] = None, graph_key: str = 'graph', switch: bool = False, switch_key: Optional[str] = None, axis: int = -1, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input atom-wise array.

output_key: Optional[str] = None

The key of the output edge-wise array. If None, the input key is used.

graph_key: str = 'graph'

The key of the graph containing the edges.

switch: bool = False

Whether to apply a switch to the edge values.

switch_key: Optional[str] = None

The key of the switch array. If None, the switch is taken from the graph.

axis: int = -1

The axis along which to concatenate the atom values.

FID: ClassVar[str] = 'EDGE_CONCATENATE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ScatterSystem(flax.linen.module.Module):
193class ScatterSystem(nn.Module):
194    """Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).
195
196    FID: SCATTER_SYSTEM
197    """
198
199    key: str
200    """The key of the input atom-wise array."""
201    output_key: Optional[str] = None
202    """The key of the output system-wise array. If None, the input key is used."""
203    average: bool = False
204    """Wether to divide by the number of atoms in the system."""
205
206    FID: ClassVar[str] = "SCATTER_SYSTEM"
207
208    @nn.compact
209    def __call__(self, inputs) -> Any:
210        batch_index = inputs["batch_index"]
211        x = inputs[self.key]
212        assert (
213            x.shape[0] == batch_index.shape[0]
214        ), f"Shape mismatch {x.shape[0]} != {batch_index.shape[0]}"
215        nsys = inputs["natoms"].shape[0]
216        if self.average:
217            shape = [batch_index.shape[0]] + (x.ndim - 1) * [1]
218            x = x / inputs["natoms"][batch_index].reshape(shape)
219
220        output = jax.ops.segment_sum(x, batch_index, nsys)
221
222        output_key = self.key if self.output_key is None else self.output_key
223        return {**inputs, output_key: output}

Reduce an atom-wise array to a system-wise array by summing over atoms (in the batch).

FID: SCATTER_SYSTEM

ScatterSystem( key: str, output_key: Optional[str] = None, average: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input atom-wise array.

output_key: Optional[str] = None

The key of the output system-wise array. If None, the input key is used.

average: bool = False

Wether to divide by the number of atoms in the system.

FID: ClassVar[str] = 'SCATTER_SYSTEM'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SystemToAtoms(flax.linen.module.Module):
226class SystemToAtoms(nn.Module):
227    """Broadcast a system-wise array to an atom-wise array.
228
229    FID: SYSTEM_TO_ATOMS
230    """
231
232    key: str
233    """The key of the input system-wise array."""
234    output_key: Optional[str] = None
235    """The key of the output atom-wise array. If None, the input key is used."""
236
237    FID: ClassVar[str] = "SYSTEM_TO_ATOMS"
238
239    @nn.compact
240    def __call__(self, inputs) -> Any:
241        batch_index = inputs["batch_index"]
242        x = inputs[self.key]
243        output = x[batch_index]
244
245        output_key = self.key if self.output_key is None else self.output_key
246        return {**inputs, output_key: output}

Broadcast a system-wise array to an atom-wise array.

FID: SYSTEM_TO_ATOMS

SystemToAtoms( key: str, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input system-wise array.

output_key: Optional[str] = None

The key of the output atom-wise array. If None, the input key is used.

FID: ClassVar[str] = 'SYSTEM_TO_ATOMS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SumAxis(flax.linen.module.Module):
249class SumAxis(nn.Module):
250    """Sum an array along an axis.
251
252    FID: SUM_AXIS
253    """
254
255    key: str
256    """The key of the input array."""
257    axis: Union[None, int, Sequence[int]] = None
258    """The axis along which to sum the array."""
259    output_key: Optional[str] = None
260    """The key of the output array. If None, the input key is used."""
261    norm: Optional[str] = None
262    """Normalization of the sum. Can be 'dim', 'sqrt', or 'none'."""
263
264    FID: ClassVar[str] = "SUM_AXIS"
265
266    @nn.compact
267    def __call__(self, inputs) -> Any:
268        x = inputs[self.key]
269        output = jnp.sum(x, axis=self.axis)
270        if self.norm is not None:
271            norm = self.norm.lower()
272            if norm == "dim":
273                dim = np.prod(x.shape[self.axis])
274                output = output / dim
275            elif norm == "sqrt":
276                dim = np.prod(x.shape[self.axis])
277                output = output / dim**0.5
278            elif norm == "none":
279                pass
280            else:
281                raise ValueError(f"Unknown norm {norm}")
282        output_key = self.key if self.output_key is None else self.output_key
283        return {**inputs, output_key: output}

Sum an array along an axis.

FID: SUM_AXIS

SumAxis( key: str, axis: Union[NoneType, int, Sequence[int]] = None, output_key: Optional[str] = None, norm: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

axis: Union[NoneType, int, Sequence[int]] = None

The axis along which to sum the array.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

norm: Optional[str] = None

Normalization of the sum. Can be 'dim', 'sqrt', or 'none'.

FID: ClassVar[str] = 'SUM_AXIS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Split(flax.linen.module.Module):
286class Split(nn.Module):
287    """Split an array along an axis.
288
289    FID: SPLIT
290    """
291
292    key: str
293    """The key of the input array."""
294    output_keys: Sequence[str]
295    """The keys of the output arrays."""
296    axis: int = -1
297    """The axis along which to split the array."""
298    sizes: Union[int, Sequence[int]] = 1
299    """The sizes of the splits."""
300    squeeze: bool = True
301    """Whether to remove the axis in the output if the size is 1."""
302
303    FID: ClassVar[str] = "SPLIT"
304
305    @nn.compact
306    def __call__(self, inputs) -> Any:
307        x = inputs[self.key]
308
309        if isinstance(self.sizes, int):
310            split_size = [self.sizes] * len(self.output_keys)
311        else:
312            split_size = self.sizes
313        if len(split_size) == len(self.output_keys):
314            assert (
315                sum(split_size) == x.shape[self.axis]
316            ), f"Split sizes {split_size} do not match input shape"
317            split_size = split_size[:-1]
318        assert (
319            len(split_size) == len(self.output_keys) - 1
320        ), f"Wrong number of split sizes {split_size} for {len(self.output_keys)} outputs"
321        split_indices = np.cumsum(split_size)
322        outs = {}
323
324        for k, v in zip(self.output_keys, jnp.split(x, split_indices, axis=self.axis)):
325            outs[k] = (
326                jnp.squeeze(v, axis=self.axis)
327                if self.squeeze and v.shape[self.axis] == 1
328                else v
329            )
330
331        return {**inputs, **outs}

Split an array along an axis.

FID: SPLIT

Split( key: str, output_keys: Sequence[str], axis: int = -1, sizes: Union[int, Sequence[int]] = 1, squeeze: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

output_keys: Sequence[str]

The keys of the output arrays.

axis: int = -1

The axis along which to split the array.

sizes: Union[int, Sequence[int]] = 1

The sizes of the splits.

squeeze: bool = True

Whether to remove the axis in the output if the size is 1.

FID: ClassVar[str] = 'SPLIT'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Concatenate(flax.linen.module.Module):
334class Concatenate(nn.Module):
335    """Concatenate a list of arrays along an axis.
336
337    FID: CONCATENATE
338    """
339
340    keys: Sequence[str]
341    """The keys of the input arrays."""
342    axis: int = -1
343    """The axis along which to concatenate the arrays."""
344    output_key: Optional[str] = None
345    """The key of the output array. If None, the name of the module is used."""
346
347    FID: ClassVar[str] = "CONCATENATE"
348
349    @nn.compact
350    def __call__(self, inputs) -> Any:
351        output = jnp.concatenate([inputs[k] for k in self.keys], axis=self.axis)
352        output_key = self.output_key if self.output_key is not None else self.name
353        return {**inputs, output_key: output}

Concatenate a list of arrays along an axis.

FID: CONCATENATE

Concatenate( keys: Sequence[str], axis: int = -1, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
keys: Sequence[str]

The keys of the input arrays.

axis: int = -1

The axis along which to concatenate the arrays.

output_key: Optional[str] = None

The key of the output array. If None, the name of the module is used.

FID: ClassVar[str] = 'CONCATENATE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Activation(flax.linen.module.Module):
356class Activation(nn.Module):
357    """Apply an element-wise activation function to an array.
358
359    FID: ACTIVATION
360    """
361
362    key: str
363    """The key of the input array."""
364    activation: Union[Callable, str]
365    """The activation function or its name."""
366    scale_out: float = 1.0
367    """Output scaling factor."""
368    shift_out: float = 0.0
369    """Output shift."""
370    output_key: Optional[str] = None
371    """The key of the output array. If None, the input key is used."""
372
373    FID: ClassVar[str] = "ACTIVATION"
374
375    @nn.compact
376    def __call__(self, inputs) -> Any:
377        x = inputs[self.key]
378        activation = (
379            activation_from_str(self.activation)
380            if isinstance(self.activation, str)
381            else self.activation
382        )
383        output = self.scale_out * activation(x) + self.shift_out
384        output_key = self.output_key if self.output_key is not None else self.key
385        return {**inputs, output_key: output}

Apply an element-wise activation function to an array.

FID: ACTIVATION

Activation( key: str, activation: Union[Callable, str], scale_out: float = 1.0, shift_out: float = 0.0, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

activation: Union[Callable, str]

The activation function or its name.

scale_out: float = 1.0

Output scaling factor.

shift_out: float = 0.0

Output shift.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

FID: ClassVar[str] = 'ACTIVATION'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class LayerNorm(flax.linen.module.Module):
387class LayerNorm(nn.Module):
388    """Layer normalization module.
389
390    FID: LAYER_NORM
391    """
392
393    key: Optional[str] = None
394    """The key of the input."""
395    output_key: Optional[str] = None
396    """The key of the output. If None, it is the same as the key."""
397    axis: Union[int, Sequence[int]] = -1
398    """The axis to normalize."""
399    epsilon: float = 1e-6
400    """The epsilon for numerical stability."""
401    scale: float = 1.0
402    """The scale for the normalization."""
403    shift: float = 0.0
404    """The shift for the normalization."""
405
406    FID: ClassVar[str] = "LAYER_NORM"
407
408    @nn.compact
409    def __call__(self, inputs):
410        if isinstance(inputs, dict):
411            if self.key is None:
412                raise ValueError("Key must be specified for LayerNorm")
413            x = inputs[self.key]
414        else:
415            x = inputs
416        mu = jnp.mean(x, axis=self.axis, keepdims=True)
417        dx = x-mu
418        var = jnp.mean(dx ** 2, axis=self.axis, keepdims=True)
419        sig = (self.epsilon + var) ** (-0.5)
420        out = self.scale * (sig * dx) + self.shift
421
422        output_key = self.output_key if self.output_key is not None else self.key
423        
424        if isinstance(inputs, dict):
425            return {**inputs, output_key: out}
426        
427        return out

Layer normalization module.

FID: LAYER_NORM

LayerNorm( key: Optional[str] = None, output_key: Optional[str] = None, axis: Union[int, Sequence[int]] = -1, epsilon: float = 1e-06, scale: float = 1.0, shift: float = 0.0, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: Optional[str] = None

The key of the input.

output_key: Optional[str] = None

The key of the output. If None, it is the same as the key.

axis: Union[int, Sequence[int]] = -1

The axis to normalize.

epsilon: float = 1e-06

The epsilon for numerical stability.

scale: float = 1.0

The scale for the normalization.

shift: float = 0.0

The shift for the normalization.

FID: ClassVar[str] = 'LAYER_NORM'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Scale(flax.linen.module.Module):
431class Scale(nn.Module):
432    """Scale an array by a constant factor.
433
434    FID: SCALE
435    """
436
437    key: str
438    """The key of the input array."""
439    scale: float
440    """The (initial) scaling factor."""
441    output_key: Optional[str] = None
442    """The key of the output array. If None, the input key is used."""
443    trainable: bool = False
444    """Whether the scaling factor is trainable."""
445
446    FID: ClassVar[str] = "SCALE"
447
448    @nn.compact
449    def __call__(self, inputs) -> Any:
450        x = inputs[self.key]
451
452        if self.trainable:
453            scale = self.param("scale", lambda rng: jnp.asarray(self.scale))
454        else:
455            scale = self.scale
456
457        output = scale * x
458        output_key = self.output_key if self.output_key is not None else self.key
459        return {**inputs, output_key: output}

Scale an array by a constant factor.

FID: SCALE

Scale( key: str, scale: float, output_key: Optional[str] = None, trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

scale: float

The (initial) scaling factor.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

trainable: bool = False

Whether the scaling factor is trainable.

FID: ClassVar[str] = 'SCALE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Add(flax.linen.module.Module):
462class Add(nn.Module):
463    """Add together a list of arrays.
464
465    FID: ADD
466    """
467
468    keys: Sequence[str]
469    """The keys of the input arrays."""
470    output_key: Optional[str] = None
471    """The key of the output array. If None, the name of the module is used."""
472
473    FID: ClassVar[str] = "ADD"
474
475    @nn.compact
476    def __call__(self, inputs) -> Any:
477        output = 0
478        for k in self.keys:
479            output = output + inputs[k]
480
481        output_key = self.output_key if self.output_key is not None else self.name
482        return {**inputs, output_key: output}

Add together a list of arrays.

FID: ADD

Add( keys: Sequence[str], output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
keys: Sequence[str]

The keys of the input arrays.

output_key: Optional[str] = None

The key of the output array. If None, the name of the module is used.

FID: ClassVar[str] = 'ADD'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Multiply(flax.linen.module.Module):
485class Multiply(nn.Module):
486    """Element-wise-multiply together a list of arrays.
487
488    FID: MULTIPLY
489    """
490
491    keys: Sequence[str]
492    """The keys of the input arrays."""
493    output_key: Optional[str] = None
494    """The key of the output array. If None, the name of the module is used."""
495
496    FID: ClassVar[str] = "MULTIPLY"
497
498    @nn.compact
499    def __call__(self, inputs) -> Any:
500        output = 1
501        for k in self.keys:
502            output = output * inputs[k]
503
504        output_key = self.output_key if self.output_key is not None else self.name
505        return {**inputs, output_key: output}

Element-wise-multiply together a list of arrays.

FID: MULTIPLY

Multiply( keys: Sequence[str], output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
keys: Sequence[str]

The keys of the input arrays.

output_key: Optional[str] = None

The key of the output array. If None, the name of the module is used.

FID: ClassVar[str] = 'MULTIPLY'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Transpose(flax.linen.module.Module):
508class Transpose(nn.Module):
509    """Transpose an array.
510
511    FID: TRANSPOSE
512    """
513
514    key: str
515    """The key of the input array."""
516    axes: Sequence[int]
517    """The permutation of the axes. See `jax.numpy.transpose` for more details."""
518    output_key: Optional[str] = None
519    """The key of the output array. If None, the input key is used."""
520
521    FID: ClassVar[str] = "TRANSPOSE"
522
523    @nn.compact
524    def __call__(self, inputs) -> Any:
525        output = jnp.transpose(inputs[self.key], axes=self.axes)
526        output_key = self.output_key if self.output_key is not None else self.key
527        return {**inputs, output_key: output}

Transpose an array.

FID: TRANSPOSE

Transpose( key: str, axes: Sequence[int], output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

axes: Sequence[int]

The permutation of the axes. See jax.numpy.transpose for more details.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

FID: ClassVar[str] = 'TRANSPOSE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class Reshape(flax.linen.module.Module):
530class Reshape(nn.Module):
531    """Reshape an array.
532
533    FID: RESHAPE
534    """
535
536    key: str
537    """The key of the input array."""
538    shape: Sequence[Union[int,str]]
539    """The shape of the output array."""
540    output_key: Optional[str] = None
541    """The key of the output array. If None, the input key is used."""
542
543    FID: ClassVar[str] = "RESHAPE"
544
545    @nn.compact
546    def __call__(self, inputs) -> Any:
547        shape = []
548        for s in self.shape:
549            if isinstance(s,int):
550                shape.append(s)
551                continue
552
553            if isinstance(s,str):
554                s_=s.lower().strip()
555                if s_ in ["natoms" ,"nat","natom","n_atoms","atoms"]:
556                    shape.append(inputs["species"].shape[0])
557                    continue
558                
559                if s_ in ["nsys","nbatch","nsystems","n_sys","n_systems","n_batch"]:
560                    shape.append(inputs["natoms"].shape[0])
561                    continue
562                
563                s_ = s.strip().split("[")
564                key = s_[0]
565                if key in inputs:
566                    axis = int(s_[1].split("]")[0])
567                    shape.append(inputs[key].shape[axis])
568                    continue
569
570            raise ValueError(f"Error parsing shape component {s}")
571
572        output = jnp.reshape(inputs[self.key], shape)
573        output_key = self.output_key if self.output_key is not None else self.key
574        return {**inputs, output_key: output}

Reshape an array.

FID: RESHAPE

Reshape( key: str, shape: Sequence[Union[int, str]], output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key of the input array.

shape: Sequence[Union[int, str]]

The shape of the output array.

output_key: Optional[str] = None

The key of the output array. If None, the input key is used.

FID: ClassVar[str] = 'RESHAPE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChemicalConstant(flax.linen.module.Module):
577class ChemicalConstant(nn.Module):
578    """Map atomic species to a constant value.
579
580    FID: CHEMICAL_CONSTANT
581    """
582
583    value: Union[str, List[float], float, Dict]
584    """The constant value or a dictionary of values for each element."""
585    output_key: Optional[str] = None
586    """The key of the output array. If None, the name of the module is used."""
587    trainable: bool = False
588    """Whether the constant is trainable."""
589
590    FID: ClassVar[str] = "CHEMICAL_CONSTANT"
591
592    @nn.compact
593    def __call__(self, inputs) -> Any:
594        if isinstance(self.value, str):
595            constant = CHEMICAL_PROPERTIES[self.value.upper()]
596        elif isinstance(self.value, list) or isinstance(self.value, tuple):
597            constant = list(self.value)
598        elif isinstance(self.value, float):
599            constant = [self.value] * len(PERIODIC_TABLE)
600        elif hasattr(self.value, "items"):
601            constant = [0.0] * len(PERIODIC_TABLE)
602            for k, v in self.value.items():
603                constant[PERIODIC_TABLE_REV_IDX[k]] = v
604        else:
605            raise ValueError(f"Unknown constant type {type(self.value)}")
606
607        if self.trainable:
608            constant = self.param(
609                "constant", lambda rng: jnp.asarray(constant, dtype=jnp.float32)
610            )
611        else:
612            constant = jnp.asarray(constant, dtype=jnp.float32)
613        output = constant[inputs["species"]]
614        output_key = self.output_key if self.output_key is not None else self.name
615        return {**inputs, output_key: output}

Map atomic species to a constant value.

FID: CHEMICAL_CONSTANT

ChemicalConstant( value: Union[str, List[float], float, Dict], output_key: Optional[str] = None, trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
value: Union[str, List[float], float, Dict]

The constant value or a dictionary of values for each element.

output_key: Optional[str] = None

The key of the output array. If None, the name of the module is used.

trainable: bool = False

Whether the constant is trainable.

FID: ClassVar[str] = 'CHEMICAL_CONSTANT'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SwitchFunction(flax.linen.module.Module):
618class SwitchFunction(nn.Module):
619    """Compute a switch array from an array of distances and a cutoff.
620
621    FID: SWITCH_FUNCTION
622    """
623
624    cutoff: Optional[float] = None
625    """The cutoff distance. If None, the cutoff is taken from the graph."""
626    switch_start: float = 0.0
627    """The proportion of the cutoff distance at which the switch function starts."""
628    graph_key: Optional[str] = "graph"
629    """The key of the graph containing the distances and edge mask."""
630    output_key: Optional[str] = None
631    """The key of the output switch array. If None, it is added to the graph."""
632    switch_type: str = "cosine"
633    """The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'."""
634    p: Optional[float] = None
635    """ The parameter of the switch function. If None, it is fixed to the default for each `switch_type`."""
636    trainable: bool = False
637    """Whether the switch parameter is trainable."""
638
639    FID: ClassVar[str] = "SWITCH_FUNCTION"
640
641    @nn.compact
642    def __call__(self, inputs) -> Any:
643        if self.graph_key is not None:
644            graph = inputs[self.graph_key]
645            distances, edge_mask = graph["distances"], graph["edge_mask"]
646            if self.cutoff is not None:
647                edge_mask = jnp.logical_and(edge_mask, (distances < self.cutoff))
648                cutoff = self.cutoff
649            else:
650                cutoff = graph["cutoff"]
651        else:
652            # distances = inputs
653            if len(inputs) == 3:
654                distances, edge_mask, cutoff = inputs
655            else:
656                distances, edge_mask = inputs
657                assert (
658                    self.cutoff is not None
659                ), "cutoff must be specified if no graph is given"
660                # edge_mask = distances < self.cutoff
661                cutoff = self.cutoff
662
663        if self.switch_start > 1.0e-5:
664            assert (
665                self.switch_start < 1.0
666            ), "switch_start is a proportion of cutoff and must be smaller than 1."
667            cutoff_in = self.switch_start * cutoff
668            x = distances - cutoff_in
669            end = cutoff - cutoff_in
670        else:
671            x = distances
672            end = cutoff
673
674        switch_type = self.switch_type.lower()
675        if switch_type == "cosine":
676            p = self.p if self.p is not None else 1.0
677            if self.trainable:
678                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
679            switch = (0.5 * jnp.cos(x * (jnp.pi / end)) + 0.5) ** p
680
681        elif switch_type == "polynomial":
682            p = self.p if self.p is not None else 3.0
683            if self.trainable:
684                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
685            d = x / end
686            switch = (
687                1.0
688                - 0.5 * (p + 1) * (p + 2) * d**p
689                + p * (p + 2) * d ** (p + 1)
690                - 0.5 * p * (p + 1) * d ** (p + 2)
691            )
692
693        elif switch_type == "exponential":
694            p = self.p if self.p is not None else 1.0
695            if self.trainable:
696                p = self.param("p", lambda rng: jnp.asarray(p, dtype=jnp.float32))
697            r2 = x**2
698            c2 = end**2
699            switch = jnp.exp(-p * r2 / (c2 - r2))
700        
701        elif switch_type == "hard":
702            switch = jnp.where(distances < cutoff, 1.0, 0.0)
703        else:
704            raise ValueError(f"Unknown switch function {switch_type}")
705
706        if self.switch_start > 1.0e-5:
707            switch = jnp.where(distances < cutoff_in, 1.0, switch)
708
709        switch = jnp.where(edge_mask, switch, 0.0)
710
711        if self.graph_key is not None:
712            if self.output_key is not None:
713                return {**inputs, self.output_key: switch}
714            else:
715                return {**inputs, self.graph_key: {**graph, "switch": switch}}
716        else:
717            return switch  # , edge_mask

Compute a switch array from an array of distances and a cutoff.

FID: SWITCH_FUNCTION

SwitchFunction( cutoff: Optional[float] = None, switch_start: float = 0.0, graph_key: Optional[str] = 'graph', output_key: Optional[str] = None, switch_type: str = 'cosine', p: Optional[float] = None, trainable: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: Optional[float] = None

The cutoff distance. If None, the cutoff is taken from the graph.

switch_start: float = 0.0

The proportion of the cutoff distance at which the switch function starts.

graph_key: Optional[str] = 'graph'

The key of the graph containing the distances and edge mask.

output_key: Optional[str] = None

The key of the output switch array. If None, it is added to the graph.

switch_type: str = 'cosine'

The type of switch function. Can be 'cosine', 'polynomial', or 'exponential'.

p: Optional[float] = None

The parameter of the switch function. If None, it is fixed to the default for each switch_type.

trainable: bool = False

Whether the switch parameter is trainable.

FID: ClassVar[str] = 'SWITCH_FUNCTION'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None