fennol.models.embeddings.raster

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Sequence, Dict, Union, ClassVar, Optional
  5import numpy as np
  6from ...utils.periodic_table import D3_COV_RADII
  7import dataclasses
  8from ..misc.encodings import RadialBasis, SpeciesEncoding
  9from ...utils.spherical_harmonics import generate_spherical_harmonics
 10from ..misc.e3 import  ChannelMixing
 11from ..misc.nets import FullyConnectedNet,BlockIndexNet
 12from ...utils.activations import activation_from_str
 13from ...utils.atomic_units import au
 14from ...utils.initializers import initializer_from_str
 15
 16class RaSTER(nn.Module):
 17    """ Range-Separated Transformer with Equivariant Representations
 18
 19    FID : RASTER
 20
 21    """
 22
 23    _graphs_properties: Dict
 24    dim: int = 176
 25    """The dimension of the output embedding."""
 26    nlayers: int = 2
 27    """The number of message-passing layers."""
 28    att_dim: int = 16
 29    """The dimension of the attention heads."""
 30    scal_heads: int = 16
 31    """The number of scalar attention heads."""
 32    tens_heads: int = 4
 33    """The number of tensor attention heads."""
 34    lmax: int = 3
 35    """The maximum angular momentum to consider."""
 36    normalize_vec: bool = True
 37    """Whether to normalize the vector features before computing spherical harmonics."""
 38    att_activation: str = "identity"
 39    """The activation function to use for the attention coefficients."""
 40    activation: str = "swish"
 41    """The activation function to use for the update network."""
 42    update_hidden: Sequence[int] = ()
 43    """The hidden layers for the update network."""
 44    update_bias: bool = True
 45    """Whether to use bias in the update network."""
 46    positional_activation: str = "swish"
 47    """The activation function to use for the positional embedding network."""
 48    positional_bias: bool = True
 49    """Whether to use bias in the positional embedding network."""
 50    switch_before_net: bool = False
 51    """Whether to apply the switch function to the radial basis before the edge neural network."""
 52    ignore_parity: bool = False
 53    """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding."""
 54    additive_positional: bool = False
 55    """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used."""
 56    edge_value: bool = False
 57    """Whether to use edge values in the attention mechanism."""
 58    layer_normalization: bool = True
 59    """Whether to use layer normalization of atomic embeddings."""
 60    graph_key: str = "graph"
 61    """ The key in the input dictionary that corresponds to the radial graph."""
 62    embedding_key: str = "embedding"
 63    """ The key in the output dictionary that corresponds to the embedding."""
 64    radial_basis: dict = dataclasses.field(
 65        default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16}
 66    )
 67    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 68    species_encoding: str | dict = dataclasses.field(
 69        default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"}
 70    )
 71    """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 72    graph_lode: Optional[str] = None
 73    """The key in the input dictionary that corresponds to the long-range graph."""
 74    lmax_lode: int = 0
 75    """The maximum angular momentum for the long-range features."""
 76    lode_rshort: Optional[float] = None
 77    """The short-range cutoff for the long-range features."""
 78    lode_dshort: float = 2.0
 79    """The width of the short-range cutoff for the long-range features."""
 80    lode_extra_powers: Sequence[int] = ()
 81    """The extra powers to include in the long-range features."""
 82    a_lode: float = -1.0
 83    """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode)."""
 84    block_index_key: Optional[str] = None
 85    """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used."""
 86    lode_channels: int = 1
 87    """The number of channels for the long-range features."""
 88    switch_cov_start: float = 0.5
 89    """The start of close-range covalent switch (in units of covalent radii)."""
 90    switch_cov_end: float = 0.6
 91    """The end of close-range covalent switch (in units of covalent radii)."""
 92    normalize_keys: bool = False
 93    """Whether to normalize queries and keys in the attention mechanism."""
 94    normalize_components: bool = False
 95    """Whether to normalize the components before the update network."""
 96    keep_all_layers: bool = False
 97    """Whether to return the stacked scalar embeddings from all message-passing layers."""
 98    kernel_init: Optional[str] = None
 99    
100    FID: ClassVar[str] = "RASTER"
101
102    @nn.compact
103    def __call__(self, inputs):
104        species = inputs["species"]
105
106        ## SETUP LAYER NORMALIZATION
107        def _layer_norm(x):
108            mu = jnp.mean(x, axis=-1, keepdims=True)
109            dx = x - mu
110            var = jnp.mean(dx**2, axis=-1, keepdims=True)
111            sig = (1.0e-6 + var) ** (-0.5)
112            return dx * sig
113
114        if self.layer_normalization:
115            layer_norm = _layer_norm
116        else:
117            layer_norm = lambda x: x
118        
119        if self.normalize_keys:
120            ln_qk = _layer_norm
121        else:
122            ln_qk = lambda x: x
123
124        kernel_init = initializer_from_str(self.kernel_init)
125
126        ## SPECIES ENCODING
127        if isinstance(self.species_encoding, str):
128            Zi = inputs[self.species_encoding]
129        else:
130            Zi = SpeciesEncoding(**self.species_encoding)(species)
131
132        ## INITIALIZE SCALAR FEATURES
133        xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear",kernel_init=kernel_init)(Zi))
134
135        # RADIAL GRAPH
136        graph = inputs[self.graph_key]
137        distances = graph["distances"]
138        switch = graph["switch"]
139        edge_src = graph["edge_src"]
140        edge_dst = graph["edge_dst"]
141        vec = (
142            graph["vec"] / graph["distances"][:, None]
143            if self.normalize_vec
144            else graph["vec"]
145        )
146        ## CLOSE-RANGE SWITCH
147        use_switch_cov = False
148        if self.switch_cov_end > 0 and self.switch_cov_start > 0:
149            use_switch_cov = True
150            assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}"
151            assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1"
152            rc = jnp.array(D3_COV_RADII*au.ANG)[species]
153            rcij = rc[edge_src] + rc[edge_dst]
154            rstart = rcij * self.switch_cov_start
155            rend = rcij * self.switch_cov_end
156            switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend)
157            switch = switch * switch_short
158
159        ## COMPUTE SPHERICAL HARMONICS ON EDGES
160        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:]
161        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
162        ls = np.arange(self.lmax + 1).repeat(nrep)
163            
164        parity = jnp.array((-1) ** ls[None,None,:])
165        if self.ignore_parity:
166            parity = -jnp.ones_like(parity)
167
168        ## INITIALIZE TENSOR FEATURES
169        Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1]))
170
171        # RADIAL BASIS
172        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
173        radial_terms = RadialBasis(
174            **{
175                **self.radial_basis,
176                "end": cutoff,
177                "name": f"RadialBasis",
178            }
179        )(distances)
180        if self.switch_before_net:
181            radial_terms = radial_terms * switch[:, None]
182        elif use_switch_cov:
183            radial_terms = radial_terms * switch_short[:, None]
184
185        ## INITIALIZE LODE
186        do_lode = self.graph_lode is not None
187        if do_lode:
188            ## LONG-RANGE GRAPH
189            graph_lode = inputs[self.graph_lode]
190            switch_lode = graph_lode["switch"][:, None]
191            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
192            r = graph_lode["distances"][:, None]
193            rc = self._graphs_properties[self.graph_lode]["cutoff"]
194
195            lmax_lr = self.lmax_lode
196            equivariant_lode = lmax_lr > 0
197            assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}"
198            assert (
199                lmax_lr <= self.lmax
200            ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
201            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
202            if equivariant_lode:
203                ls_lr = np.arange(lmax_lr + 1)
204            else:
205                ls_lr = np.array([0])
206
207            ## PARAMETERS FOR THE LR RADIAL BASIS
208            nextra_powers = len(self.lode_extra_powers)
209            if nextra_powers > 0:
210                ls_lr = np.concatenate([self.lode_extra_powers, ls_lr])
211
212            if self.a_lode > 0:
213                a = self.a_lode**2
214            else:
215                a = (
216                    self.param(
217                        "a_lr",
218                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[
219                            None, :
220                        ],
221                    )
222                    ** 2
223                )
224            rc2a = rc**2 + a
225            ls_lr = 0.5 * (ls_lr[None, :] + 1)
226            ### minimal radial basis for long range (damped coulomb)
227            eij_lr = (
228                1.0 / (r**2 + a) ** ls_lr
229                - 1.0 / rc2a**ls_lr
230                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
231            ) * switch_lode
232
233            if self.lode_rshort is not None:
234                rs = self.lode_rshort
235                d = self.lode_dshort
236                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
237                    r < rs + d
238                ) + (r >= rs + d)
239                eij_lr = eij_lr * switch_short
240
241            dim_lr = 1
242            if nextra_powers > 0:
243                eij_lr_extra = eij_lr[:, :nextra_powers]
244                eij_lr = eij_lr[:, nextra_powers:]
245                dim_lr += nextra_powers
246
247            if equivariant_lode:
248                ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH
249                eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
250                Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
251                    graph_lode["vec"] / r
252                )
253                dim_lr += lmax_lr
254                eij_lr = eij_lr * Yij_lr
255                del Yij_lr
256        
257
258        if self.keep_all_layers:
259            xis = []
260
261        ### START MESSAGE PASSING ITERATIONS
262        for layer in range(self.nlayers):
263            ## GATHER SCALAR EDGE FEATURES
264            u = [radial_terms]
265            if layer > 0:
266                ## edge-tensor contraction
267                xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij
268                for l in range(self.lmax + 1):
269                    u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1))
270            ur = jnp.concatenate(u, axis=-1)
271
272            ## BUILD RELATIVE POSITIONAL ENCODING
273            if self.edge_value:
274                nout = 2
275            else:
276                nout = 1
277            w = FullyConnectedNet(
278                [2 * self.att_dim, nout*self.att_dim],
279                activation=self.positional_activation,
280                use_bias=self.positional_bias,
281                name=f"positional_encoding_{layer}",
282            )(ur).reshape(radial_terms.shape[0],nout, self.att_dim)
283            if self.edge_value:
284                w,vij = jnp.split(w, 2, axis=1)
285
286            nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1)
287
288
289            ## QUERY, KEY, VALUE
290            q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}",kernel_init=kernel_init)(
291                xi
292            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim))
293            k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}",kernel_init=kernel_init)(
294                xi
295            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)
296
297            v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}",kernel_init=kernel_init)(xi).reshape(
298                xi.shape[0], self.scal_heads, self.att_dim
299            )
300
301            ## ATTENTION COEFFICIENTS
302            if self.additive_positional:
303                wk = ln_qk(w + k[edge_dst])
304            else:
305                wk = ln_qk(w * k[edge_dst])
306
307            act = activation_from_str(self.att_activation)
308            aij = (
309                act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5))
310                * switch[:, None]
311            )
312
313            aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
314            if layer > 0:
315                aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
316            aij = aij[:, self.tens_heads*nls:, None]
317
318            if self.edge_value:
319                ## EDGE VALUES
320                if self.additive_positional:
321                    vij = vij + v[edge_dst]
322                else:
323                    vij = vij * v[edge_dst]
324            else:
325                ## MOVE DEST VALUES TO EDGE
326                vij = v[edge_dst]
327
328            ## SCALAR ATTENDED FEATURES
329            vai = jax.ops.segment_sum(
330                aij * vij,
331                edge_src,
332                num_segments=xi.shape[0],
333            )
334            vai = vai.reshape(xi.shape[0], -1)
335
336            ### TENSOR ATTENDED FEATURES
337            uij = aijl * Yij
338            if layer > 0:
339                uij = uij + aijl1 * Vi[edge_dst]
340            Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0])
341
342            ## SELF SCALAR FEATURES
343            si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}",kernel_init=kernel_init)(xi)
344
345            components = [si, vai]
346
347            ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS
348            if self.tens_heads == 1:
349                Vi2 = Vi**2
350            else:
351                Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi)
352            for l in range(self.lmax + 1):
353                norm = 1.0 / (2 * l + 1)
354                components.append(
355                    (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm
356                )
357
358            ### LODE (~ LONG-RANGE ATTENTION)
359            if do_lode and layer == self.nlayers - 1:
360                assert self.lode_channels <= self.tens_heads
361                zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}",kernel_init=kernel_init)(xi).reshape(
362                    xi.shape[0], self.lode_channels, dim_lr
363                )
364                if nextra_powers > 0:
365                    zj_extra = zj[:,:, :nextra_powers]
366                    zj = zj[:, :, nextra_powers:]
367                    xi_lr_extra = jax.ops.segment_sum(
368                        eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr],
369                        edge_src_lr,
370                        species.shape[0],
371                    ).reshape(species.shape[0],-1)
372                    components.append(xi_lr_extra)
373                if equivariant_lode:
374                    zj = zj.repeat(nrep_lr, axis=-1)
375                Vi_lr = jax.ops.segment_sum(
376                    eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0]
377                )
378                components.append(Vi_lr[:,: , 0])
379                if equivariant_lode:
380                    Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr
381                    for l in range(1, lmax_lr + 1):
382                        norm = 1.0 / (2 * l + 1)
383                        components.append(
384                            Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1)
385                            * norm
386                        )
387
388            ### CONCATENATE UPDATE COMPONENTS
389            components = jnp.concatenate(components, axis=-1)
390            if self.normalize_components:
391                components = _layer_norm(components)
392            ### COMPUTE UPDATE
393            if self.block_index_key is not None:
394                ## MoE neural network from block index
395                block_index = inputs[self.block_index_key]
396                updi = BlockIndexNet(
397                        output_dim=self.dim + self.tens_heads*(self.lmax + 1),
398                        hidden_neurons=self.update_hidden,
399                        activation=self.activation,
400                        use_bias=self.update_bias,
401                        name=f"update_net_{layer}",
402                        kernel_init=kernel_init,
403                    )((species,components, block_index))
404            else:
405                updi = FullyConnectedNet(
406                        [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)],
407                        activation=self.activation,
408                        use_bias=self.update_bias,
409                        name=f"update_net_{layer}",
410                        kernel_init=kernel_init,
411                    )(components)
412                
413            ## UPDATE ATOM FEATURES
414            xi = layer_norm(xi + updi[:,:self.dim])
415            Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
416            if self.tens_heads > 1:
417                Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi)
418
419            if self.keep_all_layers:
420                ## STORE ALL LAYERS
421                xis.append(xi)
422
423
424        output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi}
425        if self.keep_all_layers:
426            output[self.embedding_key+'_layers'] = jnp.stack(xis,axis=1)
427        return output
class RaSTER(flax.linen.module.Module):
 17class RaSTER(nn.Module):
 18    """ Range-Separated Transformer with Equivariant Representations
 19
 20    FID : RASTER
 21
 22    """
 23
 24    _graphs_properties: Dict
 25    dim: int = 176
 26    """The dimension of the output embedding."""
 27    nlayers: int = 2
 28    """The number of message-passing layers."""
 29    att_dim: int = 16
 30    """The dimension of the attention heads."""
 31    scal_heads: int = 16
 32    """The number of scalar attention heads."""
 33    tens_heads: int = 4
 34    """The number of tensor attention heads."""
 35    lmax: int = 3
 36    """The maximum angular momentum to consider."""
 37    normalize_vec: bool = True
 38    """Whether to normalize the vector features before computing spherical harmonics."""
 39    att_activation: str = "identity"
 40    """The activation function to use for the attention coefficients."""
 41    activation: str = "swish"
 42    """The activation function to use for the update network."""
 43    update_hidden: Sequence[int] = ()
 44    """The hidden layers for the update network."""
 45    update_bias: bool = True
 46    """Whether to use bias in the update network."""
 47    positional_activation: str = "swish"
 48    """The activation function to use for the positional embedding network."""
 49    positional_bias: bool = True
 50    """Whether to use bias in the positional embedding network."""
 51    switch_before_net: bool = False
 52    """Whether to apply the switch function to the radial basis before the edge neural network."""
 53    ignore_parity: bool = False
 54    """Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding."""
 55    additive_positional: bool = False
 56    """Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used."""
 57    edge_value: bool = False
 58    """Whether to use edge values in the attention mechanism."""
 59    layer_normalization: bool = True
 60    """Whether to use layer normalization of atomic embeddings."""
 61    graph_key: str = "graph"
 62    """ The key in the input dictionary that corresponds to the radial graph."""
 63    embedding_key: str = "embedding"
 64    """ The key in the output dictionary that corresponds to the embedding."""
 65    radial_basis: dict = dataclasses.field(
 66        default_factory=lambda: {"start": 0.8, "basis": "gaussian", "dim": 16}
 67    )
 68    """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`."""
 69    species_encoding: str | dict = dataclasses.field(
 70        default_factory=lambda: {"dim": 16, "trainable": True, "encoding": "random"}
 71    )
 72    """The dictionary of parameters for species encoding. See `fennol.models.misc.encodings.SpeciesEncoding`."""
 73    graph_lode: Optional[str] = None
 74    """The key in the input dictionary that corresponds to the long-range graph."""
 75    lmax_lode: int = 0
 76    """The maximum angular momentum for the long-range features."""
 77    lode_rshort: Optional[float] = None
 78    """The short-range cutoff for the long-range features."""
 79    lode_dshort: float = 2.0
 80    """The width of the short-range cutoff for the long-range features."""
 81    lode_extra_powers: Sequence[int] = ()
 82    """The extra powers to include in the long-range features."""
 83    a_lode: float = -1.0
 84    """The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode)."""
 85    block_index_key: Optional[str] = None
 86    """The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used."""
 87    lode_channels: int = 1
 88    """The number of channels for the long-range features."""
 89    switch_cov_start: float = 0.5
 90    """The start of close-range covalent switch (in units of covalent radii)."""
 91    switch_cov_end: float = 0.6
 92    """The end of close-range covalent switch (in units of covalent radii)."""
 93    normalize_keys: bool = False
 94    """Whether to normalize queries and keys in the attention mechanism."""
 95    normalize_components: bool = False
 96    """Whether to normalize the components before the update network."""
 97    keep_all_layers: bool = False
 98    """Whether to return the stacked scalar embeddings from all message-passing layers."""
 99    kernel_init: Optional[str] = None
100    
101    FID: ClassVar[str] = "RASTER"
102
103    @nn.compact
104    def __call__(self, inputs):
105        species = inputs["species"]
106
107        ## SETUP LAYER NORMALIZATION
108        def _layer_norm(x):
109            mu = jnp.mean(x, axis=-1, keepdims=True)
110            dx = x - mu
111            var = jnp.mean(dx**2, axis=-1, keepdims=True)
112            sig = (1.0e-6 + var) ** (-0.5)
113            return dx * sig
114
115        if self.layer_normalization:
116            layer_norm = _layer_norm
117        else:
118            layer_norm = lambda x: x
119        
120        if self.normalize_keys:
121            ln_qk = _layer_norm
122        else:
123            ln_qk = lambda x: x
124
125        kernel_init = initializer_from_str(self.kernel_init)
126
127        ## SPECIES ENCODING
128        if isinstance(self.species_encoding, str):
129            Zi = inputs[self.species_encoding]
130        else:
131            Zi = SpeciesEncoding(**self.species_encoding)(species)
132
133        ## INITIALIZE SCALAR FEATURES
134        xi = layer_norm(nn.Dense(self.dim, use_bias=False,name="species_linear",kernel_init=kernel_init)(Zi))
135
136        # RADIAL GRAPH
137        graph = inputs[self.graph_key]
138        distances = graph["distances"]
139        switch = graph["switch"]
140        edge_src = graph["edge_src"]
141        edge_dst = graph["edge_dst"]
142        vec = (
143            graph["vec"] / graph["distances"][:, None]
144            if self.normalize_vec
145            else graph["vec"]
146        )
147        ## CLOSE-RANGE SWITCH
148        use_switch_cov = False
149        if self.switch_cov_end > 0 and self.switch_cov_start > 0:
150            use_switch_cov = True
151            assert self.switch_cov_start < self.switch_cov_end, f"switch_cov_start {self.switch_cov_start} must be smaller than switch_cov_end {self.switch_cov_end}"
152            assert self.switch_cov_start > 0 and self.switch_cov_end < 1, f"switch_cov_start {self.switch_cov_start} and switch_cov_end {self.switch_cov_end} must be between 0 and 1"
153            rc = jnp.array(D3_COV_RADII*au.ANG)[species]
154            rcij = rc[edge_src] + rc[edge_dst]
155            rstart = rcij * self.switch_cov_start
156            rend = rcij * self.switch_cov_end
157            switch_short = (distances >= rend) + 0.5*(1-jnp.cos(jnp.pi*(distances - rstart)/(rend-rstart)))*(distances > rstart)*(distances < rend)
158            switch = switch * switch_short
159
160        ## COMPUTE SPHERICAL HARMONICS ON EDGES
161        Yij = generate_spherical_harmonics(lmax=self.lmax, normalize=False)(vec)[:,None,:]
162        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
163        ls = np.arange(self.lmax + 1).repeat(nrep)
164            
165        parity = jnp.array((-1) ** ls[None,None,:])
166        if self.ignore_parity:
167            parity = -jnp.ones_like(parity)
168
169        ## INITIALIZE TENSOR FEATURES
170        Vi = 0. #jnp.zeros((Zi.shape[0],self.tens_heads, Yij.shape[1]))
171
172        # RADIAL BASIS
173        cutoff = self._graphs_properties[self.graph_key]["cutoff"]
174        radial_terms = RadialBasis(
175            **{
176                **self.radial_basis,
177                "end": cutoff,
178                "name": f"RadialBasis",
179            }
180        )(distances)
181        if self.switch_before_net:
182            radial_terms = radial_terms * switch[:, None]
183        elif use_switch_cov:
184            radial_terms = radial_terms * switch_short[:, None]
185
186        ## INITIALIZE LODE
187        do_lode = self.graph_lode is not None
188        if do_lode:
189            ## LONG-RANGE GRAPH
190            graph_lode = inputs[self.graph_lode]
191            switch_lode = graph_lode["switch"][:, None]
192            edge_src_lr, edge_dst_lr = graph_lode["edge_src"], graph_lode["edge_dst"]
193            r = graph_lode["distances"][:, None]
194            rc = self._graphs_properties[self.graph_lode]["cutoff"]
195
196            lmax_lr = self.lmax_lode
197            equivariant_lode = lmax_lr > 0
198            assert lmax_lr >= 0, f"lmax_lode must be >= 0, got {lmax_lr}"
199            assert (
200                lmax_lr <= self.lmax
201            ), f"lmax_lode must be <= lmax for multipole interaction, got {lmax_lr} > {self.lmax}"
202            nrep_lr = np.array([2 * l + 1 for l in range(lmax_lr + 1)], dtype=np.int32)
203            if equivariant_lode:
204                ls_lr = np.arange(lmax_lr + 1)
205            else:
206                ls_lr = np.array([0])
207
208            ## PARAMETERS FOR THE LR RADIAL BASIS
209            nextra_powers = len(self.lode_extra_powers)
210            if nextra_powers > 0:
211                ls_lr = np.concatenate([self.lode_extra_powers, ls_lr])
212
213            if self.a_lode > 0:
214                a = self.a_lode**2
215            else:
216                a = (
217                    self.param(
218                        "a_lr",
219                        lambda key: jnp.asarray([-self.a_lode] * ls_lr.shape[0])[
220                            None, :
221                        ],
222                    )
223                    ** 2
224                )
225            rc2a = rc**2 + a
226            ls_lr = 0.5 * (ls_lr[None, :] + 1)
227            ### minimal radial basis for long range (damped coulomb)
228            eij_lr = (
229                1.0 / (r**2 + a) ** ls_lr
230                - 1.0 / rc2a**ls_lr
231                + (r - rc) * rc * (2 * ls_lr) / rc2a ** (ls_lr + 1)
232            ) * switch_lode
233
234            if self.lode_rshort is not None:
235                rs = self.lode_rshort
236                d = self.lode_dshort
237                switch_short = 0.5 * (1 - jnp.cos(jnp.pi * (r - rs) / d)) * (r > rs) * (
238                    r < rs + d
239                ) + (r >= rs + d)
240                eij_lr = eij_lr * switch_short
241
242            dim_lr = 1
243            if nextra_powers > 0:
244                eij_lr_extra = eij_lr[:, :nextra_powers]
245                eij_lr = eij_lr[:, nextra_powers:]
246                dim_lr += nextra_powers
247
248            if equivariant_lode:
249                ## SPHERICAL HARMONICS ON LONG-RANGE GRAPH
250                eij_lr = eij_lr.repeat(nrep_lr, axis=-1)
251                Yij_lr = generate_spherical_harmonics(lmax=lmax_lr, normalize=False)(
252                    graph_lode["vec"] / r
253                )
254                dim_lr += lmax_lr
255                eij_lr = eij_lr * Yij_lr
256                del Yij_lr
257        
258
259        if self.keep_all_layers:
260            xis = []
261
262        ### START MESSAGE PASSING ITERATIONS
263        for layer in range(self.nlayers):
264            ## GATHER SCALAR EDGE FEATURES
265            u = [radial_terms]
266            if layer > 0:
267                ## edge-tensor contraction
268                xij2 = (Vi[edge_dst] + (parity* Vi)[edge_src]) * Yij
269                for l in range(self.lmax + 1):
270                    u.append((xij2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1))
271            ur = jnp.concatenate(u, axis=-1)
272
273            ## BUILD RELATIVE POSITIONAL ENCODING
274            if self.edge_value:
275                nout = 2
276            else:
277                nout = 1
278            w = FullyConnectedNet(
279                [2 * self.att_dim, nout*self.att_dim],
280                activation=self.positional_activation,
281                use_bias=self.positional_bias,
282                name=f"positional_encoding_{layer}",
283            )(ur).reshape(radial_terms.shape[0],nout, self.att_dim)
284            if self.edge_value:
285                w,vij = jnp.split(w, 2, axis=1)
286
287            nls = self.lmax + 1 if layer == 0 else 2 * (self.lmax + 1)
288
289
290            ## QUERY, KEY, VALUE
291            q = ln_qk(nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False,name=f"queries_{layer}",kernel_init=kernel_init)(
292                xi
293            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim))
294            k = nn.Dense((self.scal_heads + nls*self.tens_heads) * self.att_dim, use_bias=False, name=f"keys_{layer}",kernel_init=kernel_init)(
295                xi
296            ).reshape(xi.shape[0], self.scal_heads + nls*self.tens_heads, self.att_dim)
297
298            v = nn.Dense(self.scal_heads * self.att_dim, use_bias=False, name=f"values_{layer}",kernel_init=kernel_init)(xi).reshape(
299                xi.shape[0], self.scal_heads, self.att_dim
300            )
301
302            ## ATTENTION COEFFICIENTS
303            if self.additive_positional:
304                wk = ln_qk(w + k[edge_dst])
305            else:
306                wk = ln_qk(w * k[edge_dst])
307
308            act = activation_from_str(self.att_activation)
309            aij = (
310                act((q[edge_src] * wk).sum(axis=-1) / (self.att_dim**0.5))
311                * switch[:, None]
312            )
313
314            aijl = aij[:, : self.tens_heads*(self.lmax + 1)].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
315            if layer > 0:
316                aijl1 = aij[:, self.tens_heads*(self.lmax + 1) : self.tens_heads*nls].reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
317            aij = aij[:, self.tens_heads*nls:, None]
318
319            if self.edge_value:
320                ## EDGE VALUES
321                if self.additive_positional:
322                    vij = vij + v[edge_dst]
323                else:
324                    vij = vij * v[edge_dst]
325            else:
326                ## MOVE DEST VALUES TO EDGE
327                vij = v[edge_dst]
328
329            ## SCALAR ATTENDED FEATURES
330            vai = jax.ops.segment_sum(
331                aij * vij,
332                edge_src,
333                num_segments=xi.shape[0],
334            )
335            vai = vai.reshape(xi.shape[0], -1)
336
337            ### TENSOR ATTENDED FEATURES
338            uij = aijl * Yij
339            if layer > 0:
340                uij = uij + aijl1 * Vi[edge_dst]
341            Vi = Vi + jax.ops.segment_sum(uij, edge_src, num_segments=Zi.shape[0])
342
343            ## SELF SCALAR FEATURES
344            si = nn.Dense(self.att_dim, use_bias=False, name=f"self_values_{layer}",kernel_init=kernel_init)(xi)
345
346            components = [si, vai]
347
348            ### CONTRACT TENSOR FEATURES TO BUILD INVARIANTS
349            if self.tens_heads == 1:
350                Vi2 = Vi**2
351            else:
352                Vi2 = Vi * ChannelMixing(self.lmax, self.tens_heads, name=f"extract_mixing_{layer}")(Vi)
353            for l in range(self.lmax + 1):
354                norm = 1.0 / (2 * l + 1)
355                components.append(
356                    (Vi2[:,:, l**2 : (l + 1) ** 2]).sum(axis=-1) * norm
357                )
358
359            ### LODE (~ LONG-RANGE ATTENTION)
360            if do_lode and layer == self.nlayers - 1:
361                assert self.lode_channels <= self.tens_heads
362                zj = nn.Dense(self.lode_channels*dim_lr, use_bias=False, name=f"lode_values_{layer}",kernel_init=kernel_init)(xi).reshape(
363                    xi.shape[0], self.lode_channels, dim_lr
364                )
365                if nextra_powers > 0:
366                    zj_extra = zj[:,:, :nextra_powers]
367                    zj = zj[:, :, nextra_powers:]
368                    xi_lr_extra = jax.ops.segment_sum(
369                        eij_lr_extra[:,None,:] * zj_extra[edge_dst_lr],
370                        edge_src_lr,
371                        species.shape[0],
372                    ).reshape(species.shape[0],-1)
373                    components.append(xi_lr_extra)
374                if equivariant_lode:
375                    zj = zj.repeat(nrep_lr, axis=-1)
376                Vi_lr = jax.ops.segment_sum(
377                    eij_lr[:,None,:] * zj[edge_dst_lr], edge_src_lr, species.shape[0]
378                )
379                components.append(Vi_lr[:,: , 0])
380                if equivariant_lode:
381                    Mi_lr = Vi[:,:self.lode_channels, : (lmax_lr + 1) ** 2] * Vi_lr
382                    for l in range(1, lmax_lr + 1):
383                        norm = 1.0 / (2 * l + 1)
384                        components.append(
385                            Mi_lr[:, :,l**2 : (l + 1) ** 2].sum(axis=-1)
386                            * norm
387                        )
388
389            ### CONCATENATE UPDATE COMPONENTS
390            components = jnp.concatenate(components, axis=-1)
391            if self.normalize_components:
392                components = _layer_norm(components)
393            ### COMPUTE UPDATE
394            if self.block_index_key is not None:
395                ## MoE neural network from block index
396                block_index = inputs[self.block_index_key]
397                updi = BlockIndexNet(
398                        output_dim=self.dim + self.tens_heads*(self.lmax + 1),
399                        hidden_neurons=self.update_hidden,
400                        activation=self.activation,
401                        use_bias=self.update_bias,
402                        name=f"update_net_{layer}",
403                        kernel_init=kernel_init,
404                    )((species,components, block_index))
405            else:
406                updi = FullyConnectedNet(
407                        [*self.update_hidden, self.dim + self.tens_heads*(self.lmax + 1)],
408                        activation=self.activation,
409                        use_bias=self.update_bias,
410                        name=f"update_net_{layer}",
411                        kernel_init=kernel_init,
412                    )(components)
413                
414            ## UPDATE ATOM FEATURES
415            xi = layer_norm(xi + updi[:,:self.dim])
416            Vi = Vi * (1 + updi[:,self.dim:]).reshape(-1,self.tens_heads,self.lmax+1).repeat(nrep, axis=-1)
417            if self.tens_heads > 1:
418                Vi = ChannelMixing(self.lmax, self.tens_heads,name=f"update_mixing_{layer}")(Vi)
419
420            if self.keep_all_layers:
421                ## STORE ALL LAYERS
422                xis.append(xi)
423
424
425        output = {**inputs, self.embedding_key: xi, self.embedding_key + "_tensor": Vi}
426        if self.keep_all_layers:
427            output[self.embedding_key+'_layers'] = jnp.stack(xis,axis=1)
428        return output

Range-Separated Transformer with Equivariant Representations

FID : RASTER

RaSTER( _graphs_properties: Dict, dim: int = 176, nlayers: int = 2, att_dim: int = 16, scal_heads: int = 16, tens_heads: int = 4, lmax: int = 3, normalize_vec: bool = True, att_activation: str = 'identity', activation: str = 'swish', update_hidden: Sequence[int] = (), update_bias: bool = True, positional_activation: str = 'swish', positional_bias: bool = True, switch_before_net: bool = False, ignore_parity: bool = False, additive_positional: bool = False, edge_value: bool = False, layer_normalization: bool = True, graph_key: str = 'graph', embedding_key: str = 'embedding', radial_basis: dict = <factory>, species_encoding: str | dict = <factory>, graph_lode: Optional[str] = None, lmax_lode: int = 0, lode_rshort: Optional[float] = None, lode_dshort: float = 2.0, lode_extra_powers: Sequence[int] = (), a_lode: float = -1.0, block_index_key: Optional[str] = None, lode_channels: int = 1, switch_cov_start: float = 0.5, switch_cov_end: float = 0.6, normalize_keys: bool = False, normalize_components: bool = False, keep_all_layers: bool = False, kernel_init: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int = 176

The dimension of the output embedding.

nlayers: int = 2

The number of message-passing layers.

att_dim: int = 16

The dimension of the attention heads.

scal_heads: int = 16

The number of scalar attention heads.

tens_heads: int = 4

The number of tensor attention heads.

lmax: int = 3

The maximum angular momentum to consider.

normalize_vec: bool = True

Whether to normalize the vector features before computing spherical harmonics.

att_activation: str = 'identity'

The activation function to use for the attention coefficients.

activation: str = 'swish'

The activation function to use for the update network.

update_hidden: Sequence[int] = ()

The hidden layers for the update network.

update_bias: bool = True

Whether to use bias in the update network.

positional_activation: str = 'swish'

The activation function to use for the positional embedding network.

positional_bias: bool = True

Whether to use bias in the positional embedding network.

switch_before_net: bool = False

Whether to apply the switch function to the radial basis before the edge neural network.

ignore_parity: bool = False

Whether to ignore the parity of the spherical harmonics when constructing the relative positional encoding.

additive_positional: bool = False

Whether to use additive relative positional encoding. If False, multiplicative relative positional encoding is used.

edge_value: bool = False

Whether to use edge values in the attention mechanism.

layer_normalization: bool = True

Whether to use layer normalization of atomic embeddings.

graph_key: str = 'graph'

The key in the input dictionary that corresponds to the radial graph.

embedding_key: str = 'embedding'

The key in the output dictionary that corresponds to the embedding.

radial_basis: dict

The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis.

species_encoding: str | dict

The dictionary of parameters for species encoding. See fennol.models.misc.encodings.SpeciesEncoding.

graph_lode: Optional[str] = None

The key in the input dictionary that corresponds to the long-range graph.

lmax_lode: int = 0

The maximum angular momentum for the long-range features.

lode_rshort: Optional[float] = None

The short-range cutoff for the long-range features.

lode_dshort: float = 2.0

The width of the short-range cutoff for the long-range features.

lode_extra_powers: Sequence[int] = ()

The extra powers to include in the long-range features.

a_lode: float = -1.0

The damping parameter for the long-range features. If negative, the damping is trainable with initial value abs(a_lode).

block_index_key: Optional[str] = None

The key in the input dictionary that corresponds to the block index for the MoE network. If None, a normal neural network is used.

lode_channels: int = 1

The number of channels for the long-range features.

switch_cov_start: float = 0.5

The start of close-range covalent switch (in units of covalent radii).

switch_cov_end: float = 0.6

The end of close-range covalent switch (in units of covalent radii).

normalize_keys: bool = False

Whether to normalize queries and keys in the attention mechanism.

normalize_components: bool = False

Whether to normalize the components before the update network.

keep_all_layers: bool = False

Whether to return the stacked scalar embeddings from all message-passing layers.

kernel_init: Optional[str] = None
FID: ClassVar[str] = 'RASTER'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None