fennol.models.embeddings.crate

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import dataclasses
  5import numpy as np
  6from typing import Dict, Union, Callable, Sequence, Optional, Tuple, ClassVar
  7
  8from ..misc.encodings import SpeciesEncoding, RadialBasis, positional_encoding
  9from ...utils.spherical_harmonics import generate_spherical_harmonics, CG_SO3
 10from ...utils.activations import activation_from_str
 11from ...utils.initializers import initializer_from_str
 12from ..misc.nets import FullyConnectedNet, BlockIndexNet
 13from ..misc.e3 import ChannelMixing, ChannelMixingE3, FilteredTensorProduct
 14from ...utils.periodic_table import D3_COV_RADII, D3_VDW_RADII, VALENCE_ELECTRONS
 15from ...utils import AtomicUnits as au
 16
 17class CRATEmbedding(nn.Module):
 18    """Configurable Resources ATomic Environment
 19
 20    FID : CRATE
 21
 22    This class represents the CRATE (Configurable Resources ATomic Environment) embedding model.
 23    It is used to encode atomic environments using multiple sources of information
 24      (radial, angular, E(3), message-passing, LODE, etc...)
 25    """
 26
 27    _graphs_properties: Dict
 28
 29    dim: int = 256
 30    """The size of the embedding vectors."""
 31    nlayers: int = 2
 32    """The number of interaction layers in the model."""
 33    keep_all_layers: bool = False
 34    """Whether to output all layers."""
 35
 36    dim_src: int = 64
 37    """The size of the source embedding vectors."""
 38    dim_dst: int = 32
 39    """The size of the destination embedding vectors."""
 40
 41    angle_style: str = "fourier"
 42    """The style of angle representation."""
 43    dim_angle: int = 8
 44    """The size of the pairwise vectors use for triplet combinations."""
 45    nmax_angle: int = 4
 46    """The dimension of the angle representation (minus one)."""
 47    zeta: float = 14.1
 48    """The zeta parameter for the model ANI angular representation."""
 49    angle_combine_pairs: bool = True
 50    """Whether to combine angle pairs instead of average distance embedding like in ANI."""
 51
 52    message_passing: bool = True
 53    """Whether to use message passing in the model."""
 54    att_dim: int = 1
 55    """The hidden size for the attention mechanism (only used when message-passing is disabled)."""
 56
 57    lmax: int = 0
 58    """The maximum order of spherical tensors."""
 59    nchannels_l: int = 16
 60    """The number of channels for spherical tensors."""
 61    n_tp: int = 1
 62    """The number of tensor products performed at each layer."""
 63    ignore_irreps_parity: bool = False
 64    """Whether to ignore the parity of the irreps in the tensor product."""
 65    edge_tp: bool = False
 66    """Whether to perform a tensor product on edges before sending messages."""
 67    resolve_wij_l: bool = False
 68    """Equivariant message weights are l-dependent."""
 69
 70    species_init: bool = False
 71    """Whether to initialize the embedding using the species encoding."""
 72    mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 73    """The hidden layer sizes for the mixing network."""
 74    pair_mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 75    """The hidden layer sizes for the pair mixing network."""
 76    activation: Union[Callable, str] = "silu"
 77    """The activation function for the mixing network."""
 78    kernel_init: Union[str, Callable] = "lecun_normal()"
 79    """The kernel initialization function for Dense operations."""
 80    activation_mixing: Union[Callable, str] = "tssr3"
 81    """The activation function applied after mixing."""
 82    layer_normalization: bool = False
 83    """Whether to apply layer normalization after each layer."""
 84    use_bias: bool = True
 85    """Whether to use bias in the Dense operations."""
 86
 87    graph_key: str = "graph"
 88    """The key for the graph data in the inputs dictionary."""
 89    graph_angle_key: Optional[str] = None
 90    """The key for the angle graph data in the inputs dictionary."""
 91    embedding_key: Optional[str] = None
 92    """The key for the embedding data in the output dictionary."""
 93    pair_embedding_key: Optional[str] = None
 94    """The key for the pair embedding data in the output dictionary."""
 95
 96    species_encoding: Union[dict, str] = dataclasses.field(default_factory=dict)
 97    """If `str`, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 98    radial_basis: dict = dataclasses.field(default_factory=dict)
 99    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
100    radial_basis_angle: Optional[dict] = None
101    """The dictionary of parameters for radial basis functions for angle embedding. 
102        If None, the radial basis for angles is the same as the radial basis for distances."""
103
104    graph_lode: Optional[str] = None
105    """The key for the lode graph data in the inputs dictionary."""
106    lode_channels: Union[int, Sequence[int]] = 8
107    """The number of channels for lode."""
108    lmax_lode: int = 0
109    """The maximum order of spherical tensors for lode."""
110    a_lode: float = -1.
111    """The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode."""
112    lode_resolve_l: bool = True
113    """Whether to resolve the lode channels by l."""
114    lode_multipole_interaction: bool = True
115    """Whether to interact with the multipole moments of the lode graph."""
116    lode_direct_multipoles: bool = True
117    """Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction."""
118    lode_equi_full_combine: bool = False
119    lode_normalize_l: bool = False
120    lode_use_field_norm: bool = True
121    lode_rshort: Optional[float] = None
122    lode_dshort: float = 0.5
123    lode_extra_powers: Sequence[int] = ()
124    
125
126    charge_embedding: bool = False
127    """Whether to include charge embedding."""
128    total_charge_key: str = "total_charge"
129    """The key for the total charge data in the inputs dictionary."""
130
131    block_index_key: Optional[str] = None
132    """The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network."""
133
134    FID: ClassVar[str] = "CRATE"
135
136    @nn.compact
137    def __call__(self, inputs):
138        species = inputs["species"]
139        assert (
140            len(species.shape) == 1
141        ), "Species must be a 1D array (batches must be flattened)"
142        reduce_memory =  "reduce_memory" in inputs.get("flags", {})
143
144        kernel_init = (
145            initializer_from_str(self.kernel_init)
146            if isinstance(self.kernel_init, str)
147            else self.kernel_init
148        )
149
150        actmix = activation_from_str(self.activation_mixing)
151
152        ##################################################
153        graph = inputs[self.graph_key]
154        use_angles = self.graph_angle_key is not None
155        if use_angles:
156            graph_angle = inputs[self.graph_angle_key]
157
158            # Check that the graph_angle is a subgraph of graph
159            correct_graph = (
160                self.graph_angle_key == self.graph_key
161                or self._graphs_properties[self.graph_angle_key]["parent_graph"]
162                == self.graph_key
163            )
164            assert (
165                correct_graph
166            ), f"graph_angle_key={self.graph_angle_key} must be a subgraph of graph_key={self.graph_key}"
167            assert (
168                "angles" in graph_angle
169            ), f"Graph {self.graph_angle_key} must contain angles"
170            # check if graph_angle is a filtered graph
171            filtered = "parent_graph" in self._graphs_properties[self.graph_angle_key]
172            if filtered:
173                filter_indices = graph_angle["filter_indices"]
174
175        ##################################################
176        ### SPECIES ENCODING ###
177        if isinstance(self.species_encoding, str):
178            zi = inputs[self.species_encoding]
179        else:
180            zi = SpeciesEncoding(
181                **self.species_encoding, name="SpeciesEncoding"
182            )(species)
183
184        
185        if self.layer_normalization:
186            def layer_norm(x):
187                mu = jnp.mean(x,axis=-1,keepdims=True)
188                dx = x-mu
189                var = jnp.mean(dx**2,axis=-1,keepdims=True)
190                sig = (1.e-6 + var)**(-0.5)
191                return dx*sig
192        else:
193            layer_norm = lambda x:x
194                
195
196        if self.charge_embedding:
197            xi, qi = jnp.split(
198                nn.Dense(self.dim + 1, use_bias=False, name="ChargeEncoding")(zi),
199                [self.dim],
200                axis=-1,
201            )
202            batch_index = inputs["batch_index"]
203            natoms = inputs["natoms"]
204            nsys = natoms.shape[0]
205            Zi = jnp.asarray(VALENCE_ELECTRONS)[species]
206            Ntot = jax.ops.segment_sum(Zi, batch_index, nsys) - inputs.get(
207                self.total_charge_key, jnp.zeros(nsys)
208            )
209            ai = jax.nn.softplus(qi.squeeze(-1))
210            A = jax.ops.segment_sum(ai, batch_index, nsys)
211            Ni = ai * (Ntot / A)[batch_index]
212            charge_embedding = positional_encoding(Ni, self.dim)
213            xi = layer_norm(xi + charge_embedding)
214        elif self.species_init:
215            xi = layer_norm(nn.Dense(self.dim, use_bias=False, name="SpeciesInit")(zi))
216        else:
217            xi = zi
218
219        ##################################################
220        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
221        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
222        distances = graph["distances"]
223        switch = graph["switch"][:, None]
224
225        ### COMPUTE RADIAL BASIS ###
226        radial_basis = RadialBasis(
227            **{
228                **self.radial_basis,
229                "end": cutoff,
230                "name": f"RadialBasis",
231            }
232        )(distances)
233
234        do_lode = self.graph_lode is not None
235        if do_lode:
236            graph_lode = inputs[self.graph_lode]
237            switch_lode = graph_lode["switch"][:, None]
238
239            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
240            r = graph_lode["distances"][:, None]
241            rc = self._graphs_properties[self.graph_lode]["cutoff"]
242            
243            lmax_lr = self.lmax_lode
244            equivariant_lode = lmax_lr > 0
245            assert lmax_lr >=0, f"lmax_lode must be >= 0, got {lmax_lr}"
246            if self.lode_multipole_interaction:
247                assert lmax_lr <= self.lmax, f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
248            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
249            if self.lode_resolve_l and equivariant_lode:
250                ls_lr = np.arange(lmax_lr + 1)
251            else:
252                ls_lr = np.array([0])
253
254            nextra_powers = len(self.lode_extra_powers)
255            if nextra_powers > 0:
256                ls_lr = np.concatenate([self.lode_extra_powers,ls_lr])
257
258            if self.a_lode > 0:
259                a = self.a_lode**2
260            else:
261                a = (
262                    self.param(
263                        "a_lr",
264                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[None, :],
265                    )
266                    ** 2
267                )
268            rc2a = rc**2 + a
269            ls_lr = 0.5 * (ls_lr[None, :] + 1)
270            ### minimal radial basis for long range (damped coulomb)
271            eij_lr = (
272                1.0 / (r**2 + a) ** ls_lr
273                - 1.0 / rc2a**ls_lr
274                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
275            ) * switch_lode
276
277            if self.lode_rshort is not None:
278                rs = self.lode_rshort
279                d = self.lode_dshort
280                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
281                    r < rs + d
282                ) + (r >= rs + d)
283                eij_lr = eij_lr * switch_short
284            
285            if nextra_powers>0:
286                eij_lr_extra = eij_lr[:,:nextra_powers]
287                eij_lr = eij_lr[:,nextra_powers:]
288
289
290            # dim_lr = self.nchannels_lode
291            nchannels_lode = (
292                [self.lode_channels] * self.nlayers
293                if isinstance(self.lode_channels, int)
294                else self.lode_channels
295            )
296            dim_lr = nchannels_lode
297            
298            if equivariant_lode:
299                if self.lode_resolve_l:
300                    eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
301                Yij = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
302                    graph_lode["vec"] / r
303                )
304                eij_lr = (eij_lr * Yij)[:, None, :]
305                dim_lr = [d * (lmax_lr + 1) for d in dim_lr]
306            
307            if nextra_powers > 0:
308                eij_lr_extra = eij_lr_extra[:,None,:]
309                extra_dims = [nextra_powers*d for d in nchannels_lode]
310                dim_lr = [d + ed for d,ed in zip(dim_lr,extra_dims)]
311            
312
313        ##################################################
314        ### GET ANGLES ###
315        if use_angles:
316            angles = graph_angle["angles"][:, None]
317            angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"]
318            switch_angles = graph_angle["switch"][:, None]
319            central_atom = graph_angle["central_atom"]
320
321            if not self.angle_combine_pairs:
322                assert (
323                    self.radial_basis_angle is not None
324                ), "radial_basis_angle must be specified if angle_combine_pairs=False"
325
326            ### COMPUTE RADIAL BASIS FOR ANGLES ###
327            if self.radial_basis_angle is not None:
328                dangles = graph_angle["distances"]
329                swang = switch_angles
330                if not self.angle_combine_pairs:
331                    dangles = 0.5 * (dangles[angle_src] + dangles[angle_dst])
332                    swang = switch_angles[angle_src] * switch_angles[angle_dst]
333                radial_basis_angle = (
334                    RadialBasis(
335                        **{
336                            **self.radial_basis_angle,
337                            "end": self._graphs_properties[self.graph_angle_key][
338                                "cutoff"
339                            ],
340                            "name": f"RadialBasisAngle",
341                        }
342                    )(dangles)
343                    * swang
344                )
345
346            else:
347                if filtered:
348                    radial_basis_angle = radial_basis[filter_indices] * switch_angles
349                else:
350                    radial_basis_angle = radial_basis * switch
351
352        radial_basis = radial_basis * switch
353
354        # # add covalent indicator
355        # rc = jnp.asarray([d/au.BOHR for d in D3_COV_RADII])[species]
356        # rcij = rc[edge_src] + rc[edge_dst]
357        # fact = graph["switch"]*(2*distances/rcij)*jnp.exp(-0.5 * ((distances - rcij)/(0.1*rcij)) ** 2)
358        # radial_basis = jnp.concatenate([radial_basis,fact[:,None]],axis=-1)
359        # if use_angles:
360        #     rcij = rc[graph_angle["edge_src"]] + rc[graph_angle["edge_dst"]]
361        #     dangles = graph_angle["distances"]
362        #     fact = graph_angle["switch"]*((2*dangles/rcij))*jnp.exp(-0.5 * ((dangles - rcij)/(0.1*rcij))**2)
363        #     radial_basis_angle = jnp.concatenate([radial_basis_angle,fact[:,None]],axis=-1)
364
365
366        ##################################################
367        if self.lmax > 0:
368            Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
369                graph["vec"] / graph["distances"][:, None]
370            )[:, None, :]
371            Yij = jnp.broadcast_to(Yij, (Yij.shape[0], self.nchannels_l, Yij.shape[2]))
372            nrep_l = np.array([2 * l + 1 for l in range(self.lmax + 1)], dtype=np.int32)
373            # ls = [0]
374            # for l in range(1, self.lmax + 1):
375            #    ls = ls + [l] * (2 * l + 1)
376            #ls = jnp.asarray(np.array(ls)[None, :], dtype=distances.dtype)
377            #lcut = (0.5 + 0.5 * jnp.cos((np.pi / cutoff) * distances[:, #None])) ** (
378            #    ls + 1
379            #)
380            # lcut = jnp.where(graph["edge_mask"][:, None], lcut, 0.0)
381            # rijl1 = (lcut * distances[:, None] ** ls)[:, None, :]
382
383        ##################################################
384        if use_angles:
385            ### ANGULAR BASIS ###
386            if self.angle_style == "fourier":
387                # build fourier series for angles
388                nangles = self.param(
389                    f"nangles",
390                    lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :],
391                    self.nmax_angle + 1,
392                )
393
394                phi = self.param(
395                    f"phi",
396                    lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype),
397                    self.nmax_angle + 1,
398                )
399                xa = jnp.cos(nangles * angles + phi)
400            elif self.angle_style == "fourier_full":
401                # build fourier series for angles including sin terms
402                nangles = self.param(
403                    f"nangles",
404                    lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :],
405                    self.nmax_angle + 1,
406                )
407
408                phi = self.param(
409                    f"phi",
410                    lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype),
411                    2 * self.nmax_angle + 1,
412                )
413                xac = jnp.cos(nangles * angles + phi[:, : self.nmax_angle + 1])
414                xas = jnp.sin(nangles[:, 1:] * angles + phi[:, self.nmax_angle + 1 :])
415                xa = jnp.concatenate([xac, xas], axis=-1)
416            elif self.angle_style == "ani":
417                # ANI-style angle embedding
418                angle_start = np.pi / (2 * (self.nmax_angle + 1))
419                shiftZ = self.param(
420                    f"shiftZ",
421                    lambda key, dim: jnp.asarray(
422                        (np.linspace(0, np.pi, dim + 1) + angle_start)[None, :-1],
423                        dtype=distances.dtype,
424                    ),
425                    self.nmax_angle + 1,
426                )
427                zeta = self.param(
428                    f"zeta",
429                    lambda key: jnp.asarray(self.zeta, dtype=distances.dtype),
430                )
431                xa = (0.5 + 0.5 * jnp.cos(angles - shiftZ)) ** zeta
432            else:
433                raise ValueError(f"Unknown angle style {self.angle_style}")
434            xa = xa[:, None, :]
435            if not self.angle_combine_pairs:
436                if reduce_memory: raise NotImplementedError("Angle embedding not implemented with reduce_memory")
437                xa = (xa * radial_basis_angle[:, :, None]).reshape(
438                    -1, 1, xa.shape[1] * radial_basis_angle.shape[1]
439                )
440
441            if self.pair_embedding_key is not None:
442                if filtered:
443                    ang_pair_src = filter_indices[angle_src]
444                    ang_pair_dst = filter_indices[angle_dst]
445                else:
446                    ang_pair_src = angle_src
447                    ang_pair_dst = angle_dst
448                ang_pairs = jnp.concatenate((ang_pair_src, ang_pair_dst))
449
450        ##################################################
451        ### DIMENSIONS ###
452        dim_src = (
453            [self.dim_src] * self.nlayers
454            if isinstance(self.dim_src, int)
455            else self.dim_src
456        )
457        assert (
458            len(dim_src) == self.nlayers
459        ), f"dim_src must be an integer or a list of length {self.nlayers}"
460        dim_dst = self.dim_dst
461        # dim_dst = (
462        #     [self.dim_dst] * self.nlayers
463        #     if isinstance(self.dim_dst, int)
464        #     else self.dim_dst
465        # )
466        # assert (
467        #     len(dim_dst) == self.nlayers
468        # ), f"dim_dst must be an integer or a list of length {self.nlayers}"
469
470        if use_angles:
471            dim_angle = (
472                [self.dim_angle] * self.nlayers
473                if isinstance(self.dim_angle, int)
474                else self.dim_angle
475            )
476            assert (
477                len(dim_angle) == self.nlayers
478            ), f"dim_angle must be an integer or a list of length {self.nlayers}"
479            # nmax_angle = [self.nmax_angle]*self.nlayers if isinstance(self.nmax_angle, int) else self.nmax_angle
480            # assert len(nmax_angle) == self.nlayers, f"nmax_angle must be an integer or a list of length {self.nlayers}"
481        
482        initialize_e3 = True
483        if self.lmax > 0:
484            n_tp = (
485                [self.n_tp] * self.nlayers
486                if isinstance(self.n_tp, int)
487                else self.n_tp
488            )
489            assert (
490                len(n_tp) == self.nlayers
491            ), f"n_tp must be an integer or a list of length {self.nlayers}"
492
493
494        message_passing = (
495            [self.message_passing] * self.nlayers
496            if isinstance(self.message_passing, bool)
497            else self.message_passing
498        )
499        assert (
500            len(message_passing) == self.nlayers
501        ), f"message_passing must be a boolean or a list of length {self.nlayers}"
502
503        ##################################################
504        ### INITIALIZE PAIR EMBEDDING ###
505        if self.pair_embedding_key is not None:
506            xij_s,xij_d = jnp.split(nn.Dense(2*dim_dst, name="pair_init_linear")(zi), [dim_dst], axis=-1)
507            xij = layer_norm(xij_s[edge_src]*xij_d[edge_dst])
508
509        ##################################################
510        if self.keep_all_layers:
511            xis = []
512        
513        ### LOOP OVER LAYERS ###
514        for layer in range(self.nlayers):
515            ##################################################
516            ### COMPACT DESCRIPTORS ###
517            si, si_dst = jnp.split(
518                nn.Dense(
519                    dim_src[layer] + dim_dst,
520                    name=f"species_linear_{layer}",
521                    use_bias=self.use_bias,
522                )(xi),
523                [
524                    dim_src[layer],
525                ],
526                axis=-1,
527            )
528
529            ##################################################
530            if message_passing[layer] or layer == 0:
531                ### MESSAGE PASSING ###
532                si_mp = si_dst[edge_dst]
533            else:
534                # if layer == 0:
535                #     si_mp = si_dst[edge_dst]
536                ### ATTENTION TO SIMULATE MP ###
537                Q = nn.Dense(
538                    dim_dst * self.att_dim, name=f"queries_{layer}", use_bias=False
539                )(si_dst).reshape(-1, dim_dst, self.att_dim)[edge_src]
540                K = nn.Dense(
541                    dim_dst * self.att_dim, name=f"keys_{layer}", use_bias=False
542                )(zi).reshape(-1, dim_dst, self.att_dim)[edge_dst]
543
544                si_mp = (K * Q).sum(axis=-1) / self.att_dim**0.5
545                # Vmp = jax.ops.segment_sum(
546                #     (KQ * switch)[:, :, None] * Yij, edge_src, species.shape[0]
547                # )
548                # si_mp = (Vmp[edge_src] * Yij).sum(axis=-1)
549                # Q = nn.Dense(
550                #     dim_dst * dim_dst, name=f"queries_{layer}", use_bias=False
551                # )(si_dst).reshape(-1, dim_dst, dim_dst)
552                # si_mp = (
553                #     si_mp + jax.vmap(jnp.dot)(Q[edge_src], si_mp) / self.dim_dst**0.5
554                # )
555
556            if self.pair_embedding_key is not None:
557                si_mp = si_mp + xij
558
559            ##################################################
560            ### PAIR EMBEDDING ###
561            if reduce_memory:
562                Li = jnp.zeros((species.shape[0]* radial_basis.shape[1],si_mp.shape[1]),dtype=si_mp.dtype)
563                for i in range(radial_basis.shape[1]):
564                    indices = i + edge_src*radial_basis.shape[1]
565                    Li = Li.at[indices].add(si_mp*radial_basis[:,i,None])
566                Li = Li.reshape(species.shape[0], radial_basis.shape[1]*si_mp.shape[1])
567            else:
568                Lij = (si_mp[:, None, :] * radial_basis[:, :, None]).reshape(
569                    radial_basis.shape[0], si_mp.shape[1] * radial_basis.shape[1]
570                )
571                ### AGGREGATE PAIR EMBEDDING ###
572                Li = jax.ops.segment_sum(Lij, edge_src, species.shape[0])
573
574            ### CONCATENATE EMBEDDING COMPONENTS ###
575            components = [si, Li]
576            if self.pair_embedding_key is not None:
577                if reduce_memory: raise NotImplementedError("Pair embedding not implemented with reduce_memory")
578                components_pair = [si[edge_src], xij, Lij]
579
580
581            ##################################################
582            ### ANGLE EMBEDDING ###
583            if use_angles and dim_angle[layer]>0:
584                si_mp_ang = si_mp[filter_indices] if filtered else si_mp
585                if self.angle_combine_pairs:
586                    Wa = self.param(
587                        f"Wa_{layer}",
588                        nn.initializers.normal(
589                            stddev=1.0
590                            / (si_mp.shape[1] * radial_basis_angle.shape[1]) ** 0.5
591                        ),
592                        (si_mp.shape[1], radial_basis_angle.shape[1], dim_angle[layer]),
593                    )
594                    Da = jnp.einsum(
595                        "...i,...j,ijk->...k",
596                        si_mp_ang,
597                        radial_basis_angle,
598                        Wa,
599                    )
600
601                else:
602                    if message_passing[layer] or layer == 0:
603                        Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")(
604                            xi
605                        )[graph_angle["edge_dst"]]
606                    else:
607                        Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")(
608                            si_mp_ang
609                        )
610
611                Da = Da[angle_dst] * Da[angle_src]
612                ## combine pair and angle info
613                if reduce_memory:
614                    ang_embedding = jnp.zeros((species.shape[0]* Da.shape[-1],xa.shape[-1]),dtype=Da.dtype)
615                    for i in range(Da.shape[-1]):
616                        indices = i + central_atom*Da.shape[-1]
617                        ang_embedding = ang_embedding.at[indices].add(Da[:,i,None]*xa[:,0,:])
618                    ang_embedding = ang_embedding.reshape(species.shape[0], xa.shape[-1]*Da.shape[-1])
619                else:
620                    radang = (xa * Da[:, :, None]).reshape(
621                        (-1, Da.shape[1] * xa.shape[2])
622                    )
623                    ### AGGREGATE  ANGLE EMBEDDING ###
624                    ang_embedding = jax.ops.segment_sum(
625                        radang, central_atom, species.shape[0]
626                    )
627                    
628
629                components.append(ang_embedding)
630
631                if self.pair_embedding_key is not None:
632                    ang_ij = jax.ops.segment_sum(
633                        jnp.concatenate((radang, radang)),
634                        ang_pairs,
635                        edge_src.shape[0],
636                    )
637                    components_pair.append(ang_ij)
638            
639            ##################################################
640            ### EQUIVARIANT EMBEDDING ###
641            if self.lmax > 0 and n_tp[layer] >= 0:
642                if initialize_e3 or not message_passing[layer]:
643                    Vij = Yij
644                elif self.edge_tp:
645                    Vij = FilteredTensorProduct(
646                            self.lmax, self.lmax, name=f"edge_tp_{layer}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False
647                        )(Vi[edge_dst], Yij)
648                else:
649                    Vij = Vi[edge_dst]
650
651                ### compute channel weights
652                dim_wij = self.nchannels_l
653                if self.resolve_wij_l:
654                    dim_wij=self.nchannels_l*(self.lmax+1)
655
656                eij = Lij if self.pair_embedding_key is None else jnp.concatenate([Lij,xij*switch],axis=-1)
657                wij = nn.Dense(
658                        dim_wij, name=f"e3_channel_{layer}", use_bias=False
659                    )(eij)
660                if self.resolve_wij_l:
661                    wij = jnp.repeat(wij.reshape(-1,self.nchannels_l,self.lmax+1),nrep_l,axis=-1)
662                else:
663                    wij = wij[:,:,None]
664                
665                ### aggregate equivariant messages
666                drhoi = jax.ops.segment_sum(
667                    wij * Vij,
668                    edge_src,
669                    species.shape[0],
670                )
671
672                Vi0 = []
673                if initialize_e3:
674                    rhoi = drhoi
675                    Vi = ChannelMixingE3(
676                        self.lmax,
677                        self.nchannels_l,
678                        self.nchannels_l,
679                        name=f"e3_initial_mixing_{layer}",
680                    )(rhoi)
681                    # assert n_tp[layer] > 0, "n_tp must be > 0 for the first equivariant layer."
682                else:
683                    rhoi = rhoi + drhoi
684                    # if message_passing[layer]:
685                        # Vi0.append(drhoi[:, :, 0])
686                initialize_e3 = False
687                if n_tp[layer] > 0:
688                    for itp in range(n_tp[layer]):
689                        dVi = FilteredTensorProduct(
690                            self.lmax, self.lmax, name=f"tensor_product_{layer}_{itp}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False
691                        )(rhoi, Vi)
692                        Vi = ChannelMixing(
693                                self.lmax,
694                                self.nchannels_l,
695                                self.nchannels_l,
696                                name=f"tp_mixing_{layer}_{itp}",
697                            )(Vi + dVi)
698                        Vi0.append(dVi[:, :, 0])
699                    Vi0 = jnp.concatenate(Vi0, axis=-1)
700                    components.append(Vi0)
701
702                if self.pair_embedding_key is not None:
703                    Vij = Vi[edge_src]*Yij
704                    Vij0 = [Vij[...,0]]
705                    for l in range(1,self.lmax+1):
706                        Vij0.append(Vij[...,l**2:(l+1)**2].sum(axis=-1))
707                    Vij0 = jnp.concatenate(Vij0,axis=-1)
708                    components_pair.append(Vij0)
709
710            ##################################################
711            ### CONCATENATE EMBEDDING COMPONENTS ###
712            if do_lode and nchannels_lode[layer] > 0:
713                zj = nn.Dense(dim_lr[layer], use_bias=False, name=f"LODE_{layer}")(xi)
714                if nextra_powers > 0:
715                    zj_extra = zj[:,:nextra_powers*nchannels_lode[layer]].reshape(
716                        (species.shape[0],nchannels_lode[layer], nextra_powers)
717                    )
718                    zj = zj[:,nextra_powers*nchannels_lode[layer]:]
719                    xi_lr_extra = jax.ops.segment_sum(
720                        eij_lr_extra * zj_extra[edge_dst_lr], edge_src_lr, species.shape[0]
721                    )
722                    components.append(xi_lr_extra.reshape(species.shape[0],-1))
723
724                if equivariant_lode:
725                    zj = zj.reshape(
726                        (species.shape[0], nchannels_lode[layer], lmax_lr + 1)
727                    ).repeat(nrep_lr, axis=-1)
728                xi_lr = jax.ops.segment_sum(
729                    eij_lr * zj[edge_dst_lr], edge_src_lr, species.shape[0]
730                )
731                if equivariant_lode:
732                    assert self.lode_use_field_norm or self.lode_multipole_interaction, "equivariant LODE requires field norm or multipole interaction"
733                    if self.lode_multipole_interaction:
734                        if initialize_e3:
735                            raise ValueError("equivariant LODE used before local equivariants initialized")
736                        size_l_lr = (lmax_lr+1)**2
737                        if self.lode_direct_multipoles:
738                            assert nchannels_lode[layer] <= self.nchannels_l
739                            Mi = Vi[:, : nchannels_lode[layer], :size_l_lr]
740                        else:
741                            Mi = ChannelMixingE3(
742                                lmax_lr,
743                                self.nchannels_l,
744                                nchannels_lode[layer],
745                                name=f"e3_LODE_{layer}",
746                            )(Vi[...,:size_l_lr])
747                        Mi_lr = Mi * xi_lr
748                    components.append(xi_lr[:, :, 0])
749                    if self.lode_use_field_norm and self.lode_equi_full_combine:
750                        xi_lr1 = ChannelMixing(
751                            lmax_lr,
752                            nchannels_lode[layer],
753                            nchannels_lode[layer],
754                            name=f"LODE_mixing_{layer}",
755                        )(xi_lr)
756                    norm = 1.
757                    for l in range(1, lmax_lr + 1):
758                        if self.lode_normalize_l:
759                            norm = 1. / (2 * l + 1)
760                        if self.lode_multipole_interaction:
761                            components.append(Mi_lr[:, :, l**2 : (l + 1) ** 2].sum(axis=-1)*norm)
762
763                        if self.lode_use_field_norm:
764                            if self.lode_equi_full_combine:
765                                components.append((xi_lr[:,:,l**2 : (l + 1) ** 2]*xi_lr1[:,:,l**2 : (l + 1) ** 2]).sum(axis=-1)*norm)
766                            else:
767                                components.append(
768                                    ((xi_lr[:, :, l**2 : (l + 1) ** 2]) ** 2).sum(axis=-1)*norm
769                                )
770                else:
771                    components.append(xi_lr)
772
773            dxi = jnp.concatenate(components, axis=-1)
774
775            ##################################################
776            ### CONCATENATE PAIR EMBEDDING COMPONENTS ###
777            if self.pair_embedding_key is not None:
778                dxij = jnp.concatenate(components_pair, axis=-1)
779
780            ##################################################
781            ### MIX AND APPLY NONLINEARITY ###
782            if self.block_index_key is not None:
783                block_index = inputs[self.block_index_key]
784                dxi = actmix(BlockIndexNet(
785                        output_dim=self.dim,
786                        hidden_neurons=self.mixing_hidden,
787                        activation=self.activation,
788                        name=f"dxi_{layer}",
789                        use_bias=self.use_bias,
790                        kernel_init=kernel_init,
791                    )((species,dxi, block_index))
792                )
793            else:
794                dxi = actmix(
795                    FullyConnectedNet(
796                        [*self.mixing_hidden, self.dim],
797                        activation=self.activation,
798                        name=f"dxi_{layer}",
799                        use_bias=self.use_bias,
800                        kernel_init=kernel_init,
801                    )(dxi)
802                )
803
804            if self.pair_embedding_key is not None:
805                ### UPDATE PAIR EMBEDDING ###
806                # dxij = tssr3(nn.Dense(dim_dst, name=f"dxij_{layer}",use_bias=False)(dxij))
807                dxij = actmix(
808                    FullyConnectedNet(
809                        [*self.pair_mixing_hidden, dim_dst],
810                        activation=self.activation,
811                        name=f"dxij_{layer}",
812                        use_bias=False,
813                        kernel_init=kernel_init,
814                    )(dxij)
815                )
816                xij = layer_norm(xij + dxij)
817
818            ##################################################
819            ### UPDATE EMBEDDING ###
820            if layer == 0 and not (self.species_init or self.charge_embedding):
821                xi = layer_norm(dxi)
822            else:
823                ### FORGET GATE ###
824                R = jax.nn.sigmoid(
825                    self.param(
826                        f"retention_{layer}",
827                        nn.initializers.normal(),
828                        (xi.shape[-1],),
829                    )
830                )
831                xi = layer_norm(R[None, :] * xi + dxi)
832
833            if self.keep_all_layers:
834                xis.append(xi)
835
836        embedding_key = (
837            self.embedding_key if self.embedding_key is not None else self.name
838        )
839        output = {
840            **inputs,
841            embedding_key: xi,
842        }
843        if self.lmax > 0:
844            output[embedding_key + "_tensor"] = Vi
845        if self.keep_all_layers:
846            output[embedding_key + "_layers"] = jnp.stack(xis, axis=1)
847        if self.charge_embedding:
848            output[embedding_key + "_charge"] = charge_embedding
849        if self.pair_embedding_key is not None:
850            output[self.pair_embedding_key] = xij
851        return output
class CRATEmbedding(flax.linen.module.Module):
 18class CRATEmbedding(nn.Module):
 19    """Configurable Resources ATomic Environment
 20
 21    FID : CRATE
 22
 23    This class represents the CRATE (Configurable Resources ATomic Environment) embedding model.
 24    It is used to encode atomic environments using multiple sources of information
 25      (radial, angular, E(3), message-passing, LODE, etc...)
 26    """
 27
 28    _graphs_properties: Dict
 29
 30    dim: int = 256
 31    """The size of the embedding vectors."""
 32    nlayers: int = 2
 33    """The number of interaction layers in the model."""
 34    keep_all_layers: bool = False
 35    """Whether to output all layers."""
 36
 37    dim_src: int = 64
 38    """The size of the source embedding vectors."""
 39    dim_dst: int = 32
 40    """The size of the destination embedding vectors."""
 41
 42    angle_style: str = "fourier"
 43    """The style of angle representation."""
 44    dim_angle: int = 8
 45    """The size of the pairwise vectors use for triplet combinations."""
 46    nmax_angle: int = 4
 47    """The dimension of the angle representation (minus one)."""
 48    zeta: float = 14.1
 49    """The zeta parameter for the model ANI angular representation."""
 50    angle_combine_pairs: bool = True
 51    """Whether to combine angle pairs instead of average distance embedding like in ANI."""
 52
 53    message_passing: bool = True
 54    """Whether to use message passing in the model."""
 55    att_dim: int = 1
 56    """The hidden size for the attention mechanism (only used when message-passing is disabled)."""
 57
 58    lmax: int = 0
 59    """The maximum order of spherical tensors."""
 60    nchannels_l: int = 16
 61    """The number of channels for spherical tensors."""
 62    n_tp: int = 1
 63    """The number of tensor products performed at each layer."""
 64    ignore_irreps_parity: bool = False
 65    """Whether to ignore the parity of the irreps in the tensor product."""
 66    edge_tp: bool = False
 67    """Whether to perform a tensor product on edges before sending messages."""
 68    resolve_wij_l: bool = False
 69    """Equivariant message weights are l-dependent."""
 70
 71    species_init: bool = False
 72    """Whether to initialize the embedding using the species encoding."""
 73    mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 74    """The hidden layer sizes for the mixing network."""
 75    pair_mixing_hidden: Sequence[int] = dataclasses.field(default_factory=lambda: [])
 76    """The hidden layer sizes for the pair mixing network."""
 77    activation: Union[Callable, str] = "silu"
 78    """The activation function for the mixing network."""
 79    kernel_init: Union[str, Callable] = "lecun_normal()"
 80    """The kernel initialization function for Dense operations."""
 81    activation_mixing: Union[Callable, str] = "tssr3"
 82    """The activation function applied after mixing."""
 83    layer_normalization: bool = False
 84    """Whether to apply layer normalization after each layer."""
 85    use_bias: bool = True
 86    """Whether to use bias in the Dense operations."""
 87
 88    graph_key: str = "graph"
 89    """The key for the graph data in the inputs dictionary."""
 90    graph_angle_key: Optional[str] = None
 91    """The key for the angle graph data in the inputs dictionary."""
 92    embedding_key: Optional[str] = None
 93    """The key for the embedding data in the output dictionary."""
 94    pair_embedding_key: Optional[str] = None
 95    """The key for the pair embedding data in the output dictionary."""
 96
 97    species_encoding: Union[dict, str] = dataclasses.field(default_factory=dict)
 98    """If `str`, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 99    radial_basis: dict = dataclasses.field(default_factory=dict)
