fennol.models.embeddings.mace

  1import functools
  2import math
  3import jax
  4import jax.numpy as jnp
  5import flax.linen as nn
  6from typing import Sequence, Dict, Union, ClassVar, Optional, Set
  7import dataclasses
  8
  9from ..misc.encodings import RadialBasis
 10from ...utils.activations import activation_from_str
 11
 12
 13try:
 14    import e3nn_jax as e3nn
 15
 16    E3NN_AVAILABLE = True
 17    E3NN_EXCEPTION = None
 18    Irreps = e3nn.Irreps
 19    Irrep = e3nn.Irrep
 20except Exception as e:
 21    E3NN_AVAILABLE = False
 22    E3NN_EXCEPTION = e
 23    e3nn = None
 24
 25    class Irreps(tuple):
 26        pass
 27
 28    class Irrep(tuple):
 29        pass
 30
 31
 32class MACE(nn.Module):
 33    """MACE equivariant message passing neural network.
 34
 35    adapted from MACE-jax github repo by M. Geiger and I. Batatia
 36    
 37    T. Plé reordered some operations and changed defaults to match the recent mace-torch version 
 38    -> compatibility with pretrained torch models requires some work on the parameters:
 39        - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling
 40        - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors
 41        - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters
 42    
 43    References:
 44        - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436.
 45        https://doi.org/10.48550/arXiv.2206.07697
 46        - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022).
 47        https://doi.org/10.48550/arXiv.2205.06643
 48
 49    """
 50    _graphs_properties: Dict
 51    output_irreps: Union[Irreps, str] = "1x0e"
 52    """The output irreps of the model."""
 53    hidden_irreps: Union[Irreps, str] = "128x0e + 128x1o"
 54    """The hidden irreps of the model."""
 55    readout_mlp_irreps: Union[Irreps, str] = "16x0e"
 56    """The hidden irreps of the readout MLP."""
 57    graph_key: str = "graph"
 58    """The key in the input dictionary that corresponds to the molecular graph to use."""
 59    output_key: Optional[str] = None
 60    """The key of the embedding in the output dictionary."""
 61    avg_num_neighbors: float = 1.0
 62    """The expected average number of neighbors."""
 63    ninteractions: int = 2
 64    """The number of interaction layers."""
 65    num_features: Optional[int] = None
 66    """The number of features per node. default gcd of hidden_irreps multiplicities"""
 67    radial_basis: dict = dataclasses.field(
 68        default_factory=lambda: {"basis": "bessel", "dim": 8, "trainable": False}
 69    )
 70    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 71    lmax: int = 1
 72    """The maximum angular momentum to consider."""
 73    correlation: int = 3
 74    """The correlation order at each layer."""
 75    activation: str = "silu"
 76    """The activation function to use."""
 77    symmetric_tensor_product_basis: bool = False
 78    """Whether to use the symmetric tensor product basis."""
 79    interaction_irreps: Union[Irreps, str] = "o3_restricted"
 80    skip_connection_first_layer: bool = True
 81    radial_network_hidden: Sequence[int] = dataclasses.field(
 82        default_factory=lambda: [64, 64, 64]
 83    )
 84    scalar_output: bool = False
 85    zmax: int = 86
 86    """The maximum atomic number to consider."""
 87    convolution_mode: int = 1
 88
 89    FID: ClassVar[str] = "MACE"
 90
 91    @nn.compact
 92    def __call__(self, inputs):
 93        if not E3NN_AVAILABLE:
 94            raise E3NN_EXCEPTION
 95
 96        species_indices = inputs["species"]
 97        graph = inputs[self.graph_key]
 98        distances = graph["distances"]
 99        vec = e3nn.IrrepsArray("1o", graph["vec"])
