fennol.models.preprocessing
1import flax.linen as nn 2from typing import Sequence, Callable, Union, Dict, Any, ClassVar 3import jax.numpy as jnp 4import jax 5import numpy as np 6from typing import Optional, Tuple 7import numba 8import dataclasses 9from functools import partial 10 11from flax.core.frozen_dict import FrozenDict 12 13 14from ..utils.activations import chain,safe_sqrt 15from ..utils import deep_update, mask_filter_1d 16from ..utils.kspace import get_reciprocal_space_parameters 17from .misc.misc import SwitchFunction 18from ..utils.periodic_table import PERIODIC_TABLE, PERIODIC_TABLE_REV_IDX,CHEMICAL_BLOCKS,CHEMICAL_BLOCKS_NAMES 19 20 21@dataclasses.dataclass(frozen=True) 22class GraphGenerator: 23 """Generate a graph from a set of coordinates 24 25 FPID: GRAPH 26 27 For now, we generate all pairs of atoms and filter based on cutoff. 28 If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist. 29 """ 30 31 cutoff: float 32 """Cutoff distance for the graph.""" 33 graph_key: str = "graph" 34 """Key of the graph in the outputs.""" 35 switch_params: dict = dataclasses.field(default_factory=dict, hash=False) 36 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 37 kmax: int = 30 38 """Maximum number of k-points to consider.""" 39 kthr: float = 1e-6 40 """Threshold for k-point filtering.""" 41 k_space: bool = False 42 """Whether to generate k-space information for the graph.""" 43 mult_size: float = 1.05 44 """Multiplicative factor for resizing the nblist.""" 45 # covalent_cutoff: bool = False 46 47 FPID: ClassVar[str] = "GRAPH" 48 49 def init(self): 50 return FrozenDict( 51 { 52 "max_nat": 1, 53 "npairs": 1, 54 "nblist_mult_size": self.mult_size, 55 } 56 ) 57 58 def get_processor(self) -> Tuple[nn.Module, Dict]: 59 return GraphProcessor, { 60 "cutoff": self.cutoff, 61 "graph_key": self.graph_key, 62 "switch_params": self.switch_params, 63 "name": f"{self.graph_key}_Processor", 64 } 65 66 def get_graph_properties(self): 67 return { 68 self.graph_key: { 69 "cutoff": self.cutoff, 70 "directed": True, 71 } 72 } 73 74 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 75 """build a nblist on cpu with numpy and dynamic shapes + store max shapes""" 76 if self.graph_key in inputs: 77 graph = inputs[self.graph_key] 78 if "keep_graph" in graph: 79 return state, inputs 80 81 coords = np.array(inputs["coordinates"], dtype=np.float32) 82 natoms = np.array(inputs["natoms"], dtype=np.int32) 83 batch_index = np.array(inputs["batch_index"], dtype=np.int32) 84 85 new_state = {**state} 86 state_up = {} 87 88 mult_size = state.get("nblist_mult_size", self.mult_size) 89 assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0" 90 91 if natoms.shape[0] == 1: 92 max_nat = coords.shape[0] 93 true_max_nat = max_nat 94 else: 95 max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0])) 96 true_max_nat = int(np.max(natoms)) 97 if true_max_nat > max_nat: 98 add_atoms = state.get("add_atoms", 0) 99 new_maxnat = true_max_nat + add_atoms 100 state_up["max_nat"] = (new_maxnat, max_nat) 101 new_state["max_nat"] = new_maxnat 102 103 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 104 105 ### compute indices of all pairs 106 p1, p2 = np.triu_indices(true_max_nat, 1) 107 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 108 pbc_shifts = None 109 if natoms.shape[0] > 1: 110 ## batching => mask irrelevant pairs 111 mask_p12 = ( 112 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 113 ).flatten() 114 shift = np.concatenate( 115 (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32)) 116 ) 117 p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 118 p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 119 120 apply_pbc = "cells" in inputs 121 if not apply_pbc: 122 ### NO PBC 123 vec = coords[p2] - coords[p1] 124 else: 125 cells = np.array(inputs["cells"], dtype=np.float32) 126 reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32) 127 minimage = state.get("minimum_image", True) 128 if minimage: 129 ## MINIMUM IMAGE CONVENTION 130 vec = coords[p2] - coords[p1] 131 if cells.shape[0] == 1: 132 vecpbc = np.dot(vec, reciprocal_cells[0]) 133 pbc_shifts = -np.round(vecpbc) 134 vec = vec + np.dot(pbc_shifts, cells[0]) 135 else: 136 batch_index_vec = batch_index[p1] 137 vecpbc = np.einsum( 138 "aj,aji->ai", vec, reciprocal_cells[batch_index_vec] 139 ) 140 pbc_shifts = -np.round(vecpbc) 141 vec = vec + np.einsum( 142 "aj,aji->ai", pbc_shifts, cells[batch_index_vec] 143 ) 144 else: 145 ### GENERAL PBC 146 ## put all atoms in central box 147 if cells.shape[0] == 1: 148 coords_pbc = np.dot(coords, reciprocal_cells[0]) 149 at_shifts = -np.floor(coords_pbc) 150 coords_pbc = coords + np.dot(at_shifts, cells[0]) 151 else: 152 coords_pbc = np.einsum( 153 "aj,aji->ai", coords, reciprocal_cells[batch_index] 154 ) 155 at_shifts = -np.floor(coords_pbc) 156 coords_pbc = coords + np.einsum( 157 "aj,aji->ai", at_shifts, cells[batch_index] 158 ) 159 vec = coords_pbc[p2] - coords_pbc[p1] 160 161 ## compute maximum number of repeats 162 inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5 163 cdinv = cutoff_skin * inv_distances 164 num_repeats_all = np.ceil(cdinv).astype(np.int32) 165 if "true_sys" in inputs: 166 num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0) 167 # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all) 168 num_repeats = np.max(num_repeats_all, axis=0) 169 num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0))) 170 if np.any(num_repeats > num_repeats_prev): 171 num_repeats_new = np.maximum(num_repeats, num_repeats_prev) 172 state_up["num_repeats_pbc"] = ( 173 tuple(num_repeats_new), 174 tuple(num_repeats_prev), 175 ) 176 new_state["num_repeats_pbc"] = tuple(num_repeats_new) 177 ## build all possible shifts 178 cell_shift_pbc = np.array( 179 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 180 dtype=cells.dtype, 181 ).T.reshape(-1, 3) 182 ## shift applied to vectors 183 if cells.shape[0] == 1: 184 dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :] 185 vec = (vec[:, None, :] + dvec).reshape(-1, 3) 186 pbc_shifts = np.broadcast_to( 187 cell_shift_pbc[None, :, :], 188 (p1.shape[0], cell_shift_pbc.shape[0], 3), 189 ).reshape(-1, 3) 190 p1 = np.broadcast_to( 191 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 192 ).flatten() 193 p2 = np.broadcast_to( 194 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 195 ).flatten() 196 if natoms.shape[0] > 1: 197 mask_p12 = np.broadcast_to( 198 mask_p12[:, None], 199 (mask_p12.shape[0], cell_shift_pbc.shape[0]), 200 ).flatten() 201 else: 202 dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells) 203 204 ## get pbc shifts specific to each box 205 cell_shift_pbc = np.broadcast_to( 206 cell_shift_pbc[None, :, :], 207 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 208 ) 209 mask = np.all( 210 np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 211 ).flatten() 212 idx = np.nonzero(mask)[0] 213 nshifts = idx.shape[0] 214 nshifts_prev = state.get("nshifts_pbc", 0) 215 if nshifts > nshifts_prev or add_margin: 216 nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1 217 state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev) 218 new_state["nshifts_pbc"] = nshifts_new 219 220 dvec_filter = dvec.reshape(-1, 3)[idx, :] 221 cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :] 222 223 ## get batch shift in the dvec_filter array 224 nrep = np.prod(2 * num_repeats_all + 1, axis=1) 225 bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1])) 226 227 ## compute vectors 228 batch_index_vec = batch_index[p1] 229 nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0) 230 vec = vec.repeat(nrep_vec, axis=0) 231 nvec_pbc = nrep_vec.sum() #vec.shape[0] 232 nvec_pbc_prev = state.get("nvec_pbc", 0) 233 if nvec_pbc > nvec_pbc_prev or add_margin: 234 nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1 235 state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev) 236 new_state["nvec_pbc"] = nvec_pbc_new 237 238 # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev) 239 ## get shift index 240 dshift = np.concatenate( 241 (np.array([0]), np.cumsum(nrep_vec)[:-1]) 242 ).repeat(nrep_vec) 243 # ishift = np.arange(dshift.shape[0])-dshift 244 # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec) 245 icellshift = ( 246 np.arange(dshift.shape[0]) 247 - dshift 248 + bshift[batch_index_vec].repeat(nrep_vec) 249 ) 250 # shift vectors 251 vec = vec + dvec_filter[icellshift] 252 pbc_shifts = cell_shift_pbc_filter[icellshift] 253 254 p1 = np.repeat(p1, nrep_vec) 255 p2 = np.repeat(p2, nrep_vec) 256 if natoms.shape[0] > 1: 257 mask_p12 = np.repeat(mask_p12, nrep_vec) 258 259 ## compute distances 260 d12 = (vec**2).sum(axis=-1) 261 if natoms.shape[0] > 1: 262 d12 = np.where(mask_p12, d12, cutoff_skin**2) 263 264 ## filter pairs 265 max_pairs = state.get("npairs", 1) 266 mask = d12 < cutoff_skin**2 267 idx = np.nonzero(mask)[0] 268 npairs = idx.shape[0] 269 if npairs > max_pairs or add_margin: 270 prev_max_pairs = max_pairs 271 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 272 state_up["npairs"] = (max_pairs, prev_max_pairs) 273 new_state["npairs"] = max_pairs 274 275 nat = coords.shape[0] 276 edge_src = np.full(max_pairs, nat, dtype=np.int32) 277 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 278 d12_ = np.full(max_pairs, cutoff_skin**2) 279 edge_src[:npairs] = p1[idx] 280 edge_dst[:npairs] = p2[idx] 281 d12_[:npairs] = d12[idx] 282 d12 = d12_ 283 284 if apply_pbc: 285 pbc_shifts_ = np.zeros((max_pairs, 3)) 286 pbc_shifts_[:npairs] = pbc_shifts[idx] 287 pbc_shifts = pbc_shifts_ 288 if not minimage: 289 pbc_shifts[:npairs] = ( 290 pbc_shifts[:npairs] 291 + at_shifts[edge_dst[:npairs]] 292 - at_shifts[edge_src[:npairs]] 293 ) 294 295 if "nblist_skin" in state: 296 edge_src_skin = edge_src 297 edge_dst_skin = edge_dst 298 if apply_pbc: 299 pbc_shifts_skin = pbc_shifts 300 max_pairs_skin = state.get("npairs_skin", 1) 301 mask = d12 < self.cutoff**2 302 idx = np.nonzero(mask)[0] 303 npairs_skin = idx.shape[0] 304 if npairs_skin > max_pairs_skin or add_margin: 305 prev_max_pairs_skin = max_pairs_skin 306 max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1 307 state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin) 308 new_state["npairs_skin"] = max_pairs_skin 309 edge_src = np.full(max_pairs_skin, nat, dtype=np.int32) 310 edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32) 311 d12_ = np.full(max_pairs_skin, self.cutoff**2) 312 edge_src[:npairs_skin] = edge_src_skin[idx] 313 edge_dst[:npairs_skin] = edge_dst_skin[idx] 314 d12_[:npairs_skin] = d12[idx] 315 d12 = d12_ 316 if apply_pbc: 317 pbc_shifts = np.full((max_pairs_skin, 3), 0.0) 318 pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx] 319 320 ## symmetrize 321 edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate( 322 (edge_dst, edge_src) 323 ) 324 d12 = np.concatenate((d12, d12)) 325 if apply_pbc: 326 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts)) 327 328 graph = inputs.get(self.graph_key, {}) 329 graph_out = { 330 **graph, 331 "edge_src": edge_src, 332 "edge_dst": edge_dst, 333 "d12": d12, 334 "overflow": False, 335 "pbc_shifts": pbc_shifts, 336 } 337 if "nblist_skin" in state: 338 graph_out["edge_src_skin"] = edge_src_skin 339 graph_out["edge_dst_skin"] = edge_dst_skin 340 if apply_pbc: 341 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 342 343 if self.k_space and apply_pbc: 344 if "k_points" not in graph: 345 ks, _, _, bewald = get_reciprocal_space_parameters( 346 reciprocal_cells, self.cutoff, self.kmax, self.kthr 347 ) 348 graph_out["k_points"] = ks 349 graph_out["b_ewald"] = bewald 350 351 output = {**inputs, self.graph_key: graph_out} 352 353 if return_state_update: 354 return FrozenDict(new_state), output, state_up 355 return FrozenDict(new_state), output 356 357 def check_reallocate(self, state, inputs, parent_overflow=False): 358 """check for overflow and reallocate nblist if necessary""" 359 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 360 if not overflow: 361 return state, {}, inputs, False 362 363 add_margin = inputs[self.graph_key].get("overflow", False) 364 state, inputs, state_up = self( 365 state, inputs, return_state_update=True, add_margin=add_margin 366 ) 367 return state, state_up, inputs, True 368 369 @partial(jax.jit, static_argnums=(0, 1)) 370 def process(self, state, inputs): 371 """build a nblist on accelerator with jax and precomputed shapes""" 372 if self.graph_key in inputs: 373 graph = inputs[self.graph_key] 374 if "keep_graph" in graph: 375 return inputs 376 coords = inputs["coordinates"] 377 natoms = inputs["natoms"] 378 batch_index = inputs["batch_index"] 379 380 if natoms.shape[0] == 1: 381 max_nat = coords.shape[0] 382 else: 383 max_nat = state.get( 384 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 385 ) 386 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 387 388 ### compute indices of all pairs 389 p1, p2 = np.triu_indices(max_nat, 1) 390 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 391 pbc_shifts = None 392 if natoms.shape[0] > 1: 393 ## batching => mask irrelevant pairs 394 mask_p12 = ( 395 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 396 ).flatten() 397 shift = jnp.concatenate( 398 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 399 ) 400 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 401 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 402 403 ## compute vectors 404 overflow_repeats = jnp.asarray(False, dtype=bool) 405 if "cells" not in inputs: 406 vec = coords[p2] - coords[p1] 407 else: 408 cells = inputs["cells"] 409 reciprocal_cells = inputs["reciprocal_cells"] 410 minimage = state.get("minimum_image", True) 411 412 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 413 vecpbc = jnp.dot(vec, reciprocal_cell) 414 if mode == "round": 415 pbc_shifts = -jnp.round(vecpbc) 416 elif mode == "floor": 417 pbc_shifts = -jnp.floor(vecpbc) 418 else: 419 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 420 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 421 422 if minimage: 423 ## minimum image convention 424 vec = coords[p2] - coords[p1] 425 426 if cells.shape[0] == 1: 427 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 428 else: 429 batch_index_vec = batch_index[p1] 430 vec, pbc_shifts = jax.vmap(compute_pbc)( 431 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 432 ) 433 else: 434 ### general PBC only for single cell yet 435 # if cells.shape[0] > 1: 436 # raise NotImplementedError( 437 # "General PBC not implemented for batches on accelerator." 438 # ) 439 # cell = cells[0] 440 # reciprocal_cell = reciprocal_cells[0] 441 442 ## put all atoms in central box 443 if cells.shape[0] == 1: 444 coords_pbc, at_shifts = compute_pbc( 445 coords, reciprocal_cells[0], cells[0], mode="floor" 446 ) 447 else: 448 coords_pbc, at_shifts = jax.vmap( 449 partial(compute_pbc, mode="floor") 450 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 451 vec = coords_pbc[p2] - coords_pbc[p1] 452 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 453 # if num_repeats is None: 454 # raise ValueError( 455 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 456 # ) 457 # check if num_repeats is larger than previous 458 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 459 cdinv = cutoff_skin * inv_distances 460 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 461 if "true_sys" in inputs: 462 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 463 num_repeats_new = jnp.max(num_repeats_all, axis=0) 464 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 465 466 cell_shift_pbc = jnp.asarray( 467 np.array( 468 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 469 dtype=cells.dtype, 470 ).T.reshape(-1, 3) 471 ) 472 473 if cells.shape[0] == 1: 474 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 475 pbc_shifts = jnp.broadcast_to( 476 cell_shift_pbc[None, :, :], 477 (p1.shape[0], cell_shift_pbc.shape[0], 3), 478 ).reshape(-1, 3) 479 p1 = jnp.broadcast_to( 480 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 481 ).flatten() 482 p2 = jnp.broadcast_to( 483 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 484 ).flatten() 485 if natoms.shape[0] > 1: 486 mask_p12 = jnp.broadcast_to( 487 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 488 ).flatten() 489 else: 490 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 491 492 ## get pbc shifts specific to each box 493 cell_shift_pbc = jnp.broadcast_to( 494 cell_shift_pbc[None, :, :], 495 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 496 ) 497 mask = jnp.all( 498 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 499 ).flatten() 500 max_shifts = state.get("nshifts_pbc", 1) 501 502 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 503 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 504 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 505 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 506 mask, 507 max_shifts, 508 (dvecx, 0.), 509 (dvecy, 0.), 510 (dvecz, 0.), 511 (shiftx, 0), 512 (shifty, 0), 513 (shiftz, 0), 514 ) 515 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 516 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 517 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 518 519 ## get batch shift in the dvec_filter array 520 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 521 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 522 523 ## repeat vectors 524 nvec_max = state.get("nvec_pbc", 1) 525 batch_index_vec = batch_index[p1] 526 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 527 nvec = nrep_vec.sum() 528 overflow_repeats = overflow_repeats | (nvec > nvec_max) 529 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 530 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 531 532 ## get shift index 533 dshift = jnp.concatenate( 534 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 535 ) 536 if nrep_vec.size == 0: 537 dshift = jnp.array([],dtype=jnp.int32) 538 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 539 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 540 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 541 vec = vec + dvec[icellshift] 542 pbc_shifts = cell_shift_pbc[icellshift] 543 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 544 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 545 if natoms.shape[0] > 1: 546 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 547 548 549 ## compute distances 550 d12 = (vec**2).sum(axis=-1) 551 if natoms.shape[0] > 1: 552 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 553 554 ## filter pairs 555 max_pairs = state.get("npairs", 1) 556 mask = d12 < cutoff_skin**2 557 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 558 mask, 559 max_pairs, 560 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 561 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 562 (d12, cutoff_skin**2), 563 ) 564 if "cells" in inputs: 565 pbc_shifts = ( 566 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 567 .at[scatter_idx] 568 .set(pbc_shifts, mode="drop") 569 ) 570 if not minimage: 571 pbc_shifts = ( 572 pbc_shifts 573 + at_shifts.at[edge_dst].get(fill_value=0.0) 574 - at_shifts.at[edge_src].get(fill_value=0.0) 575 ) 576 577 ## check for overflow 578 if natoms.shape[0] == 1: 579 true_max_nat = coords.shape[0] 580 else: 581 true_max_nat = jnp.max(natoms) 582 overflow_count = npairs > max_pairs 583 overflow_at = true_max_nat > max_nat 584 overflow = overflow_count | overflow_at | overflow_repeats 585 586 if "nblist_skin" in state: 587 # edge_mask_skin = edge_mask 588 edge_src_skin = edge_src 589 edge_dst_skin = edge_dst 590 if "cells" in inputs: 591 pbc_shifts_skin = pbc_shifts 592 max_pairs_skin = state.get("npairs_skin", 1) 593 mask = d12 < self.cutoff**2 594 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 595 mask, 596 max_pairs_skin, 597 (edge_src, coords.shape[0]), 598 (edge_dst, coords.shape[0]), 599 (d12, self.cutoff**2), 600 ) 601 if "cells" in inputs: 602 pbc_shifts = ( 603 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 604 .at[scatter_idx] 605 .set(pbc_shifts, mode="drop") 606 ) 607 overflow = overflow | (npairs_skin > max_pairs_skin) 608 609 ## symmetrize 610 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 611 (edge_dst, edge_src) 612 ) 613 d12 = jnp.concatenate((d12, d12)) 614 if "cells" in inputs: 615 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 616 617 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 618 graph_out = { 619 **graph, 620 "edge_src": edge_src, 621 "edge_dst": edge_dst, 622 "d12": d12, 623 "overflow": overflow, 624 "pbc_shifts": pbc_shifts, 625 } 626 if "nblist_skin" in state: 627 graph_out["edge_src_skin"] = edge_src_skin 628 graph_out["edge_dst_skin"] = edge_dst_skin 629 if "cells" in inputs: 630 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 631 632 if self.k_space and "cells" in inputs: 633 if "k_points" not in graph: 634 raise NotImplementedError( 635 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 636 ) 637 return {**inputs, self.graph_key: graph_out} 638 639 @partial(jax.jit, static_argnums=(0,)) 640 def update_skin(self, inputs): 641 """update the nblist without recomputing the full nblist""" 642 graph = inputs[self.graph_key] 643 644 edge_src_skin = graph["edge_src_skin"] 645 edge_dst_skin = graph["edge_dst_skin"] 646 coords = inputs["coordinates"] 647 vec = coords.at[edge_dst_skin].get( 648 mode="fill", fill_value=self.cutoff 649 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 650 651 if "cells" in inputs: 652 pbc_shifts_skin = graph["pbc_shifts_skin"] 653 cells = inputs["cells"] 654 if cells.shape[0] == 1: 655 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 656 else: 657 batch_index_vec = inputs["batch_index"][edge_src_skin] 658 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 659 660 nat = coords.shape[0] 661 d12 = jnp.sum(vec**2, axis=-1) 662 mask = d12 < self.cutoff**2 663 max_pairs = graph["edge_src"].shape[0] // 2 664 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 665 mask, 666 max_pairs, 667 (edge_src_skin, nat), 668 (edge_dst_skin, nat), 669 (d12, self.cutoff**2), 670 ) 671 if "cells" in inputs: 672 pbc_shifts = ( 673 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 674 .at[scatter_idx] 675 .set(pbc_shifts_skin) 676 ) 677 678 overflow = graph.get("overflow", False) | (npairs > max_pairs) 679 graph_out = { 680 **graph, 681 "edge_src": jnp.concatenate((edge_src, edge_dst)), 682 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 683 "d12": jnp.concatenate((d12, d12)), 684 "overflow": overflow, 685 } 686 if "cells" in inputs: 687 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 688 689 if self.k_space and "cells" in inputs: 690 if "k_points" not in graph: 691 raise NotImplementedError( 692 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 693 ) 694 695 return {**inputs, self.graph_key: graph_out} 696 697 698class GraphProcessor(nn.Module): 699 """Process a pre-generated graph 700 701 The pre-generated graph should contain the following keys: 702 - edge_src: source indices of the edges 703 - edge_dst: destination indices of the edges 704 - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs) 705 706 This module is automatically added to a FENNIX model when a GraphGenerator is used. 707 708 """ 709 710 cutoff: float 711 """Cutoff distance for the graph.""" 712 graph_key: str = "graph" 713 """Key of the graph in the outputs.""" 714 switch_params: dict = dataclasses.field(default_factory=dict) 715 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 716 717 @nn.compact 718 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 719 graph = inputs[self.graph_key] 720 coords = inputs["coordinates"] 721 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 722 # edge_mask = edge_src < coords.shape[0] 723 vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[ 724 edge_src 725 ].get(mode="fill", fill_value=0.0) 726 if "cells" in inputs: 727 cells = inputs["cells"] 728 if cells.shape[0] == 1: 729 vec = vec + jnp.dot(graph["pbc_shifts"], cells[0]) 730 else: 731 batch_index_vec = inputs["batch_index"][edge_src] 732 vec = vec + jax.vmap(jnp.dot)( 733 graph["pbc_shifts"], cells[batch_index_vec] 734 ) 735 736 d2 = jnp.sum(vec**2, axis=-1) 737 distances = safe_sqrt(d2) 738 edge_mask = distances < self.cutoff 739 740 switch = SwitchFunction( 741 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 742 )((distances, edge_mask)) 743 744 graph_out = { 745 **graph, 746 "vec": vec, 747 "distances": distances, 748 "switch": switch, 749 "edge_mask": edge_mask, 750 } 751 752 if "alch_group" in inputs: 753 alch_group = inputs["alch_group"] 754 lambda_e = inputs["alch_elambda"] 755 mask = alch_group[edge_src] == alch_group[edge_dst] 756 graph_out["switch_raw"] = switch 757 graph_out["switch"] = jnp.where( 758 mask, 759 switch, 760 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 761 ) 762 763 if "alch_softcore_e" in inputs or "alch_softcore_v" in inputs: 764 graph_out["distances_raw"] = distances 765 if "alch_softcore_e" in inputs: 766 alch_alpha = (1-inputs["alch_elambda"])*inputs["alch_softcore_e"]**2 767 else: 768 alch_alpha = (1-inputs["alch_vlambda"])*inputs["alch_softcore_v"]**2 769 distances = jnp.where( 770 mask, 771 distances, 772 safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2)) 773 ) 774 graph_out["distances"] = distances 775 776 return {**inputs, self.graph_key: graph_out} 777 778 779@dataclasses.dataclass(frozen=True) 780class GraphFilter: 781 """Filter a graph based on a cutoff distance 782 783 FPID: GRAPH_FILTER 784 """ 785 786 cutoff: float 787 """Cutoff distance for the filtering.""" 788 parent_graph: str 789 """Key of the parent graph in the inputs.""" 790 graph_key: str 791 """Key of the filtered graph in the outputs.""" 792 remove_hydrogens: int = False 793 """Remove edges where the source is a hydrogen atom.""" 794 switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict) 795 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 796 k_space: bool = False 797 """Generate k-space information for the graph.""" 798 kmax: int = 30 799 """Maximum number of k-points to consider.""" 800 kthr: float = 1e-6 801 """Threshold for k-point filtering.""" 802 mult_size: float = 1.05 803 """Multiplicative factor for resizing the nblist.""" 804 805 FPID: ClassVar[str] = "GRAPH_FILTER" 806 807 def init(self): 808 return FrozenDict( 809 { 810 "npairs": 1, 811 "nblist_mult_size": self.mult_size, 812 } 813 ) 814 815 def get_processor(self) -> Tuple[nn.Module, Dict]: 816 return GraphFilterProcessor, { 817 "cutoff": self.cutoff, 818 "graph_key": self.graph_key, 819 "parent_graph": self.parent_graph, 820 "name": f"{self.graph_key}_Filter_{self.parent_graph}", 821 "switch_params": self.switch_params, 822 } 823 824 def get_graph_properties(self): 825 return { 826 self.graph_key: { 827 "cutoff": self.cutoff, 828 "directed": True, 829 "parent_graph": self.parent_graph, 830 } 831 } 832 833 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 834 """filter a nblist on cpu with numpy and dynamic shapes + store max shapes""" 835 graph_in = inputs[self.parent_graph] 836 nat = inputs["species"].shape[0] 837 838 new_state = {**state} 839 state_up = {} 840 mult_size = state.get("nblist_mult_size", self.mult_size) 841 assert mult_size >= 1., "nblist_mult_size should be >= 1." 842 843 edge_src = np.array(graph_in["edge_src"], dtype=np.int32) 844 d12 = np.array(graph_in["d12"], dtype=np.float32) 845 if self.remove_hydrogens: 846 species = inputs["species"] 847 src_idx = (edge_src < nat).nonzero()[0] 848 mask = np.zeros(edge_src.shape[0], dtype=bool) 849 mask[src_idx] = (species > 1)[edge_src[src_idx]] 850 d12 = np.where(mask, d12, self.cutoff**2) 851 mask = d12 < self.cutoff**2 852 853 max_pairs = state.get("npairs", 1) 854 idx = np.nonzero(mask)[0] 855 npairs = idx.shape[0] 856 if npairs > max_pairs or add_margin: 857 prev_max_pairs = max_pairs 858 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 859 state_up["npairs"] = (max_pairs, prev_max_pairs) 860 new_state["npairs"] = max_pairs 861 862 filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32) 863 edge_src = np.full(max_pairs, nat, dtype=np.int32) 864 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 865 d12_ = np.full(max_pairs, self.cutoff**2) 866 filter_indices[:npairs] = idx 867 edge_src[:npairs] = graph_in["edge_src"][idx] 868 edge_dst[:npairs] = graph_in["edge_dst"][idx] 869 d12_[:npairs] = d12[idx] 870 d12 = d12_ 871 872 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 873 graph_out = { 874 **graph, 875 "edge_src": edge_src, 876 "edge_dst": edge_dst, 877 "filter_indices": filter_indices, 878 "d12": d12, 879 "overflow": False, 880 } 881 882 if self.k_space and "cells" in inputs: 883 if "k_points" not in graph: 884 ks, _, _, bewald = get_reciprocal_space_parameters( 885 inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr 886 ) 887 graph_out["k_points"] = ks 888 graph_out["b_ewald"] = bewald 889 890 output = {**inputs, self.graph_key: graph_out} 891 if return_state_update: 892 return FrozenDict(new_state), output, state_up 893 return FrozenDict(new_state), output 894 895 def check_reallocate(self, state, inputs, parent_overflow=False): 896 """check for overflow and reallocate nblist if necessary""" 897 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 898 if not overflow: 899 return state, {}, inputs, False 900 901 add_margin = inputs[self.graph_key].get("overflow", False) 902 state, inputs, state_up = self( 903 state, inputs, return_state_update=True, add_margin=add_margin 904 ) 905 return state, state_up, inputs, True 906 907 @partial(jax.jit, static_argnums=(0, 1)) 908 def process(self, state, inputs): 909 """filter a nblist on accelerator with jax and precomputed shapes""" 910 graph_in = inputs[self.parent_graph] 911 if state is None: 912 # skin update mode 913 graph = inputs[self.graph_key] 914 max_pairs = graph["edge_src"].shape[0] 915 else: 916 max_pairs = state.get("npairs", 1) 917 918 max_pairs_in = graph_in["edge_src"].shape[0] 919 nat = inputs["species"].shape[0] 920 921 edge_src = graph_in["edge_src"] 922 d12 = graph_in["d12"] 923 if self.remove_hydrogens: 924 species = inputs["species"] 925 mask = (species > 1)[edge_src] 926 d12 = jnp.where(mask, d12, self.cutoff**2) 927 mask = d12 < self.cutoff**2 928 929 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 930 mask, 931 max_pairs, 932 (edge_src, nat), 933 (graph_in["edge_dst"], nat), 934 (d12, self.cutoff**2), 935 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 936 ) 937 938 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 939 overflow = graph.get("overflow", False) | (npairs > max_pairs) 940 graph_out = { 941 **graph, 942 "edge_src": edge_src, 943 "edge_dst": edge_dst, 944 "filter_indices": filter_indices, 945 "d12": d12, 946 "overflow": overflow, 947 } 948 949 if self.k_space and "cells" in inputs: 950 if "k_points" not in graph: 951 raise NotImplementedError( 952 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 953 ) 954 955 return {**inputs, self.graph_key: graph_out} 956 957 @partial(jax.jit, static_argnums=(0,)) 958 def update_skin(self, inputs): 959 return self.process(None, inputs) 960 961 962class GraphFilterProcessor(nn.Module): 963 """Filter processing for a pre-generated graph 964 965 This module is automatically added to a FENNIX model when a GraphFilter is used. 966 """ 967 968 cutoff: float 969 """Cutoff distance for the filtering.""" 970 graph_key: str 971 """Key of the filtered graph in the inputs.""" 972 parent_graph: str 973 """Key of the parent graph in the inputs.""" 974 switch_params: dict = dataclasses.field(default_factory=dict) 975 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 976 977 @nn.compact 978 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 979 graph_in = inputs[self.parent_graph] 980 graph = inputs[self.graph_key] 981 982 d_key = "distances_raw" if "distances_raw" in graph else "distances" 983 984 if graph_in["vec"].shape[0] == 0: 985 vec = graph_in["vec"] 986 distances = graph_in[d_key] 987 filter_indices = jnp.asarray([], dtype=jnp.int32) 988 else: 989 filter_indices = graph["filter_indices"] 990 vec = ( 991 graph_in["vec"] 992 .at[filter_indices] 993 .get(mode="fill", fill_value=self.cutoff) 994 ) 995 distances = ( 996 graph_in[d_key] 997 .at[filter_indices] 998 .get(mode="fill", fill_value=self.cutoff) 999 ) 1000 1001 edge_mask = distances < self.cutoff 1002 switch = SwitchFunction( 1003 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 1004 )((distances, edge_mask)) 1005 1006 graph_out = { 1007 **graph, 1008 "vec": vec, 1009 "distances": distances, 1010 "switch": switch, 1011 "filter_indices": filter_indices, 1012 "edge_mask": edge_mask, 1013 } 1014 1015 if "alch_group" in inputs: 1016 edge_src=graph["edge_src"] 1017 edge_dst=graph["edge_dst"] 1018 alch_group = inputs["alch_group"] 1019 lambda_e = inputs["alch_elambda"] 1020 mask = alch_group[edge_src] == alch_group[edge_dst] 1021 graph_out["switch_raw"] = switch 1022 graph_out["switch"] = jnp.where( 1023 mask, 1024 switch, 1025 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 1026 ) 1027 1028 if "alch_softcore_e" in inputs or "alch_softcore_v" in inputs: 1029 graph_out["distances_raw"] = distances 1030 if "alch_softcore_e" in inputs: 1031 alch_alpha = (1-inputs["alch_elambda"])*inputs["alch_softcore_e"]**2 1032 else: 1033 alch_alpha = (1-inputs["alch_vlambda"])*inputs["alch_softcore_v"]**2 1034 distances = jnp.where( 1035 mask, 1036 distances, 1037 safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2)) 1038 ) 1039 graph_out["distances"] = distances 1040 1041 return {**inputs, self.graph_key: graph_out} 1042 1043 1044@dataclasses.dataclass(frozen=True) 1045class GraphAngularExtension: 1046 """Add angles list to a graph 1047 1048 FPID: GRAPH_ANGULAR_EXTENSION 1049 """ 1050 1051 mult_size: float = 1.05 1052 """Multiplicative factor for resizing the nblist.""" 1053 add_neigh: int = 5 1054 """Additional neighbors to add to the nblist when resizing.""" 1055 graph_key: str = "graph" 1056 """Key of the graph in the inputs.""" 1057 1058 FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION" 1059 1060 def init(self): 1061 return FrozenDict( 1062 { 1063 "nangles": 0, 1064 "nblist_mult_size": self.mult_size, 1065 "max_neigh": self.add_neigh, 1066 "add_neigh": self.add_neigh, 1067 } 1068 ) 1069 1070 def get_processor(self) -> Tuple[nn.Module, Dict]: 1071 return GraphAngleProcessor, { 1072 "graph_key": self.graph_key, 1073 "name": f"{self.graph_key}_AngleProcessor", 1074 } 1075 1076 def get_graph_properties(self): 1077 return { 1078 self.graph_key: { 1079 "has_angles": True, 1080 } 1081 } 1082 1083 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1084 """build angle nblist on cpu with numpy and dynamic shapes + store max shapes""" 1085 graph = inputs[self.graph_key] 1086 edge_src = np.array(graph["edge_src"], dtype=np.int32) 1087 1088 new_state = {**state} 1089 state_up = {} 1090 mult_size = state.get("nblist_mult_size", self.mult_size) 1091 assert mult_size >= 1., "nblist_mult_size should be >= 1." 1092 1093 ### count number of neighbors 1094 nat = inputs["species"].shape[0] 1095 count = np.zeros(nat + 1, dtype=np.int32) 1096 np.add.at(count, edge_src, 1) 1097 max_count = int(np.max(count[:-1])) 1098 1099 ### get sizes 1100 max_neigh = state.get("max_neigh", self.add_neigh) 1101 nedge = edge_src.shape[0] 1102 if max_count > max_neigh or add_margin: 1103 prev_max_neigh = max_neigh 1104 max_neigh = max(max_count, max_neigh) + state.get( 1105 "add_neigh", self.add_neigh 1106 ) 1107 state_up["max_neigh"] = (max_neigh, prev_max_neigh) 1108 new_state["max_neigh"] = max_neigh 1109 1110 max_neigh_arr = np.empty(max_neigh, dtype=bool) 1111 1112 nedge = edge_src.shape[0] 1113 1114 ### sort edge_src 1115 idx_sort = np.argsort(edge_src) 1116 edge_src_sorted = edge_src[idx_sort] 1117 1118 ### map sparse to dense nblist 1119 offset = np.tile(np.arange(max_count), nat) 1120 if max_count * nat >= nedge: 1121 offset = np.tile(np.arange(max_count), nat)[:nedge] 1122 else: 1123 offset = np.zeros(nedge, dtype=np.int32) 1124 offset[: max_count * nat] = np.tile(np.arange(max_count), nat) 1125 1126 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1127 mask = edge_src_sorted < nat 1128 indices = edge_src_sorted * max_count + offset 1129 indices = indices[mask] 1130 idx_sort = idx_sort[mask] 1131 edge_idx = np.full(nat * max_count, nedge, dtype=np.int32) 1132 edge_idx[indices] = idx_sort 1133 edge_idx = edge_idx.reshape(nat, max_count) 1134 1135 ### find all triplet for each atom center 1136 local_src, local_dst = np.triu_indices(max_count, 1) 1137 angle_src = edge_idx[:, local_src].flatten() 1138 angle_dst = edge_idx[:, local_dst].flatten() 1139 1140 ### mask for valid angles 1141 mask1 = angle_src < nedge 1142 mask2 = angle_dst < nedge 1143 angle_mask = mask1 & mask2 1144 1145 max_angles = state.get("nangles", 0) 1146 idx = np.nonzero(angle_mask)[0] 1147 nangles = idx.shape[0] 1148 if nangles > max_angles or add_margin: 1149 max_angles_prev = max_angles 1150 max_angles = int(mult_size * max(nangles, max_angles)) + 1 1151 state_up["nangles"] = (max_angles, max_angles_prev) 1152 new_state["nangles"] = max_angles 1153 1154 ## filter angles to sparse representation 1155 angle_src_ = np.full(max_angles, nedge, dtype=np.int32) 1156 angle_dst_ = np.full(max_angles, nedge, dtype=np.int32) 1157 angle_src_[:nangles] = angle_src[idx] 1158 angle_dst_[:nangles] = angle_dst[idx] 1159 1160 central_atom = np.full(max_angles, nat, dtype=np.int32) 1161 central_atom[:nangles] = edge_src[angle_src_[:nangles]] 1162 1163 ## update graph 1164 output = { 1165 **inputs, 1166 self.graph_key: { 1167 **graph, 1168 "angle_src": angle_src_, 1169 "angle_dst": angle_dst_, 1170 "central_atom": central_atom, 1171 "angle_overflow": False, 1172 "max_neigh": max_neigh, 1173 "__max_neigh_array": max_neigh_arr, 1174 }, 1175 } 1176 1177 if return_state_update: 1178 return FrozenDict(new_state), output, state_up 1179 return FrozenDict(new_state), output 1180 1181 def check_reallocate(self, state, inputs, parent_overflow=False): 1182 """check for overflow and reallocate nblist if necessary""" 1183 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1184 if not overflow: 1185 return state, {}, inputs, False 1186 1187 add_margin = inputs[self.graph_key]["angle_overflow"] 1188 state, inputs, state_up = self( 1189 state, inputs, return_state_update=True, add_margin=add_margin 1190 ) 1191 return state, state_up, inputs, True 1192 1193 @partial(jax.jit, static_argnums=(0, 1)) 1194 def process(self, state, inputs): 1195 """build angle nblist on accelerator with jax and precomputed shapes""" 1196 graph = inputs[self.graph_key] 1197 edge_src = graph["edge_src"] 1198 1199 ### count number of neighbors 1200 nat = inputs["species"].shape[0] 1201 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1202 max_count = jnp.max(count) 1203 1204 ### get sizes 1205 if state is None: 1206 max_neigh_arr = graph["__max_neigh_array"] 1207 max_neigh = max_neigh_arr.shape[0] 1208 prev_nangles = graph["angle_src"].shape[0] 1209 else: 1210 max_neigh = state.get("max_neigh", self.add_neigh) 1211 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1212 prev_nangles = state.get("nangles", 0) 1213 1214 nedge = edge_src.shape[0] 1215 1216 ### sort edge_src 1217 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1218 edge_src_sorted = edge_src[idx_sort] 1219 1220 ### map sparse to dense nblist 1221 if max_neigh * nat < nedge: 1222 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1223 offset = jnp.asarray( 1224 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1225 ) 1226 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1227 indices = edge_src_sorted * max_neigh + offset 1228 edge_idx = ( 1229 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1230 .at[indices] 1231 .set(idx_sort, mode="drop") 1232 .reshape(nat, max_neigh) 1233 ) 1234 1235 ### find all triplet for each atom center 1236 local_src, local_dst = np.triu_indices(max_neigh, 1) 1237 angle_src = edge_idx[:, local_src].flatten() 1238 angle_dst = edge_idx[:, local_dst].flatten() 1239 1240 ### mask for valid angles 1241 mask1 = angle_src < nedge 1242 mask2 = angle_dst < nedge 1243 angle_mask = mask1 & mask2 1244 1245 ## filter angles to sparse representation 1246 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1247 angle_mask, 1248 prev_nangles, 1249 (angle_src, nedge), 1250 (angle_dst, nedge), 1251 ) 1252 ## find central atom 1253 central_atom = edge_src[angle_src] 1254 1255 ## check for overflow 1256 angle_overflow = nangles > prev_nangles 1257 neigh_overflow = max_count > max_neigh 1258 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1259 1260 ## update graph 1261 output = { 1262 **inputs, 1263 self.graph_key: { 1264 **graph, 1265 "angle_src": angle_src, 1266 "angle_dst": angle_dst, 1267 "central_atom": central_atom, 1268 "angle_overflow": overflow, 1269 # "max_neigh": max_neigh, 1270 "__max_neigh_array": max_neigh_arr, 1271 }, 1272 } 1273 1274 return output 1275 1276 @partial(jax.jit, static_argnums=(0,)) 1277 def update_skin(self, inputs): 1278 return self.process(None, inputs) 1279 1280 1281class GraphAngleProcessor(nn.Module): 1282 """Process a pre-generated graph to compute angles 1283 1284 This module is automatically added to a FENNIX model when a GraphAngularExtension is used. 1285 1286 """ 1287 1288 graph_key: str 1289 """Key of the graph in the inputs.""" 1290 1291 @nn.compact 1292 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1293 graph = inputs[self.graph_key] 1294 distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"] 1295 vec = graph["vec"] 1296 angle_src = graph["angle_src"] 1297 angle_dst = graph["angle_dst"] 1298 1299 dir = vec / jnp.clip(distances[:, None], min=1.0e-5) 1300 cos_angles = ( 1301 dir.at[angle_src].get(mode="fill", fill_value=0.5) 1302 * dir.at[angle_dst].get(mode="fill", fill_value=0.5) 1303 ).sum(axis=-1) 1304 1305 angles = jnp.arccos(0.95 * cos_angles) 1306 1307 return { 1308 **inputs, 1309 self.graph_key: { 1310 **graph, 1311 # "cos_angles": cos_angles, 1312 "angles": angles, 1313 # "angle_mask": angle_mask, 1314 }, 1315 } 1316 1317 1318@dataclasses.dataclass(frozen=True) 1319class SpeciesIndexer: 1320 """Build an index that splits atomic arrays by species. 1321 1322 FPID: SPECIES_INDEXER 1323 1324 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1325 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1326 1327 """ 1328 1329 output_key: str = "species_index" 1330 """Key for the output dictionary.""" 1331 species_order: Optional[str] = None 1332 """Comma separated list of species in the order they should be indexed.""" 1333 add_atoms: int = 0 1334 """Additional atoms to add to the sizes.""" 1335 add_atoms_margin: int = 10 1336 """Additional atoms to add to the sizes when adding margin.""" 1337 1338 FPID: ClassVar[str] = "SPECIES_INDEXER" 1339 1340 def init(self): 1341 return FrozenDict( 1342 { 1343 "sizes": {}, 1344 } 1345 ) 1346 1347 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1348 species = np.array(inputs["species"], dtype=np.int32) 1349 nat = species.shape[0] 1350 set_species, counts = np.unique(species, return_counts=True) 1351 1352 new_state = {**state} 1353 state_up = {} 1354 1355 sizes = state.get("sizes", FrozenDict({})) 1356 new_sizes = {**sizes} 1357 up_sizes = False 1358 counts_dict = {} 1359 for s, c in zip(set_species, counts): 1360 if s <= 0: 1361 continue 1362 counts_dict[s] = c 1363 if c > sizes.get(s, 0): 1364 up_sizes = True 1365 add_atoms = state.get("add_atoms", self.add_atoms) 1366 if add_margin: 1367 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1368 new_sizes[s] = c + add_atoms 1369 1370 new_sizes = FrozenDict(new_sizes) 1371 if up_sizes: 1372 state_up["sizes"] = (new_sizes, sizes) 1373 new_state["sizes"] = new_sizes 1374 1375 if self.species_order is not None: 1376 species_order = [el.strip() for el in self.species_order.split(",")] 1377 max_size_prev = state.get("max_size", 0) 1378 max_size = max(new_sizes.values()) 1379 if max_size > max_size_prev: 1380 state_up["max_size"] = (max_size, max_size_prev) 1381 new_state["max_size"] = max_size 1382 max_size_prev = max_size 1383 1384 species_index = np.full((len(species_order), max_size), nat, dtype=np.int32) 1385 for i, el in enumerate(species_order): 1386 s = PERIODIC_TABLE_REV_IDX[el] 1387 if s in counts_dict.keys(): 1388 species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1389 else: 1390 species_index = { 1391 PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32) 1392 for s, c in new_sizes.items() 1393 } 1394 for s, c in zip(set_species, counts): 1395 if s <= 0: 1396 continue 1397 species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0] 1398 1399 output = { 1400 **inputs, 1401 self.output_key: species_index, 1402 self.output_key + "_overflow": False, 1403 } 1404 1405 if return_state_update: 1406 return FrozenDict(new_state), output, state_up 1407 return FrozenDict(new_state), output 1408 1409 def check_reallocate(self, state, inputs, parent_overflow=False): 1410 """check for overflow and reallocate nblist if necessary""" 1411 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1412 if not overflow: 1413 return state, {}, inputs, False 1414 1415 add_margin = inputs[self.output_key + "_overflow"] 1416 state, inputs, state_up = self( 1417 state, inputs, return_state_update=True, add_margin=add_margin 1418 ) 1419 return state, state_up, inputs, True 1420 # return state, {}, inputs, parent_overflow 1421 1422 @partial(jax.jit, static_argnums=(0, 1)) 1423 def process(self, state, inputs): 1424 # assert ( 1425 # self.output_key in inputs 1426 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1427 1428 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1429 if self.output_key in inputs and not recompute_species_index: 1430 return inputs 1431 1432 if state is None: 1433 raise ValueError("Species Indexer state must be provided on accelerator.") 1434 1435 species = inputs["species"] 1436 nat = species.shape[0] 1437 1438 sizes = state["sizes"] 1439 1440 if self.species_order is not None: 1441 species_order = [el.strip() for el in self.species_order.split(",")] 1442 max_size = state["max_size"] 1443 1444 species_index = jnp.full( 1445 (len(species_order), max_size), nat, dtype=jnp.int32 1446 ) 1447 for i, el in enumerate(species_order): 1448 s = PERIODIC_TABLE_REV_IDX[el] 1449 if s in sizes.keys(): 1450 c = sizes[s] 1451 species_index = species_index.at[i, :].set( 1452 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1453 ) 1454 # if s in counts_dict.keys(): 1455 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1456 else: 1457 # species_index = { 1458 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1459 # for s, c in sizes.items() 1460 # } 1461 species_index = {} 1462 overflow = False 1463 natcount = 0 1464 for s, c in sizes.items(): 1465 mask = species == s 1466 new_size = jnp.sum(mask) 1467 natcount = natcount + new_size 1468 overflow = overflow | (new_size > c) # check if sizes are correct 1469 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1470 species == s, size=c, fill_value=nat 1471 )[0] 1472 1473 mask = species <= 0 1474 new_size = jnp.sum(mask) 1475 natcount = natcount + new_size 1476 overflow = overflow | ( 1477 natcount < species.shape[0] 1478 ) # check if any species missing 1479 1480 return { 1481 **inputs, 1482 self.output_key: species_index, 1483 self.output_key + "_overflow": overflow, 1484 } 1485 1486 @partial(jax.jit, static_argnums=(0,)) 1487 def update_skin(self, inputs): 1488 return self.process(None, inputs) 1489 1490@dataclasses.dataclass(frozen=True) 1491class BlockIndexer: 1492 """Build an index that splits atomic arrays by chemical blocks. 1493 1494 FPID: BLOCK_INDEXER 1495 1496 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1497 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1498 1499 """ 1500 1501 output_key: str = "block_index" 1502 """Key for the output dictionary.""" 1503 add_atoms: int = 0 1504 """Additional atoms to add to the sizes.""" 1505 add_atoms_margin: int = 10 1506 """Additional atoms to add to the sizes when adding margin.""" 1507 split_CNOPSSe: bool = False 1508 1509 FPID: ClassVar[str] = "BLOCK_INDEXER" 1510 1511 def init(self): 1512 return FrozenDict( 1513 { 1514 "sizes": {}, 1515 } 1516 ) 1517 1518 def build_chemical_blocks(self): 1519 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1520 if self.split_CNOPSSe: 1521 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1522 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1523 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1524 if self.split_CNOPSSe: 1525 _CHEMICAL_BLOCKS[6] = 1 1526 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1527 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1528 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1529 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1530 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1531 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS 1532 1533 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1534 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1535 1536 species = np.array(inputs["species"], dtype=np.int32) 1537 blocks = _CHEMICAL_BLOCKS[species] 1538 nat = species.shape[0] 1539 set_blocks, counts = np.unique(blocks, return_counts=True) 1540 1541 new_state = {**state} 1542 state_up = {} 1543 1544 sizes = state.get("sizes", FrozenDict({})) 1545 new_sizes = {**sizes} 1546 up_sizes = False 1547 for s, c in zip(set_blocks, counts): 1548 if s < 0: 1549 continue 1550 key = (s, _CHEMICAL_BLOCKS_NAMES[s]) 1551 if c > sizes.get(key, 0): 1552 up_sizes = True 1553 add_atoms = state.get("add_atoms", self.add_atoms) 1554 if add_margin: 1555 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1556 new_sizes[key] = c + add_atoms 1557 1558 new_sizes = FrozenDict(new_sizes) 1559 if up_sizes: 1560 state_up["sizes"] = (new_sizes, sizes) 1561 new_state["sizes"] = new_sizes 1562 1563 block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES} 1564 for (_,n), c in new_sizes.items(): 1565 block_index[n] = np.full(c, nat, dtype=np.int32) 1566 # block_index = { 1567 # n: np.full(c, nat, dtype=np.int32) 1568 # for (_,n), c in new_sizes.items() 1569 # } 1570 for s, c in zip(set_blocks, counts): 1571 if s < 0: 1572 continue 1573 block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0] 1574 1575 output = { 1576 **inputs, 1577 self.output_key: block_index, 1578 self.output_key + "_overflow": False, 1579 } 1580 1581 if return_state_update: 1582 return FrozenDict(new_state), output, state_up 1583 return FrozenDict(new_state), output 1584 1585 def check_reallocate(self, state, inputs, parent_overflow=False): 1586 """check for overflow and reallocate nblist if necessary""" 1587 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1588 if not overflow: 1589 return state, {}, inputs, False 1590 1591 add_margin = inputs[self.output_key + "_overflow"] 1592 state, inputs, state_up = self( 1593 state, inputs, return_state_update=True, add_margin=add_margin 1594 ) 1595 return state, state_up, inputs, True 1596 # return state, {}, inputs, parent_overflow 1597 1598 @partial(jax.jit, static_argnums=(0, 1)) 1599 def process(self, state, inputs): 1600 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1601 # assert ( 1602 # self.output_key in inputs 1603 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1604 1605 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1606 if self.output_key in inputs and not recompute_species_index: 1607 return inputs 1608 1609 if state is None: 1610 raise ValueError("Block Indexer state must be provided on accelerator.") 1611 1612 species = inputs["species"] 1613 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1614 nat = species.shape[0] 1615 1616 sizes = state["sizes"] 1617 1618 # species_index = { 1619 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1620 # for s, c in sizes.items() 1621 # } 1622 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1623 overflow = False 1624 natcount = 0 1625 for (s,name), c in sizes.items(): 1626 mask = blocks == s 1627 new_size = jnp.sum(mask) 1628 natcount = natcount + new_size 1629 overflow = overflow | (new_size > c) # check if sizes are correct 1630 block_index[name] = jnp.nonzero( 1631 mask, size=c, fill_value=nat 1632 )[0] 1633 1634 mask = blocks < 0 1635 new_size = jnp.sum(mask) 1636 natcount = natcount + new_size 1637 overflow = overflow | ( 1638 natcount < species.shape[0] 1639 ) # check if any species missing 1640 1641 return { 1642 **inputs, 1643 self.output_key: block_index, 1644 self.output_key + "_overflow": overflow, 1645 } 1646 1647 @partial(jax.jit, static_argnums=(0,)) 1648 def update_skin(self, inputs): 1649 return self.process(None, inputs) 1650 1651 1652@dataclasses.dataclass(frozen=True) 1653class AtomPadding: 1654 """Pad atomic arrays to a fixed size.""" 1655 1656 mult_size: float = 1.2 1657 """Multiplicative factor for resizing the atomic arrays.""" 1658 add_sys: int = 0 1659 1660 def init(self): 1661 return {"prev_nat": 0, "prev_nsys": 0} 1662 1663 def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]: 1664 species = inputs["species"] 1665 nat = species.shape[0] 1666 1667 prev_nat = state.get("prev_nat", 0) 1668 prev_nat_ = prev_nat 1669 if nat > prev_nat_: 1670 prev_nat_ = int(self.mult_size * nat) + 1 1671 1672 nsys = len(inputs["natoms"]) 1673 prev_nsys = state.get("prev_nsys", 0) 1674 prev_nsys_ = prev_nsys 1675 if nsys > prev_nsys_: 1676 prev_nsys_ = nsys + self.add_sys 1677 1678 add_atoms = prev_nat_ - nat 1679 add_sys = prev_nsys_ - nsys + 1 1680 output = {**inputs} 1681 if add_atoms > 0: 1682 for k, v in inputs.items(): 1683 if isinstance(v, np.ndarray) or isinstance(v, jax.Array): 1684 if v.shape[0] == nat: 1685 output[k] = np.append( 1686 v, 1687 np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype), 1688 axis=0, 1689 ) 1690 elif v.shape[0] == nsys: 1691 if k == "cells": 1692 output[k] = np.append( 1693 v, 1694 1000 1695 * np.eye(3, dtype=v.dtype)[None, :, :].repeat( 1696 add_sys, axis=0 1697 ), 1698 axis=0, 1699 ) 1700 else: 1701 output[k] = np.append( 1702 v, 1703 np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype), 1704 axis=0, 1705 ) 1706 output["natoms"] = np.append( 1707 inputs["natoms"], np.zeros(add_sys, dtype=np.int32) 1708 ) 1709 output["species"] = np.append( 1710 species, -1 * np.ones(add_atoms, dtype=species.dtype) 1711 ) 1712 output["batch_index"] = np.append( 1713 inputs["batch_index"], 1714 np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype), 1715 ) 1716 if "system_index" in inputs: 1717 output["system_index"] = np.append( 1718 inputs["system_index"], 1719 np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype), 1720 ) 1721 1722 output["true_atoms"] = output["species"] > 0 1723 output["true_sys"] = np.arange(len(output["natoms"])) < nsys 1724 1725 state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_} 1726 1727 return FrozenDict(state), output 1728 1729 1730def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]: 1731 """Remove padding from atomic arrays.""" 1732 if "true_atoms" not in inputs: 1733 return inputs 1734 1735 species = inputs["species"] 1736 true_atoms = inputs["true_atoms"] 1737 true_sys = inputs["true_sys"] 1738 natall = species.shape[0] 1739 nat = np.argmax(species <= 0) 1740 if nat == 0: 1741 return inputs 1742 1743 natoms = inputs["natoms"] 1744 nsysall = len(natoms) 1745 1746 output = {**inputs} 1747 for k, v in inputs.items(): 1748 if isinstance(v, jax.Array) or isinstance(v, np.ndarray): 1749 if v.ndim == 0: 1750 continue 1751 if v.shape[0] == natall: 1752 output[k] = v[true_atoms] 1753 elif v.shape[0] == nsysall: 1754 output[k] = v[true_sys] 1755 del output["true_sys"] 1756 del output["true_atoms"] 1757 return output 1758 1759 1760def check_input(inputs): 1761 """Check the input dictionary for required keys and types.""" 1762 assert "species" in inputs, "species must be provided" 1763 assert "coordinates" in inputs, "coordinates must be provided" 1764 species = inputs["species"].astype(np.int32) 1765 ifake = np.argmax(species <= 0) 1766 if ifake > 0: 1767 assert np.all(species[:ifake] > 0), "species must be positive" 1768 nat = inputs["species"].shape[0] 1769 1770 natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32) 1771 batch_index = inputs.get( 1772 "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms) 1773 ).astype(np.int32) 1774 output = {**inputs, "natoms": natoms, "batch_index": batch_index} 1775 if "cells" in inputs: 1776 cells = inputs["cells"] 1777 if "reciprocal_cells" not in inputs: 1778 reciprocal_cells = np.linalg.inv(cells) 1779 else: 1780 reciprocal_cells = inputs["reciprocal_cells"] 1781 if cells.ndim == 2: 1782 cells = cells[None, :, :] 1783 if reciprocal_cells.ndim == 2: 1784 reciprocal_cells = reciprocal_cells[None, :, :] 1785 output["cells"] = cells 1786 output["reciprocal_cells"] = reciprocal_cells 1787 1788 return output 1789 1790 1791def convert_to_jax(data): 1792 """Convert a numpy arrays to jax arrays in a pytree.""" 1793 1794 def convert(x): 1795 if isinstance(x, np.ndarray): 1796 # if x.dtype == np.float64: 1797 # return jnp.asarray(x, dtype=jnp.float32) 1798 return jnp.asarray(x) 1799 return x 1800 1801 return jax.tree_util.tree_map(convert, data) 1802 1803 1804class JaxConverter(nn.Module): 1805 """Convert numpy arrays to jax arrays in a pytree.""" 1806 1807 def __call__(self, data): 1808 return convert_to_jax(data) 1809 1810 1811@dataclasses.dataclass(frozen=True) 1812class PreprocessingChain: 1813 """Chain of preprocessing layers.""" 1814 1815 layers: Tuple[Callable[..., Dict[str, Any]]] 1816 """Preprocessing layers.""" 1817 use_atom_padding: bool = False 1818 """Add an AtomPadding layer at the beginning of the chain.""" 1819 atom_padder: AtomPadding = AtomPadding() 1820 """AtomPadding layer.""" 1821 1822 def __post_init__(self): 1823 if not isinstance(self.layers, Sequence): 1824 raise ValueError( 1825 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 1826 ) 1827 if not self.layers: 1828 raise ValueError(f"Error: no Preprocessing layers were provided.") 1829 1830 def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]: 1831 do_check_input = state.get("check_input", True) 1832 if do_check_input: 1833 inputs = check_input(inputs) 1834 new_state = [] 1835 layer_state = state["layers_state"] 1836 i = 0 1837 if self.use_atom_padding: 1838 s, inputs = self.atom_padder(layer_state[0], inputs) 1839 new_state.append(s) 1840 i += 1 1841 for layer in self.layers: 1842 s, inputs = layer(layer_state[i], inputs, return_state_update=False) 1843 new_state.append(s) 1844 i += 1 1845 return FrozenDict({**state, "layers_state": tuple(new_state)}), convert_to_jax( 1846 inputs 1847 ) 1848 1849 def check_reallocate(self, state, inputs): 1850 new_state = [] 1851 state_up = [] 1852 layer_state = state["layers_state"] 1853 i = 0 1854 if self.use_atom_padding: 1855 new_state.append(layer_state[0]) 1856 i += 1 1857 parent_overflow = False 1858 for layer in self.layers: 1859 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1860 layer_state[i], inputs, parent_overflow 1861 ) 1862 new_state.append(s) 1863 state_up.append(s_up) 1864 i += 1 1865 1866 if not parent_overflow: 1867 return state, {}, inputs, False 1868 return ( 1869 FrozenDict({**state, "layers_state": tuple(new_state)}), 1870 state_up, 1871 inputs, 1872 True, 1873 ) 1874 1875 def atom_padding(self, state, inputs): 1876 if self.use_atom_padding: 1877 padder_state = state["layers_state"][0] 1878 return self.atom_padder(padder_state, inputs) 1879 return state, inputs 1880 1881 @partial(jax.jit, static_argnums=(0, 1)) 1882 def process(self, state, inputs): 1883 layer_state = state["layers_state"] 1884 i = 1 if self.use_atom_padding else 0 1885 for layer in self.layers: 1886 inputs = layer.process(layer_state[i], inputs) 1887 i += 1 1888 return inputs 1889 1890 @partial(jax.jit, static_argnums=(0)) 1891 def update_skin(self, inputs): 1892 for layer in self.layers: 1893 inputs = layer.update_skin(inputs) 1894 return inputs 1895 1896 def init(self): 1897 state = [] 1898 if self.use_atom_padding: 1899 state.append(self.atom_padder.init()) 1900 for layer in self.layers: 1901 state.append(layer.init()) 1902 return FrozenDict({"check_input": True, "layers_state": state}) 1903 1904 def init_with_output(self, inputs): 1905 state = self.init() 1906 return self(state, inputs) 1907 1908 def get_processors(self): 1909 processors = [] 1910 for layer in self.layers: 1911 if hasattr(layer, "get_processor"): 1912 processors.append(layer.get_processor()) 1913 return processors 1914 1915 def get_graphs_properties(self): 1916 properties = {} 1917 for layer in self.layers: 1918 if hasattr(layer, "get_graph_properties"): 1919 properties = deep_update(properties, layer.get_graph_properties()) 1920 return properties 1921 1922 1923# PREPROCESSING = { 1924# "GRAPH": GraphGenerator, 1925# # "GRAPH_FIXED": GraphGeneratorFixed, 1926# "GRAPH_FILTER": GraphFilter, 1927# "GRAPH_ANGULAR_EXTENSION": GraphAngularExtension, 1928# # "GRAPH_DENSE_EXTENSION": GraphDenseExtension, 1929# "SPECIES_INDEXER": SpeciesIndexer, 1930# }
22@dataclasses.dataclass(frozen=True) 23class GraphGenerator: 24 """Generate a graph from a set of coordinates 25 26 FPID: GRAPH 27 28 For now, we generate all pairs of atoms and filter based on cutoff. 29 If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist. 30 """ 31 32 cutoff: float 33 """Cutoff distance for the graph.""" 34 graph_key: str = "graph" 35 """Key of the graph in the outputs.""" 36 switch_params: dict = dataclasses.field(default_factory=dict, hash=False) 37 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 38 kmax: int = 30 39 """Maximum number of k-points to consider.""" 40 kthr: float = 1e-6 41 """Threshold for k-point filtering.""" 42 k_space: bool = False 43 """Whether to generate k-space information for the graph.""" 44 mult_size: float = 1.05 45 """Multiplicative factor for resizing the nblist.""" 46 # covalent_cutoff: bool = False 47 48 FPID: ClassVar[str] = "GRAPH" 49 50 def init(self): 51 return FrozenDict( 52 { 53 "max_nat": 1, 54 "npairs": 1, 55 "nblist_mult_size": self.mult_size, 56 } 57 ) 58 59 def get_processor(self) -> Tuple[nn.Module, Dict]: 60 return GraphProcessor, { 61 "cutoff": self.cutoff, 62 "graph_key": self.graph_key, 63 "switch_params": self.switch_params, 64 "name": f"{self.graph_key}_Processor", 65 } 66 67 def get_graph_properties(self): 68 return { 69 self.graph_key: { 70 "cutoff": self.cutoff, 71 "directed": True, 72 } 73 } 74 75 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 76 """build a nblist on cpu with numpy and dynamic shapes + store max shapes""" 77 if self.graph_key in inputs: 78 graph = inputs[self.graph_key] 79 if "keep_graph" in graph: 80 return state, inputs 81 82 coords = np.array(inputs["coordinates"], dtype=np.float32) 83 natoms = np.array(inputs["natoms"], dtype=np.int32) 84 batch_index = np.array(inputs["batch_index"], dtype=np.int32) 85 86 new_state = {**state} 87 state_up = {} 88 89 mult_size = state.get("nblist_mult_size", self.mult_size) 90 assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0" 91 92 if natoms.shape[0] == 1: 93 max_nat = coords.shape[0] 94 true_max_nat = max_nat 95 else: 96 max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0])) 97 true_max_nat = int(np.max(natoms)) 98 if true_max_nat > max_nat: 99 add_atoms = state.get("add_atoms", 0) 100 new_maxnat = true_max_nat + add_atoms 101 state_up["max_nat"] = (new_maxnat, max_nat) 102 new_state["max_nat"] = new_maxnat 103 104 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 105 106 ### compute indices of all pairs 107 p1, p2 = np.triu_indices(true_max_nat, 1) 108 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 109 pbc_shifts = None 110 if natoms.shape[0] > 1: 111 ## batching => mask irrelevant pairs 112 mask_p12 = ( 113 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 114 ).flatten() 115 shift = np.concatenate( 116 (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32)) 117 ) 118 p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 119 p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 120 121 apply_pbc = "cells" in inputs 122 if not apply_pbc: 123 ### NO PBC 124 vec = coords[p2] - coords[p1] 125 else: 126 cells = np.array(inputs["cells"], dtype=np.float32) 127 reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32) 128 minimage = state.get("minimum_image", True) 129 if minimage: 130 ## MINIMUM IMAGE CONVENTION 131 vec = coords[p2] - coords[p1] 132 if cells.shape[0] == 1: 133 vecpbc = np.dot(vec, reciprocal_cells[0]) 134 pbc_shifts = -np.round(vecpbc) 135 vec = vec + np.dot(pbc_shifts, cells[0]) 136 else: 137 batch_index_vec = batch_index[p1] 138 vecpbc = np.einsum( 139 "aj,aji->ai", vec, reciprocal_cells[batch_index_vec] 140 ) 141 pbc_shifts = -np.round(vecpbc) 142 vec = vec + np.einsum( 143 "aj,aji->ai", pbc_shifts, cells[batch_index_vec] 144 ) 145 else: 146 ### GENERAL PBC 147 ## put all atoms in central box 148 if cells.shape[0] == 1: 149 coords_pbc = np.dot(coords, reciprocal_cells[0]) 150 at_shifts = -np.floor(coords_pbc) 151 coords_pbc = coords + np.dot(at_shifts, cells[0]) 152 else: 153 coords_pbc = np.einsum( 154 "aj,aji->ai", coords, reciprocal_cells[batch_index] 155 ) 156 at_shifts = -np.floor(coords_pbc) 157 coords_pbc = coords + np.einsum( 158 "aj,aji->ai", at_shifts, cells[batch_index] 159 ) 160 vec = coords_pbc[p2] - coords_pbc[p1] 161 162 ## compute maximum number of repeats 163 inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5 164 cdinv = cutoff_skin * inv_distances 165 num_repeats_all = np.ceil(cdinv).astype(np.int32) 166 if "true_sys" in inputs: 167 num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0) 168 # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all) 169 num_repeats = np.max(num_repeats_all, axis=0) 170 num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0))) 171 if np.any(num_repeats > num_repeats_prev): 172 num_repeats_new = np.maximum(num_repeats, num_repeats_prev) 173 state_up["num_repeats_pbc"] = ( 174 tuple(num_repeats_new), 175 tuple(num_repeats_prev), 176 ) 177 new_state["num_repeats_pbc"] = tuple(num_repeats_new) 178 ## build all possible shifts 179 cell_shift_pbc = np.array( 180 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 181 dtype=cells.dtype, 182 ).T.reshape(-1, 3) 183 ## shift applied to vectors 184 if cells.shape[0] == 1: 185 dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :] 186 vec = (vec[:, None, :] + dvec).reshape(-1, 3) 187 pbc_shifts = np.broadcast_to( 188 cell_shift_pbc[None, :, :], 189 (p1.shape[0], cell_shift_pbc.shape[0], 3), 190 ).reshape(-1, 3) 191 p1 = np.broadcast_to( 192 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 193 ).flatten() 194 p2 = np.broadcast_to( 195 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 196 ).flatten() 197 if natoms.shape[0] > 1: 198 mask_p12 = np.broadcast_to( 199 mask_p12[:, None], 200 (mask_p12.shape[0], cell_shift_pbc.shape[0]), 201 ).flatten() 202 else: 203 dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells) 204 205 ## get pbc shifts specific to each box 206 cell_shift_pbc = np.broadcast_to( 207 cell_shift_pbc[None, :, :], 208 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 209 ) 210 mask = np.all( 211 np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 212 ).flatten() 213 idx = np.nonzero(mask)[0] 214 nshifts = idx.shape[0] 215 nshifts_prev = state.get("nshifts_pbc", 0) 216 if nshifts > nshifts_prev or add_margin: 217 nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1 218 state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev) 219 new_state["nshifts_pbc"] = nshifts_new 220 221 dvec_filter = dvec.reshape(-1, 3)[idx, :] 222 cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :] 223 224 ## get batch shift in the dvec_filter array 225 nrep = np.prod(2 * num_repeats_all + 1, axis=1) 226 bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1])) 227 228 ## compute vectors 229 batch_index_vec = batch_index[p1] 230 nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0) 231 vec = vec.repeat(nrep_vec, axis=0) 232 nvec_pbc = nrep_vec.sum() #vec.shape[0] 233 nvec_pbc_prev = state.get("nvec_pbc", 0) 234 if nvec_pbc > nvec_pbc_prev or add_margin: 235 nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1 236 state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev) 237 new_state["nvec_pbc"] = nvec_pbc_new 238 239 # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev) 240 ## get shift index 241 dshift = np.concatenate( 242 (np.array([0]), np.cumsum(nrep_vec)[:-1]) 243 ).repeat(nrep_vec) 244 # ishift = np.arange(dshift.shape[0])-dshift 245 # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec) 246 icellshift = ( 247 np.arange(dshift.shape[0]) 248 - dshift 249 + bshift[batch_index_vec].repeat(nrep_vec) 250 ) 251 # shift vectors 252 vec = vec + dvec_filter[icellshift] 253 pbc_shifts = cell_shift_pbc_filter[icellshift] 254 255 p1 = np.repeat(p1, nrep_vec) 256 p2 = np.repeat(p2, nrep_vec) 257 if natoms.shape[0] > 1: 258 mask_p12 = np.repeat(mask_p12, nrep_vec) 259 260 ## compute distances 261 d12 = (vec**2).sum(axis=-1) 262 if natoms.shape[0] > 1: 263 d12 = np.where(mask_p12, d12, cutoff_skin**2) 264 265 ## filter pairs 266 max_pairs = state.get("npairs", 1) 267 mask = d12 < cutoff_skin**2 268 idx = np.nonzero(mask)[0] 269 npairs = idx.shape[0] 270 if npairs > max_pairs or add_margin: 271 prev_max_pairs = max_pairs 272 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 273 state_up["npairs"] = (max_pairs, prev_max_pairs) 274 new_state["npairs"] = max_pairs 275 276 nat = coords.shape[0] 277 edge_src = np.full(max_pairs, nat, dtype=np.int32) 278 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 279 d12_ = np.full(max_pairs, cutoff_skin**2) 280 edge_src[:npairs] = p1[idx] 281 edge_dst[:npairs] = p2[idx] 282 d12_[:npairs] = d12[idx] 283 d12 = d12_ 284 285 if apply_pbc: 286 pbc_shifts_ = np.zeros((max_pairs, 3)) 287 pbc_shifts_[:npairs] = pbc_shifts[idx] 288 pbc_shifts = pbc_shifts_ 289 if not minimage: 290 pbc_shifts[:npairs] = ( 291 pbc_shifts[:npairs] 292 + at_shifts[edge_dst[:npairs]] 293 - at_shifts[edge_src[:npairs]] 294 ) 295 296 if "nblist_skin" in state: 297 edge_src_skin = edge_src 298 edge_dst_skin = edge_dst 299 if apply_pbc: 300 pbc_shifts_skin = pbc_shifts 301 max_pairs_skin = state.get("npairs_skin", 1) 302 mask = d12 < self.cutoff**2 303 idx = np.nonzero(mask)[0] 304 npairs_skin = idx.shape[0] 305 if npairs_skin > max_pairs_skin or add_margin: 306 prev_max_pairs_skin = max_pairs_skin 307 max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1 308 state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin) 309 new_state["npairs_skin"] = max_pairs_skin 310 edge_src = np.full(max_pairs_skin, nat, dtype=np.int32) 311 edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32) 312 d12_ = np.full(max_pairs_skin, self.cutoff**2) 313 edge_src[:npairs_skin] = edge_src_skin[idx] 314 edge_dst[:npairs_skin] = edge_dst_skin[idx] 315 d12_[:npairs_skin] = d12[idx] 316 d12 = d12_ 317 if apply_pbc: 318 pbc_shifts = np.full((max_pairs_skin, 3), 0.0) 319 pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx] 320 321 ## symmetrize 322 edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate( 323 (edge_dst, edge_src) 324 ) 325 d12 = np.concatenate((d12, d12)) 326 if apply_pbc: 327 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts)) 328 329 graph = inputs.get(self.graph_key, {}) 330 graph_out = { 331 **graph, 332 "edge_src": edge_src, 333 "edge_dst": edge_dst, 334 "d12": d12, 335 "overflow": False, 336 "pbc_shifts": pbc_shifts, 337 } 338 if "nblist_skin" in state: 339 graph_out["edge_src_skin"] = edge_src_skin 340 graph_out["edge_dst_skin"] = edge_dst_skin 341 if apply_pbc: 342 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 343 344 if self.k_space and apply_pbc: 345 if "k_points" not in graph: 346 ks, _, _, bewald = get_reciprocal_space_parameters( 347 reciprocal_cells, self.cutoff, self.kmax, self.kthr 348 ) 349 graph_out["k_points"] = ks 350 graph_out["b_ewald"] = bewald 351 352 output = {**inputs, self.graph_key: graph_out} 353 354 if return_state_update: 355 return FrozenDict(new_state), output, state_up 356 return FrozenDict(new_state), output 357 358 def check_reallocate(self, state, inputs, parent_overflow=False): 359 """check for overflow and reallocate nblist if necessary""" 360 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 361 if not overflow: 362 return state, {}, inputs, False 363 364 add_margin = inputs[self.graph_key].get("overflow", False) 365 state, inputs, state_up = self( 366 state, inputs, return_state_update=True, add_margin=add_margin 367 ) 368 return state, state_up, inputs, True 369 370 @partial(jax.jit, static_argnums=(0, 1)) 371 def process(self, state, inputs): 372 """build a nblist on accelerator with jax and precomputed shapes""" 373 if self.graph_key in inputs: 374 graph = inputs[self.graph_key] 375 if "keep_graph" in graph: 376 return inputs 377 coords = inputs["coordinates"] 378 natoms = inputs["natoms"] 379 batch_index = inputs["batch_index"] 380 381 if natoms.shape[0] == 1: 382 max_nat = coords.shape[0] 383 else: 384 max_nat = state.get( 385 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 386 ) 387 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 388 389 ### compute indices of all pairs 390 p1, p2 = np.triu_indices(max_nat, 1) 391 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 392 pbc_shifts = None 393 if natoms.shape[0] > 1: 394 ## batching => mask irrelevant pairs 395 mask_p12 = ( 396 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 397 ).flatten() 398 shift = jnp.concatenate( 399 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 400 ) 401 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 402 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 403 404 ## compute vectors 405 overflow_repeats = jnp.asarray(False, dtype=bool) 406 if "cells" not in inputs: 407 vec = coords[p2] - coords[p1] 408 else: 409 cells = inputs["cells"] 410 reciprocal_cells = inputs["reciprocal_cells"] 411 minimage = state.get("minimum_image", True) 412 413 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 414 vecpbc = jnp.dot(vec, reciprocal_cell) 415 if mode == "round": 416 pbc_shifts = -jnp.round(vecpbc) 417 elif mode == "floor": 418 pbc_shifts = -jnp.floor(vecpbc) 419 else: 420 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 421 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 422 423 if minimage: 424 ## minimum image convention 425 vec = coords[p2] - coords[p1] 426 427 if cells.shape[0] == 1: 428 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 429 else: 430 batch_index_vec = batch_index[p1] 431 vec, pbc_shifts = jax.vmap(compute_pbc)( 432 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 433 ) 434 else: 435 ### general PBC only for single cell yet 436 # if cells.shape[0] > 1: 437 # raise NotImplementedError( 438 # "General PBC not implemented for batches on accelerator." 439 # ) 440 # cell = cells[0] 441 # reciprocal_cell = reciprocal_cells[0] 442 443 ## put all atoms in central box 444 if cells.shape[0] == 1: 445 coords_pbc, at_shifts = compute_pbc( 446 coords, reciprocal_cells[0], cells[0], mode="floor" 447 ) 448 else: 449 coords_pbc, at_shifts = jax.vmap( 450 partial(compute_pbc, mode="floor") 451 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 452 vec = coords_pbc[p2] - coords_pbc[p1] 453 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 454 # if num_repeats is None: 455 # raise ValueError( 456 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 457 # ) 458 # check if num_repeats is larger than previous 459 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 460 cdinv = cutoff_skin * inv_distances 461 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 462 if "true_sys" in inputs: 463 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 464 num_repeats_new = jnp.max(num_repeats_all, axis=0) 465 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 466 467 cell_shift_pbc = jnp.asarray( 468 np.array( 469 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 470 dtype=cells.dtype, 471 ).T.reshape(-1, 3) 472 ) 473 474 if cells.shape[0] == 1: 475 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 476 pbc_shifts = jnp.broadcast_to( 477 cell_shift_pbc[None, :, :], 478 (p1.shape[0], cell_shift_pbc.shape[0], 3), 479 ).reshape(-1, 3) 480 p1 = jnp.broadcast_to( 481 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 482 ).flatten() 483 p2 = jnp.broadcast_to( 484 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 485 ).flatten() 486 if natoms.shape[0] > 1: 487 mask_p12 = jnp.broadcast_to( 488 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 489 ).flatten() 490 else: 491 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 492 493 ## get pbc shifts specific to each box 494 cell_shift_pbc = jnp.broadcast_to( 495 cell_shift_pbc[None, :, :], 496 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 497 ) 498 mask = jnp.all( 499 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 500 ).flatten() 501 max_shifts = state.get("nshifts_pbc", 1) 502 503 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 504 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 505 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 506 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 507 mask, 508 max_shifts, 509 (dvecx, 0.), 510 (dvecy, 0.), 511 (dvecz, 0.), 512 (shiftx, 0), 513 (shifty, 0), 514 (shiftz, 0), 515 ) 516 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 517 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 518 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 519 520 ## get batch shift in the dvec_filter array 521 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 522 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 523 524 ## repeat vectors 525 nvec_max = state.get("nvec_pbc", 1) 526 batch_index_vec = batch_index[p1] 527 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 528 nvec = nrep_vec.sum() 529 overflow_repeats = overflow_repeats | (nvec > nvec_max) 530 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 531 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 532 533 ## get shift index 534 dshift = jnp.concatenate( 535 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 536 ) 537 if nrep_vec.size == 0: 538 dshift = jnp.array([],dtype=jnp.int32) 539 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 540 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 541 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 542 vec = vec + dvec[icellshift] 543 pbc_shifts = cell_shift_pbc[icellshift] 544 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 545 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 546 if natoms.shape[0] > 1: 547 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 548 549 550 ## compute distances 551 d12 = (vec**2).sum(axis=-1) 552 if natoms.shape[0] > 1: 553 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 554 555 ## filter pairs 556 max_pairs = state.get("npairs", 1) 557 mask = d12 < cutoff_skin**2 558 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 559 mask, 560 max_pairs, 561 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 562 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 563 (d12, cutoff_skin**2), 564 ) 565 if "cells" in inputs: 566 pbc_shifts = ( 567 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 568 .at[scatter_idx] 569 .set(pbc_shifts, mode="drop") 570 ) 571 if not minimage: 572 pbc_shifts = ( 573 pbc_shifts 574 + at_shifts.at[edge_dst].get(fill_value=0.0) 575 - at_shifts.at[edge_src].get(fill_value=0.0) 576 ) 577 578 ## check for overflow 579 if natoms.shape[0] == 1: 580 true_max_nat = coords.shape[0] 581 else: 582 true_max_nat = jnp.max(natoms) 583 overflow_count = npairs > max_pairs 584 overflow_at = true_max_nat > max_nat 585 overflow = overflow_count | overflow_at | overflow_repeats 586 587 if "nblist_skin" in state: 588 # edge_mask_skin = edge_mask 589 edge_src_skin = edge_src 590 edge_dst_skin = edge_dst 591 if "cells" in inputs: 592 pbc_shifts_skin = pbc_shifts 593 max_pairs_skin = state.get("npairs_skin", 1) 594 mask = d12 < self.cutoff**2 595 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 596 mask, 597 max_pairs_skin, 598 (edge_src, coords.shape[0]), 599 (edge_dst, coords.shape[0]), 600 (d12, self.cutoff**2), 601 ) 602 if "cells" in inputs: 603 pbc_shifts = ( 604 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 605 .at[scatter_idx] 606 .set(pbc_shifts, mode="drop") 607 ) 608 overflow = overflow | (npairs_skin > max_pairs_skin) 609 610 ## symmetrize 611 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 612 (edge_dst, edge_src) 613 ) 614 d12 = jnp.concatenate((d12, d12)) 615 if "cells" in inputs: 616 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 617 618 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 619 graph_out = { 620 **graph, 621 "edge_src": edge_src, 622 "edge_dst": edge_dst, 623 "d12": d12, 624 "overflow": overflow, 625 "pbc_shifts": pbc_shifts, 626 } 627 if "nblist_skin" in state: 628 graph_out["edge_src_skin"] = edge_src_skin 629 graph_out["edge_dst_skin"] = edge_dst_skin 630 if "cells" in inputs: 631 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 632 633 if self.k_space and "cells" in inputs: 634 if "k_points" not in graph: 635 raise NotImplementedError( 636 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 637 ) 638 return {**inputs, self.graph_key: graph_out} 639 640 @partial(jax.jit, static_argnums=(0,)) 641 def update_skin(self, inputs): 642 """update the nblist without recomputing the full nblist""" 643 graph = inputs[self.graph_key] 644 645 edge_src_skin = graph["edge_src_skin"] 646 edge_dst_skin = graph["edge_dst_skin"] 647 coords = inputs["coordinates"] 648 vec = coords.at[edge_dst_skin].get( 649 mode="fill", fill_value=self.cutoff 650 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 651 652 if "cells" in inputs: 653 pbc_shifts_skin = graph["pbc_shifts_skin"] 654 cells = inputs["cells"] 655 if cells.shape[0] == 1: 656 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 657 else: 658 batch_index_vec = inputs["batch_index"][edge_src_skin] 659 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 660 661 nat = coords.shape[0] 662 d12 = jnp.sum(vec**2, axis=-1) 663 mask = d12 < self.cutoff**2 664 max_pairs = graph["edge_src"].shape[0] // 2 665 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 666 mask, 667 max_pairs, 668 (edge_src_skin, nat), 669 (edge_dst_skin, nat), 670 (d12, self.cutoff**2), 671 ) 672 if "cells" in inputs: 673 pbc_shifts = ( 674 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 675 .at[scatter_idx] 676 .set(pbc_shifts_skin) 677 ) 678 679 overflow = graph.get("overflow", False) | (npairs > max_pairs) 680 graph_out = { 681 **graph, 682 "edge_src": jnp.concatenate((edge_src, edge_dst)), 683 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 684 "d12": jnp.concatenate((d12, d12)), 685 "overflow": overflow, 686 } 687 if "cells" in inputs: 688 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 689 690 if self.k_space and "cells" in inputs: 691 if "k_points" not in graph: 692 raise NotImplementedError( 693 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 694 ) 695 696 return {**inputs, self.graph_key: graph_out}
Generate a graph from a set of coordinates
FPID: GRAPH
For now, we generate all pairs of atoms and filter based on cutoff.
If a nblist_skin
is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the update_skin
method to update the original graph without recomputing the full nblist.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
358 def check_reallocate(self, state, inputs, parent_overflow=False): 359 """check for overflow and reallocate nblist if necessary""" 360 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 361 if not overflow: 362 return state, {}, inputs, False 363 364 add_margin = inputs[self.graph_key].get("overflow", False) 365 state, inputs, state_up = self( 366 state, inputs, return_state_update=True, add_margin=add_margin 367 ) 368 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
370 @partial(jax.jit, static_argnums=(0, 1)) 371 def process(self, state, inputs): 372 """build a nblist on accelerator with jax and precomputed shapes""" 373 if self.graph_key in inputs: 374 graph = inputs[self.graph_key] 375 if "keep_graph" in graph: 376 return inputs 377 coords = inputs["coordinates"] 378 natoms = inputs["natoms"] 379 batch_index = inputs["batch_index"] 380 381 if natoms.shape[0] == 1: 382 max_nat = coords.shape[0] 383 else: 384 max_nat = state.get( 385 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 386 ) 387 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 388 389 ### compute indices of all pairs 390 p1, p2 = np.triu_indices(max_nat, 1) 391 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 392 pbc_shifts = None 393 if natoms.shape[0] > 1: 394 ## batching => mask irrelevant pairs 395 mask_p12 = ( 396 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 397 ).flatten() 398 shift = jnp.concatenate( 399 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 400 ) 401 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 402 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 403 404 ## compute vectors 405 overflow_repeats = jnp.asarray(False, dtype=bool) 406 if "cells" not in inputs: 407 vec = coords[p2] - coords[p1] 408 else: 409 cells = inputs["cells"] 410 reciprocal_cells = inputs["reciprocal_cells"] 411 minimage = state.get("minimum_image", True) 412 413 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 414 vecpbc = jnp.dot(vec, reciprocal_cell) 415 if mode == "round": 416 pbc_shifts = -jnp.round(vecpbc) 417 elif mode == "floor": 418 pbc_shifts = -jnp.floor(vecpbc) 419 else: 420 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 421 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 422 423 if minimage: 424 ## minimum image convention 425 vec = coords[p2] - coords[p1] 426 427 if cells.shape[0] == 1: 428 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 429 else: 430 batch_index_vec = batch_index[p1] 431 vec, pbc_shifts = jax.vmap(compute_pbc)( 432 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 433 ) 434 else: 435 ### general PBC only for single cell yet 436 # if cells.shape[0] > 1: 437 # raise NotImplementedError( 438 # "General PBC not implemented for batches on accelerator." 439 # ) 440 # cell = cells[0] 441 # reciprocal_cell = reciprocal_cells[0] 442 443 ## put all atoms in central box 444 if cells.shape[0] == 1: 445 coords_pbc, at_shifts = compute_pbc( 446 coords, reciprocal_cells[0], cells[0], mode="floor" 447 ) 448 else: 449 coords_pbc, at_shifts = jax.vmap( 450 partial(compute_pbc, mode="floor") 451 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 452 vec = coords_pbc[p2] - coords_pbc[p1] 453 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 454 # if num_repeats is None: 455 # raise ValueError( 456 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 457 # ) 458 # check if num_repeats is larger than previous 459 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 460 cdinv = cutoff_skin * inv_distances 461 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 462 if "true_sys" in inputs: 463 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 464 num_repeats_new = jnp.max(num_repeats_all, axis=0) 465 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 466 467 cell_shift_pbc = jnp.asarray( 468 np.array( 469 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 470 dtype=cells.dtype, 471 ).T.reshape(-1, 3) 472 ) 473 474 if cells.shape[0] == 1: 475 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 476 pbc_shifts = jnp.broadcast_to( 477 cell_shift_pbc[None, :, :], 478 (p1.shape[0], cell_shift_pbc.shape[0], 3), 479 ).reshape(-1, 3) 480 p1 = jnp.broadcast_to( 481 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 482 ).flatten() 483 p2 = jnp.broadcast_to( 484 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 485 ).flatten() 486 if natoms.shape[0] > 1: 487 mask_p12 = jnp.broadcast_to( 488 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 489 ).flatten() 490 else: 491 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 492 493 ## get pbc shifts specific to each box 494 cell_shift_pbc = jnp.broadcast_to( 495 cell_shift_pbc[None, :, :], 496 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 497 ) 498 mask = jnp.all( 499 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 500 ).flatten() 501 max_shifts = state.get("nshifts_pbc", 1) 502 503 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 504 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 505 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 506 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 507 mask, 508 max_shifts, 509 (dvecx, 0.), 510 (dvecy, 0.), 511 (dvecz, 0.), 512 (shiftx, 0), 513 (shifty, 0), 514 (shiftz, 0), 515 ) 516 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 517 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 518 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 519 520 ## get batch shift in the dvec_filter array 521 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 522 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 523 524 ## repeat vectors 525 nvec_max = state.get("nvec_pbc", 1) 526 batch_index_vec = batch_index[p1] 527 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 528 nvec = nrep_vec.sum() 529 overflow_repeats = overflow_repeats | (nvec > nvec_max) 530 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 531 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 532 533 ## get shift index 534 dshift = jnp.concatenate( 535 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 536 ) 537 if nrep_vec.size == 0: 538 dshift = jnp.array([],dtype=jnp.int32) 539 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 540 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 541 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 542 vec = vec + dvec[icellshift] 543 pbc_shifts = cell_shift_pbc[icellshift] 544 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 545 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 546 if natoms.shape[0] > 1: 547 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 548 549 550 ## compute distances 551 d12 = (vec**2).sum(axis=-1) 552 if natoms.shape[0] > 1: 553 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 554 555 ## filter pairs 556 max_pairs = state.get("npairs", 1) 557 mask = d12 < cutoff_skin**2 558 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 559 mask, 560 max_pairs, 561 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 562 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 563 (d12, cutoff_skin**2), 564 ) 565 if "cells" in inputs: 566 pbc_shifts = ( 567 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 568 .at[scatter_idx] 569 .set(pbc_shifts, mode="drop") 570 ) 571 if not minimage: 572 pbc_shifts = ( 573 pbc_shifts 574 + at_shifts.at[edge_dst].get(fill_value=0.0) 575 - at_shifts.at[edge_src].get(fill_value=0.0) 576 ) 577 578 ## check for overflow 579 if natoms.shape[0] == 1: 580 true_max_nat = coords.shape[0] 581 else: 582 true_max_nat = jnp.max(natoms) 583 overflow_count = npairs > max_pairs 584 overflow_at = true_max_nat > max_nat 585 overflow = overflow_count | overflow_at | overflow_repeats 586 587 if "nblist_skin" in state: 588 # edge_mask_skin = edge_mask 589 edge_src_skin = edge_src 590 edge_dst_skin = edge_dst 591 if "cells" in inputs: 592 pbc_shifts_skin = pbc_shifts 593 max_pairs_skin = state.get("npairs_skin", 1) 594 mask = d12 < self.cutoff**2 595 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 596 mask, 597 max_pairs_skin, 598 (edge_src, coords.shape[0]), 599 (edge_dst, coords.shape[0]), 600 (d12, self.cutoff**2), 601 ) 602 if "cells" in inputs: 603 pbc_shifts = ( 604 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 605 .at[scatter_idx] 606 .set(pbc_shifts, mode="drop") 607 ) 608 overflow = overflow | (npairs_skin > max_pairs_skin) 609 610 ## symmetrize 611 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 612 (edge_dst, edge_src) 613 ) 614 d12 = jnp.concatenate((d12, d12)) 615 if "cells" in inputs: 616 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 617 618 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 619 graph_out = { 620 **graph, 621 "edge_src": edge_src, 622 "edge_dst": edge_dst, 623 "d12": d12, 624 "overflow": overflow, 625 "pbc_shifts": pbc_shifts, 626 } 627 if "nblist_skin" in state: 628 graph_out["edge_src_skin"] = edge_src_skin 629 graph_out["edge_dst_skin"] = edge_dst_skin 630 if "cells" in inputs: 631 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 632 633 if self.k_space and "cells" in inputs: 634 if "k_points" not in graph: 635 raise NotImplementedError( 636 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 637 ) 638 return {**inputs, self.graph_key: graph_out}
build a nblist on accelerator with jax and precomputed shapes
640 @partial(jax.jit, static_argnums=(0,)) 641 def update_skin(self, inputs): 642 """update the nblist without recomputing the full nblist""" 643 graph = inputs[self.graph_key] 644 645 edge_src_skin = graph["edge_src_skin"] 646 edge_dst_skin = graph["edge_dst_skin"] 647 coords = inputs["coordinates"] 648 vec = coords.at[edge_dst_skin].get( 649 mode="fill", fill_value=self.cutoff 650 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 651 652 if "cells" in inputs: 653 pbc_shifts_skin = graph["pbc_shifts_skin"] 654 cells = inputs["cells"] 655 if cells.shape[0] == 1: 656 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 657 else: 658 batch_index_vec = inputs["batch_index"][edge_src_skin] 659 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 660 661 nat = coords.shape[0] 662 d12 = jnp.sum(vec**2, axis=-1) 663 mask = d12 < self.cutoff**2 664 max_pairs = graph["edge_src"].shape[0] // 2 665 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 666 mask, 667 max_pairs, 668 (edge_src_skin, nat), 669 (edge_dst_skin, nat), 670 (d12, self.cutoff**2), 671 ) 672 if "cells" in inputs: 673 pbc_shifts = ( 674 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 675 .at[scatter_idx] 676 .set(pbc_shifts_skin) 677 ) 678 679 overflow = graph.get("overflow", False) | (npairs > max_pairs) 680 graph_out = { 681 **graph, 682 "edge_src": jnp.concatenate((edge_src, edge_dst)), 683 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 684 "d12": jnp.concatenate((d12, d12)), 685 "overflow": overflow, 686 } 687 if "cells" in inputs: 688 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 689 690 if self.k_space and "cells" in inputs: 691 if "k_points" not in graph: 692 raise NotImplementedError( 693 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 694 ) 695 696 return {**inputs, self.graph_key: graph_out}
update the nblist without recomputing the full nblist
699class GraphProcessor(nn.Module): 700 """Process a pre-generated graph 701 702 The pre-generated graph should contain the following keys: 703 - edge_src: source indices of the edges 704 - edge_dst: destination indices of the edges 705 - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs) 706 707 This module is automatically added to a FENNIX model when a GraphGenerator is used. 708 709 """ 710 711 cutoff: float 712 """Cutoff distance for the graph.""" 713 graph_key: str = "graph" 714 """Key of the graph in the outputs.""" 715 switch_params: dict = dataclasses.field(default_factory=dict) 716 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 717 718 @nn.compact 719 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 720 graph = inputs[self.graph_key] 721 coords = inputs["coordinates"] 722 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 723 # edge_mask = edge_src < coords.shape[0] 724 vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[ 725 edge_src 726 ].get(mode="fill", fill_value=0.0) 727 if "cells" in inputs: 728 cells = inputs["cells"] 729 if cells.shape[0] == 1: 730 vec = vec + jnp.dot(graph["pbc_shifts"], cells[0]) 731 else: 732 batch_index_vec = inputs["batch_index"][edge_src] 733 vec = vec + jax.vmap(jnp.dot)( 734 graph["pbc_shifts"], cells[batch_index_vec] 735 ) 736 737 d2 = jnp.sum(vec**2, axis=-1) 738 distances = safe_sqrt(d2) 739 edge_mask = distances < self.cutoff 740 741 switch = SwitchFunction( 742 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 743 )((distances, edge_mask)) 744 745 graph_out = { 746 **graph, 747 "vec": vec, 748 "distances": distances, 749 "switch": switch, 750 "edge_mask": edge_mask, 751 } 752 753 if "alch_group" in inputs: 754 alch_group = inputs["alch_group"] 755 lambda_e = inputs["alch_elambda"] 756 mask = alch_group[edge_src] == alch_group[edge_dst] 757 graph_out["switch_raw"] = switch 758 graph_out["switch"] = jnp.where( 759 mask, 760 switch, 761 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 762 ) 763 764 if "alch_softcore_e" in inputs or "alch_softcore_v" in inputs: 765 graph_out["distances_raw"] = distances 766 if "alch_softcore_e" in inputs: 767 alch_alpha = (1-inputs["alch_elambda"])*inputs["alch_softcore_e"]**2 768 else: 769 alch_alpha = (1-inputs["alch_vlambda"])*inputs["alch_softcore_v"]**2 770 distances = jnp.where( 771 mask, 772 distances, 773 safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2)) 774 ) 775 graph_out["distances"] = distances 776 777 return {**inputs, self.graph_key: graph_out}
Process a pre-generated graph
The pre-generated graph should contain the following keys:
- edge_src: source indices of the edges
- edge_dst: destination indices of the edges
- pbcs_shifts: pbc shifts for the edges (only if
cells
are present in the inputs)
This module is automatically added to a FENNIX model when a GraphGenerator is used.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
780@dataclasses.dataclass(frozen=True) 781class GraphFilter: 782 """Filter a graph based on a cutoff distance 783 784 FPID: GRAPH_FILTER 785 """ 786 787 cutoff: float 788 """Cutoff distance for the filtering.""" 789 parent_graph: str 790 """Key of the parent graph in the inputs.""" 791 graph_key: str 792 """Key of the filtered graph in the outputs.""" 793 remove_hydrogens: int = False 794 """Remove edges where the source is a hydrogen atom.""" 795 switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict) 796 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 797 k_space: bool = False 798 """Generate k-space information for the graph.""" 799 kmax: int = 30 800 """Maximum number of k-points to consider.""" 801 kthr: float = 1e-6 802 """Threshold for k-point filtering.""" 803 mult_size: float = 1.05 804 """Multiplicative factor for resizing the nblist.""" 805 806 FPID: ClassVar[str] = "GRAPH_FILTER" 807 808 def init(self): 809 return FrozenDict( 810 { 811 "npairs": 1, 812 "nblist_mult_size": self.mult_size, 813 } 814 ) 815 816 def get_processor(self) -> Tuple[nn.Module, Dict]: 817 return GraphFilterProcessor, { 818 "cutoff": self.cutoff, 819 "graph_key": self.graph_key, 820 "parent_graph": self.parent_graph, 821 "name": f"{self.graph_key}_Filter_{self.parent_graph}", 822 "switch_params": self.switch_params, 823 } 824 825 def get_graph_properties(self): 826 return { 827 self.graph_key: { 828 "cutoff": self.cutoff, 829 "directed": True, 830 "parent_graph": self.parent_graph, 831 } 832 } 833 834 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 835 """filter a nblist on cpu with numpy and dynamic shapes + store max shapes""" 836 graph_in = inputs[self.parent_graph] 837 nat = inputs["species"].shape[0] 838 839 new_state = {**state} 840 state_up = {} 841 mult_size = state.get("nblist_mult_size", self.mult_size) 842 assert mult_size >= 1., "nblist_mult_size should be >= 1." 843 844 edge_src = np.array(graph_in["edge_src"], dtype=np.int32) 845 d12 = np.array(graph_in["d12"], dtype=np.float32) 846 if self.remove_hydrogens: 847 species = inputs["species"] 848 src_idx = (edge_src < nat).nonzero()[0] 849 mask = np.zeros(edge_src.shape[0], dtype=bool) 850 mask[src_idx] = (species > 1)[edge_src[src_idx]] 851 d12 = np.where(mask, d12, self.cutoff**2) 852 mask = d12 < self.cutoff**2 853 854 max_pairs = state.get("npairs", 1) 855 idx = np.nonzero(mask)[0] 856 npairs = idx.shape[0] 857 if npairs > max_pairs or add_margin: 858 prev_max_pairs = max_pairs 859 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 860 state_up["npairs"] = (max_pairs, prev_max_pairs) 861 new_state["npairs"] = max_pairs 862 863 filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32) 864 edge_src = np.full(max_pairs, nat, dtype=np.int32) 865 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 866 d12_ = np.full(max_pairs, self.cutoff**2) 867 filter_indices[:npairs] = idx 868 edge_src[:npairs] = graph_in["edge_src"][idx] 869 edge_dst[:npairs] = graph_in["edge_dst"][idx] 870 d12_[:npairs] = d12[idx] 871 d12 = d12_ 872 873 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 874 graph_out = { 875 **graph, 876 "edge_src": edge_src, 877 "edge_dst": edge_dst, 878 "filter_indices": filter_indices, 879 "d12": d12, 880 "overflow": False, 881 } 882 883 if self.k_space and "cells" in inputs: 884 if "k_points" not in graph: 885 ks, _, _, bewald = get_reciprocal_space_parameters( 886 inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr 887 ) 888 graph_out["k_points"] = ks 889 graph_out["b_ewald"] = bewald 890 891 output = {**inputs, self.graph_key: graph_out} 892 if return_state_update: 893 return FrozenDict(new_state), output, state_up 894 return FrozenDict(new_state), output 895 896 def check_reallocate(self, state, inputs, parent_overflow=False): 897 """check for overflow and reallocate nblist if necessary""" 898 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 899 if not overflow: 900 return state, {}, inputs, False 901 902 add_margin = inputs[self.graph_key].get("overflow", False) 903 state, inputs, state_up = self( 904 state, inputs, return_state_update=True, add_margin=add_margin 905 ) 906 return state, state_up, inputs, True 907 908 @partial(jax.jit, static_argnums=(0, 1)) 909 def process(self, state, inputs): 910 """filter a nblist on accelerator with jax and precomputed shapes""" 911 graph_in = inputs[self.parent_graph] 912 if state is None: 913 # skin update mode 914 graph = inputs[self.graph_key] 915 max_pairs = graph["edge_src"].shape[0] 916 else: 917 max_pairs = state.get("npairs", 1) 918 919 max_pairs_in = graph_in["edge_src"].shape[0] 920 nat = inputs["species"].shape[0] 921 922 edge_src = graph_in["edge_src"] 923 d12 = graph_in["d12"] 924 if self.remove_hydrogens: 925 species = inputs["species"] 926 mask = (species > 1)[edge_src] 927 d12 = jnp.where(mask, d12, self.cutoff**2) 928 mask = d12 < self.cutoff**2 929 930 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 931 mask, 932 max_pairs, 933 (edge_src, nat), 934 (graph_in["edge_dst"], nat), 935 (d12, self.cutoff**2), 936 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 937 ) 938 939 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 940 overflow = graph.get("overflow", False) | (npairs > max_pairs) 941 graph_out = { 942 **graph, 943 "edge_src": edge_src, 944 "edge_dst": edge_dst, 945 "filter_indices": filter_indices, 946 "d12": d12, 947 "overflow": overflow, 948 } 949 950 if self.k_space and "cells" in inputs: 951 if "k_points" not in graph: 952 raise NotImplementedError( 953 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 954 ) 955 956 return {**inputs, self.graph_key: graph_out} 957 958 @partial(jax.jit, static_argnums=(0,)) 959 def update_skin(self, inputs): 960 return self.process(None, inputs)
Filter a graph based on a cutoff distance
FPID: GRAPH_FILTER
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
896 def check_reallocate(self, state, inputs, parent_overflow=False): 897 """check for overflow and reallocate nblist if necessary""" 898 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 899 if not overflow: 900 return state, {}, inputs, False 901 902 add_margin = inputs[self.graph_key].get("overflow", False) 903 state, inputs, state_up = self( 904 state, inputs, return_state_update=True, add_margin=add_margin 905 ) 906 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
908 @partial(jax.jit, static_argnums=(0, 1)) 909 def process(self, state, inputs): 910 """filter a nblist on accelerator with jax and precomputed shapes""" 911 graph_in = inputs[self.parent_graph] 912 if state is None: 913 # skin update mode 914 graph = inputs[self.graph_key] 915 max_pairs = graph["edge_src"].shape[0] 916 else: 917 max_pairs = state.get("npairs", 1) 918 919 max_pairs_in = graph_in["edge_src"].shape[0] 920 nat = inputs["species"].shape[0] 921 922 edge_src = graph_in["edge_src"] 923 d12 = graph_in["d12"] 924 if self.remove_hydrogens: 925 species = inputs["species"] 926 mask = (species > 1)[edge_src] 927 d12 = jnp.where(mask, d12, self.cutoff**2) 928 mask = d12 < self.cutoff**2 929 930 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 931 mask, 932 max_pairs, 933 (edge_src, nat), 934 (graph_in["edge_dst"], nat), 935 (d12, self.cutoff**2), 936 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 937 ) 938 939 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 940 overflow = graph.get("overflow", False) | (npairs > max_pairs) 941 graph_out = { 942 **graph, 943 "edge_src": edge_src, 944 "edge_dst": edge_dst, 945 "filter_indices": filter_indices, 946 "d12": d12, 947 "overflow": overflow, 948 } 949 950 if self.k_space and "cells" in inputs: 951 if "k_points" not in graph: 952 raise NotImplementedError( 953 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 954 ) 955 956 return {**inputs, self.graph_key: graph_out}
filter a nblist on accelerator with jax and precomputed shapes
963class GraphFilterProcessor(nn.Module): 964 """Filter processing for a pre-generated graph 965 966 This module is automatically added to a FENNIX model when a GraphFilter is used. 967 """ 968 969 cutoff: float 970 """Cutoff distance for the filtering.""" 971 graph_key: str 972 """Key of the filtered graph in the inputs.""" 973 parent_graph: str 974 """Key of the parent graph in the inputs.""" 975 switch_params: dict = dataclasses.field(default_factory=dict) 976 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 977 978 @nn.compact 979 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 980 graph_in = inputs[self.parent_graph] 981 graph = inputs[self.graph_key] 982 983 d_key = "distances_raw" if "distances_raw" in graph else "distances" 984 985 if graph_in["vec"].shape[0] == 0: 986 vec = graph_in["vec"] 987 distances = graph_in[d_key] 988 filter_indices = jnp.asarray([], dtype=jnp.int32) 989 else: 990 filter_indices = graph["filter_indices"] 991 vec = ( 992 graph_in["vec"] 993 .at[filter_indices] 994 .get(mode="fill", fill_value=self.cutoff) 995 ) 996 distances = ( 997 graph_in[d_key] 998 .at[filter_indices] 999 .get(mode="fill", fill_value=self.cutoff) 1000 ) 1001 1002 edge_mask = distances < self.cutoff 1003 switch = SwitchFunction( 1004 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 1005 )((distances, edge_mask)) 1006 1007 graph_out = { 1008 **graph, 1009 "vec": vec, 1010 "distances": distances, 1011 "switch": switch, 1012 "filter_indices": filter_indices, 1013 "edge_mask": edge_mask, 1014 } 1015 1016 if "alch_group" in inputs: 1017 edge_src=graph["edge_src"] 1018 edge_dst=graph["edge_dst"] 1019 alch_group = inputs["alch_group"] 1020 lambda_e = inputs["alch_elambda"] 1021 mask = alch_group[edge_src] == alch_group[edge_dst] 1022 graph_out["switch_raw"] = switch 1023 graph_out["switch"] = jnp.where( 1024 mask, 1025 switch, 1026 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 1027 ) 1028 1029 if "alch_softcore_e" in inputs or "alch_softcore_v" in inputs: 1030 graph_out["distances_raw"] = distances 1031 if "alch_softcore_e" in inputs: 1032 alch_alpha = (1-inputs["alch_elambda"])*inputs["alch_softcore_e"]**2 1033 else: 1034 alch_alpha = (1-inputs["alch_vlambda"])*inputs["alch_softcore_v"]**2 1035 distances = jnp.where( 1036 mask, 1037 distances, 1038 safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2)) 1039 ) 1040 graph_out["distances"] = distances 1041 1042 return {**inputs, self.graph_key: graph_out}
Filter processing for a pre-generated graph
This module is automatically added to a FENNIX model when a GraphFilter is used.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1045@dataclasses.dataclass(frozen=True) 1046class GraphAngularExtension: 1047 """Add angles list to a graph 1048 1049 FPID: GRAPH_ANGULAR_EXTENSION 1050 """ 1051 1052 mult_size: float = 1.05 1053 """Multiplicative factor for resizing the nblist.""" 1054 add_neigh: int = 5 1055 """Additional neighbors to add to the nblist when resizing.""" 1056 graph_key: str = "graph" 1057 """Key of the graph in the inputs.""" 1058 1059 FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION" 1060 1061 def init(self): 1062 return FrozenDict( 1063 { 1064 "nangles": 0, 1065 "nblist_mult_size": self.mult_size, 1066 "max_neigh": self.add_neigh, 1067 "add_neigh": self.add_neigh, 1068 } 1069 ) 1070 1071 def get_processor(self) -> Tuple[nn.Module, Dict]: 1072 return GraphAngleProcessor, { 1073 "graph_key": self.graph_key, 1074 "name": f"{self.graph_key}_AngleProcessor", 1075 } 1076 1077 def get_graph_properties(self): 1078 return { 1079 self.graph_key: { 1080 "has_angles": True, 1081 } 1082 } 1083 1084 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1085 """build angle nblist on cpu with numpy and dynamic shapes + store max shapes""" 1086 graph = inputs[self.graph_key] 1087 edge_src = np.array(graph["edge_src"], dtype=np.int32) 1088 1089 new_state = {**state} 1090 state_up = {} 1091 mult_size = state.get("nblist_mult_size", self.mult_size) 1092 assert mult_size >= 1., "nblist_mult_size should be >= 1." 1093 1094 ### count number of neighbors 1095 nat = inputs["species"].shape[0] 1096 count = np.zeros(nat + 1, dtype=np.int32) 1097 np.add.at(count, edge_src, 1) 1098 max_count = int(np.max(count[:-1])) 1099 1100 ### get sizes 1101 max_neigh = state.get("max_neigh", self.add_neigh) 1102 nedge = edge_src.shape[0] 1103 if max_count > max_neigh or add_margin: 1104 prev_max_neigh = max_neigh 1105 max_neigh = max(max_count, max_neigh) + state.get( 1106 "add_neigh", self.add_neigh 1107 ) 1108 state_up["max_neigh"] = (max_neigh, prev_max_neigh) 1109 new_state["max_neigh"] = max_neigh 1110 1111 max_neigh_arr = np.empty(max_neigh, dtype=bool) 1112 1113 nedge = edge_src.shape[0] 1114 1115 ### sort edge_src 1116 idx_sort = np.argsort(edge_src) 1117 edge_src_sorted = edge_src[idx_sort] 1118 1119 ### map sparse to dense nblist 1120 offset = np.tile(np.arange(max_count), nat) 1121 if max_count * nat >= nedge: 1122 offset = np.tile(np.arange(max_count), nat)[:nedge] 1123 else: 1124 offset = np.zeros(nedge, dtype=np.int32) 1125 offset[: max_count * nat] = np.tile(np.arange(max_count), nat) 1126 1127 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1128 mask = edge_src_sorted < nat 1129 indices = edge_src_sorted * max_count + offset 1130 indices = indices[mask] 1131 idx_sort = idx_sort[mask] 1132 edge_idx = np.full(nat * max_count, nedge, dtype=np.int32) 1133 edge_idx[indices] = idx_sort 1134 edge_idx = edge_idx.reshape(nat, max_count) 1135 1136 ### find all triplet for each atom center 1137 local_src, local_dst = np.triu_indices(max_count, 1) 1138 angle_src = edge_idx[:, local_src].flatten() 1139 angle_dst = edge_idx[:, local_dst].flatten() 1140 1141 ### mask for valid angles 1142 mask1 = angle_src < nedge 1143 mask2 = angle_dst < nedge 1144 angle_mask = mask1 & mask2 1145 1146 max_angles = state.get("nangles", 0) 1147 idx = np.nonzero(angle_mask)[0] 1148 nangles = idx.shape[0] 1149 if nangles > max_angles or add_margin: 1150 max_angles_prev = max_angles 1151 max_angles = int(mult_size * max(nangles, max_angles)) + 1 1152 state_up["nangles"] = (max_angles, max_angles_prev) 1153 new_state["nangles"] = max_angles 1154 1155 ## filter angles to sparse representation 1156 angle_src_ = np.full(max_angles, nedge, dtype=np.int32) 1157 angle_dst_ = np.full(max_angles, nedge, dtype=np.int32) 1158 angle_src_[:nangles] = angle_src[idx] 1159 angle_dst_[:nangles] = angle_dst[idx] 1160 1161 central_atom = np.full(max_angles, nat, dtype=np.int32) 1162 central_atom[:nangles] = edge_src[angle_src_[:nangles]] 1163 1164 ## update graph 1165 output = { 1166 **inputs, 1167 self.graph_key: { 1168 **graph, 1169 "angle_src": angle_src_, 1170 "angle_dst": angle_dst_, 1171 "central_atom": central_atom, 1172 "angle_overflow": False, 1173 "max_neigh": max_neigh, 1174 "__max_neigh_array": max_neigh_arr, 1175 }, 1176 } 1177 1178 if return_state_update: 1179 return FrozenDict(new_state), output, state_up 1180 return FrozenDict(new_state), output 1181 1182 def check_reallocate(self, state, inputs, parent_overflow=False): 1183 """check for overflow and reallocate nblist if necessary""" 1184 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1185 if not overflow: 1186 return state, {}, inputs, False 1187 1188 add_margin = inputs[self.graph_key]["angle_overflow"] 1189 state, inputs, state_up = self( 1190 state, inputs, return_state_update=True, add_margin=add_margin 1191 ) 1192 return state, state_up, inputs, True 1193 1194 @partial(jax.jit, static_argnums=(0, 1)) 1195 def process(self, state, inputs): 1196 """build angle nblist on accelerator with jax and precomputed shapes""" 1197 graph = inputs[self.graph_key] 1198 edge_src = graph["edge_src"] 1199 1200 ### count number of neighbors 1201 nat = inputs["species"].shape[0] 1202 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1203 max_count = jnp.max(count) 1204 1205 ### get sizes 1206 if state is None: 1207 max_neigh_arr = graph["__max_neigh_array"] 1208 max_neigh = max_neigh_arr.shape[0] 1209 prev_nangles = graph["angle_src"].shape[0] 1210 else: 1211 max_neigh = state.get("max_neigh", self.add_neigh) 1212 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1213 prev_nangles = state.get("nangles", 0) 1214 1215 nedge = edge_src.shape[0] 1216 1217 ### sort edge_src 1218 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1219 edge_src_sorted = edge_src[idx_sort] 1220 1221 ### map sparse to dense nblist 1222 if max_neigh * nat < nedge: 1223 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1224 offset = jnp.asarray( 1225 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1226 ) 1227 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1228 indices = edge_src_sorted * max_neigh + offset 1229 edge_idx = ( 1230 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1231 .at[indices] 1232 .set(idx_sort, mode="drop") 1233 .reshape(nat, max_neigh) 1234 ) 1235 1236 ### find all triplet for each atom center 1237 local_src, local_dst = np.triu_indices(max_neigh, 1) 1238 angle_src = edge_idx[:, local_src].flatten() 1239 angle_dst = edge_idx[:, local_dst].flatten() 1240 1241 ### mask for valid angles 1242 mask1 = angle_src < nedge 1243 mask2 = angle_dst < nedge 1244 angle_mask = mask1 & mask2 1245 1246 ## filter angles to sparse representation 1247 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1248 angle_mask, 1249 prev_nangles, 1250 (angle_src, nedge), 1251 (angle_dst, nedge), 1252 ) 1253 ## find central atom 1254 central_atom = edge_src[angle_src] 1255 1256 ## check for overflow 1257 angle_overflow = nangles > prev_nangles 1258 neigh_overflow = max_count > max_neigh 1259 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1260 1261 ## update graph 1262 output = { 1263 **inputs, 1264 self.graph_key: { 1265 **graph, 1266 "angle_src": angle_src, 1267 "angle_dst": angle_dst, 1268 "central_atom": central_atom, 1269 "angle_overflow": overflow, 1270 # "max_neigh": max_neigh, 1271 "__max_neigh_array": max_neigh_arr, 1272 }, 1273 } 1274 1275 return output 1276 1277 @partial(jax.jit, static_argnums=(0,)) 1278 def update_skin(self, inputs): 1279 return self.process(None, inputs)
Add angles list to a graph
FPID: GRAPH_ANGULAR_EXTENSION
1182 def check_reallocate(self, state, inputs, parent_overflow=False): 1183 """check for overflow and reallocate nblist if necessary""" 1184 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1185 if not overflow: 1186 return state, {}, inputs, False 1187 1188 add_margin = inputs[self.graph_key]["angle_overflow"] 1189 state, inputs, state_up = self( 1190 state, inputs, return_state_update=True, add_margin=add_margin 1191 ) 1192 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
1194 @partial(jax.jit, static_argnums=(0, 1)) 1195 def process(self, state, inputs): 1196 """build angle nblist on accelerator with jax and precomputed shapes""" 1197 graph = inputs[self.graph_key] 1198 edge_src = graph["edge_src"] 1199 1200 ### count number of neighbors 1201 nat = inputs["species"].shape[0] 1202 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1203 max_count = jnp.max(count) 1204 1205 ### get sizes 1206 if state is None: 1207 max_neigh_arr = graph["__max_neigh_array"] 1208 max_neigh = max_neigh_arr.shape[0] 1209 prev_nangles = graph["angle_src"].shape[0] 1210 else: 1211 max_neigh = state.get("max_neigh", self.add_neigh) 1212 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1213 prev_nangles = state.get("nangles", 0) 1214 1215 nedge = edge_src.shape[0] 1216 1217 ### sort edge_src 1218 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1219 edge_src_sorted = edge_src[idx_sort] 1220 1221 ### map sparse to dense nblist 1222 if max_neigh * nat < nedge: 1223 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1224 offset = jnp.asarray( 1225 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1226 ) 1227 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1228 indices = edge_src_sorted * max_neigh + offset 1229 edge_idx = ( 1230 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1231 .at[indices] 1232 .set(idx_sort, mode="drop") 1233 .reshape(nat, max_neigh) 1234 ) 1235 1236 ### find all triplet for each atom center 1237 local_src, local_dst = np.triu_indices(max_neigh, 1) 1238 angle_src = edge_idx[:, local_src].flatten() 1239 angle_dst = edge_idx[:, local_dst].flatten() 1240 1241 ### mask for valid angles 1242 mask1 = angle_src < nedge 1243 mask2 = angle_dst < nedge 1244 angle_mask = mask1 & mask2 1245 1246 ## filter angles to sparse representation 1247 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1248 angle_mask, 1249 prev_nangles, 1250 (angle_src, nedge), 1251 (angle_dst, nedge), 1252 ) 1253 ## find central atom 1254 central_atom = edge_src[angle_src] 1255 1256 ## check for overflow 1257 angle_overflow = nangles > prev_nangles 1258 neigh_overflow = max_count > max_neigh 1259 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1260 1261 ## update graph 1262 output = { 1263 **inputs, 1264 self.graph_key: { 1265 **graph, 1266 "angle_src": angle_src, 1267 "angle_dst": angle_dst, 1268 "central_atom": central_atom, 1269 "angle_overflow": overflow, 1270 # "max_neigh": max_neigh, 1271 "__max_neigh_array": max_neigh_arr, 1272 }, 1273 } 1274 1275 return output
build angle nblist on accelerator with jax and precomputed shapes
1282class GraphAngleProcessor(nn.Module): 1283 """Process a pre-generated graph to compute angles 1284 1285 This module is automatically added to a FENNIX model when a GraphAngularExtension is used. 1286 1287 """ 1288 1289 graph_key: str 1290 """Key of the graph in the inputs.""" 1291 1292 @nn.compact 1293 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1294 graph = inputs[self.graph_key] 1295 distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"] 1296 vec = graph["vec"] 1297 angle_src = graph["angle_src"] 1298 angle_dst = graph["angle_dst"] 1299 1300 dir = vec / jnp.clip(distances[:, None], min=1.0e-5) 1301 cos_angles = ( 1302 dir.at[angle_src].get(mode="fill", fill_value=0.5) 1303 * dir.at[angle_dst].get(mode="fill", fill_value=0.5) 1304 ).sum(axis=-1) 1305 1306 angles = jnp.arccos(0.95 * cos_angles) 1307 1308 return { 1309 **inputs, 1310 self.graph_key: { 1311 **graph, 1312 # "cos_angles": cos_angles, 1313 "angles": angles, 1314 # "angle_mask": angle_mask, 1315 }, 1316 }
Process a pre-generated graph to compute angles
This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1319@dataclasses.dataclass(frozen=True) 1320class SpeciesIndexer: 1321 """Build an index that splits atomic arrays by species. 1322 1323 FPID: SPECIES_INDEXER 1324 1325 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1326 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1327 1328 """ 1329 1330 output_key: str = "species_index" 1331 """Key for the output dictionary.""" 1332 species_order: Optional[str] = None 1333 """Comma separated list of species in the order they should be indexed.""" 1334 add_atoms: int = 0 1335 """Additional atoms to add to the sizes.""" 1336 add_atoms_margin: int = 10 1337 """Additional atoms to add to the sizes when adding margin.""" 1338 1339 FPID: ClassVar[str] = "SPECIES_INDEXER" 1340 1341 def init(self): 1342 return FrozenDict( 1343 { 1344 "sizes": {}, 1345 } 1346 ) 1347 1348 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1349 species = np.array(inputs["species"], dtype=np.int32) 1350 nat = species.shape[0] 1351 set_species, counts = np.unique(species, return_counts=True) 1352 1353 new_state = {**state} 1354 state_up = {} 1355 1356 sizes = state.get("sizes", FrozenDict({})) 1357 new_sizes = {**sizes} 1358 up_sizes = False 1359 counts_dict = {} 1360 for s, c in zip(set_species, counts): 1361 if s <= 0: 1362 continue 1363 counts_dict[s] = c 1364 if c > sizes.get(s, 0): 1365 up_sizes = True 1366 add_atoms = state.get("add_atoms", self.add_atoms) 1367 if add_margin: 1368 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1369 new_sizes[s] = c + add_atoms 1370 1371 new_sizes = FrozenDict(new_sizes) 1372 if up_sizes: 1373 state_up["sizes"] = (new_sizes, sizes) 1374 new_state["sizes"] = new_sizes 1375 1376 if self.species_order is not None: 1377 species_order = [el.strip() for el in self.species_order.split(",")] 1378 max_size_prev = state.get("max_size", 0) 1379 max_size = max(new_sizes.values()) 1380 if max_size > max_size_prev: 1381 state_up["max_size"] = (max_size, max_size_prev) 1382 new_state["max_size"] = max_size 1383 max_size_prev = max_size 1384 1385 species_index = np.full((len(species_order), max_size), nat, dtype=np.int32) 1386 for i, el in enumerate(species_order): 1387 s = PERIODIC_TABLE_REV_IDX[el] 1388 if s in counts_dict.keys(): 1389 species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1390 else: 1391 species_index = { 1392 PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32) 1393 for s, c in new_sizes.items() 1394 } 1395 for s, c in zip(set_species, counts): 1396 if s <= 0: 1397 continue 1398 species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0] 1399 1400 output = { 1401 **inputs, 1402 self.output_key: species_index, 1403 self.output_key + "_overflow": False, 1404 } 1405 1406 if return_state_update: 1407 return FrozenDict(new_state), output, state_up 1408 return FrozenDict(new_state), output 1409 1410 def check_reallocate(self, state, inputs, parent_overflow=False): 1411 """check for overflow and reallocate nblist if necessary""" 1412 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1413 if not overflow: 1414 return state, {}, inputs, False 1415 1416 add_margin = inputs[self.output_key + "_overflow"] 1417 state, inputs, state_up = self( 1418 state, inputs, return_state_update=True, add_margin=add_margin 1419 ) 1420 return state, state_up, inputs, True 1421 # return state, {}, inputs, parent_overflow 1422 1423 @partial(jax.jit, static_argnums=(0, 1)) 1424 def process(self, state, inputs): 1425 # assert ( 1426 # self.output_key in inputs 1427 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1428 1429 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1430 if self.output_key in inputs and not recompute_species_index: 1431 return inputs 1432 1433 if state is None: 1434 raise ValueError("Species Indexer state must be provided on accelerator.") 1435 1436 species = inputs["species"] 1437 nat = species.shape[0] 1438 1439 sizes = state["sizes"] 1440 1441 if self.species_order is not None: 1442 species_order = [el.strip() for el in self.species_order.split(",")] 1443 max_size = state["max_size"] 1444 1445 species_index = jnp.full( 1446 (len(species_order), max_size), nat, dtype=jnp.int32 1447 ) 1448 for i, el in enumerate(species_order): 1449 s = PERIODIC_TABLE_REV_IDX[el] 1450 if s in sizes.keys(): 1451 c = sizes[s] 1452 species_index = species_index.at[i, :].set( 1453 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1454 ) 1455 # if s in counts_dict.keys(): 1456 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1457 else: 1458 # species_index = { 1459 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1460 # for s, c in sizes.items() 1461 # } 1462 species_index = {} 1463 overflow = False 1464 natcount = 0 1465 for s, c in sizes.items(): 1466 mask = species == s 1467 new_size = jnp.sum(mask) 1468 natcount = natcount + new_size 1469 overflow = overflow | (new_size > c) # check if sizes are correct 1470 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1471 species == s, size=c, fill_value=nat 1472 )[0] 1473 1474 mask = species <= 0 1475 new_size = jnp.sum(mask) 1476 natcount = natcount + new_size 1477 overflow = overflow | ( 1478 natcount < species.shape[0] 1479 ) # check if any species missing 1480 1481 return { 1482 **inputs, 1483 self.output_key: species_index, 1484 self.output_key + "_overflow": overflow, 1485 } 1486 1487 @partial(jax.jit, static_argnums=(0,)) 1488 def update_skin(self, inputs): 1489 return self.process(None, inputs)
Build an index that splits atomic arrays by species.
FPID: SPECIES_INDEXER
If species_order
is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
If species_order
is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
Comma separated list of species in the order they should be indexed.
1410 def check_reallocate(self, state, inputs, parent_overflow=False): 1411 """check for overflow and reallocate nblist if necessary""" 1412 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1413 if not overflow: 1414 return state, {}, inputs, False 1415 1416 add_margin = inputs[self.output_key + "_overflow"] 1417 state, inputs, state_up = self( 1418 state, inputs, return_state_update=True, add_margin=add_margin 1419 ) 1420 return state, state_up, inputs, True 1421 # return state, {}, inputs, parent_overflow
check for overflow and reallocate nblist if necessary
1423 @partial(jax.jit, static_argnums=(0, 1)) 1424 def process(self, state, inputs): 1425 # assert ( 1426 # self.output_key in inputs 1427 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1428 1429 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1430 if self.output_key in inputs and not recompute_species_index: 1431 return inputs 1432 1433 if state is None: 1434 raise ValueError("Species Indexer state must be provided on accelerator.") 1435 1436 species = inputs["species"] 1437 nat = species.shape[0] 1438 1439 sizes = state["sizes"] 1440 1441 if self.species_order is not None: 1442 species_order = [el.strip() for el in self.species_order.split(",")] 1443 max_size = state["max_size"] 1444 1445 species_index = jnp.full( 1446 (len(species_order), max_size), nat, dtype=jnp.int32 1447 ) 1448 for i, el in enumerate(species_order): 1449 s = PERIODIC_TABLE_REV_IDX[el] 1450 if s in sizes.keys(): 1451 c = sizes[s] 1452 species_index = species_index.at[i, :].set( 1453 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1454 ) 1455 # if s in counts_dict.keys(): 1456 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1457 else: 1458 # species_index = { 1459 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1460 # for s, c in sizes.items() 1461 # } 1462 species_index = {} 1463 overflow = False 1464 natcount = 0 1465 for s, c in sizes.items(): 1466 mask = species == s 1467 new_size = jnp.sum(mask) 1468 natcount = natcount + new_size 1469 overflow = overflow | (new_size > c) # check if sizes are correct 1470 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1471 species == s, size=c, fill_value=nat 1472 )[0] 1473 1474 mask = species <= 0 1475 new_size = jnp.sum(mask) 1476 natcount = natcount + new_size 1477 overflow = overflow | ( 1478 natcount < species.shape[0] 1479 ) # check if any species missing 1480 1481 return { 1482 **inputs, 1483 self.output_key: species_index, 1484 self.output_key + "_overflow": overflow, 1485 }
1491@dataclasses.dataclass(frozen=True) 1492class BlockIndexer: 1493 """Build an index that splits atomic arrays by chemical blocks. 1494 1495 FPID: BLOCK_INDEXER 1496 1497 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1498 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1499 1500 """ 1501 1502 output_key: str = "block_index" 1503 """Key for the output dictionary.""" 1504 add_atoms: int = 0 1505 """Additional atoms to add to the sizes.""" 1506 add_atoms_margin: int = 10 1507 """Additional atoms to add to the sizes when adding margin.""" 1508 split_CNOPSSe: bool = False 1509 1510 FPID: ClassVar[str] = "BLOCK_INDEXER" 1511 1512 def init(self): 1513 return FrozenDict( 1514 { 1515 "sizes": {}, 1516 } 1517 ) 1518 1519 def build_chemical_blocks(self): 1520 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1521 if self.split_CNOPSSe: 1522 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1523 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1524 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1525 if self.split_CNOPSSe: 1526 _CHEMICAL_BLOCKS[6] = 1 1527 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1528 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1529 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1530 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1531 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1532 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS 1533 1534 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1535 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1536 1537 species = np.array(inputs["species"], dtype=np.int32) 1538 blocks = _CHEMICAL_BLOCKS[species] 1539 nat = species.shape[0] 1540 set_blocks, counts = np.unique(blocks, return_counts=True) 1541 1542 new_state = {**state} 1543 state_up = {} 1544 1545 sizes = state.get("sizes", FrozenDict({})) 1546 new_sizes = {**sizes} 1547 up_sizes = False 1548 for s, c in zip(set_blocks, counts): 1549 if s < 0: 1550 continue 1551 key = (s, _CHEMICAL_BLOCKS_NAMES[s]) 1552 if c > sizes.get(key, 0): 1553 up_sizes = True 1554 add_atoms = state.get("add_atoms", self.add_atoms) 1555 if add_margin: 1556 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1557 new_sizes[key] = c + add_atoms 1558 1559 new_sizes = FrozenDict(new_sizes) 1560 if up_sizes: 1561 state_up["sizes"] = (new_sizes, sizes) 1562 new_state["sizes"] = new_sizes 1563 1564 block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES} 1565 for (_,n), c in new_sizes.items(): 1566 block_index[n] = np.full(c, nat, dtype=np.int32) 1567 # block_index = { 1568 # n: np.full(c, nat, dtype=np.int32) 1569 # for (_,n), c in new_sizes.items() 1570 # } 1571 for s, c in zip(set_blocks, counts): 1572 if s < 0: 1573 continue 1574 block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0] 1575 1576 output = { 1577 **inputs, 1578 self.output_key: block_index, 1579 self.output_key + "_overflow": False, 1580 } 1581 1582 if return_state_update: 1583 return FrozenDict(new_state), output, state_up 1584 return FrozenDict(new_state), output 1585 1586 def check_reallocate(self, state, inputs, parent_overflow=False): 1587 """check for overflow and reallocate nblist if necessary""" 1588 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1589 if not overflow: 1590 return state, {}, inputs, False 1591 1592 add_margin = inputs[self.output_key + "_overflow"] 1593 state, inputs, state_up = self( 1594 state, inputs, return_state_update=True, add_margin=add_margin 1595 ) 1596 return state, state_up, inputs, True 1597 # return state, {}, inputs, parent_overflow 1598 1599 @partial(jax.jit, static_argnums=(0, 1)) 1600 def process(self, state, inputs): 1601 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1602 # assert ( 1603 # self.output_key in inputs 1604 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1605 1606 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1607 if self.output_key in inputs and not recompute_species_index: 1608 return inputs 1609 1610 if state is None: 1611 raise ValueError("Block Indexer state must be provided on accelerator.") 1612 1613 species = inputs["species"] 1614 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1615 nat = species.shape[0] 1616 1617 sizes = state["sizes"] 1618 1619 # species_index = { 1620 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1621 # for s, c in sizes.items() 1622 # } 1623 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1624 overflow = False 1625 natcount = 0 1626 for (s,name), c in sizes.items(): 1627 mask = blocks == s 1628 new_size = jnp.sum(mask) 1629 natcount = natcount + new_size 1630 overflow = overflow | (new_size > c) # check if sizes are correct 1631 block_index[name] = jnp.nonzero( 1632 mask, size=c, fill_value=nat 1633 )[0] 1634 1635 mask = blocks < 0 1636 new_size = jnp.sum(mask) 1637 natcount = natcount + new_size 1638 overflow = overflow | ( 1639 natcount < species.shape[0] 1640 ) # check if any species missing 1641 1642 return { 1643 **inputs, 1644 self.output_key: block_index, 1645 self.output_key + "_overflow": overflow, 1646 } 1647 1648 @partial(jax.jit, static_argnums=(0,)) 1649 def update_skin(self, inputs): 1650 return self.process(None, inputs)
Build an index that splits atomic arrays by chemical blocks.
FPID: BLOCK_INDEXER
If species_order
is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
If species_order
is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1519 def build_chemical_blocks(self): 1520 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1521 if self.split_CNOPSSe: 1522 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1523 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1524 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1525 if self.split_CNOPSSe: 1526 _CHEMICAL_BLOCKS[6] = 1 1527 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1528 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1529 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1530 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1531 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1532 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1586 def check_reallocate(self, state, inputs, parent_overflow=False): 1587 """check for overflow and reallocate nblist if necessary""" 1588 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1589 if not overflow: 1590 return state, {}, inputs, False 1591 1592 add_margin = inputs[self.output_key + "_overflow"] 1593 state, inputs, state_up = self( 1594 state, inputs, return_state_update=True, add_margin=add_margin 1595 ) 1596 return state, state_up, inputs, True 1597 # return state, {}, inputs, parent_overflow
check for overflow and reallocate nblist if necessary
1599 @partial(jax.jit, static_argnums=(0, 1)) 1600 def process(self, state, inputs): 1601 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1602 # assert ( 1603 # self.output_key in inputs 1604 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1605 1606 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1607 if self.output_key in inputs and not recompute_species_index: 1608 return inputs 1609 1610 if state is None: 1611 raise ValueError("Block Indexer state must be provided on accelerator.") 1612 1613 species = inputs["species"] 1614 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1615 nat = species.shape[0] 1616 1617 sizes = state["sizes"] 1618 1619 # species_index = { 1620 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1621 # for s, c in sizes.items() 1622 # } 1623 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1624 overflow = False 1625 natcount = 0 1626 for (s,name), c in sizes.items(): 1627 mask = blocks == s 1628 new_size = jnp.sum(mask) 1629 natcount = natcount + new_size 1630 overflow = overflow | (new_size > c) # check if sizes are correct 1631 block_index[name] = jnp.nonzero( 1632 mask, size=c, fill_value=nat 1633 )[0] 1634 1635 mask = blocks < 0 1636 new_size = jnp.sum(mask) 1637 natcount = natcount + new_size 1638 overflow = overflow | ( 1639 natcount < species.shape[0] 1640 ) # check if any species missing 1641 1642 return { 1643 **inputs, 1644 self.output_key: block_index, 1645 self.output_key + "_overflow": overflow, 1646 }
1653@dataclasses.dataclass(frozen=True) 1654class AtomPadding: 1655 """Pad atomic arrays to a fixed size.""" 1656 1657 mult_size: float = 1.2 1658 """Multiplicative factor for resizing the atomic arrays.""" 1659 add_sys: int = 0 1660 1661 def init(self): 1662 return {"prev_nat": 0, "prev_nsys": 0} 1663 1664 def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]: 1665 species = inputs["species"] 1666 nat = species.shape[0] 1667 1668 prev_nat = state.get("prev_nat", 0) 1669 prev_nat_ = prev_nat 1670 if nat > prev_nat_: 1671 prev_nat_ = int(self.mult_size * nat) + 1 1672 1673 nsys = len(inputs["natoms"]) 1674 prev_nsys = state.get("prev_nsys", 0) 1675 prev_nsys_ = prev_nsys 1676 if nsys > prev_nsys_: 1677 prev_nsys_ = nsys + self.add_sys 1678 1679 add_atoms = prev_nat_ - nat 1680 add_sys = prev_nsys_ - nsys + 1 1681 output = {**inputs} 1682 if add_atoms > 0: 1683 for k, v in inputs.items(): 1684 if isinstance(v, np.ndarray) or isinstance(v, jax.Array): 1685 if v.shape[0] == nat: 1686 output[k] = np.append( 1687 v, 1688 np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype), 1689 axis=0, 1690 ) 1691 elif v.shape[0] == nsys: 1692 if k == "cells": 1693 output[k] = np.append( 1694 v, 1695 1000 1696 * np.eye(3, dtype=v.dtype)[None, :, :].repeat( 1697 add_sys, axis=0 1698 ), 1699 axis=0, 1700 ) 1701 else: 1702 output[k] = np.append( 1703 v, 1704 np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype), 1705 axis=0, 1706 ) 1707 output["natoms"] = np.append( 1708 inputs["natoms"], np.zeros(add_sys, dtype=np.int32) 1709 ) 1710 output["species"] = np.append( 1711 species, -1 * np.ones(add_atoms, dtype=species.dtype) 1712 ) 1713 output["batch_index"] = np.append( 1714 inputs["batch_index"], 1715 np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype), 1716 ) 1717 if "system_index" in inputs: 1718 output["system_index"] = np.append( 1719 inputs["system_index"], 1720 np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype), 1721 ) 1722 1723 output["true_atoms"] = output["species"] > 0 1724 output["true_sys"] = np.arange(len(output["natoms"])) < nsys 1725 1726 state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_} 1727 1728 return FrozenDict(state), output
Pad atomic arrays to a fixed size.
1731def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]: 1732 """Remove padding from atomic arrays.""" 1733 if "true_atoms" not in inputs: 1734 return inputs 1735 1736 species = inputs["species"] 1737 true_atoms = inputs["true_atoms"] 1738 true_sys = inputs["true_sys"] 1739 natall = species.shape[0] 1740 nat = np.argmax(species <= 0) 1741 if nat == 0: 1742 return inputs 1743 1744 natoms = inputs["natoms"] 1745 nsysall = len(natoms) 1746 1747 output = {**inputs} 1748 for k, v in inputs.items(): 1749 if isinstance(v, jax.Array) or isinstance(v, np.ndarray): 1750 if v.ndim == 0: 1751 continue 1752 if v.shape[0] == natall: 1753 output[k] = v[true_atoms] 1754 elif v.shape[0] == nsysall: 1755 output[k] = v[true_sys] 1756 del output["true_sys"] 1757 del output["true_atoms"] 1758 return output
Remove padding from atomic arrays.
1761def check_input(inputs): 1762 """Check the input dictionary for required keys and types.""" 1763 assert "species" in inputs, "species must be provided" 1764 assert "coordinates" in inputs, "coordinates must be provided" 1765 species = inputs["species"].astype(np.int32) 1766 ifake = np.argmax(species <= 0) 1767 if ifake > 0: 1768 assert np.all(species[:ifake] > 0), "species must be positive" 1769 nat = inputs["species"].shape[0] 1770 1771 natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32) 1772 batch_index = inputs.get( 1773 "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms) 1774 ).astype(np.int32) 1775 output = {**inputs, "natoms": natoms, "batch_index": batch_index} 1776 if "cells" in inputs: 1777 cells = inputs["cells"] 1778 if "reciprocal_cells" not in inputs: 1779 reciprocal_cells = np.linalg.inv(cells) 1780 else: 1781 reciprocal_cells = inputs["reciprocal_cells"] 1782 if cells.ndim == 2: 1783 cells = cells[None, :, :] 1784 if reciprocal_cells.ndim == 2: 1785 reciprocal_cells = reciprocal_cells[None, :, :] 1786 output["cells"] = cells 1787 output["reciprocal_cells"] = reciprocal_cells 1788 1789 return output
Check the input dictionary for required keys and types.
1792def convert_to_jax(data): 1793 """Convert a numpy arrays to jax arrays in a pytree.""" 1794 1795 def convert(x): 1796 if isinstance(x, np.ndarray): 1797 # if x.dtype == np.float64: 1798 # return jnp.asarray(x, dtype=jnp.float32) 1799 return jnp.asarray(x) 1800 return x 1801 1802 return jax.tree_util.tree_map(convert, data)
Convert a numpy arrays to jax arrays in a pytree.
1805class JaxConverter(nn.Module): 1806 """Convert numpy arrays to jax arrays in a pytree.""" 1807 1808 def __call__(self, data): 1809 return convert_to_jax(data)
Convert numpy arrays to jax arrays in a pytree.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1812@dataclasses.dataclass(frozen=True) 1813class PreprocessingChain: 1814 """Chain of preprocessing layers.""" 1815 1816 layers: Tuple[Callable[..., Dict[str, Any]]] 1817 """Preprocessing layers.""" 1818 use_atom_padding: bool = False 1819 """Add an AtomPadding layer at the beginning of the chain.""" 1820 atom_padder: AtomPadding = AtomPadding() 1821 """AtomPadding layer.""" 1822 1823 def __post_init__(self): 1824 if not isinstance(self.layers, Sequence): 1825 raise ValueError( 1826 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 1827 ) 1828 if not self.layers: 1829 raise ValueError(f"Error: no Preprocessing layers were provided.") 1830 1831 def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]: 1832 do_check_input = state.get("check_input", True) 1833 if do_check_input: 1834 inputs = check_input(inputs) 1835 new_state = [] 1836 layer_state = state["layers_state"] 1837 i = 0 1838 if self.use_atom_padding: 1839 s, inputs = self.atom_padder(layer_state[0], inputs) 1840 new_state.append(s) 1841 i += 1 1842 for layer in self.layers: 1843 s, inputs = layer(layer_state[i], inputs, return_state_update=False) 1844 new_state.append(s) 1845 i += 1 1846 return FrozenDict({**state, "layers_state": tuple(new_state)}), convert_to_jax( 1847 inputs 1848 ) 1849 1850 def check_reallocate(self, state, inputs): 1851 new_state = [] 1852 state_up = [] 1853 layer_state = state["layers_state"] 1854 i = 0 1855 if self.use_atom_padding: 1856 new_state.append(layer_state[0]) 1857 i += 1 1858 parent_overflow = False 1859 for layer in self.layers: 1860 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1861 layer_state[i], inputs, parent_overflow 1862 ) 1863 new_state.append(s) 1864 state_up.append(s_up) 1865 i += 1 1866 1867 if not parent_overflow: 1868 return state, {}, inputs, False 1869 return ( 1870 FrozenDict({**state, "layers_state": tuple(new_state)}), 1871 state_up, 1872 inputs, 1873 True, 1874 ) 1875 1876 def atom_padding(self, state, inputs): 1877 if self.use_atom_padding: 1878 padder_state = state["layers_state"][0] 1879 return self.atom_padder(padder_state, inputs) 1880 return state, inputs 1881 1882 @partial(jax.jit, static_argnums=(0, 1)) 1883 def process(self, state, inputs): 1884 layer_state = state["layers_state"] 1885 i = 1 if self.use_atom_padding else 0 1886 for layer in self.layers: 1887 inputs = layer.process(layer_state[i], inputs) 1888 i += 1 1889 return inputs 1890 1891 @partial(jax.jit, static_argnums=(0)) 1892 def update_skin(self, inputs): 1893 for layer in self.layers: 1894 inputs = layer.update_skin(inputs) 1895 return inputs 1896 1897 def init(self): 1898 state = [] 1899 if self.use_atom_padding: 1900 state.append(self.atom_padder.init()) 1901 for layer in self.layers: 1902 state.append(layer.init()) 1903 return FrozenDict({"check_input": True, "layers_state": state}) 1904 1905 def init_with_output(self, inputs): 1906 state = self.init() 1907 return self(state, inputs) 1908 1909 def get_processors(self): 1910 processors = [] 1911 for layer in self.layers: 1912 if hasattr(layer, "get_processor"): 1913 processors.append(layer.get_processor()) 1914 return processors 1915 1916 def get_graphs_properties(self): 1917 properties = {} 1918 for layer in self.layers: 1919 if hasattr(layer, "get_graph_properties"): 1920 properties = deep_update(properties, layer.get_graph_properties()) 1921 return properties
Chain of preprocessing layers.
1850 def check_reallocate(self, state, inputs): 1851 new_state = [] 1852 state_up = [] 1853 layer_state = state["layers_state"] 1854 i = 0 1855 if self.use_atom_padding: 1856 new_state.append(layer_state[0]) 1857 i += 1 1858 parent_overflow = False 1859 for layer in self.layers: 1860 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1861 layer_state[i], inputs, parent_overflow 1862 ) 1863 new_state.append(s) 1864 state_up.append(s_up) 1865 i += 1 1866 1867 if not parent_overflow: 1868 return state, {}, inputs, False 1869 return ( 1870 FrozenDict({**state, "layers_state": tuple(new_state)}), 1871 state_up, 1872 inputs, 1873 True, 1874 )