fennol.models.embeddings.chgnet

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import dataclasses
  5import numpy as np
  6from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  7
  8from ...utils.initializers import initializer_from_str
  9from ..misc.nets import GatedPerceptron
 10
 11
 12class CHGNetEmbedding(nn.Module):
 13    """ Crystal Hamiltonian Graph Neural Network
 14
 15    FID: CHGNET
 16
 17    ### Reference
 18    Deng, B., Zhong, P., Jun, K. et al. CHGNet as a pretrained universal neural network potential for charge-informed atomistic modelling.
 19        Nat Mach Intell 5, 1031–1041 (2023). https://doi.org/10.1038/s42256-023-00716-3
 20    
 21    """
 22    _graphs_properties: Dict
 23    dim: int = 64
 24    """The dimension of the embedding."""
 25    nmax_angle: int = 4
 26    """ The maximum fourier order for the angle representation."""
 27    nlayers: int = 3
 28    """The number of layers."""
 29    graph_key: str = "graph"
 30    """The key for the graph input."""
 31    graph_angle_key: Optional[str] = None
 32    """The key for the angular graph input."""
 33    embedding_key: str = "embedding"
 34    """The key for the embedding output."""
 35    species_encoding: dict = dataclasses.field(default_factory=dict)
 36    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 37    radial_basis: dict = dataclasses.field(default_factory=dict)
 38    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`"""
 39    radial_basis_angle: Optional[dict] = None
 40    """The radial basis parameters for angle embedding. See `fennol.models.misc.encodings.RadialBasis`"""
 41    keep_all_layers: bool = False
 42    """Whether to keep all layers in the output."""
 43    kernel_init: Union[str, Callable] = "lecun_normal()"
 44    """The kernel initializer for Dense operations."""
 45
 46    FID: ClassVar[str]  = "CHGNET"
 47
 48    @nn.compact
 49    def __call__(self, inputs):
 50        species = inputs["species"]
 51        assert (
 52            len(species.shape) == 1
 53        ), "Species must be a 1D array (batches must be flattened)"
 54
 55        kernel_init = (
 56            initializer_from_str(self.kernel_init)
 57            if isinstance(self.kernel_init, str)
 58            else self.kernel_init
 59        )
 60
 61        ##################################################
 62        # Check that the graph_angle is a subgraph of graph
 63        graph = inputs[self.graph_key]
 64        graph_angle_key = (
 65            self.graph_angle_key if self.graph_angle_key is not None else self.graph_key
 66        )
 67        graph_angle = inputs[graph_angle_key]
 68
 69        correct_graph = (
 70            graph_angle_key == self.graph_key
 71            or self._graphs_properties[graph_angle_key]["parent_graph"]
 72            == self.graph_key
 73        )
 74        assert (
 75            correct_graph
 76        ), f"graph_angle_key={graph_angle_key} must be a subgraph of graph_key={self.graph_key}"
 77        assert "angles" in graph_angle, f"Graph {graph_angle_key} must contain angles"
 78        # check if graph_angle is a filtered graph
 79        filtered = "parent_graph" in self._graphs_properties[graph_angle_key]
 80        if filtered:
 81            filter_indices = graph_angle["filter_indices"]
 82
 83        ##################################################
 84        ### SPECIES ENCODING ###
 85        zi = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(species)
 86
 87        vi = nn.Dense(self.dim, name="vi0", use_bias=True, kernel_init=kernel_init)(zi)
 88
 89        ##################################################
 90        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 91        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 92        distances = graph["distances"]
 93        switch = graph["switch"][:, None]
 94
 95        ### COMPUTE RADIAL BASIS ###
 96        radial_basis = RadialBasis(
 97            **{
 98                **self.radial_basis,
 99                "end": cutoff,
100                "name": f"RadialBasis",
101            }
102        )(distances)
103
104        ##################################################
105        ### GET ANGLES ###
106        angles = graph_angle["angles"][:, None]
107        angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"]
108        switch_angles = graph_angle["switch"][:, None]
109        central_atom = graph_angle["central_atom"]
110
111        ### COMPUTE RADIAL BASIS FOR ANGLES ###
112        if self.radial_basis_angle is not None:
113            dangles = graph_angle["distances"]
114            radial_basis_angle = (
115                RadialBasis(
116                    **{
117                        **self.radial_basis_angle,
118                        "end": self._graphs_properties[graph_angle_key]["cutoff"],
119                        "name": f"RadialBasisAngle",
120                    }
121                )(dangles)
122                * switch_angles
123            )
124
125        else:
126            if filtered:
127                radial_basis_angle = radial_basis[filter_indices] * switch_angles
128            else:
129                radial_basis_angle = radial_basis * switch
130
131        radial_basis = radial_basis * switch
132
133        eij, eija = jnp.split(
134            nn.Dense(
135                2 * self.dim, use_bias=False, kernel_init=kernel_init, name="eij0"
136            )(radial_basis),
137            2,
138            axis=-1,
139        )
140
141        eijb = nn.Dense(self.dim, use_bias=False, kernel_init=kernel_init, name="eijb")(
142            radial_basis_angle
143        )
144        eijkb = eijb[angle_src] * eijb[angle_dst]
145
146        ##################################################
147        ### ANGULAR BASIS ###
148        # build fourier series for angles
149        nangles = jnp.asarray(
150            np.arange(self.nmax_angle + 1, dtype=distances.dtype)[None, :]
151        )
152
153        Ac = jnp.cos(nangles * angles)
154        As = jnp.sin(nangles[:, 1:] * angles)
155        aijk = nn.Dense(
156            self.dim, use_bias=False, kernel_init=kernel_init, name="aijk0"
157        )(jnp.concatenate([Ac, As], axis=-1))
158        aikj = aijk
159
160        ##################################################
161        if self.keep_all_layers:
162            vis = []
163
164        ### LOOP OVER LAYERS ###
165        for layer in range(self.nlayers):
166            phiv = GatedPerceptron(
167                self.dim,
168                use_bias=True,
169                kernel_init=kernel_init,
170                activation=nn.silu,
171                name=f"phiv{layer+1}",
172            )(jnp.concatenate([vi[edge_src], vi[edge_dst], eij], axis=-1))
173
174            vi = vi + nn.Dense(
175                self.dim, use_bias=True, kernel_init=kernel_init, name=f"vi{layer+1}"
176            )(jax.ops.segment_sum(phiv * eija, edge_src, vi.shape[0]))
177
178            if self.keep_all_layers:
179                vis.append(vi)
180                
181            if layer == self.nlayers - 1:
182                break
183                
184            eij_ang = eij[filter_indices] if filtered else eij
185            eij_angsrc = eij_ang[angle_src]
186            eij_angdst = eij_ang[angle_dst]
187
188            phie = GatedPerceptron(
189                self.dim,
190                use_bias=True,
191                kernel_init=kernel_init,
192                activation=nn.silu,
193                name=f"phie{layer+1}",
194            )
195
196            vi_ang = vi[central_atom]
197            phie_ijk = phie(
198                jnp.concatenate(
199                    [eij_angsrc, eij_angdst, aijk, vi_ang],
200                    axis=-1,
201                )
202            )
203            phie_ikj = phie(
204                jnp.concatenate(
205                    [eij_angdst, eij_angsrc, aikj, vi_ang],
206                    axis=-1,
207                )
208            )
209
210            eij_ang = eij_ang + nn.Dense(
211                self.dim, use_bias=False, kernel_init=kernel_init, name=f"Le{layer+1}"
212            )(
213                jax.ops.segment_sum(phie_ikj * eijkb, angle_dst, eij_ang.shape[0])
214                + jax.ops.segment_sum(phie_ijk * eijkb, angle_src, eij_ang.shape[0])
215            )
216            eij_angsrc = eij_ang[angle_src]
217            eij_angdst = eij_ang[angle_dst]
218            eij = eij.at[filter_indices].set(eij_ang)
219
220            phia = GatedPerceptron(
221                self.dim,
222                use_bias=True,
223                kernel_init=kernel_init,
224                activation=nn.silu,
225                name=f"phia{layer+1}",
226            )
227            aijk = aijk + phia(
228                jnp.concatenate([eij_angsrc, eij_angdst, aijk, vi_ang], axis=-1)
229            )
230            aikj = aikj + phia(
231                jnp.concatenate([eij_angdst, eij_angsrc, aikj, vi_ang], axis=-1)
232            )
233
234            
235
236        output = {
237            **inputs,
238            self.embedding_key: vi,
239        }
240        if self.keep_all_layers:
241            output[self.embedding_key + "_layers"] = jnp.stack(vis, axis=1)
242        return output
class CHGNetEmbedding(flax.linen.module.Module):
 13class CHGNetEmbedding(nn.Module):
 14    """ Crystal Hamiltonian Graph Neural Network
 15
 16    FID: CHGNET
 17
 18    ### Reference
 19    Deng, B., Zhong, P., Jun, K. et al. CHGNet as a pretrained universal neural network potential for charge-informed atomistic modelling.
 20        Nat Mach Intell 5, 1031–1041 (2023). https://doi.org/10.1038/s42256-023-00716-3
 21    
 22    """
 23    _graphs_properties: Dict
 24    dim: int = 64
 25    """The dimension of the embedding."""
 26    nmax_angle: int = 4
 27    """ The maximum fourier order for the angle representation."""
 28    nlayers: int = 3
 29    """The number of layers."""
 30    graph_key: str = "graph"
 31    """The key for the graph input."""
 32    graph_angle_key: Optional[str] = None
 33    """The key for the angular graph input."""
 34    embedding_key: str = "embedding"
 35    """The key for the embedding output."""
 36    species_encoding: dict = dataclasses.field(default_factory=dict)
 37    """The species encoding parameters. See `fennol.models.misc.encodings.SpeciesEncoding`"""
 38    radial_basis: dict = dataclasses.field(default_factory=dict)
 39    """The radial basis parameters. See `fennol.models.misc.encodings.RadialBasis`"""
 40    radial_basis_angle: Optional[dict] = None
 41    """The radial basis parameters for angle embedding. See `fennol.models.misc.encodings.RadialBasis`"""
 42    keep_all_layers: bool = False
 43    """Whether to keep all layers in the output."""
 44    kernel_init: Union[str, Callable] = "lecun_normal()"
 45    """The kernel initializer for Dense operations."""
 46
 47    FID: ClassVar[str]  = "CHGNET"
 48
 49    @nn.compact
 50    def __call__(self, inputs):
 51        species = inputs["species"]
 52        assert (
 53            len(species.shape) == 1
 54        ), "Species must be a 1D array (batches must be flattened)"
 55
 56        kernel_init = (
 57            initializer_from_str(self.kernel_init)
 58            if isinstance(self.kernel_init, str)
 59            else self.kernel_init
 60        )
 61
 62        ##################################################
 63        # Check that the graph_angle is a subgraph of graph
 64        graph = inputs[self.graph_key]
 65        graph_angle_key = (
 66            self.graph_angle_key if self.graph_angle_key is not None else self.graph_key
 67        )
 68        graph_angle = inputs[graph_angle_key]
 69
 70        correct_graph = (
 71            graph_angle_key == self.graph_key
 72            or self._graphs_properties[graph_angle_key]["parent_graph"]
 73            == self.graph_key
 74        )
 75        assert (
 76            correct_graph
 77        ), f"graph_angle_key={graph_angle_key} must be a subgraph of graph_key={self.graph_key}"
 78        assert "angles" in graph_angle, f"Graph {graph_angle_key} must contain angles"
 79        # check if graph_angle is a filtered graph
 80        filtered = "parent_graph" in self._graphs_properties[graph_angle_key]
 81        if filtered:
 82            filter_indices = graph_angle["filter_indices"]
 83
 84        ##################################################
 85        ### SPECIES ENCODING ###
 86        zi = SpeciesEncoding(**self.species_encoding, name="SpeciesEncoding")(species)
 87
 88        vi = nn.Dense(self.dim, name="vi0", use_bias=True, kernel_init=kernel_init)(zi)
 89
 90        ##################################################
 91        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 92        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
 93        distances = graph["distances"]
 94        switch = graph["switch"][:, None]
 95
 96        ### COMPUTE RADIAL BASIS ###
 97        radial_basis = RadialBasis(
 98            **{
 99                **self.radial_basis,
100                "end": cutoff,
101                "name": f"RadialBasis",
102            }
103        )(distances)
104
105        ##################################################
106        ### GET ANGLES ###
107        angles = graph_angle["angles"][:, None]
108        angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"]
109        switch_angles = graph_angle["switch"][:, None]
110        central_atom = graph_angle["central_atom"]
111
112        ### COMPUTE RADIAL BASIS FOR ANGLES ###
113        if self.radial_basis_angle is not None:
114            dangles = graph_angle["distances"]
115            radial_basis_angle = (
116                RadialBasis(
117                    **{
118                        **self.radial_basis_angle,
119                        "end": self._graphs_properties[graph_angle_key]["cutoff"],
120                        "name": f"RadialBasisAngle",
121                    }
122                )(dangles)
123                * switch_angles
124            )
125
126        else:
127            if filtered:
128                radial_basis_angle = radial_basis[filter_indices] * switch_angles
129            else:
130                radial_basis_angle = radial_basis * switch
131
132        radial_basis = radial_basis * switch
133
134        eij, eija = jnp.split(
135            nn.Dense(
136                2 * self.dim, use_bias=False, kernel_init=kernel_init, name="eij0"
137            )(radial_basis),
138            2,
139            axis=-1,
140        )
141
142        eijb = nn.Dense(self.dim, use_bias=False, kernel_init=kernel_init, name="eijb")(
143            radial_basis_angle
144        )
145        eijkb = eijb[angle_src] * eijb[angle_dst]
146
147        ##################################################
148        ### ANGULAR BASIS ###
149        # build fourier series for angles
150        nangles = jnp.asarray(
151            np.arange(self.nmax_angle + 1, dtype=distances.dtype)[None, :]
152        )
153
154        Ac = jnp.cos(nangles * angles)
155        As = jnp.sin(nangles[:, 1:] * angles)
156        aijk = nn.Dense(
157            self.dim, use_bias=False, kernel_init=kernel_init, name="aijk0"
158        )(jnp.concatenate([Ac, As], axis=-1))
159        aikj = aijk
160
161        ##################################################
162        if self.keep_all_layers:
163            vis = []
164
165        ### LOOP OVER LAYERS ###
166        for layer in range(self.nlayers):
167            phiv = GatedPerceptron(
168                self.dim,
169                use_bias=True,
170                kernel_init=kernel_init,
171                activation=nn.silu,
172                name=f"phiv{layer+1}",
173            )(jnp.concatenate([vi[edge_src], vi[edge_dst], eij], axis=-1))
174
175            vi = vi + nn.Dense(
176                self.dim, use_bias=True, kernel_init=kernel_init, name=f"vi{layer+1}"
177            )(jax.ops.segment_sum(phiv * eija, edge_src, vi.shape[0]))
178
179            if self.keep_all_layers:
180                vis.append(vi)
181                
182            if layer == self.nlayers - 1:
183                break
184                
185            eij_ang = eij[filter_indices] if filtered else eij
186            eij_angsrc = eij_ang[angle_src]
187            eij_angdst = eij_ang[angle_dst]
188
189            phie = GatedPerceptron(
190                self.dim,
191                use_bias=True,
192                kernel_init=kernel_init,
193                activation=nn.silu,
194                name=f"phie{layer+1}",
195            )
196
197            vi_ang = vi[central_atom]
198            phie_ijk = phie(
199                jnp.concatenate(
200                    [eij_angsrc, eij_angdst, aijk, vi_ang],
201                    axis=-1,
202                )
203            )
204            phie_ikj = phie(
205                jnp.concatenate(
206                    [eij_angdst, eij_angsrc, aikj, vi_ang],
207                    axis=-1,
208                )
209            )
210
211            eij_ang = eij_ang + nn.Dense(
212                self.dim, use_bias=False, kernel_init=kernel_init, name=f"Le{layer+1}"
213            )(
214                jax.ops.segment_sum(phie_ikj * eijkb, angle_dst, eij_ang.shape[0])
215                + jax.ops.segment_sum(phie_ijk * eijkb, angle_src, eij_ang.shape[0])
216            )
217            eij_angsrc = eij_ang[angle_src]
218            eij_angdst = eij_ang[angle_dst]
219            eij = eij.at[filter_indices].set(eij_ang)
220
221            phia = GatedPerceptron(
222                self.dim,
223                use_bias=True,
224                kernel_init=kernel_init,
225                activation=nn.silu,
226                name=f"phia{layer+1}",
227            )
228            aijk = aijk + phia(
229                jnp.concatenate([eij_angsrc, eij_angdst, aijk, vi_ang], axis=-1)
230            )
231            aikj = aikj + phia(
232                jnp.concatenate([eij_angdst, eij_angsrc, aikj, vi_ang], axis=-1)
233            )
234
235            
236
237        output = {
238            **inputs,
239            self.embedding_key: vi,
240        }
241        if self.keep_all_layers:
242            output[self.embedding_key + "_layers"] = jnp.stack(vis, axis=1)
243        return output

Crystal Hamiltonian Graph Neural Network

FID: CHGNET

Reference

Deng, B., Zhong, P., Jun, K. et al. CHGNet as a pretrained universal neural network potential for charge-informed atomistic modelling. Nat Mach Intell 5, 1031–1041 (2023). https://doi.org/10.1038/s42256-023-00716-3

CHGNetEmbedding( _graphs_properties: Dict, dim: int = 64, nmax_angle: int = 4, nlayers: int = 3, graph_key: str = 'graph', graph_angle_key: Optional[str] = None, embedding_key: str = 'embedding', species_encoding: dict = <factory>, radial_basis: dict = <factory>, radial_basis_angle: Optional[dict] = None, keep_all_layers: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 64

The dimension of the embedding.

nmax_angle: int = 4

The maximum fourier order for the angle representation.

nlayers: int = 3

The number of layers.

graph_key: str = 'graph'

The key for the graph input.

graph_angle_key: Optional[str] = None

The key for the angular graph input.

embedding_key: str = 'embedding'

The key for the embedding output.

species_encoding: dict

The species encoding parameters. See fennol.models.misc.encodings.SpeciesEncoding

radial_basis: dict

The radial basis parameters. See fennol.models.misc.encodings.RadialBasis

radial_basis_angle: Optional[dict] = None

The radial basis parameters for angle embedding. See fennol.models.misc.encodings.RadialBasis

keep_all_layers: bool = False

Whether to keep all layers in the output.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initializer for Dense operations.

FID: ClassVar[str] = 'CHGNET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None