fennol.models.embeddings.crate

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import dataclasses
  5import numpy as np
  6from typing import Dict, Union, Callable, Sequence, Optional, Tuple, ClassVar
  7
  8from ..misc.encodings import SpeciesEncoding, RadialBasis, positional_encoding
  9from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3
 10from ...utils.activations import activation_from_str
 11from ...utils.initializers import initializer_from_str
 12from ..misc.nets import FullyConnectedNet, BlockIndexNet
 13from ..misc.e3 import ChannelMixing, ChannelMixingE3, FilteredTensorProduct
 14from ...utils.periodic_table import D3_COV_RADII, D3_VDW_RADII, VALENCE_ELECTRONS
 15from ...utils import AtomicUnits as au
 16
 17class CRATEmbedding(nn.Module):
 18    """Configurable Resources ATomic Environment
 19
 20    FID : CRATE
 21
 22    This class represents the CRATE (Configurable Resources ATomic Environment) embedding model.
 23    It is used to encode atomic environments using multiple sources of information
 24      (radial, angular, E(3), message-passing, LODE, etc...)
 25    """
 26
 27    _graphs_properties: Dict
 28
 29    dim: int = 256
 30    """The size of the embedding vectors."""
 31    nlayers: int = 2
 32    """The number of interaction layers in the model."""
 33    keep_all_layers: bool = False
 34    """Whether to output all layers."""
 35
 36    dim_src: int = 64
 37    """The size of the source embedding vectors."""
 38    dim_dst: int = 32
 39    """The size of the destination embedding vectors."""
 40
 41    angle_style: str = "fourier"
 42    """The style of angle representation."""
 43    dim_angle: int = 8
 44    """The size of the pairwise vectors use for triplet combinations."""
 45    nmax_angle: int = 4
 46    """The dimension of the angle representation (minus one)."""
 47    zeta: float = 14.1
 48    """The zeta parameter for the model ANI angular representation."""
 49    angle_combine_pairs: bool = True
 50    """Whether to combine angle pairs instead of average distance embedding like in ANI."""
 51
 52    message_passing: bool = True
 53    """Whether to use message passing in the model."""
 54    att_dim: int = 1
 55    """The hidden size for the attention mechanism (only used when message-passing is disabled)."""
 56
 57    lmax: int = 0
 58    """The maximum order of spherical tensors."""
 59    nchannels_l: int = 16
 60    """The number of channels for spherical tensors."""
 61    n_tp: int = 1
 62    """The number of tensor products performed at each layer."""
 63    ignore_irreps_parity: bool = False
 64    """Whether to ignore the parity of the irreps in the tensor product."""
 65    edge_tp: bool = False
 66    """Whether to perform a tensor product on edges before sending messages."""
 67    resolve_wij_l: bool = False
 68    """Equivariant message weights are l-dependent."""
 69
 70    species_init: bool = False
 71    """Whether to initialize the embedding using the species encoding."""
 72    mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 73    """The hidden layer sizes for the mixing network."""
 74    pair_mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 75    """The hidden layer sizes for the pair mixing network."""
 76    activation: Union[Callable, str] = "silu"
 77    """The activation function for the mixing network."""
 78    kernel_init: Union[str, Callable] = "lecun_normal()"
 79    """The kernel initialization function for Dense operations."""
 80    activation_mixing: Union[Callable, str] = "tssr3"
 81    """The activation function applied after mixing."""
 82    layer_normalization: bool = False
 83    """Whether to apply layer normalization after each layer."""
 84    use_bias: bool = True
 85    """Whether to use bias in the Dense operations."""
 86
 87    graph_key: str = "graph"
 88    """The key for the graph data in the inputs dictionary."""
 89    graph_angle_key: Optional[str] = None
 90    """The key for the angle graph data in the inputs dictionary."""
 91    embedding_key: Optional[str] = None
 92    """The key for the embedding data in the output dictionary."""
 93    pair_embedding_key: Optional[str] = None
 94    """The key for the pair embedding data in the output dictionary."""
 95
 96    species_encoding: Union[dict, str] = dataclasses.field(default_factory=dict)
 97    """If `str`, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 98    radial_basis: dict = dataclasses.field(default_factory=dict)
 99    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
100    radial_basis_angle: Optional[dict] = None
101    """The dictionary of parameters for radial basis functions for angle embedding. 
102        If None, the radial basis for angles is the same as the radial basis for distances."""
103
104    graph_lode: Optional[str] = None
105    """The key for the lode graph data in the inputs dictionary."""
106    lode_channels: Union[int, Sequence[int]] = 8
107    """The number of channels for lode."""
108    lmax_lode: int = 0
109    """The maximum order of spherical tensors for lode."""
110    a_lode: float = -1.
111    """The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode."""
112    lode_resolve_l: bool = True
113    """Whether to resolve the lode channels by l."""
114    lode_multipole_interaction: bool = True
115    """Whether to interact with the multipole moments of the lode graph."""
116    lode_direct_multipoles: bool = True
117    """Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction."""
118    lode_equi_full_combine: bool = False
119    lode_normalize_l: bool = False
120    lode_use_field_norm: bool = True
121    lode_rshort: Optional[float] = None
122    lode_dshort: float = 0.5
123    
124
125    charge_embedding: bool = False
126    """Whether to include charge embedding."""
127    total_charge_key: str = "total_charge"
128    """The key for the total charge data in the inputs dictionary."""
129
130    block_index_key: Optional[str] = None
131    """The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network."""
132
133    FID: ClassVar[str] = "CRATE"
134
135    @nn.compact
136    def __call__(self, inputs):
137        species = inputs["species"]
138        assert (
139            len(species.shape) == 1
140        ), "Species must be a 1D array (batches must be flattened)"
141        reduce_memory =  "reduce_memory" in inputs.get("flags", {})
142
143        kernel_init = (
144            initializer_from_str(self.kernel_init)
145            if isinstance(self.kernel_init, str)
146            else self.kernel_init
147        )
148
149        actmix = activation_from_str(self.activation_mixing)
150
151        ##################################################
152        graph = inputs[self.graph_key]
153        use_angles = self.graph_angle_key is not None
154        if use_angles:
155            graph_angle = inputs[self.graph_angle_key]
156
157            # Check that the graph_angle is a subgraph of graph
158            correct_graph = (
159                self.graph_angle_key == self.graph_key
160                or self._graphs_properties[self.graph_angle_key]["parent_graph"]
161                == self.graph_key
162            )
163            assert (
164                correct_graph
165            ), f"graph_angle_key={self.graph_angle_key} must be a subgraph of graph_key={self.graph_key}"
166            assert (
167                "angles" in graph_angle
168            ), f"Graph {self.graph_angle_key} must contain angles"
169            # check if graph_angle is a filtered graph
170            filtered = "parent_graph" in self._graphs_properties[self.graph_angle_key]
171            if filtered:
172                filter_indices = graph_angle["filter_indices"]
173
174        ##################################################
175        ### SPECIES ENCODING ###
176        if isinstance(self.species_encoding, str):
177            zi = inputs[self.species_encoding]
178        else:
179            zi = SpeciesEncoding(
180                **self.species_encoding, name="SpeciesEncoding"
181            )(species)
182
183        
184        if self.layer_normalization:
185            def layer_norm(x):
186                mu = jnp.mean(x,axis=-1,keepdims=True)
187                dx = x-mu
188                var = jnp.mean(dx**2,axis=-1,keepdims=True)
189                sig = (1.e-6 + var)**(-0.5)
190                return dx*sig
191        else:
192            layer_norm = lambda x:x
193                
194
195        if self.charge_embedding:
196            xi, qi = jnp.split(
197                nn.Dense(self.dim + 1, use_bias=False, name="ChargeEncoding")(zi),
198                [self.dim],
199                axis=-1,
200            )
201            batch_index = inputs["batch_index"]
202            natoms = inputs["natoms"]
203            nsys = natoms.shape[0]
204            Zi = jnp.asarray(VALENCE_ELECTRONS)[species]
205            Ntot = jax.ops.segment_sum(Zi, batch_index, nsys) - inputs.get(
206                self.total_charge_key, jnp.zeros(nsys)
207            )
208            ai = jax.nn.softplus(qi.squeeze(-1))
209            A = jax.ops.segment_sum(ai, batch_index, nsys)
210            Ni = ai * (Ntot / A)[batch_index]
211            charge_embedding = positional_encoding(Ni, self.dim)
212            xi = layer_norm(xi + charge_embedding)
213        elif self.species_init:
214            xi = layer_norm(nn.Dense(self.dim, use_bias=False, name="SpeciesInit")(zi))
215        else:
216            xi = zi
217
218        ##################################################
219        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
220        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
221        distances = graph["distances"]
222        switch = graph["switch"][:, None]
223
224        ### COMPUTE RADIAL BASIS ###
225        radial_basis = RadialBasis(
226            **{
227                **self.radial_basis,
228                "end": cutoff,
229                "name": f"RadialBasis",
230            }
231        )(distances)
232
233        do_lode = self.graph_lode is not None
234        if do_lode:
235            graph_lode = inputs[self.graph_lode]
236            switch_lode = graph_lode["switch"][:, None]
237
238            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
239            r = graph_lode["distances"][:, None]
240            rc = self._graphs_properties[self.graph_lode]["cutoff"]
241            
242            lmax_lr = self.lmax_lode
243            equivariant_lode = lmax_lr > 0
244            assert lmax_lr >=0, f"lmax_lode must be >= 0, got {lmax_lr}"
245            if self.lode_multipole_interaction:
246                assert lmax_lr <= self.lmax, f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
247            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
248            if self.lode_resolve_l and equivariant_lode:
249                ls_lr = np.arange(lmax_lr + 1)
250            else:
251                ls_lr = np.array([0])
252
253            if self.a_lode > 0:
254                a = self.a_lode**2
255            else:
256                a = (
257                    self.param(
258                        "a_lr",
259                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[None, :],
260                    )
261                    ** 2
262                )
263            rc2a = rc**2 + a
264            ls_lr = 0.5 * (ls_lr[None, :] + 1)
265            ### minimal radial basis for long range (damped coulomb)
266            eij_lr = (
267                1.0 / (r**2 + a) ** ls_lr
268                - 1.0 / rc2a**ls_lr
269                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
270            ) * switch_lode
271
272            if self.lode_rshort is not None:
273                rs = self.lode_rshort
274                d = self.lode_dshort
275                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
276                    r < rs + d
277                ) + (r > rs + d)
278                eij_lr = eij_lr * switch_short
279
280            # dim_lr = self.nchannels_lode
281            nchannels_lode = (
282                [self.lode_channels] * self.nlayers
283                if isinstance(self.lode_channels, int)
284                else self.lode_channels
285            )
286            dim_lr = nchannels_lode
287            
288            if equivariant_lode:
289                if self.lode_resolve_l:
290                    eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
291                Yij = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
292                    graph_lode["vec"] / r
293                )
294                eij_lr = (eij_lr * Yij)[:, None, :]
295                dim_lr = [d * (lmax_lr + 1) for d in dim_lr]
296            
297
298        ##################################################
299        ### GET ANGLES ###
300        if use_angles:
301            angles = graph_angle["angles"][:, None]
302            angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"]
303            switch_angles = graph_angle["switch"][:, None]
304            central_atom = graph_angle["central_atom"]
305
306            if not self.angle_combine_pairs:
307                assert (
308                    self.radial_basis_angle is not None
309                ), "radial_basis_angle must be specified if angle_combine_pairs=False"
310
311            ### COMPUTE RADIAL BASIS FOR ANGLES ###
312            if self.radial_basis_angle is not None:
313                dangles = graph_angle["distances"]
314                swang = switch_angles
315                if not self.angle_combine_pairs:
316                    dangles = 0.5 * (dangles[angle_src] + dangles[angle_dst])
317                    swang = switch_angles[angle_src] * switch_angles[angle_dst]
318                radial_basis_angle = (
319                    RadialBasis(
320                        **{
321                            **self.radial_basis_angle,
322                            "end": self._graphs_properties[self.graph_angle_key][
323                                "cutoff"
324                            ],
325                            "name": f"RadialBasisAngle",
326                        }
327                    )(dangles)
328                    * swang
329                )
330
331            else:
332                if filtered:
333                    radial_basis_angle = radial_basis[filter_indices] * switch_angles
334                else:
335                    radial_basis_angle = radial_basis * switch
336
337        radial_basis = radial_basis * switch
338
339        # # add covalent indicator
340        # rc = jnp.asarray([d/au.BOHR for d in D3_COV_RADII])[species]
341        # rcij = rc[edge_src] + rc[edge_dst]
342        # fact = graph["switch"]*(2*distances/rcij)*jnp.exp(-0.5 * ((distances - rcij)/(0.1*rcij)) ** 2)
343        # radial_basis = jnp.concatenate([radial_basis,fact[:,None]],axis=-1)
344        # if use_angles:
345        #     rcij = rc[graph_angle["edge_src"]] + rc[graph_angle["edge_dst"]]
346        #     dangles = graph_angle["distances"]
347        #     fact = graph_angle["switch"]*((2*dangles/rcij))*jnp.exp(-0.5 * ((dangles - rcij)/(0.1*rcij))**2)
348        #     radial_basis_angle = jnp.concatenate([radial_basis_angle,fact[:,None]],axis=-1)
349
350
351        ##################################################
352        if self.lmax > 0:
353            Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
354                graph["vec"] / graph["distances"][:, None]
355            )[:, None, :]
356            Yij = jnp.broadcast_to(Yij, (Yij.shape[0], self.nchannels_l, Yij.shape[2]))
357            nrep_l = np.array([2 * l + 1 for l in range(self.lmax + 1)], dtype=np.int32)
358            # ls = [0]
359            # for l in range(1, self.lmax + 1):
360            #    ls = ls + [l] * (2 * l + 1)
361            #ls = jnp.asarray(np.array(ls)[None, :], dtype=distances.dtype)
362            #lcut = (0.5 + 0.5 * jnp.cos((np.pi / cutoff) * distances[:, #None])) ** (
363            #    ls + 1
364            #)
365            # lcut = jnp.where(graph["edge_mask"][:, None], lcut, 0.0)
366            # rijl1 = (lcut * distances[:, None] ** ls)[:, None, :]
367
368        ##################################################
369        if use_angles:
370            ### ANGULAR BASIS ###
371            if self.angle_style == "fourier":
372                # build fourier series for angles
373                nangles = self.param(
374                    f"nangles",
375                    lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :],
376                    self.nmax_angle + 1,
377                )
378
379                phi = self.param(
380                    f"phi",
381                    lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype),
382                    self.nmax_angle + 1,
383                )
384                xa = jnp.cos(nangles * angles + phi)
385            elif self.angle_style == "fourier_full":
386                # build fourier series for angles including sin terms
387                nangles = self.param(
388                    f"nangles",
389                    lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :],
390                    self.nmax_angle + 1,
391                )
392
393                phi = self.param(
394                    f"phi",
395                    lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype),
396                    2 * self.nmax_angle + 1,
397                )
398                xac = jnp.cos(nangles * angles + phi[:, : self.nmax_angle + 1])
399                xas = jnp.sin(nangles[:, 1:] * angles + phi[:, self.nmax_angle + 1 :])
400                xa = jnp.concatenate([xac, xas], axis=-1)
401            elif self.angle_style == "ani":
402                # ANI-style angle embedding
403                angle_start = np.pi / (2 * (self.nmax_angle + 1))
404                shiftZ = self.param(
405                    f"shiftZ",
406                    lambda key, dim: jnp.asarray(
407                        (np.linspace(0, np.pi, dim + 1) + angle_start)[None, :-1],
408                        dtype=distances.dtype,
409                    ),
410                    self.nmax_angle + 1,
411                )
412                zeta = self.param(
413                    f"zeta",
414                    lambda key: jnp.asarray(self.zeta, dtype=distances.dtype),
415                )
416                xa = (0.5 + 0.5 * jnp.cos(angles - shiftZ)) ** zeta
417            else:
418                raise ValueError(f"Unknown angle style {self.angle_style}")
419            xa = xa[:, None, :]
420            if not self.angle_combine_pairs:
421                if reduce_memory: raise NotImplementedError("Angle embedding not implemented with reduce_memory")
422                xa = (xa * radial_basis_angle[:, :, None]).reshape(
423                    -1, 1, xa.shape[1] * radial_basis_angle.shape[1]
424                )
425
426            if self.pair_embedding_key is not None:
427                if filtered:
428                    ang_pair_src = filter_indices[angle_src]
429                    ang_pair_dst = filter_indices[angle_dst]
430                else:
431                    ang_pair_src = angle_src
432                    ang_pair_dst = angle_dst
433                ang_pairs = jnp.concatenate((ang_pair_src, ang_pair_dst))
434
435        ##################################################
436        ### DIMENSIONS ###
437        dim_src = (
438            [self.dim_src] * self.nlayers
439            if isinstance(self.dim_src, int)
440            else self.dim_src
441        )
442        assert (
443            len(dim_src) == self.nlayers
444        ), f"dim_src must be an integer or a list of length {self.nlayers}"
445        dim_dst = self.dim_dst
446        # dim_dst = (
447        #     [self.dim_dst] * self.nlayers
448        #     if isinstance(self.dim_dst, int)
449        #     else self.dim_dst
450        # )
451        # assert (
452        #     len(dim_dst) == self.nlayers
453        # ), f"dim_dst must be an integer or a list of length {self.nlayers}"
454
455        if use_angles:
456            dim_angle = (
457                [self.dim_angle] * self.nlayers
458                if isinstance(self.dim_angle, int)
459                else self.dim_angle
460            )
461            assert (
462                len(dim_angle) == self.nlayers
463            ), f"dim_angle must be an integer or a list of length {self.nlayers}"
464            # nmax_angle = [self.nmax_angle]*self.nlayers if isinstance(self.nmax_angle, int) else self.nmax_angle
465            # assert len(nmax_angle) == self.nlayers, f"nmax_angle must be an integer or a list of length {self.nlayers}"
466        
467        initialize_e3 = True
468        if self.lmax > 0:
469            n_tp = (
470                [self.n_tp] * self.nlayers
471                if isinstance(self.n_tp, int)
472                else self.n_tp
473            )
474            assert (
475                len(n_tp) == self.nlayers
476            ), f"n_tp must be an integer or a list of length {self.nlayers}"
477
478
479        message_passing = (
480            [self.message_passing] * self.nlayers
481            if isinstance(self.message_passing, bool)
482            else self.message_passing
483        )
484        assert (
485            len(message_passing) == self.nlayers
486        ), f"message_passing must be a boolean or a list of length {self.nlayers}"
487
488        ##################################################
489        ### INITIALIZE PAIR EMBEDDING ###
490        if self.pair_embedding_key is not None:
491            xij_s,xij_d = jnp.split(nn.Dense(2*dim_dst, name="pair_init_linear")(zi), [dim_dst], axis=-1)
492            xij = layer_norm(xij_s[edge_src]*xij_d[edge_dst])
493
494        ##################################################
495        if self.keep_all_layers:
496            xis = []
497        
498        ### LOOP OVER LAYERS ###
499        for layer in range(self.nlayers):
500            ##################################################
501            ### COMPACT DESCRIPTORS ###
502            si, si_dst = jnp.split(
503                nn.Dense(
504                    dim_src[layer] + dim_dst,
505                    name=f"species_linear_{layer}",
506                    use_bias=self.use_bias,
507                )(xi),
508                [
509                    dim_src[layer],
510                ],
511                axis=-1,
512            )
513
514            ##################################################
515            if message_passing[layer] or layer == 0:
516                ### MESSAGE PASSING ###
517                si_mp = si_dst[edge_dst]
518            else:
519                # if layer == 0:
520                #     si_mp = si_dst[edge_dst]
521                ### ATTENTION TO SIMULATE MP ###
522                Q = nn.Dense(
523                    dim_dst * self.att_dim, name=f"queries_{layer}", use_bias=False
524                )(si_dst).reshape(-1, dim_dst, self.att_dim)[edge_src]
525                K = nn.Dense(
526                    dim_dst * self.att_dim, name=f"keys_{layer}", use_bias=False
527                )(zi).reshape(-1, dim_dst, self.att_dim)[edge_dst]
528
529                si_mp = (K * Q).sum(axis=-1) / self.att_dim**0.5
530                # Vmp = jax.ops.segment_sum(
531                #     (KQ * switch)[:, :, None] * Yij, edge_src, species.shape[0]
532                # )
533                # si_mp = (Vmp[edge_src] * Yij).sum(axis=-1)
534                # Q = nn.Dense(
535                #     dim_dst * dim_dst, name=f"queries_{layer}", use_bias=False
536                # )(si_dst).reshape(-1, dim_dst, dim_dst)
537                # si_mp = (
538                #     si_mp + jax.vmap(jnp.dot)(Q[edge_src], si_mp) / self.dim_dst**0.5
539                # )
540
541            if self.pair_embedding_key is not None:
542                si_mp = si_mp + xij
543
544            ##################################################
545            ### PAIR EMBEDDING ###
546            if reduce_memory:
547                Li = jnp.zeros((species.shape[0]* radial_basis.shape[1],si_mp.shape[1]),dtype=si_mp.dtype)
548                for i in range(radial_basis.shape[1]):
549                    indices = i + edge_src*radial_basis.shape[1]
550                    Li = Li.at[indices].add(si_mp*radial_basis[:,i,None])
551                Li = Li.reshape(species.shape[0], radial_basis.shape[1]*si_mp.shape[1])
552            else:
553                Lij = (si_mp[:, None, :] * radial_basis[:, :, None]).reshape(
554                    radial_basis.shape[0], si_mp.shape[1] * radial_basis.shape[1]
555                )
556                ### AGGREGATE PAIR EMBEDDING ###
557                Li = jax.ops.segment_sum(Lij, edge_src, species.shape[0])
558
559            ### CONCATENATE EMBEDDING COMPONENTS ###
560            components = [si, Li]
561            if self.pair_embedding_key is not None:
562                if reduce_memory: raise NotImplementedError("Pair embedding not implemented with reduce_memory")
563                components_pair = [si[edge_src], xij, Lij]
564
565
566            ##################################################
567            ### ANGLE EMBEDDING ###
568            if use_angles and dim_angle[layer]>0:
569                si_mp_ang = si_mp[filter_indices] if filtered else si_mp
570                if self.angle_combine_pairs:
571                    Wa = self.param(
572                        f"Wa_{layer}",
573                        nn.initializers.normal(
574                            stddev=1.0
575                            / (si_mp.shape[1] * radial_basis_angle.shape[1]) ** 0.5
576                        ),
577                        (si_mp.shape[1], radial_basis_angle.shape[1], dim_angle[layer]),
578                    )
579                    Da = jnp.einsum(
580                        "...i,...j,ijk->...k",
581                        si_mp_ang,
582                        radial_basis_angle,
583                        Wa,
584                    )
585
586                else:
587                    if message_passing[layer] or layer == 0:
588                        Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")(
589                            xi
590                        )[graph_angle["edge_dst"]]
591                    else:
592                        Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")(
593                            si_mp_ang
594                        )
595
596                Da = Da[angle_dst] * Da[angle_src]
597                ## combine pair and angle info
598                if reduce_memory:
599                    ang_embedding = jnp.zeros((species.shape[0]* Da.shape[-1],xa.shape[-1]),dtype=Da.dtype)
600                    for i in range(Da.shape[-1]):
601                        indices = i + central_atom*Da.shape[-1]
602                        ang_embedding = ang_embedding.at[indices].add(Da[:,i,None]*xa[:,0,:])
603                    ang_embedding = ang_embedding.reshape(species.shape[0], xa.shape[-1]*Da.shape[-1])
604                else:
605                    radang = (xa * Da[:, :, None]).reshape(
606                        (-1, Da.shape[1] * xa.shape[2])
607                    )
608                    ### AGGREGATE  ANGLE EMBEDDING ###
609                    ang_embedding = jax.ops.segment_sum(
610                        radang, central_atom, species.shape[0]
611                    )
612                    
613
614                components.append(ang_embedding)
615
616                if self.pair_embedding_key is not None:
617                    ang_ij = jax.ops.segment_sum(
618                        jnp.concatenate((radang, radang)),
619                        ang_pairs,
620                        edge_src.shape[0],
621                    )
622                    components_pair.append(ang_ij)
623            
624            ##################################################
625            ### EQUIVARIANT EMBEDDING ###
626            if self.lmax > 0 and n_tp[layer] >= 0:
627                if initialize_e3 or not message_passing[layer]:
628                    Vij = Yij
629                elif self.edge_tp:
630                    Vij = FilteredTensorProduct(
631                            self.lmax, self.lmax, name=f"edge_tp_{layer}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False
632                        )(Vi[edge_dst], Yij)
633                else:
634                    Vij = Vi[edge_dst]
635
636                ### compute channel weights
637                dim_wij = self.nchannels_l
638                if self.resolve_wij_l:
639                    dim_wij=self.nchannels_l*(self.lmax+1)
640
641                eij = Lij if self.pair_embedding_key is None else jnp.concatenate([Lij,xij*switch],axis=-1)
642                wij = nn.Dense(
643                        dim_wij, name=f"e3_channel_{layer}", use_bias=False
644                    )(eij)
645                if self.resolve_wij_l:
646                    wij = jnp.repeat(wij.reshape(-1,self.nchannels_l,self.lmax+1),nrep_l,axis=-1)
647                else:
648                    wij = wij[:,:,None]
649                
650                ### aggregate equivariant messages
651                drhoi = jax.ops.segment_sum(
652                    wij * Vij,
653                    edge_src,
654                    species.shape[0],
655                )
656
657                Vi0 = []
658                if initialize_e3:
659                    rhoi = drhoi
660                    Vi = ChannelMixingE3(
661                        self.lmax,
662                        self.nchannels_l,
663                        self.nchannels_l,
664                        name=f"e3_initial_mixing_{layer}",
665                    )(rhoi)
666                    # assert n_tp[layer] > 0, "n_tp must be > 0 for the first equivariant layer."
667                else:
668                    rhoi = rhoi + drhoi
669                    # if message_passing[layer]:
670                        # Vi0.append(drhoi[:, :, 0])
671                initialize_e3 = False
672                if n_tp[layer] > 0:
673                    for itp in range(n_tp[layer]):
674                        dVi = FilteredTensorProduct(
675                            self.lmax, self.lmax, name=f"tensor_product_{layer}_{itp}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False
676                        )(rhoi, Vi)
677                        Vi = ChannelMixing(
678                                self.lmax,
679                                self.nchannels_l,
680                                self.nchannels_l,
681                                name=f"tp_mixing_{layer}_{itp}",
682                            )(Vi + dVi)
683                        Vi0.append(dVi[:, :, 0])
684                    Vi0 = jnp.concatenate(Vi0, axis=-1)
685                    components.append(Vi0)
686
687                if self.pair_embedding_key is not None:
688                    Vij = Vi[edge_src]*Yij
689                    Vij0 = [Vij[...,0]]
690                    for l in range(1,self.lmax+1):
691                        Vij0.append(Vij[...,l**2:(l+1)**2].sum(axis=-1))
692                    Vij0 = jnp.concatenate(Vij0,axis=-1)
693                    components_pair.append(Vij0)
694
695            ##################################################
696            ### CONCATENATE EMBEDDING COMPONENTS ###
697            if do_lode and nchannels_lode[layer] > 0:
698                zj = nn.Dense(dim_lr[layer], use_bias=False, name=f"LODE_{layer}")(xi)
699                if equivariant_lode:
700                    zj = zj.reshape(
701                        (species.shape[0], nchannels_lode[layer], lmax_lr + 1)
702                    ).repeat(nrep_lr, axis=-1)
703                xi_lr = jax.ops.segment_sum(
704                    eij_lr * zj[edge_dst_lr], edge_src_lr, species.shape[0]
705                )
706                if equivariant_lode:
707                    assert self.lode_use_field_norm or self.lode_multipole_interaction, "equivariant LODE requires field norm or multipole interaction"
708                    if self.lode_multipole_interaction:
709                        if initialize_e3:
710                            raise ValueError("equivariant LODE used before local equivariants initialized")
711                        size_l_lr = (lmax_lr+1)**2
712                        if self.lode_direct_multipoles:
713                            assert nchannels_lode[layer] <= self.nchannels_l
714                            Mi = Vi[:, : nchannels_lode[layer], :size_l_lr]
715                        else:
716                            Mi = ChannelMixingE3(
717                                lmax_lr,
718                                self.nchannels_l,
719                                nchannels_lode[layer],
720                                name=f"e3_LODE_{layer}",
721                            )(Vi[...,:size_l_lr])
722                        Mi_lr = Mi * xi_lr
723                    components.append(xi_lr[:, :, 0])
724                    if self.lode_use_field_norm and self.lode_equi_full_combine:
725                        xi_lr1 = ChannelMixing(
726                            lmax_lr,
727                            nchannels_lode[layer],
728                            nchannels_lode[layer],
729                            name=f"LODE_mixing_{layer}",
730                        )(xi_lr)
731                    norm = 1.
732                    for l in range(1, lmax_lr + 1):
733                        if self.lode_normalize_l:
734                            norm = 1. / (2 * l + 1)
735                        if self.lode_multipole_interaction:
736                            components.append(Mi_lr[:, :, l**2 : (l + 1) ** 2].sum(axis=-1)*norm)
737
738                        if self.lode_use_field_norm:
739                            if self.lode_equi_full_combine:
740                                components.append((xi_lr[:,:,l**2 : (l + 1) ** 2]*xi_lr1[:,:,l**2 : (l + 1) ** 2]).sum(axis=-1)*norm)
741                            else:
742                                components.append(
743                                    ((xi_lr[:, :, l**2 : (l + 1) ** 2]) ** 2).sum(axis=-1)*norm
744                                )
745                else:
746                    components.append(xi_lr)
747
748            dxi = jnp.concatenate(components, axis=-1)
749
750            ##################################################
751            ### CONCATENATE PAIR EMBEDDING COMPONENTS ###
752            if self.pair_embedding_key is not None:
753                dxij = jnp.concatenate(components_pair, axis=-1)
754
755            ##################################################
756            ### MIX AND APPLY NONLINEARITY ###
757            if self.block_index_key is not None:
758                block_index = inputs[self.block_index_key]
759                dxi = actmix(BlockIndexNet(
760                        output_dim=self.dim,
761                        hidden_neurons=self.mixing_hidden,
762                        activation=self.activation,
763                        name=f"dxi_{layer}",
764                        use_bias=self.use_bias,
765                        kernel_init=kernel_init,
766                    )((species,dxi, block_index))
767                )
768            else:
769                dxi = actmix(
770                    FullyConnectedNet(
771                        [*self.mixing_hidden, self.dim],
772                        activation=self.activation,
773                        name=f"dxi_{layer}",
774                        use_bias=self.use_bias,
775                        kernel_init=kernel_init,
776                    )(dxi)
777                )
778
779            if self.pair_embedding_key is not None:
780                ### UPDATE PAIR EMBEDDING ###
781                # dxij = tssr3(nn.Dense(dim_dst, name=f"dxij_{layer}",use_bias=False)(dxij))
782                dxij = actmix(
783                    FullyConnectedNet(
784                        [*self.pair_mixing_hidden, dim_dst],
785                        activation=self.activation,
786                        name=f"dxij_{layer}",
787                        use_bias=False,
788                        kernel_init=kernel_init,
789                    )(dxij)
790                )
791                xij = layer_norm(xij + dxij)
792
793            ##################################################
794            ### UPDATE EMBEDDING ###
795            if layer == 0 and not (self.species_init or self.charge_embedding):
796                xi = layer_norm(dxi)
797            else:
798                ### FORGET GATE ###
799                R = jax.nn.sigmoid(
800                    self.param(
801                        f"retention_{layer}",
802                        nn.initializers.normal(),
803                        (xi.shape[-1],),
804                    )
805                )
806                xi = layer_norm(R[None, :] * xi + dxi)
807
808            if self.keep_all_layers:
809                xis.append(xi)
810
811        embedding_key = (
812            self.embedding_key if self.embedding_key is not None else self.name
813        )
814        output = {
815            **inputs,
816            embedding_key: xi,
817        }
818        if self.lmax > 0:
819            output[embedding_key + "_tensor"] = Vi
820        if self.keep_all_layers:
821            output[embedding_key + "_layers"] = jnp.stack(xis, axis=1)
822        if self.charge_embedding:
823            output[embedding_key + "_charge"] = charge_embedding
824        if self.pair_embedding_key is not None:
825            output[self.pair_embedding_key] = xij
826        return output
class CRATEmbedding(flax.linen.module.Module):
 18class CRATEmbedding(nn.Module):
 19    """Configurable Resources ATomic Environment
 20
 21    FID : CRATE
 22
 23    This class represents the CRATE (Configurable Resources ATomic Environment) embedding model.
 24    It is used to encode atomic environments using multiple sources of information
 25      (radial, angular, E(3), message-passing, LODE, etc...)
 26    """
 27
 28    _graphs_properties: Dict
 29
 30    dim: int = 256
 31    """The size of the embedding vectors."""
 32    nlayers: int = 2
 33    """The number of interaction layers in the model."""
 34    keep_all_layers: bool = False
 35    """Whether to output all layers."""
 36
 37    dim_src: int = 64
 38    """The size of the source embedding vectors."""
 39    dim_dst: int = 32
 40    """The size of the destination embedding vectors."""
 41
 42    angle_style: str = "fourier"
 43    """The style of angle representation."""
 44    dim_angle: int = 8
 45    """The size of the pairwise vectors use for triplet combinations."""
 46    nmax_angle: int = 4
 47    """The dimension of the angle representation (minus one)."""
 48    zeta: float = 14.1
 49    """The zeta parameter for the model ANI angular representation."""
 50    angle_combine_pairs: bool = True
 51    """Whether to combine angle pairs instead of average distance embedding like in ANI."""
 52
 53    message_passing: bool = True
 54    """Whether to use message passing in the model."""
 55    att_dim: int = 1
 56    """The hidden size for the attention mechanism (only used when message-passing is disabled)."""
 57
 58    lmax: int = 0
 59    """The maximum order of spherical tensors."""
 60    nchannels_l: int = 16
 61    """The number of channels for spherical tensors."""
 62    n_tp: int = 1
 63    """The number of tensor products performed at each layer."""
 64    ignore_irreps_parity: bool = False
 65    """Whether to ignore the parity of the irreps in the tensor product."""
 66    edge_tp: bool = False
 67    """Whether to perform a tensor product on edges before sending messages."""
 68    resolve_wij_l: bool = False
 69    """Equivariant message weights are l-dependent."""
 70
 71    species_init: bool = False
 72    """Whether to initialize the embedding using the species encoding."""
 73    mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 74    """The hidden layer sizes for the mixing network."""
 75    pair_mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 76    """The hidden layer sizes for the pair mixing network."""
 77    activation: Union[Callable, str] = "silu"
 78    """The activation function for the mixing network."""
 79    kernel_init: Union[str, Callable] = "lecun_normal()"
 80    """The kernel initialization function for Dense operations."""
 81    activation_mixing: Union[Callable, str] = "tssr3"
 82    """The activation function applied after mixing."""
 83    layer_normalization: bool = False
 84    """Whether to apply layer normalization after each layer."""
 85    use_bias: bool = True
 86    """Whether to use bias in the Dense operations."""
 87
 88    graph_key: str = "graph"
 89    """The key for the graph data in the inputs dictionary."""
 90    graph_angle_key: Optional[str] = None
 91    """The key for the angle graph data in the inputs dictionary."""
 92    embedding_key: Optional[str] = None
 93    """The key for the embedding data in the output dictionary."""
 94    pair_embedding_key: Optional[str] = None
 95    """The key for the pair embedding data in the output dictionary."""
 96
 97    species_encoding: Union[dict, str] = dataclasses.field(default_factory=dict)
 98    """If `str`, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 99    radial_basis: dict = dataclasses.field(default_factory=dict)
