fennol.models.preprocessing
1import flax.linen as nn 2from typing import Sequence, Callable, Union, Dict, Any, ClassVar 3import jax.numpy as jnp 4import jax 5import numpy as np 6from typing import Optional, Tuple 7import numba 8import dataclasses 9from functools import partial 10 11from flax.core.frozen_dict import FrozenDict 12 13 14from ..utils.activations import chain,safe_sqrt 15from ..utils import deep_update, mask_filter_1d 16from ..utils.kspace import get_reciprocal_space_parameters 17from .misc.misc import SwitchFunction 18from ..utils.periodic_table import PERIODIC_TABLE, PERIODIC_TABLE_REV_IDX,CHEMICAL_BLOCKS,CHEMICAL_BLOCKS_NAMES 19 20 21@dataclasses.dataclass(frozen=True) 22class GraphGenerator: 23 """Generate a graph from a set of coordinates 24 25 FPID: GRAPH 26 27 For now, we generate all pairs of atoms and filter based on cutoff. 28 If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist. 29 """ 30 31 cutoff: float 32 """Cutoff distance for the graph.""" 33 graph_key: str = "graph" 34 """Key of the graph in the outputs.""" 35 switch_params: dict = dataclasses.field(default_factory=dict, hash=False) 36 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 37 kmax: int = 30 38 """Maximum number of k-points to consider.""" 39 kthr: float = 1e-6 40 """Threshold for k-point filtering.""" 41 k_space: bool = False 42 """Whether to generate k-space information for the graph.""" 43 mult_size: float = 1.05 44 """Multiplicative factor for resizing the nblist.""" 45 # covalent_cutoff: bool = False 46 47 FPID: ClassVar[str] = "GRAPH" 48 49 def init(self): 50 return FrozenDict( 51 { 52 "max_nat": 1, 53 "npairs": 1, 54 "nblist_mult_size": self.mult_size, 55 } 56 ) 57 58 def get_processor(self) -> Tuple[nn.Module, Dict]: 59 return GraphProcessor, { 60 "cutoff": self.cutoff, 61 "graph_key": self.graph_key, 62 "switch_params": self.switch_params, 63 "name": f"{self.graph_key}_Processor", 64 } 65 66 def get_graph_properties(self): 67 return { 68 self.graph_key: { 69 "cutoff": self.cutoff, 70 "directed": True, 71 } 72 } 73 74 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 75 """build a nblist on cpu with numpy and dynamic shapes + store max shapes""" 76 if self.graph_key in inputs: 77 graph = inputs[self.graph_key] 78 if "keep_graph" in graph: 79 return state, inputs 80 81 coords = np.array(inputs["coordinates"], dtype=np.float32) 82 natoms = np.array(inputs["natoms"], dtype=np.int32) 83 batch_index = np.array(inputs["batch_index"], dtype=np.int32) 84 85 new_state = {**state} 86 state_up = {} 87 88 mult_size = state.get("nblist_mult_size", self.mult_size) 89 assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0" 90 91 if natoms.shape[0] == 1: 92 max_nat = coords.shape[0] 93 true_max_nat = max_nat 94 else: 95 max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0])) 96 true_max_nat = int(np.max(natoms)) 97 if true_max_nat > max_nat: 98 add_atoms = state.get("add_atoms", 0) 99 new_maxnat = true_max_nat + add_atoms 100 state_up["max_nat"] = (new_maxnat, max_nat) 101 new_state["max_nat"] = new_maxnat 102 103 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 104 105 ### compute indices of all pairs 106 p1, p2 = np.triu_indices(true_max_nat, 1) 107 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 108 pbc_shifts = None 109 if natoms.shape[0] > 1: 110 ## batching => mask irrelevant pairs 111 mask_p12 = ( 112 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 113 ).flatten() 114 shift = np.concatenate( 115 (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32)) 116 ) 117 p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 118 p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 119 120 apply_pbc = "cells" in inputs 121 if not apply_pbc: 122 ### NO PBC 123 vec = coords[p2] - coords[p1] 124 else: 125 cells = np.array(inputs["cells"], dtype=np.float32) 126 reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32) 127 minimage = "minimum_image" in inputs.get("flags", {}) 128 if minimage: 129 ## MINIMUM IMAGE CONVENTION 130 vec = coords[p2] - coords[p1] 131 if cells.shape[0] == 1: 132 vecpbc = np.dot(vec, reciprocal_cells[0]) 133 pbc_shifts = -np.round(vecpbc) 134 vec = vec + np.dot(pbc_shifts, cells[0]) 135 else: 136 batch_index_vec = batch_index[p1] 137 vecpbc = np.einsum( 138 "aj,aji->ai", vec, reciprocal_cells[batch_index_vec] 139 ) 140 pbc_shifts = -np.round(vecpbc) 141 vec = vec + np.einsum( 142 "aj,aji->ai", pbc_shifts, cells[batch_index_vec] 143 ) 144 else: 145 ### GENERAL PBC 146 ## put all atoms in central box 147 if cells.shape[0] == 1: 148 coords_pbc = np.dot(coords, reciprocal_cells[0]) 149 at_shifts = -np.floor(coords_pbc) 150 coords_pbc = coords + np.dot(at_shifts, cells[0]) 151 else: 152 coords_pbc = np.einsum( 153 "aj,aji->ai", coords, reciprocal_cells[batch_index] 154 ) 155 at_shifts = -np.floor(coords_pbc) 156 coords_pbc = coords + np.einsum( 157 "aj,aji->ai", at_shifts, cells[batch_index] 158 ) 159 vec = coords_pbc[p2] - coords_pbc[p1] 160 161 ## compute maximum number of repeats 162 inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5 163 cdinv = cutoff_skin * inv_distances 164 num_repeats_all = np.ceil(cdinv).astype(np.int32) 165 if "true_sys" in inputs: 166 num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0) 167 # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all) 168 num_repeats = np.max(num_repeats_all, axis=0) 169 num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0))) 170 if np.any(num_repeats > num_repeats_prev): 171 num_repeats_new = np.maximum(num_repeats, num_repeats_prev) 172 state_up["num_repeats_pbc"] = ( 173 tuple(num_repeats_new), 174 tuple(num_repeats_prev), 175 ) 176 new_state["num_repeats_pbc"] = tuple(num_repeats_new) 177 ## build all possible shifts 178 cell_shift_pbc = np.array( 179 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 180 dtype=cells.dtype, 181 ).T.reshape(-1, 3) 182 ## shift applied to vectors 183 if cells.shape[0] == 1: 184 dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :] 185 vec = (vec[:, None, :] + dvec).reshape(-1, 3) 186 pbc_shifts = np.broadcast_to( 187 cell_shift_pbc[None, :, :], 188 (p1.shape[0], cell_shift_pbc.shape[0], 3), 189 ).reshape(-1, 3) 190 p1 = np.broadcast_to( 191 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 192 ).flatten() 193 p2 = np.broadcast_to( 194 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 195 ).flatten() 196 if natoms.shape[0] > 1: 197 mask_p12 = np.broadcast_to( 198 mask_p12[:, None], 199 (mask_p12.shape[0], cell_shift_pbc.shape[0]), 200 ).flatten() 201 else: 202 dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells) 203 204 ## get pbc shifts specific to each box 205 cell_shift_pbc = np.broadcast_to( 206 cell_shift_pbc[None, :, :], 207 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 208 ) 209 mask = np.all( 210 np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 211 ).flatten() 212 idx = np.nonzero(mask)[0] 213 nshifts = idx.shape[0] 214 nshifts_prev = state.get("nshifts_pbc", 0) 215 if nshifts > nshifts_prev or add_margin: 216 nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1 217 state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev) 218 new_state["nshifts_pbc"] = nshifts_new 219 220 dvec_filter = dvec.reshape(-1, 3)[idx, :] 221 cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :] 222 223 ## get batch shift in the dvec_filter array 224 nrep = np.prod(2 * num_repeats_all + 1, axis=1) 225 bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1])) 226 227 ## compute vectors 228 batch_index_vec = batch_index[p1] 229 nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0) 230 vec = vec.repeat(nrep_vec, axis=0) 231 nvec_pbc = nrep_vec.sum() #vec.shape[0] 232 nvec_pbc_prev = state.get("nvec_pbc", 0) 233 if nvec_pbc > nvec_pbc_prev or add_margin: 234 nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1 235 state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev) 236 new_state["nvec_pbc"] = nvec_pbc_new 237 238 # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev) 239 ## get shift index 240 dshift = np.concatenate( 241 (np.array([0]), np.cumsum(nrep_vec)[:-1]) 242 ).repeat(nrep_vec) 243 # ishift = np.arange(dshift.shape[0])-dshift 244 # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec) 245 icellshift = ( 246 np.arange(dshift.shape[0]) 247 - dshift 248 + bshift[batch_index_vec].repeat(nrep_vec) 249 ) 250 # shift vectors 251 vec = vec + dvec_filter[icellshift] 252 pbc_shifts = cell_shift_pbc_filter[icellshift] 253 254 p1 = np.repeat(p1, nrep_vec) 255 p2 = np.repeat(p2, nrep_vec) 256 if natoms.shape[0] > 1: 257 mask_p12 = np.repeat(mask_p12, nrep_vec) 258 259 ## compute distances 260 d12 = (vec**2).sum(axis=-1) 261 if natoms.shape[0] > 1: 262 d12 = np.where(mask_p12, d12, cutoff_skin**2) 263 264 ## filter pairs 265 max_pairs = state.get("npairs", 1) 266 mask = d12 < cutoff_skin**2 267 idx = np.nonzero(mask)[0] 268 npairs = idx.shape[0] 269 if npairs > max_pairs or add_margin: 270 prev_max_pairs = max_pairs 271 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 272 state_up["npairs"] = (max_pairs, prev_max_pairs) 273 new_state["npairs"] = max_pairs 274 275 nat = coords.shape[0] 276 edge_src = np.full(max_pairs, nat, dtype=np.int32) 277 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 278 d12_ = np.full(max_pairs, cutoff_skin**2) 279 edge_src[:npairs] = p1[idx] 280 edge_dst[:npairs] = p2[idx] 281 d12_[:npairs] = d12[idx] 282 d12 = d12_ 283 284 if apply_pbc: 285 pbc_shifts_ = np.zeros((max_pairs, 3)) 286 pbc_shifts_[:npairs] = pbc_shifts[idx] 287 pbc_shifts = pbc_shifts_ 288 if not minimage: 289 pbc_shifts[:npairs] = ( 290 pbc_shifts[:npairs] 291 + at_shifts[edge_dst[:npairs]] 292 - at_shifts[edge_src[:npairs]] 293 ) 294 295 if "nblist_skin" in state: 296 edge_src_skin = edge_src 297 edge_dst_skin = edge_dst 298 if apply_pbc: 299 pbc_shifts_skin = pbc_shifts 300 max_pairs_skin = state.get("npairs_skin", 1) 301 mask = d12 < self.cutoff**2 302 idx = np.nonzero(mask)[0] 303 npairs_skin = idx.shape[0] 304 if npairs_skin > max_pairs_skin or add_margin: 305 prev_max_pairs_skin = max_pairs_skin 306 max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1 307 state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin) 308 new_state["npairs_skin"] = max_pairs_skin 309 edge_src = np.full(max_pairs_skin, nat, dtype=np.int32) 310 edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32) 311 d12_ = np.full(max_pairs_skin, self.cutoff**2) 312 edge_src[:npairs_skin] = edge_src_skin[idx] 313 edge_dst[:npairs_skin] = edge_dst_skin[idx] 314 d12_[:npairs_skin] = d12[idx] 315 d12 = d12_ 316 if apply_pbc: 317 pbc_shifts = np.full((max_pairs_skin, 3), 0.0) 318 pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx] 319 320 ## symmetrize 321 edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate( 322 (edge_dst, edge_src) 323 ) 324 d12 = np.concatenate((d12, d12)) 325 if apply_pbc: 326 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts)) 327 328 graph = inputs.get(self.graph_key, {}) 329 graph_out = { 330 **graph, 331 "edge_src": edge_src, 332 "edge_dst": edge_dst, 333 "d12": d12, 334 "overflow": False, 335 "pbc_shifts": pbc_shifts, 336 } 337 if "nblist_skin" in state: 338 graph_out["edge_src_skin"] = edge_src_skin 339 graph_out["edge_dst_skin"] = edge_dst_skin 340 if apply_pbc: 341 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 342 343 if self.k_space and apply_pbc: 344 if "k_points" not in graph: 345 ks, _, _, bewald = get_reciprocal_space_parameters( 346 reciprocal_cells, self.cutoff, self.kmax, self.kthr 347 ) 348 graph_out["k_points"] = ks 349 graph_out["b_ewald"] = bewald 350 351 output = {**inputs, self.graph_key: graph_out} 352 353 if return_state_update: 354 return FrozenDict(new_state), output, state_up 355 return FrozenDict(new_state), output 356 357 def check_reallocate(self, state, inputs, parent_overflow=False): 358 """check for overflow and reallocate nblist if necessary""" 359 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 360 if not overflow: 361 return state, {}, inputs, False 362 363 add_margin = inputs[self.graph_key].get("overflow", False) 364 state, inputs, state_up = self( 365 state, inputs, return_state_update=True, add_margin=add_margin 366 ) 367 return state, state_up, inputs, True 368 369 @partial(jax.jit, static_argnums=(0, 1)) 370 def process(self, state, inputs): 371 """build a nblist on accelerator with jax and precomputed shapes""" 372 if self.graph_key in inputs: 373 graph = inputs[self.graph_key] 374 if "keep_graph" in graph: 375 return inputs 376 coords = inputs["coordinates"] 377 natoms = inputs["natoms"] 378 batch_index = inputs["batch_index"] 379 380 if natoms.shape[0] == 1: 381 max_nat = coords.shape[0] 382 else: 383 max_nat = state.get( 384 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 385 ) 386 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 387 388 ### compute indices of all pairs 389 p1, p2 = np.triu_indices(max_nat, 1) 390 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 391 pbc_shifts = None 392 if natoms.shape[0] > 1: 393 ## batching => mask irrelevant pairs 394 mask_p12 = ( 395 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 396 ).flatten() 397 shift = jnp.concatenate( 398 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 399 ) 400 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 401 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 402 403 ## compute vectors 404 overflow_repeats = jnp.asarray(False, dtype=bool) 405 if "cells" not in inputs: 406 vec = coords[p2] - coords[p1] 407 else: 408 cells = inputs["cells"] 409 reciprocal_cells = inputs["reciprocal_cells"] 410 # minimage = state.get("minimum_image", True) 411 minimage = "minimum_image" in inputs.get("flags", {}) 412 413 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 414 vecpbc = jnp.dot(vec, reciprocal_cell) 415 if mode == "round": 416 pbc_shifts = -jnp.round(vecpbc) 417 elif mode == "floor": 418 pbc_shifts = -jnp.floor(vecpbc) 419 else: 420 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 421 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 422 423 if minimage: 424 ## minimum image convention 425 vec = coords[p2] - coords[p1] 426 427 if cells.shape[0] == 1: 428 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 429 else: 430 batch_index_vec = batch_index[p1] 431 vec, pbc_shifts = jax.vmap(compute_pbc)( 432 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 433 ) 434 else: 435 ### general PBC only for single cell yet 436 # if cells.shape[0] > 1: 437 # raise NotImplementedError( 438 # "General PBC not implemented for batches on accelerator." 439 # ) 440 # cell = cells[0] 441 # reciprocal_cell = reciprocal_cells[0] 442 443 ## put all atoms in central box 444 if cells.shape[0] == 1: 445 coords_pbc, at_shifts = compute_pbc( 446 coords, reciprocal_cells[0], cells[0], mode="floor" 447 ) 448 else: 449 coords_pbc, at_shifts = jax.vmap( 450 partial(compute_pbc, mode="floor") 451 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 452 vec = coords_pbc[p2] - coords_pbc[p1] 453 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 454 # if num_repeats is None: 455 # raise ValueError( 456 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 457 # ) 458 # check if num_repeats is larger than previous 459 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 460 cdinv = cutoff_skin * inv_distances 461 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 462 if "true_sys" in inputs: 463 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 464 num_repeats_new = jnp.max(num_repeats_all, axis=0) 465 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 466 467 cell_shift_pbc = jnp.asarray( 468 np.array( 469 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 470 dtype=cells.dtype, 471 ).T.reshape(-1, 3) 472 ) 473 474 if cells.shape[0] == 1: 475 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 476 pbc_shifts = jnp.broadcast_to( 477 cell_shift_pbc[None, :, :], 478 (p1.shape[0], cell_shift_pbc.shape[0], 3), 479 ).reshape(-1, 3) 480 p1 = jnp.broadcast_to( 481 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 482 ).flatten() 483 p2 = jnp.broadcast_to( 484 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 485 ).flatten() 486 if natoms.shape[0] > 1: 487 mask_p12 = jnp.broadcast_to( 488 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 489 ).flatten() 490 else: 491 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 492 493 ## get pbc shifts specific to each box 494 cell_shift_pbc = jnp.broadcast_to( 495 cell_shift_pbc[None, :, :], 496 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 497 ) 498 mask = jnp.all( 499 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 500 ).flatten() 501 max_shifts = state.get("nshifts_pbc", 1) 502 503 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 504 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 505 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 506 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 507 mask, 508 max_shifts, 509 (dvecx, 0.), 510 (dvecy, 0.), 511 (dvecz, 0.), 512 (shiftx, 0), 513 (shifty, 0), 514 (shiftz, 0), 515 ) 516 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 517 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 518 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 519 520 ## get batch shift in the dvec_filter array 521 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 522 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 523 524 ## repeat vectors 525 nvec_max = state.get("nvec_pbc", 1) 526 batch_index_vec = batch_index[p1] 527 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 528 nvec = nrep_vec.sum() 529 overflow_repeats = overflow_repeats | (nvec > nvec_max) 530 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 531 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 532 533 ## get shift index 534 dshift = jnp.concatenate( 535 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 536 ) 537 if nrep_vec.size == 0: 538 dshift = jnp.array([],dtype=jnp.int32) 539 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 540 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 541 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 542 vec = vec + dvec[icellshift] 543 pbc_shifts = cell_shift_pbc[icellshift] 544 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 545 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 546 if natoms.shape[0] > 1: 547 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 548 549 550 ## compute distances 551 d12 = (vec**2).sum(axis=-1) 552 if natoms.shape[0] > 1: 553 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 554 555 ## filter pairs 556 max_pairs = state.get("npairs", 1) 557 mask = d12 < cutoff_skin**2 558 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 559 mask, 560 max_pairs, 561 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 562 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 563 (d12, cutoff_skin**2), 564 ) 565 if "cells" in inputs: 566 pbc_shifts = ( 567 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 568 .at[scatter_idx] 569 .set(pbc_shifts, mode="drop") 570 ) 571 if not minimage: 572 pbc_shifts = ( 573 pbc_shifts 574 + at_shifts.at[edge_dst].get(fill_value=0.0) 575 - at_shifts.at[edge_src].get(fill_value=0.0) 576 ) 577 578 ## check for overflow 579 if natoms.shape[0] == 1: 580 true_max_nat = coords.shape[0] 581 else: 582 true_max_nat = jnp.max(natoms) 583 overflow_count = npairs > max_pairs 584 overflow_at = true_max_nat > max_nat 585 overflow = overflow_count | overflow_at | overflow_repeats 586 587 if "nblist_skin" in state: 588 # edge_mask_skin = edge_mask 589 edge_src_skin = edge_src 590 edge_dst_skin = edge_dst 591 if "cells" in inputs: 592 pbc_shifts_skin = pbc_shifts 593 max_pairs_skin = state.get("npairs_skin", 1) 594 mask = d12 < self.cutoff**2 595 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 596 mask, 597 max_pairs_skin, 598 (edge_src, coords.shape[0]), 599 (edge_dst, coords.shape[0]), 600 (d12, self.cutoff**2), 601 ) 602 if "cells" in inputs: 603 pbc_shifts = ( 604 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 605 .at[scatter_idx] 606 .set(pbc_shifts, mode="drop") 607 ) 608 overflow = overflow | (npairs_skin > max_pairs_skin) 609 610 ## symmetrize 611 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 612 (edge_dst, edge_src) 613 ) 614 d12 = jnp.concatenate((d12, d12)) 615 if "cells" in inputs: 616 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 617 618 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 619 graph_out = { 620 **graph, 621 "edge_src": edge_src, 622 "edge_dst": edge_dst, 623 "d12": d12, 624 "overflow": overflow, 625 "pbc_shifts": pbc_shifts, 626 } 627 if "nblist_skin" in state: 628 graph_out["edge_src_skin"] = edge_src_skin 629 graph_out["edge_dst_skin"] = edge_dst_skin 630 if "cells" in inputs: 631 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 632 633 if self.k_space and "cells" in inputs: 634 if "k_points" not in graph: 635 raise NotImplementedError( 636 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 637 ) 638 return {**inputs, self.graph_key: graph_out} 639 640 @partial(jax.jit, static_argnums=(0,)) 641 def update_skin(self, inputs): 642 """update the nblist without recomputing the full nblist""" 643 graph = inputs[self.graph_key] 644 645 edge_src_skin = graph["edge_src_skin"] 646 edge_dst_skin = graph["edge_dst_skin"] 647 coords = inputs["coordinates"] 648 vec = coords.at[edge_dst_skin].get( 649 mode="fill", fill_value=self.cutoff 650 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 651 652 if "cells" in inputs: 653 pbc_shifts_skin = graph["pbc_shifts_skin"] 654 cells = inputs["cells"] 655 if cells.shape[0] == 1: 656 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 657 else: 658 batch_index_vec = inputs["batch_index"][edge_src_skin] 659 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 660 661 nat = coords.shape[0] 662 d12 = jnp.sum(vec**2, axis=-1) 663 mask = d12 < self.cutoff**2 664 max_pairs = graph["edge_src"].shape[0] // 2 665 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 666 mask, 667 max_pairs, 668 (edge_src_skin, nat), 669 (edge_dst_skin, nat), 670 (d12, self.cutoff**2), 671 ) 672 if "cells" in inputs: 673 pbc_shifts = ( 674 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 675 .at[scatter_idx] 676 .set(pbc_shifts_skin) 677 ) 678 679 overflow = graph.get("overflow", False) | (npairs > max_pairs) 680 graph_out = { 681 **graph, 682 "edge_src": jnp.concatenate((edge_src, edge_dst)), 683 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 684 "d12": jnp.concatenate((d12, d12)), 685 "overflow": overflow, 686 } 687 if "cells" in inputs: 688 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 689 690 if self.k_space and "cells" in inputs: 691 if "k_points" not in graph: 692 raise NotImplementedError( 693 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 694 ) 695 696 return {**inputs, self.graph_key: graph_out} 697 698 699class GraphProcessor(nn.Module): 700 """Process a pre-generated graph 701 702 The pre-generated graph should contain the following keys: 703 - edge_src: source indices of the edges 704 - edge_dst: destination indices of the edges 705 - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs) 706 707 This module is automatically added to a FENNIX model when a GraphGenerator is used. 708 709 """ 710 711 cutoff: float 712 """Cutoff distance for the graph.""" 713 graph_key: str = "graph" 714 """Key of the graph in the outputs.""" 715 switch_params: dict = dataclasses.field(default_factory=dict) 716 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 717 718 @nn.compact 719 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 720 graph = inputs[self.graph_key] 721 coords = inputs["coordinates"] 722 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 723 # edge_mask = edge_src < coords.shape[0] 724 vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[ 725 edge_src 726 ].get(mode="fill", fill_value=0.0) 727 if "cells" in inputs: 728 cells = inputs["cells"] 729 if cells.shape[0] == 1: 730 vec = vec + jnp.dot(graph["pbc_shifts"], cells[0]) 731 else: 732 batch_index_vec = inputs["batch_index"][edge_src] 733 vec = vec + jax.vmap(jnp.dot)( 734 graph["pbc_shifts"], cells[batch_index_vec] 735 ) 736 737 d2 = jnp.sum(vec**2, axis=-1) 738 distances = safe_sqrt(d2) 739 edge_mask = distances < self.cutoff 740 741 switch = SwitchFunction( 742 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 743 )((distances, edge_mask)) 744 745 graph_out = { 746 **graph, 747 "vec": vec, 748 "distances": distances, 749 "switch": switch, 750 "edge_mask": edge_mask, 751 } 752 753 if "alch_group" in inputs: 754 alch_group = inputs["alch_group"] 755 lambda_e = inputs["alch_elambda"] 756 lambda_v = inputs["alch_vlambda"] 757 mask = alch_group[edge_src] == alch_group[edge_dst] 758 graph_out["switch_raw"] = switch 759 graph_out["switch"] = jnp.where( 760 mask, 761 switch, 762 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 763 ) 764 graph_out["distances_raw"] = distances 765 if "alch_softcore_e" in inputs: 766 alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2 767 else: 768 alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2 769 770 graph_out["distances"] = jnp.where( 771 mask, 772 distances, 773 safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2)) 774 ) 775 776 777 return {**inputs, self.graph_key: graph_out} 778 779 780@dataclasses.dataclass(frozen=True) 781class GraphFilter: 782 """Filter a graph based on a cutoff distance 783 784 FPID: GRAPH_FILTER 785 """ 786 787 cutoff: float 788 """Cutoff distance for the filtering.""" 789 parent_graph: str 790 """Key of the parent graph in the inputs.""" 791 graph_key: str 792 """Key of the filtered graph in the outputs.""" 793 remove_hydrogens: int = False 794 """Remove edges where the source is a hydrogen atom.""" 795 switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict) 796 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 797 k_space: bool = False 798 """Generate k-space information for the graph.""" 799 kmax: int = 30 800 """Maximum number of k-points to consider.""" 801 kthr: float = 1e-6 802 """Threshold for k-point filtering.""" 803 mult_size: float = 1.05 804 """Multiplicative factor for resizing the nblist.""" 805 806 FPID: ClassVar[str] = "GRAPH_FILTER" 807 808 def init(self): 809 return FrozenDict( 810 { 811 "npairs": 1, 812 "nblist_mult_size": self.mult_size, 813 } 814 ) 815 816 def get_processor(self) -> Tuple[nn.Module, Dict]: 817 return GraphFilterProcessor, { 818 "cutoff": self.cutoff, 819 "graph_key": self.graph_key, 820 "parent_graph": self.parent_graph, 821 "name": f"{self.graph_key}_Filter_{self.parent_graph}", 822 "switch_params": self.switch_params, 823 } 824 825 def get_graph_properties(self): 826 return { 827 self.graph_key: { 828 "cutoff": self.cutoff, 829 "directed": True, 830 "parent_graph": self.parent_graph, 831 } 832 } 833 834 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 835 """filter a nblist on cpu with numpy and dynamic shapes + store max shapes""" 836 graph_in = inputs[self.parent_graph] 837 nat = inputs["species"].shape[0] 838 839 new_state = {**state} 840 state_up = {} 841 mult_size = state.get("nblist_mult_size", self.mult_size) 842 assert mult_size >= 1., "nblist_mult_size should be >= 1." 843 844 edge_src = np.array(graph_in["edge_src"], dtype=np.int32) 845 d12 = np.array(graph_in["d12"], dtype=np.float32) 846 if self.remove_hydrogens: 847 species = inputs["species"] 848 src_idx = (edge_src < nat).nonzero()[0] 849 mask = np.zeros(edge_src.shape[0], dtype=bool) 850 mask[src_idx] = (species > 1)[edge_src[src_idx]] 851 d12 = np.where(mask, d12, self.cutoff**2) 852 mask = d12 < self.cutoff**2 853 854 max_pairs = state.get("npairs", 1) 855 idx = np.nonzero(mask)[0] 856 npairs = idx.shape[0] 857 if npairs > max_pairs or add_margin: 858 prev_max_pairs = max_pairs 859 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 860 state_up["npairs"] = (max_pairs, prev_max_pairs) 861 new_state["npairs"] = max_pairs 862 863 filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32) 864 edge_src = np.full(max_pairs, nat, dtype=np.int32) 865 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 866 d12_ = np.full(max_pairs, self.cutoff**2) 867 filter_indices[:npairs] = idx 868 edge_src[:npairs] = graph_in["edge_src"][idx] 869 edge_dst[:npairs] = graph_in["edge_dst"][idx] 870 d12_[:npairs] = d12[idx] 871 d12 = d12_ 872 873 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 874 graph_out = { 875 **graph, 876 "edge_src": edge_src, 877 "edge_dst": edge_dst, 878 "filter_indices": filter_indices, 879 "d12": d12, 880 "overflow": False, 881 } 882 883 if self.k_space and "cells" in inputs: 884 if "k_points" not in graph: 885 ks, _, _, bewald = get_reciprocal_space_parameters( 886 inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr 887 ) 888 graph_out["k_points"] = ks 889 graph_out["b_ewald"] = bewald 890 891 output = {**inputs, self.graph_key: graph_out} 892 if return_state_update: 893 return FrozenDict(new_state), output, state_up 894 return FrozenDict(new_state), output 895 896 def check_reallocate(self, state, inputs, parent_overflow=False): 897 """check for overflow and reallocate nblist if necessary""" 898 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 899 if not overflow: 900 return state, {}, inputs, False 901 902 add_margin = inputs[self.graph_key].get("overflow", False) 903 state, inputs, state_up = self( 904 state, inputs, return_state_update=True, add_margin=add_margin 905 ) 906 return state, state_up, inputs, True 907 908 @partial(jax.jit, static_argnums=(0, 1)) 909 def process(self, state, inputs): 910 """filter a nblist on accelerator with jax and precomputed shapes""" 911 graph_in = inputs[self.parent_graph] 912 if state is None: 913 # skin update mode 914 graph = inputs[self.graph_key] 915 max_pairs = graph["edge_src"].shape[0] 916 else: 917 max_pairs = state.get("npairs", 1) 918 919 max_pairs_in = graph_in["edge_src"].shape[0] 920 nat = inputs["species"].shape[0] 921 922 edge_src = graph_in["edge_src"] 923 d12 = graph_in["d12"] 924 if self.remove_hydrogens: 925 species = inputs["species"] 926 mask = (species > 1)[edge_src] 927 d12 = jnp.where(mask, d12, self.cutoff**2) 928 mask = d12 < self.cutoff**2 929 930 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 931 mask, 932 max_pairs, 933 (edge_src, nat), 934 (graph_in["edge_dst"], nat), 935 (d12, self.cutoff**2), 936 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 937 ) 938 939 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 940 overflow = graph.get("overflow", False) | (npairs > max_pairs) 941 graph_out = { 942 **graph, 943 "edge_src": edge_src, 944 "edge_dst": edge_dst, 945 "filter_indices": filter_indices, 946 "d12": d12, 947 "overflow": overflow, 948 } 949 950 if self.k_space and "cells" in inputs: 951 if "k_points" not in graph: 952 raise NotImplementedError( 953 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 954 ) 955 956 return {**inputs, self.graph_key: graph_out} 957 958 @partial(jax.jit, static_argnums=(0,)) 959 def update_skin(self, inputs): 960 return self.process(None, inputs) 961 962 963class GraphFilterProcessor(nn.Module): 964 """Filter processing for a pre-generated graph 965 966 This module is automatically added to a FENNIX model when a GraphFilter is used. 967 """ 968 969 cutoff: float 970 """Cutoff distance for the filtering.""" 971 graph_key: str 972 """Key of the filtered graph in the inputs.""" 973 parent_graph: str 974 """Key of the parent graph in the inputs.""" 975 switch_params: dict = dataclasses.field(default_factory=dict) 976 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 977 978 @nn.compact 979 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 980 graph_in = inputs[self.parent_graph] 981 graph = inputs[self.graph_key] 982 983 d_key = "distances_raw" if "distances_raw" in graph else "distances" 984 985 if graph_in["vec"].shape[0] == 0: 986 vec = graph_in["vec"] 987 distances = graph_in[d_key] 988 filter_indices = jnp.asarray([], dtype=jnp.int32) 989 else: 990 filter_indices = graph["filter_indices"] 991 vec = ( 992 graph_in["vec"] 993 .at[filter_indices] 994 .get(mode="fill", fill_value=self.cutoff) 995 ) 996 distances = ( 997 graph_in[d_key] 998 .at[filter_indices] 999 .get(mode="fill", fill_value=self.cutoff) 1000 ) 1001 1002 edge_mask = distances < self.cutoff 1003 switch = SwitchFunction( 1004 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 1005 )((distances, edge_mask)) 1006 1007 graph_out = { 1008 **graph, 1009 "vec": vec, 1010 "distances": distances, 1011 "switch": switch, 1012 "filter_indices": filter_indices, 1013 "edge_mask": edge_mask, 1014 } 1015 1016 if "alch_group" in inputs: 1017 edge_src=graph["edge_src"] 1018 edge_dst=graph["edge_dst"] 1019 alch_group = inputs["alch_group"] 1020 lambda_e = inputs["alch_elambda"] 1021 lambda_v = inputs["alch_vlambda"] 1022 mask = alch_group[edge_src] == alch_group[edge_dst] 1023 graph_out["switch_raw"] = switch 1024 graph_out["switch"] = jnp.where( 1025 mask, 1026 switch, 1027 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 1028 ) 1029 1030 graph_out["distances_raw"] = distances 1031 if "alch_softcore_e" in inputs: 1032 alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2 1033 else: 1034 alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2 1035 1036 graph_out["distances"] = jnp.where( 1037 mask, 1038 distances, 1039 safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2)) 1040 ) 1041 1042 1043 return {**inputs, self.graph_key: graph_out} 1044 1045 1046@dataclasses.dataclass(frozen=True) 1047class GraphAngularExtension: 1048 """Add angles list to a graph 1049 1050 FPID: GRAPH_ANGULAR_EXTENSION 1051 """ 1052 1053 mult_size: float = 1.05 1054 """Multiplicative factor for resizing the nblist.""" 1055 add_neigh: int = 5 1056 """Additional neighbors to add to the nblist when resizing.""" 1057 graph_key: str = "graph" 1058 """Key of the graph in the inputs.""" 1059 1060 FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION" 1061 1062 def init(self): 1063 return FrozenDict( 1064 { 1065 "nangles": 0, 1066 "nblist_mult_size": self.mult_size, 1067 "max_neigh": self.add_neigh, 1068 "add_neigh": self.add_neigh, 1069 } 1070 ) 1071 1072 def get_processor(self) -> Tuple[nn.Module, Dict]: 1073 return GraphAngleProcessor, { 1074 "graph_key": self.graph_key, 1075 "name": f"{self.graph_key}_AngleProcessor", 1076 } 1077 1078 def get_graph_properties(self): 1079 return { 1080 self.graph_key: { 1081 "has_angles": True, 1082 } 1083 } 1084 1085 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1086 """build angle nblist on cpu with numpy and dynamic shapes + store max shapes""" 1087 graph = inputs[self.graph_key] 1088 edge_src = np.array(graph["edge_src"], dtype=np.int32) 1089 1090 new_state = {**state} 1091 state_up = {} 1092 mult_size = state.get("nblist_mult_size", self.mult_size) 1093 assert mult_size >= 1., "nblist_mult_size should be >= 1." 1094 1095 ### count number of neighbors 1096 nat = inputs["species"].shape[0] 1097 count = np.zeros(nat + 1, dtype=np.int32) 1098 np.add.at(count, edge_src, 1) 1099 max_count = int(np.max(count[:-1])) 1100 1101 ### get sizes 1102 max_neigh = state.get("max_neigh", self.add_neigh) 1103 nedge = edge_src.shape[0] 1104 if max_count > max_neigh or add_margin: 1105 prev_max_neigh = max_neigh 1106 max_neigh = max(max_count, max_neigh) + state.get( 1107 "add_neigh", self.add_neigh 1108 ) 1109 state_up["max_neigh"] = (max_neigh, prev_max_neigh) 1110 new_state["max_neigh"] = max_neigh 1111 1112 max_neigh_arr = np.empty(max_neigh, dtype=bool) 1113 1114 nedge = edge_src.shape[0] 1115 1116 ### sort edge_src 1117 idx_sort = np.argsort(edge_src) 1118 edge_src_sorted = edge_src[idx_sort] 1119 1120 ### map sparse to dense nblist 1121 offset = np.tile(np.arange(max_count), nat) 1122 if max_count * nat >= nedge: 1123 offset = np.tile(np.arange(max_count), nat)[:nedge] 1124 else: 1125 offset = np.zeros(nedge, dtype=np.int32) 1126 offset[: max_count * nat] = np.tile(np.arange(max_count), nat) 1127 1128 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1129 mask = edge_src_sorted < nat 1130 indices = edge_src_sorted * max_count + offset 1131 indices = indices[mask] 1132 idx_sort = idx_sort[mask] 1133 edge_idx = np.full(nat * max_count, nedge, dtype=np.int32) 1134 edge_idx[indices] = idx_sort 1135 edge_idx = edge_idx.reshape(nat, max_count) 1136 1137 ### find all triplet for each atom center 1138 local_src, local_dst = np.triu_indices(max_count, 1) 1139 angle_src = edge_idx[:, local_src].flatten() 1140 angle_dst = edge_idx[:, local_dst].flatten() 1141 1142 ### mask for valid angles 1143 mask1 = angle_src < nedge 1144 mask2 = angle_dst < nedge 1145 angle_mask = mask1 & mask2 1146 1147 max_angles = state.get("nangles", 0) 1148 idx = np.nonzero(angle_mask)[0] 1149 nangles = idx.shape[0] 1150 if nangles > max_angles or add_margin: 1151 max_angles_prev = max_angles 1152 max_angles = int(mult_size * max(nangles, max_angles)) + 1 1153 state_up["nangles"] = (max_angles, max_angles_prev) 1154 new_state["nangles"] = max_angles 1155 1156 ## filter angles to sparse representation 1157 angle_src_ = np.full(max_angles, nedge, dtype=np.int32) 1158 angle_dst_ = np.full(max_angles, nedge, dtype=np.int32) 1159 angle_src_[:nangles] = angle_src[idx] 1160 angle_dst_[:nangles] = angle_dst[idx] 1161 1162 central_atom = np.full(max_angles, nat, dtype=np.int32) 1163 central_atom[:nangles] = edge_src[angle_src_[:nangles]] 1164 1165 ## update graph 1166 output = { 1167 **inputs, 1168 self.graph_key: { 1169 **graph, 1170 "angle_src": angle_src_, 1171 "angle_dst": angle_dst_, 1172 "central_atom": central_atom, 1173 "angle_overflow": False, 1174 "max_neigh": max_neigh, 1175 "__max_neigh_array": max_neigh_arr, 1176 }, 1177 } 1178 1179 if return_state_update: 1180 return FrozenDict(new_state), output, state_up 1181 return FrozenDict(new_state), output 1182 1183 def check_reallocate(self, state, inputs, parent_overflow=False): 1184 """check for overflow and reallocate nblist if necessary""" 1185 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1186 if not overflow: 1187 return state, {}, inputs, False 1188 1189 add_margin = inputs[self.graph_key]["angle_overflow"] 1190 state, inputs, state_up = self( 1191 state, inputs, return_state_update=True, add_margin=add_margin 1192 ) 1193 return state, state_up, inputs, True 1194 1195 @partial(jax.jit, static_argnums=(0, 1)) 1196 def process(self, state, inputs): 1197 """build angle nblist on accelerator with jax and precomputed shapes""" 1198 graph = inputs[self.graph_key] 1199 edge_src = graph["edge_src"] 1200 1201 ### count number of neighbors 1202 nat = inputs["species"].shape[0] 1203 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1204 max_count = jnp.max(count) 1205 1206 ### get sizes 1207 if state is None: 1208 max_neigh_arr = graph["__max_neigh_array"] 1209 max_neigh = max_neigh_arr.shape[0] 1210 prev_nangles = graph["angle_src"].shape[0] 1211 else: 1212 max_neigh = state.get("max_neigh", self.add_neigh) 1213 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1214 prev_nangles = state.get("nangles", 0) 1215 1216 nedge = edge_src.shape[0] 1217 1218 ### sort edge_src 1219 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1220 edge_src_sorted = edge_src[idx_sort] 1221 1222 ### map sparse to dense nblist 1223 if max_neigh * nat < nedge: 1224 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1225 offset = jnp.asarray( 1226 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1227 ) 1228 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1229 indices = edge_src_sorted * max_neigh + offset 1230 edge_idx = ( 1231 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1232 .at[indices] 1233 .set(idx_sort, mode="drop") 1234 .reshape(nat, max_neigh) 1235 ) 1236 1237 ### find all triplet for each atom center 1238 local_src, local_dst = np.triu_indices(max_neigh, 1) 1239 angle_src = edge_idx[:, local_src].flatten() 1240 angle_dst = edge_idx[:, local_dst].flatten() 1241 1242 ### mask for valid angles 1243 mask1 = angle_src < nedge 1244 mask2 = angle_dst < nedge 1245 angle_mask = mask1 & mask2 1246 1247 ## filter angles to sparse representation 1248 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1249 angle_mask, 1250 prev_nangles, 1251 (angle_src, nedge), 1252 (angle_dst, nedge), 1253 ) 1254 ## find central atom 1255 central_atom = edge_src[angle_src] 1256 1257 ## check for overflow 1258 angle_overflow = nangles > prev_nangles 1259 neigh_overflow = max_count > max_neigh 1260 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1261 1262 ## update graph 1263 output = { 1264 **inputs, 1265 self.graph_key: { 1266 **graph, 1267 "angle_src": angle_src, 1268 "angle_dst": angle_dst, 1269 "central_atom": central_atom, 1270 "angle_overflow": overflow, 1271 # "max_neigh": max_neigh, 1272 "__max_neigh_array": max_neigh_arr, 1273 }, 1274 } 1275 1276 return output 1277 1278 @partial(jax.jit, static_argnums=(0,)) 1279 def update_skin(self, inputs): 1280 return self.process(None, inputs) 1281 1282 1283class GraphAngleProcessor(nn.Module): 1284 """Process a pre-generated graph to compute angles 1285 1286 This module is automatically added to a FENNIX model when a GraphAngularExtension is used. 1287 1288 """ 1289 1290 graph_key: str 1291 """Key of the graph in the inputs.""" 1292 1293 @nn.compact 1294 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1295 graph = inputs[self.graph_key] 1296 distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"] 1297 vec = graph["vec"] 1298 angle_src = graph["angle_src"] 1299 angle_dst = graph["angle_dst"] 1300 1301 dir = vec / jnp.clip(distances[:, None], min=1.0e-5) 1302 cos_angles = ( 1303 dir.at[angle_src].get(mode="fill", fill_value=0.5) 1304 * dir.at[angle_dst].get(mode="fill", fill_value=0.5) 1305 ).sum(axis=-1) 1306 1307 angles = jnp.arccos(0.95 * cos_angles) 1308 1309 return { 1310 **inputs, 1311 self.graph_key: { 1312 **graph, 1313 # "cos_angles": cos_angles, 1314 "angles": angles, 1315 # "angle_mask": angle_mask, 1316 }, 1317 } 1318 1319 1320@dataclasses.dataclass(frozen=True) 1321class SpeciesIndexer: 1322 """Build an index that splits atomic arrays by species. 1323 1324 FPID: SPECIES_INDEXER 1325 1326 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1327 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1328 1329 """ 1330 1331 output_key: str = "species_index" 1332 """Key for the output dictionary.""" 1333 species_order: Optional[str] = None 1334 """Comma separated list of species in the order they should be indexed.""" 1335 add_atoms: int = 0 1336 """Additional atoms to add to the sizes.""" 1337 add_atoms_margin: int = 10 1338 """Additional atoms to add to the sizes when adding margin.""" 1339 1340 FPID: ClassVar[str] = "SPECIES_INDEXER" 1341 1342 def init(self): 1343 return FrozenDict( 1344 { 1345 "sizes": {}, 1346 } 1347 ) 1348 1349 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1350 species = np.array(inputs["species"], dtype=np.int32) 1351 nat = species.shape[0] 1352 set_species, counts = np.unique(species, return_counts=True) 1353 1354 new_state = {**state} 1355 state_up = {} 1356 1357 sizes = state.get("sizes", FrozenDict({})) 1358 new_sizes = {**sizes} 1359 up_sizes = False 1360 counts_dict = {} 1361 for s, c in zip(set_species, counts): 1362 if s <= 0: 1363 continue 1364 counts_dict[s] = c 1365 if c > sizes.get(s, 0): 1366 up_sizes = True 1367 add_atoms = state.get("add_atoms", self.add_atoms) 1368 if add_margin: 1369 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1370 new_sizes[s] = c + add_atoms 1371 1372 new_sizes = FrozenDict(new_sizes) 1373 if up_sizes: 1374 state_up["sizes"] = (new_sizes, sizes) 1375 new_state["sizes"] = new_sizes 1376 1377 if self.species_order is not None: 1378 species_order = [el.strip() for el in self.species_order.split(",")] 1379 max_size_prev = state.get("max_size", 0) 1380 max_size = max(new_sizes.values()) 1381 if max_size > max_size_prev: 1382 state_up["max_size"] = (max_size, max_size_prev) 1383 new_state["max_size"] = max_size 1384 max_size_prev = max_size 1385 1386 species_index = np.full((len(species_order), max_size), nat, dtype=np.int32) 1387 for i, el in enumerate(species_order): 1388 s = PERIODIC_TABLE_REV_IDX[el] 1389 if s in counts_dict.keys(): 1390 species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1391 else: 1392 species_index = { 1393 PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32) 1394 for s, c in new_sizes.items() 1395 } 1396 for s, c in zip(set_species, counts): 1397 if s <= 0: 1398 continue 1399 species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0] 1400 1401 output = { 1402 **inputs, 1403 self.output_key: species_index, 1404 self.output_key + "_overflow": False, 1405 } 1406 1407 if return_state_update: 1408 return FrozenDict(new_state), output, state_up 1409 return FrozenDict(new_state), output 1410 1411 def check_reallocate(self, state, inputs, parent_overflow=False): 1412 """check for overflow and reallocate nblist if necessary""" 1413 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1414 if not overflow: 1415 return state, {}, inputs, False 1416 1417 add_margin = inputs[self.output_key + "_overflow"] 1418 state, inputs, state_up = self( 1419 state, inputs, return_state_update=True, add_margin=add_margin 1420 ) 1421 return state, state_up, inputs, True 1422 # return state, {}, inputs, parent_overflow 1423 1424 @partial(jax.jit, static_argnums=(0, 1)) 1425 def process(self, state, inputs): 1426 # assert ( 1427 # self.output_key in inputs 1428 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1429 1430 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1431 if self.output_key in inputs and not recompute_species_index: 1432 return inputs 1433 1434 if state is None: 1435 raise ValueError("Species Indexer state must be provided on accelerator.") 1436 1437 species = inputs["species"] 1438 nat = species.shape[0] 1439 1440 sizes = state["sizes"] 1441 1442 if self.species_order is not None: 1443 species_order = [el.strip() for el in self.species_order.split(",")] 1444 max_size = state["max_size"] 1445 1446 species_index = jnp.full( 1447 (len(species_order), max_size), nat, dtype=jnp.int32 1448 ) 1449 for i, el in enumerate(species_order): 1450 s = PERIODIC_TABLE_REV_IDX[el] 1451 if s in sizes.keys(): 1452 c = sizes[s] 1453 species_index = species_index.at[i, :].set( 1454 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1455 ) 1456 # if s in counts_dict.keys(): 1457 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1458 else: 1459 # species_index = { 1460 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1461 # for s, c in sizes.items() 1462 # } 1463 species_index = {} 1464 overflow = False 1465 natcount = 0 1466 for s, c in sizes.items(): 1467 mask = species == s 1468 new_size = jnp.sum(mask) 1469 natcount = natcount + new_size 1470 overflow = overflow | (new_size > c) # check if sizes are correct 1471 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1472 species == s, size=c, fill_value=nat 1473 )[0] 1474 1475 mask = species <= 0 1476 new_size = jnp.sum(mask) 1477 natcount = natcount + new_size 1478 overflow = overflow | ( 1479 natcount < species.shape[0] 1480 ) # check if any species missing 1481 1482 return { 1483 **inputs, 1484 self.output_key: species_index, 1485 self.output_key + "_overflow": overflow, 1486 } 1487 1488 @partial(jax.jit, static_argnums=(0,)) 1489 def update_skin(self, inputs): 1490 return self.process(None, inputs) 1491 1492@dataclasses.dataclass(frozen=True) 1493class BlockIndexer: 1494 """Build an index that splits atomic arrays by chemical blocks. 1495 1496 FPID: BLOCK_INDEXER 1497 1498 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1499 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1500 1501 """ 1502 1503 output_key: str = "block_index" 1504 """Key for the output dictionary.""" 1505 add_atoms: int = 0 1506 """Additional atoms to add to the sizes.""" 1507 add_atoms_margin: int = 10 1508 """Additional atoms to add to the sizes when adding margin.""" 1509 split_CNOPSSe: bool = False 1510 1511 FPID: ClassVar[str] = "BLOCK_INDEXER" 1512 1513 def init(self): 1514 return FrozenDict( 1515 { 1516 "sizes": {}, 1517 } 1518 ) 1519 1520 def build_chemical_blocks(self): 1521 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1522 if self.split_CNOPSSe: 1523 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1524 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1525 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1526 if self.split_CNOPSSe: 1527 _CHEMICAL_BLOCKS[6] = 1 1528 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1529 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1530 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1531 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1532 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1533 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS 1534 1535 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1536 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1537 1538 species = np.array(inputs["species"], dtype=np.int32) 1539 blocks = _CHEMICAL_BLOCKS[species] 1540 nat = species.shape[0] 1541 set_blocks, counts = np.unique(blocks, return_counts=True) 1542 1543 new_state = {**state} 1544 state_up = {} 1545 1546 sizes = state.get("sizes", FrozenDict({})) 1547 new_sizes = {**sizes} 1548 up_sizes = False 1549 for s, c in zip(set_blocks, counts): 1550 if s < 0: 1551 continue 1552 key = (s, _CHEMICAL_BLOCKS_NAMES[s]) 1553 if c > sizes.get(key, 0): 1554 up_sizes = True 1555 add_atoms = state.get("add_atoms", self.add_atoms) 1556 if add_margin: 1557 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1558 new_sizes[key] = c + add_atoms 1559 1560 new_sizes = FrozenDict(new_sizes) 1561 if up_sizes: 1562 state_up["sizes"] = (new_sizes, sizes) 1563 new_state["sizes"] = new_sizes 1564 1565 block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES} 1566 for (_,n), c in new_sizes.items(): 1567 block_index[n] = np.full(c, nat, dtype=np.int32) 1568 # block_index = { 1569 # n: np.full(c, nat, dtype=np.int32) 1570 # for (_,n), c in new_sizes.items() 1571 # } 1572 for s, c in zip(set_blocks, counts): 1573 if s < 0: 1574 continue 1575 block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0] 1576 1577 output = { 1578 **inputs, 1579 self.output_key: block_index, 1580 self.output_key + "_overflow": False, 1581 } 1582 1583 if return_state_update: 1584 return FrozenDict(new_state), output, state_up 1585 return FrozenDict(new_state), output 1586 1587 def check_reallocate(self, state, inputs, parent_overflow=False): 1588 """check for overflow and reallocate nblist if necessary""" 1589 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1590 if not overflow: 1591 return state, {}, inputs, False 1592 1593 add_margin = inputs[self.output_key + "_overflow"] 1594 state, inputs, state_up = self( 1595 state, inputs, return_state_update=True, add_margin=add_margin 1596 ) 1597 return state, state_up, inputs, True 1598 # return state, {}, inputs, parent_overflow 1599 1600 @partial(jax.jit, static_argnums=(0, 1)) 1601 def process(self, state, inputs): 1602 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1603 # assert ( 1604 # self.output_key in inputs 1605 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1606 1607 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1608 if self.output_key in inputs and not recompute_species_index: 1609 return inputs 1610 1611 if state is None: 1612 raise ValueError("Block Indexer state must be provided on accelerator.") 1613 1614 species = inputs["species"] 1615 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1616 nat = species.shape[0] 1617 1618 sizes = state["sizes"] 1619 1620 # species_index = { 1621 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1622 # for s, c in sizes.items() 1623 # } 1624 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1625 overflow = False 1626 natcount = 0 1627 for (s,name), c in sizes.items(): 1628 mask = blocks == s 1629 new_size = jnp.sum(mask) 1630 natcount = natcount + new_size 1631 overflow = overflow | (new_size > c) # check if sizes are correct 1632 block_index[name] = jnp.nonzero( 1633 mask, size=c, fill_value=nat 1634 )[0] 1635 1636 mask = blocks < 0 1637 new_size = jnp.sum(mask) 1638 natcount = natcount + new_size 1639 overflow = overflow | ( 1640 natcount < species.shape[0] 1641 ) # check if any species missing 1642 1643 return { 1644 **inputs, 1645 self.output_key: block_index, 1646 self.output_key + "_overflow": overflow, 1647 } 1648 1649 @partial(jax.jit, static_argnums=(0,)) 1650 def update_skin(self, inputs): 1651 return self.process(None, inputs) 1652 1653 1654@dataclasses.dataclass(frozen=True) 1655class AtomPadding: 1656 """Pad atomic arrays to a fixed size.""" 1657 1658 mult_size: float = 1.2 1659 """Multiplicative factor for resizing the atomic arrays.""" 1660 add_sys: int = 0 1661 1662 def init(self): 1663 return {"prev_nat": 0, "prev_nsys": 0} 1664 1665 def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]: 1666 species = inputs["species"] 1667 nat = species.shape[0] 1668 1669 prev_nat = state.get("prev_nat", 0) 1670 prev_nat_ = prev_nat 1671 if nat > prev_nat_: 1672 prev_nat_ = int(self.mult_size * nat) + 1 1673 1674 nsys = len(inputs["natoms"]) 1675 prev_nsys = state.get("prev_nsys", 0) 1676 prev_nsys_ = prev_nsys 1677 if nsys > prev_nsys_: 1678 prev_nsys_ = nsys + self.add_sys 1679 1680 add_atoms = prev_nat_ - nat 1681 add_sys = prev_nsys_ - nsys + 1 1682 output = {**inputs} 1683 if add_atoms > 0: 1684 for k, v in inputs.items(): 1685 if isinstance(v, np.ndarray) or isinstance(v, jax.Array): 1686 if v.shape[0] == nat: 1687 output[k] = np.append( 1688 v, 1689 np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype), 1690 axis=0, 1691 ) 1692 elif v.shape[0] == nsys: 1693 if k == "cells": 1694 output[k] = np.append( 1695 v, 1696 1000 1697 * np.eye(3, dtype=v.dtype)[None, :, :].repeat( 1698 add_sys, axis=0 1699 ), 1700 axis=0, 1701 ) 1702 else: 1703 output[k] = np.append( 1704 v, 1705 np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype), 1706 axis=0, 1707 ) 1708 output["natoms"] = np.append( 1709 inputs["natoms"], np.zeros(add_sys, dtype=np.int32) 1710 ) 1711 output["species"] = np.append( 1712 species, -1 * np.ones(add_atoms, dtype=species.dtype) 1713 ) 1714 output["batch_index"] = np.append( 1715 inputs["batch_index"], 1716 np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype), 1717 ) 1718 if "system_index" in inputs: 1719 output["system_index"] = np.append( 1720 inputs["system_index"], 1721 np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype), 1722 ) 1723 1724 output["true_atoms"] = output["species"] > 0 1725 output["true_sys"] = np.arange(len(output["natoms"])) < nsys 1726 1727 state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_} 1728 1729 return FrozenDict(state), output 1730 1731 1732def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]: 1733 """Remove padding from atomic arrays.""" 1734 if "true_atoms" not in inputs: 1735 return inputs 1736 1737 species = np.asarray(inputs["species"]) 1738 true_atoms = np.asarray(inputs["true_atoms"]) 1739 true_sys = np.asarray(inputs["true_sys"]) 1740 natall = species.shape[0] 1741 nat = np.argmax(species <= 0) 1742 if nat == 0: 1743 return inputs 1744 1745 natoms = inputs["natoms"] 1746 nsysall = len(natoms) 1747 1748 output = {**inputs} 1749 for k, v in inputs.items(): 1750 if isinstance(v, jax.Array) or isinstance(v, np.ndarray): 1751 v = np.asarray(v) 1752 if v.ndim == 0: 1753 output[k] = v 1754 elif v.shape[0] == natall: 1755 output[k] = v[true_atoms] 1756 elif v.shape[0] == nsysall: 1757 output[k] = v[true_sys] 1758 del output["true_sys"] 1759 del output["true_atoms"] 1760 return output 1761 1762 1763def check_input(inputs): 1764 """Check the input dictionary for required keys and types.""" 1765 assert "species" in inputs, "species must be provided" 1766 assert "coordinates" in inputs, "coordinates must be provided" 1767 species = inputs["species"].astype(np.int32) 1768 ifake = np.argmax(species <= 0) 1769 if ifake > 0: 1770 assert np.all(species[:ifake] > 0), "species must be positive" 1771 nat = inputs["species"].shape[0] 1772 1773 natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32) 1774 batch_index = inputs.get( 1775 "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms) 1776 ).astype(np.int32) 1777 output = {**inputs, "natoms": natoms, "batch_index": batch_index} 1778 if "cells" in inputs: 1779 cells = inputs["cells"] 1780 if "reciprocal_cells" not in inputs: 1781 reciprocal_cells = np.linalg.inv(cells) 1782 else: 1783 reciprocal_cells = inputs["reciprocal_cells"] 1784 if cells.ndim == 2: 1785 cells = cells[None, :, :] 1786 if reciprocal_cells.ndim == 2: 1787 reciprocal_cells = reciprocal_cells[None, :, :] 1788 output["cells"] = cells 1789 output["reciprocal_cells"] = reciprocal_cells 1790 1791 return output 1792 1793 1794def convert_to_jax(data): 1795 """Convert a numpy arrays to jax arrays in a pytree.""" 1796 1797 def convert(x): 1798 if isinstance(x, np.ndarray): 1799 # if x.dtype == np.float64: 1800 # return jnp.asarray(x, dtype=jnp.float32) 1801 return jnp.asarray(x) 1802 return x 1803 1804 return jax.tree_util.tree_map(convert, data) 1805 1806 1807class JaxConverter(nn.Module): 1808 """Convert numpy arrays to jax arrays in a pytree.""" 1809 1810 def __call__(self, data): 1811 return convert_to_jax(data) 1812 1813 1814@dataclasses.dataclass(frozen=True) 1815class PreprocessingChain: 1816 """Chain of preprocessing layers.""" 1817 1818 layers: Tuple[Callable[..., Dict[str, Any]]] 1819 """Preprocessing layers.""" 1820 use_atom_padding: bool = False 1821 """Add an AtomPadding layer at the beginning of the chain.""" 1822 atom_padder: AtomPadding = AtomPadding() 1823 """AtomPadding layer.""" 1824 1825 def __post_init__(self): 1826 if not isinstance(self.layers, Sequence): 1827 raise ValueError( 1828 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 1829 ) 1830 if not self.layers: 1831 raise ValueError(f"Error: no Preprocessing layers were provided.") 1832 1833 def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]: 1834 do_check_input = state.get("check_input", True) 1835 if do_check_input: 1836 inputs = check_input(inputs) 1837 new_state = {**state} 1838 if self.use_atom_padding: 1839 s, inputs = self.atom_padder(state["padder_state"], inputs) 1840 new_state["padder_state"] = s 1841 layer_state = state["layers_state"] 1842 new_layer_state = [] 1843 for i,layer in enumerate(self.layers): 1844 s, inputs = layer(layer_state[i], inputs, return_state_update=False) 1845 new_layer_state.append(s) 1846 new_state["layers_state"] = tuple(new_layer_state) 1847 return FrozenDict(new_state), convert_to_jax(inputs) 1848 1849 def check_reallocate(self, state, inputs): 1850 new_state = [] 1851 state_up = [] 1852 layer_state = state["layers_state"] 1853 parent_overflow = False 1854 for i,layer in enumerate(self.layers): 1855 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1856 layer_state[i], inputs, parent_overflow 1857 ) 1858 new_state.append(s) 1859 state_up.append(s_up) 1860 1861 if not parent_overflow: 1862 return state, {}, inputs, False 1863 return ( 1864 FrozenDict({**state, "layers_state": tuple(new_state)}), 1865 state_up, 1866 inputs, 1867 True, 1868 ) 1869 1870 def atom_padding(self, state, inputs): 1871 if self.use_atom_padding: 1872 padder_state,inputs = self.atom_padder(state["padder_state"], inputs) 1873 return FrozenDict({**state,"padder_state": padder_state}), inputs 1874 return state, inputs 1875 1876 @partial(jax.jit, static_argnums=(0, 1)) 1877 def process(self, state, inputs): 1878 layer_state = state["layers_state"] 1879 for i,layer in enumerate(self.layers): 1880 inputs = layer.process(layer_state[i], inputs) 1881 return inputs 1882 1883 @partial(jax.jit, static_argnums=(0)) 1884 def update_skin(self, inputs): 1885 for layer in self.layers: 1886 inputs = layer.update_skin(inputs) 1887 return inputs 1888 1889 def init(self): 1890 state = {"check_input": True} 1891 if self.use_atom_padding: 1892 state["padder_state"] = self.atom_padder.init() 1893 layer_state = [] 1894 for layer in self.layers: 1895 layer_state.append(layer.init()) 1896 state["layers_state"] = tuple(layer_state) 1897 return FrozenDict(state) 1898 1899 def init_with_output(self, inputs): 1900 state = self.init() 1901 return self(state, inputs) 1902 1903 def get_processors(self): 1904 processors = [] 1905 for layer in self.layers: 1906 if hasattr(layer, "get_processor"): 1907 processors.append(layer.get_processor()) 1908 return processors 1909 1910 def get_graphs_properties(self): 1911 properties = {} 1912 for layer in self.layers: 1913 if hasattr(layer, "get_graph_properties"): 1914 properties = deep_update(properties, layer.get_graph_properties()) 1915 return properties 1916 1917 1918# PREPROCESSING = { 1919# "GRAPH": GraphGenerator, 1920# # "GRAPH_FIXED": GraphGeneratorFixed, 1921# "GRAPH_FILTER": GraphFilter, 1922# "GRAPH_ANGULAR_EXTENSION": GraphAngularExtension, 1923# # "GRAPH_DENSE_EXTENSION": GraphDenseExtension, 1924# "SPECIES_INDEXER": SpeciesIndexer, 1925# }
22@dataclasses.dataclass(frozen=True) 23class GraphGenerator: 24 """Generate a graph from a set of coordinates 25 26 FPID: GRAPH 27 28 For now, we generate all pairs of atoms and filter based on cutoff. 29 If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist. 30 """ 31 32 cutoff: float 33 """Cutoff distance for the graph.""" 34 graph_key: str = "graph" 35 """Key of the graph in the outputs.""" 36 switch_params: dict = dataclasses.field(default_factory=dict, hash=False) 37 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 38 kmax: int = 30 39 """Maximum number of k-points to consider.""" 40 kthr: float = 1e-6 41 """Threshold for k-point filtering.""" 42 k_space: bool = False 43 """Whether to generate k-space information for the graph.""" 44 mult_size: float = 1.05 45 """Multiplicative factor for resizing the nblist.""" 46 # covalent_cutoff: bool = False 47 48 FPID: ClassVar[str] = "GRAPH" 49 50 def init(self): 51 return FrozenDict( 52 { 53 "max_nat": 1, 54 "npairs": 1, 55 "nblist_mult_size": self.mult_size, 56 } 57 ) 58 59 def get_processor(self) -> Tuple[nn.Module, Dict]: 60 return GraphProcessor, { 61 "cutoff": self.cutoff, 62 "graph_key": self.graph_key, 63 "switch_params": self.switch_params, 64 "name": f"{self.graph_key}_Processor", 65 } 66 67 def get_graph_properties(self): 68 return { 69 self.graph_key: { 70 "cutoff": self.cutoff, 71 "directed": True, 72 } 73 } 74 75 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 76 """build a nblist on cpu with numpy and dynamic shapes + store max shapes""" 77 if self.graph_key in inputs: 78 graph = inputs[self.graph_key] 79 if "keep_graph" in graph: 80 return state, inputs 81 82 coords = np.array(inputs["coordinates"], dtype=np.float32) 83 natoms = np.array(inputs["natoms"], dtype=np.int32) 84 batch_index = np.array(inputs["batch_index"], dtype=np.int32) 85 86 new_state = {**state} 87 state_up = {} 88 89 mult_size = state.get("nblist_mult_size", self.mult_size) 90 assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0" 91 92 if natoms.shape[0] == 1: 93 max_nat = coords.shape[0] 94 true_max_nat = max_nat 95 else: 96 max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0])) 97 true_max_nat = int(np.max(natoms)) 98 if true_max_nat > max_nat: 99 add_atoms = state.get("add_atoms", 0) 100 new_maxnat = true_max_nat + add_atoms 101 state_up["max_nat"] = (new_maxnat, max_nat) 102 new_state["max_nat"] = new_maxnat 103 104 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 105 106 ### compute indices of all pairs 107 p1, p2 = np.triu_indices(true_max_nat, 1) 108 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 109 pbc_shifts = None 110 if natoms.shape[0] > 1: 111 ## batching => mask irrelevant pairs 112 mask_p12 = ( 113 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 114 ).flatten() 115 shift = np.concatenate( 116 (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32)) 117 ) 118 p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 119 p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 120 121 apply_pbc = "cells" in inputs 122 if not apply_pbc: 123 ### NO PBC 124 vec = coords[p2] - coords[p1] 125 else: 126 cells = np.array(inputs["cells"], dtype=np.float32) 127 reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32) 128 minimage = "minimum_image" in inputs.get("flags", {}) 129 if minimage: 130 ## MINIMUM IMAGE CONVENTION 131 vec = coords[p2] - coords[p1] 132 if cells.shape[0] == 1: 133 vecpbc = np.dot(vec, reciprocal_cells[0]) 134 pbc_shifts = -np.round(vecpbc) 135 vec = vec + np.dot(pbc_shifts, cells[0]) 136 else: 137 batch_index_vec = batch_index[p1] 138 vecpbc = np.einsum( 139 "aj,aji->ai", vec, reciprocal_cells[batch_index_vec] 140 ) 141 pbc_shifts = -np.round(vecpbc) 142 vec = vec + np.einsum( 143 "aj,aji->ai", pbc_shifts, cells[batch_index_vec] 144 ) 145 else: 146 ### GENERAL PBC 147 ## put all atoms in central box 148 if cells.shape[0] == 1: 149 coords_pbc = np.dot(coords, reciprocal_cells[0]) 150 at_shifts = -np.floor(coords_pbc) 151 coords_pbc = coords + np.dot(at_shifts, cells[0]) 152 else: 153 coords_pbc = np.einsum( 154 "aj,aji->ai", coords, reciprocal_cells[batch_index] 155 ) 156 at_shifts = -np.floor(coords_pbc) 157 coords_pbc = coords + np.einsum( 158 "aj,aji->ai", at_shifts, cells[batch_index] 159 ) 160 vec = coords_pbc[p2] - coords_pbc[p1] 161 162 ## compute maximum number of repeats 163 inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5 164 cdinv = cutoff_skin * inv_distances 165 num_repeats_all = np.ceil(cdinv).astype(np.int32) 166 if "true_sys" in inputs: 167 num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0) 168 # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all) 169 num_repeats = np.max(num_repeats_all, axis=0) 170 num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0))) 171 if np.any(num_repeats > num_repeats_prev): 172 num_repeats_new = np.maximum(num_repeats, num_repeats_prev) 173 state_up["num_repeats_pbc"] = ( 174 tuple(num_repeats_new), 175 tuple(num_repeats_prev), 176 ) 177 new_state["num_repeats_pbc"] = tuple(num_repeats_new) 178 ## build all possible shifts 179 cell_shift_pbc = np.array( 180 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 181 dtype=cells.dtype, 182 ).T.reshape(-1, 3) 183 ## shift applied to vectors 184 if cells.shape[0] == 1: 185 dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :] 186 vec = (vec[:, None, :] + dvec).reshape(-1, 3) 187 pbc_shifts = np.broadcast_to( 188 cell_shift_pbc[None, :, :], 189 (p1.shape[0], cell_shift_pbc.shape[0], 3), 190 ).reshape(-1, 3) 191 p1 = np.broadcast_to( 192 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 193 ).flatten() 194 p2 = np.broadcast_to( 195 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 196 ).flatten() 197 if natoms.shape[0] > 1: 198 mask_p12 = np.broadcast_to( 199 mask_p12[:, None], 200 (mask_p12.shape[0], cell_shift_pbc.shape[0]), 201 ).flatten() 202 else: 203 dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells) 204 205 ## get pbc shifts specific to each box 206 cell_shift_pbc = np.broadcast_to( 207 cell_shift_pbc[None, :, :], 208 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 209 ) 210 mask = np.all( 211 np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 212 ).flatten() 213 idx = np.nonzero(mask)[0] 214 nshifts = idx.shape[0] 215 nshifts_prev = state.get("nshifts_pbc", 0) 216 if nshifts > nshifts_prev or add_margin: 217 nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1 218 state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev) 219 new_state["nshifts_pbc"] = nshifts_new 220 221 dvec_filter = dvec.reshape(-1, 3)[idx, :] 222 cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :] 223 224 ## get batch shift in the dvec_filter array 225 nrep = np.prod(2 * num_repeats_all + 1, axis=1) 226 bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1])) 227 228 ## compute vectors 229 batch_index_vec = batch_index[p1] 230 nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0) 231 vec = vec.repeat(nrep_vec, axis=0) 232 nvec_pbc = nrep_vec.sum() #vec.shape[0] 233 nvec_pbc_prev = state.get("nvec_pbc", 0) 234 if nvec_pbc > nvec_pbc_prev or add_margin: 235 nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1 236 state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev) 237 new_state["nvec_pbc"] = nvec_pbc_new 238 239 # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev) 240 ## get shift index 241 dshift = np.concatenate( 242 (np.array([0]), np.cumsum(nrep_vec)[:-1]) 243 ).repeat(nrep_vec) 244 # ishift = np.arange(dshift.shape[0])-dshift 245 # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec) 246 icellshift = ( 247 np.arange(dshift.shape[0]) 248 - dshift 249 + bshift[batch_index_vec].repeat(nrep_vec) 250 ) 251 # shift vectors 252 vec = vec + dvec_filter[icellshift] 253 pbc_shifts = cell_shift_pbc_filter[icellshift] 254 255 p1 = np.repeat(p1, nrep_vec) 256 p2 = np.repeat(p2, nrep_vec) 257 if natoms.shape[0] > 1: 258 mask_p12 = np.repeat(mask_p12, nrep_vec) 259 260 ## compute distances 261 d12 = (vec**2).sum(axis=-1) 262 if natoms.shape[0] > 1: 263 d12 = np.where(mask_p12, d12, cutoff_skin**2) 264 265 ## filter pairs 266 max_pairs = state.get("npairs", 1) 267 mask = d12 < cutoff_skin**2 268 idx = np.nonzero(mask)[0] 269 npairs = idx.shape[0] 270 if npairs > max_pairs or add_margin: 271 prev_max_pairs = max_pairs 272 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 273 state_up["npairs"] = (max_pairs, prev_max_pairs) 274 new_state["npairs"] = max_pairs 275 276 nat = coords.shape[0] 277 edge_src = np.full(max_pairs, nat, dtype=np.int32) 278 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 279 d12_ = np.full(max_pairs, cutoff_skin**2) 280 edge_src[:npairs] = p1[idx] 281 edge_dst[:npairs] = p2[idx] 282 d12_[:npairs] = d12[idx] 283 d12 = d12_ 284 285 if apply_pbc: 286 pbc_shifts_ = np.zeros((max_pairs, 3)) 287 pbc_shifts_[:npairs] = pbc_shifts[idx] 288 pbc_shifts = pbc_shifts_ 289 if not minimage: 290 pbc_shifts[:npairs] = ( 291 pbc_shifts[:npairs] 292 + at_shifts[edge_dst[:npairs]] 293 - at_shifts[edge_src[:npairs]] 294 ) 295 296 if "nblist_skin" in state: 297 edge_src_skin = edge_src 298 edge_dst_skin = edge_dst 299 if apply_pbc: 300 pbc_shifts_skin = pbc_shifts 301 max_pairs_skin = state.get("npairs_skin", 1) 302 mask = d12 < self.cutoff**2 303 idx = np.nonzero(mask)[0] 304 npairs_skin = idx.shape[0] 305 if npairs_skin > max_pairs_skin or add_margin: 306 prev_max_pairs_skin = max_pairs_skin 307 max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1 308 state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin) 309 new_state["npairs_skin"] = max_pairs_skin 310 edge_src = np.full(max_pairs_skin, nat, dtype=np.int32) 311 edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32) 312 d12_ = np.full(max_pairs_skin, self.cutoff**2) 313 edge_src[:npairs_skin] = edge_src_skin[idx] 314 edge_dst[:npairs_skin] = edge_dst_skin[idx] 315 d12_[:npairs_skin] = d12[idx] 316 d12 = d12_ 317 if apply_pbc: 318 pbc_shifts = np.full((max_pairs_skin, 3), 0.0) 319 pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx] 320 321 ## symmetrize 322 edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate( 323 (edge_dst, edge_src) 324 ) 325 d12 = np.concatenate((d12, d12)) 326 if apply_pbc: 327 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts)) 328 329 graph = inputs.get(self.graph_key, {}) 330 graph_out = { 331 **graph, 332 "edge_src": edge_src, 333 "edge_dst": edge_dst, 334 "d12": d12, 335 "overflow": False, 336 "pbc_shifts": pbc_shifts, 337 } 338 if "nblist_skin" in state: 339 graph_out["edge_src_skin"] = edge_src_skin 340 graph_out["edge_dst_skin"] = edge_dst_skin 341 if apply_pbc: 342 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 343 344 if self.k_space and apply_pbc: 345 if "k_points" not in graph: 346 ks, _, _, bewald = get_reciprocal_space_parameters( 347 reciprocal_cells, self.cutoff, self.kmax, self.kthr 348 ) 349 graph_out["k_points"] = ks 350 graph_out["b_ewald"] = bewald 351 352 output = {**inputs, self.graph_key: graph_out} 353 354 if return_state_update: 355 return FrozenDict(new_state), output, state_up 356 return FrozenDict(new_state), output 357 358 def check_reallocate(self, state, inputs, parent_overflow=False): 359 """check for overflow and reallocate nblist if necessary""" 360 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 361 if not overflow: 362 return state, {}, inputs, False 363 364 add_margin = inputs[self.graph_key].get("overflow", False) 365 state, inputs, state_up = self( 366 state, inputs, return_state_update=True, add_margin=add_margin 367 ) 368 return state, state_up, inputs, True 369 370 @partial(jax.jit, static_argnums=(0, 1)) 371 def process(self, state, inputs): 372 """build a nblist on accelerator with jax and precomputed shapes""" 373 if self.graph_key in inputs: 374 graph = inputs[self.graph_key] 375 if "keep_graph" in graph: 376 return inputs 377 coords = inputs["coordinates"] 378 natoms = inputs["natoms"] 379 batch_index = inputs["batch_index"] 380 381 if natoms.shape[0] == 1: 382 max_nat = coords.shape[0] 383 else: 384 max_nat = state.get( 385 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 386 ) 387 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 388 389 ### compute indices of all pairs 390 p1, p2 = np.triu_indices(max_nat, 1) 391 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 392 pbc_shifts = None 393 if natoms.shape[0] > 1: 394 ## batching => mask irrelevant pairs 395 mask_p12 = ( 396 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 397 ).flatten() 398 shift = jnp.concatenate( 399 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 400 ) 401 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 402 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 403 404 ## compute vectors 405 overflow_repeats = jnp.asarray(False, dtype=bool) 406 if "cells" not in inputs: 407 vec = coords[p2] - coords[p1] 408 else: 409 cells = inputs["cells"] 410 reciprocal_cells = inputs["reciprocal_cells"] 411 # minimage = state.get("minimum_image", True) 412 minimage = "minimum_image" in inputs.get("flags", {}) 413 414 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 415 vecpbc = jnp.dot(vec, reciprocal_cell) 416 if mode == "round": 417 pbc_shifts = -jnp.round(vecpbc) 418 elif mode == "floor": 419 pbc_shifts = -jnp.floor(vecpbc) 420 else: 421 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 422 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 423 424 if minimage: 425 ## minimum image convention 426 vec = coords[p2] - coords[p1] 427 428 if cells.shape[0] == 1: 429 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 430 else: 431 batch_index_vec = batch_index[p1] 432 vec, pbc_shifts = jax.vmap(compute_pbc)( 433 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 434 ) 435 else: 436 ### general PBC only for single cell yet 437 # if cells.shape[0] > 1: 438 # raise NotImplementedError( 439 # "General PBC not implemented for batches on accelerator." 440 # ) 441 # cell = cells[0] 442 # reciprocal_cell = reciprocal_cells[0] 443 444 ## put all atoms in central box 445 if cells.shape[0] == 1: 446 coords_pbc, at_shifts = compute_pbc( 447 coords, reciprocal_cells[0], cells[0], mode="floor" 448 ) 449 else: 450 coords_pbc, at_shifts = jax.vmap( 451 partial(compute_pbc, mode="floor") 452 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 453 vec = coords_pbc[p2] - coords_pbc[p1] 454 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 455 # if num_repeats is None: 456 # raise ValueError( 457 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 458 # ) 459 # check if num_repeats is larger than previous 460 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 461 cdinv = cutoff_skin * inv_distances 462 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 463 if "true_sys" in inputs: 464 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 465 num_repeats_new = jnp.max(num_repeats_all, axis=0) 466 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 467 468 cell_shift_pbc = jnp.asarray( 469 np.array( 470 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 471 dtype=cells.dtype, 472 ).T.reshape(-1, 3) 473 ) 474 475 if cells.shape[0] == 1: 476 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 477 pbc_shifts = jnp.broadcast_to( 478 cell_shift_pbc[None, :, :], 479 (p1.shape[0], cell_shift_pbc.shape[0], 3), 480 ).reshape(-1, 3) 481 p1 = jnp.broadcast_to( 482 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 483 ).flatten() 484 p2 = jnp.broadcast_to( 485 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 486 ).flatten() 487 if natoms.shape[0] > 1: 488 mask_p12 = jnp.broadcast_to( 489 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 490 ).flatten() 491 else: 492 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 493 494 ## get pbc shifts specific to each box 495 cell_shift_pbc = jnp.broadcast_to( 496 cell_shift_pbc[None, :, :], 497 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 498 ) 499 mask = jnp.all( 500 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 501 ).flatten() 502 max_shifts = state.get("nshifts_pbc", 1) 503 504 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 505 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 506 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 507 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 508 mask, 509 max_shifts, 510 (dvecx, 0.), 511 (dvecy, 0.), 512 (dvecz, 0.), 513 (shiftx, 0), 514 (shifty, 0), 515 (shiftz, 0), 516 ) 517 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 518 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 519 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 520 521 ## get batch shift in the dvec_filter array 522 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 523 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 524 525 ## repeat vectors 526 nvec_max = state.get("nvec_pbc", 1) 527 batch_index_vec = batch_index[p1] 528 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 529 nvec = nrep_vec.sum() 530 overflow_repeats = overflow_repeats | (nvec > nvec_max) 531 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 532 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 533 534 ## get shift index 535 dshift = jnp.concatenate( 536 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 537 ) 538 if nrep_vec.size == 0: 539 dshift = jnp.array([],dtype=jnp.int32) 540 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 541 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 542 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 543 vec = vec + dvec[icellshift] 544 pbc_shifts = cell_shift_pbc[icellshift] 545 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 546 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 547 if natoms.shape[0] > 1: 548 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 549 550 551 ## compute distances 552 d12 = (vec**2).sum(axis=-1) 553 if natoms.shape[0] > 1: 554 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 555 556 ## filter pairs 557 max_pairs = state.get("npairs", 1) 558 mask = d12 < cutoff_skin**2 559 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 560 mask, 561 max_pairs, 562 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 563 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 564 (d12, cutoff_skin**2), 565 ) 566 if "cells" in inputs: 567 pbc_shifts = ( 568 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 569 .at[scatter_idx] 570 .set(pbc_shifts, mode="drop") 571 ) 572 if not minimage: 573 pbc_shifts = ( 574 pbc_shifts 575 + at_shifts.at[edge_dst].get(fill_value=0.0) 576 - at_shifts.at[edge_src].get(fill_value=0.0) 577 ) 578 579 ## check for overflow 580 if natoms.shape[0] == 1: 581 true_max_nat = coords.shape[0] 582 else: 583 true_max_nat = jnp.max(natoms) 584 overflow_count = npairs > max_pairs 585 overflow_at = true_max_nat > max_nat 586 overflow = overflow_count | overflow_at | overflow_repeats 587 588 if "nblist_skin" in state: 589 # edge_mask_skin = edge_mask 590 edge_src_skin = edge_src 591 edge_dst_skin = edge_dst 592 if "cells" in inputs: 593 pbc_shifts_skin = pbc_shifts 594 max_pairs_skin = state.get("npairs_skin", 1) 595 mask = d12 < self.cutoff**2 596 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 597 mask, 598 max_pairs_skin, 599 (edge_src, coords.shape[0]), 600 (edge_dst, coords.shape[0]), 601 (d12, self.cutoff**2), 602 ) 603 if "cells" in inputs: 604 pbc_shifts = ( 605 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 606 .at[scatter_idx] 607 .set(pbc_shifts, mode="drop") 608 ) 609 overflow = overflow | (npairs_skin > max_pairs_skin) 610 611 ## symmetrize 612 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 613 (edge_dst, edge_src) 614 ) 615 d12 = jnp.concatenate((d12, d12)) 616 if "cells" in inputs: 617 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 618 619 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 620 graph_out = { 621 **graph, 622 "edge_src": edge_src, 623 "edge_dst": edge_dst, 624 "d12": d12, 625 "overflow": overflow, 626 "pbc_shifts": pbc_shifts, 627 } 628 if "nblist_skin" in state: 629 graph_out["edge_src_skin"] = edge_src_skin 630 graph_out["edge_dst_skin"] = edge_dst_skin 631 if "cells" in inputs: 632 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 633 634 if self.k_space and "cells" in inputs: 635 if "k_points" not in graph: 636 raise NotImplementedError( 637 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 638 ) 639 return {**inputs, self.graph_key: graph_out} 640 641 @partial(jax.jit, static_argnums=(0,)) 642 def update_skin(self, inputs): 643 """update the nblist without recomputing the full nblist""" 644 graph = inputs[self.graph_key] 645 646 edge_src_skin = graph["edge_src_skin"] 647 edge_dst_skin = graph["edge_dst_skin"] 648 coords = inputs["coordinates"] 649 vec = coords.at[edge_dst_skin].get( 650 mode="fill", fill_value=self.cutoff 651 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 652 653 if "cells" in inputs: 654 pbc_shifts_skin = graph["pbc_shifts_skin"] 655 cells = inputs["cells"] 656 if cells.shape[0] == 1: 657 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 658 else: 659 batch_index_vec = inputs["batch_index"][edge_src_skin] 660 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 661 662 nat = coords.shape[0] 663 d12 = jnp.sum(vec**2, axis=-1) 664 mask = d12 < self.cutoff**2 665 max_pairs = graph["edge_src"].shape[0] // 2 666 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 667 mask, 668 max_pairs, 669 (edge_src_skin, nat), 670 (edge_dst_skin, nat), 671 (d12, self.cutoff**2), 672 ) 673 if "cells" in inputs: 674 pbc_shifts = ( 675 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 676 .at[scatter_idx] 677 .set(pbc_shifts_skin) 678 ) 679 680 overflow = graph.get("overflow", False) | (npairs > max_pairs) 681 graph_out = { 682 **graph, 683 "edge_src": jnp.concatenate((edge_src, edge_dst)), 684 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 685 "d12": jnp.concatenate((d12, d12)), 686 "overflow": overflow, 687 } 688 if "cells" in inputs: 689 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 690 691 if self.k_space and "cells" in inputs: 692 if "k_points" not in graph: 693 raise NotImplementedError( 694 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 695 ) 696 697 return {**inputs, self.graph_key: graph_out}
Generate a graph from a set of coordinates
FPID: GRAPH
For now, we generate all pairs of atoms and filter based on cutoff.
If a nblist_skin
is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the update_skin
method to update the original graph without recomputing the full nblist.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
358 def check_reallocate(self, state, inputs, parent_overflow=False): 359 """check for overflow and reallocate nblist if necessary""" 360 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 361 if not overflow: 362 return state, {}, inputs, False 363 364 add_margin = inputs[self.graph_key].get("overflow", False) 365 state, inputs, state_up = self( 366 state, inputs, return_state_update=True, add_margin=add_margin 367 ) 368 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
370 @partial(jax.jit, static_argnums=(0, 1)) 371 def process(self, state, inputs): 372 """build a nblist on accelerator with jax and precomputed shapes""" 373 if self.graph_key in inputs: 374 graph = inputs[self.graph_key] 375 if "keep_graph" in graph: 376 return inputs 377 coords = inputs["coordinates"] 378 natoms = inputs["natoms"] 379 batch_index = inputs["batch_index"] 380 381 if natoms.shape[0] == 1: 382 max_nat = coords.shape[0] 383 else: 384 max_nat = state.get( 385 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 386 ) 387 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 388 389 ### compute indices of all pairs 390 p1, p2 = np.triu_indices(max_nat, 1) 391 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 392 pbc_shifts = None 393 if natoms.shape[0] > 1: 394 ## batching => mask irrelevant pairs 395 mask_p12 = ( 396 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 397 ).flatten() 398 shift = jnp.concatenate( 399 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 400 ) 401 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 402 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 403 404 ## compute vectors 405 overflow_repeats = jnp.asarray(False, dtype=bool) 406 if "cells" not in inputs: 407 vec = coords[p2] - coords[p1] 408 else: 409 cells = inputs["cells"] 410 reciprocal_cells = inputs["reciprocal_cells"] 411 # minimage = state.get("minimum_image", True) 412 minimage = "minimum_image" in inputs.get("flags", {}) 413 414 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 415 vecpbc = jnp.dot(vec, reciprocal_cell) 416 if mode == "round": 417 pbc_shifts = -jnp.round(vecpbc) 418 elif mode == "floor": 419 pbc_shifts = -jnp.floor(vecpbc) 420 else: 421 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 422 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 423 424 if minimage: 425 ## minimum image convention 426 vec = coords[p2] - coords[p1] 427 428 if cells.shape[0] == 1: 429 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 430 else: 431 batch_index_vec = batch_index[p1] 432 vec, pbc_shifts = jax.vmap(compute_pbc)( 433 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 434 ) 435 else: 436 ### general PBC only for single cell yet 437 # if cells.shape[0] > 1: 438 # raise NotImplementedError( 439 # "General PBC not implemented for batches on accelerator." 440 # ) 441 # cell = cells[0] 442 # reciprocal_cell = reciprocal_cells[0] 443 444 ## put all atoms in central box 445 if cells.shape[0] == 1: 446 coords_pbc, at_shifts = compute_pbc( 447 coords, reciprocal_cells[0], cells[0], mode="floor" 448 ) 449 else: 450 coords_pbc, at_shifts = jax.vmap( 451 partial(compute_pbc, mode="floor") 452 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 453 vec = coords_pbc[p2] - coords_pbc[p1] 454 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 455 # if num_repeats is None: 456 # raise ValueError( 457 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 458 # ) 459 # check if num_repeats is larger than previous 460 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 461 cdinv = cutoff_skin * inv_distances 462 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 463 if "true_sys" in inputs: 464 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 465 num_repeats_new = jnp.max(num_repeats_all, axis=0) 466 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 467 468 cell_shift_pbc = jnp.asarray( 469 np.array( 470 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 471 dtype=cells.dtype, 472 ).T.reshape(-1, 3) 473 ) 474 475 if cells.shape[0] == 1: 476 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 477 pbc_shifts = jnp.broadcast_to( 478 cell_shift_pbc[None, :, :], 479 (p1.shape[0], cell_shift_pbc.shape[0], 3), 480 ).reshape(-1, 3) 481 p1 = jnp.broadcast_to( 482 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 483 ).flatten() 484 p2 = jnp.broadcast_to( 485 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 486 ).flatten() 487 if natoms.shape[0] > 1: 488 mask_p12 = jnp.broadcast_to( 489 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 490 ).flatten() 491 else: 492 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 493 494 ## get pbc shifts specific to each box 495 cell_shift_pbc = jnp.broadcast_to( 496 cell_shift_pbc[None, :, :], 497 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 498 ) 499 mask = jnp.all( 500 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 501 ).flatten() 502 max_shifts = state.get("nshifts_pbc", 1) 503 504 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 505 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 506 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 507 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 508 mask, 509 max_shifts, 510 (dvecx, 0.), 511 (dvecy, 0.), 512 (dvecz, 0.), 513 (shiftx, 0), 514 (shifty, 0), 515 (shiftz, 0), 516 ) 517 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 518 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 519 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 520 521 ## get batch shift in the dvec_filter array 522 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 523 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 524 525 ## repeat vectors 526 nvec_max = state.get("nvec_pbc", 1) 527 batch_index_vec = batch_index[p1] 528 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 529 nvec = nrep_vec.sum() 530 overflow_repeats = overflow_repeats | (nvec > nvec_max) 531 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 532 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 533 534 ## get shift index 535 dshift = jnp.concatenate( 536 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 537 ) 538 if nrep_vec.size == 0: 539 dshift = jnp.array([],dtype=jnp.int32) 540 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 541 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 542 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 543 vec = vec + dvec[icellshift] 544 pbc_shifts = cell_shift_pbc[icellshift] 545 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 546 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 547 if natoms.shape[0] > 1: 548 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 549 550 551 ## compute distances 552 d12 = (vec**2).sum(axis=-1) 553 if natoms.shape[0] > 1: 554 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 555 556 ## filter pairs 557 max_pairs = state.get("npairs", 1) 558 mask = d12 < cutoff_skin**2 559 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 560 mask, 561 max_pairs, 562 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 563 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 564 (d12, cutoff_skin**2), 565 ) 566 if "cells" in inputs: 567 pbc_shifts = ( 568 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 569 .at[scatter_idx] 570 .set(pbc_shifts, mode="drop") 571 ) 572 if not minimage: 573 pbc_shifts = ( 574 pbc_shifts 575 + at_shifts.at[edge_dst].get(fill_value=0.0) 576 - at_shifts.at[edge_src].get(fill_value=0.0) 577 ) 578 579 ## check for overflow 580 if natoms.shape[0] == 1: 581 true_max_nat = coords.shape[0] 582 else: 583 true_max_nat = jnp.max(natoms) 584 overflow_count = npairs > max_pairs 585 overflow_at = true_max_nat > max_nat 586 overflow = overflow_count | overflow_at | overflow_repeats 587 588 if "nblist_skin" in state: 589 # edge_mask_skin = edge_mask 590 edge_src_skin = edge_src 591 edge_dst_skin = edge_dst 592 if "cells" in inputs: 593 pbc_shifts_skin = pbc_shifts 594 max_pairs_skin = state.get("npairs_skin", 1) 595 mask = d12 < self.cutoff**2 596 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 597 mask, 598 max_pairs_skin, 599 (edge_src, coords.shape[0]), 600 (edge_dst, coords.shape[0]), 601 (d12, self.cutoff**2), 602 ) 603 if "cells" in inputs: 604 pbc_shifts = ( 605 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 606 .at[scatter_idx] 607 .set(pbc_shifts, mode="drop") 608 ) 609 overflow = overflow | (npairs_skin > max_pairs_skin) 610 611 ## symmetrize 612 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 613 (edge_dst, edge_src) 614 ) 615 d12 = jnp.concatenate((d12, d12)) 616 if "cells" in inputs: 617 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 618 619 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 620 graph_out = { 621 **graph, 622 "edge_src": edge_src, 623 "edge_dst": edge_dst, 624 "d12": d12, 625 "overflow": overflow, 626 "pbc_shifts": pbc_shifts, 627 } 628 if "nblist_skin" in state: 629 graph_out["edge_src_skin"] = edge_src_skin 630 graph_out["edge_dst_skin"] = edge_dst_skin 631 if "cells" in inputs: 632 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 633 634 if self.k_space and "cells" in inputs: 635 if "k_points" not in graph: 636 raise NotImplementedError( 637 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 638 ) 639 return {**inputs, self.graph_key: graph_out}
build a nblist on accelerator with jax and precomputed shapes
641 @partial(jax.jit, static_argnums=(0,)) 642 def update_skin(self, inputs): 643 """update the nblist without recomputing the full nblist""" 644 graph = inputs[self.graph_key] 645 646 edge_src_skin = graph["edge_src_skin"] 647 edge_dst_skin = graph["edge_dst_skin"] 648 coords = inputs["coordinates"] 649 vec = coords.at[edge_dst_skin].get( 650 mode="fill", fill_value=self.cutoff 651 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 652 653 if "cells" in inputs: 654 pbc_shifts_skin = graph["pbc_shifts_skin"] 655 cells = inputs["cells"] 656 if cells.shape[0] == 1: 657 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 658 else: 659 batch_index_vec = inputs["batch_index"][edge_src_skin] 660 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 661 662 nat = coords.shape[0] 663 d12 = jnp.sum(vec**2, axis=-1) 664 mask = d12 < self.cutoff**2 665 max_pairs = graph["edge_src"].shape[0] // 2 666 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 667 mask, 668 max_pairs, 669 (edge_src_skin, nat), 670 (edge_dst_skin, nat), 671 (d12, self.cutoff**2), 672 ) 673 if "cells" in inputs: 674 pbc_shifts = ( 675 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 676 .at[scatter_idx] 677 .set(pbc_shifts_skin) 678 ) 679 680 overflow = graph.get("overflow", False) | (npairs > max_pairs) 681 graph_out = { 682 **graph, 683 "edge_src": jnp.concatenate((edge_src, edge_dst)), 684 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 685 "d12": jnp.concatenate((d12, d12)), 686 "overflow": overflow, 687 } 688 if "cells" in inputs: 689 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 690 691 if self.k_space and "cells" in inputs: 692 if "k_points" not in graph: 693 raise NotImplementedError( 694 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 695 ) 696 697 return {**inputs, self.graph_key: graph_out}
update the nblist without recomputing the full nblist
700class GraphProcessor(nn.Module): 701 """Process a pre-generated graph 702 703 The pre-generated graph should contain the following keys: 704 - edge_src: source indices of the edges 705 - edge_dst: destination indices of the edges 706 - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs) 707 708 This module is automatically added to a FENNIX model when a GraphGenerator is used. 709 710 """ 711 712 cutoff: float 713 """Cutoff distance for the graph.""" 714 graph_key: str = "graph" 715 """Key of the graph in the outputs.""" 716 switch_params: dict = dataclasses.field(default_factory=dict) 717 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 718 719 @nn.compact 720 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 721 graph = inputs[self.graph_key] 722 coords = inputs["coordinates"] 723 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 724 # edge_mask = edge_src < coords.shape[0] 725 vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[ 726 edge_src 727 ].get(mode="fill", fill_value=0.0) 728 if "cells" in inputs: 729 cells = inputs["cells"] 730 if cells.shape[0] == 1: 731 vec = vec + jnp.dot(graph["pbc_shifts"], cells[0]) 732 else: 733 batch_index_vec = inputs["batch_index"][edge_src] 734 vec = vec + jax.vmap(jnp.dot)( 735 graph["pbc_shifts"], cells[batch_index_vec] 736 ) 737 738 d2 = jnp.sum(vec**2, axis=-1) 739 distances = safe_sqrt(d2) 740 edge_mask = distances < self.cutoff 741 742 switch = SwitchFunction( 743 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 744 )((distances, edge_mask)) 745 746 graph_out = { 747 **graph, 748 "vec": vec, 749 "distances": distances, 750 "switch": switch, 751 "edge_mask": edge_mask, 752 } 753 754 if "alch_group" in inputs: 755 alch_group = inputs["alch_group"] 756 lambda_e = inputs["alch_elambda"] 757 lambda_v = inputs["alch_vlambda"] 758 mask = alch_group[edge_src] == alch_group[edge_dst] 759 graph_out["switch_raw"] = switch 760 graph_out["switch"] = jnp.where( 761 mask, 762 switch, 763 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 764 ) 765 graph_out["distances_raw"] = distances 766 if "alch_softcore_e" in inputs: 767 alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2 768 else: 769 alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2 770 771 graph_out["distances"] = jnp.where( 772 mask, 773 distances, 774 safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2)) 775 ) 776 777 778 return {**inputs, self.graph_key: graph_out}
Process a pre-generated graph
The pre-generated graph should contain the following keys:
- edge_src: source indices of the edges
- edge_dst: destination indices of the edges
- pbcs_shifts: pbc shifts for the edges (only if
cells
are present in the inputs)
This module is automatically added to a FENNIX model when a GraphGenerator is used.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
781@dataclasses.dataclass(frozen=True) 782class GraphFilter: 783 """Filter a graph based on a cutoff distance 784 785 FPID: GRAPH_FILTER 786 """ 787 788 cutoff: float 789 """Cutoff distance for the filtering.""" 790 parent_graph: str 791 """Key of the parent graph in the inputs.""" 792 graph_key: str 793 """Key of the filtered graph in the outputs.""" 794 remove_hydrogens: int = False 795 """Remove edges where the source is a hydrogen atom.""" 796 switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict) 797 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 798 k_space: bool = False 799 """Generate k-space information for the graph.""" 800 kmax: int = 30 801 """Maximum number of k-points to consider.""" 802 kthr: float = 1e-6 803 """Threshold for k-point filtering.""" 804 mult_size: float = 1.05 805 """Multiplicative factor for resizing the nblist.""" 806 807 FPID: ClassVar[str] = "GRAPH_FILTER" 808 809 def init(self): 810 return FrozenDict( 811 { 812 "npairs": 1, 813 "nblist_mult_size": self.mult_size, 814 } 815 ) 816 817 def get_processor(self) -> Tuple[nn.Module, Dict]: 818 return GraphFilterProcessor, { 819 "cutoff": self.cutoff, 820 "graph_key": self.graph_key, 821 "parent_graph": self.parent_graph, 822 "name": f"{self.graph_key}_Filter_{self.parent_graph}", 823 "switch_params": self.switch_params, 824 } 825 826 def get_graph_properties(self): 827 return { 828 self.graph_key: { 829 "cutoff": self.cutoff, 830 "directed": True, 831 "parent_graph": self.parent_graph, 832 } 833 } 834 835 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 836 """filter a nblist on cpu with numpy and dynamic shapes + store max shapes""" 837 graph_in = inputs[self.parent_graph] 838 nat = inputs["species"].shape[0] 839 840 new_state = {**state} 841 state_up = {} 842 mult_size = state.get("nblist_mult_size", self.mult_size) 843 assert mult_size >= 1., "nblist_mult_size should be >= 1." 844 845 edge_src = np.array(graph_in["edge_src"], dtype=np.int32) 846 d12 = np.array(graph_in["d12"], dtype=np.float32) 847 if self.remove_hydrogens: 848 species = inputs["species"] 849 src_idx = (edge_src < nat).nonzero()[0] 850 mask = np.zeros(edge_src.shape[0], dtype=bool) 851 mask[src_idx] = (species > 1)[edge_src[src_idx]] 852 d12 = np.where(mask, d12, self.cutoff**2) 853 mask = d12 < self.cutoff**2 854 855 max_pairs = state.get("npairs", 1) 856 idx = np.nonzero(mask)[0] 857 npairs = idx.shape[0] 858 if npairs > max_pairs or add_margin: 859 prev_max_pairs = max_pairs 860 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 861 state_up["npairs"] = (max_pairs, prev_max_pairs) 862 new_state["npairs"] = max_pairs 863 864 filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32) 865 edge_src = np.full(max_pairs, nat, dtype=np.int32) 866 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 867 d12_ = np.full(max_pairs, self.cutoff**2) 868 filter_indices[:npairs] = idx 869 edge_src[:npairs] = graph_in["edge_src"][idx] 870 edge_dst[:npairs] = graph_in["edge_dst"][idx] 871 d12_[:npairs] = d12[idx] 872 d12 = d12_ 873 874 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 875 graph_out = { 876 **graph, 877 "edge_src": edge_src, 878 "edge_dst": edge_dst, 879 "filter_indices": filter_indices, 880 "d12": d12, 881 "overflow": False, 882 } 883 884 if self.k_space and "cells" in inputs: 885 if "k_points" not in graph: 886 ks, _, _, bewald = get_reciprocal_space_parameters( 887 inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr 888 ) 889 graph_out["k_points"] = ks 890 graph_out["b_ewald"] = bewald 891 892 output = {**inputs, self.graph_key: graph_out} 893 if return_state_update: 894 return FrozenDict(new_state), output, state_up 895 return FrozenDict(new_state), output 896 897 def check_reallocate(self, state, inputs, parent_overflow=False): 898 """check for overflow and reallocate nblist if necessary""" 899 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 900 if not overflow: 901 return state, {}, inputs, False 902 903 add_margin = inputs[self.graph_key].get("overflow", False) 904 state, inputs, state_up = self( 905 state, inputs, return_state_update=True, add_margin=add_margin 906 ) 907 return state, state_up, inputs, True 908 909 @partial(jax.jit, static_argnums=(0, 1)) 910 def process(self, state, inputs): 911 """filter a nblist on accelerator with jax and precomputed shapes""" 912 graph_in = inputs[self.parent_graph] 913 if state is None: 914 # skin update mode 915 graph = inputs[self.graph_key] 916 max_pairs = graph["edge_src"].shape[0] 917 else: 918 max_pairs = state.get("npairs", 1) 919 920 max_pairs_in = graph_in["edge_src"].shape[0] 921 nat = inputs["species"].shape[0] 922 923 edge_src = graph_in["edge_src"] 924 d12 = graph_in["d12"] 925 if self.remove_hydrogens: 926 species = inputs["species"] 927 mask = (species > 1)[edge_src] 928 d12 = jnp.where(mask, d12, self.cutoff**2) 929 mask = d12 < self.cutoff**2 930 931 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 932 mask, 933 max_pairs, 934 (edge_src, nat), 935 (graph_in["edge_dst"], nat), 936 (d12, self.cutoff**2), 937 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 938 ) 939 940 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 941 overflow = graph.get("overflow", False) | (npairs > max_pairs) 942 graph_out = { 943 **graph, 944 "edge_src": edge_src, 945 "edge_dst": edge_dst, 946 "filter_indices": filter_indices, 947 "d12": d12, 948 "overflow": overflow, 949 } 950 951 if self.k_space and "cells" in inputs: 952 if "k_points" not in graph: 953 raise NotImplementedError( 954 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 955 ) 956 957 return {**inputs, self.graph_key: graph_out} 958 959 @partial(jax.jit, static_argnums=(0,)) 960 def update_skin(self, inputs): 961 return self.process(None, inputs)
Filter a graph based on a cutoff distance
FPID: GRAPH_FILTER
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
897 def check_reallocate(self, state, inputs, parent_overflow=False): 898 """check for overflow and reallocate nblist if necessary""" 899 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 900 if not overflow: 901 return state, {}, inputs, False 902 903 add_margin = inputs[self.graph_key].get("overflow", False) 904 state, inputs, state_up = self( 905 state, inputs, return_state_update=True, add_margin=add_margin 906 ) 907 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
909 @partial(jax.jit, static_argnums=(0, 1)) 910 def process(self, state, inputs): 911 """filter a nblist on accelerator with jax and precomputed shapes""" 912 graph_in = inputs[self.parent_graph] 913 if state is None: 914 # skin update mode 915 graph = inputs[self.graph_key] 916 max_pairs = graph["edge_src"].shape[0] 917 else: 918 max_pairs = state.get("npairs", 1) 919 920 max_pairs_in = graph_in["edge_src"].shape[0] 921 nat = inputs["species"].shape[0] 922 923 edge_src = graph_in["edge_src"] 924 d12 = graph_in["d12"] 925 if self.remove_hydrogens: 926 species = inputs["species"] 927 mask = (species > 1)[edge_src] 928 d12 = jnp.where(mask, d12, self.cutoff**2) 929 mask = d12 < self.cutoff**2 930 931 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 932 mask, 933 max_pairs, 934 (edge_src, nat), 935 (graph_in["edge_dst"], nat), 936 (d12, self.cutoff**2), 937 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 938 ) 939 940 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 941 overflow = graph.get("overflow", False) | (npairs > max_pairs) 942 graph_out = { 943 **graph, 944 "edge_src": edge_src, 945 "edge_dst": edge_dst, 946 "filter_indices": filter_indices, 947 "d12": d12, 948 "overflow": overflow, 949 } 950 951 if self.k_space and "cells" in inputs: 952 if "k_points" not in graph: 953 raise NotImplementedError( 954 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 955 ) 956 957 return {**inputs, self.graph_key: graph_out}
filter a nblist on accelerator with jax and precomputed shapes
964class GraphFilterProcessor(nn.Module): 965 """Filter processing for a pre-generated graph 966 967 This module is automatically added to a FENNIX model when a GraphFilter is used. 968 """ 969 970 cutoff: float 971 """Cutoff distance for the filtering.""" 972 graph_key: str 973 """Key of the filtered graph in the inputs.""" 974 parent_graph: str 975 """Key of the parent graph in the inputs.""" 976 switch_params: dict = dataclasses.field(default_factory=dict) 977 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 978 979 @nn.compact 980 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 981 graph_in = inputs[self.parent_graph] 982 graph = inputs[self.graph_key] 983 984 d_key = "distances_raw" if "distances_raw" in graph else "distances" 985 986 if graph_in["vec"].shape[0] == 0: 987 vec = graph_in["vec"] 988 distances = graph_in[d_key] 989 filter_indices = jnp.asarray([], dtype=jnp.int32) 990 else: 991 filter_indices = graph["filter_indices"] 992 vec = ( 993 graph_in["vec"] 994 .at[filter_indices] 995 .get(mode="fill", fill_value=self.cutoff) 996 ) 997 distances = ( 998 graph_in[d_key] 999 .at[filter_indices] 1000 .get(mode="fill", fill_value=self.cutoff) 1001 ) 1002 1003 edge_mask = distances < self.cutoff 1004 switch = SwitchFunction( 1005 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 1006 )((distances, edge_mask)) 1007 1008 graph_out = { 1009 **graph, 1010 "vec": vec, 1011 "distances": distances, 1012 "switch": switch, 1013 "filter_indices": filter_indices, 1014 "edge_mask": edge_mask, 1015 } 1016 1017 if "alch_group" in inputs: 1018 edge_src=graph["edge_src"] 1019 edge_dst=graph["edge_dst"] 1020 alch_group = inputs["alch_group"] 1021 lambda_e = inputs["alch_elambda"] 1022 lambda_v = inputs["alch_vlambda"] 1023 mask = alch_group[edge_src] == alch_group[edge_dst] 1024 graph_out["switch_raw"] = switch 1025 graph_out["switch"] = jnp.where( 1026 mask, 1027 switch, 1028 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch , 1029 ) 1030 1031 graph_out["distances_raw"] = distances 1032 if "alch_softcore_e" in inputs: 1033 alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2 1034 else: 1035 alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2 1036 1037 graph_out["distances"] = jnp.where( 1038 mask, 1039 distances, 1040 safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2)) 1041 ) 1042 1043 1044 return {**inputs, self.graph_key: graph_out}
Filter processing for a pre-generated graph
This module is automatically added to a FENNIX model when a GraphFilter is used.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1047@dataclasses.dataclass(frozen=True) 1048class GraphAngularExtension: 1049 """Add angles list to a graph 1050 1051 FPID: GRAPH_ANGULAR_EXTENSION 1052 """ 1053 1054 mult_size: float = 1.05 1055 """Multiplicative factor for resizing the nblist.""" 1056 add_neigh: int = 5 1057 """Additional neighbors to add to the nblist when resizing.""" 1058 graph_key: str = "graph" 1059 """Key of the graph in the inputs.""" 1060 1061 FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION" 1062 1063 def init(self): 1064 return FrozenDict( 1065 { 1066 "nangles": 0, 1067 "nblist_mult_size": self.mult_size, 1068 "max_neigh": self.add_neigh, 1069 "add_neigh": self.add_neigh, 1070 } 1071 ) 1072 1073 def get_processor(self) -> Tuple[nn.Module, Dict]: 1074 return GraphAngleProcessor, { 1075 "graph_key": self.graph_key, 1076 "name": f"{self.graph_key}_AngleProcessor", 1077 } 1078 1079 def get_graph_properties(self): 1080 return { 1081 self.graph_key: { 1082 "has_angles": True, 1083 } 1084 } 1085 1086 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1087 """build angle nblist on cpu with numpy and dynamic shapes + store max shapes""" 1088 graph = inputs[self.graph_key] 1089 edge_src = np.array(graph["edge_src"], dtype=np.int32) 1090 1091 new_state = {**state} 1092 state_up = {} 1093 mult_size = state.get("nblist_mult_size", self.mult_size) 1094 assert mult_size >= 1., "nblist_mult_size should be >= 1." 1095 1096 ### count number of neighbors 1097 nat = inputs["species"].shape[0] 1098 count = np.zeros(nat + 1, dtype=np.int32) 1099 np.add.at(count, edge_src, 1) 1100 max_count = int(np.max(count[:-1])) 1101 1102 ### get sizes 1103 max_neigh = state.get("max_neigh", self.add_neigh) 1104 nedge = edge_src.shape[0] 1105 if max_count > max_neigh or add_margin: 1106 prev_max_neigh = max_neigh 1107 max_neigh = max(max_count, max_neigh) + state.get( 1108 "add_neigh", self.add_neigh 1109 ) 1110 state_up["max_neigh"] = (max_neigh, prev_max_neigh) 1111 new_state["max_neigh"] = max_neigh 1112 1113 max_neigh_arr = np.empty(max_neigh, dtype=bool) 1114 1115 nedge = edge_src.shape[0] 1116 1117 ### sort edge_src 1118 idx_sort = np.argsort(edge_src) 1119 edge_src_sorted = edge_src[idx_sort] 1120 1121 ### map sparse to dense nblist 1122 offset = np.tile(np.arange(max_count), nat) 1123 if max_count * nat >= nedge: 1124 offset = np.tile(np.arange(max_count), nat)[:nedge] 1125 else: 1126 offset = np.zeros(nedge, dtype=np.int32) 1127 offset[: max_count * nat] = np.tile(np.arange(max_count), nat) 1128 1129 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1130 mask = edge_src_sorted < nat 1131 indices = edge_src_sorted * max_count + offset 1132 indices = indices[mask] 1133 idx_sort = idx_sort[mask] 1134 edge_idx = np.full(nat * max_count, nedge, dtype=np.int32) 1135 edge_idx[indices] = idx_sort 1136 edge_idx = edge_idx.reshape(nat, max_count) 1137 1138 ### find all triplet for each atom center 1139 local_src, local_dst = np.triu_indices(max_count, 1) 1140 angle_src = edge_idx[:, local_src].flatten() 1141 angle_dst = edge_idx[:, local_dst].flatten() 1142 1143 ### mask for valid angles 1144 mask1 = angle_src < nedge 1145 mask2 = angle_dst < nedge 1146 angle_mask = mask1 & mask2 1147 1148 max_angles = state.get("nangles", 0) 1149 idx = np.nonzero(angle_mask)[0] 1150 nangles = idx.shape[0] 1151 if nangles > max_angles or add_margin: 1152 max_angles_prev = max_angles 1153 max_angles = int(mult_size * max(nangles, max_angles)) + 1 1154 state_up["nangles"] = (max_angles, max_angles_prev) 1155 new_state["nangles"] = max_angles 1156 1157 ## filter angles to sparse representation 1158 angle_src_ = np.full(max_angles, nedge, dtype=np.int32) 1159 angle_dst_ = np.full(max_angles, nedge, dtype=np.int32) 1160 angle_src_[:nangles] = angle_src[idx] 1161 angle_dst_[:nangles] = angle_dst[idx] 1162 1163 central_atom = np.full(max_angles, nat, dtype=np.int32) 1164 central_atom[:nangles] = edge_src[angle_src_[:nangles]] 1165 1166 ## update graph 1167 output = { 1168 **inputs, 1169 self.graph_key: { 1170 **graph, 1171 "angle_src": angle_src_, 1172 "angle_dst": angle_dst_, 1173 "central_atom": central_atom, 1174 "angle_overflow": False, 1175 "max_neigh": max_neigh, 1176 "__max_neigh_array": max_neigh_arr, 1177 }, 1178 } 1179 1180 if return_state_update: 1181 return FrozenDict(new_state), output, state_up 1182 return FrozenDict(new_state), output 1183 1184 def check_reallocate(self, state, inputs, parent_overflow=False): 1185 """check for overflow and reallocate nblist if necessary""" 1186 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1187 if not overflow: 1188 return state, {}, inputs, False 1189 1190 add_margin = inputs[self.graph_key]["angle_overflow"] 1191 state, inputs, state_up = self( 1192 state, inputs, return_state_update=True, add_margin=add_margin 1193 ) 1194 return state, state_up, inputs, True 1195 1196 @partial(jax.jit, static_argnums=(0, 1)) 1197 def process(self, state, inputs): 1198 """build angle nblist on accelerator with jax and precomputed shapes""" 1199 graph = inputs[self.graph_key] 1200 edge_src = graph["edge_src"] 1201 1202 ### count number of neighbors 1203 nat = inputs["species"].shape[0] 1204 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1205 max_count = jnp.max(count) 1206 1207 ### get sizes 1208 if state is None: 1209 max_neigh_arr = graph["__max_neigh_array"] 1210 max_neigh = max_neigh_arr.shape[0] 1211 prev_nangles = graph["angle_src"].shape[0] 1212 else: 1213 max_neigh = state.get("max_neigh", self.add_neigh) 1214 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1215 prev_nangles = state.get("nangles", 0) 1216 1217 nedge = edge_src.shape[0] 1218 1219 ### sort edge_src 1220 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1221 edge_src_sorted = edge_src[idx_sort] 1222 1223 ### map sparse to dense nblist 1224 if max_neigh * nat < nedge: 1225 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1226 offset = jnp.asarray( 1227 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1228 ) 1229 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1230 indices = edge_src_sorted * max_neigh + offset 1231 edge_idx = ( 1232 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1233 .at[indices] 1234 .set(idx_sort, mode="drop") 1235 .reshape(nat, max_neigh) 1236 ) 1237 1238 ### find all triplet for each atom center 1239 local_src, local_dst = np.triu_indices(max_neigh, 1) 1240 angle_src = edge_idx[:, local_src].flatten() 1241 angle_dst = edge_idx[:, local_dst].flatten() 1242 1243 ### mask for valid angles 1244 mask1 = angle_src < nedge 1245 mask2 = angle_dst < nedge 1246 angle_mask = mask1 & mask2 1247 1248 ## filter angles to sparse representation 1249 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1250 angle_mask, 1251 prev_nangles, 1252 (angle_src, nedge), 1253 (angle_dst, nedge), 1254 ) 1255 ## find central atom 1256 central_atom = edge_src[angle_src] 1257 1258 ## check for overflow 1259 angle_overflow = nangles > prev_nangles 1260 neigh_overflow = max_count > max_neigh 1261 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1262 1263 ## update graph 1264 output = { 1265 **inputs, 1266 self.graph_key: { 1267 **graph, 1268 "angle_src": angle_src, 1269 "angle_dst": angle_dst, 1270 "central_atom": central_atom, 1271 "angle_overflow": overflow, 1272 # "max_neigh": max_neigh, 1273 "__max_neigh_array": max_neigh_arr, 1274 }, 1275 } 1276 1277 return output 1278 1279 @partial(jax.jit, static_argnums=(0,)) 1280 def update_skin(self, inputs): 1281 return self.process(None, inputs)
Add angles list to a graph
FPID: GRAPH_ANGULAR_EXTENSION
1184 def check_reallocate(self, state, inputs, parent_overflow=False): 1185 """check for overflow and reallocate nblist if necessary""" 1186 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1187 if not overflow: 1188 return state, {}, inputs, False 1189 1190 add_margin = inputs[self.graph_key]["angle_overflow"] 1191 state, inputs, state_up = self( 1192 state, inputs, return_state_update=True, add_margin=add_margin 1193 ) 1194 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
1196 @partial(jax.jit, static_argnums=(0, 1)) 1197 def process(self, state, inputs): 1198 """build angle nblist on accelerator with jax and precomputed shapes""" 1199 graph = inputs[self.graph_key] 1200 edge_src = graph["edge_src"] 1201 1202 ### count number of neighbors 1203 nat = inputs["species"].shape[0] 1204 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1205 max_count = jnp.max(count) 1206 1207 ### get sizes 1208 if state is None: 1209 max_neigh_arr = graph["__max_neigh_array"] 1210 max_neigh = max_neigh_arr.shape[0] 1211 prev_nangles = graph["angle_src"].shape[0] 1212 else: 1213 max_neigh = state.get("max_neigh", self.add_neigh) 1214 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1215 prev_nangles = state.get("nangles", 0) 1216 1217 nedge = edge_src.shape[0] 1218 1219 ### sort edge_src 1220 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1221 edge_src_sorted = edge_src[idx_sort] 1222 1223 ### map sparse to dense nblist 1224 if max_neigh * nat < nedge: 1225 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1226 offset = jnp.asarray( 1227 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1228 ) 1229 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1230 indices = edge_src_sorted * max_neigh + offset 1231 edge_idx = ( 1232 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1233 .at[indices] 1234 .set(idx_sort, mode="drop") 1235 .reshape(nat, max_neigh) 1236 ) 1237 1238 ### find all triplet for each atom center 1239 local_src, local_dst = np.triu_indices(max_neigh, 1) 1240 angle_src = edge_idx[:, local_src].flatten() 1241 angle_dst = edge_idx[:, local_dst].flatten() 1242 1243 ### mask for valid angles 1244 mask1 = angle_src < nedge 1245 mask2 = angle_dst < nedge 1246 angle_mask = mask1 & mask2 1247 1248 ## filter angles to sparse representation 1249 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1250 angle_mask, 1251 prev_nangles, 1252 (angle_src, nedge), 1253 (angle_dst, nedge), 1254 ) 1255 ## find central atom 1256 central_atom = edge_src[angle_src] 1257 1258 ## check for overflow 1259 angle_overflow = nangles > prev_nangles 1260 neigh_overflow = max_count > max_neigh 1261 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1262 1263 ## update graph 1264 output = { 1265 **inputs, 1266 self.graph_key: { 1267 **graph, 1268 "angle_src": angle_src, 1269 "angle_dst": angle_dst, 1270 "central_atom": central_atom, 1271 "angle_overflow": overflow, 1272 # "max_neigh": max_neigh, 1273 "__max_neigh_array": max_neigh_arr, 1274 }, 1275 } 1276 1277 return output
build angle nblist on accelerator with jax and precomputed shapes
1284class GraphAngleProcessor(nn.Module): 1285 """Process a pre-generated graph to compute angles 1286 1287 This module is automatically added to a FENNIX model when a GraphAngularExtension is used. 1288 1289 """ 1290 1291 graph_key: str 1292 """Key of the graph in the inputs.""" 1293 1294 @nn.compact 1295 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1296 graph = inputs[self.graph_key] 1297 distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"] 1298 vec = graph["vec"] 1299 angle_src = graph["angle_src"] 1300 angle_dst = graph["angle_dst"] 1301 1302 dir = vec / jnp.clip(distances[:, None], min=1.0e-5) 1303 cos_angles = ( 1304 dir.at[angle_src].get(mode="fill", fill_value=0.5) 1305 * dir.at[angle_dst].get(mode="fill", fill_value=0.5) 1306 ).sum(axis=-1) 1307 1308 angles = jnp.arccos(0.95 * cos_angles) 1309 1310 return { 1311 **inputs, 1312 self.graph_key: { 1313 **graph, 1314 # "cos_angles": cos_angles, 1315 "angles": angles, 1316 # "angle_mask": angle_mask, 1317 }, 1318 }
Process a pre-generated graph to compute angles
This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1321@dataclasses.dataclass(frozen=True) 1322class SpeciesIndexer: 1323 """Build an index that splits atomic arrays by species. 1324 1325 FPID: SPECIES_INDEXER 1326 1327 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1328 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1329 1330 """ 1331 1332 output_key: str = "species_index" 1333 """Key for the output dictionary.""" 1334 species_order: Optional[str] = None 1335 """Comma separated list of species in the order they should be indexed.""" 1336 add_atoms: int = 0 1337 """Additional atoms to add to the sizes.""" 1338 add_atoms_margin: int = 10 1339 """Additional atoms to add to the sizes when adding margin.""" 1340 1341 FPID: ClassVar[str] = "SPECIES_INDEXER" 1342 1343 def init(self): 1344 return FrozenDict( 1345 { 1346 "sizes": {}, 1347 } 1348 ) 1349 1350 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1351 species = np.array(inputs["species"], dtype=np.int32) 1352 nat = species.shape[0] 1353 set_species, counts = np.unique(species, return_counts=True) 1354 1355 new_state = {**state} 1356 state_up = {} 1357 1358 sizes = state.get("sizes", FrozenDict({})) 1359 new_sizes = {**sizes} 1360 up_sizes = False 1361 counts_dict = {} 1362 for s, c in zip(set_species, counts): 1363 if s <= 0: 1364 continue 1365 counts_dict[s] = c 1366 if c > sizes.get(s, 0): 1367 up_sizes = True 1368 add_atoms = state.get("add_atoms", self.add_atoms) 1369 if add_margin: 1370 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1371 new_sizes[s] = c + add_atoms 1372 1373 new_sizes = FrozenDict(new_sizes) 1374 if up_sizes: 1375 state_up["sizes"] = (new_sizes, sizes) 1376 new_state["sizes"] = new_sizes 1377 1378 if self.species_order is not None: 1379 species_order = [el.strip() for el in self.species_order.split(",")] 1380 max_size_prev = state.get("max_size", 0) 1381 max_size = max(new_sizes.values()) 1382 if max_size > max_size_prev: 1383 state_up["max_size"] = (max_size, max_size_prev) 1384 new_state["max_size"] = max_size 1385 max_size_prev = max_size 1386 1387 species_index = np.full((len(species_order), max_size), nat, dtype=np.int32) 1388 for i, el in enumerate(species_order): 1389 s = PERIODIC_TABLE_REV_IDX[el] 1390 if s in counts_dict.keys(): 1391 species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1392 else: 1393 species_index = { 1394 PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32) 1395 for s, c in new_sizes.items() 1396 } 1397 for s, c in zip(set_species, counts): 1398 if s <= 0: 1399 continue 1400 species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0] 1401 1402 output = { 1403 **inputs, 1404 self.output_key: species_index, 1405 self.output_key + "_overflow": False, 1406 } 1407 1408 if return_state_update: 1409 return FrozenDict(new_state), output, state_up 1410 return FrozenDict(new_state), output 1411 1412 def check_reallocate(self, state, inputs, parent_overflow=False): 1413 """check for overflow and reallocate nblist if necessary""" 1414 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1415 if not overflow: 1416 return state, {}, inputs, False 1417 1418 add_margin = inputs[self.output_key + "_overflow"] 1419 state, inputs, state_up = self( 1420 state, inputs, return_state_update=True, add_margin=add_margin 1421 ) 1422 return state, state_up, inputs, True 1423 # return state, {}, inputs, parent_overflow 1424 1425 @partial(jax.jit, static_argnums=(0, 1)) 1426 def process(self, state, inputs): 1427 # assert ( 1428 # self.output_key in inputs 1429 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1430 1431 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1432 if self.output_key in inputs and not recompute_species_index: 1433 return inputs 1434 1435 if state is None: 1436 raise ValueError("Species Indexer state must be provided on accelerator.") 1437 1438 species = inputs["species"] 1439 nat = species.shape[0] 1440 1441 sizes = state["sizes"] 1442 1443 if self.species_order is not None: 1444 species_order = [el.strip() for el in self.species_order.split(",")] 1445 max_size = state["max_size"] 1446 1447 species_index = jnp.full( 1448 (len(species_order), max_size), nat, dtype=jnp.int32 1449 ) 1450 for i, el in enumerate(species_order): 1451 s = PERIODIC_TABLE_REV_IDX[el] 1452 if s in sizes.keys(): 1453 c = sizes[s] 1454 species_index = species_index.at[i, :].set( 1455 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1456 ) 1457 # if s in counts_dict.keys(): 1458 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1459 else: 1460 # species_index = { 1461 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1462 # for s, c in sizes.items() 1463 # } 1464 species_index = {} 1465 overflow = False 1466 natcount = 0 1467 for s, c in sizes.items(): 1468 mask = species == s 1469 new_size = jnp.sum(mask) 1470 natcount = natcount + new_size 1471 overflow = overflow | (new_size > c) # check if sizes are correct 1472 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1473 species == s, size=c, fill_value=nat 1474 )[0] 1475 1476 mask = species <= 0 1477 new_size = jnp.sum(mask) 1478 natcount = natcount + new_size 1479 overflow = overflow | ( 1480 natcount < species.shape[0] 1481 ) # check if any species missing 1482 1483 return { 1484 **inputs, 1485 self.output_key: species_index, 1486 self.output_key + "_overflow": overflow, 1487 } 1488 1489 @partial(jax.jit, static_argnums=(0,)) 1490 def update_skin(self, inputs): 1491 return self.process(None, inputs)
Build an index that splits atomic arrays by species.
FPID: SPECIES_INDEXER
If species_order
is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
If species_order
is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
Comma separated list of species in the order they should be indexed.
1412 def check_reallocate(self, state, inputs, parent_overflow=False): 1413 """check for overflow and reallocate nblist if necessary""" 1414 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1415 if not overflow: 1416 return state, {}, inputs, False 1417 1418 add_margin = inputs[self.output_key + "_overflow"] 1419 state, inputs, state_up = self( 1420 state, inputs, return_state_update=True, add_margin=add_margin 1421 ) 1422 return state, state_up, inputs, True 1423 # return state, {}, inputs, parent_overflow
check for overflow and reallocate nblist if necessary
1425 @partial(jax.jit, static_argnums=(0, 1)) 1426 def process(self, state, inputs): 1427 # assert ( 1428 # self.output_key in inputs 1429 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1430 1431 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1432 if self.output_key in inputs and not recompute_species_index: 1433 return inputs 1434 1435 if state is None: 1436 raise ValueError("Species Indexer state must be provided on accelerator.") 1437 1438 species = inputs["species"] 1439 nat = species.shape[0] 1440 1441 sizes = state["sizes"] 1442 1443 if self.species_order is not None: 1444 species_order = [el.strip() for el in self.species_order.split(",")] 1445 max_size = state["max_size"] 1446 1447 species_index = jnp.full( 1448 (len(species_order), max_size), nat, dtype=jnp.int32 1449 ) 1450 for i, el in enumerate(species_order): 1451 s = PERIODIC_TABLE_REV_IDX[el] 1452 if s in sizes.keys(): 1453 c = sizes[s] 1454 species_index = species_index.at[i, :].set( 1455 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1456 ) 1457 # if s in counts_dict.keys(): 1458 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1459 else: 1460 # species_index = { 1461 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1462 # for s, c in sizes.items() 1463 # } 1464 species_index = {} 1465 overflow = False 1466 natcount = 0 1467 for s, c in sizes.items(): 1468 mask = species == s 1469 new_size = jnp.sum(mask) 1470 natcount = natcount + new_size 1471 overflow = overflow | (new_size > c) # check if sizes are correct 1472 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1473 species == s, size=c, fill_value=nat 1474 )[0] 1475 1476 mask = species <= 0 1477 new_size = jnp.sum(mask) 1478 natcount = natcount + new_size 1479 overflow = overflow | ( 1480 natcount < species.shape[0] 1481 ) # check if any species missing 1482 1483 return { 1484 **inputs, 1485 self.output_key: species_index, 1486 self.output_key + "_overflow": overflow, 1487 }
1493@dataclasses.dataclass(frozen=True) 1494class BlockIndexer: 1495 """Build an index that splits atomic arrays by chemical blocks. 1496 1497 FPID: BLOCK_INDEXER 1498 1499 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1500 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1501 1502 """ 1503 1504 output_key: str = "block_index" 1505 """Key for the output dictionary.""" 1506 add_atoms: int = 0 1507 """Additional atoms to add to the sizes.""" 1508 add_atoms_margin: int = 10 1509 """Additional atoms to add to the sizes when adding margin.""" 1510 split_CNOPSSe: bool = False 1511 1512 FPID: ClassVar[str] = "BLOCK_INDEXER" 1513 1514 def init(self): 1515 return FrozenDict( 1516 { 1517 "sizes": {}, 1518 } 1519 ) 1520 1521 def build_chemical_blocks(self): 1522 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1523 if self.split_CNOPSSe: 1524 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1525 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1526 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1527 if self.split_CNOPSSe: 1528 _CHEMICAL_BLOCKS[6] = 1 1529 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1530 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1531 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1532 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1533 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1534 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS 1535 1536 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1537 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1538 1539 species = np.array(inputs["species"], dtype=np.int32) 1540 blocks = _CHEMICAL_BLOCKS[species] 1541 nat = species.shape[0] 1542 set_blocks, counts = np.unique(blocks, return_counts=True) 1543 1544 new_state = {**state} 1545 state_up = {} 1546 1547 sizes = state.get("sizes", FrozenDict({})) 1548 new_sizes = {**sizes} 1549 up_sizes = False 1550 for s, c in zip(set_blocks, counts): 1551 if s < 0: 1552 continue 1553 key = (s, _CHEMICAL_BLOCKS_NAMES[s]) 1554 if c > sizes.get(key, 0): 1555 up_sizes = True 1556 add_atoms = state.get("add_atoms", self.add_atoms) 1557 if add_margin: 1558 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1559 new_sizes[key] = c + add_atoms 1560 1561 new_sizes = FrozenDict(new_sizes) 1562 if up_sizes: 1563 state_up["sizes"] = (new_sizes, sizes) 1564 new_state["sizes"] = new_sizes 1565 1566 block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES} 1567 for (_,n), c in new_sizes.items(): 1568 block_index[n] = np.full(c, nat, dtype=np.int32) 1569 # block_index = { 1570 # n: np.full(c, nat, dtype=np.int32) 1571 # for (_,n), c in new_sizes.items() 1572 # } 1573 for s, c in zip(set_blocks, counts): 1574 if s < 0: 1575 continue 1576 block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0] 1577 1578 output = { 1579 **inputs, 1580 self.output_key: block_index, 1581 self.output_key + "_overflow": False, 1582 } 1583 1584 if return_state_update: 1585 return FrozenDict(new_state), output, state_up 1586 return FrozenDict(new_state), output 1587 1588 def check_reallocate(self, state, inputs, parent_overflow=False): 1589 """check for overflow and reallocate nblist if necessary""" 1590 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1591 if not overflow: 1592 return state, {}, inputs, False 1593 1594 add_margin = inputs[self.output_key + "_overflow"] 1595 state, inputs, state_up = self( 1596 state, inputs, return_state_update=True, add_margin=add_margin 1597 ) 1598 return state, state_up, inputs, True 1599 # return state, {}, inputs, parent_overflow 1600 1601 @partial(jax.jit, static_argnums=(0, 1)) 1602 def process(self, state, inputs): 1603 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1604 # assert ( 1605 # self.output_key in inputs 1606 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1607 1608 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1609 if self.output_key in inputs and not recompute_species_index: 1610 return inputs 1611 1612 if state is None: 1613 raise ValueError("Block Indexer state must be provided on accelerator.") 1614 1615 species = inputs["species"] 1616 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1617 nat = species.shape[0] 1618 1619 sizes = state["sizes"] 1620 1621 # species_index = { 1622 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1623 # for s, c in sizes.items() 1624 # } 1625 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1626 overflow = False 1627 natcount = 0 1628 for (s,name), c in sizes.items(): 1629 mask = blocks == s 1630 new_size = jnp.sum(mask) 1631 natcount = natcount + new_size 1632 overflow = overflow | (new_size > c) # check if sizes are correct 1633 block_index[name] = jnp.nonzero( 1634 mask, size=c, fill_value=nat 1635 )[0] 1636 1637 mask = blocks < 0 1638 new_size = jnp.sum(mask) 1639 natcount = natcount + new_size 1640 overflow = overflow | ( 1641 natcount < species.shape[0] 1642 ) # check if any species missing 1643 1644 return { 1645 **inputs, 1646 self.output_key: block_index, 1647 self.output_key + "_overflow": overflow, 1648 } 1649 1650 @partial(jax.jit, static_argnums=(0,)) 1651 def update_skin(self, inputs): 1652 return self.process(None, inputs)
Build an index that splits atomic arrays by chemical blocks.
FPID: BLOCK_INDEXER
If species_order
is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
If species_order
is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1521 def build_chemical_blocks(self): 1522 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1523 if self.split_CNOPSSe: 1524 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1525 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1526 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1527 if self.split_CNOPSSe: 1528 _CHEMICAL_BLOCKS[6] = 1 1529 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1530 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1531 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1532 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1533 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1534 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1588 def check_reallocate(self, state, inputs, parent_overflow=False): 1589 """check for overflow and reallocate nblist if necessary""" 1590 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1591 if not overflow: 1592 return state, {}, inputs, False 1593 1594 add_margin = inputs[self.output_key + "_overflow"] 1595 state, inputs, state_up = self( 1596 state, inputs, return_state_update=True, add_margin=add_margin 1597 ) 1598 return state, state_up, inputs, True 1599 # return state, {}, inputs, parent_overflow
check for overflow and reallocate nblist if necessary
1601 @partial(jax.jit, static_argnums=(0, 1)) 1602 def process(self, state, inputs): 1603 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1604 # assert ( 1605 # self.output_key in inputs 1606 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1607 1608 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1609 if self.output_key in inputs and not recompute_species_index: 1610 return inputs 1611 1612 if state is None: 1613 raise ValueError("Block Indexer state must be provided on accelerator.") 1614 1615 species = inputs["species"] 1616 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1617 nat = species.shape[0] 1618 1619 sizes = state["sizes"] 1620 1621 # species_index = { 1622 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1623 # for s, c in sizes.items() 1624 # } 1625 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1626 overflow = False 1627 natcount = 0 1628 for (s,name), c in sizes.items(): 1629 mask = blocks == s 1630 new_size = jnp.sum(mask) 1631 natcount = natcount + new_size 1632 overflow = overflow | (new_size > c) # check if sizes are correct 1633 block_index[name] = jnp.nonzero( 1634 mask, size=c, fill_value=nat 1635 )[0] 1636 1637 mask = blocks < 0 1638 new_size = jnp.sum(mask) 1639 natcount = natcount + new_size 1640 overflow = overflow | ( 1641 natcount < species.shape[0] 1642 ) # check if any species missing 1643 1644 return { 1645 **inputs, 1646 self.output_key: block_index, 1647 self.output_key + "_overflow": overflow, 1648 }
1655@dataclasses.dataclass(frozen=True) 1656class AtomPadding: 1657 """Pad atomic arrays to a fixed size.""" 1658 1659 mult_size: float = 1.2 1660 """Multiplicative factor for resizing the atomic arrays.""" 1661 add_sys: int = 0 1662 1663 def init(self): 1664 return {"prev_nat": 0, "prev_nsys": 0} 1665 1666 def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]: 1667 species = inputs["species"] 1668 nat = species.shape[0] 1669 1670 prev_nat = state.get("prev_nat", 0) 1671 prev_nat_ = prev_nat 1672 if nat > prev_nat_: 1673 prev_nat_ = int(self.mult_size * nat) + 1 1674 1675 nsys = len(inputs["natoms"]) 1676 prev_nsys = state.get("prev_nsys", 0) 1677 prev_nsys_ = prev_nsys 1678 if nsys > prev_nsys_: 1679 prev_nsys_ = nsys + self.add_sys 1680 1681 add_atoms = prev_nat_ - nat 1682 add_sys = prev_nsys_ - nsys + 1 1683 output = {**inputs} 1684 if add_atoms > 0: 1685 for k, v in inputs.items(): 1686 if isinstance(v, np.ndarray) or isinstance(v, jax.Array): 1687 if v.shape[0] == nat: 1688 output[k] = np.append( 1689 v, 1690 np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype), 1691 axis=0, 1692 ) 1693 elif v.shape[0] == nsys: 1694 if k == "cells": 1695 output[k] = np.append( 1696 v, 1697 1000 1698 * np.eye(3, dtype=v.dtype)[None, :, :].repeat( 1699 add_sys, axis=0 1700 ), 1701 axis=0, 1702 ) 1703 else: 1704 output[k] = np.append( 1705 v, 1706 np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype), 1707 axis=0, 1708 ) 1709 output["natoms"] = np.append( 1710 inputs["natoms"], np.zeros(add_sys, dtype=np.int32) 1711 ) 1712 output["species"] = np.append( 1713 species, -1 * np.ones(add_atoms, dtype=species.dtype) 1714 ) 1715 output["batch_index"] = np.append( 1716 inputs["batch_index"], 1717 np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype), 1718 ) 1719 if "system_index" in inputs: 1720 output["system_index"] = np.append( 1721 inputs["system_index"], 1722 np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype), 1723 ) 1724 1725 output["true_atoms"] = output["species"] > 0 1726 output["true_sys"] = np.arange(len(output["natoms"])) < nsys 1727 1728 state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_} 1729 1730 return FrozenDict(state), output
Pad atomic arrays to a fixed size.
1733def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]: 1734 """Remove padding from atomic arrays.""" 1735 if "true_atoms" not in inputs: 1736 return inputs 1737 1738 species = np.asarray(inputs["species"]) 1739 true_atoms = np.asarray(inputs["true_atoms"]) 1740 true_sys = np.asarray(inputs["true_sys"]) 1741 natall = species.shape[0] 1742 nat = np.argmax(species <= 0) 1743 if nat == 0: 1744 return inputs 1745 1746 natoms = inputs["natoms"] 1747 nsysall = len(natoms) 1748 1749 output = {**inputs} 1750 for k, v in inputs.items(): 1751 if isinstance(v, jax.Array) or isinstance(v, np.ndarray): 1752 v = np.asarray(v) 1753 if v.ndim == 0: 1754 output[k] = v 1755 elif v.shape[0] == natall: 1756 output[k] = v[true_atoms] 1757 elif v.shape[0] == nsysall: 1758 output[k] = v[true_sys] 1759 del output["true_sys"] 1760 del output["true_atoms"] 1761 return output
Remove padding from atomic arrays.
1764def check_input(inputs): 1765 """Check the input dictionary for required keys and types.""" 1766 assert "species" in inputs, "species must be provided" 1767 assert "coordinates" in inputs, "coordinates must be provided" 1768 species = inputs["species"].astype(np.int32) 1769 ifake = np.argmax(species <= 0) 1770 if ifake > 0: 1771 assert np.all(species[:ifake] > 0), "species must be positive" 1772 nat = inputs["species"].shape[0] 1773 1774 natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32) 1775 batch_index = inputs.get( 1776 "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms) 1777 ).astype(np.int32) 1778 output = {**inputs, "natoms": natoms, "batch_index": batch_index} 1779 if "cells" in inputs: 1780 cells = inputs["cells"] 1781 if "reciprocal_cells" not in inputs: 1782 reciprocal_cells = np.linalg.inv(cells) 1783 else: 1784 reciprocal_cells = inputs["reciprocal_cells"] 1785 if cells.ndim == 2: 1786 cells = cells[None, :, :] 1787 if reciprocal_cells.ndim == 2: 1788 reciprocal_cells = reciprocal_cells[None, :, :] 1789 output["cells"] = cells 1790 output["reciprocal_cells"] = reciprocal_cells 1791 1792 return output
Check the input dictionary for required keys and types.
1795def convert_to_jax(data): 1796 """Convert a numpy arrays to jax arrays in a pytree.""" 1797 1798 def convert(x): 1799 if isinstance(x, np.ndarray): 1800 # if x.dtype == np.float64: 1801 # return jnp.asarray(x, dtype=jnp.float32) 1802 return jnp.asarray(x) 1803 return x 1804 1805 return jax.tree_util.tree_map(convert, data)
Convert a numpy arrays to jax arrays in a pytree.
1808class JaxConverter(nn.Module): 1809 """Convert numpy arrays to jax arrays in a pytree.""" 1810 1811 def __call__(self, data): 1812 return convert_to_jax(data)
Convert numpy arrays to jax arrays in a pytree.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1815@dataclasses.dataclass(frozen=True) 1816class PreprocessingChain: 1817 """Chain of preprocessing layers.""" 1818 1819 layers: Tuple[Callable[..., Dict[str, Any]]] 1820 """Preprocessing layers.""" 1821 use_atom_padding: bool = False 1822 """Add an AtomPadding layer at the beginning of the chain.""" 1823 atom_padder: AtomPadding = AtomPadding() 1824 """AtomPadding layer.""" 1825 1826 def __post_init__(self): 1827 if not isinstance(self.layers, Sequence): 1828 raise ValueError( 1829 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 1830 ) 1831 if not self.layers: 1832 raise ValueError(f"Error: no Preprocessing layers were provided.") 1833 1834 def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]: 1835 do_check_input = state.get("check_input", True) 1836 if do_check_input: 1837 inputs = check_input(inputs) 1838 new_state = {**state} 1839 if self.use_atom_padding: 1840 s, inputs = self.atom_padder(state["padder_state"], inputs) 1841 new_state["padder_state"] = s 1842 layer_state = state["layers_state"] 1843 new_layer_state = [] 1844 for i,layer in enumerate(self.layers): 1845 s, inputs = layer(layer_state[i], inputs, return_state_update=False) 1846 new_layer_state.append(s) 1847 new_state["layers_state"] = tuple(new_layer_state) 1848 return FrozenDict(new_state), convert_to_jax(inputs) 1849 1850 def check_reallocate(self, state, inputs): 1851 new_state = [] 1852 state_up = [] 1853 layer_state = state["layers_state"] 1854 parent_overflow = False 1855 for i,layer in enumerate(self.layers): 1856 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1857 layer_state[i], inputs, parent_overflow 1858 ) 1859 new_state.append(s) 1860 state_up.append(s_up) 1861 1862 if not parent_overflow: 1863 return state, {}, inputs, False 1864 return ( 1865 FrozenDict({**state, "layers_state": tuple(new_state)}), 1866 state_up, 1867 inputs, 1868 True, 1869 ) 1870 1871 def atom_padding(self, state, inputs): 1872 if self.use_atom_padding: 1873 padder_state,inputs = self.atom_padder(state["padder_state"], inputs) 1874 return FrozenDict({**state,"padder_state": padder_state}), inputs 1875 return state, inputs 1876 1877 @partial(jax.jit, static_argnums=(0, 1)) 1878 def process(self, state, inputs): 1879 layer_state = state["layers_state"] 1880 for i,layer in enumerate(self.layers): 1881 inputs = layer.process(layer_state[i], inputs) 1882 return inputs 1883 1884 @partial(jax.jit, static_argnums=(0)) 1885 def update_skin(self, inputs): 1886 for layer in self.layers: 1887 inputs = layer.update_skin(inputs) 1888 return inputs 1889 1890 def init(self): 1891 state = {"check_input": True} 1892 if self.use_atom_padding: 1893 state["padder_state"] = self.atom_padder.init() 1894 layer_state = [] 1895 for layer in self.layers: 1896 layer_state.append(layer.init()) 1897 state["layers_state"] = tuple(layer_state) 1898 return FrozenDict(state) 1899 1900 def init_with_output(self, inputs): 1901 state = self.init() 1902 return self(state, inputs) 1903 1904 def get_processors(self): 1905 processors = [] 1906 for layer in self.layers: 1907 if hasattr(layer, "get_processor"): 1908 processors.append(layer.get_processor()) 1909 return processors 1910 1911 def get_graphs_properties(self): 1912 properties = {} 1913 for layer in self.layers: 1914 if hasattr(layer, "get_graph_properties"): 1915 properties = deep_update(properties, layer.get_graph_properties()) 1916 return properties
Chain of preprocessing layers.
1850 def check_reallocate(self, state, inputs): 1851 new_state = [] 1852 state_up = [] 1853 layer_state = state["layers_state"] 1854 parent_overflow = False 1855 for i,layer in enumerate(self.layers): 1856 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1857 layer_state[i], inputs, parent_overflow 1858 ) 1859 new_state.append(s) 1860 state_up.append(s_up) 1861 1862 if not parent_overflow: 1863 return state, {}, inputs, False 1864 return ( 1865 FrozenDict({**state, "layers_state": tuple(new_state)}), 1866 state_up, 1867 inputs, 1868 True, 1869 )
1890 def init(self): 1891 state = {"check_input": True} 1892 if self.use_atom_padding: 1893 state["padder_state"] = self.atom_padder.init() 1894 layer_state = [] 1895 for layer in self.layers: 1896 layer_state.append(layer.init()) 1897 state["layers_state"] = tuple(layer_state) 1898 return FrozenDict(state)