fennol.models.preprocessing
1import flax.linen as nn 2from typing import Sequence, Callable, Union, Dict, Any, ClassVar 3import jax.numpy as jnp 4import jax 5import numpy as np 6from typing import Optional, Tuple 7import numba 8import dataclasses 9from functools import partial 10 11from flax.core.frozen_dict import FrozenDict 12 13 14from ..utils.activations import chain 15from ..utils import deep_update, mask_filter_1d 16from ..utils.kspace import get_reciprocal_space_parameters 17from .misc.misc import SwitchFunction 18from ..utils.periodic_table import PERIODIC_TABLE, PERIODIC_TABLE_REV_IDX,CHEMICAL_BLOCKS,CHEMICAL_BLOCKS_NAMES 19 20 21@dataclasses.dataclass(frozen=True) 22class GraphGenerator: 23 """Generate a graph from a set of coordinates 24 25 FPID: GRAPH 26 27 For now, we generate all pairs of atoms and filter based on cutoff. 28 If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist. 29 """ 30 31 cutoff: float 32 """Cutoff distance for the graph.""" 33 graph_key: str = "graph" 34 """Key of the graph in the outputs.""" 35 switch_params: dict = dataclasses.field(default_factory=dict, hash=False) 36 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 37 kmax: int = 30 38 """Maximum number of k-points to consider.""" 39 kthr: float = 1e-6 40 """Threshold for k-point filtering.""" 41 k_space: bool = False 42 """Whether to generate k-space information for the graph.""" 43 mult_size: float = 1.05 44 """Multiplicative factor for resizing the nblist.""" 45 # covalent_cutoff: bool = False 46 47 FPID: ClassVar[str] = "GRAPH" 48 49 def init(self): 50 return FrozenDict( 51 { 52 "max_nat": 1, 53 "npairs": 1, 54 "nblist_mult_size": self.mult_size, 55 } 56 ) 57 58 def get_processor(self) -> Tuple[nn.Module, Dict]: 59 return GraphProcessor, { 60 "cutoff": self.cutoff, 61 "graph_key": self.graph_key, 62 "switch_params": self.switch_params, 63 "name": f"{self.graph_key}_Processor", 64 } 65 66 def get_graph_properties(self): 67 return { 68 self.graph_key: { 69 "cutoff": self.cutoff, 70 "directed": True, 71 } 72 } 73 74 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 75 """build a nblist on cpu with numpy and dynamic shapes + store max shapes""" 76 if self.graph_key in inputs: 77 graph = inputs[self.graph_key] 78 if "keep_graph" in graph: 79 return state, inputs 80 81 coords = np.array(inputs["coordinates"], dtype=np.float32) 82 natoms = np.array(inputs["natoms"], dtype=np.int32) 83 batch_index = np.array(inputs["batch_index"], dtype=np.int32) 84 85 new_state = {**state} 86 state_up = {} 87 88 mult_size = state.get("nblist_mult_size", self.mult_size) 89 assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0" 90 91 if natoms.shape[0] == 1: 92 max_nat = coords.shape[0] 93 true_max_nat = max_nat 94 else: 95 max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0])) 96 true_max_nat = int(np.max(natoms)) 97 if true_max_nat > max_nat: 98 add_atoms = state.get("add_atoms", 0) 99 new_maxnat = true_max_nat + add_atoms 100 state_up["max_nat"] = (new_maxnat, max_nat) 101 new_state["max_nat"] = new_maxnat 102 103 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 104 105 ### compute indices of all pairs 106 p1, p2 = np.triu_indices(true_max_nat, 1) 107 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 108 pbc_shifts = None 109 if natoms.shape[0] > 1: 110 ## batching => mask irrelevant pairs 111 mask_p12 = ( 112 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 113 ).flatten() 114 shift = np.concatenate( 115 (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32)) 116 ) 117 p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 118 p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 119 120 apply_pbc = "cells" in inputs 121 if not apply_pbc: 122 ### NO PBC 123 vec = coords[p2] - coords[p1] 124 else: 125 cells = np.array(inputs["cells"], dtype=np.float32) 126 reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32) 127 minimage = state.get("minimum_image", True) 128 if minimage: 129 ## MINIMUM IMAGE CONVENTION 130 vec = coords[p2] - coords[p1] 131 if cells.shape[0] == 1: 132 vecpbc = np.dot(vec, reciprocal_cells[0]) 133 pbc_shifts = -np.round(vecpbc) 134 vec = vec + np.dot(pbc_shifts, cells[0]) 135 else: 136 batch_index_vec = batch_index[p1] 137 vecpbc = np.einsum( 138 "aj,aji->ai", vec, reciprocal_cells[batch_index_vec] 139 ) 140 pbc_shifts = -np.round(vecpbc) 141 vec = vec + np.einsum( 142 "aj,aji->ai", pbc_shifts, cells[batch_index_vec] 143 ) 144 else: 145 ### GENERAL PBC 146 ## put all atoms in central box 147 if cells.shape[0] == 1: 148 coords_pbc = np.dot(coords, reciprocal_cells[0]) 149 at_shifts = -np.floor(coords_pbc) 150 coords_pbc = coords + np.dot(at_shifts, cells[0]) 151 else: 152 coords_pbc = np.einsum( 153 "aj,aji->ai", coords, reciprocal_cells[batch_index] 154 ) 155 at_shifts = -np.floor(coords_pbc) 156 coords_pbc = coords + np.einsum( 157 "aj,aji->ai", at_shifts, cells[batch_index] 158 ) 159 vec = coords_pbc[p2] - coords_pbc[p1] 160 161 ## compute maximum number of repeats 162 inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5 163 cdinv = cutoff_skin * inv_distances 164 num_repeats_all = np.ceil(cdinv).astype(np.int32) 165 if "true_sys" in inputs: 166 num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0) 167 # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all) 168 num_repeats = np.max(num_repeats_all, axis=0) 169 num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0))) 170 if np.any(num_repeats > num_repeats_prev): 171 num_repeats_new = np.maximum(num_repeats, num_repeats_prev) 172 state_up["num_repeats_pbc"] = ( 173 tuple(num_repeats_new), 174 tuple(num_repeats_prev), 175 ) 176 new_state["num_repeats_pbc"] = tuple(num_repeats_new) 177 ## build all possible shifts 178 cell_shift_pbc = np.array( 179 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 180 dtype=cells.dtype, 181 ).T.reshape(-1, 3) 182 ## shift applied to vectors 183 if cells.shape[0] == 1: 184 dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :] 185 vec = (vec[:, None, :] + dvec).reshape(-1, 3) 186 pbc_shifts = np.broadcast_to( 187 cell_shift_pbc[None, :, :], 188 (p1.shape[0], cell_shift_pbc.shape[0], 3), 189 ).reshape(-1, 3) 190 p1 = np.broadcast_to( 191 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 192 ).flatten() 193 p2 = np.broadcast_to( 194 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 195 ).flatten() 196 if natoms.shape[0] > 1: 197 mask_p12 = np.broadcast_to( 198 mask_p12[:, None], 199 (mask_p12.shape[0], cell_shift_pbc.shape[0]), 200 ).flatten() 201 else: 202 dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells) 203 204 ## get pbc shifts specific to each box 205 cell_shift_pbc = np.broadcast_to( 206 cell_shift_pbc[None, :, :], 207 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 208 ) 209 mask = np.all( 210 np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 211 ).flatten() 212 idx = np.nonzero(mask)[0] 213 nshifts = idx.shape[0] 214 nshifts_prev = state.get("nshifts_pbc", 0) 215 if nshifts > nshifts_prev or add_margin: 216 nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1 217 state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev) 218 new_state["nshifts_pbc"] = nshifts_new 219 220 dvec_filter = dvec.reshape(-1, 3)[idx, :] 221 cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :] 222 223 ## get batch shift in the dvec_filter array 224 nrep = np.prod(2 * num_repeats_all + 1, axis=1) 225 bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1])) 226 227 ## compute vectors 228 batch_index_vec = batch_index[p1] 229 nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0) 230 vec = vec.repeat(nrep_vec, axis=0) 231 nvec_pbc = nrep_vec.sum() #vec.shape[0] 232 nvec_pbc_prev = state.get("nvec_pbc", 0) 233 if nvec_pbc > nvec_pbc_prev or add_margin: 234 nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1 235 state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev) 236 new_state["nvec_pbc"] = nvec_pbc_new 237 238 # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev) 239 ## get shift index 240 dshift = np.concatenate( 241 (np.array([0]), np.cumsum(nrep_vec)[:-1]) 242 ).repeat(nrep_vec) 243 # ishift = np.arange(dshift.shape[0])-dshift 244 # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec) 245 icellshift = ( 246 np.arange(dshift.shape[0]) 247 - dshift 248 + bshift[batch_index_vec].repeat(nrep_vec) 249 ) 250 # shift vectors 251 vec = vec + dvec_filter[icellshift] 252 pbc_shifts = cell_shift_pbc_filter[icellshift] 253 254 p1 = np.repeat(p1, nrep_vec) 255 p2 = np.repeat(p2, nrep_vec) 256 if natoms.shape[0] > 1: 257 mask_p12 = np.repeat(mask_p12, nrep_vec) 258 259 ## compute distances 260 d12 = (vec**2).sum(axis=-1) 261 if natoms.shape[0] > 1: 262 d12 = np.where(mask_p12, d12, cutoff_skin**2) 263 264 ## filter pairs 265 max_pairs = state.get("npairs", 1) 266 mask = d12 < cutoff_skin**2 267 idx = np.nonzero(mask)[0] 268 npairs = idx.shape[0] 269 if npairs > max_pairs or add_margin: 270 prev_max_pairs = max_pairs 271 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 272 state_up["npairs"] = (max_pairs, prev_max_pairs) 273 new_state["npairs"] = max_pairs 274 275 nat = coords.shape[0] 276 edge_src = np.full(max_pairs, nat, dtype=np.int32) 277 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 278 d12_ = np.full(max_pairs, cutoff_skin**2) 279 edge_src[:npairs] = p1[idx] 280 edge_dst[:npairs] = p2[idx] 281 d12_[:npairs] = d12[idx] 282 d12 = d12_ 283 284 if apply_pbc: 285 pbc_shifts_ = np.zeros((max_pairs, 3)) 286 pbc_shifts_[:npairs] = pbc_shifts[idx] 287 pbc_shifts = pbc_shifts_ 288 if not minimage: 289 pbc_shifts[:npairs] = ( 290 pbc_shifts[:npairs] 291 + at_shifts[edge_dst[:npairs]] 292 - at_shifts[edge_src[:npairs]] 293 ) 294 295 if "nblist_skin" in state: 296 edge_src_skin = edge_src 297 edge_dst_skin = edge_dst 298 if apply_pbc: 299 pbc_shifts_skin = pbc_shifts 300 max_pairs_skin = state.get("npairs_skin", 1) 301 mask = d12 < self.cutoff**2 302 idx = np.nonzero(mask)[0] 303 npairs_skin = idx.shape[0] 304 if npairs_skin > max_pairs_skin or add_margin: 305 prev_max_pairs_skin = max_pairs_skin 306 max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1 307 state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin) 308 new_state["npairs_skin"] = max_pairs_skin 309 edge_src = np.full(max_pairs_skin, nat, dtype=np.int32) 310 edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32) 311 d12_ = np.full(max_pairs_skin, self.cutoff**2) 312 edge_src[:npairs_skin] = edge_src_skin[idx] 313 edge_dst[:npairs_skin] = edge_dst_skin[idx] 314 d12_[:npairs_skin] = d12[idx] 315 d12 = d12_ 316 if apply_pbc: 317 pbc_shifts = np.full((max_pairs_skin, 3), 0.0) 318 pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx] 319 320 ## symmetrize 321 edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate( 322 (edge_dst, edge_src) 323 ) 324 d12 = np.concatenate((d12, d12)) 325 if apply_pbc: 326 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts)) 327 328 graph = inputs.get(self.graph_key, {}) 329 graph_out = { 330 **graph, 331 "edge_src": edge_src, 332 "edge_dst": edge_dst, 333 "d12": d12, 334 "overflow": False, 335 "pbc_shifts": pbc_shifts, 336 } 337 if "nblist_skin" in state: 338 graph_out["edge_src_skin"] = edge_src_skin 339 graph_out["edge_dst_skin"] = edge_dst_skin 340 if apply_pbc: 341 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 342 343 if self.k_space and apply_pbc: 344 if "k_points" not in graph: 345 ks, _, _, bewald = get_reciprocal_space_parameters( 346 reciprocal_cells, self.cutoff, self.kmax, self.kthr 347 ) 348 graph_out["k_points"] = ks 349 graph_out["b_ewald"] = bewald 350 351 output = {**inputs, self.graph_key: graph_out} 352 353 if return_state_update: 354 return FrozenDict(new_state), output, state_up 355 return FrozenDict(new_state), output 356 357 def check_reallocate(self, state, inputs, parent_overflow=False): 358 """check for overflow and reallocate nblist if necessary""" 359 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 360 if not overflow: 361 return state, {}, inputs, False 362 363 add_margin = inputs[self.graph_key].get("overflow", False) 364 state, inputs, state_up = self( 365 state, inputs, return_state_update=True, add_margin=add_margin 366 ) 367 return state, state_up, inputs, True 368 369 @partial(jax.jit, static_argnums=(0, 1)) 370 def process(self, state, inputs): 371 """build a nblist on accelerator with jax and precomputed shapes""" 372 if self.graph_key in inputs: 373 graph = inputs[self.graph_key] 374 if "keep_graph" in graph: 375 return inputs 376 coords = inputs["coordinates"] 377 natoms = inputs["natoms"] 378 batch_index = inputs["batch_index"] 379 380 if natoms.shape[0] == 1: 381 max_nat = coords.shape[0] 382 else: 383 max_nat = state.get( 384 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 385 ) 386 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 387 388 ### compute indices of all pairs 389 p1, p2 = np.triu_indices(max_nat, 1) 390 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 391 pbc_shifts = None 392 if natoms.shape[0] > 1: 393 ## batching => mask irrelevant pairs 394 mask_p12 = ( 395 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 396 ).flatten() 397 shift = jnp.concatenate( 398 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 399 ) 400 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 401 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 402 403 ## compute vectors 404 overflow_repeats = jnp.asarray(False, dtype=bool) 405 if "cells" not in inputs: 406 vec = coords[p2] - coords[p1] 407 else: 408 cells = inputs["cells"] 409 reciprocal_cells = inputs["reciprocal_cells"] 410 minimage = state.get("minimum_image", True) 411 412 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 413 vecpbc = jnp.dot(vec, reciprocal_cell) 414 if mode == "round": 415 pbc_shifts = -jnp.round(vecpbc) 416 elif mode == "floor": 417 pbc_shifts = -jnp.floor(vecpbc) 418 else: 419 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 420 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 421 422 if minimage: 423 ## minimum image convention 424 vec = coords[p2] - coords[p1] 425 426 if cells.shape[0] == 1: 427 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 428 else: 429 batch_index_vec = batch_index[p1] 430 vec, pbc_shifts = jax.vmap(compute_pbc)( 431 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 432 ) 433 else: 434 ### general PBC only for single cell yet 435 # if cells.shape[0] > 1: 436 # raise NotImplementedError( 437 # "General PBC not implemented for batches on accelerator." 438 # ) 439 # cell = cells[0] 440 # reciprocal_cell = reciprocal_cells[0] 441 442 ## put all atoms in central box 443 if cells.shape[0] == 1: 444 coords_pbc, at_shifts = compute_pbc( 445 coords, reciprocal_cells[0], cells[0], mode="floor" 446 ) 447 else: 448 coords_pbc, at_shifts = jax.vmap( 449 partial(compute_pbc, mode="floor") 450 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 451 vec = coords_pbc[p2] - coords_pbc[p1] 452 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 453 # if num_repeats is None: 454 # raise ValueError( 455 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 456 # ) 457 # check if num_repeats is larger than previous 458 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 459 cdinv = cutoff_skin * inv_distances 460 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 461 if "true_sys" in inputs: 462 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 463 num_repeats_new = jnp.max(num_repeats_all, axis=0) 464 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 465 466 cell_shift_pbc = jnp.asarray( 467 np.array( 468 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 469 dtype=cells.dtype, 470 ).T.reshape(-1, 3) 471 ) 472 473 if cells.shape[0] == 1: 474 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 475 pbc_shifts = jnp.broadcast_to( 476 cell_shift_pbc[None, :, :], 477 (p1.shape[0], cell_shift_pbc.shape[0], 3), 478 ).reshape(-1, 3) 479 p1 = jnp.broadcast_to( 480 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 481 ).flatten() 482 p2 = jnp.broadcast_to( 483 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 484 ).flatten() 485 if natoms.shape[0] > 1: 486 mask_p12 = jnp.broadcast_to( 487 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 488 ).flatten() 489 else: 490 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 491 492 ## get pbc shifts specific to each box 493 cell_shift_pbc = jnp.broadcast_to( 494 cell_shift_pbc[None, :, :], 495 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 496 ) 497 mask = jnp.all( 498 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 499 ).flatten() 500 max_shifts = state.get("nshifts_pbc", 1) 501 502 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 503 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 504 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 505 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 506 mask, 507 max_shifts, 508 (dvecx, 0.), 509 (dvecy, 0.), 510 (dvecz, 0.), 511 (shiftx, 0), 512 (shifty, 0), 513 (shiftz, 0), 514 ) 515 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 516 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 517 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 518 519 ## get batch shift in the dvec_filter array 520 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 521 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 522 523 ## repeat vectors 524 nvec_max = state.get("nvec_pbc", 1) 525 batch_index_vec = batch_index[p1] 526 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 527 nvec = nrep_vec.sum() 528 overflow_repeats = overflow_repeats | (nvec > nvec_max) 529 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 530 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 531 532 ## get shift index 533 dshift = jnp.concatenate( 534 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 535 ) 536 if nrep_vec.size == 0: 537 dshift = jnp.array([],dtype=jnp.int32) 538 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 539 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 540 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 541 vec = vec + dvec[icellshift] 542 pbc_shifts = cell_shift_pbc[icellshift] 543 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 544 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 545 if natoms.shape[0] > 1: 546 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 547 548 549 ## compute distances 550 d12 = (vec**2).sum(axis=-1) 551 if natoms.shape[0] > 1: 552 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 553 554 ## filter pairs 555 max_pairs = state.get("npairs", 1) 556 mask = d12 < cutoff_skin**2 557 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 558 mask, 559 max_pairs, 560 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 561 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 562 (d12, cutoff_skin**2), 563 ) 564 if "cells" in inputs: 565 pbc_shifts = ( 566 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 567 .at[scatter_idx] 568 .set(pbc_shifts, mode="drop") 569 ) 570 if not minimage: 571 pbc_shifts = ( 572 pbc_shifts 573 + at_shifts.at[edge_dst].get(fill_value=0.0) 574 - at_shifts.at[edge_src].get(fill_value=0.0) 575 ) 576 577 ## check for overflow 578 if natoms.shape[0] == 1: 579 true_max_nat = coords.shape[0] 580 else: 581 true_max_nat = jnp.max(natoms) 582 overflow_count = npairs > max_pairs 583 overflow_at = true_max_nat > max_nat 584 overflow = overflow_count | overflow_at | overflow_repeats 585 586 if "nblist_skin" in state: 587 # edge_mask_skin = edge_mask 588 edge_src_skin = edge_src 589 edge_dst_skin = edge_dst 590 if "cells" in inputs: 591 pbc_shifts_skin = pbc_shifts 592 max_pairs_skin = state.get("npairs_skin", 1) 593 mask = d12 < self.cutoff**2 594 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 595 mask, 596 max_pairs_skin, 597 (edge_src, coords.shape[0]), 598 (edge_dst, coords.shape[0]), 599 (d12, self.cutoff**2), 600 ) 601 if "cells" in inputs: 602 pbc_shifts = ( 603 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 604 .at[scatter_idx] 605 .set(pbc_shifts, mode="drop") 606 ) 607 overflow = overflow | (npairs_skin > max_pairs_skin) 608 609 ## symmetrize 610 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 611 (edge_dst, edge_src) 612 ) 613 d12 = jnp.concatenate((d12, d12)) 614 if "cells" in inputs: 615 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 616 617 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 618 graph_out = { 619 **graph, 620 "edge_src": edge_src, 621 "edge_dst": edge_dst, 622 "d12": d12, 623 "overflow": overflow, 624 "pbc_shifts": pbc_shifts, 625 } 626 if "nblist_skin" in state: 627 graph_out["edge_src_skin"] = edge_src_skin 628 graph_out["edge_dst_skin"] = edge_dst_skin 629 if "cells" in inputs: 630 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 631 632 if self.k_space and "cells" in inputs: 633 if "k_points" not in graph: 634 raise NotImplementedError( 635 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 636 ) 637 return {**inputs, self.graph_key: graph_out} 638 639 @partial(jax.jit, static_argnums=(0,)) 640 def update_skin(self, inputs): 641 """update the nblist without recomputing the full nblist""" 642 graph = inputs[self.graph_key] 643 644 edge_src_skin = graph["edge_src_skin"] 645 edge_dst_skin = graph["edge_dst_skin"] 646 coords = inputs["coordinates"] 647 vec = coords.at[edge_dst_skin].get( 648 mode="fill", fill_value=self.cutoff 649 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 650 651 if "cells" in inputs: 652 pbc_shifts_skin = graph["pbc_shifts_skin"] 653 cells = inputs["cells"] 654 if cells.shape[0] == 1: 655 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 656 else: 657 batch_index_vec = inputs["batch_index"][edge_src_skin] 658 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 659 660 nat = coords.shape[0] 661 d12 = jnp.sum(vec**2, axis=-1) 662 mask = d12 < self.cutoff**2 663 max_pairs = graph["edge_src"].shape[0] // 2 664 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 665 mask, 666 max_pairs, 667 (edge_src_skin, nat), 668 (edge_dst_skin, nat), 669 (d12, self.cutoff**2), 670 ) 671 if "cells" in inputs: 672 pbc_shifts = ( 673 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 674 .at[scatter_idx] 675 .set(pbc_shifts_skin) 676 ) 677 678 overflow = graph.get("overflow", False) | (npairs > max_pairs) 679 graph_out = { 680 **graph, 681 "edge_src": jnp.concatenate((edge_src, edge_dst)), 682 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 683 "d12": jnp.concatenate((d12, d12)), 684 "overflow": overflow, 685 } 686 if "cells" in inputs: 687 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 688 689 if self.k_space and "cells" in inputs: 690 if "k_points" not in graph: 691 raise NotImplementedError( 692 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 693 ) 694 695 return {**inputs, self.graph_key: graph_out} 696 697 698class GraphProcessor(nn.Module): 699 """Process a pre-generated graph 700 701 The pre-generated graph should contain the following keys: 702 - edge_src: source indices of the edges 703 - edge_dst: destination indices of the edges 704 - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs) 705 706 This module is automatically added to a FENNIX model when a GraphGenerator is used. 707 708 """ 709 710 cutoff: float 711 """Cutoff distance for the graph.""" 712 graph_key: str = "graph" 713 """Key of the graph in the outputs.""" 714 switch_params: dict = dataclasses.field(default_factory=dict) 715 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 716 717 @nn.compact 718 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 719 graph = inputs[self.graph_key] 720 coords = inputs["coordinates"] 721 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 722 # edge_mask = edge_src < coords.shape[0] 723 vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[ 724 edge_src 725 ].get(mode="fill", fill_value=0.0) 726 if "cells" in inputs: 727 cells = inputs["cells"] 728 if cells.shape[0] == 1: 729 vec = vec + jnp.dot(graph["pbc_shifts"], cells[0]) 730 else: 731 batch_index_vec = inputs["batch_index"][edge_src] 732 vec = vec + jax.vmap(jnp.dot)( 733 graph["pbc_shifts"], cells[batch_index_vec] 734 ) 735 736 distances = jnp.linalg.norm(vec, axis=-1) 737 edge_mask = distances < self.cutoff 738 739 switch = SwitchFunction( 740 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 741 )((distances, edge_mask)) 742 743 graph_out = { 744 **graph, 745 "vec": vec, 746 "distances": distances, 747 "switch": switch, 748 "edge_mask": edge_mask, 749 } 750 751 if "alch_group" in inputs: 752 alch_group = inputs["alch_group"] 753 lambda_e = inputs["alch_elambda"] 754 lambda_e = 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) 755 mask = alch_group[edge_src] == alch_group[edge_dst] 756 graph_out["switch_raw"] = switch 757 graph_out["switch"] = jnp.where( 758 mask, 759 switch, 760 lambda_e * switch , 761 ) 762 763 764 return {**inputs, self.graph_key: graph_out} 765 766 767@dataclasses.dataclass(frozen=True) 768class GraphFilter: 769 """Filter a graph based on a cutoff distance 770 771 FPID: GRAPH_FILTER 772 """ 773 774 cutoff: float 775 """Cutoff distance for the filtering.""" 776 parent_graph: str 777 """Key of the parent graph in the inputs.""" 778 graph_key: str 779 """Key of the filtered graph in the outputs.""" 780 remove_hydrogens: int = False 781 """Remove edges where the source is a hydrogen atom.""" 782 switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict) 783 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 784 k_space: bool = False 785 """Generate k-space information for the graph.""" 786 kmax: int = 30 787 """Maximum number of k-points to consider.""" 788 kthr: float = 1e-6 789 """Threshold for k-point filtering.""" 790 mult_size: float = 1.05 791 """Multiplicative factor for resizing the nblist.""" 792 793 FPID: ClassVar[str] = "GRAPH_FILTER" 794 795 def init(self): 796 return FrozenDict( 797 { 798 "npairs": 1, 799 "nblist_mult_size": self.mult_size, 800 } 801 ) 802 803 def get_processor(self) -> Tuple[nn.Module, Dict]: 804 return GraphFilterProcessor, { 805 "cutoff": self.cutoff, 806 "graph_key": self.graph_key, 807 "parent_graph": self.parent_graph, 808 "name": f"{self.graph_key}_Filter_{self.parent_graph}", 809 "switch_params": self.switch_params, 810 } 811 812 def get_graph_properties(self): 813 return { 814 self.graph_key: { 815 "cutoff": self.cutoff, 816 "directed": True, 817 "parent_graph": self.parent_graph, 818 } 819 } 820 821 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 822 """filter a nblist on cpu with numpy and dynamic shapes + store max shapes""" 823 graph_in = inputs[self.parent_graph] 824 nat = inputs["species"].shape[0] 825 826 new_state = {**state} 827 state_up = {} 828 mult_size = state.get("nblist_mult_size", self.mult_size) 829 assert mult_size >= 1., "nblist_mult_size should be >= 1." 830 831 edge_src = np.array(graph_in["edge_src"], dtype=np.int32) 832 d12 = np.array(graph_in["d12"], dtype=np.float32) 833 if self.remove_hydrogens: 834 species = inputs["species"] 835 src_idx = (edge_src < nat).nonzero()[0] 836 mask = np.zeros(edge_src.shape[0], dtype=bool) 837 mask[src_idx] = (species > 1)[edge_src[src_idx]] 838 d12 = np.where(mask, d12, self.cutoff**2) 839 mask = d12 < self.cutoff**2 840 841 max_pairs = state.get("npairs", 1) 842 idx = np.nonzero(mask)[0] 843 npairs = idx.shape[0] 844 if npairs > max_pairs or add_margin: 845 prev_max_pairs = max_pairs 846 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 847 state_up["npairs"] = (max_pairs, prev_max_pairs) 848 new_state["npairs"] = max_pairs 849 850 filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32) 851 edge_src = np.full(max_pairs, nat, dtype=np.int32) 852 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 853 d12_ = np.full(max_pairs, self.cutoff**2) 854 filter_indices[:npairs] = idx 855 edge_src[:npairs] = graph_in["edge_src"][idx] 856 edge_dst[:npairs] = graph_in["edge_dst"][idx] 857 d12_[:npairs] = d12[idx] 858 d12 = d12_ 859 860 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 861 graph_out = { 862 **graph, 863 "edge_src": edge_src, 864 "edge_dst": edge_dst, 865 "filter_indices": filter_indices, 866 "d12": d12, 867 "overflow": False, 868 } 869 870 if self.k_space and "cells" in inputs: 871 if "k_points" not in graph: 872 ks, _, _, bewald = get_reciprocal_space_parameters( 873 inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr 874 ) 875 graph_out["k_points"] = ks 876 graph_out["b_ewald"] = bewald 877 878 output = {**inputs, self.graph_key: graph_out} 879 if return_state_update: 880 return FrozenDict(new_state), output, state_up 881 return FrozenDict(new_state), output 882 883 def check_reallocate(self, state, inputs, parent_overflow=False): 884 """check for overflow and reallocate nblist if necessary""" 885 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 886 if not overflow: 887 return state, {}, inputs, False 888 889 add_margin = inputs[self.graph_key].get("overflow", False) 890 state, inputs, state_up = self( 891 state, inputs, return_state_update=True, add_margin=add_margin 892 ) 893 return state, state_up, inputs, True 894 895 @partial(jax.jit, static_argnums=(0, 1)) 896 def process(self, state, inputs): 897 """filter a nblist on accelerator with jax and precomputed shapes""" 898 graph_in = inputs[self.parent_graph] 899 if state is None: 900 # skin update mode 901 graph = inputs[self.graph_key] 902 max_pairs = graph["edge_src"].shape[0] 903 else: 904 max_pairs = state.get("npairs", 1) 905 906 max_pairs_in = graph_in["edge_src"].shape[0] 907 nat = inputs["species"].shape[0] 908 909 edge_src = graph_in["edge_src"] 910 d12 = graph_in["d12"] 911 if self.remove_hydrogens: 912 species = inputs["species"] 913 mask = (species > 1)[edge_src] 914 d12 = jnp.where(mask, d12, self.cutoff**2) 915 mask = d12 < self.cutoff**2 916 917 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 918 mask, 919 max_pairs, 920 (edge_src, nat), 921 (graph_in["edge_dst"], nat), 922 (d12, self.cutoff**2), 923 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 924 ) 925 926 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 927 overflow = graph.get("overflow", False) | (npairs > max_pairs) 928 graph_out = { 929 **graph, 930 "edge_src": edge_src, 931 "edge_dst": edge_dst, 932 "filter_indices": filter_indices, 933 "d12": d12, 934 "overflow": overflow, 935 } 936 937 if self.k_space and "cells" in inputs: 938 if "k_points" not in graph: 939 raise NotImplementedError( 940 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 941 ) 942 943 return {**inputs, self.graph_key: graph_out} 944 945 @partial(jax.jit, static_argnums=(0,)) 946 def update_skin(self, inputs): 947 return self.process(None, inputs) 948 949 950class GraphFilterProcessor(nn.Module): 951 """Filter processing for a pre-generated graph 952 953 This module is automatically added to a FENNIX model when a GraphFilter is used. 954 """ 955 956 cutoff: float 957 """Cutoff distance for the filtering.""" 958 graph_key: str 959 """Key of the filtered graph in the inputs.""" 960 parent_graph: str 961 """Key of the parent graph in the inputs.""" 962 switch_params: dict = dataclasses.field(default_factory=dict) 963 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 964 965 @nn.compact 966 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 967 graph_in = inputs[self.parent_graph] 968 graph = inputs[self.graph_key] 969 970 if graph_in["vec"].shape[0] == 0: 971 vec = graph_in["vec"] 972 distances = graph_in["distances"] 973 filter_indices = jnp.asarray([], dtype=jnp.int32) 974 else: 975 filter_indices = graph["filter_indices"] 976 vec = ( 977 graph_in["vec"] 978 .at[filter_indices] 979 .get(mode="fill", fill_value=self.cutoff) 980 ) 981 distances = ( 982 graph_in["distances"] 983 .at[filter_indices] 984 .get(mode="fill", fill_value=self.cutoff) 985 ) 986 987 edge_mask = distances < self.cutoff 988 switch = SwitchFunction( 989 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 990 )((distances, edge_mask)) 991 992 graph_out = { 993 **graph, 994 "vec": vec, 995 "distances": distances, 996 "switch": switch, 997 "filter_indices": filter_indices, 998 "edge_mask": edge_mask, 999 } 1000 1001 if "alch_group" in inputs: 1002 edge_src=graph["edge_src"] 1003 edge_dst=graph["edge_dst"] 1004 alch_group = inputs["alch_group"] 1005 lambda_e = inputs["alch_elambda"] 1006 lambda_e = 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) 1007 mask = alch_group[edge_src] == alch_group[edge_dst] 1008 graph_out["switch_raw"] = switch 1009 graph_out["switch"] = jnp.where( 1010 mask, 1011 switch, 1012 lambda_e * switch , 1013 ) 1014 1015 return {**inputs, self.graph_key: graph_out} 1016 1017 1018@dataclasses.dataclass(frozen=True) 1019class GraphAngularExtension: 1020 """Add angles list to a graph 1021 1022 FPID: GRAPH_ANGULAR_EXTENSION 1023 """ 1024 1025 mult_size: float = 1.05 1026 """Multiplicative factor for resizing the nblist.""" 1027 add_neigh: int = 5 1028 """Additional neighbors to add to the nblist when resizing.""" 1029 graph_key: str = "graph" 1030 """Key of the graph in the inputs.""" 1031 1032 FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION" 1033 1034 def init(self): 1035 return FrozenDict( 1036 { 1037 "nangles": 0, 1038 "nblist_mult_size": self.mult_size, 1039 "max_neigh": self.add_neigh, 1040 "add_neigh": self.add_neigh, 1041 } 1042 ) 1043 1044 def get_processor(self) -> Tuple[nn.Module, Dict]: 1045 return GraphAngleProcessor, { 1046 "graph_key": self.graph_key, 1047 "name": f"{self.graph_key}_AngleProcessor", 1048 } 1049 1050 def get_graph_properties(self): 1051 return { 1052 self.graph_key: { 1053 "has_angles": True, 1054 } 1055 } 1056 1057 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1058 """build angle nblist on cpu with numpy and dynamic shapes + store max shapes""" 1059 graph = inputs[self.graph_key] 1060 edge_src = np.array(graph["edge_src"], dtype=np.int32) 1061 1062 new_state = {**state} 1063 state_up = {} 1064 mult_size = state.get("nblist_mult_size", self.mult_size) 1065 assert mult_size >= 1., "nblist_mult_size should be >= 1." 1066 1067 ### count number of neighbors 1068 nat = inputs["species"].shape[0] 1069 count = np.zeros(nat + 1, dtype=np.int32) 1070 np.add.at(count, edge_src, 1) 1071 max_count = int(np.max(count[:-1])) 1072 1073 ### get sizes 1074 max_neigh = state.get("max_neigh", self.add_neigh) 1075 nedge = edge_src.shape[0] 1076 if max_count > max_neigh or add_margin: 1077 prev_max_neigh = max_neigh 1078 max_neigh = max(max_count, max_neigh) + state.get( 1079 "add_neigh", self.add_neigh 1080 ) 1081 state_up["max_neigh"] = (max_neigh, prev_max_neigh) 1082 new_state["max_neigh"] = max_neigh 1083 1084 max_neigh_arr = np.empty(max_neigh, dtype=bool) 1085 1086 nedge = edge_src.shape[0] 1087 1088 ### sort edge_src 1089 idx_sort = np.argsort(edge_src) 1090 edge_src_sorted = edge_src[idx_sort] 1091 1092 ### map sparse to dense nblist 1093 offset = np.tile(np.arange(max_count), nat) 1094 if max_count * nat >= nedge: 1095 offset = np.tile(np.arange(max_count), nat)[:nedge] 1096 else: 1097 offset = np.zeros(nedge, dtype=np.int32) 1098 offset[: max_count * nat] = np.tile(np.arange(max_count), nat) 1099 1100 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1101 mask = edge_src_sorted < nat 1102 indices = edge_src_sorted * max_count + offset 1103 indices = indices[mask] 1104 idx_sort = idx_sort[mask] 1105 edge_idx = np.full(nat * max_count, nedge, dtype=np.int32) 1106 edge_idx[indices] = idx_sort 1107 edge_idx = edge_idx.reshape(nat, max_count) 1108 1109 ### find all triplet for each atom center 1110 local_src, local_dst = np.triu_indices(max_count, 1) 1111 angle_src = edge_idx[:, local_src].flatten() 1112 angle_dst = edge_idx[:, local_dst].flatten() 1113 1114 ### mask for valid angles 1115 mask1 = angle_src < nedge 1116 mask2 = angle_dst < nedge 1117 angle_mask = mask1 & mask2 1118 1119 max_angles = state.get("nangles", 0) 1120 idx = np.nonzero(angle_mask)[0] 1121 nangles = idx.shape[0] 1122 if nangles > max_angles or add_margin: 1123 max_angles_prev = max_angles 1124 max_angles = int(mult_size * max(nangles, max_angles)) + 1 1125 state_up["nangles"] = (max_angles, max_angles_prev) 1126 new_state["nangles"] = max_angles 1127 1128 ## filter angles to sparse representation 1129 angle_src_ = np.full(max_angles, nedge, dtype=np.int32) 1130 angle_dst_ = np.full(max_angles, nedge, dtype=np.int32) 1131 angle_src_[:nangles] = angle_src[idx] 1132 angle_dst_[:nangles] = angle_dst[idx] 1133 1134 central_atom = np.full(max_angles, nat, dtype=np.int32) 1135 central_atom[:nangles] = edge_src[angle_src_[:nangles]] 1136 1137 ## update graph 1138 output = { 1139 **inputs, 1140 self.graph_key: { 1141 **graph, 1142 "angle_src": angle_src_, 1143 "angle_dst": angle_dst_, 1144 "central_atom": central_atom, 1145 "angle_overflow": False, 1146 "max_neigh": max_neigh, 1147 "__max_neigh_array": max_neigh_arr, 1148 }, 1149 } 1150 1151 if return_state_update: 1152 return FrozenDict(new_state), output, state_up 1153 return FrozenDict(new_state), output 1154 1155 def check_reallocate(self, state, inputs, parent_overflow=False): 1156 """check for overflow and reallocate nblist if necessary""" 1157 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1158 if not overflow: 1159 return state, {}, inputs, False 1160 1161 add_margin = inputs[self.graph_key]["angle_overflow"] 1162 state, inputs, state_up = self( 1163 state, inputs, return_state_update=True, add_margin=add_margin 1164 ) 1165 return state, state_up, inputs, True 1166 1167 @partial(jax.jit, static_argnums=(0, 1)) 1168 def process(self, state, inputs): 1169 """build angle nblist on accelerator with jax and precomputed shapes""" 1170 graph = inputs[self.graph_key] 1171 edge_src = graph["edge_src"] 1172 1173 ### count number of neighbors 1174 nat = inputs["species"].shape[0] 1175 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1176 max_count = jnp.max(count) 1177 1178 ### get sizes 1179 if state is None: 1180 max_neigh_arr = graph["__max_neigh_array"] 1181 max_neigh = max_neigh_arr.shape[0] 1182 prev_nangles = graph["angle_src"].shape[0] 1183 else: 1184 max_neigh = state.get("max_neigh", self.add_neigh) 1185 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1186 prev_nangles = state.get("nangles", 0) 1187 1188 nedge = edge_src.shape[0] 1189 1190 ### sort edge_src 1191 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1192 edge_src_sorted = edge_src[idx_sort] 1193 1194 ### map sparse to dense nblist 1195 if max_neigh * nat < nedge: 1196 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1197 offset = jnp.asarray( 1198 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1199 ) 1200 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1201 indices = edge_src_sorted * max_neigh + offset 1202 edge_idx = ( 1203 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1204 .at[indices] 1205 .set(idx_sort, mode="drop") 1206 .reshape(nat, max_neigh) 1207 ) 1208 1209 ### find all triplet for each atom center 1210 local_src, local_dst = np.triu_indices(max_neigh, 1) 1211 angle_src = edge_idx[:, local_src].flatten() 1212 angle_dst = edge_idx[:, local_dst].flatten() 1213 1214 ### mask for valid angles 1215 mask1 = angle_src < nedge 1216 mask2 = angle_dst < nedge 1217 angle_mask = mask1 & mask2 1218 1219 ## filter angles to sparse representation 1220 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1221 angle_mask, 1222 prev_nangles, 1223 (angle_src, nedge), 1224 (angle_dst, nedge), 1225 ) 1226 ## find central atom 1227 central_atom = edge_src[angle_src] 1228 1229 ## check for overflow 1230 angle_overflow = nangles > prev_nangles 1231 neigh_overflow = max_count > max_neigh 1232 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1233 1234 ## update graph 1235 output = { 1236 **inputs, 1237 self.graph_key: { 1238 **graph, 1239 "angle_src": angle_src, 1240 "angle_dst": angle_dst, 1241 "central_atom": central_atom, 1242 "angle_overflow": overflow, 1243 # "max_neigh": max_neigh, 1244 "__max_neigh_array": max_neigh_arr, 1245 }, 1246 } 1247 1248 return output 1249 1250 @partial(jax.jit, static_argnums=(0,)) 1251 def update_skin(self, inputs): 1252 return self.process(None, inputs) 1253 1254 1255class GraphAngleProcessor(nn.Module): 1256 """Process a pre-generated graph to compute angles 1257 1258 This module is automatically added to a FENNIX model when a GraphAngularExtension is used. 1259 1260 """ 1261 1262 graph_key: str 1263 """Key of the graph in the inputs.""" 1264 1265 @nn.compact 1266 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1267 graph = inputs[self.graph_key] 1268 distances = graph["distances"] 1269 vec = graph["vec"] 1270 angle_src = graph["angle_src"] 1271 angle_dst = graph["angle_dst"] 1272 1273 dir = vec / jnp.clip(distances[:, None], a_min=1.0e-5) 1274 cos_angles = ( 1275 dir.at[angle_src].get(mode="fill", fill_value=0.5) 1276 * dir.at[angle_dst].get(mode="fill", fill_value=0.5) 1277 ).sum(axis=-1) 1278 1279 angles = jnp.arccos(0.95 * cos_angles) 1280 1281 return { 1282 **inputs, 1283 self.graph_key: { 1284 **graph, 1285 # "cos_angles": cos_angles, 1286 "angles": angles, 1287 # "angle_mask": angle_mask, 1288 }, 1289 } 1290 1291 1292@dataclasses.dataclass(frozen=True) 1293class SpeciesIndexer: 1294 """Build an index that splits atomic arrays by species. 1295 1296 FPID: SPECIES_INDEXER 1297 1298 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1299 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1300 1301 """ 1302 1303 output_key: str = "species_index" 1304 """Key for the output dictionary.""" 1305 species_order: Optional[str] = None 1306 """Comma separated list of species in the order they should be indexed.""" 1307 add_atoms: int = 0 1308 """Additional atoms to add to the sizes.""" 1309 add_atoms_margin: int = 10 1310 """Additional atoms to add to the sizes when adding margin.""" 1311 1312 FPID: ClassVar[str] = "SPECIES_INDEXER" 1313 1314 def init(self): 1315 return FrozenDict( 1316 { 1317 "sizes": {}, 1318 } 1319 ) 1320 1321 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1322 species = np.array(inputs["species"], dtype=np.int32) 1323 nat = species.shape[0] 1324 set_species, counts = np.unique(species, return_counts=True) 1325 1326 new_state = {**state} 1327 state_up = {} 1328 1329 sizes = state.get("sizes", FrozenDict({})) 1330 new_sizes = {**sizes} 1331 up_sizes = False 1332 counts_dict = {} 1333 for s, c in zip(set_species, counts): 1334 if s <= 0: 1335 continue 1336 counts_dict[s] = c 1337 if c > sizes.get(s, 0): 1338 up_sizes = True 1339 add_atoms = state.get("add_atoms", self.add_atoms) 1340 if add_margin: 1341 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1342 new_sizes[s] = c + add_atoms 1343 1344 new_sizes = FrozenDict(new_sizes) 1345 if up_sizes: 1346 state_up["sizes"] = (new_sizes, sizes) 1347 new_state["sizes"] = new_sizes 1348 1349 if self.species_order is not None: 1350 species_order = [el.strip() for el in self.species_order.split(",")] 1351 max_size_prev = state.get("max_size", 0) 1352 max_size = max(new_sizes.values()) 1353 if max_size > max_size_prev: 1354 state_up["max_size"] = (max_size, max_size_prev) 1355 new_state["max_size"] = max_size 1356 max_size_prev = max_size 1357 1358 species_index = np.full((len(species_order), max_size), nat, dtype=np.int32) 1359 for i, el in enumerate(species_order): 1360 s = PERIODIC_TABLE_REV_IDX[el] 1361 if s in counts_dict.keys(): 1362 species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1363 else: 1364 species_index = { 1365 PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32) 1366 for s, c in new_sizes.items() 1367 } 1368 for s, c in zip(set_species, counts): 1369 if s <= 0: 1370 continue 1371 species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0] 1372 1373 output = { 1374 **inputs, 1375 self.output_key: species_index, 1376 self.output_key + "_overflow": False, 1377 } 1378 1379 if return_state_update: 1380 return FrozenDict(new_state), output, state_up 1381 return FrozenDict(new_state), output 1382 1383 def check_reallocate(self, state, inputs, parent_overflow=False): 1384 """check for overflow and reallocate nblist if necessary""" 1385 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1386 if not overflow: 1387 return state, {}, inputs, False 1388 1389 add_margin = inputs[self.output_key + "_overflow"] 1390 state, inputs, state_up = self( 1391 state, inputs, return_state_update=True, add_margin=add_margin 1392 ) 1393 return state, state_up, inputs, True 1394 # return state, {}, inputs, parent_overflow 1395 1396 @partial(jax.jit, static_argnums=(0, 1)) 1397 def process(self, state, inputs): 1398 # assert ( 1399 # self.output_key in inputs 1400 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1401 1402 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1403 if self.output_key in inputs and not recompute_species_index: 1404 return inputs 1405 1406 if state is None: 1407 raise ValueError("Species Indexer state must be provided on accelerator.") 1408 1409 species = inputs["species"] 1410 nat = species.shape[0] 1411 1412 sizes = state["sizes"] 1413 1414 if self.species_order is not None: 1415 species_order = [el.strip() for el in self.species_order.split(",")] 1416 max_size = state["max_size"] 1417 1418 species_index = jnp.full( 1419 (len(species_order), max_size), nat, dtype=jnp.int32 1420 ) 1421 for i, el in enumerate(species_order): 1422 s = PERIODIC_TABLE_REV_IDX[el] 1423 if s in sizes.keys(): 1424 c = sizes[s] 1425 species_index = species_index.at[i, :].set( 1426 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1427 ) 1428 # if s in counts_dict.keys(): 1429 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1430 else: 1431 # species_index = { 1432 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1433 # for s, c in sizes.items() 1434 # } 1435 species_index = {} 1436 overflow = False 1437 natcount = 0 1438 for s, c in sizes.items(): 1439 mask = species == s 1440 new_size = jnp.sum(mask) 1441 natcount = natcount + new_size 1442 overflow = overflow | (new_size > c) # check if sizes are correct 1443 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1444 species == s, size=c, fill_value=nat 1445 )[0] 1446 1447 mask = species <= 0 1448 new_size = jnp.sum(mask) 1449 natcount = natcount + new_size 1450 overflow = overflow | ( 1451 natcount < species.shape[0] 1452 ) # check if any species missing 1453 1454 return { 1455 **inputs, 1456 self.output_key: species_index, 1457 self.output_key + "_overflow": overflow, 1458 } 1459 1460 @partial(jax.jit, static_argnums=(0,)) 1461 def update_skin(self, inputs): 1462 return self.process(None, inputs) 1463 1464@dataclasses.dataclass(frozen=True) 1465class BlockIndexer: 1466 """Build an index that splits atomic arrays by chemical blocks. 1467 1468 FPID: BLOCK_INDEXER 1469 1470 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1471 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1472 1473 """ 1474 1475 output_key: str = "block_index" 1476 """Key for the output dictionary.""" 1477 add_atoms: int = 0 1478 """Additional atoms to add to the sizes.""" 1479 add_atoms_margin: int = 10 1480 """Additional atoms to add to the sizes when adding margin.""" 1481 split_CNOPSSe: bool = False 1482 1483 FPID: ClassVar[str] = "BLOCK_INDEXER" 1484 1485 def init(self): 1486 return FrozenDict( 1487 { 1488 "sizes": {}, 1489 } 1490 ) 1491 1492 def build_chemical_blocks(self): 1493 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1494 if self.split_CNOPSSe: 1495 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1496 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1497 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1498 if self.split_CNOPSSe: 1499 _CHEMICAL_BLOCKS[6] = 1 1500 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1501 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1502 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1503 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1504 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1505 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS 1506 1507 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1508 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1509 1510 species = np.array(inputs["species"], dtype=np.int32) 1511 blocks = _CHEMICAL_BLOCKS[species] 1512 nat = species.shape[0] 1513 set_blocks, counts = np.unique(blocks, return_counts=True) 1514 1515 new_state = {**state} 1516 state_up = {} 1517 1518 sizes = state.get("sizes", FrozenDict({})) 1519 new_sizes = {**sizes} 1520 up_sizes = False 1521 for s, c in zip(set_blocks, counts): 1522 if s < 0: 1523 continue 1524 key = (s, _CHEMICAL_BLOCKS_NAMES[s]) 1525 if c > sizes.get(key, 0): 1526 up_sizes = True 1527 add_atoms = state.get("add_atoms", self.add_atoms) 1528 if add_margin: 1529 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1530 new_sizes[key] = c + add_atoms 1531 1532 new_sizes = FrozenDict(new_sizes) 1533 if up_sizes: 1534 state_up["sizes"] = (new_sizes, sizes) 1535 new_state["sizes"] = new_sizes 1536 1537 block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES} 1538 for (_,n), c in new_sizes.items(): 1539 block_index[n] = np.full(c, nat, dtype=np.int32) 1540 # block_index = { 1541 # n: np.full(c, nat, dtype=np.int32) 1542 # for (_,n), c in new_sizes.items() 1543 # } 1544 for s, c in zip(set_blocks, counts): 1545 if s < 0: 1546 continue 1547 block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0] 1548 1549 output = { 1550 **inputs, 1551 self.output_key: block_index, 1552 self.output_key + "_overflow": False, 1553 } 1554 1555 if return_state_update: 1556 return FrozenDict(new_state), output, state_up 1557 return FrozenDict(new_state), output 1558 1559 def check_reallocate(self, state, inputs, parent_overflow=False): 1560 """check for overflow and reallocate nblist if necessary""" 1561 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1562 if not overflow: 1563 return state, {}, inputs, False 1564 1565 add_margin = inputs[self.output_key + "_overflow"] 1566 state, inputs, state_up = self( 1567 state, inputs, return_state_update=True, add_margin=add_margin 1568 ) 1569 return state, state_up, inputs, True 1570 # return state, {}, inputs, parent_overflow 1571 1572 @partial(jax.jit, static_argnums=(0, 1)) 1573 def process(self, state, inputs): 1574 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1575 # assert ( 1576 # self.output_key in inputs 1577 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1578 1579 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1580 if self.output_key in inputs and not recompute_species_index: 1581 return inputs 1582 1583 if state is None: 1584 raise ValueError("Block Indexer state must be provided on accelerator.") 1585 1586 species = inputs["species"] 1587 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1588 nat = species.shape[0] 1589 1590 sizes = state["sizes"] 1591 1592 # species_index = { 1593 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1594 # for s, c in sizes.items() 1595 # } 1596 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1597 overflow = False 1598 natcount = 0 1599 for (s,name), c in sizes.items(): 1600 mask = blocks == s 1601 new_size = jnp.sum(mask) 1602 natcount = natcount + new_size 1603 overflow = overflow | (new_size > c) # check if sizes are correct 1604 block_index[name] = jnp.nonzero( 1605 mask, size=c, fill_value=nat 1606 )[0] 1607 1608 mask = blocks < 0 1609 new_size = jnp.sum(mask) 1610 natcount = natcount + new_size 1611 overflow = overflow | ( 1612 natcount < species.shape[0] 1613 ) # check if any species missing 1614 1615 return { 1616 **inputs, 1617 self.output_key: block_index, 1618 self.output_key + "_overflow": overflow, 1619 } 1620 1621 @partial(jax.jit, static_argnums=(0,)) 1622 def update_skin(self, inputs): 1623 return self.process(None, inputs) 1624 1625 1626@dataclasses.dataclass(frozen=True) 1627class AtomPadding: 1628 """Pad atomic arrays to a fixed size.""" 1629 1630 mult_size: float = 1.2 1631 """Multiplicative factor for resizing the atomic arrays.""" 1632 add_sys: int = 0 1633 1634 def init(self): 1635 return {"prev_nat": 0, "prev_nsys": 0} 1636 1637 def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]: 1638 species = inputs["species"] 1639 nat = species.shape[0] 1640 1641 prev_nat = state.get("prev_nat", 0) 1642 prev_nat_ = prev_nat 1643 if nat > prev_nat_: 1644 prev_nat_ = int(self.mult_size * nat) + 1 1645 1646 nsys = len(inputs["natoms"]) 1647 prev_nsys = state.get("prev_nsys", 0) 1648 prev_nsys_ = prev_nsys 1649 if nsys > prev_nsys_: 1650 prev_nsys_ = nsys + self.add_sys 1651 1652 add_atoms = prev_nat_ - nat 1653 add_sys = prev_nsys_ - nsys + 1 1654 output = {**inputs} 1655 if add_atoms > 0: 1656 for k, v in inputs.items(): 1657 if isinstance(v, np.ndarray) or isinstance(v, jax.Array): 1658 if v.shape[0] == nat: 1659 output[k] = np.append( 1660 v, 1661 np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype), 1662 axis=0, 1663 ) 1664 elif v.shape[0] == nsys: 1665 if k == "cells": 1666 output[k] = np.append( 1667 v, 1668 1000 1669 * np.eye(3, dtype=v.dtype)[None, :, :].repeat( 1670 add_sys, axis=0 1671 ), 1672 axis=0, 1673 ) 1674 else: 1675 output[k] = np.append( 1676 v, 1677 np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype), 1678 axis=0, 1679 ) 1680 output["natoms"] = np.append( 1681 inputs["natoms"], np.zeros(add_sys, dtype=np.int32) 1682 ) 1683 output["species"] = np.append( 1684 species, -1 * np.ones(add_atoms, dtype=species.dtype) 1685 ) 1686 output["batch_index"] = np.append( 1687 inputs["batch_index"], 1688 np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype), 1689 ) 1690 if "system_index" in inputs: 1691 output["system_index"] = np.append( 1692 inputs["system_index"], 1693 np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype), 1694 ) 1695 1696 output["true_atoms"] = output["species"] > 0 1697 output["true_sys"] = np.arange(len(output["natoms"])) < nsys 1698 1699 state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_} 1700 1701 return FrozenDict(state), output 1702 1703 1704def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]: 1705 """Remove padding from atomic arrays.""" 1706 if "true_atoms" not in inputs: 1707 return inputs 1708 1709 species = inputs["species"] 1710 true_atoms = inputs["true_atoms"] 1711 true_sys = inputs["true_sys"] 1712 natall = species.shape[0] 1713 nat = np.argmax(species <= 0) 1714 if nat == 0: 1715 return inputs 1716 1717 natoms = inputs["natoms"] 1718 nsysall = len(natoms) 1719 1720 output = {**inputs} 1721 for k, v in inputs.items(): 1722 if isinstance(v, jax.Array) or isinstance(v, np.ndarray): 1723 if v.ndim == 0: 1724 continue 1725 if v.shape[0] == natall: 1726 output[k] = v[true_atoms] 1727 elif v.shape[0] == nsysall: 1728 output[k] = v[true_sys] 1729 del output["true_sys"] 1730 del output["true_atoms"] 1731 return output 1732 1733 1734def check_input(inputs): 1735 """Check the input dictionary for required keys and types.""" 1736 assert "species" in inputs, "species must be provided" 1737 assert "coordinates" in inputs, "coordinates must be provided" 1738 species = inputs["species"].astype(np.int32) 1739 ifake = np.argmax(species <= 0) 1740 if ifake > 0: 1741 assert np.all(species[:ifake] > 0), "species must be positive" 1742 nat = inputs["species"].shape[0] 1743 1744 natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32) 1745 batch_index = inputs.get( 1746 "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms) 1747 ).astype(np.int32) 1748 output = {**inputs, "natoms": natoms, "batch_index": batch_index} 1749 if "cells" in inputs: 1750 cells = inputs["cells"] 1751 if "reciprocal_cells" not in inputs: 1752 reciprocal_cells = np.linalg.inv(cells) 1753 else: 1754 reciprocal_cells = inputs["reciprocal_cells"] 1755 if cells.ndim == 2: 1756 cells = cells[None, :, :] 1757 if reciprocal_cells.ndim == 2: 1758 reciprocal_cells = reciprocal_cells[None, :, :] 1759 output["cells"] = cells 1760 output["reciprocal_cells"] = reciprocal_cells 1761 1762 return output 1763 1764 1765def convert_to_jax(data): 1766 """Convert a numpy arrays to jax arrays in a pytree.""" 1767 1768 def convert(x): 1769 if isinstance(x, np.ndarray): 1770 # if x.dtype == np.float64: 1771 # return jnp.asarray(x, dtype=jnp.float32) 1772 return jnp.asarray(x) 1773 return x 1774 1775 return jax.tree_util.tree_map(convert, data) 1776 1777 1778class JaxConverter(nn.Module): 1779 """Convert numpy arrays to jax arrays in a pytree.""" 1780 1781 def __call__(self, data): 1782 return convert_to_jax(data) 1783 1784 1785@dataclasses.dataclass(frozen=True) 1786class PreprocessingChain: 1787 """Chain of preprocessing layers.""" 1788 1789 layers: Tuple[Callable[..., Dict[str, Any]]] 1790 """Preprocessing layers.""" 1791 use_atom_padding: bool = False 1792 """Add an AtomPadding layer at the beginning of the chain.""" 1793 atom_padder: AtomPadding = AtomPadding() 1794 """AtomPadding layer.""" 1795 1796 def __post_init__(self): 1797 if not isinstance(self.layers, Sequence): 1798 raise ValueError( 1799 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 1800 ) 1801 if not self.layers: 1802 raise ValueError(f"Error: no Preprocessing layers were provided.") 1803 1804 def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]: 1805 do_check_input = state.get("check_input", True) 1806 if do_check_input: 1807 inputs = check_input(inputs) 1808 new_state = [] 1809 layer_state = state["layers_state"] 1810 i = 0 1811 if self.use_atom_padding: 1812 s, inputs = self.atom_padder(layer_state[0], inputs) 1813 new_state.append(s) 1814 i += 1 1815 for layer in self.layers: 1816 s, inputs = layer(layer_state[i], inputs, return_state_update=False) 1817 new_state.append(s) 1818 i += 1 1819 return FrozenDict({**state, "layers_state": tuple(new_state)}), convert_to_jax( 1820 inputs 1821 ) 1822 1823 def check_reallocate(self, state, inputs): 1824 new_state = [] 1825 state_up = [] 1826 layer_state = state["layers_state"] 1827 i = 0 1828 if self.use_atom_padding: 1829 new_state.append(layer_state[0]) 1830 i += 1 1831 parent_overflow = False 1832 for layer in self.layers: 1833 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1834 layer_state[i], inputs, parent_overflow 1835 ) 1836 new_state.append(s) 1837 state_up.append(s_up) 1838 i += 1 1839 1840 if not parent_overflow: 1841 return state, {}, inputs, False 1842 return ( 1843 FrozenDict({**state, "layers_state": tuple(new_state)}), 1844 state_up, 1845 inputs, 1846 True, 1847 ) 1848 1849 def atom_padding(self, state, inputs): 1850 if self.use_atom_padding: 1851 padder_state = state["layers_state"][0] 1852 return self.atom_padder(padder_state, inputs) 1853 return state, inputs 1854 1855 @partial(jax.jit, static_argnums=(0, 1)) 1856 def process(self, state, inputs): 1857 layer_state = state["layers_state"] 1858 i = 1 if self.use_atom_padding else 0 1859 for layer in self.layers: 1860 inputs = layer.process(layer_state[i], inputs) 1861 i += 1 1862 return inputs 1863 1864 @partial(jax.jit, static_argnums=(0)) 1865 def update_skin(self, inputs): 1866 for layer in self.layers: 1867 inputs = layer.update_skin(inputs) 1868 return inputs 1869 1870 def init(self): 1871 state = [] 1872 if self.use_atom_padding: 1873 state.append(self.atom_padder.init()) 1874 for layer in self.layers: 1875 state.append(layer.init()) 1876 return FrozenDict({"check_input": True, "layers_state": state}) 1877 1878 def init_with_output(self, inputs): 1879 state = self.init() 1880 return self(state, inputs) 1881 1882 def get_processors(self): 1883 processors = [] 1884 for layer in self.layers: 1885 if hasattr(layer, "get_processor"): 1886 processors.append(layer.get_processor()) 1887 return processors 1888 1889 def get_graphs_properties(self): 1890 properties = {} 1891 for layer in self.layers: 1892 if hasattr(layer, "get_graph_properties"): 1893 properties = deep_update(properties, layer.get_graph_properties()) 1894 return properties 1895 1896 1897# PREPROCESSING = { 1898# "GRAPH": GraphGenerator, 1899# # "GRAPH_FIXED": GraphGeneratorFixed, 1900# "GRAPH_FILTER": GraphFilter, 1901# "GRAPH_ANGULAR_EXTENSION": GraphAngularExtension, 1902# # "GRAPH_DENSE_EXTENSION": GraphDenseExtension, 1903# "SPECIES_INDEXER": SpeciesIndexer, 1904# }
22@dataclasses.dataclass(frozen=True) 23class GraphGenerator: 24 """Generate a graph from a set of coordinates 25 26 FPID: GRAPH 27 28 For now, we generate all pairs of atoms and filter based on cutoff. 29 If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist. 30 """ 31 32 cutoff: float 33 """Cutoff distance for the graph.""" 34 graph_key: str = "graph" 35 """Key of the graph in the outputs.""" 36 switch_params: dict = dataclasses.field(default_factory=dict, hash=False) 37 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 38 kmax: int = 30 39 """Maximum number of k-points to consider.""" 40 kthr: float = 1e-6 41 """Threshold for k-point filtering.""" 42 k_space: bool = False 43 """Whether to generate k-space information for the graph.""" 44 mult_size: float = 1.05 45 """Multiplicative factor for resizing the nblist.""" 46 # covalent_cutoff: bool = False 47 48 FPID: ClassVar[str] = "GRAPH" 49 50 def init(self): 51 return FrozenDict( 52 { 53 "max_nat": 1, 54 "npairs": 1, 55 "nblist_mult_size": self.mult_size, 56 } 57 ) 58 59 def get_processor(self) -> Tuple[nn.Module, Dict]: 60 return GraphProcessor, { 61 "cutoff": self.cutoff, 62 "graph_key": self.graph_key, 63 "switch_params": self.switch_params, 64 "name": f"{self.graph_key}_Processor", 65 } 66 67 def get_graph_properties(self): 68 return { 69 self.graph_key: { 70 "cutoff": self.cutoff, 71 "directed": True, 72 } 73 } 74 75 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 76 """build a nblist on cpu with numpy and dynamic shapes + store max shapes""" 77 if self.graph_key in inputs: 78 graph = inputs[self.graph_key] 79 if "keep_graph" in graph: 80 return state, inputs 81 82 coords = np.array(inputs["coordinates"], dtype=np.float32) 83 natoms = np.array(inputs["natoms"], dtype=np.int32) 84 batch_index = np.array(inputs["batch_index"], dtype=np.int32) 85 86 new_state = {**state} 87 state_up = {} 88 89 mult_size = state.get("nblist_mult_size", self.mult_size) 90 assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0" 91 92 if natoms.shape[0] == 1: 93 max_nat = coords.shape[0] 94 true_max_nat = max_nat 95 else: 96 max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0])) 97 true_max_nat = int(np.max(natoms)) 98 if true_max_nat > max_nat: 99 add_atoms = state.get("add_atoms", 0) 100 new_maxnat = true_max_nat + add_atoms 101 state_up["max_nat"] = (new_maxnat, max_nat) 102 new_state["max_nat"] = new_maxnat 103 104 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 105 106 ### compute indices of all pairs 107 p1, p2 = np.triu_indices(true_max_nat, 1) 108 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 109 pbc_shifts = None 110 if natoms.shape[0] > 1: 111 ## batching => mask irrelevant pairs 112 mask_p12 = ( 113 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 114 ).flatten() 115 shift = np.concatenate( 116 (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32)) 117 ) 118 p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 119 p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 120 121 apply_pbc = "cells" in inputs 122 if not apply_pbc: 123 ### NO PBC 124 vec = coords[p2] - coords[p1] 125 else: 126 cells = np.array(inputs["cells"], dtype=np.float32) 127 reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32) 128 minimage = state.get("minimum_image", True) 129 if minimage: 130 ## MINIMUM IMAGE CONVENTION 131 vec = coords[p2] - coords[p1] 132 if cells.shape[0] == 1: 133 vecpbc = np.dot(vec, reciprocal_cells[0]) 134 pbc_shifts = -np.round(vecpbc) 135 vec = vec + np.dot(pbc_shifts, cells[0]) 136 else: 137 batch_index_vec = batch_index[p1] 138 vecpbc = np.einsum( 139 "aj,aji->ai", vec, reciprocal_cells[batch_index_vec] 140 ) 141 pbc_shifts = -np.round(vecpbc) 142 vec = vec + np.einsum( 143 "aj,aji->ai", pbc_shifts, cells[batch_index_vec] 144 ) 145 else: 146 ### GENERAL PBC 147 ## put all atoms in central box 148 if cells.shape[0] == 1: 149 coords_pbc = np.dot(coords, reciprocal_cells[0]) 150 at_shifts = -np.floor(coords_pbc) 151 coords_pbc = coords + np.dot(at_shifts, cells[0]) 152 else: 153 coords_pbc = np.einsum( 154 "aj,aji->ai", coords, reciprocal_cells[batch_index] 155 ) 156 at_shifts = -np.floor(coords_pbc) 157 coords_pbc = coords + np.einsum( 158 "aj,aji->ai", at_shifts, cells[batch_index] 159 ) 160 vec = coords_pbc[p2] - coords_pbc[p1] 161 162 ## compute maximum number of repeats 163 inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5 164 cdinv = cutoff_skin * inv_distances 165 num_repeats_all = np.ceil(cdinv).astype(np.int32) 166 if "true_sys" in inputs: 167 num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0) 168 # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all) 169 num_repeats = np.max(num_repeats_all, axis=0) 170 num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0))) 171 if np.any(num_repeats > num_repeats_prev): 172 num_repeats_new = np.maximum(num_repeats, num_repeats_prev) 173 state_up["num_repeats_pbc"] = ( 174 tuple(num_repeats_new), 175 tuple(num_repeats_prev), 176 ) 177 new_state["num_repeats_pbc"] = tuple(num_repeats_new) 178 ## build all possible shifts 179 cell_shift_pbc = np.array( 180 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 181 dtype=cells.dtype, 182 ).T.reshape(-1, 3) 183 ## shift applied to vectors 184 if cells.shape[0] == 1: 185 dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :] 186 vec = (vec[:, None, :] + dvec).reshape(-1, 3) 187 pbc_shifts = np.broadcast_to( 188 cell_shift_pbc[None, :, :], 189 (p1.shape[0], cell_shift_pbc.shape[0], 3), 190 ).reshape(-1, 3) 191 p1 = np.broadcast_to( 192 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 193 ).flatten() 194 p2 = np.broadcast_to( 195 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 196 ).flatten() 197 if natoms.shape[0] > 1: 198 mask_p12 = np.broadcast_to( 199 mask_p12[:, None], 200 (mask_p12.shape[0], cell_shift_pbc.shape[0]), 201 ).flatten() 202 else: 203 dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells) 204 205 ## get pbc shifts specific to each box 206 cell_shift_pbc = np.broadcast_to( 207 cell_shift_pbc[None, :, :], 208 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 209 ) 210 mask = np.all( 211 np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 212 ).flatten() 213 idx = np.nonzero(mask)[0] 214 nshifts = idx.shape[0] 215 nshifts_prev = state.get("nshifts_pbc", 0) 216 if nshifts > nshifts_prev or add_margin: 217 nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1 218 state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev) 219 new_state["nshifts_pbc"] = nshifts_new 220 221 dvec_filter = dvec.reshape(-1, 3)[idx, :] 222 cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :] 223 224 ## get batch shift in the dvec_filter array 225 nrep = np.prod(2 * num_repeats_all + 1, axis=1) 226 bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1])) 227 228 ## compute vectors 229 batch_index_vec = batch_index[p1] 230 nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0) 231 vec = vec.repeat(nrep_vec, axis=0) 232 nvec_pbc = nrep_vec.sum() #vec.shape[0] 233 nvec_pbc_prev = state.get("nvec_pbc", 0) 234 if nvec_pbc > nvec_pbc_prev or add_margin: 235 nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1 236 state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev) 237 new_state["nvec_pbc"] = nvec_pbc_new 238 239 # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev) 240 ## get shift index 241 dshift = np.concatenate( 242 (np.array([0]), np.cumsum(nrep_vec)[:-1]) 243 ).repeat(nrep_vec) 244 # ishift = np.arange(dshift.shape[0])-dshift 245 # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec) 246 icellshift = ( 247 np.arange(dshift.shape[0]) 248 - dshift 249 + bshift[batch_index_vec].repeat(nrep_vec) 250 ) 251 # shift vectors 252 vec = vec + dvec_filter[icellshift] 253 pbc_shifts = cell_shift_pbc_filter[icellshift] 254 255 p1 = np.repeat(p1, nrep_vec) 256 p2 = np.repeat(p2, nrep_vec) 257 if natoms.shape[0] > 1: 258 mask_p12 = np.repeat(mask_p12, nrep_vec) 259 260 ## compute distances 261 d12 = (vec**2).sum(axis=-1) 262 if natoms.shape[0] > 1: 263 d12 = np.where(mask_p12, d12, cutoff_skin**2) 264 265 ## filter pairs 266 max_pairs = state.get("npairs", 1) 267 mask = d12 < cutoff_skin**2 268 idx = np.nonzero(mask)[0] 269 npairs = idx.shape[0] 270 if npairs > max_pairs or add_margin: 271 prev_max_pairs = max_pairs 272 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 273 state_up["npairs"] = (max_pairs, prev_max_pairs) 274 new_state["npairs"] = max_pairs 275 276 nat = coords.shape[0] 277 edge_src = np.full(max_pairs, nat, dtype=np.int32) 278 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 279 d12_ = np.full(max_pairs, cutoff_skin**2) 280 edge_src[:npairs] = p1[idx] 281 edge_dst[:npairs] = p2[idx] 282 d12_[:npairs] = d12[idx] 283 d12 = d12_ 284 285 if apply_pbc: 286 pbc_shifts_ = np.zeros((max_pairs, 3)) 287 pbc_shifts_[:npairs] = pbc_shifts[idx] 288 pbc_shifts = pbc_shifts_ 289 if not minimage: 290 pbc_shifts[:npairs] = ( 291 pbc_shifts[:npairs] 292 + at_shifts[edge_dst[:npairs]] 293 - at_shifts[edge_src[:npairs]] 294 ) 295 296 if "nblist_skin" in state: 297 edge_src_skin = edge_src 298 edge_dst_skin = edge_dst 299 if apply_pbc: 300 pbc_shifts_skin = pbc_shifts 301 max_pairs_skin = state.get("npairs_skin", 1) 302 mask = d12 < self.cutoff**2 303 idx = np.nonzero(mask)[0] 304 npairs_skin = idx.shape[0] 305 if npairs_skin > max_pairs_skin or add_margin: 306 prev_max_pairs_skin = max_pairs_skin 307 max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1 308 state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin) 309 new_state["npairs_skin"] = max_pairs_skin 310 edge_src = np.full(max_pairs_skin, nat, dtype=np.int32) 311 edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32) 312 d12_ = np.full(max_pairs_skin, self.cutoff**2) 313 edge_src[:npairs_skin] = edge_src_skin[idx] 314 edge_dst[:npairs_skin] = edge_dst_skin[idx] 315 d12_[:npairs_skin] = d12[idx] 316 d12 = d12_ 317 if apply_pbc: 318 pbc_shifts = np.full((max_pairs_skin, 3), 0.0) 319 pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx] 320 321 ## symmetrize 322 edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate( 323 (edge_dst, edge_src) 324 ) 325 d12 = np.concatenate((d12, d12)) 326 if apply_pbc: 327 pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts)) 328 329 graph = inputs.get(self.graph_key, {}) 330 graph_out = { 331 **graph, 332 "edge_src": edge_src, 333 "edge_dst": edge_dst, 334 "d12": d12, 335 "overflow": False, 336 "pbc_shifts": pbc_shifts, 337 } 338 if "nblist_skin" in state: 339 graph_out["edge_src_skin"] = edge_src_skin 340 graph_out["edge_dst_skin"] = edge_dst_skin 341 if apply_pbc: 342 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 343 344 if self.k_space and apply_pbc: 345 if "k_points" not in graph: 346 ks, _, _, bewald = get_reciprocal_space_parameters( 347 reciprocal_cells, self.cutoff, self.kmax, self.kthr 348 ) 349 graph_out["k_points"] = ks 350 graph_out["b_ewald"] = bewald 351 352 output = {**inputs, self.graph_key: graph_out} 353 354 if return_state_update: 355 return FrozenDict(new_state), output, state_up 356 return FrozenDict(new_state), output 357 358 def check_reallocate(self, state, inputs, parent_overflow=False): 359 """check for overflow and reallocate nblist if necessary""" 360 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 361 if not overflow: 362 return state, {}, inputs, False 363 364 add_margin = inputs[self.graph_key].get("overflow", False) 365 state, inputs, state_up = self( 366 state, inputs, return_state_update=True, add_margin=add_margin 367 ) 368 return state, state_up, inputs, True 369 370 @partial(jax.jit, static_argnums=(0, 1)) 371 def process(self, state, inputs): 372 """build a nblist on accelerator with jax and precomputed shapes""" 373 if self.graph_key in inputs: 374 graph = inputs[self.graph_key] 375 if "keep_graph" in graph: 376 return inputs 377 coords = inputs["coordinates"] 378 natoms = inputs["natoms"] 379 batch_index = inputs["batch_index"] 380 381 if natoms.shape[0] == 1: 382 max_nat = coords.shape[0] 383 else: 384 max_nat = state.get( 385 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 386 ) 387 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 388 389 ### compute indices of all pairs 390 p1, p2 = np.triu_indices(max_nat, 1) 391 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 392 pbc_shifts = None 393 if natoms.shape[0] > 1: 394 ## batching => mask irrelevant pairs 395 mask_p12 = ( 396 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 397 ).flatten() 398 shift = jnp.concatenate( 399 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 400 ) 401 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 402 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 403 404 ## compute vectors 405 overflow_repeats = jnp.asarray(False, dtype=bool) 406 if "cells" not in inputs: 407 vec = coords[p2] - coords[p1] 408 else: 409 cells = inputs["cells"] 410 reciprocal_cells = inputs["reciprocal_cells"] 411 minimage = state.get("minimum_image", True) 412 413 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 414 vecpbc = jnp.dot(vec, reciprocal_cell) 415 if mode == "round": 416 pbc_shifts = -jnp.round(vecpbc) 417 elif mode == "floor": 418 pbc_shifts = -jnp.floor(vecpbc) 419 else: 420 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 421 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 422 423 if minimage: 424 ## minimum image convention 425 vec = coords[p2] - coords[p1] 426 427 if cells.shape[0] == 1: 428 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 429 else: 430 batch_index_vec = batch_index[p1] 431 vec, pbc_shifts = jax.vmap(compute_pbc)( 432 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 433 ) 434 else: 435 ### general PBC only for single cell yet 436 # if cells.shape[0] > 1: 437 # raise NotImplementedError( 438 # "General PBC not implemented for batches on accelerator." 439 # ) 440 # cell = cells[0] 441 # reciprocal_cell = reciprocal_cells[0] 442 443 ## put all atoms in central box 444 if cells.shape[0] == 1: 445 coords_pbc, at_shifts = compute_pbc( 446 coords, reciprocal_cells[0], cells[0], mode="floor" 447 ) 448 else: 449 coords_pbc, at_shifts = jax.vmap( 450 partial(compute_pbc, mode="floor") 451 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 452 vec = coords_pbc[p2] - coords_pbc[p1] 453 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 454 # if num_repeats is None: 455 # raise ValueError( 456 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 457 # ) 458 # check if num_repeats is larger than previous 459 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 460 cdinv = cutoff_skin * inv_distances 461 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 462 if "true_sys" in inputs: 463 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 464 num_repeats_new = jnp.max(num_repeats_all, axis=0) 465 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 466 467 cell_shift_pbc = jnp.asarray( 468 np.array( 469 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 470 dtype=cells.dtype, 471 ).T.reshape(-1, 3) 472 ) 473 474 if cells.shape[0] == 1: 475 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 476 pbc_shifts = jnp.broadcast_to( 477 cell_shift_pbc[None, :, :], 478 (p1.shape[0], cell_shift_pbc.shape[0], 3), 479 ).reshape(-1, 3) 480 p1 = jnp.broadcast_to( 481 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 482 ).flatten() 483 p2 = jnp.broadcast_to( 484 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 485 ).flatten() 486 if natoms.shape[0] > 1: 487 mask_p12 = jnp.broadcast_to( 488 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 489 ).flatten() 490 else: 491 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 492 493 ## get pbc shifts specific to each box 494 cell_shift_pbc = jnp.broadcast_to( 495 cell_shift_pbc[None, :, :], 496 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 497 ) 498 mask = jnp.all( 499 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 500 ).flatten() 501 max_shifts = state.get("nshifts_pbc", 1) 502 503 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 504 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 505 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 506 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 507 mask, 508 max_shifts, 509 (dvecx, 0.), 510 (dvecy, 0.), 511 (dvecz, 0.), 512 (shiftx, 0), 513 (shifty, 0), 514 (shiftz, 0), 515 ) 516 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 517 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 518 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 519 520 ## get batch shift in the dvec_filter array 521 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 522 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 523 524 ## repeat vectors 525 nvec_max = state.get("nvec_pbc", 1) 526 batch_index_vec = batch_index[p1] 527 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 528 nvec = nrep_vec.sum() 529 overflow_repeats = overflow_repeats | (nvec > nvec_max) 530 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 531 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 532 533 ## get shift index 534 dshift = jnp.concatenate( 535 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 536 ) 537 if nrep_vec.size == 0: 538 dshift = jnp.array([],dtype=jnp.int32) 539 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 540 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 541 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 542 vec = vec + dvec[icellshift] 543 pbc_shifts = cell_shift_pbc[icellshift] 544 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 545 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 546 if natoms.shape[0] > 1: 547 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 548 549 550 ## compute distances 551 d12 = (vec**2).sum(axis=-1) 552 if natoms.shape[0] > 1: 553 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 554 555 ## filter pairs 556 max_pairs = state.get("npairs", 1) 557 mask = d12 < cutoff_skin**2 558 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 559 mask, 560 max_pairs, 561 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 562 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 563 (d12, cutoff_skin**2), 564 ) 565 if "cells" in inputs: 566 pbc_shifts = ( 567 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 568 .at[scatter_idx] 569 .set(pbc_shifts, mode="drop") 570 ) 571 if not minimage: 572 pbc_shifts = ( 573 pbc_shifts 574 + at_shifts.at[edge_dst].get(fill_value=0.0) 575 - at_shifts.at[edge_src].get(fill_value=0.0) 576 ) 577 578 ## check for overflow 579 if natoms.shape[0] == 1: 580 true_max_nat = coords.shape[0] 581 else: 582 true_max_nat = jnp.max(natoms) 583 overflow_count = npairs > max_pairs 584 overflow_at = true_max_nat > max_nat 585 overflow = overflow_count | overflow_at | overflow_repeats 586 587 if "nblist_skin" in state: 588 # edge_mask_skin = edge_mask 589 edge_src_skin = edge_src 590 edge_dst_skin = edge_dst 591 if "cells" in inputs: 592 pbc_shifts_skin = pbc_shifts 593 max_pairs_skin = state.get("npairs_skin", 1) 594 mask = d12 < self.cutoff**2 595 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 596 mask, 597 max_pairs_skin, 598 (edge_src, coords.shape[0]), 599 (edge_dst, coords.shape[0]), 600 (d12, self.cutoff**2), 601 ) 602 if "cells" in inputs: 603 pbc_shifts = ( 604 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 605 .at[scatter_idx] 606 .set(pbc_shifts, mode="drop") 607 ) 608 overflow = overflow | (npairs_skin > max_pairs_skin) 609 610 ## symmetrize 611 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 612 (edge_dst, edge_src) 613 ) 614 d12 = jnp.concatenate((d12, d12)) 615 if "cells" in inputs: 616 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 617 618 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 619 graph_out = { 620 **graph, 621 "edge_src": edge_src, 622 "edge_dst": edge_dst, 623 "d12": d12, 624 "overflow": overflow, 625 "pbc_shifts": pbc_shifts, 626 } 627 if "nblist_skin" in state: 628 graph_out["edge_src_skin"] = edge_src_skin 629 graph_out["edge_dst_skin"] = edge_dst_skin 630 if "cells" in inputs: 631 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 632 633 if self.k_space and "cells" in inputs: 634 if "k_points" not in graph: 635 raise NotImplementedError( 636 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 637 ) 638 return {**inputs, self.graph_key: graph_out} 639 640 @partial(jax.jit, static_argnums=(0,)) 641 def update_skin(self, inputs): 642 """update the nblist without recomputing the full nblist""" 643 graph = inputs[self.graph_key] 644 645 edge_src_skin = graph["edge_src_skin"] 646 edge_dst_skin = graph["edge_dst_skin"] 647 coords = inputs["coordinates"] 648 vec = coords.at[edge_dst_skin].get( 649 mode="fill", fill_value=self.cutoff 650 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 651 652 if "cells" in inputs: 653 pbc_shifts_skin = graph["pbc_shifts_skin"] 654 cells = inputs["cells"] 655 if cells.shape[0] == 1: 656 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 657 else: 658 batch_index_vec = inputs["batch_index"][edge_src_skin] 659 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 660 661 nat = coords.shape[0] 662 d12 = jnp.sum(vec**2, axis=-1) 663 mask = d12 < self.cutoff**2 664 max_pairs = graph["edge_src"].shape[0] // 2 665 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 666 mask, 667 max_pairs, 668 (edge_src_skin, nat), 669 (edge_dst_skin, nat), 670 (d12, self.cutoff**2), 671 ) 672 if "cells" in inputs: 673 pbc_shifts = ( 674 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 675 .at[scatter_idx] 676 .set(pbc_shifts_skin) 677 ) 678 679 overflow = graph.get("overflow", False) | (npairs > max_pairs) 680 graph_out = { 681 **graph, 682 "edge_src": jnp.concatenate((edge_src, edge_dst)), 683 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 684 "d12": jnp.concatenate((d12, d12)), 685 "overflow": overflow, 686 } 687 if "cells" in inputs: 688 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 689 690 if self.k_space and "cells" in inputs: 691 if "k_points" not in graph: 692 raise NotImplementedError( 693 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 694 ) 695 696 return {**inputs, self.graph_key: graph_out}
Generate a graph from a set of coordinates
FPID: GRAPH
For now, we generate all pairs of atoms and filter based on cutoff.
If a nblist_skin
is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the update_skin
method to update the original graph without recomputing the full nblist.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
358 def check_reallocate(self, state, inputs, parent_overflow=False): 359 """check for overflow and reallocate nblist if necessary""" 360 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 361 if not overflow: 362 return state, {}, inputs, False 363 364 add_margin = inputs[self.graph_key].get("overflow", False) 365 state, inputs, state_up = self( 366 state, inputs, return_state_update=True, add_margin=add_margin 367 ) 368 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
370 @partial(jax.jit, static_argnums=(0, 1)) 371 def process(self, state, inputs): 372 """build a nblist on accelerator with jax and precomputed shapes""" 373 if self.graph_key in inputs: 374 graph = inputs[self.graph_key] 375 if "keep_graph" in graph: 376 return inputs 377 coords = inputs["coordinates"] 378 natoms = inputs["natoms"] 379 batch_index = inputs["batch_index"] 380 381 if natoms.shape[0] == 1: 382 max_nat = coords.shape[0] 383 else: 384 max_nat = state.get( 385 "max_nat", int(round(coords.shape[0] / natoms.shape[0])) 386 ) 387 cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0) 388 389 ### compute indices of all pairs 390 p1, p2 = np.triu_indices(max_nat, 1) 391 p1, p2 = p1.astype(np.int32), p2.astype(np.int32) 392 pbc_shifts = None 393 if natoms.shape[0] > 1: 394 ## batching => mask irrelevant pairs 395 mask_p12 = ( 396 (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None]) 397 ).flatten() 398 shift = jnp.concatenate( 399 (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1])) 400 ) 401 p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1) 402 p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1) 403 404 ## compute vectors 405 overflow_repeats = jnp.asarray(False, dtype=bool) 406 if "cells" not in inputs: 407 vec = coords[p2] - coords[p1] 408 else: 409 cells = inputs["cells"] 410 reciprocal_cells = inputs["reciprocal_cells"] 411 minimage = state.get("minimum_image", True) 412 413 def compute_pbc(vec, reciprocal_cell, cell, mode="round"): 414 vecpbc = jnp.dot(vec, reciprocal_cell) 415 if mode == "round": 416 pbc_shifts = -jnp.round(vecpbc) 417 elif mode == "floor": 418 pbc_shifts = -jnp.floor(vecpbc) 419 else: 420 raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.") 421 return vec + jnp.dot(pbc_shifts, cell), pbc_shifts 422 423 if minimage: 424 ## minimum image convention 425 vec = coords[p2] - coords[p1] 426 427 if cells.shape[0] == 1: 428 vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0]) 429 else: 430 batch_index_vec = batch_index[p1] 431 vec, pbc_shifts = jax.vmap(compute_pbc)( 432 vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec] 433 ) 434 else: 435 ### general PBC only for single cell yet 436 # if cells.shape[0] > 1: 437 # raise NotImplementedError( 438 # "General PBC not implemented for batches on accelerator." 439 # ) 440 # cell = cells[0] 441 # reciprocal_cell = reciprocal_cells[0] 442 443 ## put all atoms in central box 444 if cells.shape[0] == 1: 445 coords_pbc, at_shifts = compute_pbc( 446 coords, reciprocal_cells[0], cells[0], mode="floor" 447 ) 448 else: 449 coords_pbc, at_shifts = jax.vmap( 450 partial(compute_pbc, mode="floor") 451 )(coords, reciprocal_cells[batch_index], cells[batch_index]) 452 vec = coords_pbc[p2] - coords_pbc[p1] 453 num_repeats = state.get("num_repeats_pbc", (0, 0, 0)) 454 # if num_repeats is None: 455 # raise ValueError( 456 # "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first." 457 # ) 458 # check if num_repeats is larger than previous 459 inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1) 460 cdinv = cutoff_skin * inv_distances 461 num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32) 462 if "true_sys" in inputs: 463 num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0) 464 num_repeats_new = jnp.max(num_repeats_all, axis=0) 465 overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats)) 466 467 cell_shift_pbc = jnp.asarray( 468 np.array( 469 np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]), 470 dtype=cells.dtype, 471 ).T.reshape(-1, 3) 472 ) 473 474 if cells.shape[0] == 1: 475 vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3) 476 pbc_shifts = jnp.broadcast_to( 477 cell_shift_pbc[None, :, :], 478 (p1.shape[0], cell_shift_pbc.shape[0], 3), 479 ).reshape(-1, 3) 480 p1 = jnp.broadcast_to( 481 p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0]) 482 ).flatten() 483 p2 = jnp.broadcast_to( 484 p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0]) 485 ).flatten() 486 if natoms.shape[0] > 1: 487 mask_p12 = jnp.broadcast_to( 488 mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0]) 489 ).flatten() 490 else: 491 dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3) 492 493 ## get pbc shifts specific to each box 494 cell_shift_pbc = jnp.broadcast_to( 495 cell_shift_pbc[None, :, :], 496 (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3), 497 ) 498 mask = jnp.all( 499 jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1 500 ).flatten() 501 max_shifts = state.get("nshifts_pbc", 1) 502 503 cell_shift_pbc = cell_shift_pbc.reshape(-1,3) 504 shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2] 505 dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2] 506 (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d( 507 mask, 508 max_shifts, 509 (dvecx, 0.), 510 (dvecy, 0.), 511 (dvecz, 0.), 512 (shiftx, 0), 513 (shifty, 0), 514 (shiftz, 0), 515 ) 516 dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1) 517 cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1) 518 overflow_repeats = overflow_repeats | (nshifts > max_shifts) 519 520 ## get batch shift in the dvec_filter array 521 nrep = jnp.prod(2 * num_repeats_all + 1, axis=1) 522 bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1])) 523 524 ## repeat vectors 525 nvec_max = state.get("nvec_pbc", 1) 526 batch_index_vec = batch_index[p1] 527 nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0) 528 nvec = nrep_vec.sum() 529 overflow_repeats = overflow_repeats | (nvec > nvec_max) 530 vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max) 531 # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts)) 532 533 ## get shift index 534 dshift = jnp.concatenate( 535 (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1]) 536 ) 537 if nrep_vec.size == 0: 538 dshift = jnp.array([],dtype=jnp.int32) 539 dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max) 540 bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max) 541 icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift 542 vec = vec + dvec[icellshift] 543 pbc_shifts = cell_shift_pbc[icellshift] 544 p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max) 545 p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max) 546 if natoms.shape[0] > 1: 547 mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max) 548 549 550 ## compute distances 551 d12 = (vec**2).sum(axis=-1) 552 if natoms.shape[0] > 1: 553 d12 = jnp.where(mask_p12, d12, cutoff_skin**2) 554 555 ## filter pairs 556 max_pairs = state.get("npairs", 1) 557 mask = d12 < cutoff_skin**2 558 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 559 mask, 560 max_pairs, 561 (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]), 562 (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]), 563 (d12, cutoff_skin**2), 564 ) 565 if "cells" in inputs: 566 pbc_shifts = ( 567 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype) 568 .at[scatter_idx] 569 .set(pbc_shifts, mode="drop") 570 ) 571 if not minimage: 572 pbc_shifts = ( 573 pbc_shifts 574 + at_shifts.at[edge_dst].get(fill_value=0.0) 575 - at_shifts.at[edge_src].get(fill_value=0.0) 576 ) 577 578 ## check for overflow 579 if natoms.shape[0] == 1: 580 true_max_nat = coords.shape[0] 581 else: 582 true_max_nat = jnp.max(natoms) 583 overflow_count = npairs > max_pairs 584 overflow_at = true_max_nat > max_nat 585 overflow = overflow_count | overflow_at | overflow_repeats 586 587 if "nblist_skin" in state: 588 # edge_mask_skin = edge_mask 589 edge_src_skin = edge_src 590 edge_dst_skin = edge_dst 591 if "cells" in inputs: 592 pbc_shifts_skin = pbc_shifts 593 max_pairs_skin = state.get("npairs_skin", 1) 594 mask = d12 < self.cutoff**2 595 (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d( 596 mask, 597 max_pairs_skin, 598 (edge_src, coords.shape[0]), 599 (edge_dst, coords.shape[0]), 600 (d12, self.cutoff**2), 601 ) 602 if "cells" in inputs: 603 pbc_shifts = ( 604 jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype) 605 .at[scatter_idx] 606 .set(pbc_shifts, mode="drop") 607 ) 608 overflow = overflow | (npairs_skin > max_pairs_skin) 609 610 ## symmetrize 611 edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate( 612 (edge_dst, edge_src) 613 ) 614 d12 = jnp.concatenate((d12, d12)) 615 if "cells" in inputs: 616 pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts)) 617 618 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 619 graph_out = { 620 **graph, 621 "edge_src": edge_src, 622 "edge_dst": edge_dst, 623 "d12": d12, 624 "overflow": overflow, 625 "pbc_shifts": pbc_shifts, 626 } 627 if "nblist_skin" in state: 628 graph_out["edge_src_skin"] = edge_src_skin 629 graph_out["edge_dst_skin"] = edge_dst_skin 630 if "cells" in inputs: 631 graph_out["pbc_shifts_skin"] = pbc_shifts_skin 632 633 if self.k_space and "cells" in inputs: 634 if "k_points" not in graph: 635 raise NotImplementedError( 636 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 637 ) 638 return {**inputs, self.graph_key: graph_out}
build a nblist on accelerator with jax and precomputed shapes
640 @partial(jax.jit, static_argnums=(0,)) 641 def update_skin(self, inputs): 642 """update the nblist without recomputing the full nblist""" 643 graph = inputs[self.graph_key] 644 645 edge_src_skin = graph["edge_src_skin"] 646 edge_dst_skin = graph["edge_dst_skin"] 647 coords = inputs["coordinates"] 648 vec = coords.at[edge_dst_skin].get( 649 mode="fill", fill_value=self.cutoff 650 ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0) 651 652 if "cells" in inputs: 653 pbc_shifts_skin = graph["pbc_shifts_skin"] 654 cells = inputs["cells"] 655 if cells.shape[0] == 1: 656 vec = vec + jnp.dot(pbc_shifts_skin, cells[0]) 657 else: 658 batch_index_vec = inputs["batch_index"][edge_src_skin] 659 vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec]) 660 661 nat = coords.shape[0] 662 d12 = jnp.sum(vec**2, axis=-1) 663 mask = d12 < self.cutoff**2 664 max_pairs = graph["edge_src"].shape[0] // 2 665 (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d( 666 mask, 667 max_pairs, 668 (edge_src_skin, nat), 669 (edge_dst_skin, nat), 670 (d12, self.cutoff**2), 671 ) 672 if "cells" in inputs: 673 pbc_shifts = ( 674 jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype) 675 .at[scatter_idx] 676 .set(pbc_shifts_skin) 677 ) 678 679 overflow = graph.get("overflow", False) | (npairs > max_pairs) 680 graph_out = { 681 **graph, 682 "edge_src": jnp.concatenate((edge_src, edge_dst)), 683 "edge_dst": jnp.concatenate((edge_dst, edge_src)), 684 "d12": jnp.concatenate((d12, d12)), 685 "overflow": overflow, 686 } 687 if "cells" in inputs: 688 graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts)) 689 690 if self.k_space and "cells" in inputs: 691 if "k_points" not in graph: 692 raise NotImplementedError( 693 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 694 ) 695 696 return {**inputs, self.graph_key: graph_out}
update the nblist without recomputing the full nblist
699class GraphProcessor(nn.Module): 700 """Process a pre-generated graph 701 702 The pre-generated graph should contain the following keys: 703 - edge_src: source indices of the edges 704 - edge_dst: destination indices of the edges 705 - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs) 706 707 This module is automatically added to a FENNIX model when a GraphGenerator is used. 708 709 """ 710 711 cutoff: float 712 """Cutoff distance for the graph.""" 713 graph_key: str = "graph" 714 """Key of the graph in the outputs.""" 715 switch_params: dict = dataclasses.field(default_factory=dict) 716 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 717 718 @nn.compact 719 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 720 graph = inputs[self.graph_key] 721 coords = inputs["coordinates"] 722 edge_src, edge_dst = graph["edge_src"], graph["edge_dst"] 723 # edge_mask = edge_src < coords.shape[0] 724 vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[ 725 edge_src 726 ].get(mode="fill", fill_value=0.0) 727 if "cells" in inputs: 728 cells = inputs["cells"] 729 if cells.shape[0] == 1: 730 vec = vec + jnp.dot(graph["pbc_shifts"], cells[0]) 731 else: 732 batch_index_vec = inputs["batch_index"][edge_src] 733 vec = vec + jax.vmap(jnp.dot)( 734 graph["pbc_shifts"], cells[batch_index_vec] 735 ) 736 737 distances = jnp.linalg.norm(vec, axis=-1) 738 edge_mask = distances < self.cutoff 739 740 switch = SwitchFunction( 741 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 742 )((distances, edge_mask)) 743 744 graph_out = { 745 **graph, 746 "vec": vec, 747 "distances": distances, 748 "switch": switch, 749 "edge_mask": edge_mask, 750 } 751 752 if "alch_group" in inputs: 753 alch_group = inputs["alch_group"] 754 lambda_e = inputs["alch_elambda"] 755 lambda_e = 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) 756 mask = alch_group[edge_src] == alch_group[edge_dst] 757 graph_out["switch_raw"] = switch 758 graph_out["switch"] = jnp.where( 759 mask, 760 switch, 761 lambda_e * switch , 762 ) 763 764 765 return {**inputs, self.graph_key: graph_out}
Process a pre-generated graph
The pre-generated graph should contain the following keys:
- edge_src: source indices of the edges
- edge_dst: destination indices of the edges
- pbcs_shifts: pbc shifts for the edges (only if
cells
are present in the inputs)
This module is automatically added to a FENNIX model when a GraphGenerator is used.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
768@dataclasses.dataclass(frozen=True) 769class GraphFilter: 770 """Filter a graph based on a cutoff distance 771 772 FPID: GRAPH_FILTER 773 """ 774 775 cutoff: float 776 """Cutoff distance for the filtering.""" 777 parent_graph: str 778 """Key of the parent graph in the inputs.""" 779 graph_key: str 780 """Key of the filtered graph in the outputs.""" 781 remove_hydrogens: int = False 782 """Remove edges where the source is a hydrogen atom.""" 783 switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict) 784 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 785 k_space: bool = False 786 """Generate k-space information for the graph.""" 787 kmax: int = 30 788 """Maximum number of k-points to consider.""" 789 kthr: float = 1e-6 790 """Threshold for k-point filtering.""" 791 mult_size: float = 1.05 792 """Multiplicative factor for resizing the nblist.""" 793 794 FPID: ClassVar[str] = "GRAPH_FILTER" 795 796 def init(self): 797 return FrozenDict( 798 { 799 "npairs": 1, 800 "nblist_mult_size": self.mult_size, 801 } 802 ) 803 804 def get_processor(self) -> Tuple[nn.Module, Dict]: 805 return GraphFilterProcessor, { 806 "cutoff": self.cutoff, 807 "graph_key": self.graph_key, 808 "parent_graph": self.parent_graph, 809 "name": f"{self.graph_key}_Filter_{self.parent_graph}", 810 "switch_params": self.switch_params, 811 } 812 813 def get_graph_properties(self): 814 return { 815 self.graph_key: { 816 "cutoff": self.cutoff, 817 "directed": True, 818 "parent_graph": self.parent_graph, 819 } 820 } 821 822 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 823 """filter a nblist on cpu with numpy and dynamic shapes + store max shapes""" 824 graph_in = inputs[self.parent_graph] 825 nat = inputs["species"].shape[0] 826 827 new_state = {**state} 828 state_up = {} 829 mult_size = state.get("nblist_mult_size", self.mult_size) 830 assert mult_size >= 1., "nblist_mult_size should be >= 1." 831 832 edge_src = np.array(graph_in["edge_src"], dtype=np.int32) 833 d12 = np.array(graph_in["d12"], dtype=np.float32) 834 if self.remove_hydrogens: 835 species = inputs["species"] 836 src_idx = (edge_src < nat).nonzero()[0] 837 mask = np.zeros(edge_src.shape[0], dtype=bool) 838 mask[src_idx] = (species > 1)[edge_src[src_idx]] 839 d12 = np.where(mask, d12, self.cutoff**2) 840 mask = d12 < self.cutoff**2 841 842 max_pairs = state.get("npairs", 1) 843 idx = np.nonzero(mask)[0] 844 npairs = idx.shape[0] 845 if npairs > max_pairs or add_margin: 846 prev_max_pairs = max_pairs 847 max_pairs = int(mult_size * max(npairs, max_pairs)) + 1 848 state_up["npairs"] = (max_pairs, prev_max_pairs) 849 new_state["npairs"] = max_pairs 850 851 filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32) 852 edge_src = np.full(max_pairs, nat, dtype=np.int32) 853 edge_dst = np.full(max_pairs, nat, dtype=np.int32) 854 d12_ = np.full(max_pairs, self.cutoff**2) 855 filter_indices[:npairs] = idx 856 edge_src[:npairs] = graph_in["edge_src"][idx] 857 edge_dst[:npairs] = graph_in["edge_dst"][idx] 858 d12_[:npairs] = d12[idx] 859 d12 = d12_ 860 861 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 862 graph_out = { 863 **graph, 864 "edge_src": edge_src, 865 "edge_dst": edge_dst, 866 "filter_indices": filter_indices, 867 "d12": d12, 868 "overflow": False, 869 } 870 871 if self.k_space and "cells" in inputs: 872 if "k_points" not in graph: 873 ks, _, _, bewald = get_reciprocal_space_parameters( 874 inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr 875 ) 876 graph_out["k_points"] = ks 877 graph_out["b_ewald"] = bewald 878 879 output = {**inputs, self.graph_key: graph_out} 880 if return_state_update: 881 return FrozenDict(new_state), output, state_up 882 return FrozenDict(new_state), output 883 884 def check_reallocate(self, state, inputs, parent_overflow=False): 885 """check for overflow and reallocate nblist if necessary""" 886 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 887 if not overflow: 888 return state, {}, inputs, False 889 890 add_margin = inputs[self.graph_key].get("overflow", False) 891 state, inputs, state_up = self( 892 state, inputs, return_state_update=True, add_margin=add_margin 893 ) 894 return state, state_up, inputs, True 895 896 @partial(jax.jit, static_argnums=(0, 1)) 897 def process(self, state, inputs): 898 """filter a nblist on accelerator with jax and precomputed shapes""" 899 graph_in = inputs[self.parent_graph] 900 if state is None: 901 # skin update mode 902 graph = inputs[self.graph_key] 903 max_pairs = graph["edge_src"].shape[0] 904 else: 905 max_pairs = state.get("npairs", 1) 906 907 max_pairs_in = graph_in["edge_src"].shape[0] 908 nat = inputs["species"].shape[0] 909 910 edge_src = graph_in["edge_src"] 911 d12 = graph_in["d12"] 912 if self.remove_hydrogens: 913 species = inputs["species"] 914 mask = (species > 1)[edge_src] 915 d12 = jnp.where(mask, d12, self.cutoff**2) 916 mask = d12 < self.cutoff**2 917 918 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 919 mask, 920 max_pairs, 921 (edge_src, nat), 922 (graph_in["edge_dst"], nat), 923 (d12, self.cutoff**2), 924 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 925 ) 926 927 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 928 overflow = graph.get("overflow", False) | (npairs > max_pairs) 929 graph_out = { 930 **graph, 931 "edge_src": edge_src, 932 "edge_dst": edge_dst, 933 "filter_indices": filter_indices, 934 "d12": d12, 935 "overflow": overflow, 936 } 937 938 if self.k_space and "cells" in inputs: 939 if "k_points" not in graph: 940 raise NotImplementedError( 941 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 942 ) 943 944 return {**inputs, self.graph_key: graph_out} 945 946 @partial(jax.jit, static_argnums=(0,)) 947 def update_skin(self, inputs): 948 return self.process(None, inputs)
Filter a graph based on a cutoff distance
FPID: GRAPH_FILTER
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
884 def check_reallocate(self, state, inputs, parent_overflow=False): 885 """check for overflow and reallocate nblist if necessary""" 886 overflow = parent_overflow or inputs[self.graph_key].get("overflow", False) 887 if not overflow: 888 return state, {}, inputs, False 889 890 add_margin = inputs[self.graph_key].get("overflow", False) 891 state, inputs, state_up = self( 892 state, inputs, return_state_update=True, add_margin=add_margin 893 ) 894 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
896 @partial(jax.jit, static_argnums=(0, 1)) 897 def process(self, state, inputs): 898 """filter a nblist on accelerator with jax and precomputed shapes""" 899 graph_in = inputs[self.parent_graph] 900 if state is None: 901 # skin update mode 902 graph = inputs[self.graph_key] 903 max_pairs = graph["edge_src"].shape[0] 904 else: 905 max_pairs = state.get("npairs", 1) 906 907 max_pairs_in = graph_in["edge_src"].shape[0] 908 nat = inputs["species"].shape[0] 909 910 edge_src = graph_in["edge_src"] 911 d12 = graph_in["d12"] 912 if self.remove_hydrogens: 913 species = inputs["species"] 914 mask = (species > 1)[edge_src] 915 d12 = jnp.where(mask, d12, self.cutoff**2) 916 mask = d12 < self.cutoff**2 917 918 (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d( 919 mask, 920 max_pairs, 921 (edge_src, nat), 922 (graph_in["edge_dst"], nat), 923 (d12, self.cutoff**2), 924 (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in), 925 ) 926 927 graph = inputs[self.graph_key] if self.graph_key in inputs else {} 928 overflow = graph.get("overflow", False) | (npairs > max_pairs) 929 graph_out = { 930 **graph, 931 "edge_src": edge_src, 932 "edge_dst": edge_dst, 933 "filter_indices": filter_indices, 934 "d12": d12, 935 "overflow": overflow, 936 } 937 938 if self.k_space and "cells" in inputs: 939 if "k_points" not in graph: 940 raise NotImplementedError( 941 "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first." 942 ) 943 944 return {**inputs, self.graph_key: graph_out}
filter a nblist on accelerator with jax and precomputed shapes
951class GraphFilterProcessor(nn.Module): 952 """Filter processing for a pre-generated graph 953 954 This module is automatically added to a FENNIX model when a GraphFilter is used. 955 """ 956 957 cutoff: float 958 """Cutoff distance for the filtering.""" 959 graph_key: str 960 """Key of the filtered graph in the inputs.""" 961 parent_graph: str 962 """Key of the parent graph in the inputs.""" 963 switch_params: dict = dataclasses.field(default_factory=dict) 964 """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`.""" 965 966 @nn.compact 967 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 968 graph_in = inputs[self.parent_graph] 969 graph = inputs[self.graph_key] 970 971 if graph_in["vec"].shape[0] == 0: 972 vec = graph_in["vec"] 973 distances = graph_in["distances"] 974 filter_indices = jnp.asarray([], dtype=jnp.int32) 975 else: 976 filter_indices = graph["filter_indices"] 977 vec = ( 978 graph_in["vec"] 979 .at[filter_indices] 980 .get(mode="fill", fill_value=self.cutoff) 981 ) 982 distances = ( 983 graph_in["distances"] 984 .at[filter_indices] 985 .get(mode="fill", fill_value=self.cutoff) 986 ) 987 988 edge_mask = distances < self.cutoff 989 switch = SwitchFunction( 990 **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None} 991 )((distances, edge_mask)) 992 993 graph_out = { 994 **graph, 995 "vec": vec, 996 "distances": distances, 997 "switch": switch, 998 "filter_indices": filter_indices, 999 "edge_mask": edge_mask, 1000 } 1001 1002 if "alch_group" in inputs: 1003 edge_src=graph["edge_src"] 1004 edge_dst=graph["edge_dst"] 1005 alch_group = inputs["alch_group"] 1006 lambda_e = inputs["alch_elambda"] 1007 lambda_e = 0.5*(1.-jnp.cos(jnp.pi*lambda_e)) 1008 mask = alch_group[edge_src] == alch_group[edge_dst] 1009 graph_out["switch_raw"] = switch 1010 graph_out["switch"] = jnp.where( 1011 mask, 1012 switch, 1013 lambda_e * switch , 1014 ) 1015 1016 return {**inputs, self.graph_key: graph_out}
Filter processing for a pre-generated graph
This module is automatically added to a FENNIX model when a GraphFilter is used.
Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1019@dataclasses.dataclass(frozen=True) 1020class GraphAngularExtension: 1021 """Add angles list to a graph 1022 1023 FPID: GRAPH_ANGULAR_EXTENSION 1024 """ 1025 1026 mult_size: float = 1.05 1027 """Multiplicative factor for resizing the nblist.""" 1028 add_neigh: int = 5 1029 """Additional neighbors to add to the nblist when resizing.""" 1030 graph_key: str = "graph" 1031 """Key of the graph in the inputs.""" 1032 1033 FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION" 1034 1035 def init(self): 1036 return FrozenDict( 1037 { 1038 "nangles": 0, 1039 "nblist_mult_size": self.mult_size, 1040 "max_neigh": self.add_neigh, 1041 "add_neigh": self.add_neigh, 1042 } 1043 ) 1044 1045 def get_processor(self) -> Tuple[nn.Module, Dict]: 1046 return GraphAngleProcessor, { 1047 "graph_key": self.graph_key, 1048 "name": f"{self.graph_key}_AngleProcessor", 1049 } 1050 1051 def get_graph_properties(self): 1052 return { 1053 self.graph_key: { 1054 "has_angles": True, 1055 } 1056 } 1057 1058 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1059 """build angle nblist on cpu with numpy and dynamic shapes + store max shapes""" 1060 graph = inputs[self.graph_key] 1061 edge_src = np.array(graph["edge_src"], dtype=np.int32) 1062 1063 new_state = {**state} 1064 state_up = {} 1065 mult_size = state.get("nblist_mult_size", self.mult_size) 1066 assert mult_size >= 1., "nblist_mult_size should be >= 1." 1067 1068 ### count number of neighbors 1069 nat = inputs["species"].shape[0] 1070 count = np.zeros(nat + 1, dtype=np.int32) 1071 np.add.at(count, edge_src, 1) 1072 max_count = int(np.max(count[:-1])) 1073 1074 ### get sizes 1075 max_neigh = state.get("max_neigh", self.add_neigh) 1076 nedge = edge_src.shape[0] 1077 if max_count > max_neigh or add_margin: 1078 prev_max_neigh = max_neigh 1079 max_neigh = max(max_count, max_neigh) + state.get( 1080 "add_neigh", self.add_neigh 1081 ) 1082 state_up["max_neigh"] = (max_neigh, prev_max_neigh) 1083 new_state["max_neigh"] = max_neigh 1084 1085 max_neigh_arr = np.empty(max_neigh, dtype=bool) 1086 1087 nedge = edge_src.shape[0] 1088 1089 ### sort edge_src 1090 idx_sort = np.argsort(edge_src) 1091 edge_src_sorted = edge_src[idx_sort] 1092 1093 ### map sparse to dense nblist 1094 offset = np.tile(np.arange(max_count), nat) 1095 if max_count * nat >= nedge: 1096 offset = np.tile(np.arange(max_count), nat)[:nedge] 1097 else: 1098 offset = np.zeros(nedge, dtype=np.int32) 1099 offset[: max_count * nat] = np.tile(np.arange(max_count), nat) 1100 1101 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1102 mask = edge_src_sorted < nat 1103 indices = edge_src_sorted * max_count + offset 1104 indices = indices[mask] 1105 idx_sort = idx_sort[mask] 1106 edge_idx = np.full(nat * max_count, nedge, dtype=np.int32) 1107 edge_idx[indices] = idx_sort 1108 edge_idx = edge_idx.reshape(nat, max_count) 1109 1110 ### find all triplet for each atom center 1111 local_src, local_dst = np.triu_indices(max_count, 1) 1112 angle_src = edge_idx[:, local_src].flatten() 1113 angle_dst = edge_idx[:, local_dst].flatten() 1114 1115 ### mask for valid angles 1116 mask1 = angle_src < nedge 1117 mask2 = angle_dst < nedge 1118 angle_mask = mask1 & mask2 1119 1120 max_angles = state.get("nangles", 0) 1121 idx = np.nonzero(angle_mask)[0] 1122 nangles = idx.shape[0] 1123 if nangles > max_angles or add_margin: 1124 max_angles_prev = max_angles 1125 max_angles = int(mult_size * max(nangles, max_angles)) + 1 1126 state_up["nangles"] = (max_angles, max_angles_prev) 1127 new_state["nangles"] = max_angles 1128 1129 ## filter angles to sparse representation 1130 angle_src_ = np.full(max_angles, nedge, dtype=np.int32) 1131 angle_dst_ = np.full(max_angles, nedge, dtype=np.int32) 1132 angle_src_[:nangles] = angle_src[idx] 1133 angle_dst_[:nangles] = angle_dst[idx] 1134 1135 central_atom = np.full(max_angles, nat, dtype=np.int32) 1136 central_atom[:nangles] = edge_src[angle_src_[:nangles]] 1137 1138 ## update graph 1139 output = { 1140 **inputs, 1141 self.graph_key: { 1142 **graph, 1143 "angle_src": angle_src_, 1144 "angle_dst": angle_dst_, 1145 "central_atom": central_atom, 1146 "angle_overflow": False, 1147 "max_neigh": max_neigh, 1148 "__max_neigh_array": max_neigh_arr, 1149 }, 1150 } 1151 1152 if return_state_update: 1153 return FrozenDict(new_state), output, state_up 1154 return FrozenDict(new_state), output 1155 1156 def check_reallocate(self, state, inputs, parent_overflow=False): 1157 """check for overflow and reallocate nblist if necessary""" 1158 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1159 if not overflow: 1160 return state, {}, inputs, False 1161 1162 add_margin = inputs[self.graph_key]["angle_overflow"] 1163 state, inputs, state_up = self( 1164 state, inputs, return_state_update=True, add_margin=add_margin 1165 ) 1166 return state, state_up, inputs, True 1167 1168 @partial(jax.jit, static_argnums=(0, 1)) 1169 def process(self, state, inputs): 1170 """build angle nblist on accelerator with jax and precomputed shapes""" 1171 graph = inputs[self.graph_key] 1172 edge_src = graph["edge_src"] 1173 1174 ### count number of neighbors 1175 nat = inputs["species"].shape[0] 1176 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1177 max_count = jnp.max(count) 1178 1179 ### get sizes 1180 if state is None: 1181 max_neigh_arr = graph["__max_neigh_array"] 1182 max_neigh = max_neigh_arr.shape[0] 1183 prev_nangles = graph["angle_src"].shape[0] 1184 else: 1185 max_neigh = state.get("max_neigh", self.add_neigh) 1186 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1187 prev_nangles = state.get("nangles", 0) 1188 1189 nedge = edge_src.shape[0] 1190 1191 ### sort edge_src 1192 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1193 edge_src_sorted = edge_src[idx_sort] 1194 1195 ### map sparse to dense nblist 1196 if max_neigh * nat < nedge: 1197 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1198 offset = jnp.asarray( 1199 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1200 ) 1201 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1202 indices = edge_src_sorted * max_neigh + offset 1203 edge_idx = ( 1204 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1205 .at[indices] 1206 .set(idx_sort, mode="drop") 1207 .reshape(nat, max_neigh) 1208 ) 1209 1210 ### find all triplet for each atom center 1211 local_src, local_dst = np.triu_indices(max_neigh, 1) 1212 angle_src = edge_idx[:, local_src].flatten() 1213 angle_dst = edge_idx[:, local_dst].flatten() 1214 1215 ### mask for valid angles 1216 mask1 = angle_src < nedge 1217 mask2 = angle_dst < nedge 1218 angle_mask = mask1 & mask2 1219 1220 ## filter angles to sparse representation 1221 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1222 angle_mask, 1223 prev_nangles, 1224 (angle_src, nedge), 1225 (angle_dst, nedge), 1226 ) 1227 ## find central atom 1228 central_atom = edge_src[angle_src] 1229 1230 ## check for overflow 1231 angle_overflow = nangles > prev_nangles 1232 neigh_overflow = max_count > max_neigh 1233 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1234 1235 ## update graph 1236 output = { 1237 **inputs, 1238 self.graph_key: { 1239 **graph, 1240 "angle_src": angle_src, 1241 "angle_dst": angle_dst, 1242 "central_atom": central_atom, 1243 "angle_overflow": overflow, 1244 # "max_neigh": max_neigh, 1245 "__max_neigh_array": max_neigh_arr, 1246 }, 1247 } 1248 1249 return output 1250 1251 @partial(jax.jit, static_argnums=(0,)) 1252 def update_skin(self, inputs): 1253 return self.process(None, inputs)
Add angles list to a graph
FPID: GRAPH_ANGULAR_EXTENSION
1156 def check_reallocate(self, state, inputs, parent_overflow=False): 1157 """check for overflow and reallocate nblist if necessary""" 1158 overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"] 1159 if not overflow: 1160 return state, {}, inputs, False 1161 1162 add_margin = inputs[self.graph_key]["angle_overflow"] 1163 state, inputs, state_up = self( 1164 state, inputs, return_state_update=True, add_margin=add_margin 1165 ) 1166 return state, state_up, inputs, True
check for overflow and reallocate nblist if necessary
1168 @partial(jax.jit, static_argnums=(0, 1)) 1169 def process(self, state, inputs): 1170 """build angle nblist on accelerator with jax and precomputed shapes""" 1171 graph = inputs[self.graph_key] 1172 edge_src = graph["edge_src"] 1173 1174 ### count number of neighbors 1175 nat = inputs["species"].shape[0] 1176 count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop") 1177 max_count = jnp.max(count) 1178 1179 ### get sizes 1180 if state is None: 1181 max_neigh_arr = graph["__max_neigh_array"] 1182 max_neigh = max_neigh_arr.shape[0] 1183 prev_nangles = graph["angle_src"].shape[0] 1184 else: 1185 max_neigh = state.get("max_neigh", self.add_neigh) 1186 max_neigh_arr = jnp.empty(max_neigh, dtype=bool) 1187 prev_nangles = state.get("nangles", 0) 1188 1189 nedge = edge_src.shape[0] 1190 1191 ### sort edge_src 1192 idx_sort = jnp.argsort(edge_src).astype(jnp.int32) 1193 edge_src_sorted = edge_src[idx_sort] 1194 1195 ### map sparse to dense nblist 1196 if max_neigh * nat < nedge: 1197 raise ValueError("Found max_neigh*nat < nedge. This should not happen.") 1198 offset = jnp.asarray( 1199 np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32 1200 ) 1201 # offset = jnp.where(edge_src_sorted < nat, offset, 0) 1202 indices = edge_src_sorted * max_neigh + offset 1203 edge_idx = ( 1204 jnp.full(nat * max_neigh, nedge, dtype=jnp.int32) 1205 .at[indices] 1206 .set(idx_sort, mode="drop") 1207 .reshape(nat, max_neigh) 1208 ) 1209 1210 ### find all triplet for each atom center 1211 local_src, local_dst = np.triu_indices(max_neigh, 1) 1212 angle_src = edge_idx[:, local_src].flatten() 1213 angle_dst = edge_idx[:, local_dst].flatten() 1214 1215 ### mask for valid angles 1216 mask1 = angle_src < nedge 1217 mask2 = angle_dst < nedge 1218 angle_mask = mask1 & mask2 1219 1220 ## filter angles to sparse representation 1221 (angle_src, angle_dst), _, nangles = mask_filter_1d( 1222 angle_mask, 1223 prev_nangles, 1224 (angle_src, nedge), 1225 (angle_dst, nedge), 1226 ) 1227 ## find central atom 1228 central_atom = edge_src[angle_src] 1229 1230 ## check for overflow 1231 angle_overflow = nangles > prev_nangles 1232 neigh_overflow = max_count > max_neigh 1233 overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow 1234 1235 ## update graph 1236 output = { 1237 **inputs, 1238 self.graph_key: { 1239 **graph, 1240 "angle_src": angle_src, 1241 "angle_dst": angle_dst, 1242 "central_atom": central_atom, 1243 "angle_overflow": overflow, 1244 # "max_neigh": max_neigh, 1245 "__max_neigh_array": max_neigh_arr, 1246 }, 1247 } 1248 1249 return output
build angle nblist on accelerator with jax and precomputed shapes
1256class GraphAngleProcessor(nn.Module): 1257 """Process a pre-generated graph to compute angles 1258 1259 This module is automatically added to a FENNIX model when a GraphAngularExtension is used. 1260 1261 """ 1262 1263 graph_key: str 1264 """Key of the graph in the inputs.""" 1265 1266 @nn.compact 1267 def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]): 1268 graph = inputs[self.graph_key] 1269 distances = graph["distances"] 1270 vec = graph["vec"] 1271 angle_src = graph["angle_src"] 1272 angle_dst = graph["angle_dst"] 1273 1274 dir = vec / jnp.clip(distances[:, None], a_min=1.0e-5) 1275 cos_angles = ( 1276 dir.at[angle_src].get(mode="fill", fill_value=0.5) 1277 * dir.at[angle_dst].get(mode="fill", fill_value=0.5) 1278 ).sum(axis=-1) 1279 1280 angles = jnp.arccos(0.95 * cos_angles) 1281 1282 return { 1283 **inputs, 1284 self.graph_key: { 1285 **graph, 1286 # "cos_angles": cos_angles, 1287 "angles": angles, 1288 # "angle_mask": angle_mask, 1289 }, 1290 }
Process a pre-generated graph to compute angles
This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1293@dataclasses.dataclass(frozen=True) 1294class SpeciesIndexer: 1295 """Build an index that splits atomic arrays by species. 1296 1297 FPID: SPECIES_INDEXER 1298 1299 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1300 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1301 1302 """ 1303 1304 output_key: str = "species_index" 1305 """Key for the output dictionary.""" 1306 species_order: Optional[str] = None 1307 """Comma separated list of species in the order they should be indexed.""" 1308 add_atoms: int = 0 1309 """Additional atoms to add to the sizes.""" 1310 add_atoms_margin: int = 10 1311 """Additional atoms to add to the sizes when adding margin.""" 1312 1313 FPID: ClassVar[str] = "SPECIES_INDEXER" 1314 1315 def init(self): 1316 return FrozenDict( 1317 { 1318 "sizes": {}, 1319 } 1320 ) 1321 1322 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1323 species = np.array(inputs["species"], dtype=np.int32) 1324 nat = species.shape[0] 1325 set_species, counts = np.unique(species, return_counts=True) 1326 1327 new_state = {**state} 1328 state_up = {} 1329 1330 sizes = state.get("sizes", FrozenDict({})) 1331 new_sizes = {**sizes} 1332 up_sizes = False 1333 counts_dict = {} 1334 for s, c in zip(set_species, counts): 1335 if s <= 0: 1336 continue 1337 counts_dict[s] = c 1338 if c > sizes.get(s, 0): 1339 up_sizes = True 1340 add_atoms = state.get("add_atoms", self.add_atoms) 1341 if add_margin: 1342 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1343 new_sizes[s] = c + add_atoms 1344 1345 new_sizes = FrozenDict(new_sizes) 1346 if up_sizes: 1347 state_up["sizes"] = (new_sizes, sizes) 1348 new_state["sizes"] = new_sizes 1349 1350 if self.species_order is not None: 1351 species_order = [el.strip() for el in self.species_order.split(",")] 1352 max_size_prev = state.get("max_size", 0) 1353 max_size = max(new_sizes.values()) 1354 if max_size > max_size_prev: 1355 state_up["max_size"] = (max_size, max_size_prev) 1356 new_state["max_size"] = max_size 1357 max_size_prev = max_size 1358 1359 species_index = np.full((len(species_order), max_size), nat, dtype=np.int32) 1360 for i, el in enumerate(species_order): 1361 s = PERIODIC_TABLE_REV_IDX[el] 1362 if s in counts_dict.keys(): 1363 species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1364 else: 1365 species_index = { 1366 PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32) 1367 for s, c in new_sizes.items() 1368 } 1369 for s, c in zip(set_species, counts): 1370 if s <= 0: 1371 continue 1372 species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0] 1373 1374 output = { 1375 **inputs, 1376 self.output_key: species_index, 1377 self.output_key + "_overflow": False, 1378 } 1379 1380 if return_state_update: 1381 return FrozenDict(new_state), output, state_up 1382 return FrozenDict(new_state), output 1383 1384 def check_reallocate(self, state, inputs, parent_overflow=False): 1385 """check for overflow and reallocate nblist if necessary""" 1386 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1387 if not overflow: 1388 return state, {}, inputs, False 1389 1390 add_margin = inputs[self.output_key + "_overflow"] 1391 state, inputs, state_up = self( 1392 state, inputs, return_state_update=True, add_margin=add_margin 1393 ) 1394 return state, state_up, inputs, True 1395 # return state, {}, inputs, parent_overflow 1396 1397 @partial(jax.jit, static_argnums=(0, 1)) 1398 def process(self, state, inputs): 1399 # assert ( 1400 # self.output_key in inputs 1401 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1402 1403 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1404 if self.output_key in inputs and not recompute_species_index: 1405 return inputs 1406 1407 if state is None: 1408 raise ValueError("Species Indexer state must be provided on accelerator.") 1409 1410 species = inputs["species"] 1411 nat = species.shape[0] 1412 1413 sizes = state["sizes"] 1414 1415 if self.species_order is not None: 1416 species_order = [el.strip() for el in self.species_order.split(",")] 1417 max_size = state["max_size"] 1418 1419 species_index = jnp.full( 1420 (len(species_order), max_size), nat, dtype=jnp.int32 1421 ) 1422 for i, el in enumerate(species_order): 1423 s = PERIODIC_TABLE_REV_IDX[el] 1424 if s in sizes.keys(): 1425 c = sizes[s] 1426 species_index = species_index.at[i, :].set( 1427 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1428 ) 1429 # if s in counts_dict.keys(): 1430 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1431 else: 1432 # species_index = { 1433 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1434 # for s, c in sizes.items() 1435 # } 1436 species_index = {} 1437 overflow = False 1438 natcount = 0 1439 for s, c in sizes.items(): 1440 mask = species == s 1441 new_size = jnp.sum(mask) 1442 natcount = natcount + new_size 1443 overflow = overflow | (new_size > c) # check if sizes are correct 1444 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1445 species == s, size=c, fill_value=nat 1446 )[0] 1447 1448 mask = species <= 0 1449 new_size = jnp.sum(mask) 1450 natcount = natcount + new_size 1451 overflow = overflow | ( 1452 natcount < species.shape[0] 1453 ) # check if any species missing 1454 1455 return { 1456 **inputs, 1457 self.output_key: species_index, 1458 self.output_key + "_overflow": overflow, 1459 } 1460 1461 @partial(jax.jit, static_argnums=(0,)) 1462 def update_skin(self, inputs): 1463 return self.process(None, inputs)
Build an index that splits atomic arrays by species.
FPID: SPECIES_INDEXER
If species_order
is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
If species_order
is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
Comma separated list of species in the order they should be indexed.
1384 def check_reallocate(self, state, inputs, parent_overflow=False): 1385 """check for overflow and reallocate nblist if necessary""" 1386 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1387 if not overflow: 1388 return state, {}, inputs, False 1389 1390 add_margin = inputs[self.output_key + "_overflow"] 1391 state, inputs, state_up = self( 1392 state, inputs, return_state_update=True, add_margin=add_margin 1393 ) 1394 return state, state_up, inputs, True 1395 # return state, {}, inputs, parent_overflow
check for overflow and reallocate nblist if necessary
1397 @partial(jax.jit, static_argnums=(0, 1)) 1398 def process(self, state, inputs): 1399 # assert ( 1400 # self.output_key in inputs 1401 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1402 1403 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1404 if self.output_key in inputs and not recompute_species_index: 1405 return inputs 1406 1407 if state is None: 1408 raise ValueError("Species Indexer state must be provided on accelerator.") 1409 1410 species = inputs["species"] 1411 nat = species.shape[0] 1412 1413 sizes = state["sizes"] 1414 1415 if self.species_order is not None: 1416 species_order = [el.strip() for el in self.species_order.split(",")] 1417 max_size = state["max_size"] 1418 1419 species_index = jnp.full( 1420 (len(species_order), max_size), nat, dtype=jnp.int32 1421 ) 1422 for i, el in enumerate(species_order): 1423 s = PERIODIC_TABLE_REV_IDX[el] 1424 if s in sizes.keys(): 1425 c = sizes[s] 1426 species_index = species_index.at[i, :].set( 1427 jnp.nonzero(species == s, size=max_size, fill_value=nat)[0] 1428 ) 1429 # if s in counts_dict.keys(): 1430 # species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0] 1431 else: 1432 # species_index = { 1433 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1434 # for s, c in sizes.items() 1435 # } 1436 species_index = {} 1437 overflow = False 1438 natcount = 0 1439 for s, c in sizes.items(): 1440 mask = species == s 1441 new_size = jnp.sum(mask) 1442 natcount = natcount + new_size 1443 overflow = overflow | (new_size > c) # check if sizes are correct 1444 species_index[PERIODIC_TABLE[s]] = jnp.nonzero( 1445 species == s, size=c, fill_value=nat 1446 )[0] 1447 1448 mask = species <= 0 1449 new_size = jnp.sum(mask) 1450 natcount = natcount + new_size 1451 overflow = overflow | ( 1452 natcount < species.shape[0] 1453 ) # check if any species missing 1454 1455 return { 1456 **inputs, 1457 self.output_key: species_index, 1458 self.output_key + "_overflow": overflow, 1459 }
1465@dataclasses.dataclass(frozen=True) 1466class BlockIndexer: 1467 """Build an index that splits atomic arrays by chemical blocks. 1468 1469 FPID: BLOCK_INDEXER 1470 1471 If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. 1472 If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values. 1473 1474 """ 1475 1476 output_key: str = "block_index" 1477 """Key for the output dictionary.""" 1478 add_atoms: int = 0 1479 """Additional atoms to add to the sizes.""" 1480 add_atoms_margin: int = 10 1481 """Additional atoms to add to the sizes when adding margin.""" 1482 split_CNOPSSe: bool = False 1483 1484 FPID: ClassVar[str] = "BLOCK_INDEXER" 1485 1486 def init(self): 1487 return FrozenDict( 1488 { 1489 "sizes": {}, 1490 } 1491 ) 1492 1493 def build_chemical_blocks(self): 1494 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1495 if self.split_CNOPSSe: 1496 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1497 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1498 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1499 if self.split_CNOPSSe: 1500 _CHEMICAL_BLOCKS[6] = 1 1501 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1502 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1503 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1504 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1505 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1506 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS 1507 1508 def __call__(self, state, inputs, return_state_update=False, add_margin=False): 1509 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1510 1511 species = np.array(inputs["species"], dtype=np.int32) 1512 blocks = _CHEMICAL_BLOCKS[species] 1513 nat = species.shape[0] 1514 set_blocks, counts = np.unique(blocks, return_counts=True) 1515 1516 new_state = {**state} 1517 state_up = {} 1518 1519 sizes = state.get("sizes", FrozenDict({})) 1520 new_sizes = {**sizes} 1521 up_sizes = False 1522 for s, c in zip(set_blocks, counts): 1523 if s < 0: 1524 continue 1525 key = (s, _CHEMICAL_BLOCKS_NAMES[s]) 1526 if c > sizes.get(key, 0): 1527 up_sizes = True 1528 add_atoms = state.get("add_atoms", self.add_atoms) 1529 if add_margin: 1530 add_atoms += state.get("add_atoms_margin", self.add_atoms_margin) 1531 new_sizes[key] = c + add_atoms 1532 1533 new_sizes = FrozenDict(new_sizes) 1534 if up_sizes: 1535 state_up["sizes"] = (new_sizes, sizes) 1536 new_state["sizes"] = new_sizes 1537 1538 block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES} 1539 for (_,n), c in new_sizes.items(): 1540 block_index[n] = np.full(c, nat, dtype=np.int32) 1541 # block_index = { 1542 # n: np.full(c, nat, dtype=np.int32) 1543 # for (_,n), c in new_sizes.items() 1544 # } 1545 for s, c in zip(set_blocks, counts): 1546 if s < 0: 1547 continue 1548 block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0] 1549 1550 output = { 1551 **inputs, 1552 self.output_key: block_index, 1553 self.output_key + "_overflow": False, 1554 } 1555 1556 if return_state_update: 1557 return FrozenDict(new_state), output, state_up 1558 return FrozenDict(new_state), output 1559 1560 def check_reallocate(self, state, inputs, parent_overflow=False): 1561 """check for overflow and reallocate nblist if necessary""" 1562 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1563 if not overflow: 1564 return state, {}, inputs, False 1565 1566 add_margin = inputs[self.output_key + "_overflow"] 1567 state, inputs, state_up = self( 1568 state, inputs, return_state_update=True, add_margin=add_margin 1569 ) 1570 return state, state_up, inputs, True 1571 # return state, {}, inputs, parent_overflow 1572 1573 @partial(jax.jit, static_argnums=(0, 1)) 1574 def process(self, state, inputs): 1575 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1576 # assert ( 1577 # self.output_key in inputs 1578 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1579 1580 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1581 if self.output_key in inputs and not recompute_species_index: 1582 return inputs 1583 1584 if state is None: 1585 raise ValueError("Block Indexer state must be provided on accelerator.") 1586 1587 species = inputs["species"] 1588 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1589 nat = species.shape[0] 1590 1591 sizes = state["sizes"] 1592 1593 # species_index = { 1594 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1595 # for s, c in sizes.items() 1596 # } 1597 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1598 overflow = False 1599 natcount = 0 1600 for (s,name), c in sizes.items(): 1601 mask = blocks == s 1602 new_size = jnp.sum(mask) 1603 natcount = natcount + new_size 1604 overflow = overflow | (new_size > c) # check if sizes are correct 1605 block_index[name] = jnp.nonzero( 1606 mask, size=c, fill_value=nat 1607 )[0] 1608 1609 mask = blocks < 0 1610 new_size = jnp.sum(mask) 1611 natcount = natcount + new_size 1612 overflow = overflow | ( 1613 natcount < species.shape[0] 1614 ) # check if any species missing 1615 1616 return { 1617 **inputs, 1618 self.output_key: block_index, 1619 self.output_key + "_overflow": overflow, 1620 } 1621 1622 @partial(jax.jit, static_argnums=(0,)) 1623 def update_skin(self, inputs): 1624 return self.process(None, inputs)
Build an index that splits atomic arrays by chemical blocks.
FPID: BLOCK_INDEXER
If species_order
is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
If species_order
is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1493 def build_chemical_blocks(self): 1494 _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy() 1495 if self.split_CNOPSSe: 1496 _CHEMICAL_BLOCKS_NAMES[1] = "C" 1497 _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"]) 1498 _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy() 1499 if self.split_CNOPSSe: 1500 _CHEMICAL_BLOCKS[6] = 1 1501 _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES) 1502 _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1 1503 _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2 1504 _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3 1505 _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4 1506 return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1560 def check_reallocate(self, state, inputs, parent_overflow=False): 1561 """check for overflow and reallocate nblist if necessary""" 1562 overflow = parent_overflow or inputs[self.output_key + "_overflow"] 1563 if not overflow: 1564 return state, {}, inputs, False 1565 1566 add_margin = inputs[self.output_key + "_overflow"] 1567 state, inputs, state_up = self( 1568 state, inputs, return_state_update=True, add_margin=add_margin 1569 ) 1570 return state, state_up, inputs, True 1571 # return state, {}, inputs, parent_overflow
check for overflow and reallocate nblist if necessary
1573 @partial(jax.jit, static_argnums=(0, 1)) 1574 def process(self, state, inputs): 1575 _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks() 1576 # assert ( 1577 # self.output_key in inputs 1578 # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first." 1579 1580 recompute_species_index = "recompute_species_index" in inputs.get("flags", {}) 1581 if self.output_key in inputs and not recompute_species_index: 1582 return inputs 1583 1584 if state is None: 1585 raise ValueError("Block Indexer state must be provided on accelerator.") 1586 1587 species = inputs["species"] 1588 blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species] 1589 nat = species.shape[0] 1590 1591 sizes = state["sizes"] 1592 1593 # species_index = { 1594 # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0] 1595 # for s, c in sizes.items() 1596 # } 1597 block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES} 1598 overflow = False 1599 natcount = 0 1600 for (s,name), c in sizes.items(): 1601 mask = blocks == s 1602 new_size = jnp.sum(mask) 1603 natcount = natcount + new_size 1604 overflow = overflow | (new_size > c) # check if sizes are correct 1605 block_index[name] = jnp.nonzero( 1606 mask, size=c, fill_value=nat 1607 )[0] 1608 1609 mask = blocks < 0 1610 new_size = jnp.sum(mask) 1611 natcount = natcount + new_size 1612 overflow = overflow | ( 1613 natcount < species.shape[0] 1614 ) # check if any species missing 1615 1616 return { 1617 **inputs, 1618 self.output_key: block_index, 1619 self.output_key + "_overflow": overflow, 1620 }
1627@dataclasses.dataclass(frozen=True) 1628class AtomPadding: 1629 """Pad atomic arrays to a fixed size.""" 1630 1631 mult_size: float = 1.2 1632 """Multiplicative factor for resizing the atomic arrays.""" 1633 add_sys: int = 0 1634 1635 def init(self): 1636 return {"prev_nat": 0, "prev_nsys": 0} 1637 1638 def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]: 1639 species = inputs["species"] 1640 nat = species.shape[0] 1641 1642 prev_nat = state.get("prev_nat", 0) 1643 prev_nat_ = prev_nat 1644 if nat > prev_nat_: 1645 prev_nat_ = int(self.mult_size * nat) + 1 1646 1647 nsys = len(inputs["natoms"]) 1648 prev_nsys = state.get("prev_nsys", 0) 1649 prev_nsys_ = prev_nsys 1650 if nsys > prev_nsys_: 1651 prev_nsys_ = nsys + self.add_sys 1652 1653 add_atoms = prev_nat_ - nat 1654 add_sys = prev_nsys_ - nsys + 1 1655 output = {**inputs} 1656 if add_atoms > 0: 1657 for k, v in inputs.items(): 1658 if isinstance(v, np.ndarray) or isinstance(v, jax.Array): 1659 if v.shape[0] == nat: 1660 output[k] = np.append( 1661 v, 1662 np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype), 1663 axis=0, 1664 ) 1665 elif v.shape[0] == nsys: 1666 if k == "cells": 1667 output[k] = np.append( 1668 v, 1669 1000 1670 * np.eye(3, dtype=v.dtype)[None, :, :].repeat( 1671 add_sys, axis=0 1672 ), 1673 axis=0, 1674 ) 1675 else: 1676 output[k] = np.append( 1677 v, 1678 np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype), 1679 axis=0, 1680 ) 1681 output["natoms"] = np.append( 1682 inputs["natoms"], np.zeros(add_sys, dtype=np.int32) 1683 ) 1684 output["species"] = np.append( 1685 species, -1 * np.ones(add_atoms, dtype=species.dtype) 1686 ) 1687 output["batch_index"] = np.append( 1688 inputs["batch_index"], 1689 np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype), 1690 ) 1691 if "system_index" in inputs: 1692 output["system_index"] = np.append( 1693 inputs["system_index"], 1694 np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype), 1695 ) 1696 1697 output["true_atoms"] = output["species"] > 0 1698 output["true_sys"] = np.arange(len(output["natoms"])) < nsys 1699 1700 state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_} 1701 1702 return FrozenDict(state), output
Pad atomic arrays to a fixed size.
1705def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]: 1706 """Remove padding from atomic arrays.""" 1707 if "true_atoms" not in inputs: 1708 return inputs 1709 1710 species = inputs["species"] 1711 true_atoms = inputs["true_atoms"] 1712 true_sys = inputs["true_sys"] 1713 natall = species.shape[0] 1714 nat = np.argmax(species <= 0) 1715 if nat == 0: 1716 return inputs 1717 1718 natoms = inputs["natoms"] 1719 nsysall = len(natoms) 1720 1721 output = {**inputs} 1722 for k, v in inputs.items(): 1723 if isinstance(v, jax.Array) or isinstance(v, np.ndarray): 1724 if v.ndim == 0: 1725 continue 1726 if v.shape[0] == natall: 1727 output[k] = v[true_atoms] 1728 elif v.shape[0] == nsysall: 1729 output[k] = v[true_sys] 1730 del output["true_sys"] 1731 del output["true_atoms"] 1732 return output
Remove padding from atomic arrays.
1735def check_input(inputs): 1736 """Check the input dictionary for required keys and types.""" 1737 assert "species" in inputs, "species must be provided" 1738 assert "coordinates" in inputs, "coordinates must be provided" 1739 species = inputs["species"].astype(np.int32) 1740 ifake = np.argmax(species <= 0) 1741 if ifake > 0: 1742 assert np.all(species[:ifake] > 0), "species must be positive" 1743 nat = inputs["species"].shape[0] 1744 1745 natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32) 1746 batch_index = inputs.get( 1747 "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms) 1748 ).astype(np.int32) 1749 output = {**inputs, "natoms": natoms, "batch_index": batch_index} 1750 if "cells" in inputs: 1751 cells = inputs["cells"] 1752 if "reciprocal_cells" not in inputs: 1753 reciprocal_cells = np.linalg.inv(cells) 1754 else: 1755 reciprocal_cells = inputs["reciprocal_cells"] 1756 if cells.ndim == 2: 1757 cells = cells[None, :, :] 1758 if reciprocal_cells.ndim == 2: 1759 reciprocal_cells = reciprocal_cells[None, :, :] 1760 output["cells"] = cells 1761 output["reciprocal_cells"] = reciprocal_cells 1762 1763 return output
Check the input dictionary for required keys and types.
1766def convert_to_jax(data): 1767 """Convert a numpy arrays to jax arrays in a pytree.""" 1768 1769 def convert(x): 1770 if isinstance(x, np.ndarray): 1771 # if x.dtype == np.float64: 1772 # return jnp.asarray(x, dtype=jnp.float32) 1773 return jnp.asarray(x) 1774 return x 1775 1776 return jax.tree_util.tree_map(convert, data)
Convert a numpy arrays to jax arrays in a pytree.
1779class JaxConverter(nn.Module): 1780 """Convert numpy arrays to jax arrays in a pytree.""" 1781 1782 def __call__(self, data): 1783 return convert_to_jax(data)
Convert numpy arrays to jax arrays in a pytree.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
1786@dataclasses.dataclass(frozen=True) 1787class PreprocessingChain: 1788 """Chain of preprocessing layers.""" 1789 1790 layers: Tuple[Callable[..., Dict[str, Any]]] 1791 """Preprocessing layers.""" 1792 use_atom_padding: bool = False 1793 """Add an AtomPadding layer at the beginning of the chain.""" 1794 atom_padder: AtomPadding = AtomPadding() 1795 """AtomPadding layer.""" 1796 1797 def __post_init__(self): 1798 if not isinstance(self.layers, Sequence): 1799 raise ValueError( 1800 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 1801 ) 1802 if not self.layers: 1803 raise ValueError(f"Error: no Preprocessing layers were provided.") 1804 1805 def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]: 1806 do_check_input = state.get("check_input", True) 1807 if do_check_input: 1808 inputs = check_input(inputs) 1809 new_state = [] 1810 layer_state = state["layers_state"] 1811 i = 0 1812 if self.use_atom_padding: 1813 s, inputs = self.atom_padder(layer_state[0], inputs) 1814 new_state.append(s) 1815 i += 1 1816 for layer in self.layers: 1817 s, inputs = layer(layer_state[i], inputs, return_state_update=False) 1818 new_state.append(s) 1819 i += 1 1820 return FrozenDict({**state, "layers_state": tuple(new_state)}), convert_to_jax( 1821 inputs 1822 ) 1823 1824 def check_reallocate(self, state, inputs): 1825 new_state = [] 1826 state_up = [] 1827 layer_state = state["layers_state"] 1828 i = 0 1829 if self.use_atom_padding: 1830 new_state.append(layer_state[0]) 1831 i += 1 1832 parent_overflow = False 1833 for layer in self.layers: 1834 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1835 layer_state[i], inputs, parent_overflow 1836 ) 1837 new_state.append(s) 1838 state_up.append(s_up) 1839 i += 1 1840 1841 if not parent_overflow: 1842 return state, {}, inputs, False 1843 return ( 1844 FrozenDict({**state, "layers_state": tuple(new_state)}), 1845 state_up, 1846 inputs, 1847 True, 1848 ) 1849 1850 def atom_padding(self, state, inputs): 1851 if self.use_atom_padding: 1852 padder_state = state["layers_state"][0] 1853 return self.atom_padder(padder_state, inputs) 1854 return state, inputs 1855 1856 @partial(jax.jit, static_argnums=(0, 1)) 1857 def process(self, state, inputs): 1858 layer_state = state["layers_state"] 1859 i = 1 if self.use_atom_padding else 0 1860 for layer in self.layers: 1861 inputs = layer.process(layer_state[i], inputs) 1862 i += 1 1863 return inputs 1864 1865 @partial(jax.jit, static_argnums=(0)) 1866 def update_skin(self, inputs): 1867 for layer in self.layers: 1868 inputs = layer.update_skin(inputs) 1869 return inputs 1870 1871 def init(self): 1872 state = [] 1873 if self.use_atom_padding: 1874 state.append(self.atom_padder.init()) 1875 for layer in self.layers: 1876 state.append(layer.init()) 1877 return FrozenDict({"check_input": True, "layers_state": state}) 1878 1879 def init_with_output(self, inputs): 1880 state = self.init() 1881 return self(state, inputs) 1882 1883 def get_processors(self): 1884 processors = [] 1885 for layer in self.layers: 1886 if hasattr(layer, "get_processor"): 1887 processors.append(layer.get_processor()) 1888 return processors 1889 1890 def get_graphs_properties(self): 1891 properties = {} 1892 for layer in self.layers: 1893 if hasattr(layer, "get_graph_properties"): 1894 properties = deep_update(properties, layer.get_graph_properties()) 1895 return properties
Chain of preprocessing layers.
1824 def check_reallocate(self, state, inputs): 1825 new_state = [] 1826 state_up = [] 1827 layer_state = state["layers_state"] 1828 i = 0 1829 if self.use_atom_padding: 1830 new_state.append(layer_state[0]) 1831 i += 1 1832 parent_overflow = False 1833 for layer in self.layers: 1834 s, s_up, inputs, parent_overflow = layer.check_reallocate( 1835 layer_state[i], inputs, parent_overflow 1836 ) 1837 new_state.append(s) 1838 state_up.append(s_up) 1839 i += 1 1840 1841 if not parent_overflow: 1842 return state, {}, inputs, False 1843 return ( 1844 FrozenDict({**state, "layers_state": tuple(new_state)}), 1845 state_up, 1846 inputs, 1847 True, 1848 )