100    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
101    radial_basis_angle: Optional[dict] = None
102    """The dictionary of parameters for radial basis functions for angle embedding. 
103        If None, the radial basis for angles is the same as the radial basis for distances."""
104
105    graph_lode: Optional[str] = None
106    """The key for the lode graph data in the inputs dictionary."""
107    lode_channels: Union[int, Sequence[int]] = 8
108    """The number of channels for lode."""
109    lmax_lode: int = 0
110    """The maximum order of spherical tensors for lode."""
111    a_lode: float = -1.
112    """The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode."""
113    lode_resolve_l: bool = True
114    """Whether to resolve the lode channels by l."""
115    lode_multipole_interaction: bool = True
116    """Whether to interact with the multipole moments of the lode graph."""
117    lode_direct_multipoles: bool = True
118    """Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction."""
119    lode_equi_full_combine: bool = False
120    lode_normalize_l: bool = False
121    lode_use_field_norm: bool = True
122    lode_rshort: Optional[float] = None
123    lode_dshort: float = 0.5
124    lode_extra_powers: Sequence[int] = ()
125    
126
127    charge_embedding: bool = False
128    """Whether to include charge embedding."""
129    total_charge_key: str = "total_charge"
130    """The key for the total charge data in the inputs dictionary."""
131
132    block_index_key: Optional[str] = None
133    """The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network."""
134
135    FID: ClassVar[str] = "CRATE"
136
137    @nn.compact
138    def __call__(self, inputs):
139        species = inputs["species"]
140        assert (
141            len(species.shape) == 1
142        ), "Species must be a 1D array (batches must be flattened)"
143        reduce_memory =  "reduce_memory" in inputs.get("flags", {})
144
145        kernel_init = (
146            initializer_from_str(self.kernel_init)
147            if isinstance(self.kernel_init, str)
148            else self.kernel_init
149        )
150
151        actmix = activation_from_str(self.activation_mixing)
152
153        ##################################################
154        graph = inputs[self.graph_key]
155        use_angles = self.graph_angle_key is not None
156        if use_angles:
157            graph_angle = inputs[self.graph_angle_key]
158
159            # Check that the graph_angle is a subgraph of graph
160            correct_graph = (
161                self.graph_angle_key == self.graph_key
162                or self._graphs_properties[self.graph_angle_key]["parent_graph"]
163                == self.graph_key
164            )
165            assert (
166                correct_graph
167            ), f"graph_angle_key={self.graph_angle_key} must be a subgraph of graph_key={self.graph_key}"
168            assert (
169                "angles" in graph_angle
170            ), f"Graph {self.graph_angle_key} must contain angles"
171            # check if graph_angle is a filtered graph
172            filtered = "parent_graph" in self._graphs_properties[self.graph_angle_key]
173            if filtered:
174                filter_indices = graph_angle["filter_indices"]
175
176        ##################################################
177        ### SPECIES ENCODING ###
178        if isinstance(self.species_encoding, str):
179            zi = inputs[self.species_encoding]
180        else:
181            zi = SpeciesEncoding(
182                **self.species_encoding, name="SpeciesEncoding"
183            )(species)
184
185        
186        if self.layer_normalization:
187            def layer_norm(x):
188                mu = jnp.mean(x,axis=-1,keepdims=True)
189                dx = x-mu
190                var = jnp.mean(dx**2,axis=-1,keepdims=True)
191                sig = (1.e-6 + var)**(-0.5)
192                return dx*sig
193        else:
194            layer_norm = lambda x:x
195                
196
197        if self.charge_embedding:
198            xi, qi = jnp.split(
199                nn.Dense(self.dim + 1, use_bias=False, name="ChargeEncoding")(zi),
200                [self.dim],
201                axis=-1,
202            )
203            batch_index = inputs["batch_index"]
204            natoms = inputs["natoms"]
205            nsys = natoms.shape[0]
206            Zi = jnp.asarray(VALENCE_ELECTRONS)[species]
207            Ntot = jax.ops.segment_sum(Zi, batch_index, nsys) - inputs.get(
208                self.total_charge_key, jnp.zeros(nsys)
209            )
210            ai = jax.nn.softplus(qi.squeeze(-1))
211            A = jax.ops.segment_sum(ai, batch_index, nsys)
212            Ni = ai * (Ntot / A)[batch_index]
213            charge_embedding = positional_encoding(Ni, self.dim)
214            xi = layer_norm(xi + charge_embedding)
215        elif self.species_init:
216            xi = layer_norm(nn.Dense(self.dim, use_bias=False, name="SpeciesInit")(zi))
217        else:
218            xi = zi
219
220        ##################################################
221        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
222        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
223        distances = graph["distances"]
224        switch = graph["switch"][:, None]
225
226        ### COMPUTE RADIAL BASIS ###
227        radial_basis = RadialBasis(
228            **{
229                **self.radial_basis,
230                "end": cutoff,
231                "name": f"RadialBasis",
232            }
233        )(distances)
234
235        do_lode = self.graph_lode is not None
236        if do_lode:
237            graph_lode = inputs[self.graph_lode]
238            switch_lode = graph_lode["switch"][:, None]
239
240            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
241            r = graph_lode["distances"][:, None]
242            rc = self._graphs_properties[self.graph_lode]["cutoff"]
243            
244            lmax_lr = self.lmax_lode
245            equivariant_lode = lmax_lr > 0
246            assert lmax_lr >=0, f"lmax_lode must be >= 0, got {lmax_lr}"
247            if self.lode_multipole_interaction:
248                assert lmax_lr <= self.lmax, f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
249            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
250            if self.lode_resolve_l and equivariant_lode:
251                ls_lr = np.arange(lmax_lr + 1)
252            else:
253                ls_lr = np.array([0])
254
255            nextra_powers = len(self.lode_extra_powers)
256            if nextra_powers > 0:
257                ls_lr = np.concatenate([self.lode_extra_powers,ls_lr])
258
259            if self.a_lode > 0:
260                a = self.a_lode**2
261            else:
262                a = (
263                    self.param(
264                        "a_lr",
265                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[None, :],
266                    )
267                    ** 2
268                )
269            rc2a = rc**2 + a
270            ls_lr = 0.5 * (ls_lr[None, :] + 1)
271            ### minimal radial basis for long range (damped coulomb)
272            eij_lr = (
273                1.0 / (r**2 + a) ** ls_lr
274                - 1.0 / rc2a**ls_lr
275                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
276            ) * switch_lode
277
278            if self.lode_rshort is not None:
279                rs = self.lode_rshort
280                d = self.lode_dshort
281                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
282                    r < rs + d
283                ) + (r >= rs + d)
284                eij_lr = eij_lr * switch_short
285            
286            if nextra_powers>0:
287                eij_lr_extra = eij_lr[:,:nextra_powers]
288                eij_lr = eij_lr[:,nextra_powers:]
289
290
291            # dim_lr = self.nchannels_lode
292            nchannels_lode = (
293                [self.lode_channels] * self.nlayers
294                if isinstance(self.lode_channels, int)
295                else self.lode_channels
296            )
297            dim_lr = nchannels_lode
298            
299            if equivariant_lode:
300                if self.lode_resolve_l:
301                    eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
302                Yij = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
303                    graph_lode["vec"] / r
304                )
305                eij_lr = (eij_lr * Yij)[:, None, :]
306                dim_lr = [d * (lmax_lr + 1) for d in dim_lr]
307            
308            if nextra_powers > 0:
309                eij_lr_extra = eij_lr_extra[:,None,:]
310                extra_dims = [nextra_powers*d for d in nchannels_lode]
311                dim_lr = [d + ed for d,ed in zip(dim_lr,extra_dims)]
312            
313
314        ##################################################
315        ### GET ANGLES ###
316        if use_angles:
317            angles = graph_angle["angles"][:, None]
318            angle_src, angle_dst = graph_angle["angle_src"], graph_angle["angle_dst"]
319            switch_angles = graph_angle["switch"][:, None]
320            central_atom = graph_angle["central_atom"]
321
322            if not self.angle_combine_pairs:
323                assert (
324                    self.radial_basis_angle is not None
325                ), "radial_basis_angle must be specified if angle_combine_pairs=False"
326
327            ### COMPUTE RADIAL BASIS FOR ANGLES ###
328            if self.radial_basis_angle is not None:
329                dangles = graph_angle["distances"]
330                swang = switch_angles
331                if not self.angle_combine_pairs:
332                    dangles = 0.5 * (dangles[angle_src] + dangles[angle_dst])
333                    swang = switch_angles[angle_src] * switch_angles[angle_dst]
334                radial_basis_angle = (
335                    RadialBasis(
336                        **{
337                            **self.radial_basis_angle,
338                            "end": self._graphs_properties[self.graph_angle_key][
339                                "cutoff"
340                            ],
341                            "name": f"RadialBasisAngle",
342                        }
343                    )(dangles)
344                    * swang
345                )
346
347            else:
348                if filtered:
349                    radial_basis_angle = radial_basis[filter_indices] * switch_angles
350                else:
351                    radial_basis_angle = radial_basis * switch
352
353        radial_basis = radial_basis * switch
354
355        # # add covalent indicator
356        # rc = jnp.asarray([d/au.BOHR for d in D3_COV_RADII])[species]
357        # rcij = rc[edge_src] + rc[edge_dst]
358        # fact = graph["switch"]*(2*distances/rcij)*jnp.exp(-0.5 * ((distances - rcij)/(0.1*rcij)) ** 2)
359        # radial_basis = jnp.concatenate([radial_basis,fact[:,None]],axis=-1)
360        # if use_angles:
361        #     rcij = rc[graph_angle["edge_src"]] + rc[graph_angle["edge_dst"]]
362        #     dangles = graph_angle["distances"]
363        #     fact = graph_angle["switch"]*((2*dangles/rcij))*jnp.exp(-0.5 * ((dangles - rcij)/(0.1*rcij))**2)
364        #     radial_basis_angle = jnp.concatenate([radial_basis_angle,fact[:,None]],axis=-1)
365
366
367        ##################################################
368        if self.lmax > 0:
369            Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(
370                graph["vec"] / graph["distances"][:, None]
371            )[:, None, :]
372            Yij = jnp.broadcast_to(Yij, (Yij.shape[0], self.nchannels_l, Yij.shape[2]))
373            nrep_l = np.array([2 * l + 1 for l in range(self.lmax + 1)], dtype=np.int32)
374            # ls = [0]
375            # for l in range(1, self.lmax + 1):
376            #    ls = ls + [l] * (2 * l + 1)
377            #ls = jnp.asarray(np.array(ls)[None, :], dtype=distances.dtype)
378            #lcut = (0.5 + 0.5 * jnp.cos((np.pi / cutoff) * distances[:, #None])) ** (
379            #    ls + 1
380            #)
381            # lcut = jnp.where(graph["edge_mask"][:, None], lcut, 0.0)
382            # rijl1 = (lcut * distances[:, None] ** ls)[:, None, :]
383
384        ##################################################
385        if use_angles:
386            ### ANGULAR BASIS ###
387            if self.angle_style == "fourier":
388                # build fourier series for angles
389                nangles = self.param(
390                    f"nangles",
391                    lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :],
392                    self.nmax_angle + 1,
393                )
394
395                phi = self.param(
396                    f"phi",
397                    lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype),
398                    self.nmax_angle + 1,
399                )
400                xa = jnp.cos(nangles * angles + phi)
401            elif self.angle_style == "fourier_full":
402                # build fourier series for angles including sin terms
403                nangles = self.param(
404                    f"nangles",
405                    lambda key, dim: jnp.arange(dim, dtype=distances.dtype)[None, :],
406                    self.nmax_angle + 1,
407                )
408
409                phi = self.param(
410                    f"phi",
411                    lambda key, dim: jnp.zeros((1, dim), dtype=distances.dtype),
412                    2 * self.nmax_angle + 1,
413                )
414                xac = jnp.cos(nangles * angles + phi[:, : self.nmax_angle + 1])
415                xas = jnp.sin(nangles[:, 1:] * angles + phi[:, self.nmax_angle + 1 :])
416                xa = jnp.concatenate([xac, xas], axis=-1)
417            elif self.angle_style == "ani":
418                # ANI-style angle embedding
419                angle_start = np.pi / (2 * (self.nmax_angle + 1))
420                shiftZ = self.param(
421                    f"shiftZ",
422                    lambda key, dim: jnp.asarray(
423                        (np.linspace(0, np.pi, dim + 1) + angle_start)[None, :-1],
424                        dtype=distances.dtype,
425                    ),
426                    self.nmax_angle + 1,
427                )
428                zeta = self.param(
429                    f"zeta",
430                    lambda key: jnp.asarray(self.zeta, dtype=distances.dtype),
431                )
432                xa = (0.5 + 0.5 * jnp.cos(angles - shiftZ)) ** zeta
433            else:
434                raise ValueError(f"Unknown angle style {self.angle_style}")
435            xa = xa[:, None, :]
436            if not self.angle_combine_pairs:
437                if reduce_memory: raise NotImplementedError("Angle embedding not implemented with reduce_memory")
438                xa = (xa * radial_basis_angle[:, :, None]).reshape(
439                    -1, 1, xa.shape[1] * radial_basis_angle.shape[1]
440                )
441
442            if self.pair_embedding_key is not None:
443                if filtered:
444                    ang_pair_src = filter_indices[angle_src]
445                    ang_pair_dst = filter_indices[angle_dst]
446                else:
447                    ang_pair_src = angle_src
448                    ang_pair_dst = angle_dst
449                ang_pairs = jnp.concatenate((ang_pair_src, ang_pair_dst))
450
451        ##################################################
452        ### DIMENSIONS ###
453        dim_src = (
454            [self.dim_src] * self.nlayers
455            if isinstance(self.dim_src, int)
456            else self.dim_src
457        )
458        assert (
459            len(dim_src) == self.nlayers
460        ), f"dim_src must be an integer or a list of length {self.nlayers}"
461        dim_dst = self.dim_dst
462        # dim_dst = (
463        #     [self.dim_dst] * self.nlayers
464        #     if isinstance(self.dim_dst, int)
465        #     else self.dim_dst
466        # )
467        # assert (
468        #     len(dim_dst) == self.nlayers
469        # ), f"dim_dst must be an integer or a list of length {self.nlayers}"
470
471        if use_angles:
472            dim_angle = (
473                [self.dim_angle] * self.nlayers
474                if isinstance(self.dim_angle, int)
475                else self.dim_angle
476            )
477            assert (
478                len(dim_angle) == self.nlayers
479            ), f"dim_angle must be an integer or a list of length {self.nlayers}"
480            # nmax_angle = [self.nmax_angle]*self.nlayers if isinstance(self.nmax_angle, int) else self.nmax_angle
481            # assert len(nmax_angle) == self.nlayers, f"nmax_angle must be an integer or a list of length {self.nlayers}"
482        
483        initialize_e3 = True
484        if self.lmax > 0:
485            n_tp = (
486                [self.n_tp] * self.nlayers
487                if isinstance(self.n_tp, int)
488                else self.n_tp
489            )
490            assert (
491                len(n_tp) == self.nlayers
492            ), f"n_tp must be an integer or a list of length {self.nlayers}"
493
494
495        message_passing = (
496            [self.message_passing] * self.nlayers
497            if isinstance(self.message_passing, bool)
498            else self.message_passing
499        )
500        assert (
501            len(message_passing) == self.nlayers
502        ), f"message_passing must be a boolean or a list of length {self.nlayers}"
503
504        ##################################################
505        ### INITIALIZE PAIR EMBEDDING ###
506        if self.pair_embedding_key is not None:
507            xij_s,xij_d = jnp.split(nn.Dense(2*dim_dst, name="pair_init_linear")(zi), [dim_dst], axis=-1)
508            xij = layer_norm(xij_s[edge_src]*xij_d[edge_dst])
509
510        ##################################################
511        if self.keep_all_layers:
512            xis = []
513        
514        ### LOOP OVER LAYERS ###
515        for layer in range(self.nlayers):
516            ##################################################
517            ### COMPACT DESCRIPTORS ###
518            si, si_dst = jnp.split(
519                nn.Dense(
520                    dim_src[layer] + dim_dst,
521                    name=f"species_linear_{layer}",
522                    use_bias=self.use_bias,
523                )(xi),
524                [
525                    dim_src[layer],
526                ],
527                axis=-1,
528            )
529
530            ##################################################
531            if message_passing[layer] or layer == 0:
532                ### MESSAGE PASSING ###
533                si_mp = si_dst[edge_dst]
534            else:
535                # if layer == 0:
536                #     si_mp = si_dst[edge_dst]
537                ### ATTENTION TO SIMULATE MP ###
538                Q = nn.Dense(
539                    dim_dst * self.att_dim, name=f"queries_{layer}", use_bias=False
540                )(si_dst).reshape(-1, dim_dst, self.att_dim)[edge_src]
541                K = nn.Dense(
542                    dim_dst * self.att_dim, name=f"keys_{layer}", use_bias=False
543                )(zi).reshape(-1, dim_dst, self.att_dim)[edge_dst]
544
545                si_mp = (K * Q).sum(axis=-1) / self.att_dim**0.5
546                # Vmp = jax.ops.segment_sum(
547                #     (KQ * switch)[:, :, None] * Yij, edge_src, species.shape[0]
548                # )
549                # si_mp = (Vmp[edge_src] * Yij).sum(axis=-1)
550                # Q = nn.Dense(
551                #     dim_dst * dim_dst, name=f"queries_{layer}", use_bias=False
552                # )(si_dst).reshape(-1, dim_dst, dim_dst)
553                # si_mp = (
554                #     si_mp + jax.vmap(jnp.dot)(Q[edge_src], si_mp) / self.dim_dst**0.5
555                # )
556
557            if self.pair_embedding_key is not None:
558                si_mp = si_mp + xij
559
560            ##################################################
561            ### PAIR EMBEDDING ###
562            if reduce_memory:
563                Li = jnp.zeros((species.shape[0]* radial_basis.shape[1],si_mp.shape[1]),dtype=si_mp.dtype)
564                for i in range(radial_basis.shape[1]):
565                    indices = i + edge_src*radial_basis.shape[1]
566                    Li = Li.at[indices].add(si_mp*radial_basis[:,i,None])
567                Li = Li.reshape(species.shape[0], radial_basis.shape[1]*si_mp.shape[1])
568            else:
569                Lij = (si_mp[:, None, :] * radial_basis[:, :, None]).reshape(
570                    radial_basis.shape[0], si_mp.shape[1] * radial_basis.shape[1]
571                )
572                ### AGGREGATE PAIR EMBEDDING ###
573                Li = jax.ops.segment_sum(Lij, edge_src, species.shape[0])
574
575            ### CONCATENATE EMBEDDING COMPONENTS ###
576            components = [si, Li]
577            if self.pair_embedding_key is not None:
578                if reduce_memory: raise NotImplementedError("Pair embedding not implemented with reduce_memory")
579                components_pair = [si[edge_src], xij, Lij]
580
581
582            ##################################################
583            ### ANGLE EMBEDDING ###
584            if use_angles and dim_angle[layer]>0:
585                si_mp_ang = si_mp[filter_indices] if filtered else si_mp
586                if self.angle_combine_pairs:
587                    Wa = self.param(
588                        f"Wa_{layer}",
589                        nn.initializers.normal(
590                            stddev=1.0
591                            / (si_mp.shape[1] * radial_basis_angle.shape[1]) ** 0.5
592                        ),
593                        (si_mp.shape[1], radial_basis_angle.shape[1], dim_angle[layer]),
594                    )
595                    Da = jnp.einsum(
596                        "...i,...j,ijk->...k",
597                        si_mp_ang,
598                        radial_basis_angle,
599                        Wa,
600                    )
601
602                else:
603                    if message_passing[layer] or layer == 0:
604                        Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")(
605                            xi
606                        )[graph_angle["edge_dst"]]
607                    else:
608                        Da = nn.Dense(dim_angle[layer], name=f"angle_linear_{layer}")(
609                            si_mp_ang
610                        )
611
612                Da = Da[angle_dst] * Da[angle_src]
613                ## combine pair and angle info
614                if reduce_memory:
615                    ang_embedding = jnp.zeros((species.shape[0]* Da.shape[-1],xa.shape[-1]),dtype=Da.dtype)
616                    for i in range(Da.shape[-1]):
617                        indices = i + central_atom*Da.shape[-1]
618                        ang_embedding = ang_embedding.at[indices].add(Da[:,i,None]*xa[:,0,:])
619                    ang_embedding = ang_embedding.reshape(species.shape[0], xa.shape[-1]*Da.shape[-1])
620                else:
621                    radang = (xa * Da[:, :, None]).reshape(
622                        (-1, Da.shape[1] * xa.shape[2])
623                    )
624                    ### AGGREGATE  ANGLE EMBEDDING ###
625                    ang_embedding = jax.ops.segment_sum(
626                        radang, central_atom, species.shape[0]
627                    )
628                    
629
630                components.append(ang_embedding)
631
632                if self.pair_embedding_key is not None:
633                    ang_ij = jax.ops.segment_sum(
634                        jnp.concatenate((radang, radang)),
635                        ang_pairs,
636                        edge_src.shape[0],
637                    )
638                    components_pair.append(ang_ij)
639            
640            ##################################################
641            ### EQUIVARIANT EMBEDDING ###
642            if self.lmax > 0 and n_tp[layer] >= 0:
643                if initialize_e3 or not message_passing[layer]:
644                    Vij = Yij
645                elif self.edge_tp:
646                    Vij = FilteredTensorProduct(
647                            self.lmax, self.lmax, name=f"edge_tp_{layer}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False
648                        )(Vi[edge_dst], Yij)
649                else:
650                    Vij = Vi[edge_dst]
651
652                ### compute channel weights
653                dim_wij = self.nchannels_l
654                if self.resolve_wij_l:
655                    dim_wij=self.nchannels_l*(self.lmax+1)
656
657                eij = Lij if self.pair_embedding_key is None else jnp.concatenate([Lij,xij*switch],axis=-1)
658                wij = nn.Dense(
659                        dim_wij, name=f"e3_channel_{layer}", use_bias=False
660                    )(eij)
661                if self.resolve_wij_l:
662                    wij = jnp.repeat(wij.reshape(-1,self.nchannels_l,self.lmax+1),nrep_l,axis=-1)
663                else:
664                    wij = wij[:,:,None]
665                
666                ### aggregate equivariant messages
667                drhoi = jax.ops.segment_sum(
668                    wij * Vij,
669                    edge_src,
670                    species.shape[0],
671                )
672
673                Vi0 = []
674                if initialize_e3:
675                    rhoi = drhoi
676                    Vi = ChannelMixingE3(
677                        self.lmax,
678                        self.nchannels_l,
679                        self.nchannels_l,
680                        name=f"e3_initial_mixing_{layer}",
681                    )(rhoi)
682                    # assert n_tp[layer] > 0, "n_tp must be > 0 for the first equivariant layer."
683                else:
684                    rhoi = rhoi + drhoi
685                    # if message_passing[layer]:
686                        # Vi0.append(drhoi[:, :, 0])
687                initialize_e3 = False
688                if n_tp[layer] > 0:
689                    for itp in range(n_tp[layer]):
690                        dVi = FilteredTensorProduct(
691                            self.lmax, self.lmax, name=f"tensor_product_{layer}_{itp}",ignore_parity=self.ignore_irreps_parity,weights_by_channel=False
692                        )(rhoi, Vi)
693                        Vi = ChannelMixing(
694                                self.lmax,
695                                self.nchannels_l,
696                                self.nchannels_l,
697                                name=f"tp_mixing_{layer}_{itp}",
698                            )(Vi + dVi)
699                        Vi0.append(dVi[:, :, 0])
700                    Vi0 = jnp.concatenate(Vi0, axis=-1)
701                    components.append(Vi0)
702
703                if self.pair_embedding_key is not None:
704                    Vij = Vi[edge_src]*Yij
705                    Vij0 = [Vij[...,0]]
706                    for l in range(1,self.lmax+1):
707                        Vij0.append(Vij[...,l**2:(l+1)**2].sum(axis=-1))
708                    Vij0 = jnp.concatenate(Vij0,axis=-1)
709                    components_pair.append(Vij0)
710
711            ##################################################
712            ### CONCATENATE EMBEDDING COMPONENTS ###
713            if do_lode and nchannels_lode[layer] > 0:
714                zj = nn.Dense(dim_lr[layer], use_bias=False, name=f"LODE_{layer}")(xi)
715                if nextra_powers > 0:
716                    zj_extra = zj[:,:nextra_powers*nchannels_lode[layer]].reshape(
717                        (species.shape[0],nchannels_lode[layer], nextra_powers)
718                    )
719                    zj = zj[:,nextra_powers*nchannels_lode[layer]:]
720                    xi_lr_extra = jax.ops.segment_sum(
721                        eij_lr_extra * zj_extra[edge_dst_lr], edge_src_lr, species.shape[0]
722                    )
723                    components.append(xi_lr_extra.reshape(species.shape[0],-1))
724
725                if equivariant_lode:
726                    zj = zj.reshape(
727                        (species.shape[0], nchannels_lode[layer], lmax_lr + 1)
728                    ).repeat(nrep_lr, axis=-1)
729                xi_lr = jax.ops.segment_sum(
730                    eij_lr * zj[edge_dst_lr], edge_src_lr, species.shape[0]
731                )
732                if equivariant_lode:
733                    assert self.lode_use_field_norm or self.lode_multipole_interaction, "equivariant LODE requires field norm or multipole interaction"
734                    if self.lode_multipole_interaction:
735                        if initialize_e3:
736                            raise ValueError("equivariant LODE used before local equivariants initialized")
737                        size_l_lr = (lmax_lr+1)**2
738                        if self.lode_direct_multipoles:
739                            assert nchannels_lode[layer] <= self.nchannels_l
740                            Mi = Vi[:, : nchannels_lode[layer], :size_l_lr]
741                        else:
742                            Mi = ChannelMixingE3(
743                                lmax_lr,
744                                self.nchannels_l,
745                                nchannels_lode[layer],
746                                name=f"e3_LODE_{layer}",
747                            )(Vi[...,:size_l_lr])
748                        Mi_lr = Mi * xi_lr
749                    components.append(xi_lr[:, :, 0])
750                    if self.lode_use_field_norm and self.lode_equi_full_combine:
751                        xi_lr1 = ChannelMixing(
752                            lmax_lr,
753                            nchannels_lode[layer],
754                            nchannels_lode[layer],
755                            name=f"LODE_mixing_{layer}",
756                        )(xi_lr)
757                    norm = 1.
758                    for l in range(1, lmax_lr + 1):
759                        if self.lode_normalize_l:
760                            norm = 1. / (2 * l + 1)
761                        if self.lode_multipole_interaction:
762                            components.append(Mi_lr[:, :, l**2 : (l + 1) ** 2].sum(axis=-1)*norm)
763
764                        if self.lode_use_field_norm:
765                            if self.lode_equi_full_combine:
766                                components.append((xi_lr[:,:,l**2 : (l + 1) ** 2]*xi_lr1[:,:,l**2 : (l + 1) ** 2]).sum(axis=-1)*norm)
767                            else:
768                                components.append(
769                                    ((xi_lr[:, :, l**2 : (l + 1) ** 2]) ** 2).sum(axis=-1)*norm
770                                )
771                else:
772                    components.append(xi_lr)
773
774            dxi = jnp.concatenate(components, axis=-1)
775
776            ##################################################
777            ### CONCATENATE PAIR EMBEDDING COMPONENTS ###
778            if self.pair_embedding_key is not None:
779                dxij = jnp.concatenate(components_pair, axis=-1)
780
781            ##################################################
782            ### MIX AND APPLY NONLINEARITY ###
783            if self.block_index_key is not None:
784                block_index = inputs[self.block_index_key]
785                dxi = actmix(BlockIndexNet(
786                        output_dim=self.dim,
787                        hidden_neurons=self.mixing_hidden,
788                        activation=self.activation,
789                        name=f"dxi_{layer}",
790                        use_bias=self.use_bias,
791                        kernel_init=kernel_init,
792                    )((species,dxi, block_index))
793                )
794            else:
795                dxi = actmix(
796                    FullyConnectedNet(
797                        [*self.mixing_hidden, self.dim],
798                        activation=self.activation,
799                        name=f"dxi_{layer}",
800                        use_bias=self.use_bias,
801                        kernel_init=kernel_init,
802                    )(dxi)
803                )
804
805            if self.pair_embedding_key is not None:
806                ### UPDATE PAIR EMBEDDING ###
807                # dxij = tssr3(nn.Dense(dim_dst, name=f"dxij_{layer}",use_bias=False)(dxij))
808                dxij = actmix(
809                    FullyConnectedNet(
810                        [*self.pair_mixing_hidden, dim_dst],
811                        activation=self.activation,
812                        name=f"dxij_{layer}",
813                        use_bias=False,
814                        kernel_init=kernel_init,
815                    )(dxij)
816                )
817                xij = layer_norm(xij + dxij)
818
819            ##################################################
820            ### UPDATE EMBEDDING ###
821            if layer == 0 and not (self.species_init or self.charge_embedding):
822                xi = layer_norm(dxi)
823            else:
824                ### FORGET GATE ###
825                R = jax.nn.sigmoid(
826                    self.param(
827                        f"retention_{layer}",
828                        nn.initializers.normal(),
829                        (xi.shape[-1],),
830                    )
831                )
832                xi = layer_norm(R[None, :] * xi + dxi)
833
834            if self.keep_all_layers:
835                xis.append(xi)
836
837        embedding_key = (
838            self.embedding_key if self.embedding_key is not None else self.name
839        )
840        output = {
841            **inputs,
842            embedding_key: xi,
843        }
844        if self.lmax > 0:
845            output[embedding_key + "_tensor"] = Vi
846        if self.keep_all_layers:
847            output[embedding_key + "_layers"] = jnp.stack(xis, axis=1)
848        if self.charge_embedding:
849            output[embedding_key + "_charge"] = charge_embedding
850        if self.pair_embedding_key is not None:
851            output[self.pair_embedding_key] = xij
852        return output

