fennol.models.misc.encodings
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Optional, Union, List, Sequence, ClassVar, Dict 5import math 6import dataclasses 7import numpy as np 8from .nets import FullyConnectedNet 9from ...utils import AtomicUnits as au 10from functools import partial 11from ...utils.periodic_table import ( 12 PERIODIC_TABLE_REV_IDX, 13 PERIODIC_TABLE, 14 EL_STRUCT, 15 VALENCE_STRUCTURE, 16 XENONPY_PROPS, 17 SJS_COORDINATES, 18 PERIODIC_COORDINATES, 19 ATOMIC_IONIZATION_ENERGY, 20 POLARIZABILITIES, 21 D3_COV_RADII, 22 VDW_RADII, 23 OXIDATION_STATES, 24) 25 26 27class SpeciesEncoding(nn.Module): 28 """A module that encodes chemical species information. 29 30 FID: SPECIES_ENCODING 31 """ 32 33 encoding: str = "random" 34 """ The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". 35 Multiple encodings can be concatenated using the "+" separator. 36 """ 37 dim: int = 16 38 """ The dimension of the encoding if not fixed by design.""" 39 zmax: int = 86 40 """ The maximum atomic number to encode.""" 41 output_key: Optional[str] = None 42 """ The key to use for the output in the returned dictionary.""" 43 44 species_order: Optional[Union[str, Sequence[str]]] = None 45 """ The order of the species to use for the encoding. Only used for "onehot" encoding. 46 If None, we encode all elements up to `zmax`.""" 47 trainable: bool = False 48 """ Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable.""" 49 extra_params: Dict = dataclasses.field(default_factory=dict) 50 """ Dictionary of extra parameters for the basis.""" 51 52 FID: ClassVar[str] = "SPECIES_ENCODING" 53 54 @nn.compact 55 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 56 57 zmax = self.zmax 58 if zmax <= 0 or zmax > len(PERIODIC_TABLE): 59 zmax = len(PERIODIC_TABLE) 60 61 zmaxpad = zmax + 2 62 63 encoding = self.encoding.lower().strip() 64 encodings = encoding.split("+") 65 ############################ 66 conv_tensors = [] 67 68 if "one_hot" in encodings or "onehot" in encodings: 69 if self.species_order is None: 70 conv_tensor = np.eye(zmax) 71 conv_tensor = np.concatenate( 72 [np.zeros((1, zmax)), conv_tensor, np.zeros((1, zmax))], axis=0 73 ) 74 else: 75 if isinstance(self.species_order, str): 76 species_order = [el.strip() for el in self.species_order.split(",")] 77 else: 78 species_order = [el for el in self.species_order] 79 conv_tensor = np.zeros((zmaxpad, len(species_order))) 80 for i, s in enumerate(species_order): 81 conv_tensor[PERIODIC_TABLE_REV_IDX[s], i] = 1 82 83 conv_tensors.append(conv_tensor) 84 85 if "occupation" in encodings: 86 conv_tensor = np.zeros((zmaxpad, (zmax + 1) // 2)) 87 for i in range(1, zmax + 1): 88 conv_tensor[i, : i // 2] = 1 89 if i % 2 == 1: 90 conv_tensor[i, i // 2] = 0.5 91 92 conv_tensors.append(conv_tensor) 93 94 if "electronic_structure" in encodings: 95 Z = np.arange(1, zmax + 1).reshape(-1, 1) 96 Zref = [zmax] 97 e_struct = np.array(EL_STRUCT[1 : zmax + 1]) 98 eref = [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 14, 10, 6] 99 v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1]) 100 vref = [2, 6, 10, 14] 101 if zmax <= 86: 102 e_struct = e_struct[:, :15] 103 eref = eref[:15] 104 ref = np.array(Zref + eref + vref) 105 conv_tensor = np.concatenate([Z, e_struct, v_struct], axis=1) 106 conv_tensor = conv_tensor / ref[None, :] 107 dim = conv_tensor.shape[1] 108 conv_tensor = np.concatenate( 109 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 110 ) 111 112 conv_tensors.append(conv_tensor) 113 114 if "properties" in encodings: 115 props = np.array(XENONPY_PROPS)[1:-1] 116 assert ( 117 self.zmax <= props.shape[0] 118 ), f"zmax > {props.shape[0]} not supported for xenonpy properties" 119 conv_tensor = props[1 : zmax + 1] 120 mean = np.mean(props, axis=0) 121 std = np.std(props, axis=0) 122 conv_tensor = (conv_tensor - mean[None, :]) / std[None, :] 123 dim = conv_tensor.shape[1] 124 conv_tensor = np.concatenate( 125 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 126 ) 127 conv_tensors.append(conv_tensor) 128 129 if "valence_properties" in encodings: 130 assert zmax <= 86, "Valence properties only available for zmax <= 86" 131 Z = np.arange(1, zmax + 1).reshape(-1, 1) 132 Zref = [zmax] 133 Zinv = 1.0 / Z 134 Zinvref = [1.0] 135 v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1]) 136 vref = [2, 6, 10, 14] 137 ionization = np.array(ATOMIC_IONIZATION_ENERGY[1 : zmax + 1]).reshape(-1, 1) 138 ionizationref = [0.5] 139 polariz = np.array(POLARIZABILITIES[1 : zmax + 1]).reshape(-1, 1) 140 polarizref = [np.median(polariz)] 141 142 cov = np.array(D3_COV_RADII[1 : zmax + 1]).reshape(-1, 1) 143 covref = [np.median(cov)] 144 145 vdw = np.array(VDW_RADII[1 : zmax + 1]).reshape(-1, 1) 146 vdwref = [np.median(vdw)] 147 148 ref = np.array( 149 Zref + Zinvref + ionizationref + polarizref + covref + vdwref + vref 150 ) 151 conv_tensor = np.concatenate( 152 [Z, Zinv, ionization, polariz, cov, vdw, v_struct], axis=1 153 ) 154 155 conv_tensor = conv_tensor / ref[None, :] 156 dim = conv_tensor.shape[1] 157 conv_tensor = np.concatenate( 158 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 159 ) 160 conv_tensors.append(conv_tensor) 161 162 if "sjs_coordinates" in encodings: 163 coords = np.array(SJS_COORDINATES)[1:-1] 164 conv_tensor = coords[1 : zmax + 1] 165 mean = np.mean(coords, axis=0) 166 std = np.std(coords, axis=0) 167 conv_tensor = (conv_tensor - mean[None, :]) / std[None, :] 168 dim = conv_tensor.shape[1] 169 conv_tensor = np.concatenate( 170 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 171 ) 172 conv_tensors.append(conv_tensor) 173 174 if "positional" in encodings: 175 coords = np.array(PERIODIC_COORDINATES)[1 : zmax + 1] 176 row, col = coords[:, 0], coords[:, 1] 177 drow = self.extra_params.get("drow", default=5) 178 dcol = self.dim - drow 179 nrow = self.extra_params.get("nrow", default=100.0) 180 ncol = self.extra_params.get("ncol", default=1000.0) 181 182 erow = positional_encoding_static(row, drow, nrow) 183 ecol = positional_encoding_static(col, dcol, ncol) 184 conv_tensor = np.concatenate([erow, ecol], axis=-1) 185 dim = conv_tensor.shape[1] 186 conv_tensor = np.concatenate( 187 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 188 ) 189 conv_tensors.append(conv_tensor) 190 191 if "oxidation" in encodings: 192 states_set = sorted(set(sum(OXIDATION_STATES, []))) 193 nstates = len(states_set) 194 state_dict = {s: i for i, s in enumerate(states_set)} 195 conv_tensor = np.zeros((zmaxpad, nstates)) 196 for i, states in enumerate(OXIDATION_STATES[1 : zmax + 1]): 197 for s in states: 198 conv_tensor[i, state_dict[s]] = 1 199 conv_tensors.append(conv_tensor) 200 201 if len(conv_tensors) > 0: 202 conv_tensor = np.concatenate(conv_tensors, axis=1) 203 if self.trainable: 204 conv_tensor = self.param( 205 "conv_tensor", 206 lambda key: jnp.asarray(conv_tensor, dtype=jnp.float32), 207 ) 208 else: 209 conv_tensor = jnp.asarray(conv_tensor, dtype=jnp.float32) 210 conv_tensors = [conv_tensor] 211 else: 212 conv_tensors = [] 213 214 if "random" in encodings: 215 rand_encoding = self.param( 216 "rand_encoding", 217 lambda key, shape: jax.nn.standardize( 218 jax.random.normal(key, shape, dtype=jnp.float32) 219 ), 220 (zmaxpad, self.dim), 221 ) 222 conv_tensors.append(rand_encoding) 223 224 if "randint" in encodings: 225 rand_encoding = self.param( 226 "randint_encoding", 227 lambda key, shape: jax.random.randint(key, shape, 0, 2).astype( 228 jnp.float32 229 ), 230 (zmaxpad, self.dim), 231 ) 232 conv_tensors.append(rand_encoding) 233 234 if "randtri" in encodings: 235 rand_encoding = self.param( 236 "randtri_encoding", 237 lambda key, shape: (jax.random.randint(key, shape, 0, 3) - 1).astype( 238 jnp.float32 239 ), 240 (zmaxpad, self.dim), 241 ) 242 conv_tensors.append(rand_encoding) 243 244 assert len(conv_tensors) > 0, f"No encoding recognized in '{self.encoding}'" 245 246 conv_tensor = jnp.concatenate(conv_tensors, axis=1) 247 248 species = inputs["species"] if isinstance(inputs, dict) else inputs 249 out = conv_tensor[species] 250 ############################ 251 252 if isinstance(inputs, dict): 253 output_key = self.name if self.output_key is None else self.output_key 254 out = out.astype(inputs["coordinates"].dtype) 255 return {**inputs, output_key: out} if output_key is not None else out 256 return out 257 258 259class RadialBasis(nn.Module): 260 """Computes a radial encoding of distances. 261 262 FID: RADIAL_BASIS 263 """ 264 265 end: float 266 """ The maximum distance to consider.""" 267 start: float = 0.0 268 """ The minimum distance to consider.""" 269 dim: int = 8 270 """ The dimension of the basis.""" 271 graph_key: Optional[str] = None 272 """ The key of the graph in the inputs.""" 273 output_key: Optional[str] = None 274 """ The key to use for the output in the returned dictionary.""" 275 basis: str = "bessel" 276 """ The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky".""" 277 trainable: bool = False 278 """ Whether the basis parameters are trainable or fixed.""" 279 enforce_positive: bool = False 280 """ Whether to enforce distance-start to be positive""" 281 gamma: float = 1.0 / (2 * au.BOHR) 282 """ The gamma parameter for the "spooky" basis.""" 283 n_levels: int = 10 284 """ The number of levels for the "levels" basis.""" 285 alt_bessel_norm: bool = False 286 """ If True, use the (2/(end-start))**0.5 normalization for the bessel basis.""" 287 extra_params: Dict = dataclasses.field(default_factory=dict) 288 """ Dictionary of extra parameters for the basis.""" 289 290 FID: ClassVar[str] = "RADIAL_BASIS" 291 292 @nn.compact 293 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 294 if self.graph_key is not None: 295 x = inputs[self.graph_key]["distances"] 296 else: 297 x = inputs["distances"] if isinstance(inputs, dict) else inputs 298 299 shape = x.shape 300 x = x.reshape(-1) 301 302 basis = self.basis.lower() 303 ############################ 304 if basis == "bessel": 305 c = self.end - self.start 306 x = x[:, None] - self.start 307 # if self.enforce_positive: 308 # x = jax.nn.softplus(x) 309 310 if self.trainable: 311 bessel_roots = self.param( 312 "bessel_roots", 313 lambda key, dim: jnp.asarray( 314 np.arange(1, dim + 1, dtype=x.dtype)[None, :] * (math.pi / c) 315 ), 316 self.dim, 317 ) 318 norm = 1.0 / jnp.max( 319 bessel_roots 320 ) # (2.0 / c) ** 0.5 /jnp.max(bessel_roots) 321 else: 322 bessel_roots = jnp.asarray( 323 np.arange(1, self.dim + 1, dtype=x.dtype)[None, :] * (math.pi / c) 324 ) 325 norm = 1.0 / ( 326 self.dim * math.pi / c 327 ) # (2.0 / c) ** 0.5/(self.dim*math.pi/c) 328 if self.alt_bessel_norm: 329 norm = (2.0 / c) ** 0.5 330 out = norm * jnp.sin(x * bessel_roots) / x 331 332 if self.enforce_positive: 333 out = jnp.where(x > 0, out * (1.0 - jnp.exp(-(x**2))), 0.0) 334 335 elif basis == "gaussian": 336 if self.trainable: 337 roots = self.param( 338 "radial_centers", 339 lambda key, dim, start, end: jnp.linspace( 340 start, end, dim + 1, dtype=x.dtype 341 )[None, :-1], 342 self.dim, 343 self.start, 344 self.end, 345 ) 346 eta = self.param( 347 "radial_etas", 348 lambda key, dim, start, end: jnp.full( 349 dim, 350 dim / (end - start), 351 dtype=x.dtype, 352 )[None, :], 353 self.dim, 354 self.start, 355 self.end, 356 ) 357 358 else: 359 roots = jnp.asarray( 360 np.linspace(self.start, self.end, self.dim + 1)[None, :-1], 361 dtype=x.dtype, 362 ) 363 eta = jnp.asarray( 364 np.full(self.dim, self.dim / (self.end - self.start))[None, :], 365 dtype=x.dtype, 366 ) 367 368 x = x[:, None] 369 x2 = (eta * (x - roots)) ** 2 370 out = jnp.exp(-x2) 371 if self.enforce_positive: 372 out = jnp.where( 373 x > self.start, 374 out * (1.0 - jnp.exp(-10 * (x - self.start) ** 2)), 375 0.0, 376 ) 377 378 elif basis == "gaussian_rinv": 379 rinv_high = 1.0 / max(self.start, 0.1) 380 rinv_low = 1.0 / (0.8 * self.end) 381 382 if self.trainable: 383 roots = self.param( 384 "radial_centers", 385 lambda key, dim, rinv_low, rinv_high: jnp.linspace( 386 rinv_low, rinv_high, dim, dtype=x.dtype 387 )[None, :], 388 self.dim, 389 rinv_low, 390 rinv_high, 391 ) 392 sigmas = self.param( 393 "radial_sigmas", 394 lambda key, dim, rinv_low, rinv_high: jnp.full( 395 dim, 396 2**0.5 / (2 * dim * rinv_low), 397 dtype=x.dtype, 398 )[None, :], 399 self.dim, 400 rinv_low, 401 rinv_high, 402 ) 403 else: 404 roots = jnp.asarray( 405 np.linspace(rinv_low, rinv_high, self.dim, dtype=x.dtype)[None, :] 406 ) 407 sigmas = jnp.asarray( 408 np.full( 409 self.dim, 410 2**0.5 / (2 * self.dim * rinv_low), 411 )[None, :], 412 dtype=x.dtype, 413 ) 414 415 rinv = 1.0 / x[:, None] 416 417 out = jnp.exp(-((rinv - roots) ** 2) / sigmas**2) 418 419 elif basis == "fourier": 420 if self.trainable: 421 roots = self.param( 422 "roots", 423 lambda key, dim: jnp.arange(dim, dtype=x.dtype)[None, :] * math.pi, 424 self.dim, 425 ) 426 else: 427 roots = jnp.asarray( 428 np.arange(self.dim)[None, :] * math.pi, dtype=x.dtype 429 ) 430 c = self.end - self.start 431 x = x[:, None] - self.start 432 # if self.enforce_positive: 433 # x = jax.nn.softplus(x) 434 norm = 1.0 / (0.25 + 0.5 * self.dim) ** 0.5 435 out = norm * jnp.cos(x * roots / c) 436 if self.enforce_positive: 437 out = jnp.where(x > 0, out, norm) 438 439 elif basis == "spooky": 440 441 gamma = self.gamma 442 if self.trainable: 443 gamma = jnp.abs( 444 self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype)) 445 ) 446 447 if self.enforce_positive: 448 x = jnp.where(x - self.start > 1.0e-3, x - self.start, 1.0e-3)[:, None] 449 dim = self.dim 450 else: 451 x = x[:, None] - self.start 452 dim = self.dim - 1 453 454 norms = [] 455 for k in range(self.dim): 456 norms.append(math.comb(dim, k)) 457 norms = jnp.asarray(np.array(norms)[None, :], dtype=x.dtype) 458 459 e = jnp.exp(-gamma * x) 460 k = jnp.asarray(np.arange(self.dim, dtype=x.dtype)[None, :]) 461 b = e**k * (1 - e) ** (dim - k) 462 out = b * e * norms 463 if self.enforce_positive: 464 out = jnp.where(x > 1.0e-3, out * (1.0 - jnp.exp(-(x**2))), 0.0) 465 # logfac = np.zeros(self.dim) 466 # for i in range(2, self.dim): 467 # logfac[i] = logfac[i - 1] + np.log(i) 468 # k = np.arange(self.dim) 469 # n = self.dim - 1 - k 470 # logbin = jnp.asarray((logfac[-1] - logfac[k] - logfac[n])[None,:], dtype=x.dtype) 471 # n = jnp.asarray(n[None,:], dtype=x.dtype) 472 # k = jnp.asarray(k[None,:], dtype=x.dtype) 473 474 # gamma = 1.0 / (2 * au.BOHR) 475 # if self.trainable: 476 # gamma = jnp.abs( 477 # self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype)) 478 # ) 479 # gammar = (-gamma * x)[:,None] 480 # x = logbin + n * gammar + k * jnp.log(-jnp.expm1(gammar)) 481 # out = jnp.exp(x)*jnp.exp(gammar) 482 elif basis == "levels": 483 assert self.n_levels >= 2, "Number of levels must be >= 2." 484 485 def initialize_levels(key): 486 key0, key1, key_phi = jax.random.split(key, 3) 487 level0 = jax.random.randint(key0, (self.dim,), 0, 2) 488 level1 = jax.random.randint(key1, (self.dim,), 0, 2) 489 # level0 = jax.random.normal(key0, (self.dim,), dtype=jnp.float32) 490 # level1 = jax.random.normal(key1, (self.dim,), dtype=jnp.float32) 491 phi = jax.random.uniform(key_phi, (self.dim,), dtype=jnp.float32) 492 levels = [level0] 493 for l in range(2, self.n_levels - 1): 494 tau = float(self.n_levels - l) / float(self.n_levels - 1) 495 phil = phi < tau 496 level = jnp.where(phil, level0, level1) 497 levels.append(level) 498 levels.append(level1) 499 return jnp.stack(levels).astype(jnp.float32) 500 501 levels = self.param("levels", initialize_levels) 502 # levels = self.param( 503 # "levels", 504 # lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32), 505 # (self.n_levels,self.dim), 506 # ) 507 508 flevel = (x - self.start) / (self.end - self.start) * (self.n_levels - 1) 509 ilevel = jnp.floor(flevel).astype(jnp.int32) 510 ilevel1 = jnp.clip(ilevel + 1, 0, self.n_levels - 1) 511 ilevel = jnp.clip(ilevel, 0, self.n_levels - 1) 512 513 dx = flevel - ilevel 514 w = 0.5 * (1 + jnp.cos(jnp.pi * dx))[:, None] 515 516 ## interpolate between level vectors 517 v1 = levels[ilevel] 518 v2 = levels[ilevel1] 519 out = v1 * w + v2 * (1 - w) 520 elif basis == "finite_support": 521 flevel = (x - self.start) / (self.end - self.start) * (self.dim + 1) 522 ilevel = jnp.floor(flevel).astype(jnp.int32) 523 ilevel1 = jnp.clip(ilevel + 1, 0, self.dim + 1) 524 ilevel = jnp.clip(ilevel, 0, self.dim + 1) 525 526 dx = flevel - ilevel 527 w = 0.5 * (1 + jnp.cos(jnp.pi * dx)) 528 529 ilevelflat = ilevel + jnp.arange(x.shape[0]) * (self.dim + 2) 530 ilevel1flat = ilevel1 + jnp.arange(x.shape[0]) * (self.dim + 2) 531 532 out = ( 533 jnp.zeros((x.shape[0] * (self.dim + 2)), dtype=x.dtype) 534 .at[ilevelflat] 535 .set(w) 536 .at[ilevel1flat] 537 .set(1 - w) 538 .reshape(-1, self.dim + 2)[:, 1:-1] 539 ) 540 541 elif basis == "exp_lr" or basis == "exp": 542 zeta = self.extra_params.get("zeta", default=2.0) 543 s = self.extra_params.get("s", default=0.5) 544 n = np.arange(self.dim) 545 # if self.trainable: 546 # zeta = jnp.abs( 547 # self.param("zeta", lambda key: jnp.asarray(zeta, dtype=x.dtype)) 548 # ) 549 # s = jnp.abs(self.param("s", lambda key: jnp.asarray(s, dtype=x.dtype))) 550 551 a = zeta * s**n 552 xx = np.linspace(0, self.end, 10000) 553 554 if self.start > 0: 555 a1 = np.minimum(1.0 / a, 1.0) 556 switchstart = jnp.where( 557 x[:, None] < self.end, 558 1 559 - ( 560 0.5 561 + 0.5 562 * jnp.cos( 563 np.pi * (x[:, None] - self.start) / (self.end - self.start) 564 ) 565 ) 566 ** a1[None, :], 567 1, 568 ) * (x[:, None] > self.start) 569 switchstartxx = ( 570 1 571 - ( 572 0.5 573 + 0.5 574 * np.cos( 575 np.pi * (xx[:, None] - self.start) / (self.end - self.start) 576 ) 577 ) 578 ** a1[None, :] 579 ) * (xx[:, None] > self.start) 580 581 else: 582 switchstart = 1.0 583 switchstartxx = 1.0 584 585 norm = 1.0 / np.trapz( 586 switchstartxx * np.exp(-a[None, :] * xx[:, None]), xx, axis=0 587 ) 588 # norm = 1./np.max(switchstartxx*np.exp(-a[None, :] * xx[:, None]), axis=0) 589 590 if self.trainable: 591 a = jnp.abs(self.param("exponent", lambda key: jnp.asarray(a))) 592 593 out = switchstart * jnp.exp(-a[None, :] * x[:, None]) * norm[None, :] 594 595 elif basis == "neural_net" or basis == "nn": 596 neurons = self.extra_params.get( 597 "hidden_neurons", default=[2 * self.dim] 598 ) + [self.dim] 599 activation = self.extra_params.get("activation", default="swish") 600 use_bias = self.extra_params.get("use_bias", default=True) 601 602 out = FullyConnectedNet( 603 neurons, activation=activation, use_bias=use_bias, squeeze=False 604 )(x[:, None]) 605 elif basis == "damped_coulomb": 606 l = np.arange(self.dim)[None, :] 607 a2 = self.extra_params.get("a", default=1.0)**2 608 x = x[:, None] - self.start 609 end = self.end - self.start 610 x21 = a2 + x**2 611 R21 = a2 + end**2 612 out = ( 613 1.0 / x21 ** (0.5 * (l + 1)) 614 - 1.0 / (R21 ** (0.5 * (l + 1))) 615 + (x - end) * ((l + 1) * end / (R21 ** (0.5 * (l + 3)))) 616 ) * (x < end) 617 618 elif basis.startswith("spherical_bessel_"): 619 l = int(basis.split("_")[-1]) 620 out = generate_spherical_jn_basis([l], self.dim, self.end)(x) 621 else: 622 raise NotImplementedError(f"Unknown radial basis {basis}.") 623 ############################ 624 625 out = out.reshape((*shape, self.dim)) 626 627 if self.graph_key is not None: 628 output_key = self.name if self.output_key is None else self.output_key 629 return {**inputs, output_key: out} 630 return out 631 632 633def positional_encoding_static(t, d: int, n: float = 10000.0): 634 if d % 2 == 0: 635 k = np.arange(d // 2) 636 else: 637 k = np.arange((d + 1) // 2) 638 wk = np.asarray(1.0 / (n ** (2 * k / d))) 639 wkt = wk[None, :] * t[:, None] 640 out = np.concatenate([np.cos(wkt), np.sin(wkt)], axis=-1) 641 if d % 2 == 1: 642 out = out[:, :-1] 643 return out 644 645 646@partial(jax.jit, static_argnums=(1, 2), inline=True) 647def positional_encoding(t, d: int, n: float = 10000.0): 648 if d % 2 == 0: 649 k = np.arange(d // 2) 650 else: 651 k = np.arange((d + 1) // 2) 652 wk = jnp.asarray(1.0 / (n ** (2 * k / d))) 653 wkt = wk[None, :] * t[:, None] 654 out = jnp.concatenate([jnp.cos(wkt), jnp.sin(wkt)], axis=-1) 655 if d % 2 == 1: 656 out = out[:, :-1] 657 return out 658 659 660def generate_spherical_jn_basis(dim:int, rc:float, ls:Union[int,Sequence[int]]=[0], print_code:bool=False, jit:bool=False): 661 from sympy import Symbol, jn, expand_func 662 from scipy.special import spherical_jn 663 from sympy import jn_zeros 664 import scipy.integrate as integrate 665 666 if isinstance(ls, int): 667 ls = list(range(ls + 1)) 668 zl = [Symbol(f"xz[...,{l}]") for l in ls] 669 zn = np.array([jn_zeros(l, dim) for l in ls], dtype=float).T 670 znrc = zn / rc 671 norms = np.zeros((dim, len(ls)), dtype=float) 672 for l in ls: 673 for i in range(dim): 674 norms[i, l] = ( 675 integrate.quad(lambda x: (spherical_jn(l, x) * x) ** 2, 0, zn[i, l])[0] 676 / znrc[i, l] ** 3 677 ) ** (-0.5) 678 679 fn_str = f"""def spherical_jn_basis_(x): 680 from jax.numpy import cos,sin 681 682 znrc = jnp.array({znrc.tolist()},dtype=x.dtype) 683 norms = jnp.array({norms.tolist()},dtype=x.dtype) 684 xshape = x.shape 685 x = x.reshape(-1) 686 xz = x[:,None,None]*znrc[None,:,:] 687 688 jns = jnp.stack([ 689 """ 690 for l in ls: 691 fn_str += f" {expand_func(jn(l, zl[l]))},\n" 692 fn_str += f""" ],axis=-1) 693 return (norms[None,:,:]*jns).reshape(*xshape,{dim},{len(ls)}) 694 """ 695 696 if print_code: 697 print(fn_str) 698 exec(fn_str) 699 jn_basis = locals()["spherical_jn_basis_"] 700 if jit: 701 jn_basis = jax.jit(jn_basis) 702 return jn_basis
28class SpeciesEncoding(nn.Module): 29 """A module that encodes chemical species information. 30 31 FID: SPECIES_ENCODING 32 """ 33 34 encoding: str = "random" 35 """ The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". 36 Multiple encodings can be concatenated using the "+" separator. 37 """ 38 dim: int = 16 39 """ The dimension of the encoding if not fixed by design.""" 40 zmax: int = 86 41 """ The maximum atomic number to encode.""" 42 output_key: Optional[str] = None 43 """ The key to use for the output in the returned dictionary.""" 44 45 species_order: Optional[Union[str, Sequence[str]]] = None 46 """ The order of the species to use for the encoding. Only used for "onehot" encoding. 47 If None, we encode all elements up to `zmax`.""" 48 trainable: bool = False 49 """ Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable.""" 50 extra_params: Dict = dataclasses.field(default_factory=dict) 51 """ Dictionary of extra parameters for the basis.""" 52 53 FID: ClassVar[str] = "SPECIES_ENCODING" 54 55 @nn.compact 56 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 57 58 zmax = self.zmax 59 if zmax <= 0 or zmax > len(PERIODIC_TABLE): 60 zmax = len(PERIODIC_TABLE) 61 62 zmaxpad = zmax + 2 63 64 encoding = self.encoding.lower().strip() 65 encodings = encoding.split("+") 66 ############################ 67 conv_tensors = [] 68 69 if "one_hot" in encodings or "onehot" in encodings: 70 if self.species_order is None: 71 conv_tensor = np.eye(zmax) 72 conv_tensor = np.concatenate( 73 [np.zeros((1, zmax)), conv_tensor, np.zeros((1, zmax))], axis=0 74 ) 75 else: 76 if isinstance(self.species_order, str): 77 species_order = [el.strip() for el in self.species_order.split(",")] 78 else: 79 species_order = [el for el in self.species_order] 80 conv_tensor = np.zeros((zmaxpad, len(species_order))) 81 for i, s in enumerate(species_order): 82 conv_tensor[PERIODIC_TABLE_REV_IDX[s], i] = 1 83 84 conv_tensors.append(conv_tensor) 85 86 if "occupation" in encodings: 87 conv_tensor = np.zeros((zmaxpad, (zmax + 1) // 2)) 88 for i in range(1, zmax + 1): 89 conv_tensor[i, : i // 2] = 1 90 if i % 2 == 1: 91 conv_tensor[i, i // 2] = 0.5 92 93 conv_tensors.append(conv_tensor) 94 95 if "electronic_structure" in encodings: 96 Z = np.arange(1, zmax + 1).reshape(-1, 1) 97 Zref = [zmax] 98 e_struct = np.array(EL_STRUCT[1 : zmax + 1]) 99 eref = [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 14, 10, 6] 100 v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1]) 101 vref = [2, 6, 10, 14] 102 if zmax <= 86: 103 e_struct = e_struct[:, :15] 104 eref = eref[:15] 105 ref = np.array(Zref + eref + vref) 106 conv_tensor = np.concatenate([Z, e_struct, v_struct], axis=1) 107 conv_tensor = conv_tensor / ref[None, :] 108 dim = conv_tensor.shape[1] 109 conv_tensor = np.concatenate( 110 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 111 ) 112 113 conv_tensors.append(conv_tensor) 114 115 if "properties" in encodings: 116 props = np.array(XENONPY_PROPS)[1:-1] 117 assert ( 118 self.zmax <= props.shape[0] 119 ), f"zmax > {props.shape[0]} not supported for xenonpy properties" 120 conv_tensor = props[1 : zmax + 1] 121 mean = np.mean(props, axis=0) 122 std = np.std(props, axis=0) 123 conv_tensor = (conv_tensor - mean[None, :]) / std[None, :] 124 dim = conv_tensor.shape[1] 125 conv_tensor = np.concatenate( 126 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 127 ) 128 conv_tensors.append(conv_tensor) 129 130 if "valence_properties" in encodings: 131 assert zmax <= 86, "Valence properties only available for zmax <= 86" 132 Z = np.arange(1, zmax + 1).reshape(-1, 1) 133 Zref = [zmax] 134 Zinv = 1.0 / Z 135 Zinvref = [1.0] 136 v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1]) 137 vref = [2, 6, 10, 14] 138 ionization = np.array(ATOMIC_IONIZATION_ENERGY[1 : zmax + 1]).reshape(-1, 1) 139 ionizationref = [0.5] 140 polariz = np.array(POLARIZABILITIES[1 : zmax + 1]).reshape(-1, 1) 141 polarizref = [np.median(polariz)] 142 143 cov = np.array(D3_COV_RADII[1 : zmax + 1]).reshape(-1, 1) 144 covref = [np.median(cov)] 145 146 vdw = np.array(VDW_RADII[1 : zmax + 1]).reshape(-1, 1) 147 vdwref = [np.median(vdw)] 148 149 ref = np.array( 150 Zref + Zinvref + ionizationref + polarizref + covref + vdwref + vref 151 ) 152 conv_tensor = np.concatenate( 153 [Z, Zinv, ionization, polariz, cov, vdw, v_struct], axis=1 154 ) 155 156 conv_tensor = conv_tensor / ref[None, :] 157 dim = conv_tensor.shape[1] 158 conv_tensor = np.concatenate( 159 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 160 ) 161 conv_tensors.append(conv_tensor) 162 163 if "sjs_coordinates" in encodings: 164 coords = np.array(SJS_COORDINATES)[1:-1] 165 conv_tensor = coords[1 : zmax + 1] 166 mean = np.mean(coords, axis=0) 167 std = np.std(coords, axis=0) 168 conv_tensor = (conv_tensor - mean[None, :]) / std[None, :] 169 dim = conv_tensor.shape[1] 170 conv_tensor = np.concatenate( 171 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 172 ) 173 conv_tensors.append(conv_tensor) 174 175 if "positional" in encodings: 176 coords = np.array(PERIODIC_COORDINATES)[1 : zmax + 1] 177 row, col = coords[:, 0], coords[:, 1] 178 drow = self.extra_params.get("drow", default=5) 179 dcol = self.dim - drow 180 nrow = self.extra_params.get("nrow", default=100.0) 181 ncol = self.extra_params.get("ncol", default=1000.0) 182 183 erow = positional_encoding_static(row, drow, nrow) 184 ecol = positional_encoding_static(col, dcol, ncol) 185 conv_tensor = np.concatenate([erow, ecol], axis=-1) 186 dim = conv_tensor.shape[1] 187 conv_tensor = np.concatenate( 188 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 189 ) 190 conv_tensors.append(conv_tensor) 191 192 if "oxidation" in encodings: 193 states_set = sorted(set(sum(OXIDATION_STATES, []))) 194 nstates = len(states_set) 195 state_dict = {s: i for i, s in enumerate(states_set)} 196 conv_tensor = np.zeros((zmaxpad, nstates)) 197 for i, states in enumerate(OXIDATION_STATES[1 : zmax + 1]): 198 for s in states: 199 conv_tensor[i, state_dict[s]] = 1 200 conv_tensors.append(conv_tensor) 201 202 if len(conv_tensors) > 0: 203 conv_tensor = np.concatenate(conv_tensors, axis=1) 204 if self.trainable: 205 conv_tensor = self.param( 206 "conv_tensor", 207 lambda key: jnp.asarray(conv_tensor, dtype=jnp.float32), 208 ) 209 else: 210 conv_tensor = jnp.asarray(conv_tensor, dtype=jnp.float32) 211 conv_tensors = [conv_tensor] 212 else: 213 conv_tensors = [] 214 215 if "random" in encodings: 216 rand_encoding = self.param( 217 "rand_encoding", 218 lambda key, shape: jax.nn.standardize( 219 jax.random.normal(key, shape, dtype=jnp.float32) 220 ), 221 (zmaxpad, self.dim), 222 ) 223 conv_tensors.append(rand_encoding) 224 225 if "randint" in encodings: 226 rand_encoding = self.param( 227 "randint_encoding", 228 lambda key, shape: jax.random.randint(key, shape, 0, 2).astype( 229 jnp.float32 230 ), 231 (zmaxpad, self.dim), 232 ) 233 conv_tensors.append(rand_encoding) 234 235 if "randtri" in encodings: 236 rand_encoding = self.param( 237 "randtri_encoding", 238 lambda key, shape: (jax.random.randint(key, shape, 0, 3) - 1).astype( 239 jnp.float32 240 ), 241 (zmaxpad, self.dim), 242 ) 243 conv_tensors.append(rand_encoding) 244 245 assert len(conv_tensors) > 0, f"No encoding recognized in '{self.encoding}'" 246 247 conv_tensor = jnp.concatenate(conv_tensors, axis=1) 248 249 species = inputs["species"] if isinstance(inputs, dict) else inputs 250 out = conv_tensor[species] 251 ############################ 252 253 if isinstance(inputs, dict): 254 output_key = self.name if self.output_key is None else self.output_key 255 out = out.astype(inputs["coordinates"].dtype) 256 return {**inputs, output_key: out} if output_key is not None else out 257 return out
A module that encodes chemical species information.
FID: SPECIES_ENCODING
The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". Multiple encodings can be concatenated using the "+" separator.
The order of the species to use for the encoding. Only used for "onehot" encoding.
If None, we encode all elements up to zmax
.
Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
260class RadialBasis(nn.Module): 261 """Computes a radial encoding of distances. 262 263 FID: RADIAL_BASIS 264 """ 265 266 end: float 267 """ The maximum distance to consider.""" 268 start: float = 0.0 269 """ The minimum distance to consider.""" 270 dim: int = 8 271 """ The dimension of the basis.""" 272 graph_key: Optional[str] = None 273 """ The key of the graph in the inputs.""" 274 output_key: Optional[str] = None 275 """ The key to use for the output in the returned dictionary.""" 276 basis: str = "bessel" 277 """ The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky".""" 278 trainable: bool = False 279 """ Whether the basis parameters are trainable or fixed.""" 280 enforce_positive: bool = False 281 """ Whether to enforce distance-start to be positive""" 282 gamma: float = 1.0 / (2 * au.BOHR) 283 """ The gamma parameter for the "spooky" basis.""" 284 n_levels: int = 10 285 """ The number of levels for the "levels" basis.""" 286 alt_bessel_norm: bool = False 287 """ If True, use the (2/(end-start))**0.5 normalization for the bessel basis.""" 288 extra_params: Dict = dataclasses.field(default_factory=dict) 289 """ Dictionary of extra parameters for the basis.""" 290 291 FID: ClassVar[str] = "RADIAL_BASIS" 292 293 @nn.compact 294 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 295 if self.graph_key is not None: 296 x = inputs[self.graph_key]["distances"] 297 else: 298 x = inputs["distances"] if isinstance(inputs, dict) else inputs 299 300 shape = x.shape 301 x = x.reshape(-1) 302 303 basis = self.basis.lower() 304 ############################ 305 if basis == "bessel": 306 c = self.end - self.start 307 x = x[:, None] - self.start 308 # if self.enforce_positive: 309 # x = jax.nn.softplus(x) 310 311 if self.trainable: 312 bessel_roots = self.param( 313 "bessel_roots", 314 lambda key, dim: jnp.asarray( 315 np.arange(1, dim + 1, dtype=x.dtype)[None, :] * (math.pi / c) 316 ), 317 self.dim, 318 ) 319 norm = 1.0 / jnp.max( 320 bessel_roots 321 ) # (2.0 / c) ** 0.5 /jnp.max(bessel_roots) 322 else: 323 bessel_roots = jnp.asarray( 324 np.arange(1, self.dim + 1, dtype=x.dtype)[None, :] * (math.pi / c) 325 ) 326 norm = 1.0 / ( 327 self.dim * math.pi / c 328 ) # (2.0 / c) ** 0.5/(self.dim*math.pi/c) 329 if self.alt_bessel_norm: 330 norm = (2.0 / c) ** 0.5 331 out = norm * jnp.sin(x * bessel_roots) / x 332 333 if self.enforce_positive: 334 out = jnp.where(x > 0, out * (1.0 - jnp.exp(-(x**2))), 0.0) 335 336 elif basis == "gaussian": 337 if self.trainable: 338 roots = self.param( 339 "radial_centers", 340 lambda key, dim, start, end: jnp.linspace( 341 start, end, dim + 1, dtype=x.dtype 342 )[None, :-1], 343 self.dim, 344 self.start, 345 self.end, 346 ) 347 eta = self.param( 348 "radial_etas", 349 lambda key, dim, start, end: jnp.full( 350 dim, 351 dim / (end - start), 352 dtype=x.dtype, 353 )[None, :], 354 self.dim, 355 self.start, 356 self.end, 357 ) 358 359 else: 360 roots = jnp.asarray( 361 np.linspace(self.start, self.end, self.dim + 1)[None, :-1], 362 dtype=x.dtype, 363 ) 364 eta = jnp.asarray( 365 np.full(self.dim, self.dim / (self.end - self.start))[None, :], 366 dtype=x.dtype, 367 ) 368 369 x = x[:, None] 370 x2 = (eta * (x - roots)) ** 2 371 out = jnp.exp(-x2) 372 if self.enforce_positive: 373 out = jnp.where( 374 x > self.start, 375 out * (1.0 - jnp.exp(-10 * (x - self.start) ** 2)), 376 0.0, 377 ) 378 379 elif basis == "gaussian_rinv": 380 rinv_high = 1.0 / max(self.start, 0.1) 381 rinv_low = 1.0 / (0.8 * self.end) 382 383 if self.trainable: 384 roots = self.param( 385 "radial_centers", 386 lambda key, dim, rinv_low, rinv_high: jnp.linspace( 387 rinv_low, rinv_high, dim, dtype=x.dtype 388 )[None, :], 389 self.dim, 390 rinv_low, 391 rinv_high, 392 ) 393 sigmas = self.param( 394 "radial_sigmas", 395 lambda key, dim, rinv_low, rinv_high: jnp.full( 396 dim, 397 2**0.5 / (2 * dim * rinv_low), 398 dtype=x.dtype, 399 )[None, :], 400 self.dim, 401 rinv_low, 402 rinv_high, 403 ) 404 else: 405 roots = jnp.asarray( 406 np.linspace(rinv_low, rinv_high, self.dim, dtype=x.dtype)[None, :] 407 ) 408 sigmas = jnp.asarray( 409 np.full( 410 self.dim, 411 2**0.5 / (2 * self.dim * rinv_low), 412 )[None, :], 413 dtype=x.dtype, 414 ) 415 416 rinv = 1.0 / x[:, None] 417 418 out = jnp.exp(-((rinv - roots) ** 2) / sigmas**2) 419 420 elif basis == "fourier": 421 if self.trainable: 422 roots = self.param( 423 "roots", 424 lambda key, dim: jnp.arange(dim, dtype=x.dtype)[None, :] * math.pi, 425 self.dim, 426 ) 427 else: 428 roots = jnp.asarray( 429 np.arange(self.dim)[None, :] * math.pi, dtype=x.dtype 430 ) 431 c = self.end - self.start 432 x = x[:, None] - self.start 433 # if self.enforce_positive: 434 # x = jax.nn.softplus(x) 435 norm = 1.0 / (0.25 + 0.5 * self.dim) ** 0.5 436 out = norm * jnp.cos(x * roots / c) 437 if self.enforce_positive: 438 out = jnp.where(x > 0, out, norm) 439 440 elif basis == "spooky": 441 442 gamma = self.gamma 443 if self.trainable: 444 gamma = jnp.abs( 445 self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype)) 446 ) 447 448 if self.enforce_positive: 449 x = jnp.where(x - self.start > 1.0e-3, x - self.start, 1.0e-3)[:, None] 450 dim = self.dim 451 else: 452 x = x[:, None] - self.start 453 dim = self.dim - 1 454 455 norms = [] 456 for k in range(self.dim): 457 norms.append(math.comb(dim, k)) 458 norms = jnp.asarray(np.array(norms)[None, :], dtype=x.dtype) 459 460 e = jnp.exp(-gamma * x) 461 k = jnp.asarray(np.arange(self.dim, dtype=x.dtype)[None, :]) 462 b = e**k * (1 - e) ** (dim - k) 463 out = b * e * norms 464 if self.enforce_positive: 465 out = jnp.where(x > 1.0e-3, out * (1.0 - jnp.exp(-(x**2))), 0.0) 466 # logfac = np.zeros(self.dim) 467 # for i in range(2, self.dim): 468 # logfac[i] = logfac[i - 1] + np.log(i) 469 # k = np.arange(self.dim) 470 # n = self.dim - 1 - k 471 # logbin = jnp.asarray((logfac[-1] - logfac[k] - logfac[n])[None,:], dtype=x.dtype) 472 # n = jnp.asarray(n[None,:], dtype=x.dtype) 473 # k = jnp.asarray(k[None,:], dtype=x.dtype) 474 475 # gamma = 1.0 / (2 * au.BOHR) 476 # if self.trainable: 477 # gamma = jnp.abs( 478 # self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype)) 479 # ) 480 # gammar = (-gamma * x)[:,None] 481 # x = logbin + n * gammar + k * jnp.log(-jnp.expm1(gammar)) 482 # out = jnp.exp(x)*jnp.exp(gammar) 483 elif basis == "levels": 484 assert self.n_levels >= 2, "Number of levels must be >= 2." 485 486 def initialize_levels(key): 487 key0, key1, key_phi = jax.random.split(key, 3) 488 level0 = jax.random.randint(key0, (self.dim,), 0, 2) 489 level1 = jax.random.randint(key1, (self.dim,), 0, 2) 490 # level0 = jax.random.normal(key0, (self.dim,), dtype=jnp.float32) 491 # level1 = jax.random.normal(key1, (self.dim,), dtype=jnp.float32) 492 phi = jax.random.uniform(key_phi, (self.dim,), dtype=jnp.float32) 493 levels = [level0] 494 for l in range(2, self.n_levels - 1): 495 tau = float(self.n_levels - l) / float(self.n_levels - 1) 496 phil = phi < tau 497 level = jnp.where(phil, level0, level1) 498 levels.append(level) 499 levels.append(level1) 500 return jnp.stack(levels).astype(jnp.float32) 501 502 levels = self.param("levels", initialize_levels) 503 # levels = self.param( 504 # "levels", 505 # lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32), 506 # (self.n_levels,self.dim), 507 # ) 508 509 flevel = (x - self.start) / (self.end - self.start) * (self.n_levels - 1) 510 ilevel = jnp.floor(flevel).astype(jnp.int32) 511 ilevel1 = jnp.clip(ilevel + 1, 0, self.n_levels - 1) 512 ilevel = jnp.clip(ilevel, 0, self.n_levels - 1) 513 514 dx = flevel - ilevel 515 w = 0.5 * (1 + jnp.cos(jnp.pi * dx))[:, None] 516 517 ## interpolate between level vectors 518 v1 = levels[ilevel] 519 v2 = levels[ilevel1] 520 out = v1 * w + v2 * (1 - w) 521 elif basis == "finite_support": 522 flevel = (x - self.start) / (self.end - self.start) * (self.dim + 1) 523 ilevel = jnp.floor(flevel).astype(jnp.int32) 524 ilevel1 = jnp.clip(ilevel + 1, 0, self.dim + 1) 525 ilevel = jnp.clip(ilevel, 0, self.dim + 1) 526 527 dx = flevel - ilevel 528 w = 0.5 * (1 + jnp.cos(jnp.pi * dx)) 529 530 ilevelflat = ilevel + jnp.arange(x.shape[0]) * (self.dim + 2) 531 ilevel1flat = ilevel1 + jnp.arange(x.shape[0]) * (self.dim + 2) 532 533 out = ( 534 jnp.zeros((x.shape[0] * (self.dim + 2)), dtype=x.dtype) 535 .at[ilevelflat] 536 .set(w) 537 .at[ilevel1flat] 538 .set(1 - w) 539 .reshape(-1, self.dim + 2)[:, 1:-1] 540 ) 541 542 elif basis == "exp_lr" or basis == "exp": 543 zeta = self.extra_params.get("zeta", default=2.0) 544 s = self.extra_params.get("s", default=0.5) 545 n = np.arange(self.dim) 546 # if self.trainable: 547 # zeta = jnp.abs( 548 # self.param("zeta", lambda key: jnp.asarray(zeta, dtype=x.dtype)) 549 # ) 550 # s = jnp.abs(self.param("s", lambda key: jnp.asarray(s, dtype=x.dtype))) 551 552 a = zeta * s**n 553 xx = np.linspace(0, self.end, 10000) 554 555 if self.start > 0: 556 a1 = np.minimum(1.0 / a, 1.0) 557 switchstart = jnp.where( 558 x[:, None] < self.end, 559 1 560 - ( 561 0.5 562 + 0.5 563 * jnp.cos( 564 np.pi * (x[:, None] - self.start) / (self.end - self.start) 565 ) 566 ) 567 ** a1[None, :], 568 1, 569 ) * (x[:, None] > self.start) 570 switchstartxx = ( 571 1 572 - ( 573 0.5 574 + 0.5 575 * np.cos( 576 np.pi * (xx[:, None] - self.start) / (self.end - self.start) 577 ) 578 ) 579 ** a1[None, :] 580 ) * (xx[:, None] > self.start) 581 582 else: 583 switchstart = 1.0 584 switchstartxx = 1.0 585 586 norm = 1.0 / np.trapz( 587 switchstartxx * np.exp(-a[None, :] * xx[:, None]), xx, axis=0 588 ) 589 # norm = 1./np.max(switchstartxx*np.exp(-a[None, :] * xx[:, None]), axis=0) 590 591 if self.trainable: 592 a = jnp.abs(self.param("exponent", lambda key: jnp.asarray(a))) 593 594 out = switchstart * jnp.exp(-a[None, :] * x[:, None]) * norm[None, :] 595 596 elif basis == "neural_net" or basis == "nn": 597 neurons = self.extra_params.get( 598 "hidden_neurons", default=[2 * self.dim] 599 ) + [self.dim] 600 activation = self.extra_params.get("activation", default="swish") 601 use_bias = self.extra_params.get("use_bias", default=True) 602 603 out = FullyConnectedNet( 604 neurons, activation=activation, use_bias=use_bias, squeeze=False 605 )(x[:, None]) 606 elif basis == "damped_coulomb": 607 l = np.arange(self.dim)[None, :] 608 a2 = self.extra_params.get("a", default=1.0)**2 609 x = x[:, None] - self.start 610 end = self.end - self.start 611 x21 = a2 + x**2 612 R21 = a2 + end**2 613 out = ( 614 1.0 / x21 ** (0.5 * (l + 1)) 615 - 1.0 / (R21 ** (0.5 * (l + 1))) 616 + (x - end) * ((l + 1) * end / (R21 ** (0.5 * (l + 3)))) 617 ) * (x < end) 618 619 elif basis.startswith("spherical_bessel_"): 620 l = int(basis.split("_")[-1]) 621 out = generate_spherical_jn_basis([l], self.dim, self.end)(x) 622 else: 623 raise NotImplementedError(f"Unknown radial basis {basis}.") 624 ############################ 625 626 out = out.reshape((*shape, self.dim)) 627 628 if self.graph_key is not None: 629 output_key = self.name if self.output_key is None else self.output_key 630 return {**inputs, output_key: out} 631 return out
Computes a radial encoding of distances.
FID: RADIAL_BASIS
The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky".
If True, use the (2/(end-start))**0.5 normalization for the bessel basis.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
634def positional_encoding_static(t, d: int, n: float = 10000.0): 635 if d % 2 == 0: 636 k = np.arange(d // 2) 637 else: 638 k = np.arange((d + 1) // 2) 639 wk = np.asarray(1.0 / (n ** (2 * k / d))) 640 wkt = wk[None, :] * t[:, None] 641 out = np.concatenate([np.cos(wkt), np.sin(wkt)], axis=-1) 642 if d % 2 == 1: 643 out = out[:, :-1] 644 return out
647@partial(jax.jit, static_argnums=(1, 2), inline=True) 648def positional_encoding(t, d: int, n: float = 10000.0): 649 if d % 2 == 0: 650 k = np.arange(d // 2) 651 else: 652 k = np.arange((d + 1) // 2) 653 wk = jnp.asarray(1.0 / (n ** (2 * k / d))) 654 wkt = wk[None, :] * t[:, None] 655 out = jnp.concatenate([jnp.cos(wkt), jnp.sin(wkt)], axis=-1) 656 if d % 2 == 1: 657 out = out[:, :-1] 658 return out
661def generate_spherical_jn_basis(dim:int, rc:float, ls:Union[int,Sequence[int]]=[0], print_code:bool=False, jit:bool=False): 662 from sympy import Symbol, jn, expand_func 663 from scipy.special import spherical_jn 664 from sympy import jn_zeros 665 import scipy.integrate as integrate 666 667 if isinstance(ls, int): 668 ls = list(range(ls + 1)) 669 zl = [Symbol(f"xz[...,{l}]") for l in ls] 670 zn = np.array([jn_zeros(l, dim) for l in ls], dtype=float).T 671 znrc = zn / rc 672 norms = np.zeros((dim, len(ls)), dtype=float) 673 for l in ls: 674 for i in range(dim): 675 norms[i, l] = ( 676 integrate.quad(lambda x: (spherical_jn(l, x) * x) ** 2, 0, zn[i, l])[0] 677 / znrc[i, l] ** 3 678 ) ** (-0.5) 679 680 fn_str = f"""def spherical_jn_basis_(x): 681 from jax.numpy import cos,sin 682 683 znrc = jnp.array({znrc.tolist()},dtype=x.dtype) 684 norms = jnp.array({norms.tolist()},dtype=x.dtype) 685 xshape = x.shape 686 x = x.reshape(-1) 687 xz = x[:,None,None]*znrc[None,:,:] 688 689 jns = jnp.stack([ 690 """ 691 for l in ls: 692 fn_str += f" {expand_func(jn(l, zl[l]))},\n" 693 fn_str += f""" ],axis=-1) 694 return (norms[None,:,:]*jns).reshape(*xshape,{dim},{len(ls)}) 695 """ 696 697 if print_code: 698 print(fn_str) 699 exec(fn_str) 700 jn_basis = locals()["spherical_jn_basis_"] 701 if jit: 702 jn_basis = jax.jit(jn_basis) 703 return jn_basis