100    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
101    radial_basis_angle: Optional[dict] = None
102    """The dictionary of parameters for radial basis functions for angle embedding. 
103        If None, the radial basis for angles is the same as the radial basis for distances."""
104
105    graph_lode: Optional[str] = None
106    """The key for the lode graph data in the inputs dictionary."""
107    lode_channels: Union[int, Sequence[int]] = 8
108    """The number of channels for lode."""
109    lmax_lode: int = 0
110    """The maximum order of spherical tensors for lode."""
111    a_lode: float = -1.
112    """The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode."""
113    lode_resolve_l: bool = True
114    """Whether to resolve the lode channels by l."""
115    lode_multipole_interaction: bool = True
116    """Whether to interact with the multipole moments of the lode graph."""
117    lode_direct_multipoles: bool = True
118    """Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction."""
119    lode_equi_full_combine: bool = False
120    lode_normalize_l: bool = False
121    lode_use_field_norm: bool = True
122    lode_rshort: Optional[float] = None
123    lode_dshort: float = 0.5
124    
125
126    charge_embedding: bool = False
127    """Whether to include charge embedding."""
128    total_charge_key: str = "total_charge"
129    """The key for the total charge data in the inputs dictionary."""
130
131    block_index_key: Optional[str] = None
132    """The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network."""
133
134    FID: ClassVar[str] = "CRATE"
135
136    @nn.compact
137    def __call__(self, inputs):
138        species = inputs["species"]
139        assert (
140            len(species.shape) == 1
141        ), "Species must be a 1D array (batches must be flattened)"
142        reduce_memory =  "reduce_memory" in inputs.get("flags", {})
143
144        kernel_init = (
145            initializer_from_str(self.kernel_init)
146            if isinstance(self.kernel_init, str)
147            else self.kernel_init
148        )
149
150        actmix = activation_from_str(self.activation_mixing)
151
152        ##################################################
153        graph = inputs[self.graph_key]
154        use_angles = self.graph_angle_key is not None
155        if use_angles:
156            graph_angle = inputs[self.graph_angle_key]
157
158            # Check that the graph_angle is a subgraph of graph
159            correct_graph = (
160                self.graph_angle_key == self.graph_key
161                or self._graphs_properties[self.graph_angle_key]["parent_graph"]
162                == self.graph_key
163            )
164            assert (
165                correct_graph
166            ), f"graph_angle_key={self.graph_angle_key} must be a subgraph of graph_key={self.graph_key}"
167            assert (
168                "angles" in graph_angle
169            ), f"Graph {self.graph_angle_key} must contain angles"
170            # check if graph_angle is a filtered graph
171            filtered = "parent_graph" in self._graphs_properties[self.graph_angle_key]
172            if filtered:
173                filter_indices = graph_angle["filter_indices"]
174
175        ##################################################
176        ### SPECIES ENCODING ###
177        if isinstance(self.species_encoding, str):
178            zi = inputs[self.species_encoding]
179        else:
180            zi = SpeciesEncoding(
181                **self.species_encoding, name="SpeciesEncoding"
182            )(species)
183
184        
185        if self.layer_normalization:
186            def layer_norm(x):
187                mu = jnp.mean(x,axis=-1,keepdims=True)
188                dx = x-mu
189                var = jnp.mean(dx**2,axis=-1,keepdims=True)
190                sig = (1.e-6 + var)**(-0.5)
191                return dx*sig
192        else:
193            layer_norm = lambda x:x
194                
195
196        if self.charge_embedding:
197            xi, qi = jnp.split(
198                nn.Dense(self.dim + 1, use_bias=False, name="ChargeEncoding")(zi),
199                [self.dim],
200                axis=-1,
201            )
202            batch_index = inputs["batch_index"]
203            natoms = inputs["natoms"]
204            nsys = natoms.shape[0]
205            Zi = jnp.asarray(VALENCE_ELECTRONS)[species]
206            Ntot = jax.ops.segment_sum(Zi, batch_index, nsys) - inputs.get(
207                self.total_charge_key, jnp.zeros(nsys)
208            )
209            ai = jax.nn.softplus(qi.squeeze(-1))
210            A = jax.ops.segment_sum(ai, batch_index, nsys)
211            Ni = ai * (Ntot / A)[batch_index]
212            charge_embedding = positional_encoding(Ni, self.dim)
213            xi = layer_norm(xi + charge_embedding)
214        elif self.species_init:
215            xi = layer_norm(nn.Dense(self.dim, use_bias=False, name="SpeciesInit")(zi))
216        else:
217            xi = zi
218
219        ##################################################
220        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
221        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
222        distances = graph["distances"]
223        switch = graph["switch"][:, None]
224
225        ### COMPUTE RADIAL BASIS ###
226        radial_basis = RadialBasis(
227            **{
228                **self.radial_basis,
229                "end": cutoff,
230                "name": f"RadialBasis",
231            }
232        )(distances)
233
234        do_lode = self.graph_lode is not None
235        if do_lode:
236            graph_lode = inputs[self.graph_lode]
237            switch_lode = graph_lode["switch"][:, None]
238
239            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
240            r = graph_lode["distances"][:, None]
241            rc = self._graphs_properties[self.graph_lode]["cutoff"]
242            
243            lmax_lr = self.lmax_lode
244            equivariant_lode = lmax_lr > 0
245            assert lmax_lr >=0, f"lmax_lode must be >= 0, got {lmax_lr}"
246            if self.lode_multipole_interaction:
247                assert lmax_lr <= self.lmax, f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
248            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
249            if self.lode_resolve_l and equivariant_lode:
250                ls_lr = np.arange(lmax_lr + 1)
251            else:
252                ls_lr = np.array([0])
253
254            if self.a_lode > 0:
255                a = self.a_lode**2
256            else:
257                a = (
258                    self.param(
259                        "a_lr",
260                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[None, :],
261                    )
262                    ** 2
263                )
264            rc2a = rc**2 + a
265            ls_lr = 0.5 * (ls_lr[None, :] + 1)
266            ### minimal radial basis for long range (damped coulomb)
267            eij_lr = (
268                1.0 / (r**2 + a) ** ls_lr
269                - 1.0 / rc2a**ls_lr
270                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
271            ) * switch_lode
272
273            if self.lode_rshort is not None:
274                rs = self.lode_rshort
275                d = self.lode_dshort
276                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
277                    r < rs + d
278                ) + (r > rs + d)
279                eij_lr = eij_lr * switch_short
280
281            # dim_lr = self.nchannels_lode
282            nchannels_lode = (
283                [self.lode_channels] * self.nlayers
284                if isinstance(self.lode_channels, int)
285                else self.lode_channels
286            )
287            dim_lr = nchannels_lode
288            
289            if equivariant_lode:
290                if self.lode_resolve_l:
291                    eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
292                Yij = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
293                    graph_lode["vec"] / r
294                )
295                eij_lr = (eij_lr * Yij)[:, None, :]
296                dim_lr = [d * (lmax_lr + 1) for d in dim_lr]
297            
298
299        ##################################################
300        ### GET ANGLES ###
301        if use_angles:
302            angles = graph_angle["angles"][:, None]
303            angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"]
304            switch_angles = graph_angle["switch"][:, None]
305            central_atom = graph_angle["central_atom"]
306
307            if not self.angle_combine_pairs:
308                assert (
309                    self.radial_basis_angle is not None
310                ), "radial_basis_angle must be specified if angle_combine_pairs=False"
311
312            ### COMPUTE RADIAL BASIS FOR ANGLES ###
313            if self.radial_basis_angle is not None:
314                dangles = graph_angle["distances"]
315                swang = switch_angles
316                if not self.angle_combine_pairs:
317                    dangles = 0.5 * (dangles[angle_src] + dangles[angle_dst])
318                    swang = switch_angles[angle_src] * switch_angles[angle_dst]
319                radial_basis_angle = (
320                    RadialBasis(
321                        **{
322                            **self.radial_basis_angle,
323                            "end": self._graphs_properties[self.graph_angle_key][
324                                "cutoff"
325                            ],
326                            "name": f"RadialBasisAngle",
327                        }
328                    )(dangles)
329                    * swang
330                )
331
332            else:
333                if filtered:
334                    radial_basis_angle = radial_basis[filter_indices] * switch_angles
335                else:
336                    radial_basis_angle = radial_basis * switch
337
338        radial_basis = radial_basis * switch
339
340        # # add covalent indicator
341        # rc = jnp.asarray([d/au.BOHR for d in D3_COV_RADII])[species]
342        # rcij = rc[edge_src] + rc[edge_dst]
343        # fact = graph["switch"]*(2*distances/rcij)*jnp.exp(-0.5 * ((distances - rcij)/(0.1*rcij)) ** 2)
344        # radial_basis = jnp.concatenate([radial_basis,fact[:,None]],axis=-1)
345        # if use_angles:
346        #     rcij = rc[graph_angle["edge_src"]] + rc[graph_angle["edge_dst"]]
347        #     dangles = graph_angle["distances"]
348        #     fact = graph_angle["switch"]*((2*dangles/rcij))*jnp.exp(-0.5 * ((dangles - rcij)/(0.1*rcij))**2)
349        #     radial_basis_angle = jnp.concatenate([radial_basis_angle,fact[:,None]],axis=-1)
350
351
352        ##################################################
353        if self.lmax > 0:
354            Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
355                graph["vec"] / graph["distances"][:, None]
356            )[:, None, :]
357            Yij = jnp.broadcast_to(Yij, (Yij.shape[0], self.nchannels_l, Yij.shape[2]))
358            nrep_l = np.array([2 * l + 1 for l in range(self.lmax + 1)], dtype=np.int32)
359            # ls = [0]
360            # for l in range(1, self.lmax + 1):
361            #    ls = ls + [l] * (2 * l + 1)
362            #ls = jnp.asarray(np.array(ls)[None, :], dtype=distances.dtype)
363            #lcut = (0.5 + 0.5 * jnp.cos((np.pi / cutoff) * distances[:, #None])) ** (
364            #    ls + 1
365            #)
366            # lcut = jnp.where(graph["edge_mask"][:, None], lcut, 0.0)
367            # rijl1 = (lcut * distances[:, None] ** ls)[:, None, :]
368
369        ##################################################
370        if use_angles:
371            ### ANGULAR BASIS ###
372            if self.angle_style == "fourier":
373                # build fourier series for angles
374                nangles = self.param(
375                    f"nangles",
376                    lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :],
377                    self.nmax_angle + 1,
378                )
379
380                phi = self.param(
381                    f"phi",
382                    lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype),
383                    self.nmax_angle + 1,
384                )
385                xa = jnp.cos(nangles * angles + phi)
386            elif self.angle_style == "fourier_full":
387                # build fourier series for angles including sin terms
388                nangles = self.param(
389                    f"nangles",
390                    lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :],
391                    self.nmax_angle + 1,
392                )
393
394                phi = self.param(
395                    f"phi",
396                    lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype),
397                    2 * self.nmax_angle + 1,
398                )
399                xac = jnp.cos(nangles * angles + phi[:, : self.nmax_angle + 1])
400                xas = jnp.sin(nangles[:, 1:] * angles + phi[:, self.nmax_angle + 1 :])
401                xa = jnp.concatenate([xac, xas], axis=-1)
402            elif self.angle_style == "ani":
403                # ANI-style angle embedding
404                angle_start = np.pi / (2 * (self.nmax_angle + 1))
405                shiftZ = self.param(
406                    f"shiftZ",
407                    lambda key, dim: jnp.asarray(
408                        (np.linspace(0, np.pi, dim + 1) + angle_start)[None, :-1],
409                        dtype=distances.dtype,
410                    ),
411                    self.nmax_angle + 1,
412                )
413                zeta = self.param(
414                    f"zeta",
415                    lambda key: jnp.asarray(self.zeta, dtype=distances.dtype),
416                )
417                xa = (0.5 + 0.5 * jnp.cos(angles - shiftZ)) ** zeta
418            else:
419                raise ValueError(f"Unknown angle style {self.angle_style}")
420            xa = xa[:, None, :]
421            if not self.angle_combine_pairs:
422                if reduce_memory: raise NotImplementedError("Angle embedding not implemented with reduce_memory")
423                xa = (xa * radial_basis_angle[:, :, None]).reshape(
424                    -1, 1, xa.shape[1] * radial_basis_angle.shape[1]
425                )
426
427            if self.pair_embedding_key is not None:
428                if filtered:
429                    ang_pair_src = filter_indices[angle_src]
430                    ang_pair_dst = filter_indices[angle_dst]
431                else:
432                    ang_pair_src = angle_src
433                    ang_pair_dst = angle_dst
434                ang_pairs = jnp.concatenate((ang_pair_src, ang_pair_dst))
435
436        ##################################################
437        ### DIMENSIONS ###
438        dim_src = (
439            [self.dim_src] * self.nlayers
440            if isinstance(self.dim_src, int)
441            else self.dim_src
442        )
443        assert (
444            len(dim_src) == self.nlayers
445        ), f"dim_src must be an integer or a list of length {self.nlayers}"
446        dim_dst = self.dim_dst
447        # dim_dst = (
448        #     [self.dim_dst] * self.nlayers
449        #     if isinstance(self.dim_dst, int)
450        #     else self.dim_dst
451        # )
452        # assert (
453        #     len(dim_dst) == self.nlayers
454        # ), f"dim_dst must be an integer or a list of length {self.nlayers}"
455
456        if use_angles:
457            dim_angle = (
458                [self.dim_angle] * self.nlayers
459                if isinstance(self.dim_angle, int)
460                else self.dim_angle
461            )
462            assert (
463                len(dim_angle) == self.nlayers
464            ), f"dim_angle must be an integer or a list of length {self.nlayers}"
465            # nmax_angle = [self.nmax_angle]*self.nlayers if isinstance(self.nmax_angle, int) else self.nmax_angle
466            # assert len(nmax_angle) == self.nlayers, f"nmax_angle must be an integer or a list of length {self.nlayers}"
467        
468        initialize_e3 = True
469        if self.lmax > 0:
470            n_tp = (
471                [self.n_tp] * self.nlayers
472                if isinstance(self.n_tp, int)
473                else self.n_tp
474            )
475            assert (
476                len(n_tp) == self.nlayers
477            ), f"n_tp must be an integer or a list of length {self.nlayers}"
478
479
480        message_passing = (
481            [self.message_passing] * self.nlayers
482            if isinstance(self.message_passing, bool)
483            else self.message_passing
484        )
485        assert (
486            len(message_passing) == self.nlayers
487        ), f"message_passing must be a boolean or a list of length {self.nlayers}"
488
489        ##################################################
490        ### INITIALIZE PAIR EMBEDDING ###
491        if self.pair_embedding_key is not None:
492            xij_s,xij_d = jnp.split(nn.Dense(2*dim_dst, name="pair_init_linear")(zi), [dim_dst], axis=-1)
493            xij = layer_norm(xij_s[edge_src]*xij_d[edge_dst])
494
495        ##################################################
496        if self.keep_all_layers:
497            xis = []
498        
499        ### LOOP OVER LAYERS ###
500        for layer in range(self.nlayers):
501            ##################################################
502            ### COMPACT DESCRIPTORS ###
503            si, si_dst = jnp.split(
504                nn.Dense(
505                    dim_src[layer] + dim_dst,
506                    name=f"species_linear_{layer}",
507                    use_bias=self.use_bias,
508                )(xi),
509                [
510                    dim_src[layer],
511                ],
512                axis=-1,
513            )
514
515            ##################################################
516            if message_passing[layer] or layer == 0:
517                ### MESSAGE PASSING ###
518                si_mp = si_dst[edge_dst]
519            else:
520                # if layer == 0:
521                #     si_mp = si_dst[edge_dst]
522                ### ATTENTION TO SIMULATE MP ###
523                Q = nn.Dense(
524                    dim_dst * self.att_dim, name=f"queries_{layer}", use_bias=False
525                )(si_dst).reshape(-1, dim_dst, self.att_dim)[edge_src]
526                K = nn.Dense(
527                    dim_dst * self.att_dim, name=f"keys_{layer}", use_bias=False
528                )(zi).reshape(-1, dim_dst, self.att_dim)[edge_dst]
529
530                si_mp = (K * Q).sum(axis=-1) / self.att_dim**0.5
531                # Vmp = jax.ops.segment_sum(
532                #     (KQ * switch)[:, :, None] * Yij, edge_src, species.shape[0]
533                # )
534                # si_mp = (Vmp[edge_src] * Yij).sum(axis=-1)
535                # Q = nn.Dense(
536                #     dim_dst * dim_dst, name=f"queries_{layer}", use_bias=False
537                # )(si_dst).reshape(-1, dim_dst, dim_dst)
538                # si_mp = (
539                #     si_mp + jax.vmap(jnp.dot)(Q[edge_src], si_mp) / self.dim_dst**0.5
540                # )
541
542            if self.pair_embedding_key is not None:
543                si_mp = si_mp + xij
544
545            ##################################################
546            ### PAIR EMBEDDING ###
547            if reduce_memory:
548                Li = jnp.zeros((species.shape[0]* radial_basis.shape[1],si_mp.shape[1]),dtype=si_mp.dtype)
549                for i in range(radial_basis.shape[1]):
550                    indices = i + edge_src*radial_basis.shape[1]
551                    Li = Li.at[indices].add(si_mp*radial_basis[:,i,None])
552                Li = Li.reshape(species.shape[0], radial_basis.shape[1]*si_mp.shape[1])
553            else:
554                Lij = (si_mp[:, None, :] * radial_basis[:, :, None]).reshape(
555                    radial_basis.shape[0], si_mp.shape[1] * radial_basis.shape[1]
556                )
557                ### AGGREGATE PAIR EMBEDDING ###
558                Li = jax.ops.segment_sum(Lij, edge_src, species.shape[0])
559
560            ### CONCATENATE EMBEDDING COMPONENTS ###
561            components = [si, Li]
562            if self.pair_embedding_key is not None:
563                if reduce_memory: raise NotImplementedError("Pair embedding not implemented with reduce_memory")
564                components_pair = [si[edge_src], xij, Lij]
565
566
567            ##################################################
568            ### ANGLE EMBEDDING ###
569            if use_angles and dim_angle[layer]>0:
570                si_mp_ang = si_mp[filter_indices] if filtered else si_mp
571                if self.angle_combine_pairs:
572                    Wa = self.param(
573                        f"Wa_{layer}",
574                        nn.initializers.normal(
575                            stddev=1.0
576                            / (si_mp.shape[1] * radial_basis_angle.shape[1]) ** 0.5
577                        ),
578                        (si_mp.shape[1], radial_basis_angle.shape[1], dim_angle[layer]),
579                    )
580                    Da = jnp.einsum(
581                        "...i,...j,ijk->...k",
582                        si_mp_ang,
583                        radial_basis_angle,
584                        Wa,
585                    )
586
587                else:
588                    if message_passing[layer] or layer == 0:
589                        Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")(
590                            xi
591                        )[graph_angle["edge_dst"]]
592                    else:
593                        Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")(
594                            si_mp_ang
595                        )
596
597                Da = Da[angle_dst] * Da[angle_src]
598                ## combine pair and angle info
599                if reduce_memory:
600                    ang_embedding = jnp.zeros((species.shape[0]* Da.shape[-1],xa.shape[-1]),dtype=Da.dtype)
601                    for i in range(Da.shape[-1]):
602                        indices = i + central_atom*Da.shape[-1]
603                        ang_embedding = ang_embedding.at[indices].add(Da[:,i,None]*xa[:,0,:])
604                    ang_embedding = ang_embedding.reshape(species.shape[0], xa.shape[-1]*Da.shape[-1])
605                else:
606                    radang = (xa * Da[:, :, None]).reshape(
607                        (-1, Da.shape[1] * xa.shape[2])
608                    )
609                    ### AGGREGATE  ANGLE EMBEDDING ###
610                    ang_embedding = jax.ops.segment_sum(
611                        radang, central_atom, species.shape[0]
612                    )
613                    
614
615                components.append(ang_embedding)
616
617                if self.pair_embedding_key is not None:
618                    ang_ij = jax.ops.segment_sum(
619                        jnp.concatenate((radang, radang)),
620                        ang_pairs,
621                        edge_src.shape[0],
622                    )
623                    components_pair.append(ang_ij)
624            
625            ##################################################
626            ### EQUIVARIANT EMBEDDING ###
627            if self.lmax > 0 and n_tp[layer] >= 0:
628                if initialize_e3 or not message_passing[layer]:
629                    Vij = Yij
630                elif self.edge_tp:
631                    Vij = FilteredTensorProduct(
632                            self.lmax, self.lmax, name=f"edge_tp_{layer}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False
633                        )(Vi[edge_dst], Yij)
634                else:
635                    Vij = Vi[edge_dst]
636
637                ### compute channel weights
638                dim_wij = self.nchannels_l
639                if self.resolve_wij_l:
640                    dim_wij=self.nchannels_l*(self.lmax+1)
641
642                eij = Lij if self.pair_embedding_key is None else jnp.concatenate([Lij,xij*switch],axis=-1)
643                wij = nn.Dense(
644                        dim_wij, name=f"e3_channel_{layer}", use_bias=False
645                    )(eij)
646                if self.resolve_wij_l:
647                    wij = jnp.repeat(wij.reshape(-1,self.nchannels_l,self.lmax+1),nrep_l,axis=-1)
648                else:
649                    wij = wij[:,:,None]
650                
651                ### aggregate equivariant messages
652                drhoi = jax.ops.segment_sum(
653                    wij * Vij,
654                    edge_src,
655                    species.shape[0],
656                )
657
658                Vi0 = []
659                if initialize_e3:
660                    rhoi = drhoi
661                    Vi = ChannelMixingE3(
662                        self.lmax,
663                        self.nchannels_l,
664                        self.nchannels_l,
665                        name=f"e3_initial_mixing_{layer}",
666                    )(rhoi)
667                    # assert n_tp[layer] > 0, "n_tp must be > 0 for the first equivariant layer."
668                else:
669                    rhoi = rhoi + drhoi
670                    # if message_passing[layer]:
671                        # Vi0.append(drhoi[:, :, 0])
672                initialize_e3 = False
673                if n_tp[layer] > 0:
674                    for itp in range(n_tp[layer]):
675                        dVi = FilteredTensorProduct(
676                            self.lmax, self.lmax, name=f"tensor_product_{layer}_{itp}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False
677                        )(rhoi, Vi)
678                        Vi = ChannelMixing(
679                                self.lmax,
680                                self.nchannels_l,
681                                self.nchannels_l,
682                                name=f"tp_mixing_{layer}_{itp}",
683                            )(Vi + dVi)
684                        Vi0.append(dVi[:, :, 0])
685                    Vi0 = jnp.concatenate(Vi0, axis=-1)
686                    components.append(Vi0)
687
688                if self.pair_embedding_key is not None:
689                    Vij = Vi[edge_src]*Yij
690                    Vij0 = [Vij[...,0]]
691                    for l in range(1,self.lmax+1):
692                        Vij0.append(Vij[...,l**2:(l+1)**2].sum(axis=-1))
693                    Vij0 = jnp.concatenate(Vij0,axis=-1)
694                    components_pair.append(Vij0)
695
696            ##################################################
697            ### CONCATENATE EMBEDDING COMPONENTS ###
698            if do_lode and nchannels_lode[layer] > 0:
699                zj = nn.Dense(dim_lr[layer], use_bias=False, name=f"LODE_{layer}")(xi)
700                if equivariant_lode:
701                    zj = zj.reshape(
702                        (species.shape[0], nchannels_lode[layer], lmax_lr + 1)
703                    ).repeat(nrep_lr, axis=-1)
704                xi_lr = jax.ops.segment_sum(
705                    eij_lr * zj[edge_dst_lr], edge_src_lr, species.shape[0]
706                )
707                if equivariant_lode:
708                    assert self.lode_use_field_norm or self.lode_multipole_interaction, "equivariant LODE requires field norm or multipole interaction"
709                    if self.lode_multipole_interaction:
710                        if initialize_e3:
711                            raise ValueError("equivariant LODE used before local equivariants initialized")
712                        size_l_lr = (lmax_lr+1)**2
713                        if self.lode_direct_multipoles:
714                            assert nchannels_lode[layer] <= self.nchannels_l
715                            Mi = Vi[:, : nchannels_lode[layer], :size_l_lr]
716                        else:
717                            Mi = ChannelMixingE3(
718                                lmax_lr,
719                                self.nchannels_l,
720                                nchannels_lode[layer],
721                                name=f"e3_LODE_{layer}",
722                            )(Vi[...,:size_l_lr])
723                        Mi_lr = Mi * xi_lr
724                    components.append(xi_lr[:, :, 0])
725                    if self.lode_use_field_norm and self.lode_equi_full_combine:
726                        xi_lr1 = ChannelMixing(
727                            lmax_lr,
728                            nchannels_lode[layer],
729                            nchannels_lode[layer],
730                            name=f"LODE_mixing_{layer}",
731                        )(xi_lr)
732                    norm = 1.
733                    for l in range(1, lmax_lr + 1):
734                        if self.lode_normalize_l:
735                            norm = 1. / (2 * l + 1)
736                        if self.lode_multipole_interaction:
737                            components.append(Mi_lr[:, :, l**2 : (l + 1) ** 2].sum(axis=-1)*norm)
738
739                        if self.lode_use_field_norm:
740                            if self.lode_equi_full_combine:
741                                components.append((xi_lr[:,:,l**2 : (l + 1) ** 2]*xi_lr1[:,:,l**2 : (l + 1) ** 2]).sum(axis=-1)*norm)
742                            else:
743                                components.append(
744                                    ((xi_lr[:, :, l**2 : (l + 1) ** 2]) ** 2).sum(axis=-1)*norm
745                                )
746                else:
747                    components.append(xi_lr)
748
749            dxi = jnp.concatenate(components, axis=-1)
750
751            ##################################################
752            ### CONCATENATE PAIR EMBEDDING COMPONENTS ###
753            if self.pair_embedding_key is not None:
754                dxij = jnp.concatenate(components_pair, axis=-1)
755
756            ##################################################
757            ### MIX AND APPLY NONLINEARITY ###
758            if self.block_index_key is not None:
759                block_index = inputs[self.block_index_key]
760                dxi = actmix(BlockIndexNet(
761                        output_dim=self.dim,
762                        hidden_neurons=self.mixing_hidden,
763                        activation=self.activation,
764                        name=f"dxi_{layer}",
765                        use_bias=self.use_bias,
766                        kernel_init=kernel_init,
767                    )((species,dxi, block_index))
768                )
769            else:
770                dxi = actmix(
771                    FullyConnectedNet(
772                        [*self.mixing_hidden, self.dim],
773                        activation=self.activation,
774                        name=f"dxi_{layer}",
775                        use_bias=self.use_bias,
776                        kernel_init=kernel_init,
777                    )(dxi)
778                )
779
780            if self.pair_embedding_key is not None:
781                ### UPDATE PAIR EMBEDDING ###
782                # dxij = tssr3(nn.Dense(dim_dst, name=f"dxij_{layer}",use_bias=False)(dxij))
783                dxij = actmix(
784                    FullyConnectedNet(
785                        [*self.pair_mixing_hidden, dim_dst],
786                        activation=self.activation,
787                        name=f"dxij_{layer}",
788                        use_bias=False,
789                        kernel_init=kernel_init,
790                    )(dxij)
791                )
792                xij = layer_norm(xij + dxij)
793
794            ##################################################
795            ### UPDATE EMBEDDING ###
796            if layer == 0 and not (self.species_init or self.charge_embedding):
797                xi = layer_norm(dxi)
798            else:
799                ### FORGET GATE ###
800                R = jax.nn.sigmoid(
801                    self.param(
802                        f"retention_{layer}",
803                        nn.initializers.normal(),
804                        (xi.shape[-1],),
805                    )
806                )
807                xi = layer_norm(R[None, :] * xi + dxi)
808
809            if self.keep_all_layers:
810                xis.append(xi)
811
812        embedding_key = (
813            self.embedding_key if self.embedding_key is not None else self.name
814        )
815        output = {
816            **inputs,
817            embedding_key: xi,
818        }
819        if self.lmax > 0:
820            output[embedding_key + "_tensor"] = Vi
821        if self.keep_all_layers:
822            output[embedding_key + "_layers"] = jnp.stack(xis, axis=1)
823        if self.charge_embedding:
824            output[embedding_key + "_charge"] = charge_embedding
825        if self.pair_embedding_key is not None:
826            output[self.pair_embedding_key] = xij
827        return output