100        switch = graph["switch"]
101        edge_src = graph["edge_src"]
102        edge_dst = graph["edge_dst"]
103
104        output_irreps = e3nn.Irreps(self.output_irreps)
105        hidden_irreps = e3nn.Irreps(self.hidden_irreps)
106        readout_mlp_irreps = e3nn.Irreps(self.readout_mlp_irreps)
107
108        # extract or set num_features
109        if self.num_features is None:
110            num_features = functools.reduce(math.gcd, (mul for mul, _ in hidden_irreps))
111            hidden_irreps = e3nn.Irreps(
112                [(mul // num_features, ir) for mul, ir in hidden_irreps]
113            )
114        else:
115            num_features = self.num_features
116
117        # get interaction irreps
118        if self.interaction_irreps == "o3_restricted":
119            interaction_irreps = e3nn.Irreps.spherical_harmonics(self.lmax)
120        elif self.interaction_irreps == "o3_full":
121            interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(self.lmax))
122        else:
123            interaction_irreps = e3nn.Irreps(self.interaction_irreps)
124        convol_irreps = num_features * interaction_irreps
125
126        # convert species to internal indices
127        # maxidx = max(PERIODIC_TABLE_REV_IDX.values())
128        # conv_tensor = [0] * (maxidx + 2)
129        # if isinstance(self.species_order, str):
130        #     species_order = [el.strip() for el in self.species_order.split(",")]
131        # else:
132        #     species_order = [el for el in self.species_order]
133        # for i, s in enumerate(species_order):
134        #     conv_tensor[PERIODIC_TABLE_REV_IDX[s]] = i
135        # species_indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species]
136        num_species = self.zmax + 2
137
138        # species encoding
139        encoding_irreps: e3nn.Irreps = (
140            (num_features * hidden_irreps).filter("0e").regroup()
141        )
142        species_encoding = self.param(
143            "species_encoding",
144            lambda key, shape: jax.nn.standardize(
145                jax.random.normal(key, shape, dtype=jnp.float32)
146            ),
147            (num_species, encoding_irreps.dim),
148        )[species_indices]
149        # convert to IrrepsArray
150        node_feats = e3nn.IrrepsArray(encoding_irreps, species_encoding)
151
152        # radial embedding
153        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
154        radial_embedding = (
155            RadialBasis(
156                **{
157                    **self.radial_basis,
158                    "end": cutoff,
159                    "name": f"RadialBasis",
160                }
161            )(distances)
162            * switch[:, None]
163        )
164
165        # spherical harmonics
166        assert self.convolution_mode in [0,1,2], "convolution_mode must be 0, 1 or 2"
167        if self.convolution_mode == 0:
168            Yij = e3nn.spherical_harmonics(range(0, self.lmax + 1), vec, True)
169        elif self.convolution_mode == 1:
170            Yij = e3nn.spherical_harmonics(range(1, self.lmax + 1), vec, True)
171
172        outputs = []
173        node_feats_all = []
174        for layer in range(self.ninteractions):
175            first = layer == 0
176            last = layer == self.ninteractions - 1
177
178            layer_irreps = num_features * (
179                hidden_irreps if not last else hidden_irreps.filter(output_irreps)
180            )
181
182            # Linear skip connection
183            sc = None
184            if not first or self.skip_connection_first_layer:
185                sc = e3nn.flax.Linear(
186                    layer_irreps,
187                    num_indexed_weights=num_species,
188                    name=f"skip_tp_{layer}",
189                    force_irreps_out=True,
190                )(species_indices, node_feats)
191
192            ################################################
193            # Interaction block (Message passing convolution)
194            node_feats = e3nn.flax.Linear(node_feats.irreps, name=f"linear_up_{layer}")(
195                node_feats
196            )
197
198
199            messages = node_feats[edge_src]
200            if self.convolution_mode == 0:
201                messages = e3nn.tensor_product(
202                    messages,
203                    Yij,
204                    filter_ir_out=convol_irreps,
205                    regroup_output=True,
206                )
207            elif self.convolution_mode == 1:
208                messages = e3nn.concatenate(
209                    [
210                        messages.filter(convol_irreps),
211                        e3nn.tensor_product(
212                            messages,
213                            Yij,
214                            filter_ir_out=convol_irreps,
215                        ),
216                        # e3nn.tensor_product_with_spherical_harmonics(
217                        #     messages, vectors, self.max_ell
218                        # ).filter(convol_irreps),
219                    ]
220                ).regroup()
221            else:
222                messages = e3nn.tensor_product_with_spherical_harmonics(
223                    messages, vec, self.lmax
224                ).filter(convol_irreps).regroup()
225
226            # mix = FullyConnectedNet(
227            #     [*self.radial_network_hidden, messages.irreps.num_irreps],
228            #     activation=activation_from_str(self.activation),
229            #     name=f"radial_network_{layer}",
230            #     use_bias=False,
231            # )(radial_embedding)
232            mix = e3nn.flax.MultiLayerPerceptron(
233                [*self.radial_network_hidden, messages.irreps.num_irreps],
234                act=activation_from_str(self.activation),
235                output_activation=False,
236                name=f"radial_network_{layer}",
237                gradient_normalization="element",
238            )(
239                radial_embedding
240            )
241
242            messages = messages * mix
243            node_feats = (
244                e3nn.IrrepsArray.zeros(
245                    messages.irreps, node_feats.shape[:1], messages.dtype
246                )
247                .at[edge_dst]
248                .add(messages)
249            )
250            # print("irreps_mid jax",node_feats.irreps)
251            # jax.debug.print("node_feats={n}", n=jnp.sum(node_feats.array,axis=0)[550:570])
252
253            node_feats = (
254                e3nn.flax.Linear(convol_irreps, name=f"linear_dn_{layer}")(node_feats)
255                / self.avg_num_neighbors
256            )
257
258            if first and not self.skip_connection_first_layer:
259                node_feats = e3nn.flax.Linear(
260                    node_feats.irreps,
261                    num_indexed_weights=num_species,
262                    name=f"skip_tp_{layer}",
263                )(species_indices, node_feats)
264
265            ################################################
266            # Equivariant product basis block
267
268            # symmetric contractions
269            node_feats = SymmetricContraction(
270                keep_irrep_out={ir for _, ir in layer_irreps},
271                correlation=self.correlation,
272                num_species=num_species,
273                gradient_normalization="element",  # NOTE: This is to copy mace-torch
274                symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
275            )(
276                node_feats, species_indices
277            )
278
279
280            node_feats = e3nn.flax.Linear(
281                layer_irreps, name=f"linear_contraction_{layer}"
282            )(node_feats)
283
284
285            if sc is not None:
286                # add skip connection
287                node_feats = node_feats + sc
288
289
290            ################################################
291            
292            # Readout block
293            if last:
294                num_vectors = readout_mlp_irreps.filter(drop=["0e", "0o"]).num_irreps
295                layer_out = e3nn.flax.Linear(
296                    (readout_mlp_irreps + e3nn.Irreps(f"{num_vectors}x0e")).simplify(),
297                    name=f"hidden_linear_readout_last",
298                )(node_feats)
299                layer_out = e3nn.gate(
300                    layer_out,
301                    even_act=activation_from_str(self.activation),
302                    even_gate_act=None,
303                )
304                layer_out = e3nn.flax.Linear(
305                    output_irreps, name=f"linear_readout_last"
306                )(layer_out)
307            else:
308                layer_out = e3nn.flax.Linear(
309                    output_irreps,
310                    name=f"linear_readout_{layer}",
311                )(node_feats)
312
313            if self.scalar_output:
314                layer_out = layer_out.filter("0e").array
315
316            outputs.append(layer_out)
317            node_feats_all.append(node_feats.filter("0e").array)
318
319        if self.scalar_output:
320            output = jnp.stack(outputs, axis=1)
321        else:
322            output = e3nn.stack(outputs, axis=1)
323
324        node_feats_all = jnp.concatenate(node_feats_all, axis=-1)
325
326        output_key = self.output_key if self.output_key is not None else self.name
327        return {
328            **inputs,
329            output_key: output,
330            output_key + "_node_feats": node_feats_all,
331        }
332
333
334class SymmetricContraction(nn.Module):
335
336    correlation: int
337    keep_irrep_out: Set[Irrep]
338    num_species: int
339    gradient_normalization: Union[str, float]
340    symmetric_tensor_product_basis: bool
341
342    @nn.compact
343    def __call__(self, input, index):
344        if not E3NN_AVAILABLE:
345            raise E3NN_EXCEPTION
346
347        if self.gradient_normalization is None:
348            gradient_normalization = e3nn.config("gradient_normalization")
349        else:
350            gradient_normalization = self.gradient_normalization
351        if isinstance(gradient_normalization, str):
352            gradient_normalization = {"element": 0.0, "path": 1.0}[
353                gradient_normalization
354            ]
355
356        keep_irrep_out = self.keep_irrep_out
357        if isinstance(keep_irrep_out, str):
358            keep_irrep_out = e3nn.Irreps(keep_irrep_out)
359            assert all(mul == 1 for mul, _ in keep_irrep_out)
360
361        keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out}
362
363        input = input.mul_to_axis().remove_nones()
364
365        ### PREPARE WEIGHTS
366        ws = []
367        Us = []
368        for order in range(1, self.correlation + 1):  # correlation, ..., 1
369            if self.symmetric_tensor_product_basis:
370                U = e3nn.reduced_symmetric_tensor_product_basis(
371                    input.irreps, order, keep_ir=keep_irrep_out
372                )
373            else:
374                U = e3nn.reduced_tensor_product_basis(
375                    [input.irreps] * order, keep_ir=keep_irrep_out
376                )
377            # U = U / order  # normalization TODO(mario): put back after testing
378            # NOTE(mario): The normalization constants (/order and /mul**0.5)
379            # has been numerically checked to be correct.
380
381            # TODO(mario) implement norm_p
382            Us.append(U)
383
384            wsorder = []
385            for (mul, ir_out), u in zip(U.irreps, U.list):
386                u = u.astype(input.array.dtype)
387                # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim]
388
389                w = self.param(
390                    f"w{order}_{ir_out}",
391                    nn.initializers.normal(
392                        stddev=(mul**-0.5) ** (1.0 - gradient_normalization)
393                    ),
394                    (self.num_species, mul, input.shape[-2]),
395                )
396                w = w * (mul**-0.5) ** gradient_normalization
397                wsorder.append(w)
398            ws.append(wsorder)
399
400        def fn(input: e3nn.IrrepsArray, index: jnp.ndarray):
401            # - This operation is parallel on the feature dimension (but each feature has its own parameters)
402            # This operation is an efficient implementation of
403            # vmap(lambda w, x: FunctionalLinear(irreps_out)(w, concatenate([x, tensor_product(x, x), tensor_product(x, x, x), ...])))(w, x)
404            # up to x power self.correlation
405            assert input.ndim == 2  # [num_features, irreps_x.dim]
406            assert index.ndim == 0  # int
407
408            out = dict()
409            x_ = input.array
410
411            for order in range(self.correlation, 0, -1):  # correlation, ..., 1
412
413                U = Us[order - 1]
414
415                # ((w3 x + w2) x + w1) x
416                #  \-----------/
417                #       out
418
419                for ii, ((mul, ir_out), u) in enumerate(zip(U.irreps, U.list)):
420                    u = u.astype(x_.dtype)
421                    # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim]
422
423                    w = ws[order - 1][ii][index]
424                    if ir_out not in out:
425                        out[ir_out] = (
426                            "special",
427                            jnp.einsum("...jki,kc,cj->c...i", u, w, x_),
428                        )  # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim]
429                    else:
430                        out[ir_out] += jnp.einsum(
431                            "...ki,kc->c...i", u, w
432                        )  # [num_features, (irreps_x.dim)^order, ir_out.dim]
433
434                # ((w3 x + w2) x + w1) x
435                #  \----------------/
436                #         out (in the normal case)
437
438                for ir_out in out:
439                    if isinstance(out[ir_out], tuple):
440                        out[ir_out] = out[ir_out][1]
441                        continue  # already done (special case optimization above)
442
443                    out[ir_out] = jnp.einsum(
444                        "c...ji,cj->c...i", out[ir_out], x_
445                    )  # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim]
446
447                # ((w3 x + w2) x + w1) x
448                #  \-------------------/
449                #           out
450
451            # out[irrep_out] : [num_features, ir_out.dim]
452            irreps_out = e3nn.Irreps(sorted(out.keys()))
453            return e3nn.IrrepsArray.from_list(
454                irreps_out,
455                [out[ir][:, None, :] for (_, ir) in irreps_out],
456                (input.shape[0],),
457            )
458
459        # Treat batch indices using vmap
460        shape = jnp.broadcast_shapes(input.shape[:-2], index.shape)
461        input = input.broadcast_to(shape + input.shape[-2:])
462        index = jnp.broadcast_to(index, shape)
463
464        fn_mapped = fn
465        for _ in range(input.ndim - 2):
466            fn_mapped = jax.vmap(fn_mapped)
467
468        return fn_mapped(input, index).axis_to_mul()
469
470
471# class SymmetricContraction(nn.Module):
472
473#     correlation: int
474#     keep_irrep_out: Set[Irrep]
475#     num_species: int
476#     gradient_normalization: Union[str, float]
477#     symmetric_tensor_product_basis: bool
478
479#     @nn.compact
480#     def __call__(self, input: IrrepsArray, index: jnp.ndarray):
481#         if not E3NN_AVAILABLE:
482#             raise E3NN_EXCEPTION
483
484#         if self.gradient_normalization is None:
485#             gradient_normalization = e3nn.config("gradient_normalization")
486#         else:
487#             gradient_normalization = self.gradient_normalization
488#         if isinstance(gradient_normalization, str):
489#             gradient_normalization = {"element": 0.0, "path": 1.0}[
490#                 gradient_normalization
491#             ]
492
493#         keep_irrep_out = self.keep_irrep_out
494#         if isinstance(keep_irrep_out, str):
495#             keep_irrep_out = e3nn.Irreps(keep_irrep_out)
496#             assert all(mul == 1 for mul, _ in keep_irrep_out)
497
498#         keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out}
499
500#         onehot = jnp.eye(self.num_species)[index]
501
502#         ### PREPARE WEIGHTS
503#         ws = []
504#         us = []
505#         for ir_out in keep_irrep_out:
506#             usorder = []
507#             wsorder = []
508#             for order in range(1, self.correlation + 1):  # correlation, ..., 1
509#                 if self.symmetric_tensor_product_basis:
510#                     U = e3nn.reduced_symmetric_tensor_product_basis(
511#                         input.irreps, order, keep_ir=[ir_out]
512#                     )
513#                 else:
514#                     U = e3nn.reduced_tensor_product_basis(
515#                         [input.irreps] * order, keep_ir=[ir_out]
516#                     )
517#                 u = jnp.moveaxis(U.list[0].astype(input.array.dtype), -1, 0)
518#                 usorder.append(u)
519
520#                 mul, _ = U.irreps[0]
521#                 w = self.param(
522#                     f"w{order}_{ir_out}",
523#                     nn.initializers.normal(
524#                         stddev=(mul**-0.5) ** (1.0 - gradient_normalization)
525#                     ),
526#                     (self.num_species, mul, input.shape[-2]),
527#                 )
528#                 w = w * (mul**-0.5) ** gradient_normalization
529#                 wsorder.append(w)
530#             ws.append(wsorder)
531#             us.append(usorder)
532
533#         x = input.array
534
535#         outs = []
536#         for i, ir in enumerate(keep_irrep_out):
537#             w = ws[i][-1]  # [index]
538#             u = us[i][-1]
539#             out = jnp.einsum("...jk,ekc,bcj,be->bc...", u, w, x, onehot)
540
541#             for order in range(self.correlation - 1, 0, -1):
542#                 w = ws[i][order - 1]  # [index]
543#                 u = us[i][order - 1]
544
545#                 c_tensor = jnp.einsum("...k,ekc,be->bc...", u, w, onehot) + out
546#                 out = jnp.einsum("bc...j,bcj->bc...", c_tensor, x)
547
548#             outs.append(out.reshape(x.shape[0], -1))
549
550#         out = jnp.concatenate(outs, axis=-1)
551
552#         return e3nn.IrrepsArray(input.shape[1] * e3nn.Irreps(keep_irrep_out), out)
class MACE(flax.linen.module.Module):
 33class MACE(nn.Module):
 34    """MACE equivariant message passing neural network.
 35
 36    adapted from MACE-jax github repo by M. Geiger and I. Batatia
 37    
 38    T. Plé reordered some operations and changed defaults to match the recent mace-torch version 
 39    -> compatibility with pretrained torch models requires some work on the parameters:
 40        - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling
 41        - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors
 42        - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters
 43    
 44    References:
 45        - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436.
 46        https://doi.org/10.48550/arXiv.2206.07697
 47        - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022).
 48        https://doi.org/10.48550/arXiv.2205.06643
 49
 50    """
 51    _graphs_properties: Dict
 52    output_irreps: Union[Irreps, str] = "1x0e"
 53    """The output irreps of the model."""
 54    hidden_irreps: Union[Irreps, str] = "128x0e + 128x1o"
 55    """The hidden irreps of the model."""
 56    readout_mlp_irreps: Union[Irreps, str] = "16x0e"
 57    """The hidden irreps of the readout MLP."""
 58    graph_key: str = "graph"
 59    """The key in the input dictionary that corresponds to the molecular graph to use."""
 60    output_key: Optional[str] = None
 61    """The key of the embedding in the output dictionary."""
 62    avg_num_neighbors: float = 1.0
 63    """The expected average number of neighbors."""
 64    ninteractions: int = 2
 65    """The number of interaction layers."""
 66    num_features: Optional[int] = None
 67    """The number of features per node. default gcd of hidden_irreps multiplicities"""
 68    radial_basis: dict = dataclasses.field(
 69        default_factory=lambda: {"basis": "bessel", "dim": 8, "trainable": False}
 70    )
 71    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 72    lmax: int = 1
 73    """The maximum angular momentum to consider."""
 74    correlation: int = 3
 75    """The correlation order at each layer."""
 76    activation: str = "silu"
 77    """The activation function to use."""
 78    symmetric_tensor_product_basis: bool = False
 79    """Whether to use the symmetric tensor product basis."""
 80    interaction_irreps: Union[Irreps, str] = "o3_restricted"
 81    skip_connection_first_layer: bool = True
 82    radial_network_hidden: Sequence[int] = dataclasses.field(
 83        default_factory=lambda: [64, 64, 64]
 84    )
 85    scalar_output: bool = False
 86    zmax: int = 86
 87    """The maximum atomic number to consider."""
 88    convolution_mode: int = 1
 89
 90    FID: ClassVar[str] = "MACE"
 91
 92    @nn.compact
 93    def __call__(self, inputs):
 94        if not E3NN_AVAILABLE:
 95            raise E3NN_EXCEPTION
 96
 97        species_indices = inputs["species"]
 98        graph = inputs[self.graph_key]
 99        distances = graph["distances"]
