fennol.models.embeddings.raster

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Sequence, Dict, Union, ClassVar, Optional
  5import numpy as np
  6from ...utils.periodic_table import D3_COV_RADII
  7import dataclasses
  8from ..misc.encodings import RadialBasis, SpeciesEncoding
  9from ...utils.spherical_harmonics import generate_spherical_harmonics
 10from ..misc.e3 import  ChannelMixing
 11from ..misc.nets import FullyConnectedNet,BlockIndexNet
 12from ...utils.activations import activation_from_str
 13from ...utils import AtomicUnits as au
 14
 15class RaSTER(nn.Module):
 16    """ Range-Separated Transformer with Equivariant Representations
 17
 18    FID : RASTER
 19
 20    """
 21
 22    _graphs_properties: Dict
 23    dim: int = 176
 24    """The dimension of the output embedding."""
 25    nlayers: int = 2
 26    """The number of message-passing layers."""
 27    att_dim: int = 16
 28    """The dimension of the attention heads."""
 29    scal_heads: int = 16
 30    """The number of scalar attention heads."""
 31    tens_heads: int = 4
 32    """The number of tensor attention heads."""
 33    lmax: int = 3
 34    """The maximum angular momentum to consider."""
 35    normalize_vec: bool = True
 36    """Whether to normalize the vector features before computing spherical harmonics."""
 37    att_activation: str = "identity"
 38    """The activation function to use for the attention coefficients."""
 39    activation: str = "swish"
 40    """The activation function to use for the update network."""
 41    update_hidden: Sequence[int] = ()
 42    """The hidden layers for the update network."""
 43    update_bias: bool = True
 44    """Whether to use bias in the update network."""
 45    positional_activation: str = "swish"
 46    """The activation function to use for the positional embedding network."""
 47    positional_bias: bool = True
 48    """Whether to use bias in the positional embedding network."""
 49    switch_before_net: bool = False
 50    """Whether to apply the switch function to the radial basis before the edge neural network."""
 51    ignore_parity: bool = False
 52    """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding."""
 53    additive_positional: bool = False
 54    """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used."""
 55    edge_value: bool = False
 56    """Whether to use edge values in the attention mechanism."""
 57    layer_normalization: bool = True
 58    """Whether to use layer normalization of atomic embeddings."""
 59    graph_key: str = "graph"
 60    """ The key in the input dictionary that corresponds to the radial graph."""
 61    embedding_key: str = "embedding"
 62    """ The key in the output dictionary that corresponds to the embedding."""
 63    radial_basis: dict = dataclasses.field(
 64        default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16}
 65    )
 66    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 67    species_encoding: str | dict = dataclasses.field(
 68        default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"}
 69    )
 70    """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 71    graph_lode: Optional[str] = None
 72    """The key in the input dictionary that corresponds to the long-range graph."""
 73    lmax_lode: int = 0
 74    """The maximum angular momentum for the long-range features."""
 75    lode_rshort: Optional[float] = None
 76    """The short-range cutoff for the long-range features."""
 77    lode_dshort: float = 2.0
 78    """The width of the short-range cutoff for the long-range features."""
 79    lode_extra_powers: Sequence[int] = ()
 80    """The extra powers to include in the long-range features."""
 81    a_lode: float = -1.0
 82    """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode)."""
 83    block_index_key: Optional[str] = None
 84    """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used."""
 85    lode_channels: int = 1
 86    """The number of channels for the long-range features."""
 87    switch_cov_start: float = 0.5
 88    """The start of close-range covalent switch (in units of covalent radii)."""
 89    switch_cov_end: float = 0.6
 90    """The end of close-range covalent switch (in units of covalent radii)."""
 91    normalize_keys: bool = False
 92    """Whether to normalize queries and keys in the attention mechanism."""
 93    keep_all_layers: bool = False
 94    """Whether to return the stacked scalar embeddings from all message-passing layers."""
 95    
 96    FID: ClassVar[str] = "RASTER"
 97
 98    @nn.compact
 99    def __call__(self, inputs):
100        species = inputs["species"]
101
102        ## SETUP LAYER NORMALIZATION
103        def _layer_norm(x):
104            mu = jnp.mean(x, axis=-1, keepdims=True)
105            dx = x - mu
106            var = jnp.mean(dx**2, axis=-1, keepdims=True)
107            sig = (1.0e-6 + var) ** (-0.5)
108            return dx * sig
109
110        if self.layer_normalization:
111            layer_norm = _layer_norm
112        else:
113            layer_norm = lambda x: x
114        
115        if self.normalize_keys:
116            ln_qk = _layer_norm
117        else:
118            ln_qk = lambda x: x
119
120        ## SPECIES ENCODING
121        if isinstance(self.species_encoding, str):
122            Zi = inputs[self.species_encoding]
123        else:
124            Zi = SpeciesEncoding(**self.species_encoding)(species)
125
126        ## INITIALIZE SCALAR FEATURES
127        xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear")(Zi))
128
129        # RADIAL GRAPH
130        graph = inputs[self.graph_key]
131        distances = graph["distances"]
132        switch = graph["switch"]
133        edge_src = graph["edge_src"]
134        edge_dst = graph["edge_dst"]
135        vec = (
136            graph["vec"] / graph["distances"][:, None]
137            if self.normalize_vec
138            else graph["vec"]
139        )
140        ## CLOSE-RANGE SWITCH
141        use_switch_cov = False
142        if self.switch_cov_end > 0 and self.switch_cov_start > 0:
143            use_switch_cov = True
144            assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}"
145            assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1"
146            rc = jnp.array(D3_COV_RADII*au.BOHR)[species]
147            rcij = rc[edge_src] + rc[edge_dst]
148            rstart = rcij * self.switch_cov_start
149            rend = rcij * self.switch_cov_end
150            switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend)
151            switch = switch * switch_short
152
153        ## COMPUTE SPHERICAL HARMONICS ON EDGES
154        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:]
155        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
156        ls = np.arange(self.lmax + 1).repeat(nrep)
157            
158        parity = jnp.array((-1) ** ls[None,None,:])
159        if self.ignore_parity:
160            parity = -jnp.ones_like(parity)
161
162        ## INITIALIZE TENSOR FEATURES
163        Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1]))
164
165        # RADIAL BASIS
166        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
167        radial_terms = RadialBasis(
168            **{
169                **self.radial_basis,
170                "end": cutoff,
171                "name": f"RadialBasis",
172            }
173        )(distances)
174        if self.switch_before_net:
175            radial_terms = radial_terms * switch[:, None]
176        elif use_switch_cov:
177            radial_terms = radial_terms * switch_short[:, None]
178
179        ## INITIALIZE LODE
180        do_lode = self.graph_lode is not None
181        if do_lode:
182            ## LONG-RANGE GRAPH
183            graph_lode = inputs[self.graph_lode]
184            switch_lode = graph_lode["switch"][:, None]
185            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
186            r = graph_lode["distances"][:, None]
187            rc = self._graphs_properties[self.graph_lode]["cutoff"]
188
189            lmax_lr = self.lmax_lode
190            equivariant_lode = lmax_lr > 0
191            assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}"
192            assert (
193                lmax_lr <= self.lmax
194            ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
195            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
196            if equivariant_lode:
197                ls_lr = np.arange(lmax_lr + 1)
198            else:
199                ls_lr = np.array([0])
200
201            ## PARAMETERS FOR THE LR RADIAL BASIS
202            nextra_powers = len(self.lode_extra_powers)
203            if nextra_powers > 0:
204                ls_lr = np.concatenate([self.lode_extra_powers, ls_lr])
205
206            if self.a_lode > 0:
207                a = self.a_lode**2
208            else:
209                a = (
210                    self.param(
211                        "a_lr",
212                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[
213                            None, :
214                        ],
215                    )
216                    ** 2
217                )
218            rc2a = rc**2 + a
219            ls_lr = 0.5 * (ls_lr[None, :] + 1)
220            ### minimal radial basis for long range (damped coulomb)
221            eij_lr = (
222                1.0 / (r**2 + a) ** ls_lr
223                - 1.0 / rc2a**ls_lr
224                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
225            ) * switch_lode
226
227            if self.lode_rshort is not None:
228                rs = self.lode_rshort
229                d = self.lode_dshort
230                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
231                    r < rs + d
232                ) + (r >= rs + d)
233                eij_lr = eij_lr * switch_short
234
235            dim_lr = 1
236            if nextra_powers > 0:
237                eij_lr_extra = eij_lr[:, :nextra_powers]
238                eij_lr = eij_lr[:, nextra_powers:]
239                dim_lr += nextra_powers
240
241            if equivariant_lode:
242                ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH
243                eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
244                Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
245                    graph_lode["vec"] / r
246                )
247                dim_lr += lmax_lr
248                eij_lr = eij_lr * Yij_lr
249                del Yij_lr
250        
251
252        if self.keep_all_layers:
253            fis = []
254
255        ### START MESSAGE PASSING ITERATIONS
256        for layer in range(self.nlayers):
257            ## GATHER SCALAR EDGE FEATURES
258            u = [radial_terms]
259            if layer > 0:
260                ## edge-tensor contraction
261                xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij
262                for l in range(self.lmax + 1):
263                    u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1))
264            ur = jnp.concatenate(u, axis=-1)
265
266            ## BUILD RELATIVE POSITIONAL ENCODING
267            if self.edge_value:
268                nout = 2
269            else:
270                nout = 1
271            w = FullyConnectedNet(
272                [2 * self.att_dim, nout*self.att_dim],
273                activation=self.positional_activation,
274                use_bias=self.positional_bias,
275                name=f"positional_encoding_{layer}",
276            )(ur).reshape(radial_terms.shape[0],nout, self.att_dim)
277            if self.edge_value:
278                w,vij = jnp.split(w, 2, axis=1)
279
280            nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1)
281
282
283            ## QUERY, KEY, VALUE
284            q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}")(
285                xi
286            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim))
287            k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}")(
288                xi
289            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)
290
291            v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}")(xi).reshape(
292                xi.shape[0], self.scal_heads, self.att_dim
293            )
294
295            ## ATTENTION COEFFICIENTS
296            if self.additive_positional:
297                wk = ln_qk(w + k[edge_dst])
298            else:
299                wk = ln_qk(w * k[edge_dst])
300
301            act = activation_from_str(self.att_activation)
302            aij = (
303                act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5))
304                * switch[:, None]
305            )
306
307            aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
308            if layer > 0:
309                aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
310            aij = aij[:, self.tens_heads*nls:, None]
311
312            if self.edge_value:
313                ## EDGE VALUES
314                if self.additive_positional:
315                    vij = vij + v[edge_dst]
316                else:
317                    vij = vij * v[edge_dst]
318            else:
319                ## MOVE DEST VALUES TO EDGE
320                vij = v[edge_dst]
321
322            ## SCALAR ATTENDED FEATURES
323            vai = jax.ops.segment_sum(
324                aij * vij,
325                edge_src,
326                num_segments=xi.shape[0],
327            )
328            vai = vai.reshape(xi.shape[0], -1)
329
330            ### TENSOR ATTENDED FEATURES
331            uij = aijl * Yij
332            if layer > 0:
333                uij = uij + aijl1 * Vi[edge_dst]
334            Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0])
335
336            ## SELF SCALAR FEATURES
337            si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}")(xi)
338
339            components = [si, vai]
340
341            ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS
342            if self.tens_heads == 1:
343                Vi2 = Vi**2
344            else:
345                Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi)
346            for l in range(self.lmax + 1):
347                norm = 1.0 / (2 * l + 1)
348                components.append(
349                    (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm
350                )
351
352            ### LODE (~ LONG-RANGE ATTENTION)
353            if do_lode and layer == self.nlayers - 1:
354                assert self.lode_channels <= self.tens_heads
355                zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}")(xi).reshape(
356                    xi.shape[0], self.lode_channels, dim_lr
357                )
358                if nextra_powers > 0:
359                    zj_extra = zj[:,:, :nextra_powers]
360                    zj = zj[:, :, nextra_powers:]
361                    xi_lr_extra = jax.ops.segment_sum(
362                        eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr],
363                        edge_src_lr,
364                        species.shape[0],
365                    ).reshape(species.shape[0],-1)
366                    components.append(xi_lr_extra)
367                if equivariant_lode:
368                    zj = zj.repeat(nrep_lr, axis=-1)
369                Vi_lr = jax.ops.segment_sum(
370                    eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0]
371                )
372                components.append(Vi_lr[:,: , 0])
373                if equivariant_lode:
374                    Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr
375                    for l in range(1, lmax_lr + 1):
376                        norm = 1.0 / (2 * l + 1)
377                        components.append(
378                            Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1)
379                            * norm
380                        )
381
382            ### CONCATENATE UPDATE COMPONENTS
383            components = jnp.concatenate(components, axis=-1)
384            ### COMPUTE UPDATE
385            if self.block_index_key is not None:
386                ## MoE neural network from block index
387                block_index = inputs[self.block_index_key]
388                updi = BlockIndexNet(
389                        output_dim=self.dim + self.tens_heads*(self.lmax + 1),
390                        hidden_neurons=self.update_hidden,
391                        activation=self.activation,
392                        use_bias=self.update_bias,
393                        name=f"update_net_{layer}",
394                    )((species,components, block_index))
395            else:
396                updi = FullyConnectedNet(
397                        [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)],
398                        activation=self.activation,
399                        use_bias=self.update_bias,
400                        name=f"update_net_{layer}",
401                    )(components)
402                
403            ## UPDATE ATOM FEATURES
404            xi = layer_norm(xi + updi[:,:self.dim])
405            Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
406            if self.tens_heads > 1:
407                Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi)
408
409            if self.keep_all_layers:
410                ## STORE ALL LAYERS
411                fis.append(xi)
412
413
414        output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi}
415        if self.keep_all_layers:
416            output[self.embedding_key+'_layers'] = jnp.stack(fis,axis=1)
417        return output
class RaSTER(flax.linen.module.Module):
 16class RaSTER(nn.Module):
 17    """ Range-Separated Transformer with Equivariant Representations
 18
 19    FID : RASTER
 20
 21    """
 22
 23    _graphs_properties: Dict
 24    dim: int = 176
 25    """The dimension of the output embedding."""
 26    nlayers: int = 2
 27    """The number of message-passing layers."""
 28    att_dim: int = 16
 29    """The dimension of the attention heads."""
 30    scal_heads: int = 16
 31    """The number of scalar attention heads."""
 32    tens_heads: int = 4
 33    """The number of tensor attention heads."""
 34    lmax: int = 3
 35    """The maximum angular momentum to consider."""
 36    normalize_vec: bool = True
 37    """Whether to normalize the vector features before computing spherical harmonics."""
 38    att_activation: str = "identity"
 39    """The activation function to use for the attention coefficients."""
 40    activation: str = "swish"
 41    """The activation function to use for the update network."""
 42    update_hidden: Sequence[int] = ()
 43    """The hidden layers for the update network."""
 44    update_bias: bool = True
 45    """Whether to use bias in the update network."""
 46    positional_activation: str = "swish"
 47    """The activation function to use for the positional embedding network."""
 48    positional_bias: bool = True
 49    """Whether to use bias in the positional embedding network."""
 50    switch_before_net: bool = False
 51    """Whether to apply the switch function to the radial basis before the edge neural network."""
 52    ignore_parity: bool = False
 53    """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding."""
 54    additive_positional: bool = False
 55    """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used."""
 56    edge_value: bool = False
 57    """Whether to use edge values in the attention mechanism."""
 58    layer_normalization: bool = True
 59    """Whether to use layer normalization of atomic embeddings."""
 60    graph_key: str = "graph"
 61    """ The key in the input dictionary that corresponds to the radial graph."""
 62    embedding_key: str = "embedding"
 63    """ The key in the output dictionary that corresponds to the embedding."""
 64    radial_basis: dict = dataclasses.field(
 65        default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16}
 66    )
 67    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 68    species_encoding: str | dict = dataclasses.field(
 69        default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"}
 70    )
 71    """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 72    graph_lode: Optional[str] = None
 73    """The key in the input dictionary that corresponds to the long-range graph."""
 74    lmax_lode: int = 0
 75    """The maximum angular momentum for the long-range features."""
 76    lode_rshort: Optional[float] = None
 77    """The short-range cutoff for the long-range features."""
 78    lode_dshort: float = 2.0
 79    """The width of the short-range cutoff for the long-range features."""
 80    lode_extra_powers: Sequence[int] = ()
 81    """The extra powers to include in the long-range features."""
 82    a_lode: float = -1.0
 83    """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode)."""
 84    block_index_key: Optional[str] = None
 85    """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used."""
 86    lode_channels: int = 1
 87    """The number of channels for the long-range features."""
 88    switch_cov_start: float = 0.5
 89    """The start of close-range covalent switch (in units of covalent radii)."""
 90    switch_cov_end: float = 0.6
 91    """The end of close-range covalent switch (in units of covalent radii)."""
 92    normalize_keys: bool = False
 93    """Whether to normalize queries and keys in the attention mechanism."""
 94    keep_all_layers: bool = False
 95    """Whether to return the stacked scalar embeddings from all message-passing layers."""
 96    
 97    FID: ClassVar[str] = "RASTER"
 98
 99    @nn.compact