Configurable Resources ATomic Environment

FID : CRATE

This class represents the CRATE (Configurable Resources ATomic Environment) embedding model. It is used to encode atomic environments using multiple sources of information (radial, angular, E(3), message-passing, LODE, etc...)

CRATEmbedding( _graphs_properties: Dict, dim: int = 256, nlayers: int = 2, keep_all_layers: bool = False, dim_src: int = 64, dim_dst: int = 32, angle_style: str = 'fourier', dim_angle: int = 8, nmax_angle: int = 4, zeta: float = 14.1, angle_combine_pairs: bool = True, message_passing: bool = True, att_dim: int = 1, lmax: int = 0, nchannels_l: int = 16, n_tp: int = 1, ignore_irreps_parity: bool = False, edge_tp: bool = False, resolve_wij_l: bool = False, species_init: bool = False, mixing_hidden: Sequence[int] = <factory>, pair_mixing_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', kernel_init: Union[str, Callable] = 'lecun_normal()', activation_mixing: Union[Callable, str] = 'tssr3', layer_normalization: bool = False, use_bias: bool = True, graph_key: str = 'graph', graph_angle_key: Optional[str] = None, embedding_key: Optional[str] = None, pair_embedding_key: Optional[str] = None, species_encoding: Union[dict, str] = <factory>, radial_basis: dict = <factory>, radial_basis_angle: Optional[dict] = None, graph_lode: Optional[str] = None, lode_channels: Union[int, Sequence[int]] = 8, lmax_lode: int = 0, a_lode: float = -1.0, lode_resolve_l: bool = True, lode_multipole_interaction: bool = True, lode_direct_multipoles: bool = True, lode_equi_full_combine: bool = False, lode_normalize_l: bool = False, lode_use_field_norm: bool = True, lode_rshort: Optional[float] = None, lode_dshort: float = 0.5, charge_embedding: bool = False, total_charge_key: str = 'total_charge', block_index_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 256

The size of the embedding vectors.

nlayers: int = 2

The number of interaction layers in the model.

keep_all_layers: bool = False

Whether to output all layers.

dim_src: int = 64

The size of the source embedding vectors.

dim_dst: int = 32

The size of the destination embedding vectors.

angle_style: str = 'fourier'

The style of angle representation.

dim_angle: int = 8

The size of the pairwise vectors use for triplet combinations.

nmax_angle: int = 4

The dimension of the angle representation (minus one).

zeta: float = 14.1

The zeta parameter for the model ANI angular representation.

angle_combine_pairs: bool = True

Whether to combine angle pairs instead of average distance embedding like in ANI.

message_passing: bool = True

Whether to use message passing in the model.

att_dim: int = 1

The hidden size for the attention mechanism (only used when message-passing is disabled).

lmax: int = 0

The maximum order of spherical tensors.

nchannels_l: int = 16

The number of channels for spherical tensors.

n_tp: int = 1

The number of tensor products performed at each layer.

ignore_irreps_parity: bool = False

Whether to ignore the parity of the irreps in the tensor product.

edge_tp: bool = False

Whether to perform a tensor product on edges before sending messages.

resolve_wij_l: bool = False

Equivariant message weights are l-dependent.

species_init: bool = False

Whether to initialize the embedding using the species encoding.

mixing_hidden: Sequence[int]

The hidden layer sizes for the mixing network.

pair_mixing_hidden: Sequence[int]

The hidden layer sizes for the pair mixing network.

activation: Union[Callable, str] = 'silu'

The activation function for the mixing network.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization function for Dense operations.

activation_mixing: Union[Callable, str] = 'tssr3'

The activation function applied after mixing.

layer_normalization: bool = False

Whether to apply layer normalization after each layer.

use_bias: bool = True

Whether to use bias in the Dense operations.

graph_key: str = 'graph'

The key for the graph data in the inputs dictionary.

graph_angle_key: Optional[str] = None

The key for the angle graph data in the inputs dictionary.

embedding_key: Optional[str] = None

The key for the embedding data in the output dictionary.

pair_embedding_key: Optional[str] = None

The key for the pair embedding data in the output dictionary.

species_encoding: Union[dict, str]

If str, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding.

radial_basis: dict

The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.

radial_basis_angle: Optional[dict] = None

The dictionary of parameters for radial basis functions for angle embedding. If None, the radial basis for angles is the same as the radial basis for distances.

graph_lode: Optional[str] = None

The key for the lode graph data in the inputs dictionary.

lode_channels: Union[int, Sequence[int]] = 8

The number of channels for lode.

lmax_lode: int = 0

The maximum order of spherical tensors for lode.

a_lode: float = -1.0

The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode.

lode_resolve_l: bool = True

Whether to resolve the lode channels by l.

lode_multipole_interaction: bool = True

Whether to interact with the multipole moments of the lode graph.

lode_direct_multipoles: bool = True

Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction.

lode_equi_full_combine: bool = False
lode_normalize_l: bool = False
lode_use_field_norm: bool = True
lode_rshort: Optional[float] = None
lode_dshort: float = 0.5
charge_embedding: bool = False

Whether to include charge embedding.

total_charge_key: str = 'total_charge'

The key for the total charge data in the inputs dictionary.

block_index_key: Optional[str] = None

The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network.

FID: ClassVar[str] = 'CRATE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None