fennol.models.misc.encodings

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Optional, Union, List, Sequence, ClassVar, Dict
  5import math
  6import dataclasses
  7import numpy as np
  8from .nets import FullyConnectedNet
  9from ...utils import AtomicUnits as au
 10from functools import partial
 11from ...utils.periodic_table import (
 12    PERIODIC_TABLE_REV_IDX,
 13    PERIODIC_TABLE,
 14    EL_STRUCT,
 15    VALENCE_STRUCTURE,
 16    XENONPY_PROPS,
 17    SJS_COORDINATES,
 18    PERIODIC_COORDINATES,
 19    ATOMIC_IONIZATION_ENERGY,
 20    POLARIZABILITIES,
 21    D3_COV_RADII,
 22    VDW_RADII,
 23    OXIDATION_STATES,
 24)
 25
 26
 27class SpeciesEncoding(nn.Module):
 28    """A module that encodes chemical species information.
 29
 30    FID: SPECIES_ENCODING
 31    """
 32
 33    encoding: str = "random"
 34    """ The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". 
 35        Multiple encodings can be concatenated using the "+" separator.
 36    """
 37    dim: int = 16
 38    """ The dimension of the encoding if not fixed by design."""
 39    zmax: int = 86
 40    """ The maximum atomic number to encode."""
 41    output_key: Optional[str] = None
 42    """ The key to use for the output in the returned dictionary."""
 43
 44    species_order: Optional[Union[str, Sequence[str]]] = None
 45    """ The order of the species to use for the encoding. Only used for "onehot" encoding.
 46         If None, we encode all elements up to `zmax`."""
 47    trainable: bool = False
 48    """ Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable."""
 49    extra_params: Dict = dataclasses.field(default_factory=dict)
 50    """ Dictionary of extra parameters for the basis."""
 51
 52    FID: ClassVar[str] = "SPECIES_ENCODING"
 53
 54    @nn.compact
 55    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 56
 57        zmax = self.zmax
 58        if zmax <= 0 or zmax > len(PERIODIC_TABLE):
 59            zmax = len(PERIODIC_TABLE)
 60
 61        zmaxpad = zmax + 2
 62
 63        encoding = self.encoding.lower().strip()
 64        encodings = encoding.split("+")
 65        ############################
 66        conv_tensors = []
 67
 68        if "one_hot" in encodings or "onehot" in encodings:
 69            if self.species_order is None:
 70                conv_tensor = np.eye(zmax)
 71                conv_tensor = np.concatenate(
 72                    [np.zeros((1, zmax)), conv_tensor, np.zeros((1, zmax))], axis=0
 73                )
 74            else:
 75                if isinstance(self.species_order, str):
 76                    species_order = [el.strip() for el in self.species_order.split(",")]
 77                else:
 78                    species_order = [el for el in self.species_order]
 79                conv_tensor = np.zeros((zmaxpad, len(species_order)))
 80                for i, s in enumerate(species_order):
 81                    conv_tensor[PERIODIC_TABLE_REV_IDX[s], i] = 1
 82
 83            conv_tensors.append(conv_tensor)
 84
 85        if "occupation" in encodings:
 86            conv_tensor = np.zeros((zmaxpad, (zmax + 1) // 2))
 87            for i in range(1, zmax + 1):
 88                conv_tensor[i, : i // 2] = 1
 89                if i % 2 == 1:
 90                    conv_tensor[i, i // 2] = 0.5
 91
 92            conv_tensors.append(conv_tensor)
 93
 94        if "electronic_structure" in encodings:
 95            Z = np.arange(1, zmax + 1).reshape(-1, 1)
 96            Zref = [zmax]
 97            e_struct = np.array(EL_STRUCT[1 : zmax + 1])
 98            eref = [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 14, 10, 6]
 99            v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1])
100            vref = [2, 6, 10, 14]
101            if zmax <= 86:
102                e_struct = e_struct[:, :15]
103                eref = eref[:15]
104            ref = np.array(Zref + eref + vref)
105            conv_tensor = np.concatenate([Z, e_struct, v_struct], axis=1)
106            conv_tensor = conv_tensor / ref[None, :]
107            dim = conv_tensor.shape[1]
108            conv_tensor = np.concatenate(
109                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
110            )
111
112            conv_tensors.append(conv_tensor)
113
114        if "properties" in encodings:
115            props = np.array(XENONPY_PROPS)[1:-1]
116            assert (
117                self.zmax <= props.shape[0]
118            ), f"zmax > {props.shape[0]} not supported for xenonpy properties"
119            conv_tensor = props[1 : zmax + 1]
120            mean = np.mean(props, axis=0)
121            std = np.std(props, axis=0)
122            conv_tensor = (conv_tensor - mean[None, :]) / std[None, :]
123            dim = conv_tensor.shape[1]
124            conv_tensor = np.concatenate(
125                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
126            )
127            conv_tensors.append(conv_tensor)
128
129        if "valence_properties" in encodings:
130            assert zmax <= 86, "Valence properties only available for zmax <= 86"
131            Z = np.arange(1, zmax + 1).reshape(-1, 1)
132            Zref = [zmax]
133            Zinv = 1.0 / Z
134            Zinvref = [1.0]
135            v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1])
136            vref = [2, 6, 10, 14]
137            ionization = np.array(ATOMIC_IONIZATION_ENERGY[1 : zmax + 1]).reshape(-1, 1)
138            ionizationref = [0.5]
139            polariz = np.array(POLARIZABILITIES[1 : zmax + 1]).reshape(-1, 1)
140            polarizref = [np.median(polariz)]
141
142            cov = np.array(D3_COV_RADII[1 : zmax + 1]).reshape(-1, 1)
143            covref = [np.median(cov)]
144
145            vdw = np.array(VDW_RADII[1 : zmax + 1]).reshape(-1, 1)
146            vdwref = [np.median(vdw)]
147
148            ref = np.array(
149                Zref + Zinvref + ionizationref + polarizref + covref + vdwref + vref
150            )
151            conv_tensor = np.concatenate(
152                [Z, Zinv, ionization, polariz, cov, vdw, v_struct], axis=1
153            )
154
155            conv_tensor = conv_tensor / ref[None, :]
156            dim = conv_tensor.shape[1]
157            conv_tensor = np.concatenate(
158                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
159            )
160            conv_tensors.append(conv_tensor)
161
162        if "sjs_coordinates" in encodings:
163            coords = np.array(SJS_COORDINATES)[1:-1]
164            conv_tensor = coords[1 : zmax + 1]
165            mean = np.mean(coords, axis=0)
166            std = np.std(coords, axis=0)
167            conv_tensor = (conv_tensor - mean[None, :]) / std[None, :]
168            dim = conv_tensor.shape[1]
169            conv_tensor = np.concatenate(
170                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
171            )
172            conv_tensors.append(conv_tensor)
173
174        if "positional" in encodings:
175            coords = np.array(PERIODIC_COORDINATES)[1 : zmax + 1]
176            row, col = coords[:, 0], coords[:, 1]
177            drow = self.extra_params.get("drow", default=5)
178            dcol = self.dim - drow
179            nrow = self.extra_params.get("nrow", default=100.0)
180            ncol = self.extra_params.get("ncol", default=1000.0)
181
182            erow = positional_encoding_static(row, drow, nrow)
183            ecol = positional_encoding_static(col, dcol, ncol)
184            conv_tensor = np.concatenate([erow, ecol], axis=-1)
185            dim = conv_tensor.shape[1]
186            conv_tensor = np.concatenate(
187                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
188            )
189            conv_tensors.append(conv_tensor)
190
191        if "oxidation" in encodings:
192            states_set = sorted(set(sum(OXIDATION_STATES, [])))
193            nstates = len(states_set)
194            state_dict = {s: i for i, s in enumerate(states_set)}
195            conv_tensor = np.zeros((zmaxpad, nstates))
196            for i, states in enumerate(OXIDATION_STATES[1 : zmax + 1]):
197                for s in states:
198                    conv_tensor[i, state_dict[s]] = 1
199            conv_tensors.append(conv_tensor)
200
201        if len(conv_tensors) > 0:
202            conv_tensor = np.concatenate(conv_tensors, axis=1)
203            if self.trainable:
204                conv_tensor = self.param(
205                    "conv_tensor",
206                    lambda key: jnp.asarray(conv_tensor, dtype=jnp.float32),
207                )
208            else:
209                conv_tensor = jnp.asarray(conv_tensor, dtype=jnp.float32)
210            conv_tensors = [conv_tensor]
211        else:
212            conv_tensors = []
213
214        if "random" in encodings:
215            rand_encoding = self.param(
216                "rand_encoding",
217                lambda key, shape: jax.nn.standardize(
218                    jax.random.normal(key, shape, dtype=jnp.float32)
219                ),
220                (zmaxpad, self.dim),
221            )
222            conv_tensors.append(rand_encoding)
223
224        if "randint" in encodings:
225            rand_encoding = self.param(
226                "randint_encoding",
227                lambda key, shape: jax.random.randint(key, shape, 0, 2).astype(
228                    jnp.float32
229                ),
230                (zmaxpad, self.dim),
231            )
232            conv_tensors.append(rand_encoding)
233
234        if "randtri" in encodings:
235            rand_encoding = self.param(
236                "randtri_encoding",
237                lambda key, shape: (jax.random.randint(key, shape, 0, 3) - 1).astype(
238                    jnp.float32
239                ),
240                (zmaxpad, self.dim),
241            )
242            conv_tensors.append(rand_encoding)
243
244        assert len(conv_tensors) > 0, f"No encoding recognized in '{self.encoding}'"
245
246        conv_tensor = jnp.concatenate(conv_tensors, axis=1)
247
248        species = inputs["species"] if isinstance(inputs, dict) else inputs
249        out = conv_tensor[species]
250        ############################
251
252        if isinstance(inputs, dict):
253            output_key = self.name if self.output_key is None else self.output_key
254            out = out.astype(inputs["coordinates"].dtype)
255            return {**inputs, output_key: out} if output_key is not None else out
256        return out
257
258
259class RadialBasis(nn.Module):
260    """Computes a radial encoding of distances.
261
262    FID: RADIAL_BASIS
263    """
264
265    end: float
266    """ The maximum distance to consider."""
267    start: float = 0.0
268    """ The minimum distance to consider."""
269    dim: int = 8
270    """ The dimension of the basis."""
271    graph_key: Optional[str] = None
272    """ The key of the graph in the inputs."""
273    output_key: Optional[str] = None
274    """ The key to use for the output in the returned dictionary."""
275    basis: str = "bessel"
276    """ The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky"."""
277    trainable: bool = False
278    """ Whether the basis parameters are trainable or fixed."""
279    enforce_positive: bool = False
280    """ Whether to enforce distance-start to be positive"""
281    gamma: float = 1.0 / (2 * au.BOHR)
282    """ The gamma parameter for the "spooky" basis."""
283    n_levels: int = 10
284    """ The number of levels for the "levels" basis."""
285    alt_bessel_norm: bool = False
286    """ If True, use the (2/(end-start))**0.5 normalization for the bessel basis."""
287    extra_params: Dict = dataclasses.field(default_factory=dict)
288    """ Dictionary of extra parameters for the basis."""
289
290    FID: ClassVar[str] = "RADIAL_BASIS"
291
292    @nn.compact
293    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
294        if self.graph_key is not None:
295            x = inputs[self.graph_key]["distances"]
296        else:
297            x = inputs["distances"] if isinstance(inputs, dict) else inputs
298
299        shape = x.shape
300        x = x.reshape(-1)
301
302        basis = self.basis.lower()
303        ############################
304        if basis == "bessel":
305            c = self.end - self.start
306            x = x[:, None] - self.start
307            # if self.enforce_positive:
308            #     x = jax.nn.softplus(x)
309
310            if self.trainable:
311                bessel_roots = self.param(
312                    "bessel_roots",
313                    lambda key, dim: jnp.asarray(
314                        np.arange(1, dim + 1, dtype=x.dtype)[None, :] * (math.pi / c)
315                    ),
316                    self.dim,
317                )
318                norm = 1.0 / jnp.max(
319                    bessel_roots
320                )  # (2.0 / c) ** 0.5 /jnp.max(bessel_roots)
321            else:
322                bessel_roots = jnp.asarray(
323                    np.arange(1, self.dim + 1, dtype=x.dtype)[None, :] * (math.pi / c)
324                )
325                norm = 1.0 / (
326                    self.dim * math.pi / c
327                )  # (2.0 / c) ** 0.5/(self.dim*math.pi/c)
328            if self.alt_bessel_norm:
329                norm = (2.0 / c) ** 0.5
330            out = norm * jnp.sin(x * bessel_roots) / x
331
332            if self.enforce_positive:
333                out = jnp.where(x > 0, out * (1.0 - jnp.exp(-(x**2))), 0.0)
334
335        elif basis == "gaussian":
336            if self.trainable:
337                roots = self.param(
338                    "radial_centers",
339                    lambda key, dim, start, end: jnp.linspace(
340                        start, end, dim + 1, dtype=x.dtype
341                    )[None, :-1],
342                    self.dim,
343                    self.start,
344                    self.end,
345                )
346                eta = self.param(
347                    "radial_etas",
348                    lambda key, dim, start, end: jnp.full(
349                        dim,
350                        dim / (end - start),
351                        dtype=x.dtype,
352                    )[None, :],
353                    self.dim,
354                    self.start,
355                    self.end,
356                )
357
358            else:
359                roots = jnp.asarray(
360                    np.linspace(self.start, self.end, self.dim + 1)[None, :-1],
361                    dtype=x.dtype,
362                )
363                eta = jnp.asarray(
364                    np.full(self.dim, self.dim / (self.end - self.start))[None, :],
365                    dtype=x.dtype,
366                )
367
368            x = x[:, None]
369            x2 = (eta * (x - roots)) ** 2
370            out = jnp.exp(-x2)
371            if self.enforce_positive:
372                out = jnp.where(
373                    x > self.start,
374                    out * (1.0 - jnp.exp(-10 * (x - self.start) ** 2)),
375                    0.0,
376                )
377
378        elif basis == "gaussian_rinv":
379            rinv_high = 1.0 / max(self.start, 0.1)
380            rinv_low = 1.0 / (0.8 * self.end)
381
382            if self.trainable:
383                roots = self.param(
384                    "radial_centers",
385                    lambda key, dim, rinv_low, rinv_high: jnp.linspace(
386                        rinv_low, rinv_high, dim, dtype=x.dtype
387                    )[None, :],
388                    self.dim,
389                    rinv_low,
390                    rinv_high,
391                )
392                sigmas = self.param(
393                    "radial_sigmas",
394                    lambda key, dim, rinv_low, rinv_high: jnp.full(
395                        dim,
396                        2**0.5 / (2 * dim * rinv_low),
397                        dtype=x.dtype,
398                    )[None, :],
399                    self.dim,
400                    rinv_low,
401                    rinv_high,
402                )
403            else:
404                roots = jnp.asarray(
405                    np.linspace(rinv_low, rinv_high, self.dim, dtype=x.dtype)[None, :]
406                )
407                sigmas = jnp.asarray(
408                    np.full(
409                        self.dim,
410                        2**0.5 / (2 * self.dim * rinv_low),
411                    )[None, :],
412                    dtype=x.dtype,
413                )
414
415            rinv = 1.0 / x[:, None]
416
417            out = jnp.exp(-((rinv - roots) ** 2) / sigmas**2)
418
419        elif basis == "fourier":
420            if self.trainable:
421                roots = self.param(
422                    "roots",
423                    lambda key, dim: jnp.arange(dim, dtype=x.dtype)[None, :] * math.pi,
424                    self.dim,
425                )
426            else:
427                roots = jnp.asarray(
428                    np.arange(self.dim)[None, :] * math.pi, dtype=x.dtype
429                )
430            c = self.end - self.start
431            x = x[:, None] - self.start
432            # if self.enforce_positive:
433            #     x = jax.nn.softplus(x)
434            norm = 1.0 / (0.25 + 0.5 * self.dim) ** 0.5
435            out = norm * jnp.cos(x * roots / c)
436            if self.enforce_positive:
437                out = jnp.where(x > 0, out, norm)
438
439        elif basis == "spooky":
440
441            gamma = self.gamma
442            if self.trainable:
443                gamma = jnp.abs(
444                    self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype))
445                )
446
447            if self.enforce_positive:
448                x = jnp.where(x - self.start > 1.0e-3, x - self.start, 1.0e-3)[:, None]
449                dim = self.dim
450            else:
451                x = x[:, None] - self.start
452                dim = self.dim - 1
453
454            norms = []
455            for k in range(self.dim):
456                norms.append(math.comb(dim, k))
457            norms = jnp.asarray(np.array(norms)[None, :], dtype=x.dtype)
458
459            e = jnp.exp(-gamma * x)
460            k = jnp.asarray(np.arange(self.dim, dtype=x.dtype)[None, :])
461            b = e**k * (1 - e) ** (dim - k)
462            out = b * e * norms
463            if self.enforce_positive:
464                out = jnp.where(x > 1.0e-3, out * (1.0 - jnp.exp(-(x**2))), 0.0)
465            # logfac = np.zeros(self.dim)
466            # for i in range(2, self.dim):
467            #     logfac[i] = logfac[i - 1] + np.log(i)
468            # k = np.arange(self.dim)
469            # n = self.dim - 1 - k
470            # logbin = jnp.asarray((logfac[-1] - logfac[k] - logfac[n])[None,:], dtype=x.dtype)
471            # n = jnp.asarray(n[None,:], dtype=x.dtype)
472            # k = jnp.asarray(k[None,:], dtype=x.dtype)
473
474            # gamma = 1.0 / (2 * au.BOHR)
475            # if self.trainable:
476            #     gamma = jnp.abs(
477            #         self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype))
478            #     )
479            # gammar = (-gamma * x)[:,None]
480            # x = logbin + n * gammar + k * jnp.log(-jnp.expm1(gammar))
481            # out = jnp.exp(x)*jnp.exp(gammar)
482        elif basis == "levels":
483            assert self.n_levels >= 2, "Number of levels must be >= 2."
484
485            def initialize_levels(key):
486                key0, key1, key_phi = jax.random.split(key, 3)
487                level0 = jax.random.randint(key0, (self.dim,), 0, 2)
488                level1 = jax.random.randint(key1, (self.dim,), 0, 2)
489                # level0 = jax.random.normal(key0, (self.dim,), dtype=jnp.float32)
490                # level1 = jax.random.normal(key1, (self.dim,), dtype=jnp.float32)
491                phi = jax.random.uniform(key_phi, (self.dim,), dtype=jnp.float32)
492                levels = [level0]
493                for l in range(2, self.n_levels - 1):
494                    tau = float(self.n_levels - l) / float(self.n_levels - 1)
495                    phil = phi < tau
496                    level = jnp.where(phil, level0, level1)
497                    levels.append(level)
498                levels.append(level1)
499                return jnp.stack(levels).astype(jnp.float32)
500
501            levels = self.param("levels", initialize_levels)
502            # levels = self.param(
503            #     "levels",
504            #     lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32),
505            #     (self.n_levels,self.dim),
506            # )
507
508            flevel = (x - self.start) / (self.end - self.start) * (self.n_levels - 1)
509            ilevel = jnp.floor(flevel).astype(jnp.int32)
510            ilevel1 = jnp.clip(ilevel + 1, 0, self.n_levels - 1)
511            ilevel = jnp.clip(ilevel, 0, self.n_levels - 1)
512
513            dx = flevel - ilevel
514            w = 0.5 * (1 + jnp.cos(jnp.pi * dx))[:, None]
515
516            ## interpolate between level vectors
517            v1 = levels[ilevel]
518            v2 = levels[ilevel1]
519            out = v1 * w + v2 * (1 - w)
520        elif basis == "finite_support":
521            flevel = (x - self.start) / (self.end - self.start) * (self.dim + 1)
522            ilevel = jnp.floor(flevel).astype(jnp.int32)
523            ilevel1 = jnp.clip(ilevel + 1, 0, self.dim + 1)
524            ilevel = jnp.clip(ilevel, 0, self.dim + 1)
525
526            dx = flevel - ilevel
527            w = 0.5 * (1 + jnp.cos(jnp.pi * dx))
528
529            ilevelflat = ilevel + jnp.arange(x.shape[0]) * (self.dim + 2)
530            ilevel1flat = ilevel1 + jnp.arange(x.shape[0]) * (self.dim + 2)
531
532            out = (
533                jnp.zeros((x.shape[0] * (self.dim + 2)), dtype=x.dtype)
534                .at[ilevelflat]
535                .set(w)
536                .at[ilevel1flat]
537                .set(1 - w)
538                .reshape(-1, self.dim + 2)[:, 1:-1]
539            )
540
541        elif basis == "exp_lr" or basis == "exp":
542            zeta = self.extra_params.get("zeta", default=2.0)
543            s = self.extra_params.get("s", default=0.5)
544            n = np.arange(self.dim)
545            # if self.trainable:
546            # zeta = jnp.abs(
547            #     self.param("zeta", lambda key: jnp.asarray(zeta, dtype=x.dtype))
548            # )
549            # s = jnp.abs(self.param("s", lambda key: jnp.asarray(s, dtype=x.dtype)))
550
551            a = zeta * s**n
552            xx = np.linspace(0, self.end, 10000)
553
554            if self.start > 0:
555                a1 = np.minimum(1.0 / a, 1.0)
556                switchstart = jnp.where(
557                    x[:, None] < self.end,
558                    1
559                    - (
560                        0.5
561                        + 0.5
562                        * jnp.cos(
563                            np.pi * (x[:, None] - self.start) / (self.end - self.start)
564                        )
565                    )
566                    ** a1[None, :],
567                    1,
568                ) * (x[:, None] > self.start)
569                switchstartxx = (
570                    1
571                    - (
572                        0.5
573                        + 0.5
574                        * np.cos(
575                            np.pi * (xx[:, None] - self.start) / (self.end - self.start)
576                        )
577                    )
578                    ** a1[None, :]
579                ) * (xx[:, None] > self.start)
580
581            else:
582                switchstart = 1.0
583                switchstartxx = 1.0
584
585            norm = 1.0 / np.trapz(
586                switchstartxx * np.exp(-a[None, :] * xx[:, None]), xx, axis=0
587            )
588            # norm = 1./np.max(switchstartxx*np.exp(-a[None, :] * xx[:, None]), axis=0)
589
590            if self.trainable:
591                a = jnp.abs(self.param("exponent", lambda key: jnp.asarray(a)))
592
593            out = switchstart * jnp.exp(-a[None, :] * x[:, None]) * norm[None, :]
594
595        elif basis == "neural_net" or basis == "nn":
596            neurons = self.extra_params.get(
597                "hidden_neurons", default=[2 * self.dim]
598            ) + [self.dim]
599            activation = self.extra_params.get("activation", default="swish")
600            use_bias = self.extra_params.get("use_bias", default=True)
601
602            out = FullyConnectedNet(
603                neurons, activation=activation, use_bias=use_bias, squeeze=False
604            )(x[:, None])
605        elif basis == "damped_coulomb":
606            l = np.arange(self.dim)[None, :]
607            a2 = self.extra_params.get("a", default=1.0)**2
608            x = x[:, None] - self.start
609            end = self.end - self.start
610            x21 = a2 + x**2
611            R21 = a2 + end**2
612            out = (
613                1.0 / x21 ** (0.5 * (l + 1))
614                - 1.0 / (R21 ** (0.5 * (l + 1)))
615                + (x - end) * ((l + 1) * end / (R21 ** (0.5 * (l + 3))))
616            ) * (x < end)
617
618        elif basis.startswith("spherical_bessel_"):
619            l = int(basis.split("_")[-1])
620            out = generate_spherical_jn_basis([l], self.dim, self.end)(x)
621        else:
622            raise NotImplementedError(f"Unknown radial basis {basis}.")
623        ############################
624
625        out = out.reshape((*shape, self.dim))
626
627        if self.graph_key is not None:
628            output_key = self.name if self.output_key is None else self.output_key
629            return {**inputs, output_key: out}
630        return out
631
632
633def positional_encoding_static(t, d: int, n: float = 10000.0):
634    if d % 2 == 0:
635        k = np.arange(d // 2)
636    else:
637        k = np.arange((d + 1) // 2)
638    wk = np.asarray(1.0 / (n ** (2 * k / d)))
639    wkt = wk[None, :] * t[:, None]
640    out = np.concatenate([np.cos(wkt), np.sin(wkt)], axis=-1)
641    if d % 2 == 1:
642        out = out[:, :-1]
643    return out
644
645
646@partial(jax.jit, static_argnums=(1, 2), inline=True)
647def positional_encoding(t, d: int, n: float = 10000.0):
648    if d % 2 == 0:
649        k = np.arange(d // 2)
650    else:
651        k = np.arange((d + 1) // 2)
652    wk = jnp.asarray(1.0 / (n ** (2 * k / d)))
653    wkt = wk[None, :] * t[:, None]
654    out = jnp.concatenate([jnp.cos(wkt), jnp.sin(wkt)], axis=-1)
655    if d % 2 == 1:
656        out = out[:, :-1]
657    return out
658
659
660def generate_spherical_jn_basis(dim:int, rc:float, ls:Union[int,Sequence[int]]=[0], print_code:bool=False, jit:bool=False):
661    from sympy import Symbol, jn, expand_func
662    from scipy.special import spherical_jn
663    from sympy import jn_zeros
664    import scipy.integrate as integrate
665
666    if isinstance(ls, int):
667        ls = list(range(ls + 1))
668    zl = [Symbol(f"xz[...,{l}]") for l in ls]
669    zn = np.array([jn_zeros(l, dim) for l in ls], dtype=float).T
670    znrc = zn / rc
671    norms = np.zeros((dim, len(ls)), dtype=float)
672    for l in ls:
673        for i in range(dim):
674            norms[i, l] = (
675                integrate.quad(lambda x: (spherical_jn(l, x) * x) ** 2, 0, zn[i, l])[0]
676                / znrc[i, l] ** 3
677            ) ** (-0.5)
678
679    fn_str = f"""def spherical_jn_basis_(x):
680    from jax.numpy import cos,sin
681    
682    znrc = jnp.array({znrc.tolist()},dtype=x.dtype)
683    norms = jnp.array({norms.tolist()},dtype=x.dtype)
684    xshape = x.shape
685    x = x.reshape(-1)
686    xz = x[:,None,None]*znrc[None,:,:]
687
688    jns = jnp.stack([
689  """
690    for l in ls:
691        fn_str += f"      {expand_func(jn(l, zl[l]))},\n"
692    fn_str += f"""    ],axis=-1)
693    return (norms[None,:,:]*jns).reshape(*xshape,{dim},{len(ls)})
694  """
695
696    if print_code:
697        print(fn_str)
698    exec(fn_str)
699    jn_basis = locals()["spherical_jn_basis_"]
700    if jit:
701        jn_basis = jax.jit(jn_basis)
702    return jn_basis
class SpeciesEncoding(flax.linen.module.Module):
 28class SpeciesEncoding(nn.Module):
 29    """A module that encodes chemical species information.
 30
 31    FID: SPECIES_ENCODING
 32    """
 33
 34    encoding: str = "random"
 35    """ The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". 
 36        Multiple encodings can be concatenated using the "+" separator.
 37    """
 38    dim: int = 16
 39    """ The dimension of the encoding if not fixed by design."""
 40    zmax: int = 86
 41    """ The maximum atomic number to encode."""
 42    output_key: Optional[str] = None
 43    """ The key to use for the output in the returned dictionary."""
 44
 45    species_order: Optional[Union[str, Sequence[str]]] = None
 46    """ The order of the species to use for the encoding. Only used for "onehot" encoding.
 47         If None, we encode all elements up to `zmax`."""
 48    trainable: bool = False
 49    """ Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable."""
 50    extra_params: Dict = dataclasses.field(default_factory=dict)
 51    """ Dictionary of extra parameters for the basis."""
 52
 53    FID: ClassVar[str] = "SPECIES_ENCODING"
 54
 55    @nn.compact
 56    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 57
 58        zmax = self.zmax
 59        if zmax <= 0 or zmax > len(PERIODIC_TABLE):
 60            zmax = len(PERIODIC_TABLE)
 61
 62        zmaxpad = zmax + 2
 63
 64        encoding = self.encoding.lower().strip()
 65        encodings = encoding.split("+")
 66        ############################
 67        conv_tensors = []
 68
 69        if "one_hot" in encodings or "onehot" in encodings:
 70            if self.species_order is None:
 71                conv_tensor = np.eye(zmax)
 72                conv_tensor = np.concatenate(
 73                    [np.zeros((1, zmax)), conv_tensor, np.zeros((1, zmax))], axis=0
 74                )
 75            else:
 76                if isinstance(self.species_order, str):
 77                    species_order = [el.strip() for el in self.species_order.split(",")]
 78                else:
 79                    species_order = [el for el in self.species_order]
 80                conv_tensor = np.zeros((zmaxpad, len(species_order)))
 81                for i, s in enumerate(species_order):
 82                    conv_tensor[PERIODIC_TABLE_REV_IDX[s], i] = 1
 83
 84            conv_tensors.append(conv_tensor)
 85
 86        if "occupation" in encodings:
 87            conv_tensor = np.zeros((zmaxpad, (zmax + 1) // 2))
 88            for i in range(1, zmax + 1):
 89                conv_tensor[i, : i // 2] = 1
 90                if i % 2 == 1:
 91                    conv_tensor[i, i // 2] = 0.5
 92
 93            conv_tensors.append(conv_tensor)
 94
 95        if "electronic_structure" in encodings:
 96            Z = np.arange(1, zmax + 1).reshape(-1, 1)
 97            Zref = [zmax]
 98            e_struct = np.array(EL_STRUCT[1 : zmax + 1])
 99            eref = [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 14, 10, 6]
100            v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1])
101            vref = [2, 6, 10, 14]
102            if zmax <= 86:
103                e_struct = e_struct[:, :15]
104                eref = eref[:15]
105            ref = np.array(Zref + eref + vref)
106            conv_tensor = np.concatenate([Z, e_struct, v_struct], axis=1)
107            conv_tensor = conv_tensor / ref[None, :]
108            dim = conv_tensor.shape[1]
109            conv_tensor = np.concatenate(
110                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
111            )
112
113            conv_tensors.append(conv_tensor)
114
115        if "properties" in encodings:
116            props = np.array(XENONPY_PROPS)[1:-1]
117            assert (
118                self.zmax <= props.shape[0]
119            ), f"zmax > {props.shape[0]} not supported for xenonpy properties"
120            conv_tensor = props[1 : zmax + 1]
121            mean = np.mean(props, axis=0)
122            std = np.std(props, axis=0)
123            conv_tensor = (conv_tensor - mean[None, :]) / std[None, :]
124            dim = conv_tensor.shape[1]
125            conv_tensor = np.concatenate(
126                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
127            )
128            conv_tensors.append(conv_tensor)
129
130        if "valence_properties" in encodings:
131            assert zmax <= 86, "Valence properties only available for zmax <= 86"
132            Z = np.arange(1, zmax + 1).reshape(-1, 1)
133            Zref = [zmax]
134            Zinv = 1.0 / Z
135            Zinvref = [1.0]
136            v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1])
137            vref = [2, 6, 10, 14]
138            ionization = np.array(ATOMIC_IONIZATION_ENERGY[1 : zmax + 1]).reshape(-1, 1)
139            ionizationref = [0.5]
140            polariz = np.array(POLARIZABILITIES[1 : zmax + 1]).reshape(-1, 1)
141            polarizref = [np.median(polariz)]
142
143            cov = np.array(D3_COV_RADII[1 : zmax + 1]).reshape(-1, 1)
144            covref = [np.median(cov)]
145
146            vdw = np.array(VDW_RADII[1 : zmax + 1]).reshape(-1, 1)
147            vdwref = [np.median(vdw)]
148
149            ref = np.array(
150                Zref + Zinvref + ionizationref + polarizref + covref + vdwref + vref
151            )
152            conv_tensor = np.concatenate(
153                [Z, Zinv, ionization, polariz, cov, vdw, v_struct], axis=1
154            )
155
156            conv_tensor = conv_tensor / ref[None, :]
157            dim = conv_tensor.shape[1]
158            conv_tensor = np.concatenate(
159                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
160            )
161            conv_tensors.append(conv_tensor)
162
163        if "sjs_coordinates" in encodings:
164            coords = np.array(SJS_COORDINATES)[1:-1]
165            conv_tensor = coords[1 : zmax + 1]
166            mean = np.mean(coords, axis=0)
167            std = np.std(coords, axis=0)
168            conv_tensor = (conv_tensor - mean[None, :]) / std[None, :]
169            dim = conv_tensor.shape[1]
170            conv_tensor = np.concatenate(
171                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
172            )
173            conv_tensors.append(conv_tensor)
174
175        if "positional" in encodings:
176            coords = np.array(PERIODIC_COORDINATES)[1 : zmax + 1]
177            row, col = coords[:, 0], coords[:, 1]
178            drow = self.extra_params.get("drow", default=5)
179            dcol = self.dim - drow
180            nrow = self.extra_params.get("nrow", default=100.0)
181            ncol = self.extra_params.get("ncol", default=1000.0)
182
183            erow = positional_encoding_static(row, drow, nrow)
184            ecol = positional_encoding_static(col, dcol, ncol)
185            conv_tensor = np.concatenate([erow, ecol], axis=-1)
186            dim = conv_tensor.shape[1]
187            conv_tensor = np.concatenate(
188                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
189            )
190            conv_tensors.append(conv_tensor)
191
192        if "oxidation" in encodings:
193            states_set = sorted(set(sum(OXIDATION_STATES, [])))
194            nstates = len(states_set)
195            state_dict = {s: i for i, s in enumerate(states_set)}
196            conv_tensor = np.zeros((zmaxpad, nstates))
197            for i, states in enumerate(OXIDATION_STATES[1 : zmax + 1]):
198                for s in states:
199                    conv_tensor[i, state_dict[s]] = 1
200            conv_tensors.append(conv_tensor)
201
202        if len(conv_tensors) > 0:
203            conv_tensor = np.concatenate(conv_tensors, axis=1)
204            if self.trainable:
205                conv_tensor = self.param(
206                    "conv_tensor",
207                    lambda key: jnp.asarray(conv_tensor, dtype=jnp.float32),
208                )
209            else:
210                conv_tensor = jnp.asarray(conv_tensor, dtype=jnp.float32)
211            conv_tensors = [conv_tensor]
212        else:
213            conv_tensors = []
214
215        if "random" in encodings:
216            rand_encoding = self.param(
217                "rand_encoding",
218                lambda key, shape: jax.nn.standardize(
219                    jax.random.normal(key, shape, dtype=jnp.float32)
220                ),
221                (zmaxpad, self.dim),
222            )
223            conv_tensors.append(rand_encoding)
224
225        if "randint" in encodings:
226            rand_encoding = self.param(
227                "randint_encoding",
228                lambda key, shape: jax.random.randint(key, shape, 0, 2).astype(
229                    jnp.float32
230                ),
231                (zmaxpad, self.dim),
232            )
233            conv_tensors.append(rand_encoding)
234
235        if "randtri" in encodings:
236            rand_encoding = self.param(
237                "randtri_encoding",
238                lambda key, shape: (jax.random.randint(key, shape, 0, 3) - 1).astype(
239                    jnp.float32
240                ),
241                (zmaxpad, self.dim),
242            )
243            conv_tensors.append(rand_encoding)
244
245        assert len(conv_tensors) > 0, f"No encoding recognized in '{self.encoding}'"
246
247        conv_tensor = jnp.concatenate(conv_tensors, axis=1)
248
249        species = inputs["species"] if isinstance(inputs, dict) else inputs
250        out = conv_tensor[species]
251        ############################
252
253        if isinstance(inputs, dict):
254            output_key = self.name if self.output_key is None else self.output_key
255            out = out.astype(inputs["coordinates"].dtype)
256            return {**inputs, output_key: out} if output_key is not None else out
257        return out

A module that encodes chemical species information.

FID: SPECIES_ENCODING

SpeciesEncoding( encoding: str = 'random', dim: int = 16, zmax: int = 86, output_key: Optional[str] = None, species_order: Union[str, Sequence[str], NoneType] = None, trainable: bool = False, extra_params: Dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
encoding: str = 'random'

The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". Multiple encodings can be concatenated using the "+" separator.

dim: int = 16

The dimension of the encoding if not fixed by design.

zmax: int = 86

The maximum atomic number to encode.

output_key: Optional[str] = None

The key to use for the output in the returned dictionary.

species_order: Union[str, Sequence[str], NoneType] = None

The order of the species to use for the encoding. Only used for "onehot" encoding. If None, we encode all elements up to zmax.

trainable: bool = False

Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable.

extra_params: Dict

Dictionary of extra parameters for the basis.

FID: ClassVar[str] = 'SPECIES_ENCODING'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class RadialBasis(flax.linen.module.Module):
260class RadialBasis(nn.Module):
261    """Computes a radial encoding of distances.
262
263    FID: RADIAL_BASIS
264    """
265
266    end: float
267    """ The maximum distance to consider."""
268    start: float = 0.0
269    """ The minimum distance to consider."""
270    dim: int = 8
271    """ The dimension of the basis."""
272    graph_key: Optional[str] = None
273    """ The key of the graph in the inputs."""
274    output_key: Optional[str] = None
275    """ The key to use for the output in the returned dictionary."""
276    basis: str = "bessel"
277    """ The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky"."""
278    trainable: bool = False
279    """ Whether the basis parameters are trainable or fixed."""
280    enforce_positive: bool = False
281    """ Whether to enforce distance-start to be positive"""
282    gamma: float = 1.0 / (2 * au.BOHR)
283    """ The gamma parameter for the "spooky" basis."""
284    n_levels: int = 10
285    """ The number of levels for the "levels" basis."""
286    alt_bessel_norm: bool = False
287    """ If True, use the (2/(end-start))**0.5 normalization for the bessel basis."""
288    extra_params: Dict = dataclasses.field(default_factory=dict)
289    """ Dictionary of extra parameters for the basis."""
290
291    FID: ClassVar[str] = "RADIAL_BASIS"
292
293    @nn.compact
294    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
295        if self.graph_key is not None:
296            x = inputs[self.graph_key]["distances"]
297        else:
298            x = inputs["distances"] if isinstance(inputs, dict) else inputs
299
300        shape = x.shape
301        x = x.reshape(-1)
302
303        basis = self.basis.lower()
304        ############################
305        if basis == "bessel":
306            c = self.end - self.start
307            x = x[:, None] - self.start
308            # if self.enforce_positive:
309            #     x = jax.nn.softplus(x)
310
311            if self.trainable:
312                bessel_roots = self.param(
313                    "bessel_roots",
314                    lambda key, dim: jnp.asarray(
315                        np.arange(1, dim + 1, dtype=x.dtype)[None, :] * (math.pi / c)
316                    ),
317                    self.dim,
318                )
319                norm = 1.0 / jnp.max(
320                    bessel_roots
321                )  # (2.0 / c) ** 0.5 /jnp.max(bessel_roots)
322            else:
323                bessel_roots = jnp.asarray(
324                    np.arange(1, self.dim + 1, dtype=x.dtype)[None, :] * (math.pi / c)
325                )
326                norm = 1.0 / (
327                    self.dim * math.pi / c
328                )  # (2.0 / c) ** 0.5/(self.dim*math.pi/c)
329            if self.alt_bessel_norm:
330                norm = (2.0 / c) ** 0.5
331            out = norm * jnp.sin(x * bessel_roots) / x
332
333            if self.enforce_positive:
334                out = jnp.where(x > 0, out * (1.0 - jnp.exp(-(x**2))), 0.0)
335
336        elif basis == "gaussian":
337            if self.trainable:
338                roots = self.param(
339                    "radial_centers",
340                    lambda key, dim, start, end: jnp.linspace(
341                        start, end, dim + 1, dtype=x.dtype
342                    )[None, :-1],
343                    self.dim,
344                    self.start,
345                    self.end,
346                )
347                eta = self.param(
348                    "radial_etas",
349                    lambda key, dim, start, end: jnp.full(
350                        dim,
351                        dim / (end - start),
352                        dtype=x.dtype,
353                    )[None, :],
354                    self.dim,
355                    self.start,
356                    self.end,
357                )
358
359            else:
360                roots = jnp.asarray(
361                    np.linspace(self.start, self.end, self.dim + 1)[None, :-1],
362                    dtype=x.dtype,
363                )
364                eta = jnp.asarray(
365                    np.full(self.dim, self.dim / (self.end - self.start))[None, :],
366                    dtype=x.dtype,
367                )
368
369            x = x[:, None]
370            x2 = (eta * (x - roots)) ** 2
371            out = jnp.exp(-x2)
372            if self.enforce_positive:
373                out = jnp.where(
374                    x > self.start,
375                    out * (1.0 - jnp.exp(-10 * (x - self.start) ** 2)),
376                    0.0,
377                )
378
379        elif basis == "gaussian_rinv":
380            rinv_high = 1.0 / max(self.start, 0.1)
381            rinv_low = 1.0 / (0.8 * self.end)
382
383            if self.trainable:
384                roots = self.param(
385                    "radial_centers",
386                    lambda key, dim, rinv_low, rinv_high: jnp.linspace(
387                        rinv_low, rinv_high, dim, dtype=x.dtype
388                    )[None, :],
389                    self.dim,
390                    rinv_low,
391                    rinv_high,
392                )
393                sigmas = self.param(
394                    "radial_sigmas",
395                    lambda key, dim, rinv_low, rinv_high: jnp.full(
396                        dim,
397                        2**0.5 / (2 * dim * rinv_low),
398                        dtype=x.dtype,
399                    )[None, :],
400                    self.dim,
401                    rinv_low,
402                    rinv_high,
403                )
404            else:
405                roots = jnp.asarray(
406                    np.linspace(rinv_low, rinv_high, self.dim, dtype=x.dtype)[None, :]
407                )
408                sigmas = jnp.asarray(
409                    np.full(
410                        self.dim,
411                        2**0.5 / (2 * self.dim * rinv_low),
412                    )[None, :],
413                    dtype=x.dtype,
414                )
415
416            rinv = 1.0 / x[:, None]
417
418            out = jnp.exp(-((rinv - roots) ** 2) / sigmas**2)
419
420        elif basis == "fourier":
421            if self.trainable:
422                roots = self.param(
423                    "roots",
424                    lambda key, dim: jnp.arange(dim, dtype=x.dtype)[None, :] * math.pi,
425                    self.dim,
426                )
427            else:
428                roots = jnp.asarray(
429                    np.arange(self.dim)[None, :] * math.pi, dtype=x.dtype
430                )
431            c = self.end - self.start
432            x = x[:, None] - self.start
433            # if self.enforce_positive:
434            #     x = jax.nn.softplus(x)
435            norm = 1.0 / (0.25 + 0.5 * self.dim) ** 0.5
436            out = norm * jnp.cos(x * roots / c)
437            if self.enforce_positive:
438                out = jnp.where(x > 0, out, norm)
439
440        elif basis == "spooky":
441
442            gamma = self.gamma
443            if self.trainable:
444                gamma = jnp.abs(
445                    self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype))
446                )
447
448            if self.enforce_positive:
449                x = jnp.where(x - self.start > 1.0e-3, x - self.start, 1.0e-3)[:, None]
450                dim = self.dim
451            else:
452                x = x[:, None] - self.start
453                dim = self.dim - 1
454
455            norms = []
456            for k in range(self.dim):
457                norms.append(math.comb(dim, k))
458            norms = jnp.asarray(np.array(norms)[None, :], dtype=x.dtype)
459
460            e = jnp.exp(-gamma * x)
461            k = jnp.asarray(np.arange(self.dim, dtype=x.dtype)[None, :])
462            b = e**k * (1 - e) ** (dim - k)
463            out = b * e * norms
464            if self.enforce_positive:
465                out = jnp.where(x > 1.0e-3, out * (1.0 - jnp.exp(-(x**2))), 0.0)
466            # logfac = np.zeros(self.dim)
467            # for i in range(2, self.dim):
468            #     logfac[i] = logfac[i - 1] + np.log(i)
469            # k = np.arange(self.dim)
470            # n = self.dim - 1 - k
471            # logbin = jnp.asarray((logfac[-1] - logfac[k] - logfac[n])[None,:], dtype=x.dtype)
472            # n = jnp.asarray(n[None,:], dtype=x.dtype)
473            # k = jnp.asarray(k[None,:], dtype=x.dtype)
474
475            # gamma = 1.0 / (2 * au.BOHR)
476            # if self.trainable:
477            #     gamma = jnp.abs(
478            #         self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype))
479            #     )
480            # gammar = (-gamma * x)[:,None]
481            # x = logbin + n * gammar + k * jnp.log(-jnp.expm1(gammar))
482            # out = jnp.exp(x)*jnp.exp(gammar)
483        elif basis == "levels":
484            assert self.n_levels >= 2, "Number of levels must be >= 2."
485
486            def initialize_levels(key):
487                key0, key1, key_phi = jax.random.split(key, 3)
488                level0 = jax.random.randint(key0, (self.dim,), 0, 2)
489                level1 = jax.random.randint(key1, (self.dim,), 0, 2)
490                # level0 = jax.random.normal(key0, (self.dim,), dtype=jnp.float32)
491                # level1 = jax.random.normal(key1, (self.dim,), dtype=jnp.float32)
492                phi = jax.random.uniform(key_phi, (self.dim,), dtype=jnp.float32)
493                levels = [level0]
494                for l in range(2, self.n_levels - 1):
495                    tau = float(self.n_levels - l) / float(self.n_levels - 1)
496                    phil = phi < tau
497                    level = jnp.where(phil, level0, level1)
498                    levels.append(level)
499                levels.append(level1)
500                return jnp.stack(levels).astype(jnp.float32)
501
502            levels = self.param("levels", initialize_levels)
503            # levels = self.param(
504            #     "levels",
505            #     lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32),
506            #     (self.n_levels,self.dim),
507            # )
508
509            flevel = (x - self.start) / (self.end - self.start) * (self.n_levels - 1)
510            ilevel = jnp.floor(flevel).astype(jnp.int32)
511            ilevel1 = jnp.clip(ilevel + 1, 0, self.n_levels - 1)
512            ilevel = jnp.clip(ilevel, 0, self.n_levels - 1)
513
514            dx = flevel - ilevel
515            w = 0.5 * (1 + jnp.cos(jnp.pi * dx))[:, None]
516
517            ## interpolate between level vectors
518            v1 = levels[ilevel]
519            v2 = levels[ilevel1]
520            out = v1 * w + v2 * (1 - w)
521        elif basis == "finite_support":
522            flevel = (x - self.start) / (self.end - self.start) * (self.dim + 1)
523            ilevel = jnp.floor(flevel).astype(jnp.int32)
524            ilevel1 = jnp.clip(ilevel + 1, 0, self.dim + 1)
525            ilevel = jnp.clip(ilevel, 0, self.dim + 1)
526
527            dx = flevel - ilevel
528            w = 0.5 * (1 + jnp.cos(jnp.pi * dx))
529
530            ilevelflat = ilevel + jnp.arange(x.shape[0]) * (self.dim + 2)
531            ilevel1flat = ilevel1 + jnp.arange(x.shape[0]) * (self.dim + 2)
532
533            out = (
534                jnp.zeros((x.shape[0] * (self.dim + 2)), dtype=x.dtype)
535                .at[ilevelflat]
536                .set(w)
537                .at[ilevel1flat]
538                .set(1 - w)
539                .reshape(-1, self.dim + 2)[:, 1:-1]
540            )
541
542        elif basis == "exp_lr" or basis == "exp":
543            zeta = self.extra_params.get("zeta", default=2.0)
544            s = self.extra_params.get("s", default=0.5)
545            n = np.arange(self.dim)
546            # if self.trainable:
547            # zeta = jnp.abs(
548            #     self.param("zeta", lambda key: jnp.asarray(zeta, dtype=x.dtype))
549            # )
550            # s = jnp.abs(self.param("s", lambda key: jnp.asarray(s, dtype=x.dtype)))
551
552            a = zeta * s**n
553            xx = np.linspace(0, self.end, 10000)
554
555            if self.start > 0:
556                a1 = np.minimum(1.0 / a, 1.0)
557                switchstart = jnp.where(
558                    x[:, None] < self.end,
559                    1
560                    - (
561                        0.5
562                        + 0.5
563                        * jnp.cos(
564                            np.pi * (x[:, None] - self.start) / (self.end - self.start)
565                        )
566                    )
567                    ** a1[None, :],
568                    1,
569                ) * (x[:, None] > self.start)
570                switchstartxx = (
571                    1
572                    - (
573                        0.5
574                        + 0.5
575                        * np.cos(
576                            np.pi * (xx[:, None] - self.start) / (self.end - self.start)
577                        )
578                    )
579                    ** a1[None, :]
580                ) * (xx[:, None] > self.start)
581
582            else:
583                switchstart = 1.0
584                switchstartxx = 1.0
585
586            norm = 1.0 / np.trapz(
587                switchstartxx * np.exp(-a[None, :] * xx[:, None]), xx, axis=0
588            )
589            # norm = 1./np.max(switchstartxx*np.exp(-a[None, :] * xx[:, None]), axis=0)
590
591            if self.trainable:
592                a = jnp.abs(self.param("exponent", lambda key: jnp.asarray(a)))
593
594            out = switchstart * jnp.exp(-a[None, :] * x[:, None]) * norm[None, :]
595
596        elif basis == "neural_net" or basis == "nn":
597            neurons = self.extra_params.get(
598                "hidden_neurons", default=[2 * self.dim]
599            ) + [self.dim]
600            activation = self.extra_params.get("activation", default="swish")
601            use_bias = self.extra_params.get("use_bias", default=True)
602
603            out = FullyConnectedNet(
604                neurons, activation=activation, use_bias=use_bias, squeeze=False
605            )(x[:, None])
606        elif basis == "damped_coulomb":
607            l = np.arange(self.dim)[None, :]
608            a2 = self.extra_params.get("a", default=1.0)**2
609            x = x[:, None] - self.start
610            end = self.end - self.start
611            x21 = a2 + x**2
612            R21 = a2 + end**2
613            out = (
614                1.0 / x21 ** (0.5 * (l + 1))
615                - 1.0 / (R21 ** (0.5 * (l + 1)))
616                + (x - end) * ((l + 1) * end / (R21 ** (0.5 * (l + 3))))
617            ) * (x < end)
618
619        elif basis.startswith("spherical_bessel_"):
620            l = int(basis.split("_")[-1])
621            out = generate_spherical_jn_basis([l], self.dim, self.end)(x)
622        else:
623            raise NotImplementedError(f"Unknown radial basis {basis}.")
624        ############################
625
626        out = out.reshape((*shape, self.dim))
627
628        if self.graph_key is not None:
629            output_key = self.name if self.output_key is None else self.output_key
630            return {**inputs, output_key: out}
631        return out

Computes a radial encoding of distances.

FID: RADIAL_BASIS

RadialBasis( end: float, start: float = 0.0, dim: int = 8, graph_key: Optional[str] = None, output_key: Optional[str] = None, basis: str = 'bessel', trainable: bool = False, enforce_positive: bool = False, gamma: float = 0.9448630639252209, n_levels: int = 10, alt_bessel_norm: bool = False, extra_params: Dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
end: float

The maximum distance to consider.

start: float = 0.0

The minimum distance to consider.

dim: int = 8

The dimension of the basis.

graph_key: Optional[str] = None

The key of the graph in the inputs.

output_key: Optional[str] = None

The key to use for the output in the returned dictionary.

basis: str = 'bessel'

The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky".

trainable: bool = False

Whether the basis parameters are trainable or fixed.

enforce_positive: bool = False

Whether to enforce distance-start to be positive

gamma: float = 0.9448630639252209

The gamma parameter for the "spooky" basis.

n_levels: int = 10

The number of levels for the "levels" basis.

alt_bessel_norm: bool = False

If True, use the (2/(end-start))**0.5 normalization for the bessel basis.

extra_params: Dict

Dictionary of extra parameters for the basis.

FID: ClassVar[str] = 'RADIAL_BASIS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
def positional_encoding_static(t, d: int, n: float = 10000.0):
634def positional_encoding_static(t, d: int, n: float = 10000.0):
635    if d % 2 == 0:
636        k = np.arange(d // 2)
637    else:
638        k = np.arange((d + 1) // 2)
639    wk = np.asarray(1.0 / (n ** (2 * k / d)))
640    wkt = wk[None, :] * t[:, None]
641    out = np.concatenate([np.cos(wkt), np.sin(wkt)], axis=-1)
642    if d % 2 == 1:
643        out = out[:, :-1]
644    return out
@partial(jax.jit, static_argnums=(1, 2), inline=True)
def positional_encoding(t, d: int, n: float = 10000.0):
647@partial(jax.jit, static_argnums=(1, 2), inline=True)
648def positional_encoding(t, d: int, n: float = 10000.0):
649    if d % 2 == 0:
650        k = np.arange(d // 2)
651    else:
652        k = np.arange((d + 1) // 2)
653    wk = jnp.asarray(1.0 / (n ** (2 * k / d)))
654    wkt = wk[None, :] * t[:, None]
655    out = jnp.concatenate([jnp.cos(wkt), jnp.sin(wkt)], axis=-1)
656    if d % 2 == 1:
657        out = out[:, :-1]
658    return out
def generate_spherical_jn_basis( dim: int, rc: float, ls: Union[int, Sequence[int]] = [0], print_code: bool = False, jit: bool = False):
661def generate_spherical_jn_basis(dim:int, rc:float, ls:Union[int,Sequence[int]]=[0], print_code:bool=False, jit:bool=False):
662    from sympy import Symbol, jn, expand_func
663    from scipy.special import spherical_jn
664    from sympy import jn_zeros
665    import scipy.integrate as integrate
666
667    if isinstance(ls, int):
668        ls = list(range(ls + 1))
669    zl = [Symbol(f"xz[...,{l}]") for l in ls]
670    zn = np.array([jn_zeros(l, dim) for l in ls], dtype=float).T
671    znrc = zn / rc
672    norms = np.zeros((dim, len(ls)), dtype=float)
673    for l in ls:
674        for i in range(dim):
675            norms[i, l] = (
676                integrate.quad(lambda x: (spherical_jn(l, x) * x) ** 2, 0, zn[i, l])[0]
677                / znrc[i, l] ** 3
678            ) ** (-0.5)
679
680    fn_str = f"""def spherical_jn_basis_(x):
681    from jax.numpy import cos,sin
682    
683    znrc = jnp.array({znrc.tolist()},dtype=x.dtype)
684    norms = jnp.array({norms.tolist()},dtype=x.dtype)
685    xshape = x.shape
686    x = x.reshape(-1)
687    xz = x[:,None,None]*znrc[None,:,:]
688
689    jns = jnp.stack([
690  """
691    for l in ls:
692        fn_str += f"      {expand_func(jn(l, zl[l]))},\n"
693    fn_str += f"""    ],axis=-1)
694    return (norms[None,:,:]*jns).reshape(*xshape,{dim},{len(ls)})
695  """
696
697    if print_code:
698        print(fn_str)
699    exec(fn_str)
700    jn_basis = locals()["spherical_jn_basis_"]
701    if jit:
702        jn_basis = jax.jit(jn_basis)
703    return jn_basis