100    def __call__(self, inputs):
101        species = inputs["species"]
102
103        ## SETUP LAYER NORMALIZATION
104        def _layer_norm(x):
105            mu = jnp.mean(x, axis=-1, keepdims=True)
106            dx = x - mu
107            var = jnp.mean(dx**2, axis=-1, keepdims=True)
108            sig = (1.0e-6 + var) ** (-0.5)
109            return dx * sig
110
111        if self.layer_normalization:
112            layer_norm = _layer_norm
113        else:
114            layer_norm = lambda x: x
115        
116        if self.normalize_keys:
117            ln_qk = _layer_norm
118        else:
119            ln_qk = lambda x: x
120
121        ## SPECIES ENCODING
122        if isinstance(self.species_encoding, str):
123            Zi = inputs[self.species_encoding]
124        else:
125            Zi = SpeciesEncoding(**self.species_encoding)(species)
126
127        ## INITIALIZE SCALAR FEATURES
128        xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear")(Zi))
129
130        # RADIAL GRAPH
131        graph = inputs[self.graph_key]
132        distances = graph["distances"]
133        switch = graph["switch"]
134        edge_src = graph["edge_src"]
135        edge_dst = graph["edge_dst"]
136        vec = (
137            graph["vec"] / graph["distances"][:, None]
138            if self.normalize_vec
139            else graph["vec"]
140        )
141        ## CLOSE-RANGE SWITCH
142        use_switch_cov = False
143        if self.switch_cov_end > 0 and self.switch_cov_start > 0:
144            use_switch_cov = True
145            assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}"
146            assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1"
147            rc = jnp.array(D3_COV_RADII*au.BOHR)[species]
148            rcij = rc[edge_src] + rc[edge_dst]
149            rstart = rcij * self.switch_cov_start
150            rend = rcij * self.switch_cov_end
151            switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend)
152            switch = switch * switch_short
153
154        ## COMPUTE SPHERICAL HARMONICS ON EDGES
155        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:]
156        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
157        ls = np.arange(self.lmax + 1).repeat(nrep)
158            
159        parity = jnp.array((-1) ** ls[None,None,:])
160        if self.ignore_parity:
161            parity = -jnp.ones_like(parity)
162
163        ## INITIALIZE TENSOR FEATURES
164        Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1]))
165
166        # RADIAL BASIS
167        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
168        radial_terms = RadialBasis(
169            **{
170                **self.radial_basis,
171                "end": cutoff,
172                "name": f"RadialBasis",
173            }
174        )(distances)
175        if self.switch_before_net:
176            radial_terms = radial_terms * switch[:, None]
177        elif use_switch_cov:
178            radial_terms = radial_terms * switch_short[:, None]
179
180        ## INITIALIZE LODE
181        do_lode = self.graph_lode is not None
182        if do_lode:
183            ## LONG-RANGE GRAPH
184            graph_lode = inputs[self.graph_lode]
185            switch_lode = graph_lode["switch"][:, None]
186            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
187            r = graph_lode["distances"][:, None]
188            rc = self._graphs_properties[self.graph_lode]["cutoff"]
189
190            lmax_lr = self.lmax_lode
191            equivariant_lode = lmax_lr > 0
192            assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}"
193            assert (
194                lmax_lr <= self.lmax
195            ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
196            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
197            if equivariant_lode:
198                ls_lr = np.arange(lmax_lr + 1)
199            else:
200                ls_lr = np.array([0])
201
202            ## PARAMETERS FOR THE LR RADIAL BASIS
203            nextra_powers = len(self.lode_extra_powers)
204            if nextra_powers > 0:
205                ls_lr = np.concatenate([self.lode_extra_powers, ls_lr])
206
207            if self.a_lode > 0:
208                a = self.a_lode**2
209            else:
210                a = (
211                    self.param(
212                        "a_lr",
213                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[
214                            None, :
215                        ],
216                    )
217                    ** 2
218                )
219            rc2a = rc**2 + a
220            ls_lr = 0.5 * (ls_lr[None, :] + 1)
221            ### minimal radial basis for long range (damped coulomb)
222            eij_lr = (
223                1.0 / (r**2 + a) ** ls_lr
224                - 1.0 / rc2a**ls_lr
225                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
226            ) * switch_lode
227
228            if self.lode_rshort is not None:
229                rs = self.lode_rshort
230                d = self.lode_dshort
231                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
232                    r < rs + d
233                ) + (r >= rs + d)
234                eij_lr = eij_lr * switch_short
235
236            dim_lr = 1
237            if nextra_powers > 0:
238                eij_lr_extra = eij_lr[:, :nextra_powers]
239                eij_lr = eij_lr[:, nextra_powers:]
240                dim_lr += nextra_powers
241
242            if equivariant_lode:
243                ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH
244                eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
245                Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
246                    graph_lode["vec"] / r
247                )
248                dim_lr += lmax_lr
249                eij_lr = eij_lr * Yij_lr
250                del Yij_lr
251        
252
253        if self.keep_all_layers:
254            fis = []
255
256        ### START MESSAGE PASSING ITERATIONS
257        for layer in range(self.nlayers):
258            ## GATHER SCALAR EDGE FEATURES
259            u = [radial_terms]
260            if layer > 0:
261                ## edge-tensor contraction
262                xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij
263                for l in range(self.lmax + 1):
264                    u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1))
265            ur = jnp.concatenate(u, axis=-1)
266
267            ## BUILD RELATIVE POSITIONAL ENCODING
268            if self.edge_value:
269                nout = 2
270            else:
271                nout = 1
272            w = FullyConnectedNet(
273                [2 * self.att_dim, nout*self.att_dim],
274                activation=self.positional_activation,
275                use_bias=self.positional_bias,
276                name=f"positional_encoding_{layer}",
277            )(ur).reshape(radial_terms.shape[0],nout, self.att_dim)
278            if self.edge_value:
279                w,vij = jnp.split(w, 2, axis=1)
280
281            nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1)
282
283
284            ## QUERY, KEY, VALUE
285            q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}")(
286                xi
287            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim))
288            k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}")(
289                xi
290            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)
291
292            v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}")(xi).reshape(
293                xi.shape[0], self.scal_heads, self.att_dim
294            )
295
296            ## ATTENTION COEFFICIENTS
297            if self.additive_positional:
298                wk = ln_qk(w + k[edge_dst])
299            else:
300                wk = ln_qk(w * k[edge_dst])
301
302            act = activation_from_str(self.att_activation)
303            aij = (
304                act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5))
305                * switch[:, None]
306            )
307
308            aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
309            if layer > 0:
310                aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
311            aij = aij[:, self.tens_heads*nls:, None]
312
313            if self.edge_value:
314                ## EDGE VALUES
315                if self.additive_positional:
316                    vij = vij + v[edge_dst]
317                else:
318                    vij = vij * v[edge_dst]
319            else:
320                ## MOVE DEST VALUES TO EDGE
321                vij = v[edge_dst]
322
323            ## SCALAR ATTENDED FEATURES
324            vai = jax.ops.segment_sum(
325                aij * vij,
326                edge_src,
327                num_segments=xi.shape[0],
328            )
329            vai = vai.reshape(xi.shape[0], -1)
330
331            ### TENSOR ATTENDED FEATURES
332            uij = aijl * Yij
333            if layer > 0:
334                uij = uij + aijl1 * Vi[edge_dst]
335            Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0])
336
337            ## SELF SCALAR FEATURES
338            si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}")(xi)
339
340            components = [si, vai]
341
342            ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS
343            if self.tens_heads == 1:
344                Vi2 = Vi**2
345            else:
346                Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi)
347            for l in range(self.lmax + 1):
348                norm = 1.0 / (2 * l + 1)
349                components.append(
350                    (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm
351                )
352
353            ### LODE (~ LONG-RANGE ATTENTION)
354            if do_lode and layer == self.nlayers - 1:
355                assert self.lode_channels <= self.tens_heads
356                zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}")(xi).reshape(
357                    xi.shape[0], self.lode_channels, dim_lr
358                )
359                if nextra_powers > 0:
360                    zj_extra = zj[:,:, :nextra_powers]
361                    zj = zj[:, :, nextra_powers:]
362                    xi_lr_extra = jax.ops.segment_sum(
363                        eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr],
364                        edge_src_lr,
365                        species.shape[0],
366                    ).reshape(species.shape[0],-1)
367                    components.append(xi_lr_extra)
368                if equivariant_lode:
369                    zj = zj.repeat(nrep_lr, axis=-1)
370                Vi_lr = jax.ops.segment_sum(
371                    eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0]
372                )
373                components.append(Vi_lr[:,: , 0])
374                if equivariant_lode:
375                    Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr
376                    for l in range(1, lmax_lr + 1):
377                        norm = 1.0 / (2 * l + 1)
378                        components.append(
379                            Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1)
380                            * norm
381                        )
382
383            ### CONCATENATE UPDATE COMPONENTS
384            components = jnp.concatenate(components, axis=-1)
385            ### COMPUTE UPDATE
386            if self.block_index_key is not None:
387                ## MoE neural network from block index
388                block_index = inputs[self.block_index_key]
389                updi = BlockIndexNet(
390                        output_dim=self.dim + self.tens_heads*(self.lmax + 1),
391                        hidden_neurons=self.update_hidden,
392                        activation=self.activation,
393                        use_bias=self.update_bias,
394                        name=f"update_net_{layer}",
395                    )((species,components, block_index))
396            else:
397                updi = FullyConnectedNet(
398                        [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)],
399                        activation=self.activation,
400                        use_bias=self.update_bias,
401                        name=f"update_net_{layer}",
402                    )(components)
403                
404            ## UPDATE ATOM FEATURES
405            xi = layer_norm(xi + updi[:,:self.dim])
406            Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
407            if self.tens_heads > 1:
408                Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi)
409
410            if self.keep_all_layers:
411                ## STORE ALL LAYERS
412                fis.append(xi)
413
414
415        output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi}
416        if self.keep_all_layers:
417            output[self.embedding_key+'_layers'] = jnp.stack(fis,axis=1)
418        return output