100        vec = e3nn.IrrepsArray("1o", graph["vec"])
101        switch = graph["switch"]
102        edge_src = graph["edge_src"]
103        edge_dst = graph["edge_dst"]
104
105        output_irreps = e3nn.Irreps(self.output_irreps)
106        hidden_irreps = e3nn.Irreps(self.hidden_irreps)
107        readout_mlp_irreps = e3nn.Irreps(self.readout_mlp_irreps)
108
109        # extract or set num_features
110        if self.num_features is None:
111            num_features = functools.reduce(math.gcd, (mul for mul, _ in hidden_irreps))
112            hidden_irreps = e3nn.Irreps(
113                [(mul // num_features, ir) for mul, ir in hidden_irreps]
114            )
115        else:
116            num_features = self.num_features
117
118        # get interaction irreps
119        if self.interaction_irreps == "o3_restricted":
120            interaction_irreps = e3nn.Irreps.spherical_harmonics(self.lmax)
121        elif self.interaction_irreps == "o3_full":
122            interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(self.lmax))
123        else:
124            interaction_irreps = e3nn.Irreps(self.interaction_irreps)
125        convol_irreps = num_features * interaction_irreps
126
127        # convert species to internal indices
128        # maxidx = max(PERIODIC_TABLE_REV_IDX.values())
129        # conv_tensor = [0] * (maxidx + 2)
130        # if isinstance(self.species_order, str):
131        #     species_order = [el.strip() for el in self.species_order.split(",")]
132        # else:
133        #     species_order = [el for el in self.species_order]
134        # for i, s in enumerate(species_order):
135        #     conv_tensor[PERIODIC_TABLE_REV_IDX[s]] = i
136        # species_indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species]
137        num_species = self.zmax + 2
138
139        # species encoding
140        encoding_irreps: e3nn.Irreps = (
141            (num_features * hidden_irreps).filter("0e").regroup()
142        )
143        species_encoding = self.param(
144            "species_encoding",
145            lambda key, shape: jax.nn.standardize(
146                jax.random.normal(key, shape, dtype=jnp.float32)
147            ),
148            (num_species, encoding_irreps.dim),
149        )[species_indices]
150        # convert to IrrepsArray
151        node_feats = e3nn.IrrepsArray(encoding_irreps, species_encoding)
152
153        # radial embedding
154        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
155        radial_embedding = (
156            RadialBasis(
157                **{
158                    **self.radial_basis,
159                    "end": cutoff,
160                    "name": f"RadialBasis",
161                }
162            )(distances)
163            * switch[:, None]
164        )
165
166        # spherical harmonics
167        assert self.convolution_mode in [0,1,2], "convolution_mode must be 0, 1 or 2"
168        if self.convolution_mode == 0:
169            Yij = e3nn.spherical_harmonics(range(0, self.lmax + 1), vec, True)
170        elif self.convolution_mode == 1:
171            Yij = e3nn.spherical_harmonics(range(1, self.lmax + 1), vec, True)
172
173        outputs = []
174        node_feats_all = []
175        for layer in range(self.ninteractions):
176            first = layer == 0
177            last = layer == self.ninteractions - 1
178
179            layer_irreps = num_features * (
180                hidden_irreps if not last else hidden_irreps.filter(output_irreps)
181            )
182
183            # Linear skip connection
184            sc = None
185            if not first or self.skip_connection_first_layer:
186                sc = e3nn.flax.Linear(
187                    layer_irreps,
188                    num_indexed_weights=num_species,
189                    name=f"skip_tp_{layer}",
190                    force_irreps_out=True,
191                )(species_indices, node_feats)
192
193            ################################################
194            # Interaction block (Message passing convolution)
195            node_feats = e3nn.flax.Linear(node_feats.irreps, name=f"linear_up_{layer}")(
196                node_feats
197            )
198
199
200            messages = node_feats[edge_src]
201            if self.convolution_mode == 0:
202                messages = e3nn.tensor_product(
203                    messages,
204                    Yij,
205                    filter_ir_out=convol_irreps,
206                    regroup_output=True,
207                )
208            elif self.convolution_mode == 1:
209                messages = e3nn.concatenate(
210                    [
211                        messages.filter(convol_irreps),
212                        e3nn.tensor_product(
213                            messages,
214                            Yij,
215                            filter_ir_out=convol_irreps,
216                        ),
217                        # e3nn.tensor_product_with_spherical_harmonics(
218                        #     messages, vectors, self.max_ell
219                        # ).filter(convol_irreps),
220                    ]
221                ).regroup()
222            else:
223                messages = e3nn.tensor_product_with_spherical_harmonics(
224                    messages, vec, self.lmax
225                ).filter(convol_irreps).regroup()
226
227            # mix = FullyConnectedNet(
228            #     [*self.radial_network_hidden, messages.irreps.num_irreps],
229            #     activation=activation_from_str(self.activation),
230            #     name=f"radial_network_{layer}",
231            #     use_bias=False,
232            # )(radial_embedding)
233            mix = e3nn.flax.MultiLayerPerceptron(
234                [*self.radial_network_hidden, messages.irreps.num_irreps],
235                act=activation_from_str(self.activation),
236                output_activation=False,
237                name=f"radial_network_{layer}",
238                gradient_normalization="element",
239            )(
240                radial_embedding
241            )
242
243            messages = messages * mix
244            node_feats = (
245                e3nn.IrrepsArray.zeros(
246                    messages.irreps, node_feats.shape[:1], messages.dtype
247                )
248                .at[edge_dst]
249                .add(messages)
250            )
251            # print("irreps_mid jax",node_feats.irreps)
252            # jax.debug.print("node_feats={n}", n=jnp.sum(node_feats.array,axis=0)[550:570])
253
254            node_feats = (
255                e3nn.flax.Linear(convol_irreps, name=f"linear_dn_{layer}")(node_feats)
256                / self.avg_num_neighbors
257            )
258
259            if first and not self.skip_connection_first_layer:
260                node_feats = e3nn.flax.Linear(
261                    node_feats.irreps,
262                    num_indexed_weights=num_species,
263                    name=f"skip_tp_{layer}",
264                )(species_indices, node_feats)
265
266            ################################################
267            # Equivariant product basis block
268
269            # symmetric contractions
270            node_feats = SymmetricContraction(
271                keep_irrep_out={ir for _, ir in layer_irreps},
272                correlation=self.correlation,
273                num_species=num_species,
274                gradient_normalization="element",  # NOTE: This is to copy mace-torch
275                symmetric_tensor_product_basis=self.symmetric_tensor_product_basis,
276            )(
277                node_feats, species_indices
278            )
279
280
281            node_feats = e3nn.flax.Linear(
282                layer_irreps, name=f"linear_contraction_{layer}"
283            )(node_feats)
284
285
286            if sc is not None:
287                # add skip connection
288                node_feats = node_feats + sc
289
290
291            ################################################
292            
293            # Readout block
294            if last:
295                num_vectors = readout_mlp_irreps.filter(drop=["0e", "0o"]).num_irreps
296                layer_out = e3nn.flax.Linear(
297                    (readout_mlp_irreps + e3nn.Irreps(f"{num_vectors}x0e")).simplify(),
298                    name=f"hidden_linear_readout_last",
299                )(node_feats)
300                layer_out = e3nn.gate(
301                    layer_out,
302                    even_act=activation_from_str(self.activation),
303                    even_gate_act=None,
304                )
305                layer_out = e3nn.flax.Linear(
306                    output_irreps, name=f"linear_readout_last"
307                )(layer_out)
308            else:
309                layer_out = e3nn.flax.Linear(
310                    output_irreps,
311                    name=f"linear_readout_{layer}",
312                )(node_feats)
313
314            if self.scalar_output:
315                layer_out = layer_out.filter("0e").array
316
317            outputs.append(layer_out)
318            node_feats_all.append(node_feats.filter("0e").array)
319
320        if self.scalar_output:
321            output = jnp.stack(outputs, axis=1)
322        else:
323            output = e3nn.stack(outputs, axis=1)
324
325        node_feats_all = jnp.concatenate(node_feats_all, axis=-1)
326
327        output_key = self.output_key if self.output_key is not None else self.name
328        return {
329            **inputs,
330            output_key: output,
331            output_key + "_node_feats": node_feats_all,
332        }

MACE equivariant message passing neural network.

adapted from MACE-jax github repo by M. Geiger and I. Batatia

T. Plé reordered some operations and changed defaults to match the recent mace-torch version -> compatibility with pretrained torch models requires some work on the parameters: - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters

References: - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436. https://doi.org/10.48550/arXiv.2206.07697 - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022). https://doi.org/10.48550/arXiv.2205.06643

MACE( _graphs_properties: Dict, output_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '1x0e', hidden_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '128x0e + 128x1o', readout_mlp_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '16x0e', graph_key: str = 'graph', output_key: Optional[str] = None, avg_num_neighbors: float = 1.0, ninteractions: int = 2, num_features: Optional[int] = None, radial_basis: dict = <factory>, lmax: int = 1, correlation: int = 3, activation: str = 'silu', symmetric_tensor_product_basis: bool = False, interaction_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = 'o3_restricted', skip_connection_first_layer: bool = True, radial_network_hidden: Sequence[int] = <factory>, scalar_output: bool = False, zmax: int = 86, convolution_mode: int = 1, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
output_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '1x0e'

The output irreps of the model.

hidden_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '128x0e + 128x1o'

The hidden irreps of the model.

readout_mlp_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = '16x0e'

The hidden irreps of the readout MLP.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the molecular graph to use.

output_key: Optional[str] = None

The key of the embedding in the output dictionary.

avg_num_neighbors: float = 1.0

The expected average number of neighbors.

ninteractions: int = 2

The number of interaction layers.

num_features: Optional[int] = None

The number of features per node. default gcd of hidden_irreps multiplicities

radial_basis: dict

The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.

lmax: int = 1

The maximum angular momentum to consider.

correlation: int = 3

The correlation order at each layer.

activation: str = 'silu'

The activation function to use.

symmetric_tensor_product_basis: bool = False

Whether to use the symmetric tensor product basis.

interaction_irreps: Union[e3nn_jax._src.irreps.Irreps, str] = 'o3_restricted'
skip_connection_first_layer: bool = True
radial_network_hidden: Sequence[int]
scalar_output: bool = False
zmax: int = 86

The maximum atomic number to consider.

convolution_mode: int = 1
FID: ClassVar[str] = 'MACE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SymmetricContraction(flax.linen.module.Module):
335class SymmetricContraction(nn.Module):
336
337    correlation: int
338    keep_irrep_out: Set[Irrep]
339    num_species: int
340    gradient_normalization: Union[str, float]
341    symmetric_tensor_product_basis: bool
342
343    @nn.compact
344    def __call__(self, input, index):
345        if not E3NN_AVAILABLE:
346            raise E3NN_EXCEPTION
347
348        if self.gradient_normalization is None:
349            gradient_normalization = e3nn.config("gradient_normalization")
350        else:
351            gradient_normalization = self.gradient_normalization
352        if isinstance(gradient_normalization, str):
353            gradient_normalization = {"element": 0.0, "path": 1.0}[
354                gradient_normalization
355            ]
356
357        keep_irrep_out = self.keep_irrep_out
358        if isinstance(keep_irrep_out, str):
359            keep_irrep_out = e3nn.Irreps(keep_irrep_out)
360            assert all(mul == 1 for mul, _ in keep_irrep_out)
361
362        keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out}
363
364        input = input.mul_to_axis().remove_nones()
365
366        ### PREPARE WEIGHTS
367        ws = []
368        Us = []
369        for order in range(1, self.correlation + 1):  # correlation, ..., 1
370            if self.symmetric_tensor_product_basis:
371                U = e3nn.reduced_symmetric_tensor_product_basis(
372                    input.irreps, order, keep_ir=keep_irrep_out
373                )
374            else:
375                U = e3nn.reduced_tensor_product_basis(
376                    [input.irreps] * order, keep_ir=keep_irrep_out
377                )
378            # U = U / order  # normalization TODO(mario): put back after testing
379            # NOTE(mario): The normalization constants (/order and /mul**0.5)
380            # has been numerically checked to be correct.
381
382            # TODO(mario) implement norm_p
383            Us.append(U)
384
385            wsorder = []
386            for (mul, ir_out), u in zip(U.irreps, U.list):
387                u = u.astype(input.array.dtype)
388                # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim]
389
390                w = self.param(
391                    f"w{order}_{ir_out}",
392                    nn.initializers.normal(
393                        stddev=(mul**-0.5) ** (1.0 - gradient_normalization)
394                    ),
395                    (self.num_species, mul, input.shape[-2]),
396                )
397                w = w * (mul**-0.5) ** gradient_normalization
398                wsorder.append(w)
399            ws.append(wsorder)
400
401        def fn(input: e3nn.IrrepsArray, index: jnp.ndarray):
402            # - This operation is parallel on the feature dimension (but each feature has its own parameters)
403            # This operation is an efficient implementation of
404            # vmap(lambda w, x: FunctionalLinear(irreps_out)(w, concatenate([x, tensor_product(x, x), tensor_product(x, x, x), ...])))(w, x)
405            # up to x power self.correlation
406            assert input.ndim == 2  # [num_features, irreps_x.dim]
407            assert index.ndim == 0  # int
408
409            out = dict()
410            x_ = input.array
411
412            for order in range(self.correlation, 0, -1):  # correlation, ..., 1
413
414                U = Us[order - 1]
415
416                # ((w3 x + w2) x + w1) x
417                #  \-----------/
418                #       out
419
420                for ii, ((mul, ir_out), u) in enumerate(zip(U.irreps, U.list)):
421                    u = u.astype(x_.dtype)
422                    # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim]
423
424                    w = ws[order - 1][ii][index]
425                    if ir_out not in out:
426                        out[ir_out] = (
427                            "special",
428                            jnp.einsum("...jki,kc,cj->c...i", u, w, x_),
429                        )  # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim]
430                    else:
431                        out[ir_out] += jnp.einsum(
432                            "...ki,kc->c...i", u, w
433                        )  # [num_features, (irreps_x.dim)^order, ir_out.dim]
434
435                # ((w3 x + w2) x + w1) x
436                #  \----------------/
437                #         out (in the normal case)
438
439                for ir_out in out:
440                    if isinstance(out[ir_out], tuple):
441                        out[ir_out] = out[ir_out][1]
442                        continue  # already done (special case optimization above)
443
444                    out[ir_out] = jnp.einsum(
445                        "c...ji,cj->c...i", out[ir_out], x_
446                    )  # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim]
447
448                # ((w3 x + w2) x + w1) x
449                #  \-------------------/
450                #           out
451
452            # out[irrep_out] : [num_features, ir_out.dim]
453            irreps_out = e3nn.Irreps(sorted(out.keys()))
454            return e3nn.IrrepsArray.from_list(
455                irreps_out,
456                [out[ir][:, None, :] for (_, ir) in irreps_out],
457                (input.shape[0],),
458            )
459
460        # Treat batch indices using vmap
461        shape = jnp.broadcast_shapes(input.shape[:-2], index.shape)
462        input = input.broadcast_to(shape + input.shape[-2:])
463        index = jnp.broadcast_to(index, shape)
464
465        fn_mapped = fn
466        for _ in range(input.ndim - 2):
467            fn_mapped = jax.vmap(fn_mapped)
468
469        return fn_mapped(input, index).axis_to_mul()
SymmetricContraction( correlation: int, keep_irrep_out: Set[e3nn_jax._src.irreps.Irrep], num_species: int, gradient_normalization: Union[str, float], symmetric_tensor_product_basis: bool, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
correlation: int
keep_irrep_out: Set[e3nn_jax._src.irreps.Irrep]
num_species: int
gradient_normalization: Union[str, float]
symmetric_tensor_product_basis: bool
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None