fennol.models.embeddings.foam

 1import jax
 2import jax.numpy as jnp
 3import flax.linen as nn
 4from ...utils.spherical_harmonics import generate_spherical_harmonics
 5from ..misc.encodings import SpeciesEncoding, RadialBasis
 6import dataclasses
 7import numpy as np
 8from typing import  Dict, Optional, ClassVar
 9
10
11class FOAMEmbedding(nn.Module):
12    """Filtered Overlap of Atomic Moments
13    
14    FID : FOAM
15
16    Similar to SOAP embedding but for each rank l, we do not take all combinations
17    of each channels but linearly project on 2 nchannels elements and then take the
18    scalar product. This is then kind of a linearly filtered SOAP embedding.
19
20    """
21    _graphs_properties: Dict
22    lmax: int = 2
23    """The maximum order of spherical tensors."""
24    nchannels: Optional[int] = None
25    """The number of channels."""
26    graph_key: str = "graph"
27    """The key for the graph input."""
28    embedding_key: str = "embedding"
29    """The key for the embedding output."""
30    species_encoding: dict = dataclasses.field(default_factory=dict)
31    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
32    radial_basis: dict = dataclasses.field(default_factory=dict)
33    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`"""
34    include_species: bool = True
35    """Whether to include the species encoding in the embedding."""
36
37    FID: ClassVar[str] = "FOAM"
38
39    @nn.compact
40    def __call__(self, inputs):
41        species = inputs["species"]
42        assert (
43            len(species.shape) == 1
44        ), "Species must be a 1D array (batches must be flattened)"
45
46        graph = inputs[self.graph_key]
47        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
48        switch = graph["switch"][:, None]
49        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
50        radial_basis = (
51            RadialBasis(**{**self.radial_basis, "end": cutoff, "name": "RadialBasis"})(
52                graph["distances"]
53            )
54            * switch
55        )
56
57        species_encoding = SpeciesEncoding(
58            **self.species_encoding, name="SpeciesEncoding"
59        )(species)
60
61        Dij = (radial_basis[:, :, None] * species_encoding[edge_dst, None, :]).reshape(
62            -1, species_encoding.shape[-1] * radial_basis.shape[-1]
63        )
64
65        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
66            graph["vec"] / graph["distances"][:, None]
67        )[:, :, None]
68
69        rhoi = jax.ops.segment_sum(Dij[:, None, :] * Yij, edge_src, species.shape[0])
70
71        nbasis = rhoi.shape[-1]
72        nchannels = self.nchannels if self.nchannels is not None else nbasis
73
74        if self.include_species:
75            xis = [species_encoding, rhoi[:, 0, :]]
76        else:
77            xis = [rhoi[:, 0, :]]
78
79        for l in range(self.lmax + 1):
80            rhoil = rhoi[:, l**2 : (l + 1) ** 2, :]
81            xl, yl = jnp.split(
82                nn.Dense(2 * nchannels, use_bias=False, name=f"xy_l{l}")(rhoil),
83                2,
84                axis=-1,
85            )
86            xil = (xl*yl).sum(axis=1) / (2 * l + 1) ** 0.5
87            xis.append(xil)
88        xi = jnp.concatenate(xis, axis=-1)
89
90        if self.embedding_key is None:
91            return xi
92        return {**inputs, self.embedding_key: xi}
class FOAMEmbedding(flax.linen.module.Module):
12class FOAMEmbedding(nn.Module):
13    """Filtered Overlap of Atomic Moments
14    
15    FID : FOAM
16
17    Similar to SOAP embedding but for each rank l, we do not take all combinations
18    of each channels but linearly project on 2 nchannels elements and then take the
19    scalar product. This is then kind of a linearly filtered SOAP embedding.
20
21    """
22    _graphs_properties: Dict
23    lmax: int = 2
24    """The maximum order of spherical tensors."""
25    nchannels: Optional[int] = None
26    """The number of channels."""
27    graph_key: str = "graph"
28    """The key for the graph input."""
29    embedding_key: str = "embedding"
30    """The key for the embedding output."""
31    species_encoding: dict = dataclasses.field(default_factory=dict)
32    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
33    radial_basis: dict = dataclasses.field(default_factory=dict)
34    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`"""
35    include_species: bool = True
36    """Whether to include the species encoding in the embedding."""
37
38    FID: ClassVar[str] = "FOAM"
39
40    @nn.compact
41    def __call__(self, inputs):
42        species = inputs["species"]
43        assert (
44            len(species.shape) == 1
45        ), "Species must be a 1D array (batches must be flattened)"
46
47        graph = inputs[self.graph_key]
48        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
49        switch = graph["switch"][:, None]
50        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
51        radial_basis = (
52            RadialBasis(**{**self.radial_basis, "end": cutoff, "name": "RadialBasis"})(
53                graph["distances"]
54            )
55            * switch
56        )
57
58        species_encoding = SpeciesEncoding(
59            **self.species_encoding, name="SpeciesEncoding"
60        )(species)
61
62        Dij = (radial_basis[:, :, None] * species_encoding[edge_dst, None, :]).reshape(
63            -1, species_encoding.shape[-1] * radial_basis.shape[-1]
64        )
65
66        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
67            graph["vec"] / graph["distances"][:, None]
68        )[:, :, None]
69
70        rhoi = jax.ops.segment_sum(Dij[:, None, :] * Yij, edge_src, species.shape[0])
71
72        nbasis = rhoi.shape[-1]
73        nchannels = self.nchannels if self.nchannels is not None else nbasis
74
75        if self.include_species:
76            xis = [species_encoding, rhoi[:, 0, :]]
77        else:
78            xis = [rhoi[:, 0, :]]
79
80        for l in range(self.lmax + 1):
81            rhoil = rhoi[:, l**2 : (l + 1) ** 2, :]
82            xl, yl = jnp.split(
83                nn.Dense(2 * nchannels, use_bias=False, name=f"xy_l{l}")(rhoil),
84                2,
85                axis=-1,
86            )
87            xil = (xl*yl).sum(axis=1) / (2 * l + 1) ** 0.5
88            xis.append(xil)
89        xi = jnp.concatenate(xis, axis=-1)
90
91        if self.embedding_key is None:
92            return xi
93        return {**inputs, self.embedding_key: xi}

Filtered Overlap of Atomic Moments

FID : FOAM

Similar to SOAP embedding but for each rank l, we do not take all combinations of each channels but linearly project on 2 nchannels elements and then take the scalar product. This is then kind of a linearly filtered SOAP embedding.

FOAMEmbedding( _graphs_properties: Dict, lmax: int = 2, nchannels: Optional[int] = None, graph_key: str = 'graph', embedding_key: str = 'embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, include_species: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
lmax: int = 2

The maximum order of spherical tensors.

nchannels: Optional[int] = None

The number of channels.

graph_key: str = 'graph'

The key for the graph input.

embedding_key: str = 'embedding'

The key for the embedding output.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding

radial_basis: dict

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis

include_species: bool = True

Whether to include the species encoding in the embedding.

FID: ClassVar[str] = 'FOAM'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None