Range-Separated Transformer with Equivariant Representations

FID : RASTER

RaSTER( _graphs_properties: Dict, dim: int = 176, nlayers: int = 2, att_dim: int = 16, scal_heads: int = 16, tens_heads: int = 4, lmax: int = 3, normalize_vec: bool = True, att_activation: str = 'identity', activation: str = 'swish', update_hidden: Sequence[int] = (), update_bias: bool = True, positional_activation: str = 'swish', positional_bias: bool = True, switch_before_net: bool = False, ignore_parity: bool = False, additive_positional: bool = False, edge_value: bool = False, layer_normalization: bool = True, graph_key: str = 'graph', embedding_key: str = 'embedding', radial_basis: dict = <factory>, species_encoding: str | dict = <factory>, graph_lode: Optional[str] = None, lmax_lode: int = 0, lode_rshort: Optional[float] = None, lode_dshort: float = 2.0, lode_extra_powers: Sequence[int] = (), a_lode: float = -1.0, block_index_key: Optional[str] = None, lode_channels: int = 1, switch_cov_start: float = 0.5, switch_cov_end: float = 0.6, normalize_keys: bool = False, keep_all_layers: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 176

The dimension of the output embedding.

nlayers: int = 2

The number of message-passing layers.

att_dim: int = 16

The dimension of the attention heads.

scal_heads: int = 16

The number of scalar attention heads.

tens_heads: int = 4

The number of tensor attention heads.

lmax: int = 3

The maximum angular momentum to consider.

normalize_vec: bool = True

Whether to normalize the vector features before computing spherical harmonics.

att_activation: str = 'identity'

The activation function to use for the attention coefficients.

activation: str = 'swish'

The activation function to use for the update network.

update_hidden: Sequence[int] = ()

The hidden layers for the update network.

update_bias: bool = True

Whether to use bias in the update network.

positional_activation: str = 'swish'

The activation function to use for the positional embedding network.

positional_bias: bool = True

Whether to use bias in the positional embedding network.

switch_before_net: bool = False

Whether to apply the switch function to the radial basis before the edge neural network.

ignore_parity: bool = False

Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.

additive_positional: bool = False

Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.

edge_value: bool = False

Whether to use edge values in the attention mechanism.

layer_normalization: bool = True

Whether to use layer normalization of atomic embeddings.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the radial graph.

embedding_key: str = 'embedding'

The key in the output dictionary that corresponds to the embedding.

radial_basis: dict

The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.

species_encoding: str | dict

The dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding.

graph_lode: Optional[str] = None

The key in the input dictionary that corresponds to the long-range graph.

lmax_lode: int = 0

The maximum angular momentum for the long-range features.

lode_rshort: Optional[float] = None

The short-range cutoff for the long-range features.

lode_dshort: float = 2.0

The width of the short-range cutoff for the long-range features.

lode_extra_powers: Sequence[int] = ()

The extra powers to include in the long-range features.

a_lode: float = -1.0

The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).

block_index_key: Optional[str] = None

The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.

lode_channels: int = 1

The number of channels for the long-range features.

switch_cov_start: float = 0.5

The start of close-range covalent switch (in units of covalent radii).

switch_cov_end: float = 0.6

The end of close-range covalent switch (in units of covalent radii).

normalize_keys: bool = False

Whether to normalize queries and keys in the attention mechanism.

keep_all_layers: bool = False

Whether to return the stacked scalar embeddings from all message-passing layers.

FID: ClassVar[str] = 'RASTER'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None