fennol.models.misc.encodings
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Optional, Union, List, Sequence, ClassVar, Dict 5import math 6import dataclasses 7import numpy as np 8from .nets import FullyConnectedNet 9from ...utils import AtomicUnits as au 10from functools import partial 11from ...utils.periodic_table import ( 12 PERIODIC_TABLE_REV_IDX, 13 PERIODIC_TABLE, 14 EL_STRUCT, 15 VALENCE_STRUCTURE, 16 XENONPY_PROPS, 17 SJS_COORDINATES, 18 PERIODIC_COORDINATES, 19 ATOMIC_IONIZATION_ENERGY, 20 POLARIZABILITIES, 21 D3_COV_RADII, 22 VDW_RADII, 23 OXIDATION_STATES, 24) 25 26 27class SpeciesEncoding(nn.Module): 28 """A module that encodes chemical species information. 29 30 FID: SPECIES_ENCODING 31 """ 32 33 encoding: str = "random" 34 """ The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". 35 Multiple encodings can be concatenated using the "+" separator. 36 """ 37 dim: int = 16 38 """ The dimension of the encoding if not fixed by design.""" 39 zmax: int = 86 40 """ The maximum atomic number to encode.""" 41 output_key: Optional[str] = None 42 """ The key to use for the output in the returned dictionary.""" 43 44 species_order: Optional[Union[str, Sequence[str]]] = None 45 """ The order of the species to use for the encoding. Only used for "onehot" encoding. 46 If None, we encode all elements up to `zmax`.""" 47 trainable: bool = False 48 """ Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable.""" 49 extra_params: Dict = dataclasses.field(default_factory=dict) 50 """ Dictionary of extra parameters for the basis.""" 51 52 FID: ClassVar[str] = "SPECIES_ENCODING" 53 54 @nn.compact 55 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 56 57 zmax = self.zmax 58 if zmax <= 0 or zmax > len(PERIODIC_TABLE): 59 zmax = len(PERIODIC_TABLE) 60 61 zmaxpad = zmax + 2 62 63 encoding = self.encoding.lower().strip() 64 encodings = encoding.split("+") 65 ############################ 66 conv_tensors = [] 67 68 if "one_hot" in encodings or "onehot" in encodings: 69 if self.species_order is None: 70 conv_tensor = np.eye(zmax) 71 conv_tensor = np.concatenate( 72 [np.zeros((1, zmax)), conv_tensor, np.zeros((1, zmax))], axis=0 73 ) 74 else: 75 if isinstance(self.species_order, str): 76 species_order = [el.strip() for el in self.species_order.split(",")] 77 else: 78 species_order = [el for el in self.species_order] 79 conv_tensor = np.zeros((zmaxpad, len(species_order))) 80 for i, s in enumerate(species_order): 81 conv_tensor[PERIODIC_TABLE_REV_IDX[s], i] = 1 82 83 conv_tensors.append(conv_tensor) 84 85 if "occupation" in encodings: 86 conv_tensor = np.zeros((zmaxpad, (zmax + 1) // 2)) 87 for i in range(1, zmax + 1): 88 conv_tensor[i, : i // 2] = 1 89 if i % 2 == 1: 90 conv_tensor[i, i // 2] = 0.5 91 92 conv_tensors.append(conv_tensor) 93 94 if "electronic_structure" in encodings: 95 Z = np.arange(1, zmax + 1).reshape(-1, 1) 96 Zref = [zmax] 97 e_struct = np.array(EL_STRUCT[1 : zmax + 1]) 98 eref = [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 14, 10, 6] 99 v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1]) 100 vref = [2, 6, 10, 14] 101 if zmax <= 86: 102 e_struct = e_struct[:, :15] 103 eref = eref[:15] 104 ref = np.array(Zref + eref + vref) 105 conv_tensor = np.concatenate([Z, e_struct, v_struct], axis=1) 106 conv_tensor = conv_tensor / ref[None, :] 107 dim = conv_tensor.shape[1] 108 conv_tensor = np.concatenate( 109 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 110 ) 111 112 conv_tensors.append(conv_tensor) 113 114 if "valence_structure" in encodings: 115 vref = np.array([2, 6, 10, 14]) 116 conv_tensor = np.array(VALENCE_STRUCTURE[1 : zmax + 1])/vref[None,:] 117 dim = conv_tensor.shape[1] 118 conv_tensor = np.concatenate( 119 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 120 ) 121 conv_tensors.append(conv_tensor) 122 123 if "period_onehot" in encodings: 124 periods = PERIODIC_COORDINATES[1 : zmax + 1, 0]-1 125 conv_tensor = np.eye(7)[periods] 126 dim = conv_tensor.shape[1] 127 conv_tensor = np.concatenate( 128 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 129 ) 130 conv_tensors.append(conv_tensor) 131 132 if "period" in encodings: 133 conv_tensor = (PERIODIC_COORDINATES[1 : zmax + 1, 0]/7.)[:,None] 134 dim = conv_tensor.shape[1] 135 conv_tensor = np.concatenate( 136 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 137 ) 138 conv_tensors.append(conv_tensor) 139 140 if "period_positional" in encodings: 141 row = PERIODIC_COORDINATES[1 : zmax + 1, 0] 142 drow = self.extra_params.get("drow", default=5) 143 nrow = self.extra_params.get("nrow", default=100.0) 144 145 conv_tensor = positional_encoding_static(row, drow, nrow) 146 dim = conv_tensor.shape[1] 147 conv_tensor = np.concatenate( 148 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 149 ) 150 conv_tensors.append(conv_tensor) 151 152 153 if "properties" in encodings: 154 props = np.array(XENONPY_PROPS)[1:-1] 155 assert ( 156 self.zmax <= props.shape[0] 157 ), f"zmax > {props.shape[0]} not supported for xenonpy properties" 158 conv_tensor = props[1 : zmax + 1] 159 mean = np.mean(props, axis=0) 160 std = np.std(props, axis=0) 161 conv_tensor = (conv_tensor - mean[None, :]) / std[None, :] 162 dim = conv_tensor.shape[1] 163 conv_tensor = np.concatenate( 164 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 165 ) 166 conv_tensors.append(conv_tensor) 167 168 169 if "valence_properties" in encodings: 170 assert zmax <= 86, "Valence properties only available for zmax <= 86" 171 Z = np.arange(1, zmax + 1).reshape(-1, 1) 172 Zref = [zmax] 173 Zinv = 1.0 / Z 174 Zinvref = [1.0] 175 v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1]) 176 vref = [2, 6, 10, 14] 177 ionization = np.array(ATOMIC_IONIZATION_ENERGY[1 : zmax + 1]).reshape(-1, 1) 178 ionizationref = [0.5] 179 polariz = np.array(POLARIZABILITIES[1 : zmax + 1]).reshape(-1, 1) 180 polarizref = [np.median(polariz)] 181 182 cov = np.array(D3_COV_RADII[1 : zmax + 1]).reshape(-1, 1) 183 covref = [np.median(cov)] 184 185 vdw = np.array(VDW_RADII[1 : zmax + 1]).reshape(-1, 1) 186 vdwref = [np.median(vdw)] 187 188 ref = np.array( 189 Zref + Zinvref + ionizationref + polarizref + covref + vdwref + vref 190 ) 191 conv_tensor = np.concatenate( 192 [Z, Zinv, ionization, polariz, cov, vdw, v_struct], axis=1 193 ) 194 195 conv_tensor = conv_tensor / ref[None, :] 196 dim = conv_tensor.shape[1] 197 conv_tensor = np.concatenate( 198 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 199 ) 200 conv_tensors.append(conv_tensor) 201 202 if "sjs_coordinates" in encodings: 203 coords = np.array(SJS_COORDINATES)[1:-1] 204 conv_tensor = coords[1 : zmax + 1] 205 mean = np.mean(coords, axis=0) 206 std = np.std(coords, axis=0) 207 conv_tensor = (conv_tensor - mean[None, :]) / std[None, :] 208 dim = conv_tensor.shape[1] 209 conv_tensor = np.concatenate( 210 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 211 ) 212 conv_tensors.append(conv_tensor) 213 214 if "positional" in encodings: 215 coords = np.array(PERIODIC_COORDINATES)[1 : zmax + 1] 216 row, col = coords[:, 0], coords[:, 1] 217 drow = self.extra_params.get("drow", default=5) 218 dcol = self.dim - drow 219 nrow = self.extra_params.get("nrow", default=100.0) 220 ncol = self.extra_params.get("ncol", default=1000.0) 221 222 erow = positional_encoding_static(row, drow, nrow) 223 ecol = positional_encoding_static(col, dcol, ncol) 224 conv_tensor = np.concatenate([erow, ecol], axis=-1) 225 dim = conv_tensor.shape[1] 226 conv_tensor = np.concatenate( 227 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 228 ) 229 conv_tensors.append(conv_tensor) 230 231 if "oxidation" in encodings: 232 states_set = sorted(set(sum(OXIDATION_STATES, []))) 233 nstates = len(states_set) 234 state_dict = {s: i for i, s in enumerate(states_set)} 235 conv_tensor = np.zeros((zmaxpad, nstates)) 236 for i, states in enumerate(OXIDATION_STATES[1 : zmax + 1]): 237 for s in states: 238 conv_tensor[i, state_dict[s]] = 1 239 conv_tensors.append(conv_tensor) 240 241 if len(conv_tensors) > 0: 242 conv_tensor = np.concatenate(conv_tensors, axis=1) 243 if self.trainable: 244 conv_tensor = self.param( 245 "conv_tensor", 246 lambda key: jnp.asarray(conv_tensor, dtype=jnp.float32), 247 ) 248 else: 249 conv_tensor = jnp.asarray(conv_tensor, dtype=jnp.float32) 250 conv_tensors = [conv_tensor] 251 else: 252 conv_tensors = [] 253 254 if "random" in encodings: 255 rand_encoding = self.param( 256 "rand_encoding", 257 lambda key, shape: jax.nn.standardize( 258 jax.random.normal(key, shape, dtype=jnp.float32) 259 ), 260 (zmaxpad, self.dim), 261 ) 262 conv_tensors.append(rand_encoding) 263 264 if "zeros" in encodings: 265 zeros_encoding = self.param( 266 "zeros_encoding", 267 lambda key, shape: jnp.zeros(shape, dtype=jnp.float32), 268 (zmaxpad, self.dim), 269 ) 270 conv_tensors.append(zeros_encoding) 271 272 if "randint" in encodings: 273 rand_encoding = self.param( 274 "randint_encoding", 275 lambda key, shape: jax.random.randint(key, shape, 0, 2).astype( 276 jnp.float32 277 ), 278 (zmaxpad, self.dim), 279 ) 280 conv_tensors.append(rand_encoding) 281 282 if "randtri" in encodings: 283 rand_encoding = self.param( 284 "randtri_encoding", 285 lambda key, shape: (jax.random.randint(key, shape, 0, 3) - 1).astype( 286 jnp.float32 287 ), 288 (zmaxpad, self.dim), 289 ) 290 conv_tensors.append(rand_encoding) 291 292 assert len(conv_tensors) > 0, f"No encoding recognized in '{self.encoding}'" 293 294 conv_tensor = jnp.concatenate(conv_tensors, axis=1) 295 296 species = inputs["species"] if isinstance(inputs, dict) else inputs 297 out = conv_tensor[species] 298 ############################ 299 300 if isinstance(inputs, dict): 301 output_key = self.name if self.output_key is None else self.output_key 302 out = out.astype(inputs["coordinates"].dtype) 303 return {**inputs, output_key: out} if output_key is not None else out 304 return out 305 306 307class RadialBasis(nn.Module): 308 """Computes a radial encoding of distances. 309 310 FID: RADIAL_BASIS 311 """ 312 313 end: float 314 """ The maximum distance to consider.""" 315 start: float = 0.0 316 """ The minimum distance to consider.""" 317 dim: int = 8 318 """ The dimension of the basis.""" 319 graph_key: Optional[str] = None 320 """ The key of the graph in the inputs.""" 321 output_key: Optional[str] = None 322 """ The key to use for the output in the returned dictionary.""" 323 basis: str = "bessel" 324 """ The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky".""" 325 trainable: bool = False 326 """ Whether the basis parameters are trainable or fixed.""" 327 enforce_positive: bool = False 328 """ Whether to enforce distance-start to be positive""" 329 gamma: float = 1.0 / (2 * au.BOHR) 330 """ The gamma parameter for the "spooky" basis.""" 331 n_levels: int = 10 332 """ The number of levels for the "levels" basis.""" 333 alt_bessel_norm: bool = False 334 """ If True, use the (2/(end-start))**0.5 normalization for the bessel basis.""" 335 extra_params: Dict = dataclasses.field(default_factory=dict) 336 """ Dictionary of extra parameters for the basis.""" 337 338 FID: ClassVar[str] = "RADIAL_BASIS" 339 340 @nn.compact 341 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 342 if self.graph_key is not None: 343 x = inputs[self.graph_key]["distances"] 344 else: 345 x = inputs["distances"] if isinstance(inputs, dict) else inputs 346 347 shape = x.shape 348 x = x.reshape(-1) 349 350 basis = self.basis.lower() 351 ############################ 352 if basis == "bessel": 353 c = self.end - self.start 354 x = x[:, None] - self.start 355 # if self.enforce_positive: 356 # x = jax.nn.softplus(x) 357 358 if self.trainable: 359 bessel_roots = self.param( 360 "bessel_roots", 361 lambda key, dim: jnp.asarray( 362 np.arange(1, dim + 1, dtype=x.dtype)[None, :] * (math.pi / c) 363 ), 364 self.dim, 365 ) 366 norm = 1.0 / jnp.max( 367 bessel_roots 368 ) # (2.0 / c) ** 0.5 /jnp.max(bessel_roots) 369 else: 370 bessel_roots = jnp.asarray( 371 np.arange(1, self.dim + 1, dtype=x.dtype)[None, :] * (math.pi / c) 372 ) 373 norm = 1.0 / ( 374 self.dim * math.pi / c 375 ) # (2.0 / c) ** 0.5/(self.dim*math.pi/c) 376 if self.alt_bessel_norm: 377 norm = (2.0 / c) ** 0.5 378 out = norm * jnp.sin(x * bessel_roots) / x 379 380 if self.enforce_positive: 381 out = jnp.where(x > 0, out * (1.0 - jnp.exp(-(x**2))), 0.0) 382 383 elif basis == "gaussian": 384 if self.trainable: 385 roots = self.param( 386 "radial_centers", 387 lambda key, dim, start, end: jnp.linspace( 388 start, end, dim + 1, dtype=x.dtype 389 )[None, :-1], 390 self.dim, 391 self.start, 392 self.end, 393 ) 394 eta = self.param( 395 "radial_etas", 396 lambda key, dim, start, end: jnp.full( 397 dim, 398 dim / (end - start), 399 dtype=x.dtype, 400 )[None, :], 401 self.dim, 402 self.start, 403 self.end, 404 ) 405 406 else: 407 roots = jnp.asarray( 408 np.linspace(self.start, self.end, self.dim + 1)[None, :-1], 409 dtype=x.dtype, 410 ) 411 eta = jnp.asarray( 412 np.full(self.dim, self.dim / (self.end - self.start))[None, :], 413 dtype=x.dtype, 414 ) 415 416 x = x[:, None] 417 x2 = (eta * (x - roots)) ** 2 418 out = jnp.exp(-x2) 419 if self.enforce_positive: 420 out = jnp.where( 421 x > self.start, 422 out * (1.0 - jnp.exp(-10 * (x - self.start) ** 2)), 423 0.0, 424 ) 425 426 elif basis == "gaussian_rinv": 427 rinv_high = 1.0 / max(self.start, 0.1) 428 rinv_low = 1.0 / (0.8 * self.end) 429 430 if self.trainable: 431 roots = self.param( 432 "radial_centers", 433 lambda key, dim, rinv_low, rinv_high: jnp.linspace( 434 rinv_low, rinv_high, dim, dtype=x.dtype 435 )[None, :], 436 self.dim, 437 rinv_low, 438 rinv_high, 439 ) 440 sigmas = self.param( 441 "radial_sigmas", 442 lambda key, dim, rinv_low, rinv_high: jnp.full( 443 dim, 444 2**0.5 / (2 * dim * rinv_low), 445 dtype=x.dtype, 446 )[None, :], 447 self.dim, 448 rinv_low, 449 rinv_high, 450 ) 451 else: 452 roots = jnp.asarray( 453 np.linspace(rinv_low, rinv_high, self.dim, dtype=x.dtype)[None, :] 454 ) 455 sigmas = jnp.asarray( 456 np.full( 457 self.dim, 458 2**0.5 / (2 * self.dim * rinv_low), 459 )[None, :], 460 dtype=x.dtype, 461 ) 462 463 rinv = 1.0 / x[:, None] 464 465 out = jnp.exp(-((rinv - roots) ** 2) / sigmas**2) 466 467 elif basis == "fourier": 468 if self.trainable: 469 roots = self.param( 470 "roots", 471 lambda key, dim: jnp.arange(dim, dtype=x.dtype)[None, :] * math.pi, 472 self.dim, 473 ) 474 else: 475 roots = jnp.asarray( 476 np.arange(self.dim)[None, :] * math.pi, dtype=x.dtype 477 ) 478 c = self.end - self.start 479 x = x[:, None] - self.start 480 # if self.enforce_positive: 481 # x = jax.nn.softplus(x) 482 norm = 1.0 / (0.25 + 0.5 * self.dim) ** 0.5 483 out = norm * jnp.cos(x * roots / c) 484 if self.enforce_positive: 485 out = jnp.where(x > 0, out, norm) 486 487 elif basis == "spooky": 488 489 gamma = self.gamma 490 if self.trainable: 491 gamma = jnp.abs( 492 self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype)) 493 ) 494 495 if self.enforce_positive: 496 x = jnp.where(x - self.start > 1.0e-3, x - self.start, 1.0e-3)[:, None] 497 dim = self.dim 498 else: 499 x = x[:, None] - self.start 500 dim = self.dim - 1 501 502 norms = [] 503 for k in range(self.dim): 504 norms.append(math.comb(dim, k)) 505 norms = jnp.asarray(np.array(norms)[None, :], dtype=x.dtype) 506 507 e = jnp.exp(-gamma * x) 508 k = jnp.asarray(np.arange(self.dim, dtype=x.dtype)[None, :]) 509 b = e**k * (1 - e) ** (dim - k) 510 out = b * e * norms 511 if self.enforce_positive: 512 out = jnp.where(x > 1.0e-3, out * (1.0 - jnp.exp(-(x**2))), 0.0) 513 # logfac = np.zeros(self.dim) 514 # for i in range(2, self.dim): 515 # logfac[i] = logfac[i - 1] + np.log(i) 516 # k = np.arange(self.dim) 517 # n = self.dim - 1 - k 518 # logbin = jnp.asarray((logfac[-1] - logfac[k] - logfac[n])[None,:], dtype=x.dtype) 519 # n = jnp.asarray(n[None,:], dtype=x.dtype) 520 # k = jnp.asarray(k[None,:], dtype=x.dtype) 521 522 # gamma = 1.0 / (2 * au.BOHR) 523 # if self.trainable: 524 # gamma = jnp.abs( 525 # self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype)) 526 # ) 527 # gammar = (-gamma * x)[:,None] 528 # x = logbin + n * gammar + k * jnp.log(-jnp.expm1(gammar)) 529 # out = jnp.exp(x)*jnp.exp(gammar) 530 elif basis == "levels": 531 assert self.n_levels >= 2, "Number of levels must be >= 2." 532 533 def initialize_levels(key): 534 key0, key1, key_phi = jax.random.split(key, 3) 535 level0 = jax.random.randint(key0, (self.dim,), 0, 2) 536 level1 = jax.random.randint(key1, (self.dim,), 0, 2) 537 # level0 = jax.random.normal(key0, (self.dim,), dtype=jnp.float32) 538 # level1 = jax.random.normal(key1, (self.dim,), dtype=jnp.float32) 539 phi = jax.random.uniform(key_phi, (self.dim,), dtype=jnp.float32) 540 levels = [level0] 541 for l in range(2, self.n_levels - 1): 542 tau = float(self.n_levels - l) / float(self.n_levels - 1) 543 phil = phi < tau 544 level = jnp.where(phil, level0, level1) 545 levels.append(level) 546 levels.append(level1) 547 return jnp.stack(levels).astype(jnp.float32) 548 549 levels = self.param("levels", initialize_levels) 550 # levels = self.param( 551 # "levels", 552 # lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32), 553 # (self.n_levels,self.dim), 554 # ) 555 556 flevel = (x - self.start) / (self.end - self.start) * (self.n_levels - 1) 557 ilevel = jnp.floor(flevel).astype(jnp.int32) 558 ilevel1 = jnp.clip(ilevel + 1, 0, self.n_levels - 1) 559 ilevel = jnp.clip(ilevel, 0, self.n_levels - 1) 560 561 dx = flevel - ilevel 562 w = 0.5 * (1 + jnp.cos(jnp.pi * dx))[:, None] 563 564 ## interpolate between level vectors 565 v1 = levels[ilevel] 566 v2 = levels[ilevel1] 567 out = v1 * w + v2 * (1 - w) 568 elif basis == "finite_support": 569 flevel = (x - self.start) / (self.end - self.start) * (self.dim + 1) 570 ilevel = jnp.floor(flevel).astype(jnp.int32) 571 ilevel1 = jnp.clip(ilevel + 1, 0, self.dim + 1) 572 ilevel = jnp.clip(ilevel, 0, self.dim + 1) 573 574 dx = flevel - ilevel 575 w = 0.5 * (1 + jnp.cos(jnp.pi * dx)) 576 577 ilevelflat = ilevel + jnp.arange(x.shape[0]) * (self.dim + 2) 578 ilevel1flat = ilevel1 + jnp.arange(x.shape[0]) * (self.dim + 2) 579 580 out = ( 581 jnp.zeros((x.shape[0] * (self.dim + 2)), dtype=x.dtype) 582 .at[ilevelflat] 583 .set(w) 584 .at[ilevel1flat] 585 .set(1 - w) 586 .reshape(-1, self.dim + 2)[:, 1:-1] 587 ) 588 589 elif basis == "exp_lr" or basis == "exp": 590 zeta = self.extra_params.get("zeta", default=2.0) 591 s = self.extra_params.get("s", default=0.5) 592 n = np.arange(self.dim) 593 # if self.trainable: 594 # zeta = jnp.abs( 595 # self.param("zeta", lambda key: jnp.asarray(zeta, dtype=x.dtype)) 596 # ) 597 # s = jnp.abs(self.param("s", lambda key: jnp.asarray(s, dtype=x.dtype))) 598 599 a = zeta * s**n 600 xx = np.linspace(0, self.end, 10000) 601 602 if self.start > 0: 603 a1 = np.minimum(1.0 / a, 1.0) 604 switchstart = jnp.where( 605 x[:, None] < self.end, 606 1 607 - ( 608 0.5 609 + 0.5 610 * jnp.cos( 611 np.pi * (x[:, None] - self.start) / (self.end - self.start) 612 ) 613 ) 614 ** a1[None, :], 615 1, 616 ) * (x[:, None] > self.start) 617 switchstartxx = ( 618 1 619 - ( 620 0.5 621 + 0.5 622 * np.cos( 623 np.pi * (xx[:, None] - self.start) / (self.end - self.start) 624 ) 625 ) 626 ** a1[None, :] 627 ) * (xx[:, None] > self.start) 628 629 else: 630 switchstart = 1.0 631 switchstartxx = 1.0 632 633 norm = 1.0 / np.trapz( 634 switchstartxx * np.exp(-a[None, :] * xx[:, None]), xx, axis=0 635 ) 636 # norm = 1./np.max(switchstartxx*np.exp(-a[None, :] * xx[:, None]), axis=0) 637 638 if self.trainable: 639 a = jnp.abs(self.param("exponent", lambda key: jnp.asarray(a))) 640 641 out = switchstart * jnp.exp(-a[None, :] * x[:, None]) * norm[None, :] 642 643 elif basis == "neural_net" or basis == "nn": 644 neurons = [*self.extra_params.get( 645 "hidden_neurons", default=[2 * self.dim] 646 ), self.dim] 647 activation = self.extra_params.get("activation", default="swish") 648 use_bias = self.extra_params.get("use_bias", default=True) 649 650 out = FullyConnectedNet( 651 neurons, activation=activation, use_bias=use_bias, squeeze=False 652 )(x[:, None]) 653 elif basis == "damped_coulomb": 654 l = np.arange(self.dim)[None, :] 655 a = self.extra_params.get("a", default=1.0) 656 if self.trainable: 657 a = self.param("a", lambda key: jnp.asarray([a]*self.dim)[None,:]) 658 a2=a**2 659 x = x[:, None] - self.start 660 end = self.end - self.start 661 x21 = a2 + x**2 662 R21 = a2 + end**2 663 out = ( 664 1.0 / x21 ** (0.5 * (l + 1)) 665 - 1.0 / (R21 ** (0.5 * (l + 1))) 666 + (x - end) * ((l + 1) * end / (R21 ** (0.5 * (l + 3)))) 667 ) * (x < end) 668 669 elif basis.startswith("spherical_bessel_"): 670 l = int(basis.split("_")[-1]) 671 out = generate_spherical_jn_basis(self.dim, self.end,[l])(x) 672 else: 673 raise NotImplementedError(f"Unknown radial basis {basis}.") 674 ############################ 675 676 out = out.reshape((*shape, self.dim)) 677 678 if self.graph_key is not None: 679 output_key = self.name if self.output_key is None else self.output_key 680 return {**inputs, output_key: out} 681 return out 682 683 684def positional_encoding_static(t, d: int, n: float = 10000.0): 685 if d % 2 == 0: 686 k = np.arange(d // 2) 687 else: 688 k = np.arange((d + 1) // 2) 689 wk = np.asarray(1.0 / (n ** (2 * k / d))) 690 wkt = wk[None, :] * t[:, None] 691 out = np.concatenate([np.cos(wkt), np.sin(wkt)], axis=-1) 692 if d % 2 == 1: 693 out = out[:, :-1] 694 return out 695 696 697@partial(jax.jit, static_argnums=(1, 2), inline=True) 698def positional_encoding(t, d: int, n: float = 10000.0): 699 if d % 2 == 0: 700 k = np.arange(d // 2) 701 else: 702 k = np.arange((d + 1) // 2) 703 wk = jnp.asarray(1.0 / (n ** (2 * k / d))) 704 wkt = wk[None, :] * t[:, None] 705 out = jnp.concatenate([jnp.cos(wkt), jnp.sin(wkt)], axis=-1) 706 if d % 2 == 1: 707 out = out[:, :-1] 708 return out 709 710 711def generate_spherical_jn_basis(dim:int, rc:float, ls:Union[int,Sequence[int]]=[0], print_code:bool=False, jit:bool=False): 712 from sympy import Symbol, jn, expand_func 713 from scipy.special import spherical_jn 714 from sympy import jn_zeros 715 import scipy.integrate as integrate 716 717 if isinstance(ls, int): 718 ls = list(range(ls + 1)) 719 zl = [Symbol(f"xz[...,{l}]") for l in ls] 720 zn = np.array([jn_zeros(l, dim) for l in ls], dtype=float).T 721 znrc = zn / rc 722 norms = np.zeros((dim, len(ls)), dtype=float) 723 for il,l in enumerate(ls): 724 for i in range(dim): 725 norms[i, il] = ( 726 integrate.quad(lambda x: (spherical_jn(l, x) * x) ** 2, 0, zn[i, il])[0] 727 / znrc[i, il] ** 3 728 ) ** (-0.5) 729 730 fn_str = f"""def spherical_jn_basis_(x): 731 from jax.numpy import cos,sin 732 733 znrc = jnp.array({znrc.tolist()},dtype=x.dtype) 734 norms = jnp.array({norms.tolist()},dtype=x.dtype) 735 xshape = x.shape 736 x = x.reshape(-1) 737 xz = x[:,None,None]*znrc[None,:,:] 738 739 jns = jnp.stack([ 740 """ 741 for il,l in enumerate(ls): 742 fn_str += f" {expand_func(jn(l, zl[il]))},\n" 743 fn_str += f""" ],axis=-1) 744 return (norms[None,:,:]*jns).reshape(*xshape,{dim},{len(ls)}) 745 """ 746 747 if print_code: 748 print(fn_str) 749 exec(fn_str) 750 jn_basis = locals()["spherical_jn_basis_"] 751 if jit: 752 jn_basis = jax.jit(jn_basis) 753 return jn_basis
28class SpeciesEncoding(nn.Module): 29 """A module that encodes chemical species information. 30 31 FID: SPECIES_ENCODING 32 """ 33 34 encoding: str = "random" 35 """ The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". 36 Multiple encodings can be concatenated using the "+" separator. 37 """ 38 dim: int = 16 39 """ The dimension of the encoding if not fixed by design.""" 40 zmax: int = 86 41 """ The maximum atomic number to encode.""" 42 output_key: Optional[str] = None 43 """ The key to use for the output in the returned dictionary.""" 44 45 species_order: Optional[Union[str, Sequence[str]]] = None 46 """ The order of the species to use for the encoding. Only used for "onehot" encoding. 47 If None, we encode all elements up to `zmax`.""" 48 trainable: bool = False 49 """ Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable.""" 50 extra_params: Dict = dataclasses.field(default_factory=dict) 51 """ Dictionary of extra parameters for the basis.""" 52 53 FID: ClassVar[str] = "SPECIES_ENCODING" 54 55 @nn.compact 56 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 57 58 zmax = self.zmax 59 if zmax <= 0 or zmax > len(PERIODIC_TABLE): 60 zmax = len(PERIODIC_TABLE) 61 62 zmaxpad = zmax + 2 63 64 encoding = self.encoding.lower().strip() 65 encodings = encoding.split("+") 66 ############################ 67 conv_tensors = [] 68 69 if "one_hot" in encodings or "onehot" in encodings: 70 if self.species_order is None: 71 conv_tensor = np.eye(zmax) 72 conv_tensor = np.concatenate( 73 [np.zeros((1, zmax)), conv_tensor, np.zeros((1, zmax))], axis=0 74 ) 75 else: 76 if isinstance(self.species_order, str): 77 species_order = [el.strip() for el in self.species_order.split(",")] 78 else: 79 species_order = [el for el in self.species_order] 80 conv_tensor = np.zeros((zmaxpad, len(species_order))) 81 for i, s in enumerate(species_order): 82 conv_tensor[PERIODIC_TABLE_REV_IDX[s], i] = 1 83 84 conv_tensors.append(conv_tensor) 85 86 if "occupation" in encodings: 87 conv_tensor = np.zeros((zmaxpad, (zmax + 1) // 2)) 88 for i in range(1, zmax + 1): 89 conv_tensor[i, : i // 2] = 1 90 if i % 2 == 1: 91 conv_tensor[i, i // 2] = 0.5 92 93 conv_tensors.append(conv_tensor) 94 95 if "electronic_structure" in encodings: 96 Z = np.arange(1, zmax + 1).reshape(-1, 1) 97 Zref = [zmax] 98 e_struct = np.array(EL_STRUCT[1 : zmax + 1]) 99 eref = [2, 2, 6, 2, 6, 2, 10, 6, 2, 10, 6, 2, 14, 10, 6, 2, 14, 10, 6] 100 v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1]) 101 vref = [2, 6, 10, 14] 102 if zmax <= 86: 103 e_struct = e_struct[:, :15] 104 eref = eref[:15] 105 ref = np.array(Zref + eref + vref) 106 conv_tensor = np.concatenate([Z, e_struct, v_struct], axis=1) 107 conv_tensor = conv_tensor / ref[None, :] 108 dim = conv_tensor.shape[1] 109 conv_tensor = np.concatenate( 110 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 111 ) 112 113 conv_tensors.append(conv_tensor) 114 115 if "valence_structure" in encodings: 116 vref = np.array([2, 6, 10, 14]) 117 conv_tensor = np.array(VALENCE_STRUCTURE[1 : zmax + 1])/vref[None,:] 118 dim = conv_tensor.shape[1] 119 conv_tensor = np.concatenate( 120 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 121 ) 122 conv_tensors.append(conv_tensor) 123 124 if "period_onehot" in encodings: 125 periods = PERIODIC_COORDINATES[1 : zmax + 1, 0]-1 126 conv_tensor = np.eye(7)[periods] 127 dim = conv_tensor.shape[1] 128 conv_tensor = np.concatenate( 129 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 130 ) 131 conv_tensors.append(conv_tensor) 132 133 if "period" in encodings: 134 conv_tensor = (PERIODIC_COORDINATES[1 : zmax + 1, 0]/7.)[:,None] 135 dim = conv_tensor.shape[1] 136 conv_tensor = np.concatenate( 137 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 138 ) 139 conv_tensors.append(conv_tensor) 140 141 if "period_positional" in encodings: 142 row = PERIODIC_COORDINATES[1 : zmax + 1, 0] 143 drow = self.extra_params.get("drow", default=5) 144 nrow = self.extra_params.get("nrow", default=100.0) 145 146 conv_tensor = positional_encoding_static(row, drow, nrow) 147 dim = conv_tensor.shape[1] 148 conv_tensor = np.concatenate( 149 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 150 ) 151 conv_tensors.append(conv_tensor) 152 153 154 if "properties" in encodings: 155 props = np.array(XENONPY_PROPS)[1:-1] 156 assert ( 157 self.zmax <= props.shape[0] 158 ), f"zmax > {props.shape[0]} not supported for xenonpy properties" 159 conv_tensor = props[1 : zmax + 1] 160 mean = np.mean(props, axis=0) 161 std = np.std(props, axis=0) 162 conv_tensor = (conv_tensor - mean[None, :]) / std[None, :] 163 dim = conv_tensor.shape[1] 164 conv_tensor = np.concatenate( 165 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 166 ) 167 conv_tensors.append(conv_tensor) 168 169 170 if "valence_properties" in encodings: 171 assert zmax <= 86, "Valence properties only available for zmax <= 86" 172 Z = np.arange(1, zmax + 1).reshape(-1, 1) 173 Zref = [zmax] 174 Zinv = 1.0 / Z 175 Zinvref = [1.0] 176 v_struct = np.array(VALENCE_STRUCTURE[1 : zmax + 1]) 177 vref = [2, 6, 10, 14] 178 ionization = np.array(ATOMIC_IONIZATION_ENERGY[1 : zmax + 1]).reshape(-1, 1) 179 ionizationref = [0.5] 180 polariz = np.array(POLARIZABILITIES[1 : zmax + 1]).reshape(-1, 1) 181 polarizref = [np.median(polariz)] 182 183 cov = np.array(D3_COV_RADII[1 : zmax + 1]).reshape(-1, 1) 184 covref = [np.median(cov)] 185 186 vdw = np.array(VDW_RADII[1 : zmax + 1]).reshape(-1, 1) 187 vdwref = [np.median(vdw)] 188 189 ref = np.array( 190 Zref + Zinvref + ionizationref + polarizref + covref + vdwref + vref 191 ) 192 conv_tensor = np.concatenate( 193 [Z, Zinv, ionization, polariz, cov, vdw, v_struct], axis=1 194 ) 195 196 conv_tensor = conv_tensor / ref[None, :] 197 dim = conv_tensor.shape[1] 198 conv_tensor = np.concatenate( 199 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 200 ) 201 conv_tensors.append(conv_tensor) 202 203 if "sjs_coordinates" in encodings: 204 coords = np.array(SJS_COORDINATES)[1:-1] 205 conv_tensor = coords[1 : zmax + 1] 206 mean = np.mean(coords, axis=0) 207 std = np.std(coords, axis=0) 208 conv_tensor = (conv_tensor - mean[None, :]) / std[None, :] 209 dim = conv_tensor.shape[1] 210 conv_tensor = np.concatenate( 211 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 212 ) 213 conv_tensors.append(conv_tensor) 214 215 if "positional" in encodings: 216 coords = np.array(PERIODIC_COORDINATES)[1 : zmax + 1] 217 row, col = coords[:, 0], coords[:, 1] 218 drow = self.extra_params.get("drow", default=5) 219 dcol = self.dim - drow 220 nrow = self.extra_params.get("nrow", default=100.0) 221 ncol = self.extra_params.get("ncol", default=1000.0) 222 223 erow = positional_encoding_static(row, drow, nrow) 224 ecol = positional_encoding_static(col, dcol, ncol) 225 conv_tensor = np.concatenate([erow, ecol], axis=-1) 226 dim = conv_tensor.shape[1] 227 conv_tensor = np.concatenate( 228 [np.zeros((1, dim)), conv_tensor, np.zeros((1, dim))], axis=0 229 ) 230 conv_tensors.append(conv_tensor) 231 232 if "oxidation" in encodings: 233 states_set = sorted(set(sum(OXIDATION_STATES, []))) 234 nstates = len(states_set) 235 state_dict = {s: i for i, s in enumerate(states_set)} 236 conv_tensor = np.zeros((zmaxpad, nstates)) 237 for i, states in enumerate(OXIDATION_STATES[1 : zmax + 1]): 238 for s in states: 239 conv_tensor[i, state_dict[s]] = 1 240 conv_tensors.append(conv_tensor) 241 242 if len(conv_tensors) > 0: 243 conv_tensor = np.concatenate(conv_tensors, axis=1) 244 if self.trainable: 245 conv_tensor = self.param( 246 "conv_tensor", 247 lambda key: jnp.asarray(conv_tensor, dtype=jnp.float32), 248 ) 249 else: 250 conv_tensor = jnp.asarray(conv_tensor, dtype=jnp.float32) 251 conv_tensors = [conv_tensor] 252 else: 253 conv_tensors = [] 254 255 if "random" in encodings: 256 rand_encoding = self.param( 257 "rand_encoding", 258 lambda key, shape: jax.nn.standardize( 259 jax.random.normal(key, shape, dtype=jnp.float32) 260 ), 261 (zmaxpad, self.dim), 262 ) 263 conv_tensors.append(rand_encoding) 264 265 if "zeros" in encodings: 266 zeros_encoding = self.param( 267 "zeros_encoding", 268 lambda key, shape: jnp.zeros(shape, dtype=jnp.float32), 269 (zmaxpad, self.dim), 270 ) 271 conv_tensors.append(zeros_encoding) 272 273 if "randint" in encodings: 274 rand_encoding = self.param( 275 "randint_encoding", 276 lambda key, shape: jax.random.randint(key, shape, 0, 2).astype( 277 jnp.float32 278 ), 279 (zmaxpad, self.dim), 280 ) 281 conv_tensors.append(rand_encoding) 282 283 if "randtri" in encodings: 284 rand_encoding = self.param( 285 "randtri_encoding", 286 lambda key, shape: (jax.random.randint(key, shape, 0, 3) - 1).astype( 287 jnp.float32 288 ), 289 (zmaxpad, self.dim), 290 ) 291 conv_tensors.append(rand_encoding) 292 293 assert len(conv_tensors) > 0, f"No encoding recognized in '{self.encoding}'" 294 295 conv_tensor = jnp.concatenate(conv_tensors, axis=1) 296 297 species = inputs["species"] if isinstance(inputs, dict) else inputs 298 out = conv_tensor[species] 299 ############################ 300 301 if isinstance(inputs, dict): 302 output_key = self.name if self.output_key is None else self.output_key 303 out = out.astype(inputs["coordinates"].dtype) 304 return {**inputs, output_key: out} if output_key is not None else out 305 return out
A module that encodes chemical species information.
FID: SPECIES_ENCODING
The encoding to use. Can be one of "one_hot", "occupation", "electronic_structure", "properties", "sjs_coordinates", "random". Multiple encodings can be concatenated using the "+" separator.
The order of the species to use for the encoding. Only used for "onehot" encoding.
If None, we encode all elements up to zmax
.
Whether the encoding is trainable or fixed. Does not apply to "random" encoding which is always trainable.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
308class RadialBasis(nn.Module): 309 """Computes a radial encoding of distances. 310 311 FID: RADIAL_BASIS 312 """ 313 314 end: float 315 """ The maximum distance to consider.""" 316 start: float = 0.0 317 """ The minimum distance to consider.""" 318 dim: int = 8 319 """ The dimension of the basis.""" 320 graph_key: Optional[str] = None 321 """ The key of the graph in the inputs.""" 322 output_key: Optional[str] = None 323 """ The key to use for the output in the returned dictionary.""" 324 basis: str = "bessel" 325 """ The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky".""" 326 trainable: bool = False 327 """ Whether the basis parameters are trainable or fixed.""" 328 enforce_positive: bool = False 329 """ Whether to enforce distance-start to be positive""" 330 gamma: float = 1.0 / (2 * au.BOHR) 331 """ The gamma parameter for the "spooky" basis.""" 332 n_levels: int = 10 333 """ The number of levels for the "levels" basis.""" 334 alt_bessel_norm: bool = False 335 """ If True, use the (2/(end-start))**0.5 normalization for the bessel basis.""" 336 extra_params: Dict = dataclasses.field(default_factory=dict) 337 """ Dictionary of extra parameters for the basis.""" 338 339 FID: ClassVar[str] = "RADIAL_BASIS" 340 341 @nn.compact 342 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 343 if self.graph_key is not None: 344 x = inputs[self.graph_key]["distances"] 345 else: 346 x = inputs["distances"] if isinstance(inputs, dict) else inputs 347 348 shape = x.shape 349 x = x.reshape(-1) 350 351 basis = self.basis.lower() 352 ############################ 353 if basis == "bessel": 354 c = self.end - self.start 355 x = x[:, None] - self.start 356 # if self.enforce_positive: 357 # x = jax.nn.softplus(x) 358 359 if self.trainable: 360 bessel_roots = self.param( 361 "bessel_roots", 362 lambda key, dim: jnp.asarray( 363 np.arange(1, dim + 1, dtype=x.dtype)[None, :] * (math.pi / c) 364 ), 365 self.dim, 366 ) 367 norm = 1.0 / jnp.max( 368 bessel_roots 369 ) # (2.0 / c) ** 0.5 /jnp.max(bessel_roots) 370 else: 371 bessel_roots = jnp.asarray( 372 np.arange(1, self.dim + 1, dtype=x.dtype)[None, :] * (math.pi / c) 373 ) 374 norm = 1.0 / ( 375 self.dim * math.pi / c 376 ) # (2.0 / c) ** 0.5/(self.dim*math.pi/c) 377 if self.alt_bessel_norm: 378 norm = (2.0 / c) ** 0.5 379 out = norm * jnp.sin(x * bessel_roots) / x 380 381 if self.enforce_positive: 382 out = jnp.where(x > 0, out * (1.0 - jnp.exp(-(x**2))), 0.0) 383 384 elif basis == "gaussian": 385 if self.trainable: 386 roots = self.param( 387 "radial_centers", 388 lambda key, dim, start, end: jnp.linspace( 389 start, end, dim + 1, dtype=x.dtype 390 )[None, :-1], 391 self.dim, 392 self.start, 393 self.end, 394 ) 395 eta = self.param( 396 "radial_etas", 397 lambda key, dim, start, end: jnp.full( 398 dim, 399 dim / (end - start), 400 dtype=x.dtype, 401 )[None, :], 402 self.dim, 403 self.start, 404 self.end, 405 ) 406 407 else: 408 roots = jnp.asarray( 409 np.linspace(self.start, self.end, self.dim + 1)[None, :-1], 410 dtype=x.dtype, 411 ) 412 eta = jnp.asarray( 413 np.full(self.dim, self.dim / (self.end - self.start))[None, :], 414 dtype=x.dtype, 415 ) 416 417 x = x[:, None] 418 x2 = (eta * (x - roots)) ** 2 419 out = jnp.exp(-x2) 420 if self.enforce_positive: 421 out = jnp.where( 422 x > self.start, 423 out * (1.0 - jnp.exp(-10 * (x - self.start) ** 2)), 424 0.0, 425 ) 426 427 elif basis == "gaussian_rinv": 428 rinv_high = 1.0 / max(self.start, 0.1) 429 rinv_low = 1.0 / (0.8 * self.end) 430 431 if self.trainable: 432 roots = self.param( 433 "radial_centers", 434 lambda key, dim, rinv_low, rinv_high: jnp.linspace( 435 rinv_low, rinv_high, dim, dtype=x.dtype 436 )[None, :], 437 self.dim, 438 rinv_low, 439 rinv_high, 440 ) 441 sigmas = self.param( 442 "radial_sigmas", 443 lambda key, dim, rinv_low, rinv_high: jnp.full( 444 dim, 445 2**0.5 / (2 * dim * rinv_low), 446 dtype=x.dtype, 447 )[None, :], 448 self.dim, 449 rinv_low, 450 rinv_high, 451 ) 452 else: 453 roots = jnp.asarray( 454 np.linspace(rinv_low, rinv_high, self.dim, dtype=x.dtype)[None, :] 455 ) 456 sigmas = jnp.asarray( 457 np.full( 458 self.dim, 459 2**0.5 / (2 * self.dim * rinv_low), 460 )[None, :], 461 dtype=x.dtype, 462 ) 463 464 rinv = 1.0 / x[:, None] 465 466 out = jnp.exp(-((rinv - roots) ** 2) / sigmas**2) 467 468 elif basis == "fourier": 469 if self.trainable: 470 roots = self.param( 471 "roots", 472 lambda key, dim: jnp.arange(dim, dtype=x.dtype)[None, :] * math.pi, 473 self.dim, 474 ) 475 else: 476 roots = jnp.asarray( 477 np.arange(self.dim)[None, :] * math.pi, dtype=x.dtype 478 ) 479 c = self.end - self.start 480 x = x[:, None] - self.start 481 # if self.enforce_positive: 482 # x = jax.nn.softplus(x) 483 norm = 1.0 / (0.25 + 0.5 * self.dim) ** 0.5 484 out = norm * jnp.cos(x * roots / c) 485 if self.enforce_positive: 486 out = jnp.where(x > 0, out, norm) 487 488 elif basis == "spooky": 489 490 gamma = self.gamma 491 if self.trainable: 492 gamma = jnp.abs( 493 self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype)) 494 ) 495 496 if self.enforce_positive: 497 x = jnp.where(x - self.start > 1.0e-3, x - self.start, 1.0e-3)[:, None] 498 dim = self.dim 499 else: 500 x = x[:, None] - self.start 501 dim = self.dim - 1 502 503 norms = [] 504 for k in range(self.dim): 505 norms.append(math.comb(dim, k)) 506 norms = jnp.asarray(np.array(norms)[None, :], dtype=x.dtype) 507 508 e = jnp.exp(-gamma * x) 509 k = jnp.asarray(np.arange(self.dim, dtype=x.dtype)[None, :]) 510 b = e**k * (1 - e) ** (dim - k) 511 out = b * e * norms 512 if self.enforce_positive: 513 out = jnp.where(x > 1.0e-3, out * (1.0 - jnp.exp(-(x**2))), 0.0) 514 # logfac = np.zeros(self.dim) 515 # for i in range(2, self.dim): 516 # logfac[i] = logfac[i - 1] + np.log(i) 517 # k = np.arange(self.dim) 518 # n = self.dim - 1 - k 519 # logbin = jnp.asarray((logfac[-1] - logfac[k] - logfac[n])[None,:], dtype=x.dtype) 520 # n = jnp.asarray(n[None,:], dtype=x.dtype) 521 # k = jnp.asarray(k[None,:], dtype=x.dtype) 522 523 # gamma = 1.0 / (2 * au.BOHR) 524 # if self.trainable: 525 # gamma = jnp.abs( 526 # self.param("gamma", lambda key: jnp.asarray(gamma, dtype=x.dtype)) 527 # ) 528 # gammar = (-gamma * x)[:,None] 529 # x = logbin + n * gammar + k * jnp.log(-jnp.expm1(gammar)) 530 # out = jnp.exp(x)*jnp.exp(gammar) 531 elif basis == "levels": 532 assert self.n_levels >= 2, "Number of levels must be >= 2." 533 534 def initialize_levels(key): 535 key0, key1, key_phi = jax.random.split(key, 3) 536 level0 = jax.random.randint(key0, (self.dim,), 0, 2) 537 level1 = jax.random.randint(key1, (self.dim,), 0, 2) 538 # level0 = jax.random.normal(key0, (self.dim,), dtype=jnp.float32) 539 # level1 = jax.random.normal(key1, (self.dim,), dtype=jnp.float32) 540 phi = jax.random.uniform(key_phi, (self.dim,), dtype=jnp.float32) 541 levels = [level0] 542 for l in range(2, self.n_levels - 1): 543 tau = float(self.n_levels - l) / float(self.n_levels - 1) 544 phil = phi < tau 545 level = jnp.where(phil, level0, level1) 546 levels.append(level) 547 levels.append(level1) 548 return jnp.stack(levels).astype(jnp.float32) 549 550 levels = self.param("levels", initialize_levels) 551 # levels = self.param( 552 # "levels", 553 # lambda key, shape: jax.random.normal(key, shape, dtype=jnp.float32), 554 # (self.n_levels,self.dim), 555 # ) 556 557 flevel = (x - self.start) / (self.end - self.start) * (self.n_levels - 1) 558 ilevel = jnp.floor(flevel).astype(jnp.int32) 559 ilevel1 = jnp.clip(ilevel + 1, 0, self.n_levels - 1) 560 ilevel = jnp.clip(ilevel, 0, self.n_levels - 1) 561 562 dx = flevel - ilevel 563 w = 0.5 * (1 + jnp.cos(jnp.pi * dx))[:, None] 564 565 ## interpolate between level vectors 566 v1 = levels[ilevel] 567 v2 = levels[ilevel1] 568 out = v1 * w + v2 * (1 - w) 569 elif basis == "finite_support": 570 flevel = (x - self.start) / (self.end - self.start) * (self.dim + 1) 571 ilevel = jnp.floor(flevel).astype(jnp.int32) 572 ilevel1 = jnp.clip(ilevel + 1, 0, self.dim + 1) 573 ilevel = jnp.clip(ilevel, 0, self.dim + 1) 574 575 dx = flevel - ilevel 576 w = 0.5 * (1 + jnp.cos(jnp.pi * dx)) 577 578 ilevelflat = ilevel + jnp.arange(x.shape[0]) * (self.dim + 2) 579 ilevel1flat = ilevel1 + jnp.arange(x.shape[0]) * (self.dim + 2) 580 581 out = ( 582 jnp.zeros((x.shape[0] * (self.dim + 2)), dtype=x.dtype) 583 .at[ilevelflat] 584 .set(w) 585 .at[ilevel1flat] 586 .set(1 - w) 587 .reshape(-1, self.dim + 2)[:, 1:-1] 588 ) 589 590 elif basis == "exp_lr" or basis == "exp": 591 zeta = self.extra_params.get("zeta", default=2.0) 592 s = self.extra_params.get("s", default=0.5) 593 n = np.arange(self.dim) 594 # if self.trainable: 595 # zeta = jnp.abs( 596 # self.param("zeta", lambda key: jnp.asarray(zeta, dtype=x.dtype)) 597 # ) 598 # s = jnp.abs(self.param("s", lambda key: jnp.asarray(s, dtype=x.dtype))) 599 600 a = zeta * s**n 601 xx = np.linspace(0, self.end, 10000) 602 603 if self.start > 0: 604 a1 = np.minimum(1.0 / a, 1.0) 605 switchstart = jnp.where( 606 x[:, None] < self.end, 607 1 608 - ( 609 0.5 610 + 0.5 611 * jnp.cos( 612 np.pi * (x[:, None] - self.start) / (self.end - self.start) 613 ) 614 ) 615 ** a1[None, :], 616 1, 617 ) * (x[:, None] > self.start) 618 switchstartxx = ( 619 1 620 - ( 621 0.5 622 + 0.5 623 * np.cos( 624 np.pi * (xx[:, None] - self.start) / (self.end - self.start) 625 ) 626 ) 627 ** a1[None, :] 628 ) * (xx[:, None] > self.start) 629 630 else: 631 switchstart = 1.0 632 switchstartxx = 1.0 633 634 norm = 1.0 / np.trapz( 635 switchstartxx * np.exp(-a[None, :] * xx[:, None]), xx, axis=0 636 ) 637 # norm = 1./np.max(switchstartxx*np.exp(-a[None, :] * xx[:, None]), axis=0) 638 639 if self.trainable: 640 a = jnp.abs(self.param("exponent", lambda key: jnp.asarray(a))) 641 642 out = switchstart * jnp.exp(-a[None, :] * x[:, None]) * norm[None, :] 643 644 elif basis == "neural_net" or basis == "nn": 645 neurons = [*self.extra_params.get( 646 "hidden_neurons", default=[2 * self.dim] 647 ), self.dim] 648 activation = self.extra_params.get("activation", default="swish") 649 use_bias = self.extra_params.get("use_bias", default=True) 650 651 out = FullyConnectedNet( 652 neurons, activation=activation, use_bias=use_bias, squeeze=False 653 )(x[:, None]) 654 elif basis == "damped_coulomb": 655 l = np.arange(self.dim)[None, :] 656 a = self.extra_params.get("a", default=1.0) 657 if self.trainable: 658 a = self.param("a", lambda key: jnp.asarray([a]*self.dim)[None,:]) 659 a2=a**2 660 x = x[:, None] - self.start 661 end = self.end - self.start 662 x21 = a2 + x**2 663 R21 = a2 + end**2 664 out = ( 665 1.0 / x21 ** (0.5 * (l + 1)) 666 - 1.0 / (R21 ** (0.5 * (l + 1))) 667 + (x - end) * ((l + 1) * end / (R21 ** (0.5 * (l + 3)))) 668 ) * (x < end) 669 670 elif basis.startswith("spherical_bessel_"): 671 l = int(basis.split("_")[-1]) 672 out = generate_spherical_jn_basis(self.dim, self.end,[l])(x) 673 else: 674 raise NotImplementedError(f"Unknown radial basis {basis}.") 675 ############################ 676 677 out = out.reshape((*shape, self.dim)) 678 679 if self.graph_key is not None: 680 output_key = self.name if self.output_key is None else self.output_key 681 return {**inputs, output_key: out} 682 return out
Computes a radial encoding of distances.
FID: RADIAL_BASIS
The basis to use. Can be one of "bessel", "gaussian", "gaussian_rinv", "fourier", "spooky".
If True, use the (2/(end-start))**0.5 normalization for the bessel basis.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
685def positional_encoding_static(t, d: int, n: float = 10000.0): 686 if d % 2 == 0: 687 k = np.arange(d // 2) 688 else: 689 k = np.arange((d + 1) // 2) 690 wk = np.asarray(1.0 / (n ** (2 * k / d))) 691 wkt = wk[None, :] * t[:, None] 692 out = np.concatenate([np.cos(wkt), np.sin(wkt)], axis=-1) 693 if d % 2 == 1: 694 out = out[:, :-1] 695 return out
698@partial(jax.jit, static_argnums=(1, 2), inline=True) 699def positional_encoding(t, d: int, n: float = 10000.0): 700 if d % 2 == 0: 701 k = np.arange(d // 2) 702 else: 703 k = np.arange((d + 1) // 2) 704 wk = jnp.asarray(1.0 / (n ** (2 * k / d))) 705 wkt = wk[None, :] * t[:, None] 706 out = jnp.concatenate([jnp.cos(wkt), jnp.sin(wkt)], axis=-1) 707 if d % 2 == 1: 708 out = out[:, :-1] 709 return out
712def generate_spherical_jn_basis(dim:int, rc:float, ls:Union[int,Sequence[int]]=[0], print_code:bool=False, jit:bool=False): 713 from sympy import Symbol, jn, expand_func 714 from scipy.special import spherical_jn 715 from sympy import jn_zeros 716 import scipy.integrate as integrate 717 718 if isinstance(ls, int): 719 ls = list(range(ls + 1)) 720 zl = [Symbol(f"xz[...,{l}]") for l in ls] 721 zn = np.array([jn_zeros(l, dim) for l in ls], dtype=float).T 722 znrc = zn / rc 723 norms = np.zeros((dim, len(ls)), dtype=float) 724 for il,l in enumerate(ls): 725 for i in range(dim): 726 norms[i, il] = ( 727 integrate.quad(lambda x: (spherical_jn(l, x) * x) ** 2, 0, zn[i, il])[0] 728 / znrc[i, il] ** 3 729 ) ** (-0.5) 730 731 fn_str = f"""def spherical_jn_basis_(x): 732 from jax.numpy import cos,sin 733 734 znrc = jnp.array({znrc.tolist()},dtype=x.dtype) 735 norms = jnp.array({norms.tolist()},dtype=x.dtype) 736 xshape = x.shape 737 x = x.reshape(-1) 738 xz = x[:,None,None]*znrc[None,:,:] 739 740 jns = jnp.stack([ 741 """ 742 for il,l in enumerate(ls): 743 fn_str += f" {expand_func(jn(l, zl[il]))},\n" 744 fn_str += f""" ],axis=-1) 745 return (norms[None,:,:]*jns).reshape(*xshape,{dim},{len(ls)}) 746 """ 747 748 if print_code: 749 print(fn_str) 750 exec(fn_str) 751 jn_basis = locals()["spherical_jn_basis_"] 752 if jit: 753 jn_basis = jax.jit(jn_basis) 754 return jn_basis