Configurable Resources ATomic Environment

FID : CRATE

This class represents the CRATE (Configurable Resources ATomic Environment) embedding model. It is used to encode atomic environments using multiple sources of information (radial, angular, E(3), message-passing, LODE, etc...)

CRATEmbedding( _graphs_properties: Dict, dim: int = 256, nlayers: int = 2, keep_all_layers: bool = False, dim_src: int = 64, dim_dst: int = 32, angle_style: str = 'fourier', dim_angle: int = 8, nmax_angle: int = 4, zeta: float = 14.1, angle_combine_pairs: bool = True, message_passing: bool = True, att_dim: int = 1, lmax: int = 0, nchannels_l: int = 16, n_tp: int = 1, ignore_irreps_parity: bool = False, edge_tp: bool = False, resolve_wij_l: bool = False, species_init: bool = False, mixing_hidden: Sequence[int] = <factory>, pair_mixing_hidden: Sequence[int] = <factory>, activation: Union[Callable, str] = 'silu', kernel_init: Union[str, Callable] = 'lecun_normal()', activation_mixing: Union[Callable, str] = 'tssr3', layer_normalization: bool = False, use_bias: bool = True, graph_key: str = 'graph', graph_angle_key: Optional[str] = None, embedding_key: Optional[str] = None, pair_embedding_key: Optional[str] = None, species_encoding: Union[dict, str] = <factory>, radial_basis: dict = <factory>, radial_basis_angle: Optional[dict] = None, graph_lode: Optional[str] = None, lode_channels: Union[int, Sequence[int]] = 8, lmax_lode: int = 0, a_lode: float = -1.0, lode_resolve_l: bool = True, lode_multipole_interaction: bool = True, lode_direct_multipoles: bool = True, lode_equi_full_combine: bool = False, lode_normalize_l: bool = False, lode_use_field_norm: bool = True, lode_rshort: Optional[float] = None, lode_dshort: float = 0.5, lode_extra_powers: Sequence[int] = (), charge_embedding: bool = False, total_charge_key: str = 'total_charge', block_index_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 256

The size of the embedding vectors.

nlayers: int = 2

The number of interaction layers in the model.

keep_all_layers: bool = False

Whether to output all layers.

dim_src: int = 64

The size of the source embedding vectors.

dim_dst: int = 32

The size of the destination embedding vectors.

angle_style: str = 'fourier'

The style of angle representation.

dim_angle: int = 8

The size of the pairwise vectors use for triplet combinations.

nmax_angle: int = 4

The dimension of the angle representation (minus one).

zeta: float = 14.1

The zeta parameter for the model ANI angular representation.

angle_combine_pairs: bool = True

Whether to combine angle pairs instead of average distance embedding like in ANI.

message_passing: bool = True

Whether to use message passing in the model.

att_dim: int = 1

The hidden size for the attention mechanism (only used when message-passing is disabled).

lmax: int = 0

The maximum order of spherical tensors.

nchannels_l: int = 16

The number of channels for spherical tensors.

n_tp: int = 1

The number of tensor products performed at each layer.

ignore_irreps_parity: bool = False

Whether to ignore the parity of the irreps in the tensor product.

edge_tp: bool = False

Whether to perform a tensor product on edges before sending messages.

resolve_wij_l: bool = False

Equivariant message weights are l-dependent.

species_init: bool = False

Whether to initialize the embedding using the species encoding.

mixing_hidden: Sequence[int]

The hidden layer sizes for the mixing network.

pair_mixing_hidden: Sequence[int]

The hidden layer sizes for the pair mixing network.

activation: Union[Callable, str] = 'silu'

The activation function for the mixing network.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization function for Dense operations.

activation_mixing: Union[Callable, str] = 'tssr3'

The activation function applied after mixing.

layer_normalization: bool = False

Whether to apply layer normalization after each layer.

use_bias: bool = True

Whether to use bias in the Dense operations.

graph_key: str = 'graph'

The key for the graph data in the inputs dictionary.

graph_angle_key: Optional[str] = None

The key for the angle graph data in the inputs dictionary.

embedding_key: Optional[str] = None

The key for the embedding data in the output dictionary.

pair_embedding_key: Optional[str] = None

The key for the pair embedding data in the output dictionary.

species_encoding: Union[dict, str]

If str, it is the key in the inputs dictionary that contains species encodings. Else, it is the dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding.

radial_basis: dict

The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.

radial_basis_angle: Optional[dict] = None

The dictionary of parameters for radial basis functions for angle embedding. If None, the radial basis for angles is the same as the radial basis for distances.

graph_lode: Optional[str] = None

The key for the lode graph data in the inputs dictionary.

lode_channels: Union[int, Sequence[int]] = 8

The number of channels for lode.

lmax_lode: int = 0

The maximum order of spherical tensors for lode.

a_lode: float = -1.0

The cutoff for the lode graph. If negative, the value is trainable with starting value -a_lode.

lode_resolve_l: bool = True

Whether to resolve the lode channels by l.

lode_multipole_interaction: bool = True

Whether to interact with the multipole moments of the lode graph.

lode_direct_multipoles: bool = True

Whether to directly use the first local equivariants to interact with long-range equivariants. If false, local equivariants are mixed before interaction.

lode_equi_full_combine: bool = False
lode_normalize_l: bool = False
lode_use_field_norm: bool = True
lode_rshort: Optional[float] = None
lode_dshort: float = 0.5
lode_extra_powers: Sequence[int] = ()
charge_embedding: bool = False

Whether to include charge embedding.

total_charge_key: str = 'total_charge'

The key for the total charge data in the inputs dictionary.

block_index_key: Optional[str] = None

The key for the block index. If provided, will use a BLOCK_INDEX_NET as a mixing network.

FID: ClassVar[str] = 'CRATE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None