fennol.models.misc.encodings

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Optional, Union, List, Sequence, ClassVar, Dict
  5import math
  6import dataclasses
  7import numpy as np
  8from .nets import FullyConnectedNet
  9from ...utils import AtomicUnits as au
 10from functools import partial
 11from ...utils.periodic_table import (
 12    PERIODIC_TABLE_REV_IDX,
 13    PERIODIC_TABLE,
 14    EL_STRUCT,
 15    VALENCE_STRUCTURE,
 16    XENONPY_PROPS,
 17    SJS_COORDINATES,
 18    PERIODIC_COORDINATES,
 19    ATOMIC_IONIZATION_ENERGY,
 20    POLARIZABILITIES,
 21    D3_COV_RADII,
 22    VDW_RADII,
 23    OXIDATION_STATES,
 24)
 25
 26
 27class SpeciesEncoding(nn.Module):
 28    """A module that encodes chemical species information.
 29
 30    FID: SPECIES_ENCODING
 31    """
 32
 33    encoding: str = "random"
 34    """ The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". 
 35        Multiple encodings can be concatenated using the "+" separator.
 36    """
 37    dim: int = 16
 38    """ The dimension of the encoding if not fixed by design."""
 39    zmax: int = 86
 40    """ The maximum atomic number to encode."""
 41    output_key: Optional[str] = None
 42    """ The key to use for the output in the returned dictionary."""
 43
 44    species_order: Optional[Union[str, Sequence[str]]] = None
 45    """ The order of the species to use for the encoding. Only used for "onehot" encoding.
 46         If None, we encode all elements up to `zmax`."""
 47    trainable: bool = False
 48    """ Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable."""
 49    extra_params: Dict = dataclasses.field(default_factory=dict)
 50    """ Dictionary of extra parameters for the basis."""
 51
 52    FID: ClassVar[str] = "SPECIES_ENCODING"
 53
 54    @nn.compact
 55    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 56
 57        zmax = self.zmax
 58        if zmax <= 0 or zmax > len(PERIODIC_TABLE):
 59            zmax = len(PERIODIC_TABLE)
 60
 61        zmaxpad = zmax + 2
 62
 63        encoding = self.encoding.lower().strip()
 64        encodings = encoding.split("+")
 65        ############################
 66        conv_tensors = []
 67
 68        if "one_hot" in encodings or "onehot" in encodings:
 69            if self.species_order is None:
 70                conv_tensor = np.eye(zmax)
 71                conv_tensor = np.concatenate(
 72                    [np.zeros((1, zmax)), conv_tensor, np.zeros((1, zmax))], axis=0
 73                )
 74            else:
 75                if isinstance(self.species_order, str):
 76                    species_order = [el.strip() for el in self.species_order.split(",")]
 77                else:
 78                    species_order = [el for el in self.species_order]
 79                conv_tensor = np.zeros((zmaxpad, len(species_order)))
 80                for i, s in enumerate(species_order):
 81                    conv_tensor[PERIODIC_TABLE_REV_IDX[s], i] = 1
 82
 83            conv_tensors.append(conv_tensor)
 84
 85        if "occupation" in encodings:
 86            conv_tensor = np.zeros((zmaxpad, (zmax + 1) // 2))
 87            for i in range(1, zmax + 1):
 88                conv_tensor[i, : i // 2] = 1
 89                if i % 2 == 1:
 90                    conv_tensor[i, i // 2] = 0.5
 91
 92            conv_tensors.append(conv_tensor)
 93
 94        if "electronic_structure" in encodings:
 95            Z = np.arange(1, zmax + 1).reshape(-1, 1)
 96            Zref = [zmax]
 97            e_struct = np.array(EL_STRUCT[1 : zmax + 1])
 98            eref = [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 14, 10, 6]
 99            v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1])
100            vref = [2, 6, 10, 14]
101            if zmax <= 86:
102                e_struct = e_struct[:, :15]
103                eref = eref[:15]
104            ref = np.array(Zref + eref + vref)
105            conv_tensor = np.concatenate([Z, e_struct, v_struct], axis=1)
106            conv_tensor = conv_tensor / ref[None, :]
107            dim = conv_tensor.shape[1]
108            conv_tensor = np.concatenate(
109                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
110            )
111
112            conv_tensors.append(conv_tensor)
113        
114        if "valence_structure" in encodings:
115            vref = np.array([2, 6, 10, 14])
116            conv_tensor = np.array(VALENCE_STRUCTURE[1 : zmax + 1])/vref[None,:]
117            dim = conv_tensor.shape[1]
118            conv_tensor = np.concatenate(
119                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
120            )
121            conv_tensors.append(conv_tensor)
122        
123        if "period_onehot" in encodings:
124            periods = PERIODIC_COORDINATES[1 : zmax + 1, 0]-1
125            conv_tensor = np.eye(7)[periods]
126            dim = conv_tensor.shape[1]
127            conv_tensor = np.concatenate(
128                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
129            )
130            conv_tensors.append(conv_tensor)
131        
132        if "period" in encodings:
133            conv_tensor = (PERIODIC_COORDINATES[1 : zmax + 1, 0]/7.)[:,None]
134            dim = conv_tensor.shape[1]
135            conv_tensor = np.concatenate(
136                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
137            )
138            conv_tensors.append(conv_tensor)
139        
140        if "period_positional" in encodings:
141            row = PERIODIC_COORDINATES[1 : zmax + 1, 0]
142            drow = self.extra_params.get("drow", default=5)
143            nrow = self.extra_params.get("nrow", default=100.0)
144
145            conv_tensor = positional_encoding_static(row, drow, nrow)
146            dim = conv_tensor.shape[1]
147            conv_tensor = np.concatenate(
148                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
149            )
150            conv_tensors.append(conv_tensor)
151
152
153        if "properties" in encodings:
154            props = np.array(XENONPY_PROPS)[1:-1]
155            assert (
156                self.zmax <= props.shape[0]
157            ), f"zmax > {props.shape[0]} not supported for xenonpy properties"
158            conv_tensor = props[1 : zmax + 1]
159            mean = np.mean(props, axis=0)
160            std = np.std(props, axis=0)
161            conv_tensor = (conv_tensor - mean[None, :]) / std[None, :]
162            dim = conv_tensor.shape[1]
163            conv_tensor = np.concatenate(
164                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
165            )
166            conv_tensors.append(conv_tensor)
167        
168        
169        if "valence_properties" in encodings:
170            assert zmax <= 86, "Valence properties only available for zmax <= 86"
171            Z = np.arange(1, zmax + 1).reshape(-1, 1)
172            Zref = [zmax]
173            Zinv = 1.0 / Z
174            Zinvref = [1.0]
175            v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1])
176            vref = [2, 6, 10, 14]
177            ionization = np.array(ATOMIC_IONIZATION_ENERGY[1 : zmax + 1]).reshape(-1, 1)
178            ionizationref = [0.5]
179            polariz = np.array(POLARIZABILITIES[1 : zmax + 1]).reshape(-1, 1)
180            polarizref = [np.median(polariz)]
181
182            cov = np.array(D3_COV_RADII[1 : zmax + 1]).reshape(-1, 1)
183            covref = [np.median(cov)]
184
185            vdw = np.array(VDW_RADII[1 : zmax + 1]).reshape(-1, 1)
186            vdwref = [np.median(vdw)]
187
188            ref = np.array(
189                Zref + Zinvref + ionizationref + polarizref + covref + vdwref + vref
190            )
191            conv_tensor = np.concatenate(
192                [Z, Zinv, ionization, polariz, cov, vdw, v_struct], axis=1
193            )
194
195            conv_tensor = conv_tensor / ref[None, :]
196            dim = conv_tensor.shape[1]
197            conv_tensor = np.concatenate(
198                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
199            )
200            conv_tensors.append(conv_tensor)
201
202        if "sjs_coordinates" in encodings:
203            coords = np.array(SJS_COORDINATES)[1:-1]
204            conv_tensor = coords[1 : zmax + 1]
205            mean = np.mean(coords, axis=0)
206            std = np.std(coords, axis=0)
207            conv_tensor = (conv_tensor - mean[None, :]) / std[None, :]
208            dim = conv_tensor.shape[1]
209            conv_tensor = np.concatenate(
210                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
211            )
212            conv_tensors.append(conv_tensor)
213
214        if "positional" in encodings:
215            coords = np.array(PERIODIC_COORDINATES)[1 : zmax + 1]
216            row, col = coords[:, 0], coords[:, 1]
217            drow = self.extra_params.get("drow", default=5)
218            dcol = self.dim - drow
219            nrow = self.extra_params.get("nrow", default=100.0)
220            ncol = self.extra_params.get("ncol", default=1000.0)
221
222            erow = positional_encoding_static(row, drow, nrow)
223            ecol = positional_encoding_static(col, dcol, ncol)
224            conv_tensor = np.concatenate([erow, ecol], axis=-1)
225            dim = conv_tensor.shape[1]
226            conv_tensor = np.concatenate(
227                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
228            )
229            conv_tensors.append(conv_tensor)
230
231        if "oxidation" in encodings:
232            states_set = sorted(set(sum(OXIDATION_STATES, [])))
233            nstates = len(states_set)
234            state_dict = {s: i for i, s in enumerate(states_set)}
235            conv_tensor = np.zeros((zmaxpad, nstates))
236            for i, states in enumerate(OXIDATION_STATES[1 : zmax + 1]):
237                for s in states:
238                    conv_tensor[i, state_dict[s]] = 1
239            conv_tensors.append(conv_tensor)
240
241        if len(conv_tensors) > 0:
242            conv_tensor = np.concatenate(conv_tensors, axis=1)
243            if self.trainable:
244                conv_tensor = self.param(
245                    "conv_tensor",
246                    lambda key: jnp.asarray(conv_tensor, dtype=jnp.float32),
247                )
248            else:
249                conv_tensor = jnp.asarray(conv_tensor, dtype=jnp.float32)
250            conv_tensors = [conv_tensor]
251        else:
252            conv_tensors = []
253
254        if "random" in encodings:
255            rand_encoding = self.param(
256                "rand_encoding",
257                lambda key, shape: jax.nn.standardize(
258                    jax.random.normal(key, shape, dtype=jnp.float32)
259                ),
260                (zmaxpad, self.dim),
261            )
262            conv_tensors.append(rand_encoding)
263        
264        if "zeros" in encodings:
265            zeros_encoding = self.param(
266                "zeros_encoding",
267                lambda key, shape: jnp.zeros(shape, dtype=jnp.float32),
268                (zmaxpad, self.dim),
269            )
270            conv_tensors.append(zeros_encoding)
271
272        if "randint" in encodings:
273            rand_encoding = self.param(
274                "randint_encoding",
275                lambda key, shape: jax.random.randint(key, shape, 0, 2).astype(
276                    jnp.float32
277                ),
278                (zmaxpad, self.dim),
279            )
280            conv_tensors.append(rand_encoding)
281
282        if "randtri" in encodings:
283            rand_encoding = self.param(
284                "randtri_encoding",
285                lambda key, shape: (jax.random.randint(key, shape, 0, 3) - 1).astype(
286                    jnp.float32
287                ),
288                (zmaxpad, self.dim),
289            )
290            conv_tensors.append(rand_encoding)
291
292        assert len(conv_tensors) > 0, f"No encoding recognized in '{self.encoding}'"
293
294        conv_tensor = jnp.concatenate(conv_tensors, axis=1)
295
296        species = inputs["species"] if isinstance(inputs, dict) else inputs
297        out = conv_tensor[species]
298        ############################
299
300        if isinstance(inputs, dict):
301            output_key = self.name if self.output_key is None else self.output_key
302            out = out.astype(inputs["coordinates"].dtype)
303            return {**inputs, output_key: out} if output_key is not None else out
304        return out
305
306
307class RadialBasis(nn.Module):
308    """Computes a radial encoding of distances.
309
310    FID: RADIAL_BASIS
311    """
312
313    end: float
314    """ The maximum distance to consider."""
315    start: float = 0.0
316    """ The minimum distance to consider."""
317    dim: int = 8
318    """ The dimension of the basis."""
319    graph_key: Optional[str] = None
320    """ The key of the graph in the inputs."""
321    output_key: Optional[str] = None
322    """ The key to use for the output in the returned dictionary."""
323    basis: str = "bessel"
324    """ The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky"."""
325    trainable: bool = False
326    """ Whether the basis parameters are trainable or fixed."""
327    enforce_positive: bool = False
328    """ Whether to enforce distance-start to be positive"""
329    gamma: float = 1.0 / (2 * au.BOHR)
330    """ The gamma parameter for the "spooky" basis."""
331    n_levels: int = 10
332    """ The number of levels for the "levels" basis."""
333    alt_bessel_norm: bool = False
334    """ If True, use the (2/(end-start))**0.5 normalization for the bessel basis."""
335    extra_params: Dict = dataclasses.field(default_factory=dict)
336    """ Dictionary of extra parameters for the basis."""
337
338    FID: ClassVar[str] = "RADIAL_BASIS"
339
340    @nn.compact
341    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
342        if self.graph_key is not None:
343            x = inputs[self.graph_key]["distances"]
344        else:
345            x = inputs["distances"] if isinstance(inputs, dict) else inputs
346
347        shape = x.shape
348        x = x.reshape(-1)
349
350        basis = self.basis.lower()
351        ############################
352        if basis == "bessel":
353            c = self.end - self.start
354            x = x[:, None] - self.start
355            # if self.enforce_positive:
356            #     x = jax.nn.softplus(x)
357
358            if self.trainable:
359                bessel_roots = self.param(
360                    "bessel_roots",
361                    lambda key, dim: jnp.asarray(
362                        np.arange(1, dim + 1, dtype=x.dtype)[None, :] * (math.pi / c)
363                    ),
364                    self.dim,
365                )
366                norm = 1.0 / jnp.max(
367                    bessel_roots
368                )  # (2.0 / c) ** 0.5 /jnp.max(bessel_roots)
369            else:
370                bessel_roots = jnp.asarray(
371                    np.arange(1, self.dim + 1, dtype=x.dtype)[None, :] * (math.pi / c)
372                )
373                norm = 1.0 / (
374                    self.dim * math.pi / c
375                )  # (2.0 / c) ** 0.5/(self.dim*math.pi/c)
376            if self.alt_bessel_norm:
377                norm = (2.0 / c) ** 0.5
378            out = norm * jnp.sin(x * bessel_roots) / x
379
380            if self.enforce_positive:
381                out = jnp.where(x > 0, out * (1.0 - jnp.exp(-(x**2))), 0.0)
382
383        elif basis == "gaussian":
384            if self.trainable:
385                roots = self.param(
386                    "radial_centers",
387                    lambda key, dim, start, end: jnp.linspace(
388                        start, end, dim + 1, dtype=x.dtype
389                    )[None, :-1],
390                    self.dim,
391                    self.start,
392                    self.end,
393                )
394                eta = self.param(
395                    "radial_etas",
396                    lambda key, dim, start, end: jnp.full(
397                        dim,
398                        dim / (end - start),
399                        dtype=x.dtype,
400                    )[None, :],
401                    self.dim,
402                    self.start,
403                    self.end,
404                )
405
406            else:
407                roots = jnp.asarray(
408                    np.linspace(self.start, self.end, self.dim + 1)[None, :-1],
409                    dtype=x.dtype,
410                )
411                eta = jnp.asarray(
412                    np.full(self.dim, self.dim / (self.end - self.start))[None, :],
413                    dtype=x.dtype,
414                )
415
416            x = x[:, None]
417            x2 = (eta * (x - roots)) ** 2
418            out = jnp.exp(-x2)
419            if self.enforce_positive:
420                out = jnp.where(
421                    x > self.start,
422                    out * (1.0 - jnp.exp(-10 * (x - self.start) ** 2)),
423                    0.0,
424                )
425
426        elif basis == "gaussian_rinv":
427            rinv_high = 1.0 / max(self.start, 0.1)
428            rinv_low = 1.0 / (0.8 * self.end)
429
430            if self.trainable:
431                roots = self.param(
432                    "radial_centers",
433                    lambda key, dim, rinv_low, rinv_high: jnp.linspace(
434                        rinv_low, rinv_high, dim, dtype=x.dtype
435                    )[None, :],
436                    self.dim,
437                    rinv_low,
438                    rinv_high,
439                )
440                sigmas = self.param(
441                    "radial_sigmas",
442                    lambda key, dim, rinv_low, rinv_high: jnp.full(
443                        dim,
444                        2**0.5 / (2 * dim * rinv_low),
445                        dtype=x.dtype,
446                    )[None, :],
447                    self.dim,
448                    rinv_low,
449                    rinv_high,
450                )
451            else:
452                roots = jnp.asarray(
453                    np.linspace(rinv_low, rinv_high, self.dim, dtype=x.dtype)[None, :]
454                )
455                sigmas = jnp.asarray(
456                    np.full(
457                        self.dim,
458                        2**0.5 / (2 * self.dim * rinv_low),
459                    )[None, :],
460                    dtype=x.dtype,
461                )
462
463            rinv = 1.0 / x[:, None]
464
465            out = jnp.exp(-((rinv - roots) ** 2) / sigmas**2)
466
467        elif basis == "fourier":
468            if self.trainable:
469                roots = self.param(
470                    "roots",
471                    lambda key, dim: jnp.arange(dim, dtype=x.dtype)[None, :] * math.pi,
472                    self.dim,
473                )
474            else:
475                roots = jnp.asarray(
476                    np.arange(self.dim)[None, :] * math.pi, dtype=x.dtype
477                )
478            c = self.end - self.start
479            x = x[:, None] - self.start
480            # if self.enforce_positive:
481            #     x = jax.nn.softplus(x)
482            norm = 1.0 / (0.25 + 0.5 * self.dim) ** 0.5
483            out = norm * jnp.cos(x * roots / c)
484            if self.enforce_positive:
485                out = jnp.where(x > 0, out, norm)
486
487        elif basis == "spooky":
488
489            gamma = self.gamma
490            if self.trainable:
491                gamma = jnp.abs(
492                    self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype))
493                )
494
495            if self.enforce_positive:
496                x = jnp.where(x - self.start > 1.0e-3, x - self.start, 1.0e-3)[:, None]
497                dim = self.dim
498            else:
499                x = x[:, None] - self.start
500                dim = self.dim - 1
501
502            norms = []
503            for k in range(self.dim):
504                norms.append(math.comb(dim, k))
505            norms = jnp.asarray(np.array(norms)[None, :], dtype=x.dtype)
506
507            e = jnp.exp(-gamma * x)
508            k = jnp.asarray(np.arange(self.dim, dtype=x.dtype)[None, :])
509            b = e**k * (1 - e) ** (dim - k)
510            out = b * e * norms
511            if self.enforce_positive:
512                out = jnp.where(x > 1.0e-3, out * (1.0 - jnp.exp(-(x**2))), 0.0)
513            # logfac = np.zeros(self.dim)
514            # for i in range(2, self.dim):
515            #     logfac[i] = logfac[i - 1] + np.log(i)
516            # k = np.arange(self.dim)
517            # n = self.dim - 1 - k
518            # logbin = jnp.asarray((logfac[-1] - logfac[k] - logfac[n])[None,:], dtype=x.dtype)
519            # n = jnp.asarray(n[None,:], dtype=x.dtype)
520            # k = jnp.asarray(k[None,:], dtype=x.dtype)
521
522            # gamma = 1.0 / (2 * au.BOHR)
523            # if self.trainable:
524            #     gamma = jnp.abs(
525            #         self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype))
526            #     )
527            # gammar = (-gamma * x)[:,None]
528            # x = logbin + n * gammar + k * jnp.log(-jnp.expm1(gammar))
529            # out = jnp.exp(x)*jnp.exp(gammar)
530        elif basis == "levels":
531            assert self.n_levels >= 2, "Number of levels must be >= 2."
532
533            def initialize_levels(key):
534                key0, key1, key_phi = jax.random.split(key, 3)
535                level0 = jax.random.randint(key0, (self.dim,), 0, 2)
536                level1 = jax.random.randint(key1, (self.dim,), 0, 2)
537                # level0 = jax.random.normal(key0, (self.dim,), dtype=jnp.float32)
538                # level1 = jax.random.normal(key1, (self.dim,), dtype=jnp.float32)
539                phi = jax.random.uniform(key_phi, (self.dim,), dtype=jnp.float32)
540                levels = [level0]
541                for l in range(2, self.n_levels - 1):
542                    tau = float(self.n_levels - l) / float(self.n_levels - 1)
543                    phil = phi < tau
544                    level = jnp.where(phil, level0, level1)
545                    levels.append(level)
546                levels.append(level1)
547                return jnp.stack(levels).astype(jnp.float32)
548
549            levels = self.param("levels", initialize_levels)
550            # levels = self.param(
551            #     "levels",
552            #     lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32),
553            #     (self.n_levels,self.dim),
554            # )
555
556            flevel = (x - self.start) / (self.end - self.start) * (self.n_levels - 1)
557            ilevel = jnp.floor(flevel).astype(jnp.int32)
558            ilevel1 = jnp.clip(ilevel + 1, 0, self.n_levels - 1)
559            ilevel = jnp.clip(ilevel, 0, self.n_levels - 1)
560
561            dx = flevel - ilevel
562            w = 0.5 * (1 + jnp.cos(jnp.pi * dx))[:, None]
563
564            ## interpolate between level vectors
565            v1 = levels[ilevel]
566            v2 = levels[ilevel1]
567            out = v1 * w + v2 * (1 - w)
568        elif basis == "finite_support":
569            flevel = (x - self.start) / (self.end - self.start) * (self.dim + 1)
570            ilevel = jnp.floor(flevel).astype(jnp.int32)
571            ilevel1 = jnp.clip(ilevel + 1, 0, self.dim + 1)
572            ilevel = jnp.clip(ilevel, 0, self.dim + 1)
573
574            dx = flevel - ilevel
575            w = 0.5 * (1 + jnp.cos(jnp.pi * dx))
576
577            ilevelflat = ilevel + jnp.arange(x.shape[0]) * (self.dim + 2)
578            ilevel1flat = ilevel1 + jnp.arange(x.shape[0]) * (self.dim + 2)
579
580            out = (
581                jnp.zeros((x.shape[0] * (self.dim + 2)), dtype=x.dtype)
582                .at[ilevelflat]
583                .set(w)
584                .at[ilevel1flat]
585                .set(1 - w)
586                .reshape(-1, self.dim + 2)[:, 1:-1]
587            )
588
589        elif basis == "exp_lr" or basis == "exp":
590            zeta = self.extra_params.get("zeta", default=2.0)
591            s = self.extra_params.get("s", default=0.5)
592            n = np.arange(self.dim)
593            # if self.trainable:
594            # zeta = jnp.abs(
595            #     self.param("zeta", lambda key: jnp.asarray(zeta, dtype=x.dtype))
596            # )
597            # s = jnp.abs(self.param("s", lambda key: jnp.asarray(s, dtype=x.dtype)))
598
599            a = zeta * s**n
600            xx = np.linspace(0, self.end, 10000)
601
602            if self.start > 0:
603                a1 = np.minimum(1.0 / a, 1.0)
604                switchstart = jnp.where(
605                    x[:, None] < self.end,
606                    1
607                    - (
608                        0.5
609                        + 0.5
610                        * jnp.cos(
611                            np.pi * (x[:, None] - self.start) / (self.end - self.start)
612                        )
613                    )
614                    ** a1[None, :],
615                    1,
616                ) * (x[:, None] > self.start)
617                switchstartxx = (
618                    1
619                    - (
620                        0.5
621                        + 0.5
622                        * np.cos(
623                            np.pi * (xx[:, None] - self.start) / (self.end - self.start)
624                        )
625                    )
626                    ** a1[None, :]
627                ) * (xx[:, None] > self.start)
628
629            else:
630                switchstart = 1.0
631                switchstartxx = 1.0
632
633            norm = 1.0 / np.trapz(
634                switchstartxx * np.exp(-a[None, :] * xx[:, None]), xx, axis=0
635            )
636            # norm = 1./np.max(switchstartxx*np.exp(-a[None, :] * xx[:, None]), axis=0)
637
638            if self.trainable:
639                a = jnp.abs(self.param("exponent", lambda key: jnp.asarray(a)))
640
641            out = switchstart * jnp.exp(-a[None, :] * x[:, None]) * norm[None, :]
642
643        elif basis == "neural_net" or basis == "nn":
644            neurons = [*self.extra_params.get(
645                "hidden_neurons", default=[2 * self.dim]
646            ), self.dim]
647            activation = self.extra_params.get("activation", default="swish")
648            use_bias = self.extra_params.get("use_bias", default=True)
649
650            out = FullyConnectedNet(
651                neurons, activation=activation, use_bias=use_bias, squeeze=False
652            )(x[:, None])
653        elif basis == "damped_coulomb":
654            l = np.arange(self.dim)[None, :]
655            a = self.extra_params.get("a", default=1.0)
656            if self.trainable:
657                a = self.param("a", lambda key: jnp.asarray([a]*self.dim)[None,:])
658            a2=a**2
659            x = x[:, None] - self.start
660            end = self.end - self.start
661            x21 = a2 + x**2
662            R21 = a2 + end**2
663            out = (
664                1.0 / x21 ** (0.5 * (l + 1))
665                - 1.0 / (R21 ** (0.5 * (l + 1)))
666                + (x - end) * ((l + 1) * end / (R21 ** (0.5 * (l + 3))))
667            ) * (x < end)
668
669        elif basis.startswith("spherical_bessel_"):
670            l = int(basis.split("_")[-1])
671            out = generate_spherical_jn_basis(self.dim, self.end,[l])(x)
672        else:
673            raise NotImplementedError(f"Unknown radial basis {basis}.")
674        ############################
675
676        out = out.reshape((*shape, self.dim))
677
678        if self.graph_key is not None:
679            output_key = self.name if self.output_key is None else self.output_key
680            return {**inputs, output_key: out}
681        return out
682
683
684def positional_encoding_static(t, d: int, n: float = 10000.0):
685    if d % 2 == 0:
686        k = np.arange(d // 2)
687    else:
688        k = np.arange((d + 1) // 2)
689    wk = np.asarray(1.0 / (n ** (2 * k / d)))
690    wkt = wk[None, :] * t[:, None]
691    out = np.concatenate([np.cos(wkt), np.sin(wkt)], axis=-1)
692    if d % 2 == 1:
693        out = out[:, :-1]
694    return out
695
696
697@partial(jax.jit, static_argnums=(1, 2), inline=True)
698def positional_encoding(t, d: int, n: float = 10000.0):
699    if d % 2 == 0:
700        k = np.arange(d // 2)
701    else:
702        k = np.arange((d + 1) // 2)
703    wk = jnp.asarray(1.0 / (n ** (2 * k / d)))
704    wkt = wk[None, :] * t[:, None]
705    out = jnp.concatenate([jnp.cos(wkt), jnp.sin(wkt)], axis=-1)
706    if d % 2 == 1:
707        out = out[:, :-1]
708    return out
709
710
711def generate_spherical_jn_basis(dim:int, rc:float, ls:Union[int,Sequence[int]]=[0], print_code:bool=False, jit:bool=False):
712    from sympy import Symbol, jn, expand_func
713    from scipy.special import spherical_jn
714    from sympy import jn_zeros
715    import scipy.integrate as integrate
716
717    if isinstance(ls, int):
718        ls = list(range(ls + 1))
719    zl = [Symbol(f"xz[...,{l}]") for l in ls]
720    zn = np.array([jn_zeros(l, dim) for l in ls], dtype=float).T
721    znrc = zn / rc
722    norms = np.zeros((dim, len(ls)), dtype=float)
723    for il,l in enumerate(ls):
724        for i in range(dim):
725            norms[i, il] = (
726                integrate.quad(lambda x: (spherical_jn(l, x) * x) ** 2, 0, zn[i, il])[0]
727                / znrc[i, il] ** 3
728            ) ** (-0.5)
729
730    fn_str = f"""def spherical_jn_basis_(x):
731    from jax.numpy import cos,sin
732    
733    znrc = jnp.array({znrc.tolist()},dtype=x.dtype)
734    norms = jnp.array({norms.tolist()},dtype=x.dtype)
735    xshape = x.shape
736    x = x.reshape(-1)
737    xz = x[:,None,None]*znrc[None,:,:]
738
739    jns = jnp.stack([
740  """
741    for il,l in enumerate(ls):
742        fn_str += f"      {expand_func(jn(l, zl[il]))},\n"
743    fn_str += f"""    ],axis=-1)
744    return (norms[None,:,:]*jns).reshape(*xshape,{dim},{len(ls)})
745  """
746
747    if print_code:
748        print(fn_str)
749    exec(fn_str)
750    jn_basis = locals()["spherical_jn_basis_"]
751    if jit:
752        jn_basis = jax.jit(jn_basis)
753    return jn_basis
class SpeciesEncoding(flax.linen.module.Module):
 28class SpeciesEncoding(nn.Module):
 29    """A module that encodes chemical species information.
 30
 31    FID: SPECIES_ENCODING
 32    """
 33
 34    encoding: str = "random"
 35    """ The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". 
 36        Multiple encodings can be concatenated using the "+" separator.
 37    """
 38    dim: int = 16
 39    """ The dimension of the encoding if not fixed by design."""
 40    zmax: int = 86
 41    """ The maximum atomic number to encode."""
 42    output_key: Optional[str] = None
 43    """ The key to use for the output in the returned dictionary."""
 44
 45    species_order: Optional[Union[str, Sequence[str]]] = None
 46    """ The order of the species to use for the encoding. Only used for "onehot" encoding.
 47         If None, we encode all elements up to `zmax`."""
 48    trainable: bool = False
 49    """ Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable."""
 50    extra_params: Dict = dataclasses.field(default_factory=dict)
 51    """ Dictionary of extra parameters for the basis."""
 52
 53    FID: ClassVar[str] = "SPECIES_ENCODING"
 54
 55    @nn.compact
 56    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 57
 58        zmax = self.zmax
 59        if zmax <= 0 or zmax > len(PERIODIC_TABLE):
 60            zmax = len(PERIODIC_TABLE)
 61
 62        zmaxpad = zmax + 2
 63
 64        encoding = self.encoding.lower().strip()
 65        encodings = encoding.split("+")
 66        ############################
 67        conv_tensors = []
 68
 69        if "one_hot" in encodings or "onehot" in encodings:
 70            if self.species_order is None:
 71                conv_tensor = np.eye(zmax)
 72                conv_tensor = np.concatenate(
 73                    [np.zeros((1, zmax)), conv_tensor, np.zeros((1, zmax))], axis=0
 74                )
 75            else:
 76                if isinstance(self.species_order, str):
 77                    species_order = [el.strip() for el in self.species_order.split(",")]
 78                else:
 79                    species_order = [el for el in self.species_order]
 80                conv_tensor = np.zeros((zmaxpad, len(species_order)))
 81                for i, s in enumerate(species_order):
 82                    conv_tensor[PERIODIC_TABLE_REV_IDX[s], i] = 1
 83
 84            conv_tensors.append(conv_tensor)
 85
 86        if "occupation" in encodings:
 87            conv_tensor = np.zeros((zmaxpad, (zmax + 1) // 2))
 88            for i in range(1, zmax + 1):
 89                conv_tensor[i, : i // 2] = 1
 90                if i % 2 == 1:
 91                    conv_tensor[i, i // 2] = 0.5
 92
 93            conv_tensors.append(conv_tensor)
 94
 95        if "electronic_structure" in encodings:
 96            Z = np.arange(1, zmax + 1).reshape(-1, 1)
 97            Zref = [zmax]
 98            e_struct = np.array(EL_STRUCT[1 : zmax + 1])
 99            eref = [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 14, 10, 6]
100            v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1])
101            vref = [2, 6, 10, 14]
102            if zmax <= 86:
103                e_struct = e_struct[:, :15]
104                eref = eref[:15]
105            ref = np.array(Zref + eref + vref)
106            conv_tensor = np.concatenate([Z, e_struct, v_struct], axis=1)
107            conv_tensor = conv_tensor / ref[None, :]
108            dim = conv_tensor.shape[1]
109            conv_tensor = np.concatenate(
110                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
111            )
112
113            conv_tensors.append(conv_tensor)
114        
115        if "valence_structure" in encodings:
116            vref = np.array([2, 6, 10, 14])
117            conv_tensor = np.array(VALENCE_STRUCTURE[1 : zmax + 1])/vref[None,:]
118            dim = conv_tensor.shape[1]
119            conv_tensor = np.concatenate(
120                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
121            )
122            conv_tensors.append(conv_tensor)
123        
124        if "period_onehot" in encodings:
125            periods = PERIODIC_COORDINATES[1 : zmax + 1, 0]-1
126            conv_tensor = np.eye(7)[periods]
127            dim = conv_tensor.shape[1]
128            conv_tensor = np.concatenate(
129                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
130            )
131            conv_tensors.append(conv_tensor)
132        
133        if "period" in encodings:
134            conv_tensor = (PERIODIC_COORDINATES[1 : zmax + 1, 0]/7.)[:,None]
135            dim = conv_tensor.shape[1]
136            conv_tensor = np.concatenate(
137                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
138            )
139            conv_tensors.append(conv_tensor)
140        
141        if "period_positional" in encodings:
142            row = PERIODIC_COORDINATES[1 : zmax + 1, 0]
143            drow = self.extra_params.get("drow", default=5)
144            nrow = self.extra_params.get("nrow", default=100.0)
145
146            conv_tensor = positional_encoding_static(row, drow, nrow)
147            dim = conv_tensor.shape[1]
148            conv_tensor = np.concatenate(
149                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
150            )
151            conv_tensors.append(conv_tensor)
152
153
154        if "properties" in encodings:
155            props = np.array(XENONPY_PROPS)[1:-1]
156            assert (
157                self.zmax <= props.shape[0]
158            ), f"zmax > {props.shape[0]} not supported for xenonpy properties"
159            conv_tensor = props[1 : zmax + 1]
160            mean = np.mean(props, axis=0)
161            std = np.std(props, axis=0)
162            conv_tensor = (conv_tensor - mean[None, :]) / std[None, :]
163            dim = conv_tensor.shape[1]
164            conv_tensor = np.concatenate(
165                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
166            )
167            conv_tensors.append(conv_tensor)
168        
169        
170        if "valence_properties" in encodings:
171            assert zmax <= 86, "Valence properties only available for zmax <= 86"
172            Z = np.arange(1, zmax + 1).reshape(-1, 1)
173            Zref = [zmax]
174            Zinv = 1.0 / Z
175            Zinvref = [1.0]
176            v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1])
177            vref = [2, 6, 10, 14]
178            ionization = np.array(ATOMIC_IONIZATION_ENERGY[1 : zmax + 1]).reshape(-1, 1)
179            ionizationref = [0.5]
180            polariz = np.array(POLARIZABILITIES[1 : zmax + 1]).reshape(-1, 1)
181            polarizref = [np.median(polariz)]
182
183            cov = np.array(D3_COV_RADII[1 : zmax + 1]).reshape(-1, 1)
184            covref = [np.median(cov)]
185
186            vdw = np.array(VDW_RADII[1 : zmax + 1]).reshape(-1, 1)
187            vdwref = [np.median(vdw)]
188
189            ref = np.array(
190                Zref + Zinvref + ionizationref + polarizref + covref + vdwref + vref
191            )
192            conv_tensor = np.concatenate(
193                [Z, Zinv, ionization, polariz, cov, vdw, v_struct], axis=1
194            )
195
196            conv_tensor = conv_tensor / ref[None, :]
197            dim = conv_tensor.shape[1]
198            conv_tensor = np.concatenate(
199                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
200            )
201            conv_tensors.append(conv_tensor)
202
203        if "sjs_coordinates" in encodings:
204            coords = np.array(SJS_COORDINATES)[1:-1]
205            conv_tensor = coords[1 : zmax + 1]
206            mean = np.mean(coords, axis=0)
207            std = np.std(coords, axis=0)
208            conv_tensor = (conv_tensor - mean[None, :]) / std[None, :]
209            dim = conv_tensor.shape[1]
210            conv_tensor = np.concatenate(
211                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
212            )
213            conv_tensors.append(conv_tensor)
214
215        if "positional" in encodings:
216            coords = np.array(PERIODIC_COORDINATES)[1 : zmax + 1]
217            row, col = coords[:, 0], coords[:, 1]
218            drow = self.extra_params.get("drow", default=5)
219            dcol = self.dim - drow
220            nrow = self.extra_params.get("nrow", default=100.0)
221            ncol = self.extra_params.get("ncol", default=1000.0)
222
223            erow = positional_encoding_static(row, drow, nrow)
224            ecol = positional_encoding_static(col, dcol, ncol)
225            conv_tensor = np.concatenate([erow, ecol], axis=-1)
226            dim = conv_tensor.shape[1]
227            conv_tensor = np.concatenate(
228                [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0
229            )
230            conv_tensors.append(conv_tensor)
231
232        if "oxidation" in encodings:
233            states_set = sorted(set(sum(OXIDATION_STATES, [])))
234            nstates = len(states_set)
235            state_dict = {s: i for i, s in enumerate(states_set)}
236            conv_tensor = np.zeros((zmaxpad, nstates))
237            for i, states in enumerate(OXIDATION_STATES[1 : zmax + 1]):
238                for s in states:
239                    conv_tensor[i, state_dict[s]] = 1
240            conv_tensors.append(conv_tensor)
241
242        if len(conv_tensors) > 0:
243            conv_tensor = np.concatenate(conv_tensors, axis=1)
244            if self.trainable:
245                conv_tensor = self.param(
246                    "conv_tensor",
247                    lambda key: jnp.asarray(conv_tensor, dtype=jnp.float32),
248                )
249            else:
250                conv_tensor = jnp.asarray(conv_tensor, dtype=jnp.float32)
251            conv_tensors = [conv_tensor]
252        else:
253            conv_tensors = []
254
255        if "random" in encodings:
256            rand_encoding = self.param(
257                "rand_encoding",
258                lambda key, shape: jax.nn.standardize(
259                    jax.random.normal(key, shape, dtype=jnp.float32)
260                ),
261                (zmaxpad, self.dim),
262            )
263            conv_tensors.append(rand_encoding)
264        
265        if "zeros" in encodings:
266            zeros_encoding = self.param(
267                "zeros_encoding",
268                lambda key, shape: jnp.zeros(shape, dtype=jnp.float32),
269                (zmaxpad, self.dim),
270            )
271            conv_tensors.append(zeros_encoding)
272
273        if "randint" in encodings:
274            rand_encoding = self.param(
275                "randint_encoding",
276                lambda key, shape: jax.random.randint(key, shape, 0, 2).astype(
277                    jnp.float32
278                ),
279                (zmaxpad, self.dim),
280            )
281            conv_tensors.append(rand_encoding)
282
283        if "randtri" in encodings:
284            rand_encoding = self.param(
285                "randtri_encoding",
286                lambda key, shape: (jax.random.randint(key, shape, 0, 3) - 1).astype(
287                    jnp.float32
288                ),
289                (zmaxpad, self.dim),
290            )
291            conv_tensors.append(rand_encoding)
292
293        assert len(conv_tensors) > 0, f"No encoding recognized in '{self.encoding}'"
294
295        conv_tensor = jnp.concatenate(conv_tensors, axis=1)
296
297        species = inputs["species"] if isinstance(inputs, dict) else inputs
298        out = conv_tensor[species]
299        ############################
300
301        if isinstance(inputs, dict):
302            output_key = self.name if self.output_key is None else self.output_key
303            out = out.astype(inputs["coordinates"].dtype)
304            return {**inputs, output_key: out} if output_key is not None else out
305        return out

A module that encodes chemical species information.

FID: SPECIES_ENCODING

SpeciesEncoding( encoding: str = 'random', dim: int = 16, zmax: int = 86, output_key: Optional[str] = None, species_order: Union[str, Sequence[str], NoneType] = None, trainable: bool = False, extra_params: Dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
encoding: str = 'random'

The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". Multiple encodings can be concatenated using the "+" separator.

dim: int = 16

The dimension of the encoding if not fixed by design.

zmax: int = 86

The maximum atomic number to encode.

output_key: Optional[str] = None

The key to use for the output in the returned dictionary.

species_order: Union[str, Sequence[str], NoneType] = None

The order of the species to use for the encoding. Only used for "onehot" encoding. If None, we encode all elements up to zmax.

trainable: bool = False

Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable.

extra_params: Dict

Dictionary of extra parameters for the basis.

FID: ClassVar[str] = 'SPECIES_ENCODING'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class RadialBasis(flax.linen.module.Module):
308class RadialBasis(nn.Module):
309    """Computes a radial encoding of distances.
310
311    FID: RADIAL_BASIS
312    """
313
314    end: float
315    """ The maximum distance to consider."""
316    start: float = 0.0
317    """ The minimum distance to consider."""
318    dim: int = 8
319    """ The dimension of the basis."""
320    graph_key: Optional[str] = None
321    """ The key of the graph in the inputs."""
322    output_key: Optional[str] = None
323    """ The key to use for the output in the returned dictionary."""
324    basis: str = "bessel"
325    """ The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky"."""
326    trainable: bool = False
327    """ Whether the basis parameters are trainable or fixed."""
328    enforce_positive: bool = False
329    """ Whether to enforce distance-start to be positive"""
330    gamma: float = 1.0 / (2 * au.BOHR)
331    """ The gamma parameter for the "spooky" basis."""
332    n_levels: int = 10
333    """ The number of levels for the "levels" basis."""
334    alt_bessel_norm: bool = False
335    """ If True, use the (2/(end-start))**0.5 normalization for the bessel basis."""
336    extra_params: Dict = dataclasses.field(default_factory=dict)
337    """ Dictionary of extra parameters for the basis."""
338
339    FID: ClassVar[str] = "RADIAL_BASIS"
340
341    @nn.compact
342    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
343        if self.graph_key is not None:
344            x = inputs[self.graph_key]["distances"]
345        else:
346            x = inputs["distances"] if isinstance(inputs, dict) else inputs
347
348        shape = x.shape
349        x = x.reshape(-1)
350
351        basis = self.basis.lower()
352        ############################
353        if basis == "bessel":
354            c = self.end - self.start
355            x = x[:, None] - self.start
356            # if self.enforce_positive:
357            #     x = jax.nn.softplus(x)
358
359            if self.trainable:
360                bessel_roots = self.param(
361                    "bessel_roots",
362                    lambda key, dim: jnp.asarray(
363                        np.arange(1, dim + 1, dtype=x.dtype)[None, :] * (math.pi / c)
364                    ),
365                    self.dim,
366                )
367                norm = 1.0 / jnp.max(
368                    bessel_roots
369                )  # (2.0 / c) ** 0.5 /jnp.max(bessel_roots)
370            else:
371                bessel_roots = jnp.asarray(
372                    np.arange(1, self.dim + 1, dtype=x.dtype)[None, :] * (math.pi / c)
373                )
374                norm = 1.0 / (
375                    self.dim * math.pi / c
376                )  # (2.0 / c) ** 0.5/(self.dim*math.pi/c)
377            if self.alt_bessel_norm:
378                norm = (2.0 / c) ** 0.5
379            out = norm * jnp.sin(x * bessel_roots) / x
380
381            if self.enforce_positive:
382                out = jnp.where(x > 0, out * (1.0 - jnp.exp(-(x**2))), 0.0)
383
384        elif basis == "gaussian":
385            if self.trainable:
386                roots = self.param(
387                    "radial_centers",
388                    lambda key, dim, start, end: jnp.linspace(
389                        start, end, dim + 1, dtype=x.dtype
390                    )[None, :-1],
391                    self.dim,
392                    self.start,
393                    self.end,
394                )
395                eta = self.param(
396                    "radial_etas",
397                    lambda key, dim, start, end: jnp.full(
398                        dim,
399                        dim / (end - start),
400                        dtype=x.dtype,
401                    )[None, :],
402                    self.dim,
403                    self.start,
404                    self.end,
405                )
406
407            else:
408                roots = jnp.asarray(
409                    np.linspace(self.start, self.end, self.dim + 1)[None, :-1],
410                    dtype=x.dtype,
411                )
412                eta = jnp.asarray(
413                    np.full(self.dim, self.dim / (self.end - self.start))[None, :],
414                    dtype=x.dtype,
415                )
416
417            x = x[:, None]
418            x2 = (eta * (x - roots)) ** 2
419            out = jnp.exp(-x2)
420            if self.enforce_positive:
421                out = jnp.where(
422                    x > self.start,
423                    out * (1.0 - jnp.exp(-10 * (x - self.start) ** 2)),
424                    0.0,
425                )
426
427        elif basis == "gaussian_rinv":
428            rinv_high = 1.0 / max(self.start, 0.1)
429            rinv_low = 1.0 / (0.8 * self.end)
430
431            if self.trainable:
432                roots = self.param(
433                    "radial_centers",
434                    lambda key, dim, rinv_low, rinv_high: jnp.linspace(
435                        rinv_low, rinv_high, dim, dtype=x.dtype
436                    )[None, :],
437                    self.dim,
438                    rinv_low,
439                    rinv_high,
440                )
441                sigmas = self.param(
442                    "radial_sigmas",
443                    lambda key, dim, rinv_low, rinv_high: jnp.full(
444                        dim,
445                        2**0.5 / (2 * dim * rinv_low),
446                        dtype=x.dtype,
447                    )[None, :],
448                    self.dim,
449                    rinv_low,
450                    rinv_high,
451                )
452            else:
453                roots = jnp.asarray(
454                    np.linspace(rinv_low, rinv_high, self.dim, dtype=x.dtype)[None, :]
455                )
456                sigmas = jnp.asarray(
457                    np.full(
458                        self.dim,
459                        2**0.5 / (2 * self.dim * rinv_low),
460                    )[None, :],
461                    dtype=x.dtype,
462                )
463
464            rinv = 1.0 / x[:, None]
465
466            out = jnp.exp(-((rinv - roots) ** 2) / sigmas**2)
467
468        elif basis == "fourier":
469            if self.trainable:
470                roots = self.param(
471                    "roots",
472                    lambda key, dim: jnp.arange(dim, dtype=x.dtype)[None, :] * math.pi,
473                    self.dim,
474                )
475            else:
476                roots = jnp.asarray(
477                    np.arange(self.dim)[None, :] * math.pi, dtype=x.dtype
478                )
479            c = self.end - self.start
480            x = x[:, None] - self.start
481            # if self.enforce_positive:
482            #     x = jax.nn.softplus(x)
483            norm = 1.0 / (0.25 + 0.5 * self.dim) ** 0.5
484            out = norm * jnp.cos(x * roots / c)
485            if self.enforce_positive:
486                out = jnp.where(x > 0, out, norm)
487
488        elif basis == "spooky":
489
490            gamma = self.gamma
491            if self.trainable:
492                gamma = jnp.abs(
493                    self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype))
494                )
495
496            if self.enforce_positive:
497                x = jnp.where(x - self.start > 1.0e-3, x - self.start, 1.0e-3)[:, None]
498                dim = self.dim
499            else:
500                x = x[:, None] - self.start
501                dim = self.dim - 1
502
503            norms = []
504            for k in range(self.dim):
505                norms.append(math.comb(dim, k))
506            norms = jnp.asarray(np.array(norms)[None, :], dtype=x.dtype)
507
508            e = jnp.exp(-gamma * x)
509            k = jnp.asarray(np.arange(self.dim, dtype=x.dtype)[None, :])
510            b = e**k * (1 - e) ** (dim - k)
511            out = b * e * norms
512            if self.enforce_positive:
513                out = jnp.where(x > 1.0e-3, out * (1.0 - jnp.exp(-(x**2))), 0.0)
514            # logfac = np.zeros(self.dim)
515            # for i in range(2, self.dim):
516            #     logfac[i] = logfac[i - 1] + np.log(i)
517            # k = np.arange(self.dim)
518            # n = self.dim - 1 - k
519            # logbin = jnp.asarray((logfac[-1] - logfac[k] - logfac[n])[None,:], dtype=x.dtype)
520            # n = jnp.asarray(n[None,:], dtype=x.dtype)
521            # k = jnp.asarray(k[None,:], dtype=x.dtype)
522
523            # gamma = 1.0 / (2 * au.BOHR)
524            # if self.trainable:
525            #     gamma = jnp.abs(
526            #         self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype))
527            #     )
528            # gammar = (-gamma * x)[:,None]
529            # x = logbin + n * gammar + k * jnp.log(-jnp.expm1(gammar))
530            # out = jnp.exp(x)*jnp.exp(gammar)
531        elif basis == "levels":
532            assert self.n_levels >= 2, "Number of levels must be >= 2."
533
534            def initialize_levels(key):
535                key0, key1, key_phi = jax.random.split(key, 3)
536                level0 = jax.random.randint(key0, (self.dim,), 0, 2)
537                level1 = jax.random.randint(key1, (self.dim,), 0, 2)
538                # level0 = jax.random.normal(key0, (self.dim,), dtype=jnp.float32)
539                # level1 = jax.random.normal(key1, (self.dim,), dtype=jnp.float32)
540                phi = jax.random.uniform(key_phi, (self.dim,), dtype=jnp.float32)
541                levels = [level0]
542                for l in range(2, self.n_levels - 1):
543                    tau = float(self.n_levels - l) / float(self.n_levels - 1)
544                    phil = phi < tau
545                    level = jnp.where(phil, level0, level1)
546                    levels.append(level)
547                levels.append(level1)
548                return jnp.stack(levels).astype(jnp.float32)
549
550            levels = self.param("levels", initialize_levels)
551            # levels = self.param(
552            #     "levels",
553            #     lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32),
554            #     (self.n_levels,self.dim),
555            # )
556
557            flevel = (x - self.start) / (self.end - self.start) * (self.n_levels - 1)
558            ilevel = jnp.floor(flevel).astype(jnp.int32)
559            ilevel1 = jnp.clip(ilevel + 1, 0, self.n_levels - 1)
560            ilevel = jnp.clip(ilevel, 0, self.n_levels - 1)
561
562            dx = flevel - ilevel
563            w = 0.5 * (1 + jnp.cos(jnp.pi * dx))[:, None]
564
565            ## interpolate between level vectors
566            v1 = levels[ilevel]
567            v2 = levels[ilevel1]
568            out = v1 * w + v2 * (1 - w)
569        elif basis == "finite_support":
570            flevel = (x - self.start) / (self.end - self.start) * (self.dim + 1)
571            ilevel = jnp.floor(flevel).astype(jnp.int32)
572            ilevel1 = jnp.clip(ilevel + 1, 0, self.dim + 1)
573            ilevel = jnp.clip(ilevel, 0, self.dim + 1)
574
575            dx = flevel - ilevel
576            w = 0.5 * (1 + jnp.cos(jnp.pi * dx))
577
578            ilevelflat = ilevel + jnp.arange(x.shape[0]) * (self.dim + 2)
579            ilevel1flat = ilevel1 + jnp.arange(x.shape[0]) * (self.dim + 2)
580
581            out = (
582                jnp.zeros((x.shape[0] * (self.dim + 2)), dtype=x.dtype)
583                .at[ilevelflat]
584                .set(w)
585                .at[ilevel1flat]
586                .set(1 - w)
587                .reshape(-1, self.dim + 2)[:, 1:-1]
588            )
589
590        elif basis == "exp_lr" or basis == "exp":
591            zeta = self.extra_params.get("zeta", default=2.0)
592            s = self.extra_params.get("s", default=0.5)
593            n = np.arange(self.dim)
594            # if self.trainable:
595            # zeta = jnp.abs(
596            #     self.param("zeta", lambda key: jnp.asarray(zeta, dtype=x.dtype))
597            # )
598            # s = jnp.abs(self.param("s", lambda key: jnp.asarray(s, dtype=x.dtype)))
599
600            a = zeta * s**n
601            xx = np.linspace(0, self.end, 10000)
602
603            if self.start > 0:
604                a1 = np.minimum(1.0 / a, 1.0)
605                switchstart = jnp.where(
606                    x[:, None] < self.end,
607                    1
608                    - (
609                        0.5
610                        + 0.5
611                        * jnp.cos(
612                            np.pi * (x[:, None] - self.start) / (self.end - self.start)
613                        )
614                    )
615                    ** a1[None, :],
616                    1,
617                ) * (x[:, None] > self.start)
618                switchstartxx = (
619                    1
620                    - (
621                        0.5
622                        + 0.5
623                        * np.cos(
624                            np.pi * (xx[:, None] - self.start) / (self.end - self.start)
625                        )
626                    )
627                    ** a1[None, :]
628                ) * (xx[:, None] > self.start)
629
630            else:
631                switchstart = 1.0
632                switchstartxx = 1.0
633
634            norm = 1.0 / np.trapz(
635                switchstartxx * np.exp(-a[None, :] * xx[:, None]), xx, axis=0
636            )
637            # norm = 1./np.max(switchstartxx*np.exp(-a[None, :] * xx[:, None]), axis=0)
638
639            if self.trainable:
640                a = jnp.abs(self.param("exponent", lambda key: jnp.asarray(a)))
641
642            out = switchstart * jnp.exp(-a[None, :] * x[:, None]) * norm[None, :]
643
644        elif basis == "neural_net" or basis == "nn":
645            neurons = [*self.extra_params.get(
646                "hidden_neurons", default=[2 * self.dim]
647            ), self.dim]
648            activation = self.extra_params.get("activation", default="swish")
649            use_bias = self.extra_params.get("use_bias", default=True)
650
651            out = FullyConnectedNet(
652                neurons, activation=activation, use_bias=use_bias, squeeze=False
653            )(x[:, None])
654        elif basis == "damped_coulomb":
655            l = np.arange(self.dim)[None, :]
656            a = self.extra_params.get("a", default=1.0)
657            if self.trainable:
658                a = self.param("a", lambda key: jnp.asarray([a]*self.dim)[None,:])
659            a2=a**2
660            x = x[:, None] - self.start
661            end = self.end - self.start
662            x21 = a2 + x**2
663            R21 = a2 + end**2
664            out = (
665                1.0 / x21 ** (0.5 * (l + 1))
666                - 1.0 / (R21 ** (0.5 * (l + 1)))
667                + (x - end) * ((l + 1) * end / (R21 ** (0.5 * (l + 3))))
668            ) * (x < end)
669
670        elif basis.startswith("spherical_bessel_"):
671            l = int(basis.split("_")[-1])
672            out = generate_spherical_jn_basis(self.dim, self.end,[l])(x)
673        else:
674            raise NotImplementedError(f"Unknown radial basis {basis}.")
675        ############################
676
677        out = out.reshape((*shape, self.dim))
678
679        if self.graph_key is not None:
680            output_key = self.name if self.output_key is None else self.output_key
681            return {**inputs, output_key: out}
682        return out

Computes a radial encoding of distances.

FID: RADIAL_BASIS

RadialBasis( end: float, start: float = 0.0, dim: int = 8, graph_key: Optional[str] = None, output_key: Optional[str] = None, basis: str = 'bessel', trainable: bool = False, enforce_positive: bool = False, gamma: float = 0.9448630639252209, n_levels: int = 10, alt_bessel_norm: bool = False, extra_params: Dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
end: float

The maximum distance to consider.

start: float = 0.0

The minimum distance to consider.

dim: int = 8

The dimension of the basis.

graph_key: Optional[str] = None

The key of the graph in the inputs.

output_key: Optional[str] = None

The key to use for the output in the returned dictionary.

basis: str = 'bessel'

The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky".

trainable: bool = False

Whether the basis parameters are trainable or fixed.

enforce_positive: bool = False

Whether to enforce distance-start to be positive

gamma: float = 0.9448630639252209

The gamma parameter for the "spooky" basis.

n_levels: int = 10

The number of levels for the "levels" basis.

alt_bessel_norm: bool = False

If True, use the (2/(end-start))**0.5 normalization for the bessel basis.

extra_params: Dict

Dictionary of extra parameters for the basis.

FID: ClassVar[str] = 'RADIAL_BASIS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
def positional_encoding_static(t, d: int, n: float = 10000.0):
685def positional_encoding_static(t, d: int, n: float = 10000.0):
686    if d % 2 == 0:
687        k = np.arange(d // 2)
688    else:
689        k = np.arange((d + 1) // 2)
690    wk = np.asarray(1.0 / (n ** (2 * k / d)))
691    wkt = wk[None, :] * t[:, None]
692    out = np.concatenate([np.cos(wkt), np.sin(wkt)], axis=-1)
693    if d % 2 == 1:
694        out = out[:, :-1]
695    return out
@partial(jax.jit, static_argnums=(1, 2), inline=True)
def positional_encoding(t, d: int, n: float = 10000.0):
698@partial(jax.jit, static_argnums=(1, 2), inline=True)
699def positional_encoding(t, d: int, n: float = 10000.0):
700    if d % 2 == 0:
701        k = np.arange(d // 2)
702    else:
703        k = np.arange((d + 1) // 2)
704    wk = jnp.asarray(1.0 / (n ** (2 * k / d)))
705    wkt = wk[None, :] * t[:, None]
706    out = jnp.concatenate([jnp.cos(wkt), jnp.sin(wkt)], axis=-1)
707    if d % 2 == 1:
708        out = out[:, :-1]
709    return out
def generate_spherical_jn_basis( dim: int, rc: float, ls: Union[int, Sequence[int]] = [0], print_code: bool = False, jit: bool = False):
712def generate_spherical_jn_basis(dim:int, rc:float, ls:Union[int,Sequence[int]]=[0], print_code:bool=False, jit:bool=False):
713    from sympy import Symbol, jn, expand_func
714    from scipy.special import spherical_jn
715    from sympy import jn_zeros
716    import scipy.integrate as integrate
717
718    if isinstance(ls, int):
719        ls = list(range(ls + 1))
720    zl = [Symbol(f"xz[...,{l}]") for l in ls]
721    zn = np.array([jn_zeros(l, dim) for l in ls], dtype=float).T
722    znrc = zn / rc
723    norms = np.zeros((dim, len(ls)), dtype=float)
724    for il,l in enumerate(ls):
725        for i in range(dim):
726            norms[i, il] = (
727                integrate.quad(lambda x: (spherical_jn(l, x) * x) ** 2, 0, zn[i, il])[0]
728                / znrc[i, il] ** 3
729            ) ** (-0.5)
730
731    fn_str = f"""def spherical_jn_basis_(x):
732    from jax.numpy import cos,sin
733    
734    znrc = jnp.array({znrc.tolist()},dtype=x.dtype)
735    norms = jnp.array({norms.tolist()},dtype=x.dtype)
736    xshape = x.shape
737    x = x.reshape(-1)
738    xz = x[:,None,None]*znrc[None,:,:]
739
740    jns = jnp.stack([
741  """
742    for il,l in enumerate(ls):
743        fn_str += f"      {expand_func(jn(l, zl[il]))},\n"
744    fn_str += f"""    ],axis=-1)
745    return (norms[None,:,:]*jns).reshape(*xshape,{dim},{len(ls)})
746  """
747
748    if print_code:
749        print(fn_str)
750    exec(fn_str)
751    jn_basis = locals()["spherical_jn_basis_"]
752    if jit:
753        jn_basis = jax.jit(jn_basis)
754    return jn_basis