fennol.models.preprocessing

   1import flax.linen as nn
   2from typing import Sequence, Callable, Union, Dict, Any, ClassVar
   3import jax.numpy as jnp
   4import jax
   5import numpy as np
   6from typing import Optional, Tuple
   7import numba
   8import dataclasses
   9from functools import partial
  10
  11from flax.core.frozen_dict import FrozenDict
  12
  13
  14from ..utils.activations import chain
  15from ..utils import deep_update, mask_filter_1d
  16from ..utils.kspace import get_reciprocal_space_parameters
  17from .misc.misc import SwitchFunction
  18from ..utils.periodic_table import PERIODIC_TABLE, PERIODIC_TABLE_REV_IDX,CHEMICAL_BLOCKS,CHEMICAL_BLOCKS_NAMES
  19
  20
  21@dataclasses.dataclass(frozen=True)
  22class GraphGenerator:
  23    """Generate a graph from a set of coordinates
  24
  25    FPID: GRAPH
  26
  27    For now, we generate all pairs of atoms and filter based on cutoff.
  28    If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist.
  29    """
  30
  31    cutoff: float
  32    """Cutoff distance for the graph."""
  33    graph_key: str = "graph"
  34    """Key of the graph in the outputs."""
  35    switch_params: dict = dataclasses.field(default_factory=dict, hash=False)
  36    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
  37    kmax: int = 30
  38    """Maximum number of k-points to consider."""
  39    kthr: float = 1e-6
  40    """Threshold for k-point filtering."""
  41    k_space: bool = False
  42    """Whether to generate k-space information for the graph."""
  43    mult_size: float = 1.05
  44    """Multiplicative factor for resizing the nblist."""
  45    # covalent_cutoff: bool = False
  46
  47    FPID: ClassVar[str] = "GRAPH"
  48
  49    def init(self):
  50        return FrozenDict(
  51            {
  52                "max_nat": 1,
  53                "npairs": 1,
  54                "nblist_mult_size": self.mult_size,
  55            }
  56        )
  57
  58    def get_processor(self) -> Tuple[nn.Module, Dict]:
  59        return GraphProcessor, {
  60            "cutoff": self.cutoff,
  61            "graph_key": self.graph_key,
  62            "switch_params": self.switch_params,
  63            "name": f"{self.graph_key}_Processor",
  64        }
  65
  66    def get_graph_properties(self):
  67        return {
  68            self.graph_key: {
  69                "cutoff": self.cutoff,
  70                "directed": True,
  71            }
  72        }
  73
  74    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
  75        """build a nblist on cpu with numpy and dynamic shapes + store max shapes"""
  76        if self.graph_key in inputs:
  77            graph = inputs[self.graph_key]
  78            if "keep_graph" in graph:
  79                return state, inputs
  80
  81        coords = np.array(inputs["coordinates"], dtype=np.float32)
  82        natoms = np.array(inputs["natoms"], dtype=np.int32)
  83        batch_index = np.array(inputs["batch_index"], dtype=np.int32)
  84
  85        new_state = {**state}
  86        state_up = {}
  87
  88        mult_size = state.get("nblist_mult_size", self.mult_size)
  89        assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0"
  90
  91        if natoms.shape[0] == 1:
  92            max_nat = coords.shape[0]
  93            true_max_nat = max_nat
  94        else:
  95            max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0]))
  96            true_max_nat = int(np.max(natoms))
  97            if true_max_nat > max_nat:
  98                add_atoms = state.get("add_atoms", 0)
  99                new_maxnat = true_max_nat + add_atoms
 100                state_up["max_nat"] = (new_maxnat, max_nat)
 101                new_state["max_nat"] = new_maxnat
 102
 103        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
 104
 105        ### compute indices of all pairs
 106        p1, p2 = np.triu_indices(true_max_nat, 1)
 107        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
 108        pbc_shifts = None
 109        if natoms.shape[0] > 1:
 110            ## batching => mask irrelevant pairs
 111            mask_p12 = (
 112                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
 113            ).flatten()
 114            shift = np.concatenate(
 115                (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32))
 116            )
 117            p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
 118            p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
 119
 120        apply_pbc = "cells" in inputs
 121        if not apply_pbc:
 122            ### NO PBC
 123            vec = coords[p2] - coords[p1]
 124        else:
 125            cells = np.array(inputs["cells"], dtype=np.float32)
 126            reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32)
 127            minimage = state.get("minimum_image", True)
 128            if minimage:
 129                ## MINIMUM IMAGE CONVENTION
 130                vec = coords[p2] - coords[p1]
 131                if cells.shape[0] == 1:
 132                    vecpbc = np.dot(vec, reciprocal_cells[0])
 133                    pbc_shifts = -np.round(vecpbc)
 134                    vec = vec + np.dot(pbc_shifts, cells[0])
 135                else:
 136                    batch_index_vec = batch_index[p1]
 137                    vecpbc = np.einsum(
 138                        "aj,aji->ai", vec, reciprocal_cells[batch_index_vec]
 139                    )
 140                    pbc_shifts = -np.round(vecpbc)
 141                    vec = vec + np.einsum(
 142                        "aj,aji->ai", pbc_shifts, cells[batch_index_vec]
 143                    )
 144            else:
 145                ### GENERAL PBC
 146                ## put all atoms in central box
 147                if cells.shape[0] == 1:
 148                    coords_pbc = np.dot(coords, reciprocal_cells[0])
 149                    at_shifts = -np.floor(coords_pbc)
 150                    coords_pbc = coords + np.dot(at_shifts, cells[0])
 151                else:
 152                    coords_pbc = np.einsum(
 153                        "aj,aji->ai", coords, reciprocal_cells[batch_index]
 154                    )
 155                    at_shifts = -np.floor(coords_pbc)
 156                    coords_pbc = coords + np.einsum(
 157                        "aj,aji->ai", at_shifts, cells[batch_index]
 158                    )
 159                vec = coords_pbc[p2] - coords_pbc[p1]
 160
 161                ## compute maximum number of repeats
 162                inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5
 163                cdinv = cutoff_skin * inv_distances
 164                num_repeats_all = np.ceil(cdinv).astype(np.int32)
 165                if "true_sys" in inputs:
 166                    num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0)
 167                # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all)
 168                num_repeats = np.max(num_repeats_all, axis=0)
 169                num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0)))
 170                if np.any(num_repeats > num_repeats_prev):
 171                    num_repeats_new = np.maximum(num_repeats, num_repeats_prev)
 172                    state_up["num_repeats_pbc"] = (
 173                        tuple(num_repeats_new),
 174                        tuple(num_repeats_prev),
 175                    )
 176                    new_state["num_repeats_pbc"] = tuple(num_repeats_new)
 177                ## build all possible shifts
 178                cell_shift_pbc = np.array(
 179                    np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
 180                    dtype=cells.dtype,
 181                ).T.reshape(-1, 3)
 182                ## shift applied to vectors
 183                if cells.shape[0] == 1:
 184                    dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :]
 185                    vec = (vec[:, None, :] + dvec).reshape(-1, 3)
 186                    pbc_shifts = np.broadcast_to(
 187                        cell_shift_pbc[None, :, :],
 188                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
 189                    ).reshape(-1, 3)
 190                    p1 = np.broadcast_to(
 191                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
 192                    ).flatten()
 193                    p2 = np.broadcast_to(
 194                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
 195                    ).flatten()
 196                    if natoms.shape[0] > 1:
 197                        mask_p12 = np.broadcast_to(
 198                            mask_p12[:, None],
 199                            (mask_p12.shape[0], cell_shift_pbc.shape[0]),
 200                        ).flatten()
 201                else:
 202                    dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells)
 203
 204                    ## get pbc shifts specific to each box
 205                    cell_shift_pbc = np.broadcast_to(
 206                        cell_shift_pbc[None, :, :],
 207                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
 208                    )
 209                    mask = np.all(
 210                        np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
 211                    ).flatten()
 212                    idx = np.nonzero(mask)[0]
 213                    nshifts = idx.shape[0]
 214                    nshifts_prev = state.get("nshifts_pbc", 0)
 215                    if nshifts > nshifts_prev or add_margin:
 216                        nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1
 217                        state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev)
 218                        new_state["nshifts_pbc"] = nshifts_new
 219
 220                    dvec_filter = dvec.reshape(-1, 3)[idx, :]
 221                    cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :]
 222
 223                    ## get batch shift in the dvec_filter array
 224                    nrep = np.prod(2 * num_repeats_all + 1, axis=1)
 225                    bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1]))
 226
 227                    ## compute vectors
 228                    batch_index_vec = batch_index[p1]
 229                    nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0)
 230                    vec = vec.repeat(nrep_vec, axis=0)
 231                    nvec_pbc = nrep_vec.sum() #vec.shape[0]
 232                    nvec_pbc_prev = state.get("nvec_pbc", 0)
 233                    if nvec_pbc > nvec_pbc_prev or add_margin:
 234                        nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1
 235                        state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev)
 236                        new_state["nvec_pbc"] = nvec_pbc_new
 237
 238                    # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev)
 239                    ## get shift index
 240                    dshift = np.concatenate(
 241                        (np.array([0]), np.cumsum(nrep_vec)[:-1])
 242                    ).repeat(nrep_vec)
 243                    # ishift = np.arange(dshift.shape[0])-dshift
 244                    # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec)
 245                    icellshift = (
 246                        np.arange(dshift.shape[0])
 247                        - dshift
 248                        + bshift[batch_index_vec].repeat(nrep_vec)
 249                    )
 250                    # shift vectors
 251                    vec = vec + dvec_filter[icellshift]
 252                    pbc_shifts = cell_shift_pbc_filter[icellshift]
 253
 254                    p1 = np.repeat(p1, nrep_vec)
 255                    p2 = np.repeat(p2, nrep_vec)
 256                    if natoms.shape[0] > 1:
 257                        mask_p12 = np.repeat(mask_p12, nrep_vec)
 258
 259        ## compute distances
 260        d12 = (vec**2).sum(axis=-1)
 261        if natoms.shape[0] > 1:
 262            d12 = np.where(mask_p12, d12, cutoff_skin**2)
 263
 264        ## filter pairs
 265        max_pairs = state.get("npairs", 1)
 266        mask = d12 < cutoff_skin**2
 267        idx = np.nonzero(mask)[0]
 268        npairs = idx.shape[0]
 269        if npairs > max_pairs or add_margin:
 270            prev_max_pairs = max_pairs
 271            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 272            state_up["npairs"] = (max_pairs, prev_max_pairs)
 273            new_state["npairs"] = max_pairs
 274
 275        nat = coords.shape[0]
 276        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 277        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 278        d12_ = np.full(max_pairs, cutoff_skin**2)
 279        edge_src[:npairs] = p1[idx]
 280        edge_dst[:npairs] = p2[idx]
 281        d12_[:npairs] = d12[idx]
 282        d12 = d12_
 283
 284        if apply_pbc:
 285            pbc_shifts_ = np.zeros((max_pairs, 3))
 286            pbc_shifts_[:npairs] = pbc_shifts[idx]
 287            pbc_shifts = pbc_shifts_
 288            if not minimage:
 289                pbc_shifts[:npairs] = (
 290                    pbc_shifts[:npairs]
 291                    + at_shifts[edge_dst[:npairs]]
 292                    - at_shifts[edge_src[:npairs]]
 293                )
 294
 295        if "nblist_skin" in state:
 296            edge_src_skin = edge_src
 297            edge_dst_skin = edge_dst
 298            if apply_pbc:
 299                pbc_shifts_skin = pbc_shifts
 300            max_pairs_skin = state.get("npairs_skin", 1)
 301            mask = d12 < self.cutoff**2
 302            idx = np.nonzero(mask)[0]
 303            npairs_skin = idx.shape[0]
 304            if npairs_skin > max_pairs_skin or add_margin:
 305                prev_max_pairs_skin = max_pairs_skin
 306                max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1
 307                state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin)
 308                new_state["npairs_skin"] = max_pairs_skin
 309            edge_src = np.full(max_pairs_skin, nat, dtype=np.int32)
 310            edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32)
 311            d12_ = np.full(max_pairs_skin, self.cutoff**2)
 312            edge_src[:npairs_skin] = edge_src_skin[idx]
 313            edge_dst[:npairs_skin] = edge_dst_skin[idx]
 314            d12_[:npairs_skin] = d12[idx]
 315            d12 = d12_
 316            if apply_pbc:
 317                pbc_shifts = np.full((max_pairs_skin, 3), 0.0)
 318                pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx]
 319
 320        ## symmetrize
 321        edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate(
 322            (edge_dst, edge_src)
 323        )
 324        d12 = np.concatenate((d12, d12))
 325        if apply_pbc:
 326            pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts))
 327
 328        graph = inputs.get(self.graph_key, {})
 329        graph_out = {
 330            **graph,
 331            "edge_src": edge_src,
 332            "edge_dst": edge_dst,
 333            "d12": d12,
 334            "overflow": False,
 335            "pbc_shifts": pbc_shifts,
 336        }
 337        if "nblist_skin" in state:
 338            graph_out["edge_src_skin"] = edge_src_skin
 339            graph_out["edge_dst_skin"] = edge_dst_skin
 340            if apply_pbc:
 341                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
 342
 343        if self.k_space and apply_pbc:
 344            if "k_points" not in graph:
 345                ks, _, _, bewald = get_reciprocal_space_parameters(
 346                    reciprocal_cells, self.cutoff, self.kmax, self.kthr
 347                )
 348            graph_out["k_points"] = ks
 349            graph_out["b_ewald"] = bewald
 350
 351        output = {**inputs, self.graph_key: graph_out}
 352
 353        if return_state_update:
 354            return FrozenDict(new_state), output, state_up
 355        return FrozenDict(new_state), output
 356
 357    def check_reallocate(self, state, inputs, parent_overflow=False):
 358        """check for overflow and reallocate nblist if necessary"""
 359        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 360        if not overflow:
 361            return state, {}, inputs, False
 362
 363        add_margin = inputs[self.graph_key].get("overflow", False)
 364        state, inputs, state_up = self(
 365            state, inputs, return_state_update=True, add_margin=add_margin
 366        )
 367        return state, state_up, inputs, True
 368
 369    @partial(jax.jit, static_argnums=(0, 1))
 370    def process(self, state, inputs):
 371        """build a nblist on accelerator with jax and precomputed shapes"""
 372        if self.graph_key in inputs:
 373            graph = inputs[self.graph_key]
 374            if "keep_graph" in graph:
 375                return inputs
 376        coords = inputs["coordinates"]
 377        natoms = inputs["natoms"]
 378        batch_index = inputs["batch_index"]
 379
 380        if natoms.shape[0] == 1:
 381            max_nat = coords.shape[0]
 382        else:
 383            max_nat = state.get(
 384                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
 385            )
 386        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
 387
 388        ### compute indices of all pairs
 389        p1, p2 = np.triu_indices(max_nat, 1)
 390        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
 391        pbc_shifts = None
 392        if natoms.shape[0] > 1:
 393            ## batching => mask irrelevant pairs
 394            mask_p12 = (
 395                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
 396            ).flatten()
 397            shift = jnp.concatenate(
 398                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
 399            )
 400            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
 401            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
 402
 403        ## compute vectors
 404        overflow_repeats = jnp.asarray(False, dtype=bool)
 405        if "cells" not in inputs:
 406            vec = coords[p2] - coords[p1]
 407        else:
 408            cells = inputs["cells"]
 409            reciprocal_cells = inputs["reciprocal_cells"]
 410            minimage = state.get("minimum_image", True)
 411
 412            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
 413                vecpbc = jnp.dot(vec, reciprocal_cell)
 414                if mode == "round":
 415                    pbc_shifts = -jnp.round(vecpbc)
 416                elif mode == "floor":
 417                    pbc_shifts = -jnp.floor(vecpbc)
 418                else:
 419                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
 420                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
 421
 422            if minimage:
 423                ## minimum image convention
 424                vec = coords[p2] - coords[p1]
 425
 426                if cells.shape[0] == 1:
 427                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
 428                else:
 429                    batch_index_vec = batch_index[p1]
 430                    vec, pbc_shifts = jax.vmap(compute_pbc)(
 431                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
 432                    )
 433            else:
 434                ### general PBC only for single cell yet
 435                # if cells.shape[0] > 1:
 436                #     raise NotImplementedError(
 437                #         "General PBC not implemented for batches on accelerator."
 438                #     )
 439                # cell = cells[0]
 440                # reciprocal_cell = reciprocal_cells[0]
 441
 442                ## put all atoms in central box
 443                if cells.shape[0] == 1:
 444                    coords_pbc, at_shifts = compute_pbc(
 445                        coords, reciprocal_cells[0], cells[0], mode="floor"
 446                    )
 447                else:
 448                    coords_pbc, at_shifts = jax.vmap(
 449                        partial(compute_pbc, mode="floor")
 450                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
 451                vec = coords_pbc[p2] - coords_pbc[p1]
 452                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
 453                # if num_repeats is None:
 454                #     raise ValueError(
 455                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
 456                #     )
 457                # check if num_repeats is larger than previous
 458                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
 459                cdinv = cutoff_skin * inv_distances
 460                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
 461                if "true_sys" in inputs:
 462                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
 463                num_repeats_new = jnp.max(num_repeats_all, axis=0)
 464                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
 465
 466                cell_shift_pbc = jnp.asarray(
 467                    np.array(
 468                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
 469                        dtype=cells.dtype,
 470                    ).T.reshape(-1, 3)
 471                )
 472
 473                if cells.shape[0] == 1:
 474                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
 475                    pbc_shifts = jnp.broadcast_to(
 476                        cell_shift_pbc[None, :, :],
 477                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
 478                    ).reshape(-1, 3)
 479                    p1 = jnp.broadcast_to(
 480                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
 481                    ).flatten()
 482                    p2 = jnp.broadcast_to(
 483                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
 484                    ).flatten()
 485                    if natoms.shape[0] > 1:
 486                        mask_p12 = jnp.broadcast_to(
 487                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
 488                        ).flatten()
 489                else:
 490                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
 491
 492                    ## get pbc shifts specific to each box
 493                    cell_shift_pbc = jnp.broadcast_to(
 494                        cell_shift_pbc[None, :, :],
 495                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
 496                    )
 497                    mask = jnp.all(
 498                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
 499                    ).flatten()
 500                    max_shifts  = state.get("nshifts_pbc", 1)
 501
 502                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
 503                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
 504                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
 505                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
 506                        mask,
 507                        max_shifts,
 508                        (dvecx, 0.),
 509                        (dvecy, 0.),
 510                        (dvecz, 0.),
 511                        (shiftx, 0),
 512                        (shifty, 0),
 513                        (shiftz, 0),
 514                    )
 515                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
 516                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
 517                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
 518
 519                    ## get batch shift in the dvec_filter array
 520                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
 521                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
 522
 523                    ## repeat vectors
 524                    nvec_max = state.get("nvec_pbc", 1)
 525                    batch_index_vec = batch_index[p1]
 526                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
 527                    nvec = nrep_vec.sum()
 528                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
 529                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
 530                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
 531
 532                    ## get shift index
 533                    dshift = jnp.concatenate(
 534                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
 535                    )
 536                    if nrep_vec.size == 0:
 537                        dshift = jnp.array([],dtype=jnp.int32)
 538                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
 539                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
 540                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
 541                    vec = vec + dvec[icellshift]
 542                    pbc_shifts = cell_shift_pbc[icellshift]
 543                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
 544                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
 545                    if natoms.shape[0] > 1:
 546                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
 547                
 548
 549        ## compute distances
 550        d12 = (vec**2).sum(axis=-1)
 551        if natoms.shape[0] > 1:
 552            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
 553
 554        ## filter pairs
 555        max_pairs = state.get("npairs", 1)
 556        mask = d12 < cutoff_skin**2
 557        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
 558            mask,
 559            max_pairs,
 560            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
 561            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
 562            (d12, cutoff_skin**2),
 563        )
 564        if "cells" in inputs:
 565            pbc_shifts = (
 566                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
 567                .at[scatter_idx]
 568                .set(pbc_shifts, mode="drop")
 569            )
 570            if not minimage:
 571                pbc_shifts = (
 572                    pbc_shifts
 573                    + at_shifts.at[edge_dst].get(fill_value=0.0)
 574                    - at_shifts.at[edge_src].get(fill_value=0.0)
 575                )
 576
 577        ## check for overflow
 578        if natoms.shape[0] == 1:
 579            true_max_nat = coords.shape[0]
 580        else:
 581            true_max_nat = jnp.max(natoms)
 582        overflow_count = npairs > max_pairs
 583        overflow_at = true_max_nat > max_nat
 584        overflow = overflow_count | overflow_at | overflow_repeats
 585
 586        if "nblist_skin" in state:
 587            # edge_mask_skin = edge_mask
 588            edge_src_skin = edge_src
 589            edge_dst_skin = edge_dst
 590            if "cells" in inputs:
 591                pbc_shifts_skin = pbc_shifts
 592            max_pairs_skin = state.get("npairs_skin", 1)
 593            mask = d12 < self.cutoff**2
 594            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
 595                mask,
 596                max_pairs_skin,
 597                (edge_src, coords.shape[0]),
 598                (edge_dst, coords.shape[0]),
 599                (d12, self.cutoff**2),
 600            )
 601            if "cells" in inputs:
 602                pbc_shifts = (
 603                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
 604                    .at[scatter_idx]
 605                    .set(pbc_shifts, mode="drop")
 606                )
 607            overflow = overflow | (npairs_skin > max_pairs_skin)
 608
 609        ## symmetrize
 610        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
 611            (edge_dst, edge_src)
 612        )
 613        d12 = jnp.concatenate((d12, d12))
 614        if "cells" in inputs:
 615            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
 616
 617        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 618        graph_out = {
 619            **graph,
 620            "edge_src": edge_src,
 621            "edge_dst": edge_dst,
 622            "d12": d12,
 623            "overflow": overflow,
 624            "pbc_shifts": pbc_shifts,
 625        }
 626        if "nblist_skin" in state:
 627            graph_out["edge_src_skin"] = edge_src_skin
 628            graph_out["edge_dst_skin"] = edge_dst_skin
 629            if "cells" in inputs:
 630                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
 631
 632        if self.k_space and "cells" in inputs:
 633            if "k_points" not in graph:
 634                raise NotImplementedError(
 635                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 636                )
 637        return {**inputs, self.graph_key: graph_out}
 638
 639    @partial(jax.jit, static_argnums=(0,))
 640    def update_skin(self, inputs):
 641        """update the nblist without recomputing the full nblist"""
 642        graph = inputs[self.graph_key]
 643
 644        edge_src_skin = graph["edge_src_skin"]
 645        edge_dst_skin = graph["edge_dst_skin"]
 646        coords = inputs["coordinates"]
 647        vec = coords.at[edge_dst_skin].get(
 648            mode="fill", fill_value=self.cutoff
 649        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
 650
 651        if "cells" in inputs:
 652            pbc_shifts_skin = graph["pbc_shifts_skin"]
 653            cells = inputs["cells"]
 654            if cells.shape[0] == 1:
 655                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
 656            else:
 657                batch_index_vec = inputs["batch_index"][edge_src_skin]
 658                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
 659
 660        nat = coords.shape[0]
 661        d12 = jnp.sum(vec**2, axis=-1)
 662        mask = d12 < self.cutoff**2
 663        max_pairs = graph["edge_src"].shape[0] // 2
 664        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
 665            mask,
 666            max_pairs,
 667            (edge_src_skin, nat),
 668            (edge_dst_skin, nat),
 669            (d12, self.cutoff**2),
 670        )
 671        if "cells" in inputs:
 672            pbc_shifts = (
 673                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
 674                .at[scatter_idx]
 675                .set(pbc_shifts_skin)
 676            )
 677
 678        overflow = graph.get("overflow", False) | (npairs > max_pairs)
 679        graph_out = {
 680            **graph,
 681            "edge_src": jnp.concatenate((edge_src, edge_dst)),
 682            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
 683            "d12": jnp.concatenate((d12, d12)),
 684            "overflow": overflow,
 685        }
 686        if "cells" in inputs:
 687            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
 688
 689        if self.k_space and "cells" in inputs:
 690            if "k_points" not in graph:
 691                raise NotImplementedError(
 692                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 693                )
 694
 695        return {**inputs, self.graph_key: graph_out}
 696
 697
 698class GraphProcessor(nn.Module):
 699    """Process a pre-generated graph
 700
 701    The pre-generated graph should contain the following keys:
 702    - edge_src: source indices of the edges
 703    - edge_dst: destination indices of the edges
 704    - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs)
 705
 706    This module is automatically added to a FENNIX model when a GraphGenerator is used.
 707
 708    """
 709
 710    cutoff: float
 711    """Cutoff distance for the graph."""
 712    graph_key: str = "graph"
 713    """Key of the graph in the outputs."""
 714    switch_params: dict = dataclasses.field(default_factory=dict)
 715    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 716
 717    @nn.compact
 718    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 719        graph = inputs[self.graph_key]
 720        coords = inputs["coordinates"]
 721        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 722        # edge_mask = edge_src < coords.shape[0]
 723        vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[
 724            edge_src
 725        ].get(mode="fill", fill_value=0.0)
 726        if "cells" in inputs:
 727            cells = inputs["cells"]
 728            if cells.shape[0] == 1:
 729                vec = vec + jnp.dot(graph["pbc_shifts"], cells[0])
 730            else:
 731                batch_index_vec = inputs["batch_index"][edge_src]
 732                vec = vec + jax.vmap(jnp.dot)(
 733                    graph["pbc_shifts"], cells[batch_index_vec]
 734                )
 735
 736        distances = jnp.linalg.norm(vec, axis=-1)
 737        edge_mask = distances < self.cutoff
 738
 739        switch = SwitchFunction(
 740            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
 741        )((distances, edge_mask))
 742
 743        graph_out = {
 744            **graph,
 745            "vec": vec,
 746            "distances": distances,
 747            "switch": switch,
 748            "edge_mask": edge_mask,
 749        }
 750
 751        if "alch_group" in inputs:
 752            alch_group = inputs["alch_group"]
 753            lambda_e = inputs["alch_elambda"]
 754            lambda_e = 0.5*(1.-jnp.cos(jnp.pi*lambda_e))
 755            mask = alch_group[edge_src] == alch_group[edge_dst]
 756            graph_out["switch_raw"] = switch
 757            graph_out["switch"] = jnp.where(
 758                mask,
 759                switch,
 760                lambda_e * switch ,
 761            )
 762
 763
 764        return {**inputs, self.graph_key: graph_out}
 765
 766
 767@dataclasses.dataclass(frozen=True)
 768class GraphFilter:
 769    """Filter a graph based on a cutoff distance
 770
 771    FPID: GRAPH_FILTER
 772    """
 773
 774    cutoff: float
 775    """Cutoff distance for the filtering."""
 776    parent_graph: str
 777    """Key of the parent graph in the inputs."""
 778    graph_key: str
 779    """Key of the filtered graph in the outputs."""
 780    remove_hydrogens: int = False
 781    """Remove edges where the source is a hydrogen atom."""
 782    switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict)
 783    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 784    k_space: bool = False
 785    """Generate k-space information for the graph."""
 786    kmax: int = 30
 787    """Maximum number of k-points to consider."""
 788    kthr: float = 1e-6
 789    """Threshold for k-point filtering."""
 790    mult_size: float = 1.05
 791    """Multiplicative factor for resizing the nblist."""
 792
 793    FPID: ClassVar[str] = "GRAPH_FILTER"
 794
 795    def init(self):
 796        return FrozenDict(
 797            {
 798                "npairs": 1,
 799                "nblist_mult_size": self.mult_size,
 800            }
 801        )
 802
 803    def get_processor(self) -> Tuple[nn.Module, Dict]:
 804        return GraphFilterProcessor, {
 805            "cutoff": self.cutoff,
 806            "graph_key": self.graph_key,
 807            "parent_graph": self.parent_graph,
 808            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
 809            "switch_params": self.switch_params,
 810        }
 811
 812    def get_graph_properties(self):
 813        return {
 814            self.graph_key: {
 815                "cutoff": self.cutoff,
 816                "directed": True,
 817                "parent_graph": self.parent_graph,
 818            }
 819        }
 820
 821    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 822        """filter a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 823        graph_in = inputs[self.parent_graph]
 824        nat = inputs["species"].shape[0]
 825
 826        new_state = {**state}
 827        state_up = {}
 828        mult_size = state.get("nblist_mult_size", self.mult_size)
 829        assert mult_size >= 1., "nblist_mult_size should be >= 1."
 830
 831        edge_src = np.array(graph_in["edge_src"], dtype=np.int32)
 832        d12 = np.array(graph_in["d12"], dtype=np.float32)
 833        if self.remove_hydrogens:
 834            species = inputs["species"]
 835            src_idx = (edge_src < nat).nonzero()[0]
 836            mask = np.zeros(edge_src.shape[0], dtype=bool)
 837            mask[src_idx] = (species > 1)[edge_src[src_idx]]
 838            d12 = np.where(mask, d12, self.cutoff**2)
 839        mask = d12 < self.cutoff**2
 840
 841        max_pairs = state.get("npairs", 1)
 842        idx = np.nonzero(mask)[0]
 843        npairs = idx.shape[0]
 844        if npairs > max_pairs or add_margin:
 845            prev_max_pairs = max_pairs
 846            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 847            state_up["npairs"] = (max_pairs, prev_max_pairs)
 848            new_state["npairs"] = max_pairs
 849
 850        filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32)
 851        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 852        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 853        d12_ = np.full(max_pairs, self.cutoff**2)
 854        filter_indices[:npairs] = idx
 855        edge_src[:npairs] = graph_in["edge_src"][idx]
 856        edge_dst[:npairs] = graph_in["edge_dst"][idx]
 857        d12_[:npairs] = d12[idx]
 858        d12 = d12_
 859
 860        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 861        graph_out = {
 862            **graph,
 863            "edge_src": edge_src,
 864            "edge_dst": edge_dst,
 865            "filter_indices": filter_indices,
 866            "d12": d12,
 867            "overflow": False,
 868        }
 869
 870        if self.k_space and "cells" in inputs:
 871            if "k_points" not in graph:
 872                ks, _, _, bewald = get_reciprocal_space_parameters(
 873                    inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr
 874                )
 875            graph_out["k_points"] = ks
 876            graph_out["b_ewald"] = bewald
 877
 878        output = {**inputs, self.graph_key: graph_out}
 879        if return_state_update:
 880            return FrozenDict(new_state), output, state_up
 881        return FrozenDict(new_state), output
 882
 883    def check_reallocate(self, state, inputs, parent_overflow=False):
 884        """check for overflow and reallocate nblist if necessary"""
 885        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 886        if not overflow:
 887            return state, {}, inputs, False
 888
 889        add_margin = inputs[self.graph_key].get("overflow", False)
 890        state, inputs, state_up = self(
 891            state, inputs, return_state_update=True, add_margin=add_margin
 892        )
 893        return state, state_up, inputs, True
 894
 895    @partial(jax.jit, static_argnums=(0, 1))
 896    def process(self, state, inputs):
 897        """filter a nblist on accelerator with jax and precomputed shapes"""
 898        graph_in = inputs[self.parent_graph]
 899        if state is None:
 900            # skin update mode
 901            graph = inputs[self.graph_key]
 902            max_pairs = graph["edge_src"].shape[0]
 903        else:
 904            max_pairs = state.get("npairs", 1)
 905
 906        max_pairs_in = graph_in["edge_src"].shape[0]
 907        nat = inputs["species"].shape[0]
 908
 909        edge_src = graph_in["edge_src"]
 910        d12 = graph_in["d12"]
 911        if self.remove_hydrogens:
 912            species = inputs["species"]
 913            mask = (species > 1)[edge_src]
 914            d12 = jnp.where(mask, d12, self.cutoff**2)
 915        mask = d12 < self.cutoff**2
 916
 917        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
 918            mask,
 919            max_pairs,
 920            (edge_src, nat),
 921            (graph_in["edge_dst"], nat),
 922            (d12, self.cutoff**2),
 923            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
 924        )
 925
 926        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 927        overflow = graph.get("overflow", False) | (npairs > max_pairs)
 928        graph_out = {
 929            **graph,
 930            "edge_src": edge_src,
 931            "edge_dst": edge_dst,
 932            "filter_indices": filter_indices,
 933            "d12": d12,
 934            "overflow": overflow,
 935        }
 936
 937        if self.k_space and "cells" in inputs:
 938            if "k_points" not in graph:
 939                raise NotImplementedError(
 940                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 941                )
 942
 943        return {**inputs, self.graph_key: graph_out}
 944
 945    @partial(jax.jit, static_argnums=(0,))
 946    def update_skin(self, inputs):
 947        return self.process(None, inputs)
 948
 949
 950class GraphFilterProcessor(nn.Module):
 951    """Filter processing for a pre-generated graph
 952
 953    This module is automatically added to a FENNIX model when a GraphFilter is used.
 954    """
 955
 956    cutoff: float
 957    """Cutoff distance for the filtering."""
 958    graph_key: str
 959    """Key of the filtered graph in the inputs."""
 960    parent_graph: str
 961    """Key of the parent graph in the inputs."""
 962    switch_params: dict = dataclasses.field(default_factory=dict)
 963    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 964
 965    @nn.compact
 966    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 967        graph_in = inputs[self.parent_graph]
 968        graph = inputs[self.graph_key]
 969
 970        if graph_in["vec"].shape[0] == 0:
 971            vec = graph_in["vec"]
 972            distances = graph_in["distances"]
 973            filter_indices = jnp.asarray([], dtype=jnp.int32)
 974        else:
 975            filter_indices = graph["filter_indices"]
 976            vec = (
 977                graph_in["vec"]
 978                .at[filter_indices]
 979                .get(mode="fill", fill_value=self.cutoff)
 980            )
 981            distances = (
 982                graph_in["distances"]
 983                .at[filter_indices]
 984                .get(mode="fill", fill_value=self.cutoff)
 985            )
 986
 987        edge_mask = distances < self.cutoff
 988        switch = SwitchFunction(
 989            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
 990        )((distances, edge_mask))
 991
 992        graph_out = {
 993            **graph,
 994            "vec": vec,
 995            "distances": distances,
 996            "switch": switch,
 997            "filter_indices": filter_indices,
 998            "edge_mask": edge_mask,
 999        }
1000
1001        if "alch_group" in inputs:
1002            edge_src=graph["edge_src"]
1003            edge_dst=graph["edge_dst"]
1004            alch_group = inputs["alch_group"]
1005            lambda_e = inputs["alch_elambda"]
1006            lambda_e = 0.5*(1.-jnp.cos(jnp.pi*lambda_e))
1007            mask = alch_group[edge_src] == alch_group[edge_dst]
1008            graph_out["switch_raw"] = switch
1009            graph_out["switch"] = jnp.where(
1010                mask,
1011                switch,
1012                lambda_e * switch ,
1013            )
1014
1015        return {**inputs, self.graph_key: graph_out}
1016
1017
1018@dataclasses.dataclass(frozen=True)
1019class GraphAngularExtension:
1020    """Add angles list to a graph
1021
1022    FPID: GRAPH_ANGULAR_EXTENSION
1023    """
1024
1025    mult_size: float = 1.05
1026    """Multiplicative factor for resizing the nblist."""
1027    add_neigh: int = 5
1028    """Additional neighbors to add to the nblist when resizing."""
1029    graph_key: str = "graph"
1030    """Key of the graph in the inputs."""
1031
1032    FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION"
1033
1034    def init(self):
1035        return FrozenDict(
1036            {
1037                "nangles": 0,
1038                "nblist_mult_size": self.mult_size,
1039                "max_neigh": self.add_neigh,
1040                "add_neigh": self.add_neigh,
1041            }
1042        )
1043
1044    def get_processor(self) -> Tuple[nn.Module, Dict]:
1045        return GraphAngleProcessor, {
1046            "graph_key": self.graph_key,
1047            "name": f"{self.graph_key}_AngleProcessor",
1048        }
1049
1050    def get_graph_properties(self):
1051        return {
1052            self.graph_key: {
1053                "has_angles": True,
1054            }
1055        }
1056
1057    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1058        """build angle nblist on cpu with numpy and dynamic shapes + store max shapes"""
1059        graph = inputs[self.graph_key]
1060        edge_src = np.array(graph["edge_src"], dtype=np.int32)
1061
1062        new_state = {**state}
1063        state_up = {}
1064        mult_size = state.get("nblist_mult_size", self.mult_size)
1065        assert mult_size >= 1., "nblist_mult_size should be >= 1."
1066
1067        ### count number of neighbors
1068        nat = inputs["species"].shape[0]
1069        count = np.zeros(nat + 1, dtype=np.int32)
1070        np.add.at(count, edge_src, 1)
1071        max_count = int(np.max(count[:-1]))
1072
1073        ### get sizes
1074        max_neigh = state.get("max_neigh", self.add_neigh)
1075        nedge = edge_src.shape[0]
1076        if max_count > max_neigh or add_margin:
1077            prev_max_neigh = max_neigh
1078            max_neigh = max(max_count, max_neigh) + state.get(
1079                "add_neigh", self.add_neigh
1080            )
1081            state_up["max_neigh"] = (max_neigh, prev_max_neigh)
1082            new_state["max_neigh"] = max_neigh
1083
1084        max_neigh_arr = np.empty(max_neigh, dtype=bool)
1085
1086        nedge = edge_src.shape[0]
1087
1088        ### sort edge_src
1089        idx_sort = np.argsort(edge_src)
1090        edge_src_sorted = edge_src[idx_sort]
1091
1092        ### map sparse to dense nblist
1093        offset = np.tile(np.arange(max_count), nat)
1094        if max_count * nat >= nedge:
1095            offset = np.tile(np.arange(max_count), nat)[:nedge]
1096        else:
1097            offset = np.zeros(nedge, dtype=np.int32)
1098            offset[: max_count * nat] = np.tile(np.arange(max_count), nat)
1099
1100        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1101        mask = edge_src_sorted < nat
1102        indices = edge_src_sorted * max_count + offset
1103        indices = indices[mask]
1104        idx_sort = idx_sort[mask]
1105        edge_idx = np.full(nat * max_count, nedge, dtype=np.int32)
1106        edge_idx[indices] = idx_sort
1107        edge_idx = edge_idx.reshape(nat, max_count)
1108
1109        ### find all triplet for each atom center
1110        local_src, local_dst = np.triu_indices(max_count, 1)
1111        angle_src = edge_idx[:, local_src].flatten()
1112        angle_dst = edge_idx[:, local_dst].flatten()
1113
1114        ### mask for valid angles
1115        mask1 = angle_src < nedge
1116        mask2 = angle_dst < nedge
1117        angle_mask = mask1 & mask2
1118
1119        max_angles = state.get("nangles", 0)
1120        idx = np.nonzero(angle_mask)[0]
1121        nangles = idx.shape[0]
1122        if nangles > max_angles or add_margin:
1123            max_angles_prev = max_angles
1124            max_angles = int(mult_size * max(nangles, max_angles)) + 1
1125            state_up["nangles"] = (max_angles, max_angles_prev)
1126            new_state["nangles"] = max_angles
1127
1128        ## filter angles to sparse representation
1129        angle_src_ = np.full(max_angles, nedge, dtype=np.int32)
1130        angle_dst_ = np.full(max_angles, nedge, dtype=np.int32)
1131        angle_src_[:nangles] = angle_src[idx]
1132        angle_dst_[:nangles] = angle_dst[idx]
1133
1134        central_atom = np.full(max_angles, nat, dtype=np.int32)
1135        central_atom[:nangles] = edge_src[angle_src_[:nangles]]
1136
1137        ## update graph
1138        output = {
1139            **inputs,
1140            self.graph_key: {
1141                **graph,
1142                "angle_src": angle_src_,
1143                "angle_dst": angle_dst_,
1144                "central_atom": central_atom,
1145                "angle_overflow": False,
1146                "max_neigh": max_neigh,
1147                "__max_neigh_array": max_neigh_arr,
1148            },
1149        }
1150
1151        if return_state_update:
1152            return FrozenDict(new_state), output, state_up
1153        return FrozenDict(new_state), output
1154
1155    def check_reallocate(self, state, inputs, parent_overflow=False):
1156        """check for overflow and reallocate nblist if necessary"""
1157        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1158        if not overflow:
1159            return state, {}, inputs, False
1160
1161        add_margin = inputs[self.graph_key]["angle_overflow"]
1162        state, inputs, state_up = self(
1163            state, inputs, return_state_update=True, add_margin=add_margin
1164        )
1165        return state, state_up, inputs, True
1166
1167    @partial(jax.jit, static_argnums=(0, 1))
1168    def process(self, state, inputs):
1169        """build angle nblist on accelerator with jax and precomputed shapes"""
1170        graph = inputs[self.graph_key]
1171        edge_src = graph["edge_src"]
1172
1173        ### count number of neighbors
1174        nat = inputs["species"].shape[0]
1175        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1176        max_count = jnp.max(count)
1177
1178        ### get sizes
1179        if state is None:
1180            max_neigh_arr = graph["__max_neigh_array"]
1181            max_neigh = max_neigh_arr.shape[0]
1182            prev_nangles = graph["angle_src"].shape[0]
1183        else:
1184            max_neigh = state.get("max_neigh", self.add_neigh)
1185            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1186            prev_nangles = state.get("nangles", 0)
1187
1188        nedge = edge_src.shape[0]
1189
1190        ### sort edge_src
1191        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1192        edge_src_sorted = edge_src[idx_sort]
1193
1194        ### map sparse to dense nblist
1195        if max_neigh * nat < nedge:
1196            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1197        offset = jnp.asarray(
1198            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1199        )
1200        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1201        indices = edge_src_sorted * max_neigh + offset
1202        edge_idx = (
1203            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1204            .at[indices]
1205            .set(idx_sort, mode="drop")
1206            .reshape(nat, max_neigh)
1207        )
1208
1209        ### find all triplet for each atom center
1210        local_src, local_dst = np.triu_indices(max_neigh, 1)
1211        angle_src = edge_idx[:, local_src].flatten()
1212        angle_dst = edge_idx[:, local_dst].flatten()
1213
1214        ### mask for valid angles
1215        mask1 = angle_src < nedge
1216        mask2 = angle_dst < nedge
1217        angle_mask = mask1 & mask2
1218
1219        ## filter angles to sparse representation
1220        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1221            angle_mask,
1222            prev_nangles,
1223            (angle_src, nedge),
1224            (angle_dst, nedge),
1225        )
1226        ## find central atom
1227        central_atom = edge_src[angle_src]
1228
1229        ## check for overflow
1230        angle_overflow = nangles > prev_nangles
1231        neigh_overflow = max_count > max_neigh
1232        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1233
1234        ## update graph
1235        output = {
1236            **inputs,
1237            self.graph_key: {
1238                **graph,
1239                "angle_src": angle_src,
1240                "angle_dst": angle_dst,
1241                "central_atom": central_atom,
1242                "angle_overflow": overflow,
1243                # "max_neigh": max_neigh,
1244                "__max_neigh_array": max_neigh_arr,
1245            },
1246        }
1247
1248        return output
1249
1250    @partial(jax.jit, static_argnums=(0,))
1251    def update_skin(self, inputs):
1252        return self.process(None, inputs)
1253
1254
1255class GraphAngleProcessor(nn.Module):
1256    """Process a pre-generated graph to compute angles
1257
1258    This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
1259
1260    """
1261
1262    graph_key: str
1263    """Key of the graph in the inputs."""
1264
1265    @nn.compact
1266    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1267        graph = inputs[self.graph_key]
1268        distances = graph["distances"]
1269        vec = graph["vec"]
1270        angle_src = graph["angle_src"]
1271        angle_dst = graph["angle_dst"]
1272
1273        dir = vec / jnp.clip(distances[:, None], a_min=1.0e-5)
1274        cos_angles = (
1275            dir.at[angle_src].get(mode="fill", fill_value=0.5)
1276            * dir.at[angle_dst].get(mode="fill", fill_value=0.5)
1277        ).sum(axis=-1)
1278
1279        angles = jnp.arccos(0.95 * cos_angles)
1280
1281        return {
1282            **inputs,
1283            self.graph_key: {
1284                **graph,
1285                # "cos_angles": cos_angles,
1286                "angles": angles,
1287                # "angle_mask": angle_mask,
1288            },
1289        }
1290
1291
1292@dataclasses.dataclass(frozen=True)
1293class SpeciesIndexer:
1294    """Build an index that splits atomic arrays by species.
1295
1296    FPID: SPECIES_INDEXER
1297
1298    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1299    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1300
1301    """
1302
1303    output_key: str = "species_index"
1304    """Key for the output dictionary."""
1305    species_order: Optional[str] = None
1306    """Comma separated list of species in the order they should be indexed."""
1307    add_atoms: int = 0
1308    """Additional atoms to add to the sizes."""
1309    add_atoms_margin: int = 10
1310    """Additional atoms to add to the sizes when adding margin."""
1311
1312    FPID: ClassVar[str] = "SPECIES_INDEXER"
1313
1314    def init(self):
1315        return FrozenDict(
1316            {
1317                "sizes": {},
1318            }
1319        )
1320
1321    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1322        species = np.array(inputs["species"], dtype=np.int32)
1323        nat = species.shape[0]
1324        set_species, counts = np.unique(species, return_counts=True)
1325
1326        new_state = {**state}
1327        state_up = {}
1328
1329        sizes = state.get("sizes", FrozenDict({}))
1330        new_sizes = {**sizes}
1331        up_sizes = False
1332        counts_dict = {}
1333        for s, c in zip(set_species, counts):
1334            if s <= 0:
1335                continue
1336            counts_dict[s] = c
1337            if c > sizes.get(s, 0):
1338                up_sizes = True
1339                add_atoms = state.get("add_atoms", self.add_atoms)
1340                if add_margin:
1341                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1342                new_sizes[s] = c + add_atoms
1343
1344        new_sizes = FrozenDict(new_sizes)
1345        if up_sizes:
1346            state_up["sizes"] = (new_sizes, sizes)
1347            new_state["sizes"] = new_sizes
1348
1349        if self.species_order is not None:
1350            species_order = [el.strip() for el in self.species_order.split(",")]
1351            max_size_prev = state.get("max_size", 0)
1352            max_size = max(new_sizes.values())
1353            if max_size > max_size_prev:
1354                state_up["max_size"] = (max_size, max_size_prev)
1355                new_state["max_size"] = max_size
1356                max_size_prev = max_size
1357
1358            species_index = np.full((len(species_order), max_size), nat, dtype=np.int32)
1359            for i, el in enumerate(species_order):
1360                s = PERIODIC_TABLE_REV_IDX[el]
1361                if s in counts_dict.keys():
1362                    species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1363        else:
1364            species_index = {
1365                PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32)
1366                for s, c in new_sizes.items()
1367            }
1368            for s, c in zip(set_species, counts):
1369                if s <= 0:
1370                    continue
1371                species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0]
1372
1373        output = {
1374            **inputs,
1375            self.output_key: species_index,
1376            self.output_key + "_overflow": False,
1377        }
1378
1379        if return_state_update:
1380            return FrozenDict(new_state), output, state_up
1381        return FrozenDict(new_state), output
1382
1383    def check_reallocate(self, state, inputs, parent_overflow=False):
1384        """check for overflow and reallocate nblist if necessary"""
1385        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1386        if not overflow:
1387            return state, {}, inputs, False
1388
1389        add_margin = inputs[self.output_key + "_overflow"]
1390        state, inputs, state_up = self(
1391            state, inputs, return_state_update=True, add_margin=add_margin
1392        )
1393        return state, state_up, inputs, True
1394        # return state, {}, inputs, parent_overflow
1395
1396    @partial(jax.jit, static_argnums=(0, 1))
1397    def process(self, state, inputs):
1398        # assert (
1399        #     self.output_key in inputs
1400        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1401
1402        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1403        if self.output_key in inputs and not recompute_species_index:
1404            return inputs
1405
1406        if state is None:
1407            raise ValueError("Species Indexer state must be provided on accelerator.")
1408
1409        species = inputs["species"]
1410        nat = species.shape[0]
1411
1412        sizes = state["sizes"]
1413
1414        if self.species_order is not None:
1415            species_order = [el.strip() for el in self.species_order.split(",")]
1416            max_size = state["max_size"]
1417
1418            species_index = jnp.full(
1419                (len(species_order), max_size), nat, dtype=jnp.int32
1420            )
1421            for i, el in enumerate(species_order):
1422                s = PERIODIC_TABLE_REV_IDX[el]
1423                if s in sizes.keys():
1424                    c = sizes[s]
1425                    species_index = species_index.at[i, :].set(
1426                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1427                    )
1428                # if s in counts_dict.keys():
1429                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1430        else:
1431            # species_index = {
1432            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1433            # for s, c in sizes.items()
1434            # }
1435            species_index = {}
1436            overflow = False
1437            natcount = 0
1438            for s, c in sizes.items():
1439                mask = species == s
1440                new_size = jnp.sum(mask)
1441                natcount = natcount + new_size
1442                overflow = overflow | (new_size > c)  # check if sizes are correct
1443                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1444                    species == s, size=c, fill_value=nat
1445                )[0]
1446
1447            mask = species <= 0
1448            new_size = jnp.sum(mask)
1449            natcount = natcount + new_size
1450            overflow = overflow | (
1451                natcount < species.shape[0]
1452            )  # check if any species missing
1453
1454        return {
1455            **inputs,
1456            self.output_key: species_index,
1457            self.output_key + "_overflow": overflow,
1458        }
1459
1460    @partial(jax.jit, static_argnums=(0,))
1461    def update_skin(self, inputs):
1462        return self.process(None, inputs)
1463
1464@dataclasses.dataclass(frozen=True)
1465class BlockIndexer:
1466    """Build an index that splits atomic arrays by chemical blocks.
1467
1468    FPID: BLOCK_INDEXER
1469
1470    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1471    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1472
1473    """
1474
1475    output_key: str = "block_index"
1476    """Key for the output dictionary."""
1477    add_atoms: int = 0
1478    """Additional atoms to add to the sizes."""
1479    add_atoms_margin: int = 10
1480    """Additional atoms to add to the sizes when adding margin."""
1481    split_CNOPSSe: bool = False
1482
1483    FPID: ClassVar[str] = "BLOCK_INDEXER"
1484
1485    def init(self):
1486        return FrozenDict(
1487            {
1488                "sizes": {},
1489            }
1490        )
1491
1492    def build_chemical_blocks(self):
1493        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1494        if self.split_CNOPSSe:
1495            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1496            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1497        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1498        if self.split_CNOPSSe:
1499            _CHEMICAL_BLOCKS[6] = 1
1500            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1501            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1502            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1503            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1504            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1505        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1506
1507    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1508        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1509
1510        species = np.array(inputs["species"], dtype=np.int32)
1511        blocks = _CHEMICAL_BLOCKS[species]
1512        nat = species.shape[0]
1513        set_blocks, counts = np.unique(blocks, return_counts=True)
1514
1515        new_state = {**state}
1516        state_up = {}
1517
1518        sizes = state.get("sizes", FrozenDict({}))
1519        new_sizes = {**sizes}
1520        up_sizes = False
1521        for s, c in zip(set_blocks, counts):
1522            if s < 0:
1523                continue
1524            key = (s, _CHEMICAL_BLOCKS_NAMES[s])
1525            if c > sizes.get(key, 0):
1526                up_sizes = True
1527                add_atoms = state.get("add_atoms", self.add_atoms)
1528                if add_margin:
1529                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1530                new_sizes[key] = c + add_atoms
1531
1532        new_sizes = FrozenDict(new_sizes)
1533        if up_sizes:
1534            state_up["sizes"] = (new_sizes, sizes)
1535            new_state["sizes"] = new_sizes
1536
1537        block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES}
1538        for (_,n), c in new_sizes.items():
1539            block_index[n] = np.full(c, nat, dtype=np.int32)
1540        # block_index = {
1541            # n: np.full(c, nat, dtype=np.int32)
1542            # for (_,n), c in new_sizes.items()
1543        # }
1544        for s, c in zip(set_blocks, counts):
1545            if s < 0:
1546                continue
1547            block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0]
1548
1549        output = {
1550            **inputs,
1551            self.output_key: block_index,
1552            self.output_key + "_overflow": False,
1553        }
1554
1555        if return_state_update:
1556            return FrozenDict(new_state), output, state_up
1557        return FrozenDict(new_state), output
1558
1559    def check_reallocate(self, state, inputs, parent_overflow=False):
1560        """check for overflow and reallocate nblist if necessary"""
1561        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1562        if not overflow:
1563            return state, {}, inputs, False
1564
1565        add_margin = inputs[self.output_key + "_overflow"]
1566        state, inputs, state_up = self(
1567            state, inputs, return_state_update=True, add_margin=add_margin
1568        )
1569        return state, state_up, inputs, True
1570        # return state, {}, inputs, parent_overflow
1571
1572    @partial(jax.jit, static_argnums=(0, 1))
1573    def process(self, state, inputs):
1574        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1575        # assert (
1576        #     self.output_key in inputs
1577        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1578
1579        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1580        if self.output_key in inputs and not recompute_species_index:
1581            return inputs
1582
1583        if state is None:
1584            raise ValueError("Block Indexer state must be provided on accelerator.")
1585
1586        species = inputs["species"]
1587        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1588        nat = species.shape[0]
1589
1590        sizes = state["sizes"]
1591
1592        # species_index = {
1593        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1594        # for s, c in sizes.items()
1595        # }
1596        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1597        overflow = False
1598        natcount = 0
1599        for (s,name), c in sizes.items():
1600            mask = blocks == s
1601            new_size = jnp.sum(mask)
1602            natcount = natcount + new_size
1603            overflow = overflow | (new_size > c)  # check if sizes are correct
1604            block_index[name] = jnp.nonzero(
1605                mask, size=c, fill_value=nat
1606            )[0]
1607
1608        mask = blocks < 0
1609        new_size = jnp.sum(mask)
1610        natcount = natcount + new_size
1611        overflow = overflow | (
1612            natcount < species.shape[0]
1613        )  # check if any species missing
1614
1615        return {
1616            **inputs,
1617            self.output_key: block_index,
1618            self.output_key + "_overflow": overflow,
1619        }
1620
1621    @partial(jax.jit, static_argnums=(0,))
1622    def update_skin(self, inputs):
1623        return self.process(None, inputs)
1624
1625
1626@dataclasses.dataclass(frozen=True)
1627class AtomPadding:
1628    """Pad atomic arrays to a fixed size."""
1629
1630    mult_size: float = 1.2
1631    """Multiplicative factor for resizing the atomic arrays."""
1632    add_sys: int = 0
1633
1634    def init(self):
1635        return {"prev_nat": 0, "prev_nsys": 0}
1636
1637    def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]:
1638        species = inputs["species"]
1639        nat = species.shape[0]
1640
1641        prev_nat = state.get("prev_nat", 0)
1642        prev_nat_ = prev_nat
1643        if nat > prev_nat_:
1644            prev_nat_ = int(self.mult_size * nat) + 1
1645
1646        nsys = len(inputs["natoms"])
1647        prev_nsys = state.get("prev_nsys", 0)
1648        prev_nsys_ = prev_nsys
1649        if nsys > prev_nsys_:
1650            prev_nsys_ = nsys + self.add_sys
1651
1652        add_atoms = prev_nat_ - nat
1653        add_sys = prev_nsys_ - nsys  + 1
1654        output = {**inputs}
1655        if add_atoms > 0:
1656            for k, v in inputs.items():
1657                if isinstance(v, np.ndarray) or isinstance(v, jax.Array):
1658                    if v.shape[0] == nat:
1659                        output[k] = np.append(
1660                            v,
1661                            np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype),
1662                            axis=0,
1663                        )
1664                    elif v.shape[0] == nsys:
1665                        if k == "cells":
1666                            output[k] = np.append(
1667                                v,
1668                                1000
1669                                * np.eye(3, dtype=v.dtype)[None, :, :].repeat(
1670                                    add_sys, axis=0
1671                                ),
1672                                axis=0,
1673                            )
1674                        else:
1675                            output[k] = np.append(
1676                                v,
1677                                np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype),
1678                                axis=0,
1679                            )
1680            output["natoms"] = np.append(
1681                inputs["natoms"], np.zeros(add_sys, dtype=np.int32)
1682            )
1683            output["species"] = np.append(
1684                species, -1 * np.ones(add_atoms, dtype=species.dtype)
1685            )
1686            output["batch_index"] = np.append(
1687                inputs["batch_index"],
1688                np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype),
1689            )
1690            if "system_index" in inputs:
1691                output["system_index"] = np.append(
1692                    inputs["system_index"],
1693                    np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype),
1694                )
1695
1696        output["true_atoms"] = output["species"] > 0
1697        output["true_sys"] = np.arange(len(output["natoms"])) < nsys
1698
1699        state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_}
1700
1701        return FrozenDict(state), output
1702
1703
1704def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1705    """Remove padding from atomic arrays."""
1706    if "true_atoms" not in inputs:
1707        return inputs
1708
1709    species = inputs["species"]
1710    true_atoms = inputs["true_atoms"]
1711    true_sys = inputs["true_sys"]
1712    natall = species.shape[0]
1713    nat = np.argmax(species <= 0)
1714    if nat == 0:
1715        return inputs
1716
1717    natoms = inputs["natoms"]
1718    nsysall = len(natoms)
1719
1720    output = {**inputs}
1721    for k, v in inputs.items():
1722        if isinstance(v, jax.Array) or isinstance(v, np.ndarray):
1723            if v.ndim == 0:
1724                continue
1725            if v.shape[0] == natall:
1726                output[k] = v[true_atoms]
1727            elif v.shape[0] == nsysall:
1728                output[k] = v[true_sys]
1729    del output["true_sys"]
1730    del output["true_atoms"]
1731    return output
1732
1733
1734def check_input(inputs):
1735    """Check the input dictionary for required keys and types."""
1736    assert "species" in inputs, "species must be provided"
1737    assert "coordinates" in inputs, "coordinates must be provided"
1738    species = inputs["species"].astype(np.int32)
1739    ifake = np.argmax(species <= 0)
1740    if ifake > 0:
1741        assert np.all(species[:ifake] > 0), "species must be positive"
1742    nat = inputs["species"].shape[0]
1743
1744    natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32)
1745    batch_index = inputs.get(
1746        "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms)
1747    ).astype(np.int32)
1748    output = {**inputs, "natoms": natoms, "batch_index": batch_index}
1749    if "cells" in inputs:
1750        cells = inputs["cells"]
1751        if "reciprocal_cells" not in inputs:
1752            reciprocal_cells = np.linalg.inv(cells)
1753        else:
1754            reciprocal_cells = inputs["reciprocal_cells"]
1755        if cells.ndim == 2:
1756            cells = cells[None, :, :]
1757        if reciprocal_cells.ndim == 2:
1758            reciprocal_cells = reciprocal_cells[None, :, :]
1759        output["cells"] = cells
1760        output["reciprocal_cells"] = reciprocal_cells
1761
1762    return output
1763
1764
1765def convert_to_jax(data):
1766    """Convert a numpy arrays to jax arrays in a pytree."""
1767
1768    def convert(x):
1769        if isinstance(x, np.ndarray):
1770            # if x.dtype == np.float64:
1771            #     return jnp.asarray(x, dtype=jnp.float32)
1772            return jnp.asarray(x)
1773        return x
1774
1775    return jax.tree_util.tree_map(convert, data)
1776
1777
1778class JaxConverter(nn.Module):
1779    """Convert numpy arrays to jax arrays in a pytree."""
1780
1781    def __call__(self, data):
1782        return convert_to_jax(data)
1783
1784
1785@dataclasses.dataclass(frozen=True)
1786class PreprocessingChain:
1787    """Chain of preprocessing layers."""
1788
1789    layers: Tuple[Callable[..., Dict[str, Any]]]
1790    """Preprocessing layers."""
1791    use_atom_padding: bool = False
1792    """Add an AtomPadding layer at the beginning of the chain."""
1793    atom_padder: AtomPadding = AtomPadding()
1794    """AtomPadding layer."""
1795
1796    def __post_init__(self):
1797        if not isinstance(self.layers, Sequence):
1798            raise ValueError(
1799                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
1800            )
1801        if not self.layers:
1802            raise ValueError(f"Error: no Preprocessing layers were provided.")
1803
1804    def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]:
1805        do_check_input = state.get("check_input", True)
1806        if do_check_input:
1807            inputs = check_input(inputs)
1808        new_state = []
1809        layer_state = state["layers_state"]
1810        i = 0
1811        if self.use_atom_padding:
1812            s, inputs = self.atom_padder(layer_state[0], inputs)
1813            new_state.append(s)
1814            i += 1
1815        for layer in self.layers:
1816            s, inputs = layer(layer_state[i], inputs, return_state_update=False)
1817            new_state.append(s)
1818            i += 1
1819        return FrozenDict({**state, "layers_state": tuple(new_state)}), convert_to_jax(
1820            inputs
1821        )
1822
1823    def check_reallocate(self, state, inputs):
1824        new_state = []
1825        state_up = []
1826        layer_state = state["layers_state"]
1827        i = 0
1828        if self.use_atom_padding:
1829            new_state.append(layer_state[0])
1830            i += 1
1831        parent_overflow = False
1832        for layer in self.layers:
1833            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1834                layer_state[i], inputs, parent_overflow
1835            )
1836            new_state.append(s)
1837            state_up.append(s_up)
1838            i += 1
1839
1840        if not parent_overflow:
1841            return state, {}, inputs, False
1842        return (
1843            FrozenDict({**state, "layers_state": tuple(new_state)}),
1844            state_up,
1845            inputs,
1846            True,
1847        )
1848
1849    def atom_padding(self, state, inputs):
1850        if self.use_atom_padding:
1851            padder_state = state["layers_state"][0]
1852            return self.atom_padder(padder_state, inputs)
1853        return state, inputs
1854
1855    @partial(jax.jit, static_argnums=(0, 1))
1856    def process(self, state, inputs):
1857        layer_state = state["layers_state"]
1858        i = 1 if self.use_atom_padding else 0
1859        for layer in self.layers:
1860            inputs = layer.process(layer_state[i], inputs)
1861            i += 1
1862        return inputs
1863
1864    @partial(jax.jit, static_argnums=(0))
1865    def update_skin(self, inputs):
1866        for layer in self.layers:
1867            inputs = layer.update_skin(inputs)
1868        return inputs
1869
1870    def init(self):
1871        state = []
1872        if self.use_atom_padding:
1873            state.append(self.atom_padder.init())
1874        for layer in self.layers:
1875            state.append(layer.init())
1876        return FrozenDict({"check_input": True, "layers_state": state})
1877
1878    def init_with_output(self, inputs):
1879        state = self.init()
1880        return self(state, inputs)
1881
1882    def get_processors(self):
1883        processors = []
1884        for layer in self.layers:
1885            if hasattr(layer, "get_processor"):
1886                processors.append(layer.get_processor())
1887        return processors
1888
1889    def get_graphs_properties(self):
1890        properties = {}
1891        for layer in self.layers:
1892            if hasattr(layer, "get_graph_properties"):
1893                properties = deep_update(properties, layer.get_graph_properties())
1894        return properties
1895
1896
1897# PREPROCESSING = {
1898#     "GRAPH": GraphGenerator,
1899#     # "GRAPH_FIXED": GraphGeneratorFixed,
1900#     "GRAPH_FILTER": GraphFilter,
1901#     "GRAPH_ANGULAR_EXTENSION": GraphAngularExtension,
1902#     # "GRAPH_DENSE_EXTENSION": GraphDenseExtension,
1903#     "SPECIES_INDEXER": SpeciesIndexer,
1904# }
@dataclasses.dataclass(frozen=True)
class GraphGenerator:
 22@dataclasses.dataclass(frozen=True)
 23class GraphGenerator:
 24    """Generate a graph from a set of coordinates
 25
 26    FPID: GRAPH
 27
 28    For now, we generate all pairs of atoms and filter based on cutoff.
 29    If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist.
 30    """
 31
 32    cutoff: float
 33    """Cutoff distance for the graph."""
 34    graph_key: str = "graph"
 35    """Key of the graph in the outputs."""
 36    switch_params: dict = dataclasses.field(default_factory=dict, hash=False)
 37    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 38    kmax: int = 30
 39    """Maximum number of k-points to consider."""
 40    kthr: float = 1e-6
 41    """Threshold for k-point filtering."""
 42    k_space: bool = False
 43    """Whether to generate k-space information for the graph."""
 44    mult_size: float = 1.05
 45    """Multiplicative factor for resizing the nblist."""
 46    # covalent_cutoff: bool = False
 47
 48    FPID: ClassVar[str] = "GRAPH"
 49
 50    def init(self):
 51        return FrozenDict(
 52            {
 53                "max_nat": 1,
 54                "npairs": 1,
 55                "nblist_mult_size": self.mult_size,
 56            }
 57        )
 58
 59    def get_processor(self) -> Tuple[nn.Module, Dict]:
 60        return GraphProcessor, {
 61            "cutoff": self.cutoff,
 62            "graph_key": self.graph_key,
 63            "switch_params": self.switch_params,
 64            "name": f"{self.graph_key}_Processor",
 65        }
 66
 67    def get_graph_properties(self):
 68        return {
 69            self.graph_key: {
 70                "cutoff": self.cutoff,
 71                "directed": True,
 72            }
 73        }
 74
 75    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 76        """build a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 77        if self.graph_key in inputs:
 78            graph = inputs[self.graph_key]
 79            if "keep_graph" in graph:
 80                return state, inputs
 81
 82        coords = np.array(inputs["coordinates"], dtype=np.float32)
 83        natoms = np.array(inputs["natoms"], dtype=np.int32)
 84        batch_index = np.array(inputs["batch_index"], dtype=np.int32)
 85
 86        new_state = {**state}
 87        state_up = {}
 88
 89        mult_size = state.get("nblist_mult_size", self.mult_size)
 90        assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0"
 91
 92        if natoms.shape[0] == 1:
 93            max_nat = coords.shape[0]
 94            true_max_nat = max_nat
 95        else:
 96            max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0]))
 97            true_max_nat = int(np.max(natoms))
 98            if true_max_nat > max_nat:
 99                add_atoms = state.get("add_atoms", 0)
100                new_maxnat = true_max_nat + add_atoms
101                state_up["max_nat"] = (new_maxnat, max_nat)
102                new_state["max_nat"] = new_maxnat
103
104        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
105
106        ### compute indices of all pairs
107        p1, p2 = np.triu_indices(true_max_nat, 1)
108        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
109        pbc_shifts = None
110        if natoms.shape[0] > 1:
111            ## batching => mask irrelevant pairs
112            mask_p12 = (
113                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
114            ).flatten()
115            shift = np.concatenate(
116                (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32))
117            )
118            p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
119            p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
120
121        apply_pbc = "cells" in inputs
122        if not apply_pbc:
123            ### NO PBC
124            vec = coords[p2] - coords[p1]
125        else:
126            cells = np.array(inputs["cells"], dtype=np.float32)
127            reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32)
128            minimage = state.get("minimum_image", True)
129            if minimage:
130                ## MINIMUM IMAGE CONVENTION
131                vec = coords[p2] - coords[p1]
132                if cells.shape[0] == 1:
133                    vecpbc = np.dot(vec, reciprocal_cells[0])
134                    pbc_shifts = -np.round(vecpbc)
135                    vec = vec + np.dot(pbc_shifts, cells[0])
136                else:
137                    batch_index_vec = batch_index[p1]
138                    vecpbc = np.einsum(
139                        "aj,aji->ai", vec, reciprocal_cells[batch_index_vec]
140                    )
141                    pbc_shifts = -np.round(vecpbc)
142                    vec = vec + np.einsum(
143                        "aj,aji->ai", pbc_shifts, cells[batch_index_vec]
144                    )
145            else:
146                ### GENERAL PBC
147                ## put all atoms in central box
148                if cells.shape[0] == 1:
149                    coords_pbc = np.dot(coords, reciprocal_cells[0])
150                    at_shifts = -np.floor(coords_pbc)
151                    coords_pbc = coords + np.dot(at_shifts, cells[0])
152                else:
153                    coords_pbc = np.einsum(
154                        "aj,aji->ai", coords, reciprocal_cells[batch_index]
155                    )
156                    at_shifts = -np.floor(coords_pbc)
157                    coords_pbc = coords + np.einsum(
158                        "aj,aji->ai", at_shifts, cells[batch_index]
159                    )
160                vec = coords_pbc[p2] - coords_pbc[p1]
161
162                ## compute maximum number of repeats
163                inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5
164                cdinv = cutoff_skin * inv_distances
165                num_repeats_all = np.ceil(cdinv).astype(np.int32)
166                if "true_sys" in inputs:
167                    num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0)
168                # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all)
169                num_repeats = np.max(num_repeats_all, axis=0)
170                num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0)))
171                if np.any(num_repeats > num_repeats_prev):
172                    num_repeats_new = np.maximum(num_repeats, num_repeats_prev)
173                    state_up["num_repeats_pbc"] = (
174                        tuple(num_repeats_new),
175                        tuple(num_repeats_prev),
176                    )
177                    new_state["num_repeats_pbc"] = tuple(num_repeats_new)
178                ## build all possible shifts
179                cell_shift_pbc = np.array(
180                    np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
181                    dtype=cells.dtype,
182                ).T.reshape(-1, 3)
183                ## shift applied to vectors
184                if cells.shape[0] == 1:
185                    dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :]
186                    vec = (vec[:, None, :] + dvec).reshape(-1, 3)
187                    pbc_shifts = np.broadcast_to(
188                        cell_shift_pbc[None, :, :],
189                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
190                    ).reshape(-1, 3)
191                    p1 = np.broadcast_to(
192                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
193                    ).flatten()
194                    p2 = np.broadcast_to(
195                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
196                    ).flatten()
197                    if natoms.shape[0] > 1:
198                        mask_p12 = np.broadcast_to(
199                            mask_p12[:, None],
200                            (mask_p12.shape[0], cell_shift_pbc.shape[0]),
201                        ).flatten()
202                else:
203                    dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells)
204
205                    ## get pbc shifts specific to each box
206                    cell_shift_pbc = np.broadcast_to(
207                        cell_shift_pbc[None, :, :],
208                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
209                    )
210                    mask = np.all(
211                        np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
212                    ).flatten()
213                    idx = np.nonzero(mask)[0]
214                    nshifts = idx.shape[0]
215                    nshifts_prev = state.get("nshifts_pbc", 0)
216                    if nshifts > nshifts_prev or add_margin:
217                        nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1
218                        state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev)
219                        new_state["nshifts_pbc"] = nshifts_new
220
221                    dvec_filter = dvec.reshape(-1, 3)[idx, :]
222                    cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :]
223
224                    ## get batch shift in the dvec_filter array
225                    nrep = np.prod(2 * num_repeats_all + 1, axis=1)
226                    bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1]))
227
228                    ## compute vectors
229                    batch_index_vec = batch_index[p1]
230                    nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0)
231                    vec = vec.repeat(nrep_vec, axis=0)
232                    nvec_pbc = nrep_vec.sum() #vec.shape[0]
233                    nvec_pbc_prev = state.get("nvec_pbc", 0)
234                    if nvec_pbc > nvec_pbc_prev or add_margin:
235                        nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1
236                        state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev)
237                        new_state["nvec_pbc"] = nvec_pbc_new
238
239                    # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev)
240                    ## get shift index
241                    dshift = np.concatenate(
242                        (np.array([0]), np.cumsum(nrep_vec)[:-1])
243                    ).repeat(nrep_vec)
244                    # ishift = np.arange(dshift.shape[0])-dshift
245                    # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec)
246                    icellshift = (
247                        np.arange(dshift.shape[0])
248                        - dshift
249                        + bshift[batch_index_vec].repeat(nrep_vec)
250                    )
251                    # shift vectors
252                    vec = vec + dvec_filter[icellshift]
253                    pbc_shifts = cell_shift_pbc_filter[icellshift]
254
255                    p1 = np.repeat(p1, nrep_vec)
256                    p2 = np.repeat(p2, nrep_vec)
257                    if natoms.shape[0] > 1:
258                        mask_p12 = np.repeat(mask_p12, nrep_vec)
259
260        ## compute distances
261        d12 = (vec**2).sum(axis=-1)
262        if natoms.shape[0] > 1:
263            d12 = np.where(mask_p12, d12, cutoff_skin**2)
264
265        ## filter pairs
266        max_pairs = state.get("npairs", 1)
267        mask = d12 < cutoff_skin**2
268        idx = np.nonzero(mask)[0]
269        npairs = idx.shape[0]
270        if npairs > max_pairs or add_margin:
271            prev_max_pairs = max_pairs
272            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
273            state_up["npairs"] = (max_pairs, prev_max_pairs)
274            new_state["npairs"] = max_pairs
275
276        nat = coords.shape[0]
277        edge_src = np.full(max_pairs, nat, dtype=np.int32)
278        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
279        d12_ = np.full(max_pairs, cutoff_skin**2)
280        edge_src[:npairs] = p1[idx]
281        edge_dst[:npairs] = p2[idx]
282        d12_[:npairs] = d12[idx]
283        d12 = d12_
284
285        if apply_pbc:
286            pbc_shifts_ = np.zeros((max_pairs, 3))
287            pbc_shifts_[:npairs] = pbc_shifts[idx]
288            pbc_shifts = pbc_shifts_
289            if not minimage:
290                pbc_shifts[:npairs] = (
291                    pbc_shifts[:npairs]
292                    + at_shifts[edge_dst[:npairs]]
293                    - at_shifts[edge_src[:npairs]]
294                )
295
296        if "nblist_skin" in state:
297            edge_src_skin = edge_src
298            edge_dst_skin = edge_dst
299            if apply_pbc:
300                pbc_shifts_skin = pbc_shifts
301            max_pairs_skin = state.get("npairs_skin", 1)
302            mask = d12 < self.cutoff**2
303            idx = np.nonzero(mask)[0]
304            npairs_skin = idx.shape[0]
305            if npairs_skin > max_pairs_skin or add_margin:
306                prev_max_pairs_skin = max_pairs_skin
307                max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1
308                state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin)
309                new_state["npairs_skin"] = max_pairs_skin
310            edge_src = np.full(max_pairs_skin, nat, dtype=np.int32)
311            edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32)
312            d12_ = np.full(max_pairs_skin, self.cutoff**2)
313            edge_src[:npairs_skin] = edge_src_skin[idx]
314            edge_dst[:npairs_skin] = edge_dst_skin[idx]
315            d12_[:npairs_skin] = d12[idx]
316            d12 = d12_
317            if apply_pbc:
318                pbc_shifts = np.full((max_pairs_skin, 3), 0.0)
319                pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx]
320
321        ## symmetrize
322        edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate(
323            (edge_dst, edge_src)
324        )
325        d12 = np.concatenate((d12, d12))
326        if apply_pbc:
327            pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts))
328
329        graph = inputs.get(self.graph_key, {})
330        graph_out = {
331            **graph,
332            "edge_src": edge_src,
333            "edge_dst": edge_dst,
334            "d12": d12,
335            "overflow": False,
336            "pbc_shifts": pbc_shifts,
337        }
338        if "nblist_skin" in state:
339            graph_out["edge_src_skin"] = edge_src_skin
340            graph_out["edge_dst_skin"] = edge_dst_skin
341            if apply_pbc:
342                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
343
344        if self.k_space and apply_pbc:
345            if "k_points" not in graph:
346                ks, _, _, bewald = get_reciprocal_space_parameters(
347                    reciprocal_cells, self.cutoff, self.kmax, self.kthr
348                )
349            graph_out["k_points"] = ks
350            graph_out["b_ewald"] = bewald
351
352        output = {**inputs, self.graph_key: graph_out}
353
354        if return_state_update:
355            return FrozenDict(new_state), output, state_up
356        return FrozenDict(new_state), output
357
358    def check_reallocate(self, state, inputs, parent_overflow=False):
359        """check for overflow and reallocate nblist if necessary"""
360        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
361        if not overflow:
362            return state, {}, inputs, False
363
364        add_margin = inputs[self.graph_key].get("overflow", False)
365        state, inputs, state_up = self(
366            state, inputs, return_state_update=True, add_margin=add_margin
367        )
368        return state, state_up, inputs, True
369
370    @partial(jax.jit, static_argnums=(0, 1))
371    def process(self, state, inputs):
372        """build a nblist on accelerator with jax and precomputed shapes"""
373        if self.graph_key in inputs:
374            graph = inputs[self.graph_key]
375            if "keep_graph" in graph:
376                return inputs
377        coords = inputs["coordinates"]
378        natoms = inputs["natoms"]
379        batch_index = inputs["batch_index"]
380
381        if natoms.shape[0] == 1:
382            max_nat = coords.shape[0]
383        else:
384            max_nat = state.get(
385                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
386            )
387        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
388
389        ### compute indices of all pairs
390        p1, p2 = np.triu_indices(max_nat, 1)
391        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
392        pbc_shifts = None
393        if natoms.shape[0] > 1:
394            ## batching => mask irrelevant pairs
395            mask_p12 = (
396                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
397            ).flatten()
398            shift = jnp.concatenate(
399                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
400            )
401            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
402            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
403
404        ## compute vectors
405        overflow_repeats = jnp.asarray(False, dtype=bool)
406        if "cells" not in inputs:
407            vec = coords[p2] - coords[p1]
408        else:
409            cells = inputs["cells"]
410            reciprocal_cells = inputs["reciprocal_cells"]
411            minimage = state.get("minimum_image", True)
412
413            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
414                vecpbc = jnp.dot(vec, reciprocal_cell)
415                if mode == "round":
416                    pbc_shifts = -jnp.round(vecpbc)
417                elif mode == "floor":
418                    pbc_shifts = -jnp.floor(vecpbc)
419                else:
420                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
421                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
422
423            if minimage:
424                ## minimum image convention
425                vec = coords[p2] - coords[p1]
426
427                if cells.shape[0] == 1:
428                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
429                else:
430                    batch_index_vec = batch_index[p1]
431                    vec, pbc_shifts = jax.vmap(compute_pbc)(
432                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
433                    )
434            else:
435                ### general PBC only for single cell yet
436                # if cells.shape[0] > 1:
437                #     raise NotImplementedError(
438                #         "General PBC not implemented for batches on accelerator."
439                #     )
440                # cell = cells[0]
441                # reciprocal_cell = reciprocal_cells[0]
442
443                ## put all atoms in central box
444                if cells.shape[0] == 1:
445                    coords_pbc, at_shifts = compute_pbc(
446                        coords, reciprocal_cells[0], cells[0], mode="floor"
447                    )
448                else:
449                    coords_pbc, at_shifts = jax.vmap(
450                        partial(compute_pbc, mode="floor")
451                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
452                vec = coords_pbc[p2] - coords_pbc[p1]
453                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
454                # if num_repeats is None:
455                #     raise ValueError(
456                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
457                #     )
458                # check if num_repeats is larger than previous
459                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
460                cdinv = cutoff_skin * inv_distances
461                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
462                if "true_sys" in inputs:
463                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
464                num_repeats_new = jnp.max(num_repeats_all, axis=0)
465                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
466
467                cell_shift_pbc = jnp.asarray(
468                    np.array(
469                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
470                        dtype=cells.dtype,
471                    ).T.reshape(-1, 3)
472                )
473
474                if cells.shape[0] == 1:
475                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
476                    pbc_shifts = jnp.broadcast_to(
477                        cell_shift_pbc[None, :, :],
478                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
479                    ).reshape(-1, 3)
480                    p1 = jnp.broadcast_to(
481                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
482                    ).flatten()
483                    p2 = jnp.broadcast_to(
484                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
485                    ).flatten()
486                    if natoms.shape[0] > 1:
487                        mask_p12 = jnp.broadcast_to(
488                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
489                        ).flatten()
490                else:
491                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
492
493                    ## get pbc shifts specific to each box
494                    cell_shift_pbc = jnp.broadcast_to(
495                        cell_shift_pbc[None, :, :],
496                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
497                    )
498                    mask = jnp.all(
499                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
500                    ).flatten()
501                    max_shifts  = state.get("nshifts_pbc", 1)
502
503                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
504                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
505                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
506                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
507                        mask,
508                        max_shifts,
509                        (dvecx, 0.),
510                        (dvecy, 0.),
511                        (dvecz, 0.),
512                        (shiftx, 0),
513                        (shifty, 0),
514                        (shiftz, 0),
515                    )
516                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
517                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
518                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
519
520                    ## get batch shift in the dvec_filter array
521                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
522                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
523
524                    ## repeat vectors
525                    nvec_max = state.get("nvec_pbc", 1)
526                    batch_index_vec = batch_index[p1]
527                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
528                    nvec = nrep_vec.sum()
529                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
530                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
531                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
532
533                    ## get shift index
534                    dshift = jnp.concatenate(
535                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
536                    )
537                    if nrep_vec.size == 0:
538                        dshift = jnp.array([],dtype=jnp.int32)
539                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
540                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
541                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
542                    vec = vec + dvec[icellshift]
543                    pbc_shifts = cell_shift_pbc[icellshift]
544                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
545                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
546                    if natoms.shape[0] > 1:
547                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
548                
549
550        ## compute distances
551        d12 = (vec**2).sum(axis=-1)
552        if natoms.shape[0] > 1:
553            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
554
555        ## filter pairs
556        max_pairs = state.get("npairs", 1)
557        mask = d12 < cutoff_skin**2
558        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
559            mask,
560            max_pairs,
561            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
562            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
563            (d12, cutoff_skin**2),
564        )
565        if "cells" in inputs:
566            pbc_shifts = (
567                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
568                .at[scatter_idx]
569                .set(pbc_shifts, mode="drop")
570            )
571            if not minimage:
572                pbc_shifts = (
573                    pbc_shifts
574                    + at_shifts.at[edge_dst].get(fill_value=0.0)
575                    - at_shifts.at[edge_src].get(fill_value=0.0)
576                )
577
578        ## check for overflow
579        if natoms.shape[0] == 1:
580            true_max_nat = coords.shape[0]
581        else:
582            true_max_nat = jnp.max(natoms)
583        overflow_count = npairs > max_pairs
584        overflow_at = true_max_nat > max_nat
585        overflow = overflow_count | overflow_at | overflow_repeats
586
587        if "nblist_skin" in state:
588            # edge_mask_skin = edge_mask
589            edge_src_skin = edge_src
590            edge_dst_skin = edge_dst
591            if "cells" in inputs:
592                pbc_shifts_skin = pbc_shifts
593            max_pairs_skin = state.get("npairs_skin", 1)
594            mask = d12 < self.cutoff**2
595            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
596                mask,
597                max_pairs_skin,
598                (edge_src, coords.shape[0]),
599                (edge_dst, coords.shape[0]),
600                (d12, self.cutoff**2),
601            )
602            if "cells" in inputs:
603                pbc_shifts = (
604                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
605                    .at[scatter_idx]
606                    .set(pbc_shifts, mode="drop")
607                )
608            overflow = overflow | (npairs_skin > max_pairs_skin)
609
610        ## symmetrize
611        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
612            (edge_dst, edge_src)
613        )
614        d12 = jnp.concatenate((d12, d12))
615        if "cells" in inputs:
616            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
617
618        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
619        graph_out = {
620            **graph,
621            "edge_src": edge_src,
622            "edge_dst": edge_dst,
623            "d12": d12,
624            "overflow": overflow,
625            "pbc_shifts": pbc_shifts,
626        }
627        if "nblist_skin" in state:
628            graph_out["edge_src_skin"] = edge_src_skin
629            graph_out["edge_dst_skin"] = edge_dst_skin
630            if "cells" in inputs:
631                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
632
633        if self.k_space and "cells" in inputs:
634            if "k_points" not in graph:
635                raise NotImplementedError(
636                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
637                )
638        return {**inputs, self.graph_key: graph_out}
639
640    @partial(jax.jit, static_argnums=(0,))
641    def update_skin(self, inputs):
642        """update the nblist without recomputing the full nblist"""
643        graph = inputs[self.graph_key]
644
645        edge_src_skin = graph["edge_src_skin"]
646        edge_dst_skin = graph["edge_dst_skin"]
647        coords = inputs["coordinates"]
648        vec = coords.at[edge_dst_skin].get(
649            mode="fill", fill_value=self.cutoff
650        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
651
652        if "cells" in inputs:
653            pbc_shifts_skin = graph["pbc_shifts_skin"]
654            cells = inputs["cells"]
655            if cells.shape[0] == 1:
656                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
657            else:
658                batch_index_vec = inputs["batch_index"][edge_src_skin]
659                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
660
661        nat = coords.shape[0]
662        d12 = jnp.sum(vec**2, axis=-1)
663        mask = d12 < self.cutoff**2
664        max_pairs = graph["edge_src"].shape[0] // 2
665        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
666            mask,
667            max_pairs,
668            (edge_src_skin, nat),
669            (edge_dst_skin, nat),
670            (d12, self.cutoff**2),
671        )
672        if "cells" in inputs:
673            pbc_shifts = (
674                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
675                .at[scatter_idx]
676                .set(pbc_shifts_skin)
677            )
678
679        overflow = graph.get("overflow", False) | (npairs > max_pairs)
680        graph_out = {
681            **graph,
682            "edge_src": jnp.concatenate((edge_src, edge_dst)),
683            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
684            "d12": jnp.concatenate((d12, d12)),
685            "overflow": overflow,
686        }
687        if "cells" in inputs:
688            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
689
690        if self.k_space and "cells" in inputs:
691            if "k_points" not in graph:
692                raise NotImplementedError(
693                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
694                )
695
696        return {**inputs, self.graph_key: graph_out}

Generate a graph from a set of coordinates

FPID: GRAPH

For now, we generate all pairs of atoms and filter based on cutoff. If a nblist_skin is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the update_skin method to update the original graph without recomputing the full nblist.

GraphGenerator( cutoff: float, graph_key: str = 'graph', switch_params: dict = <factory>, kmax: int = 30, kthr: float = 1e-06, k_space: bool = False, mult_size: float = 1.05)
cutoff: float

Cutoff distance for the graph.

graph_key: str = 'graph'

Key of the graph in the outputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

kmax: int = 30

Maximum number of k-points to consider.

kthr: float = 1e-06

Threshold for k-point filtering.

k_space: bool = False

Whether to generate k-space information for the graph.

mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

FPID: ClassVar[str] = 'GRAPH'
def init(self):
50    def init(self):
51        return FrozenDict(
52            {
53                "max_nat": 1,
54                "npairs": 1,
55                "nblist_mult_size": self.mult_size,
56            }
57        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
59    def get_processor(self) -> Tuple[nn.Module, Dict]:
60        return GraphProcessor, {
61            "cutoff": self.cutoff,
62            "graph_key": self.graph_key,
63            "switch_params": self.switch_params,
64            "name": f"{self.graph_key}_Processor",
65        }
def get_graph_properties(self):
67    def get_graph_properties(self):
68        return {
69            self.graph_key: {
70                "cutoff": self.cutoff,
71                "directed": True,
72            }
73        }
def check_reallocate(self, state, inputs, parent_overflow=False):
358    def check_reallocate(self, state, inputs, parent_overflow=False):
359        """check for overflow and reallocate nblist if necessary"""
360        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
361        if not overflow:
362            return state, {}, inputs, False
363
364        add_margin = inputs[self.graph_key].get("overflow", False)
365        state, inputs, state_up = self(
366            state, inputs, return_state_update=True, add_margin=add_margin
367        )
368        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
370    @partial(jax.jit, static_argnums=(0, 1))
371    def process(self, state, inputs):
372        """build a nblist on accelerator with jax and precomputed shapes"""
373        if self.graph_key in inputs:
374            graph = inputs[self.graph_key]
375            if "keep_graph" in graph:
376                return inputs
377        coords = inputs["coordinates"]
378        natoms = inputs["natoms"]
379        batch_index = inputs["batch_index"]
380
381        if natoms.shape[0] == 1:
382            max_nat = coords.shape[0]
383        else:
384            max_nat = state.get(
385                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
386            )
387        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
388
389        ### compute indices of all pairs
390        p1, p2 = np.triu_indices(max_nat, 1)
391        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
392        pbc_shifts = None
393        if natoms.shape[0] > 1:
394            ## batching => mask irrelevant pairs
395            mask_p12 = (
396                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
397            ).flatten()
398            shift = jnp.concatenate(
399                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
400            )
401            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
402            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
403
404        ## compute vectors
405        overflow_repeats = jnp.asarray(False, dtype=bool)
406        if "cells" not in inputs:
407            vec = coords[p2] - coords[p1]
408        else:
409            cells = inputs["cells"]
410            reciprocal_cells = inputs["reciprocal_cells"]
411            minimage = state.get("minimum_image", True)
412
413            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
414                vecpbc = jnp.dot(vec, reciprocal_cell)
415                if mode == "round":
416                    pbc_shifts = -jnp.round(vecpbc)
417                elif mode == "floor":
418                    pbc_shifts = -jnp.floor(vecpbc)
419                else:
420                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
421                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
422
423            if minimage:
424                ## minimum image convention
425                vec = coords[p2] - coords[p1]
426
427                if cells.shape[0] == 1:
428                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
429                else:
430                    batch_index_vec = batch_index[p1]
431                    vec, pbc_shifts = jax.vmap(compute_pbc)(
432                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
433                    )
434            else:
435                ### general PBC only for single cell yet
436                # if cells.shape[0] > 1:
437                #     raise NotImplementedError(
438                #         "General PBC not implemented for batches on accelerator."
439                #     )
440                # cell = cells[0]
441                # reciprocal_cell = reciprocal_cells[0]
442
443                ## put all atoms in central box
444                if cells.shape[0] == 1:
445                    coords_pbc, at_shifts = compute_pbc(
446                        coords, reciprocal_cells[0], cells[0], mode="floor"
447                    )
448                else:
449                    coords_pbc, at_shifts = jax.vmap(
450                        partial(compute_pbc, mode="floor")
451                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
452                vec = coords_pbc[p2] - coords_pbc[p1]
453                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
454                # if num_repeats is None:
455                #     raise ValueError(
456                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
457                #     )
458                # check if num_repeats is larger than previous
459                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
460                cdinv = cutoff_skin * inv_distances
461                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
462                if "true_sys" in inputs:
463                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
464                num_repeats_new = jnp.max(num_repeats_all, axis=0)
465                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
466
467                cell_shift_pbc = jnp.asarray(
468                    np.array(
469                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
470                        dtype=cells.dtype,
471                    ).T.reshape(-1, 3)
472                )
473
474                if cells.shape[0] == 1:
475                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
476                    pbc_shifts = jnp.broadcast_to(
477                        cell_shift_pbc[None, :, :],
478                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
479                    ).reshape(-1, 3)
480                    p1 = jnp.broadcast_to(
481                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
482                    ).flatten()
483                    p2 = jnp.broadcast_to(
484                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
485                    ).flatten()
486                    if natoms.shape[0] > 1:
487                        mask_p12 = jnp.broadcast_to(
488                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
489                        ).flatten()
490                else:
491                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
492
493                    ## get pbc shifts specific to each box
494                    cell_shift_pbc = jnp.broadcast_to(
495                        cell_shift_pbc[None, :, :],
496                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
497                    )
498                    mask = jnp.all(
499                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
500                    ).flatten()
501                    max_shifts  = state.get("nshifts_pbc", 1)
502
503                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
504                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
505                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
506                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
507                        mask,
508                        max_shifts,
509                        (dvecx, 0.),
510                        (dvecy, 0.),
511                        (dvecz, 0.),
512                        (shiftx, 0),
513                        (shifty, 0),
514                        (shiftz, 0),
515                    )
516                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
517                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
518                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
519
520                    ## get batch shift in the dvec_filter array
521                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
522                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
523
524                    ## repeat vectors
525                    nvec_max = state.get("nvec_pbc", 1)
526                    batch_index_vec = batch_index[p1]
527                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
528                    nvec = nrep_vec.sum()
529                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
530                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
531                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
532
533                    ## get shift index
534                    dshift = jnp.concatenate(
535                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
536                    )
537                    if nrep_vec.size == 0:
538                        dshift = jnp.array([],dtype=jnp.int32)
539                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
540                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
541                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
542                    vec = vec + dvec[icellshift]
543                    pbc_shifts = cell_shift_pbc[icellshift]
544                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
545                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
546                    if natoms.shape[0] > 1:
547                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
548                
549
550        ## compute distances
551        d12 = (vec**2).sum(axis=-1)
552        if natoms.shape[0] > 1:
553            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
554
555        ## filter pairs
556        max_pairs = state.get("npairs", 1)
557        mask = d12 < cutoff_skin**2
558        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
559            mask,
560            max_pairs,
561            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
562            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
563            (d12, cutoff_skin**2),
564        )
565        if "cells" in inputs:
566            pbc_shifts = (
567                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
568                .at[scatter_idx]
569                .set(pbc_shifts, mode="drop")
570            )
571            if not minimage:
572                pbc_shifts = (
573                    pbc_shifts
574                    + at_shifts.at[edge_dst].get(fill_value=0.0)
575                    - at_shifts.at[edge_src].get(fill_value=0.0)
576                )
577
578        ## check for overflow
579        if natoms.shape[0] == 1:
580            true_max_nat = coords.shape[0]
581        else:
582            true_max_nat = jnp.max(natoms)
583        overflow_count = npairs > max_pairs
584        overflow_at = true_max_nat > max_nat
585        overflow = overflow_count | overflow_at | overflow_repeats
586
587        if "nblist_skin" in state:
588            # edge_mask_skin = edge_mask
589            edge_src_skin = edge_src
590            edge_dst_skin = edge_dst
591            if "cells" in inputs:
592                pbc_shifts_skin = pbc_shifts
593            max_pairs_skin = state.get("npairs_skin", 1)
594            mask = d12 < self.cutoff**2
595            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
596                mask,
597                max_pairs_skin,
598                (edge_src, coords.shape[0]),
599                (edge_dst, coords.shape[0]),
600                (d12, self.cutoff**2),
601            )
602            if "cells" in inputs:
603                pbc_shifts = (
604                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
605                    .at[scatter_idx]
606                    .set(pbc_shifts, mode="drop")
607                )
608            overflow = overflow | (npairs_skin > max_pairs_skin)
609
610        ## symmetrize
611        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
612            (edge_dst, edge_src)
613        )
614        d12 = jnp.concatenate((d12, d12))
615        if "cells" in inputs:
616            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
617
618        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
619        graph_out = {
620            **graph,
621            "edge_src": edge_src,
622            "edge_dst": edge_dst,
623            "d12": d12,
624            "overflow": overflow,
625            "pbc_shifts": pbc_shifts,
626        }
627        if "nblist_skin" in state:
628            graph_out["edge_src_skin"] = edge_src_skin
629            graph_out["edge_dst_skin"] = edge_dst_skin
630            if "cells" in inputs:
631                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
632
633        if self.k_space and "cells" in inputs:
634            if "k_points" not in graph:
635                raise NotImplementedError(
636                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
637                )
638        return {**inputs, self.graph_key: graph_out}

build a nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
640    @partial(jax.jit, static_argnums=(0,))
641    def update_skin(self, inputs):
642        """update the nblist without recomputing the full nblist"""
643        graph = inputs[self.graph_key]
644
645        edge_src_skin = graph["edge_src_skin"]
646        edge_dst_skin = graph["edge_dst_skin"]
647        coords = inputs["coordinates"]
648        vec = coords.at[edge_dst_skin].get(
649            mode="fill", fill_value=self.cutoff
650        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
651
652        if "cells" in inputs:
653            pbc_shifts_skin = graph["pbc_shifts_skin"]
654            cells = inputs["cells"]
655            if cells.shape[0] == 1:
656                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
657            else:
658                batch_index_vec = inputs["batch_index"][edge_src_skin]
659                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
660
661        nat = coords.shape[0]
662        d12 = jnp.sum(vec**2, axis=-1)
663        mask = d12 < self.cutoff**2
664        max_pairs = graph["edge_src"].shape[0] // 2
665        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
666            mask,
667            max_pairs,
668            (edge_src_skin, nat),
669            (edge_dst_skin, nat),
670            (d12, self.cutoff**2),
671        )
672        if "cells" in inputs:
673            pbc_shifts = (
674                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
675                .at[scatter_idx]
676                .set(pbc_shifts_skin)
677            )
678
679        overflow = graph.get("overflow", False) | (npairs > max_pairs)
680        graph_out = {
681            **graph,
682            "edge_src": jnp.concatenate((edge_src, edge_dst)),
683            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
684            "d12": jnp.concatenate((d12, d12)),
685            "overflow": overflow,
686        }
687        if "cells" in inputs:
688            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
689
690        if self.k_space and "cells" in inputs:
691            if "k_points" not in graph:
692                raise NotImplementedError(
693                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
694                )
695
696        return {**inputs, self.graph_key: graph_out}

update the nblist without recomputing the full nblist

class GraphProcessor(flax.linen.module.Module):
699class GraphProcessor(nn.Module):
700    """Process a pre-generated graph
701
702    The pre-generated graph should contain the following keys:
703    - edge_src: source indices of the edges
704    - edge_dst: destination indices of the edges
705    - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs)
706
707    This module is automatically added to a FENNIX model when a GraphGenerator is used.
708
709    """
710
711    cutoff: float
712    """Cutoff distance for the graph."""
713    graph_key: str = "graph"
714    """Key of the graph in the outputs."""
715    switch_params: dict = dataclasses.field(default_factory=dict)
716    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
717
718    @nn.compact
719    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
720        graph = inputs[self.graph_key]
721        coords = inputs["coordinates"]
722        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
723        # edge_mask = edge_src < coords.shape[0]
724        vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[
725            edge_src
726        ].get(mode="fill", fill_value=0.0)
727        if "cells" in inputs:
728            cells = inputs["cells"]
729            if cells.shape[0] == 1:
730                vec = vec + jnp.dot(graph["pbc_shifts"], cells[0])
731            else:
732                batch_index_vec = inputs["batch_index"][edge_src]
733                vec = vec + jax.vmap(jnp.dot)(
734                    graph["pbc_shifts"], cells[batch_index_vec]
735                )
736
737        distances = jnp.linalg.norm(vec, axis=-1)
738        edge_mask = distances < self.cutoff
739
740        switch = SwitchFunction(
741            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
742        )((distances, edge_mask))
743
744        graph_out = {
745            **graph,
746            "vec": vec,
747            "distances": distances,
748            "switch": switch,
749            "edge_mask": edge_mask,
750        }
751
752        if "alch_group" in inputs:
753            alch_group = inputs["alch_group"]
754            lambda_e = inputs["alch_elambda"]
755            lambda_e = 0.5*(1.-jnp.cos(jnp.pi*lambda_e))
756            mask = alch_group[edge_src] == alch_group[edge_dst]
757            graph_out["switch_raw"] = switch
758            graph_out["switch"] = jnp.where(
759                mask,
760                switch,
761                lambda_e * switch ,
762            )
763
764
765        return {**inputs, self.graph_key: graph_out}

Process a pre-generated graph

The pre-generated graph should contain the following keys:

  • edge_src: source indices of the edges
  • edge_dst: destination indices of the edges
  • pbcs_shifts: pbc shifts for the edges (only if cells are present in the inputs)

This module is automatically added to a FENNIX model when a GraphGenerator is used.

GraphProcessor( cutoff: float, graph_key: str = 'graph', switch_params: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: float

Cutoff distance for the graph.

graph_key: str = 'graph'

Key of the graph in the outputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class GraphFilter:
768@dataclasses.dataclass(frozen=True)
769class GraphFilter:
770    """Filter a graph based on a cutoff distance
771
772    FPID: GRAPH_FILTER
773    """
774
775    cutoff: float
776    """Cutoff distance for the filtering."""
777    parent_graph: str
778    """Key of the parent graph in the inputs."""
779    graph_key: str
780    """Key of the filtered graph in the outputs."""
781    remove_hydrogens: int = False
782    """Remove edges where the source is a hydrogen atom."""
783    switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict)
784    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
785    k_space: bool = False
786    """Generate k-space information for the graph."""
787    kmax: int = 30
788    """Maximum number of k-points to consider."""
789    kthr: float = 1e-6
790    """Threshold for k-point filtering."""
791    mult_size: float = 1.05
792    """Multiplicative factor for resizing the nblist."""
793
794    FPID: ClassVar[str] = "GRAPH_FILTER"
795
796    def init(self):
797        return FrozenDict(
798            {
799                "npairs": 1,
800                "nblist_mult_size": self.mult_size,
801            }
802        )
803
804    def get_processor(self) -> Tuple[nn.Module, Dict]:
805        return GraphFilterProcessor, {
806            "cutoff": self.cutoff,
807            "graph_key": self.graph_key,
808            "parent_graph": self.parent_graph,
809            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
810            "switch_params": self.switch_params,
811        }
812
813    def get_graph_properties(self):
814        return {
815            self.graph_key: {
816                "cutoff": self.cutoff,
817                "directed": True,
818                "parent_graph": self.parent_graph,
819            }
820        }
821
822    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
823        """filter a nblist on cpu with numpy and dynamic shapes + store max shapes"""
824        graph_in = inputs[self.parent_graph]
825        nat = inputs["species"].shape[0]
826
827        new_state = {**state}
828        state_up = {}
829        mult_size = state.get("nblist_mult_size", self.mult_size)
830        assert mult_size >= 1., "nblist_mult_size should be >= 1."
831
832        edge_src = np.array(graph_in["edge_src"], dtype=np.int32)
833        d12 = np.array(graph_in["d12"], dtype=np.float32)
834        if self.remove_hydrogens:
835            species = inputs["species"]
836            src_idx = (edge_src < nat).nonzero()[0]
837            mask = np.zeros(edge_src.shape[0], dtype=bool)
838            mask[src_idx] = (species > 1)[edge_src[src_idx]]
839            d12 = np.where(mask, d12, self.cutoff**2)
840        mask = d12 < self.cutoff**2
841
842        max_pairs = state.get("npairs", 1)
843        idx = np.nonzero(mask)[0]
844        npairs = idx.shape[0]
845        if npairs > max_pairs or add_margin:
846            prev_max_pairs = max_pairs
847            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
848            state_up["npairs"] = (max_pairs, prev_max_pairs)
849            new_state["npairs"] = max_pairs
850
851        filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32)
852        edge_src = np.full(max_pairs, nat, dtype=np.int32)
853        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
854        d12_ = np.full(max_pairs, self.cutoff**2)
855        filter_indices[:npairs] = idx
856        edge_src[:npairs] = graph_in["edge_src"][idx]
857        edge_dst[:npairs] = graph_in["edge_dst"][idx]
858        d12_[:npairs] = d12[idx]
859        d12 = d12_
860
861        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
862        graph_out = {
863            **graph,
864            "edge_src": edge_src,
865            "edge_dst": edge_dst,
866            "filter_indices": filter_indices,
867            "d12": d12,
868            "overflow": False,
869        }
870
871        if self.k_space and "cells" in inputs:
872            if "k_points" not in graph:
873                ks, _, _, bewald = get_reciprocal_space_parameters(
874                    inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr
875                )
876            graph_out["k_points"] = ks
877            graph_out["b_ewald"] = bewald
878
879        output = {**inputs, self.graph_key: graph_out}
880        if return_state_update:
881            return FrozenDict(new_state), output, state_up
882        return FrozenDict(new_state), output
883
884    def check_reallocate(self, state, inputs, parent_overflow=False):
885        """check for overflow and reallocate nblist if necessary"""
886        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
887        if not overflow:
888            return state, {}, inputs, False
889
890        add_margin = inputs[self.graph_key].get("overflow", False)
891        state, inputs, state_up = self(
892            state, inputs, return_state_update=True, add_margin=add_margin
893        )
894        return state, state_up, inputs, True
895
896    @partial(jax.jit, static_argnums=(0, 1))
897    def process(self, state, inputs):
898        """filter a nblist on accelerator with jax and precomputed shapes"""
899        graph_in = inputs[self.parent_graph]
900        if state is None:
901            # skin update mode
902            graph = inputs[self.graph_key]
903            max_pairs = graph["edge_src"].shape[0]
904        else:
905            max_pairs = state.get("npairs", 1)
906
907        max_pairs_in = graph_in["edge_src"].shape[0]
908        nat = inputs["species"].shape[0]
909
910        edge_src = graph_in["edge_src"]
911        d12 = graph_in["d12"]
912        if self.remove_hydrogens:
913            species = inputs["species"]
914            mask = (species > 1)[edge_src]
915            d12 = jnp.where(mask, d12, self.cutoff**2)
916        mask = d12 < self.cutoff**2
917
918        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
919            mask,
920            max_pairs,
921            (edge_src, nat),
922            (graph_in["edge_dst"], nat),
923            (d12, self.cutoff**2),
924            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
925        )
926
927        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
928        overflow = graph.get("overflow", False) | (npairs > max_pairs)
929        graph_out = {
930            **graph,
931            "edge_src": edge_src,
932            "edge_dst": edge_dst,
933            "filter_indices": filter_indices,
934            "d12": d12,
935            "overflow": overflow,
936        }
937
938        if self.k_space and "cells" in inputs:
939            if "k_points" not in graph:
940                raise NotImplementedError(
941                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
942                )
943
944        return {**inputs, self.graph_key: graph_out}
945
946    @partial(jax.jit, static_argnums=(0,))
947    def update_skin(self, inputs):
948        return self.process(None, inputs)

Filter a graph based on a cutoff distance

FPID: GRAPH_FILTER

GraphFilter( cutoff: float, parent_graph: str, graph_key: str, remove_hydrogens: int = False, switch_params: flax.core.frozen_dict.FrozenDict = <factory>, k_space: bool = False, kmax: int = 30, kthr: float = 1e-06, mult_size: float = 1.05)
cutoff: float

Cutoff distance for the filtering.

parent_graph: str

Key of the parent graph in the inputs.

graph_key: str

Key of the filtered graph in the outputs.

remove_hydrogens: int = False

Remove edges where the source is a hydrogen atom.

switch_params: flax.core.frozen_dict.FrozenDict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

k_space: bool = False

Generate k-space information for the graph.

kmax: int = 30

Maximum number of k-points to consider.

kthr: float = 1e-06

Threshold for k-point filtering.

mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

FPID: ClassVar[str] = 'GRAPH_FILTER'
def init(self):
796    def init(self):
797        return FrozenDict(
798            {
799                "npairs": 1,
800                "nblist_mult_size": self.mult_size,
801            }
802        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
804    def get_processor(self) -> Tuple[nn.Module, Dict]:
805        return GraphFilterProcessor, {
806            "cutoff": self.cutoff,
807            "graph_key": self.graph_key,
808            "parent_graph": self.parent_graph,
809            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
810            "switch_params": self.switch_params,
811        }
def get_graph_properties(self):
813    def get_graph_properties(self):
814        return {
815            self.graph_key: {
816                "cutoff": self.cutoff,
817                "directed": True,
818                "parent_graph": self.parent_graph,
819            }
820        }
def check_reallocate(self, state, inputs, parent_overflow=False):
884    def check_reallocate(self, state, inputs, parent_overflow=False):
885        """check for overflow and reallocate nblist if necessary"""
886        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
887        if not overflow:
888            return state, {}, inputs, False
889
890        add_margin = inputs[self.graph_key].get("overflow", False)
891        state, inputs, state_up = self(
892            state, inputs, return_state_update=True, add_margin=add_margin
893        )
894        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
896    @partial(jax.jit, static_argnums=(0, 1))
897    def process(self, state, inputs):
898        """filter a nblist on accelerator with jax and precomputed shapes"""
899        graph_in = inputs[self.parent_graph]
900        if state is None:
901            # skin update mode
902            graph = inputs[self.graph_key]
903            max_pairs = graph["edge_src"].shape[0]
904        else:
905            max_pairs = state.get("npairs", 1)
906
907        max_pairs_in = graph_in["edge_src"].shape[0]
908        nat = inputs["species"].shape[0]
909
910        edge_src = graph_in["edge_src"]
911        d12 = graph_in["d12"]
912        if self.remove_hydrogens:
913            species = inputs["species"]
914            mask = (species > 1)[edge_src]
915            d12 = jnp.where(mask, d12, self.cutoff**2)
916        mask = d12 < self.cutoff**2
917
918        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
919            mask,
920            max_pairs,
921            (edge_src, nat),
922            (graph_in["edge_dst"], nat),
923            (d12, self.cutoff**2),
924            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
925        )
926
927        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
928        overflow = graph.get("overflow", False) | (npairs > max_pairs)
929        graph_out = {
930            **graph,
931            "edge_src": edge_src,
932            "edge_dst": edge_dst,
933            "filter_indices": filter_indices,
934            "d12": d12,
935            "overflow": overflow,
936        }
937
938        if self.k_space and "cells" in inputs:
939            if "k_points" not in graph:
940                raise NotImplementedError(
941                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
942                )
943
944        return {**inputs, self.graph_key: graph_out}

filter a nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
946    @partial(jax.jit, static_argnums=(0,))
947    def update_skin(self, inputs):
948        return self.process(None, inputs)
class GraphFilterProcessor(flax.linen.module.Module):
 951class GraphFilterProcessor(nn.Module):
 952    """Filter processing for a pre-generated graph
 953
 954    This module is automatically added to a FENNIX model when a GraphFilter is used.
 955    """
 956
 957    cutoff: float
 958    """Cutoff distance for the filtering."""
 959    graph_key: str
 960    """Key of the filtered graph in the inputs."""
 961    parent_graph: str
 962    """Key of the parent graph in the inputs."""
 963    switch_params: dict = dataclasses.field(default_factory=dict)
 964    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 965
 966    @nn.compact
 967    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 968        graph_in = inputs[self.parent_graph]
 969        graph = inputs[self.graph_key]
 970
 971        if graph_in["vec"].shape[0] == 0:
 972            vec = graph_in["vec"]
 973            distances = graph_in["distances"]
 974            filter_indices = jnp.asarray([], dtype=jnp.int32)
 975        else:
 976            filter_indices = graph["filter_indices"]
 977            vec = (
 978                graph_in["vec"]
 979                .at[filter_indices]
 980                .get(mode="fill", fill_value=self.cutoff)
 981            )
 982            distances = (
 983                graph_in["distances"]
 984                .at[filter_indices]
 985                .get(mode="fill", fill_value=self.cutoff)
 986            )
 987
 988        edge_mask = distances < self.cutoff
 989        switch = SwitchFunction(
 990            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
 991        )((distances, edge_mask))
 992
 993        graph_out = {
 994            **graph,
 995            "vec": vec,
 996            "distances": distances,
 997            "switch": switch,
 998            "filter_indices": filter_indices,
 999            "edge_mask": edge_mask,
1000        }
1001
1002        if "alch_group" in inputs:
1003            edge_src=graph["edge_src"]
1004            edge_dst=graph["edge_dst"]
1005            alch_group = inputs["alch_group"]
1006            lambda_e = inputs["alch_elambda"]
1007            lambda_e = 0.5*(1.-jnp.cos(jnp.pi*lambda_e))
1008            mask = alch_group[edge_src] == alch_group[edge_dst]
1009            graph_out["switch_raw"] = switch
1010            graph_out["switch"] = jnp.where(
1011                mask,
1012                switch,
1013                lambda_e * switch ,
1014            )
1015
1016        return {**inputs, self.graph_key: graph_out}

Filter processing for a pre-generated graph

This module is automatically added to a FENNIX model when a GraphFilter is used.

GraphFilterProcessor( cutoff: float, graph_key: str, parent_graph: str, switch_params: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: float

Cutoff distance for the filtering.

graph_key: str

Key of the filtered graph in the inputs.

parent_graph: str

Key of the parent graph in the inputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class GraphAngularExtension:
1019@dataclasses.dataclass(frozen=True)
1020class GraphAngularExtension:
1021    """Add angles list to a graph
1022
1023    FPID: GRAPH_ANGULAR_EXTENSION
1024    """
1025
1026    mult_size: float = 1.05
1027    """Multiplicative factor for resizing the nblist."""
1028    add_neigh: int = 5
1029    """Additional neighbors to add to the nblist when resizing."""
1030    graph_key: str = "graph"
1031    """Key of the graph in the inputs."""
1032
1033    FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION"
1034
1035    def init(self):
1036        return FrozenDict(
1037            {
1038                "nangles": 0,
1039                "nblist_mult_size": self.mult_size,
1040                "max_neigh": self.add_neigh,
1041                "add_neigh": self.add_neigh,
1042            }
1043        )
1044
1045    def get_processor(self) -> Tuple[nn.Module, Dict]:
1046        return GraphAngleProcessor, {
1047            "graph_key": self.graph_key,
1048            "name": f"{self.graph_key}_AngleProcessor",
1049        }
1050
1051    def get_graph_properties(self):
1052        return {
1053            self.graph_key: {
1054                "has_angles": True,
1055            }
1056        }
1057
1058    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1059        """build angle nblist on cpu with numpy and dynamic shapes + store max shapes"""
1060        graph = inputs[self.graph_key]
1061        edge_src = np.array(graph["edge_src"], dtype=np.int32)
1062
1063        new_state = {**state}
1064        state_up = {}
1065        mult_size = state.get("nblist_mult_size", self.mult_size)
1066        assert mult_size >= 1., "nblist_mult_size should be >= 1."
1067
1068        ### count number of neighbors
1069        nat = inputs["species"].shape[0]
1070        count = np.zeros(nat + 1, dtype=np.int32)
1071        np.add.at(count, edge_src, 1)
1072        max_count = int(np.max(count[:-1]))
1073
1074        ### get sizes
1075        max_neigh = state.get("max_neigh", self.add_neigh)
1076        nedge = edge_src.shape[0]
1077        if max_count > max_neigh or add_margin:
1078            prev_max_neigh = max_neigh
1079            max_neigh = max(max_count, max_neigh) + state.get(
1080                "add_neigh", self.add_neigh
1081            )
1082            state_up["max_neigh"] = (max_neigh, prev_max_neigh)
1083            new_state["max_neigh"] = max_neigh
1084
1085        max_neigh_arr = np.empty(max_neigh, dtype=bool)
1086
1087        nedge = edge_src.shape[0]
1088
1089        ### sort edge_src
1090        idx_sort = np.argsort(edge_src)
1091        edge_src_sorted = edge_src[idx_sort]
1092
1093        ### map sparse to dense nblist
1094        offset = np.tile(np.arange(max_count), nat)
1095        if max_count * nat >= nedge:
1096            offset = np.tile(np.arange(max_count), nat)[:nedge]
1097        else:
1098            offset = np.zeros(nedge, dtype=np.int32)
1099            offset[: max_count * nat] = np.tile(np.arange(max_count), nat)
1100
1101        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1102        mask = edge_src_sorted < nat
1103        indices = edge_src_sorted * max_count + offset
1104        indices = indices[mask]
1105        idx_sort = idx_sort[mask]
1106        edge_idx = np.full(nat * max_count, nedge, dtype=np.int32)
1107        edge_idx[indices] = idx_sort
1108        edge_idx = edge_idx.reshape(nat, max_count)
1109
1110        ### find all triplet for each atom center
1111        local_src, local_dst = np.triu_indices(max_count, 1)
1112        angle_src = edge_idx[:, local_src].flatten()
1113        angle_dst = edge_idx[:, local_dst].flatten()
1114
1115        ### mask for valid angles
1116        mask1 = angle_src < nedge
1117        mask2 = angle_dst < nedge
1118        angle_mask = mask1 & mask2
1119
1120        max_angles = state.get("nangles", 0)
1121        idx = np.nonzero(angle_mask)[0]
1122        nangles = idx.shape[0]
1123        if nangles > max_angles or add_margin:
1124            max_angles_prev = max_angles
1125            max_angles = int(mult_size * max(nangles, max_angles)) + 1
1126            state_up["nangles"] = (max_angles, max_angles_prev)
1127            new_state["nangles"] = max_angles
1128
1129        ## filter angles to sparse representation
1130        angle_src_ = np.full(max_angles, nedge, dtype=np.int32)
1131        angle_dst_ = np.full(max_angles, nedge, dtype=np.int32)
1132        angle_src_[:nangles] = angle_src[idx]
1133        angle_dst_[:nangles] = angle_dst[idx]
1134
1135        central_atom = np.full(max_angles, nat, dtype=np.int32)
1136        central_atom[:nangles] = edge_src[angle_src_[:nangles]]
1137
1138        ## update graph
1139        output = {
1140            **inputs,
1141            self.graph_key: {
1142                **graph,
1143                "angle_src": angle_src_,
1144                "angle_dst": angle_dst_,
1145                "central_atom": central_atom,
1146                "angle_overflow": False,
1147                "max_neigh": max_neigh,
1148                "__max_neigh_array": max_neigh_arr,
1149            },
1150        }
1151
1152        if return_state_update:
1153            return FrozenDict(new_state), output, state_up
1154        return FrozenDict(new_state), output
1155
1156    def check_reallocate(self, state, inputs, parent_overflow=False):
1157        """check for overflow and reallocate nblist if necessary"""
1158        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1159        if not overflow:
1160            return state, {}, inputs, False
1161
1162        add_margin = inputs[self.graph_key]["angle_overflow"]
1163        state, inputs, state_up = self(
1164            state, inputs, return_state_update=True, add_margin=add_margin
1165        )
1166        return state, state_up, inputs, True
1167
1168    @partial(jax.jit, static_argnums=(0, 1))
1169    def process(self, state, inputs):
1170        """build angle nblist on accelerator with jax and precomputed shapes"""
1171        graph = inputs[self.graph_key]
1172        edge_src = graph["edge_src"]
1173
1174        ### count number of neighbors
1175        nat = inputs["species"].shape[0]
1176        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1177        max_count = jnp.max(count)
1178
1179        ### get sizes
1180        if state is None:
1181            max_neigh_arr = graph["__max_neigh_array"]
1182            max_neigh = max_neigh_arr.shape[0]
1183            prev_nangles = graph["angle_src"].shape[0]
1184        else:
1185            max_neigh = state.get("max_neigh", self.add_neigh)
1186            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1187            prev_nangles = state.get("nangles", 0)
1188
1189        nedge = edge_src.shape[0]
1190
1191        ### sort edge_src
1192        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1193        edge_src_sorted = edge_src[idx_sort]
1194
1195        ### map sparse to dense nblist
1196        if max_neigh * nat < nedge:
1197            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1198        offset = jnp.asarray(
1199            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1200        )
1201        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1202        indices = edge_src_sorted * max_neigh + offset
1203        edge_idx = (
1204            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1205            .at[indices]
1206            .set(idx_sort, mode="drop")
1207            .reshape(nat, max_neigh)
1208        )
1209
1210        ### find all triplet for each atom center
1211        local_src, local_dst = np.triu_indices(max_neigh, 1)
1212        angle_src = edge_idx[:, local_src].flatten()
1213        angle_dst = edge_idx[:, local_dst].flatten()
1214
1215        ### mask for valid angles
1216        mask1 = angle_src < nedge
1217        mask2 = angle_dst < nedge
1218        angle_mask = mask1 & mask2
1219
1220        ## filter angles to sparse representation
1221        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1222            angle_mask,
1223            prev_nangles,
1224            (angle_src, nedge),
1225            (angle_dst, nedge),
1226        )
1227        ## find central atom
1228        central_atom = edge_src[angle_src]
1229
1230        ## check for overflow
1231        angle_overflow = nangles > prev_nangles
1232        neigh_overflow = max_count > max_neigh
1233        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1234
1235        ## update graph
1236        output = {
1237            **inputs,
1238            self.graph_key: {
1239                **graph,
1240                "angle_src": angle_src,
1241                "angle_dst": angle_dst,
1242                "central_atom": central_atom,
1243                "angle_overflow": overflow,
1244                # "max_neigh": max_neigh,
1245                "__max_neigh_array": max_neigh_arr,
1246            },
1247        }
1248
1249        return output
1250
1251    @partial(jax.jit, static_argnums=(0,))
1252    def update_skin(self, inputs):
1253        return self.process(None, inputs)

Add angles list to a graph

FPID: GRAPH_ANGULAR_EXTENSION

GraphAngularExtension( mult_size: float = 1.05, add_neigh: int = 5, graph_key: str = 'graph')
mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

add_neigh: int = 5

Additional neighbors to add to the nblist when resizing.

graph_key: str = 'graph'

Key of the graph in the inputs.

FPID: ClassVar[str] = 'GRAPH_ANGULAR_EXTENSION'
def init(self):
1035    def init(self):
1036        return FrozenDict(
1037            {
1038                "nangles": 0,
1039                "nblist_mult_size": self.mult_size,
1040                "max_neigh": self.add_neigh,
1041                "add_neigh": self.add_neigh,
1042            }
1043        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
1045    def get_processor(self) -> Tuple[nn.Module, Dict]:
1046        return GraphAngleProcessor, {
1047            "graph_key": self.graph_key,
1048            "name": f"{self.graph_key}_AngleProcessor",
1049        }
def get_graph_properties(self):
1051    def get_graph_properties(self):
1052        return {
1053            self.graph_key: {
1054                "has_angles": True,
1055            }
1056        }
def check_reallocate(self, state, inputs, parent_overflow=False):
1156    def check_reallocate(self, state, inputs, parent_overflow=False):
1157        """check for overflow and reallocate nblist if necessary"""
1158        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1159        if not overflow:
1160            return state, {}, inputs, False
1161
1162        add_margin = inputs[self.graph_key]["angle_overflow"]
1163        state, inputs, state_up = self(
1164            state, inputs, return_state_update=True, add_margin=add_margin
1165        )
1166        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1168    @partial(jax.jit, static_argnums=(0, 1))
1169    def process(self, state, inputs):
1170        """build angle nblist on accelerator with jax and precomputed shapes"""
1171        graph = inputs[self.graph_key]
1172        edge_src = graph["edge_src"]
1173
1174        ### count number of neighbors
1175        nat = inputs["species"].shape[0]
1176        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1177        max_count = jnp.max(count)
1178
1179        ### get sizes
1180        if state is None:
1181            max_neigh_arr = graph["__max_neigh_array"]
1182            max_neigh = max_neigh_arr.shape[0]
1183            prev_nangles = graph["angle_src"].shape[0]
1184        else:
1185            max_neigh = state.get("max_neigh", self.add_neigh)
1186            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1187            prev_nangles = state.get("nangles", 0)
1188
1189        nedge = edge_src.shape[0]
1190
1191        ### sort edge_src
1192        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1193        edge_src_sorted = edge_src[idx_sort]
1194
1195        ### map sparse to dense nblist
1196        if max_neigh * nat < nedge:
1197            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1198        offset = jnp.asarray(
1199            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1200        )
1201        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1202        indices = edge_src_sorted * max_neigh + offset
1203        edge_idx = (
1204            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1205            .at[indices]
1206            .set(idx_sort, mode="drop")
1207            .reshape(nat, max_neigh)
1208        )
1209
1210        ### find all triplet for each atom center
1211        local_src, local_dst = np.triu_indices(max_neigh, 1)
1212        angle_src = edge_idx[:, local_src].flatten()
1213        angle_dst = edge_idx[:, local_dst].flatten()
1214
1215        ### mask for valid angles
1216        mask1 = angle_src < nedge
1217        mask2 = angle_dst < nedge
1218        angle_mask = mask1 & mask2
1219
1220        ## filter angles to sparse representation
1221        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1222            angle_mask,
1223            prev_nangles,
1224            (angle_src, nedge),
1225            (angle_dst, nedge),
1226        )
1227        ## find central atom
1228        central_atom = edge_src[angle_src]
1229
1230        ## check for overflow
1231        angle_overflow = nangles > prev_nangles
1232        neigh_overflow = max_count > max_neigh
1233        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1234
1235        ## update graph
1236        output = {
1237            **inputs,
1238            self.graph_key: {
1239                **graph,
1240                "angle_src": angle_src,
1241                "angle_dst": angle_dst,
1242                "central_atom": central_atom,
1243                "angle_overflow": overflow,
1244                # "max_neigh": max_neigh,
1245                "__max_neigh_array": max_neigh_arr,
1246            },
1247        }
1248
1249        return output

build angle nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1251    @partial(jax.jit, static_argnums=(0,))
1252    def update_skin(self, inputs):
1253        return self.process(None, inputs)
class GraphAngleProcessor(flax.linen.module.Module):
1256class GraphAngleProcessor(nn.Module):
1257    """Process a pre-generated graph to compute angles
1258
1259    This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
1260
1261    """
1262
1263    graph_key: str
1264    """Key of the graph in the inputs."""
1265
1266    @nn.compact
1267    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1268        graph = inputs[self.graph_key]
1269        distances = graph["distances"]
1270        vec = graph["vec"]
1271        angle_src = graph["angle_src"]
1272        angle_dst = graph["angle_dst"]
1273
1274        dir = vec / jnp.clip(distances[:, None], a_min=1.0e-5)
1275        cos_angles = (
1276            dir.at[angle_src].get(mode="fill", fill_value=0.5)
1277            * dir.at[angle_dst].get(mode="fill", fill_value=0.5)
1278        ).sum(axis=-1)
1279
1280        angles = jnp.arccos(0.95 * cos_angles)
1281
1282        return {
1283            **inputs,
1284            self.graph_key: {
1285                **graph,
1286                # "cos_angles": cos_angles,
1287                "angles": angles,
1288                # "angle_mask": angle_mask,
1289            },
1290        }

Process a pre-generated graph to compute angles

This module is automatically added to a FENNIX model when a GraphAngularExtension is used.

GraphAngleProcessor( graph_key: str, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str

Key of the graph in the inputs.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class SpeciesIndexer:
1293@dataclasses.dataclass(frozen=True)
1294class SpeciesIndexer:
1295    """Build an index that splits atomic arrays by species.
1296
1297    FPID: SPECIES_INDEXER
1298
1299    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1300    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1301
1302    """
1303
1304    output_key: str = "species_index"
1305    """Key for the output dictionary."""
1306    species_order: Optional[str] = None
1307    """Comma separated list of species in the order they should be indexed."""
1308    add_atoms: int = 0
1309    """Additional atoms to add to the sizes."""
1310    add_atoms_margin: int = 10
1311    """Additional atoms to add to the sizes when adding margin."""
1312
1313    FPID: ClassVar[str] = "SPECIES_INDEXER"
1314
1315    def init(self):
1316        return FrozenDict(
1317            {
1318                "sizes": {},
1319            }
1320        )
1321
1322    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1323        species = np.array(inputs["species"], dtype=np.int32)
1324        nat = species.shape[0]
1325        set_species, counts = np.unique(species, return_counts=True)
1326
1327        new_state = {**state}
1328        state_up = {}
1329
1330        sizes = state.get("sizes", FrozenDict({}))
1331        new_sizes = {**sizes}
1332        up_sizes = False
1333        counts_dict = {}
1334        for s, c in zip(set_species, counts):
1335            if s <= 0:
1336                continue
1337            counts_dict[s] = c
1338            if c > sizes.get(s, 0):
1339                up_sizes = True
1340                add_atoms = state.get("add_atoms", self.add_atoms)
1341                if add_margin:
1342                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1343                new_sizes[s] = c + add_atoms
1344
1345        new_sizes = FrozenDict(new_sizes)
1346        if up_sizes:
1347            state_up["sizes"] = (new_sizes, sizes)
1348            new_state["sizes"] = new_sizes
1349
1350        if self.species_order is not None:
1351            species_order = [el.strip() for el in self.species_order.split(",")]
1352            max_size_prev = state.get("max_size", 0)
1353            max_size = max(new_sizes.values())
1354            if max_size > max_size_prev:
1355                state_up["max_size"] = (max_size, max_size_prev)
1356                new_state["max_size"] = max_size
1357                max_size_prev = max_size
1358
1359            species_index = np.full((len(species_order), max_size), nat, dtype=np.int32)
1360            for i, el in enumerate(species_order):
1361                s = PERIODIC_TABLE_REV_IDX[el]
1362                if s in counts_dict.keys():
1363                    species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1364        else:
1365            species_index = {
1366                PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32)
1367                for s, c in new_sizes.items()
1368            }
1369            for s, c in zip(set_species, counts):
1370                if s <= 0:
1371                    continue
1372                species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0]
1373
1374        output = {
1375            **inputs,
1376            self.output_key: species_index,
1377            self.output_key + "_overflow": False,
1378        }
1379
1380        if return_state_update:
1381            return FrozenDict(new_state), output, state_up
1382        return FrozenDict(new_state), output
1383
1384    def check_reallocate(self, state, inputs, parent_overflow=False):
1385        """check for overflow and reallocate nblist if necessary"""
1386        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1387        if not overflow:
1388            return state, {}, inputs, False
1389
1390        add_margin = inputs[self.output_key + "_overflow"]
1391        state, inputs, state_up = self(
1392            state, inputs, return_state_update=True, add_margin=add_margin
1393        )
1394        return state, state_up, inputs, True
1395        # return state, {}, inputs, parent_overflow
1396
1397    @partial(jax.jit, static_argnums=(0, 1))
1398    def process(self, state, inputs):
1399        # assert (
1400        #     self.output_key in inputs
1401        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1402
1403        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1404        if self.output_key in inputs and not recompute_species_index:
1405            return inputs
1406
1407        if state is None:
1408            raise ValueError("Species Indexer state must be provided on accelerator.")
1409
1410        species = inputs["species"]
1411        nat = species.shape[0]
1412
1413        sizes = state["sizes"]
1414
1415        if self.species_order is not None:
1416            species_order = [el.strip() for el in self.species_order.split(",")]
1417            max_size = state["max_size"]
1418
1419            species_index = jnp.full(
1420                (len(species_order), max_size), nat, dtype=jnp.int32
1421            )
1422            for i, el in enumerate(species_order):
1423                s = PERIODIC_TABLE_REV_IDX[el]
1424                if s in sizes.keys():
1425                    c = sizes[s]
1426                    species_index = species_index.at[i, :].set(
1427                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1428                    )
1429                # if s in counts_dict.keys():
1430                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1431        else:
1432            # species_index = {
1433            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1434            # for s, c in sizes.items()
1435            # }
1436            species_index = {}
1437            overflow = False
1438            natcount = 0
1439            for s, c in sizes.items():
1440                mask = species == s
1441                new_size = jnp.sum(mask)
1442                natcount = natcount + new_size
1443                overflow = overflow | (new_size > c)  # check if sizes are correct
1444                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1445                    species == s, size=c, fill_value=nat
1446                )[0]
1447
1448            mask = species <= 0
1449            new_size = jnp.sum(mask)
1450            natcount = natcount + new_size
1451            overflow = overflow | (
1452                natcount < species.shape[0]
1453            )  # check if any species missing
1454
1455        return {
1456            **inputs,
1457            self.output_key: species_index,
1458            self.output_key + "_overflow": overflow,
1459        }
1460
1461    @partial(jax.jit, static_argnums=(0,))
1462    def update_skin(self, inputs):
1463        return self.process(None, inputs)

Build an index that splits atomic arrays by species.

FPID: SPECIES_INDEXER

If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.

SpeciesIndexer( output_key: str = 'species_index', species_order: Optional[str] = None, add_atoms: int = 0, add_atoms_margin: int = 10)
output_key: str = 'species_index'

Key for the output dictionary.

species_order: Optional[str] = None

Comma separated list of species in the order they should be indexed.

add_atoms: int = 0

Additional atoms to add to the sizes.

add_atoms_margin: int = 10

Additional atoms to add to the sizes when adding margin.

FPID: ClassVar[str] = 'SPECIES_INDEXER'
def init(self):
1315    def init(self):
1316        return FrozenDict(
1317            {
1318                "sizes": {},
1319            }
1320        )
def check_reallocate(self, state, inputs, parent_overflow=False):
1384    def check_reallocate(self, state, inputs, parent_overflow=False):
1385        """check for overflow and reallocate nblist if necessary"""
1386        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1387        if not overflow:
1388            return state, {}, inputs, False
1389
1390        add_margin = inputs[self.output_key + "_overflow"]
1391        state, inputs, state_up = self(
1392            state, inputs, return_state_update=True, add_margin=add_margin
1393        )
1394        return state, state_up, inputs, True
1395        # return state, {}, inputs, parent_overflow

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1397    @partial(jax.jit, static_argnums=(0, 1))
1398    def process(self, state, inputs):
1399        # assert (
1400        #     self.output_key in inputs
1401        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1402
1403        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1404        if self.output_key in inputs and not recompute_species_index:
1405            return inputs
1406
1407        if state is None:
1408            raise ValueError("Species Indexer state must be provided on accelerator.")
1409
1410        species = inputs["species"]
1411        nat = species.shape[0]
1412
1413        sizes = state["sizes"]
1414
1415        if self.species_order is not None:
1416            species_order = [el.strip() for el in self.species_order.split(",")]
1417            max_size = state["max_size"]
1418
1419            species_index = jnp.full(
1420                (len(species_order), max_size), nat, dtype=jnp.int32
1421            )
1422            for i, el in enumerate(species_order):
1423                s = PERIODIC_TABLE_REV_IDX[el]
1424                if s in sizes.keys():
1425                    c = sizes[s]
1426                    species_index = species_index.at[i, :].set(
1427                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1428                    )
1429                # if s in counts_dict.keys():
1430                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1431        else:
1432            # species_index = {
1433            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1434            # for s, c in sizes.items()
1435            # }
1436            species_index = {}
1437            overflow = False
1438            natcount = 0
1439            for s, c in sizes.items():
1440                mask = species == s
1441                new_size = jnp.sum(mask)
1442                natcount = natcount + new_size
1443                overflow = overflow | (new_size > c)  # check if sizes are correct
1444                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1445                    species == s, size=c, fill_value=nat
1446                )[0]
1447
1448            mask = species <= 0
1449            new_size = jnp.sum(mask)
1450            natcount = natcount + new_size
1451            overflow = overflow | (
1452                natcount < species.shape[0]
1453            )  # check if any species missing
1454
1455        return {
1456            **inputs,
1457            self.output_key: species_index,
1458            self.output_key + "_overflow": overflow,
1459        }
@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1461    @partial(jax.jit, static_argnums=(0,))
1462    def update_skin(self, inputs):
1463        return self.process(None, inputs)
@dataclasses.dataclass(frozen=True)
class BlockIndexer:
1465@dataclasses.dataclass(frozen=True)
1466class BlockIndexer:
1467    """Build an index that splits atomic arrays by chemical blocks.
1468
1469    FPID: BLOCK_INDEXER
1470
1471    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1472    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1473
1474    """
1475
1476    output_key: str = "block_index"
1477    """Key for the output dictionary."""
1478    add_atoms: int = 0
1479    """Additional atoms to add to the sizes."""
1480    add_atoms_margin: int = 10
1481    """Additional atoms to add to the sizes when adding margin."""
1482    split_CNOPSSe: bool = False
1483
1484    FPID: ClassVar[str] = "BLOCK_INDEXER"
1485
1486    def init(self):
1487        return FrozenDict(
1488            {
1489                "sizes": {},
1490            }
1491        )
1492
1493    def build_chemical_blocks(self):
1494        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1495        if self.split_CNOPSSe:
1496            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1497            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1498        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1499        if self.split_CNOPSSe:
1500            _CHEMICAL_BLOCKS[6] = 1
1501            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1502            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1503            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1504            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1505            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1506        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1507
1508    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1509        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1510
1511        species = np.array(inputs["species"], dtype=np.int32)
1512        blocks = _CHEMICAL_BLOCKS[species]
1513        nat = species.shape[0]
1514        set_blocks, counts = np.unique(blocks, return_counts=True)
1515
1516        new_state = {**state}
1517        state_up = {}
1518
1519        sizes = state.get("sizes", FrozenDict({}))
1520        new_sizes = {**sizes}
1521        up_sizes = False
1522        for s, c in zip(set_blocks, counts):
1523            if s < 0:
1524                continue
1525            key = (s, _CHEMICAL_BLOCKS_NAMES[s])
1526            if c > sizes.get(key, 0):
1527                up_sizes = True
1528                add_atoms = state.get("add_atoms", self.add_atoms)
1529                if add_margin:
1530                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1531                new_sizes[key] = c + add_atoms
1532
1533        new_sizes = FrozenDict(new_sizes)
1534        if up_sizes:
1535            state_up["sizes"] = (new_sizes, sizes)
1536            new_state["sizes"] = new_sizes
1537
1538        block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES}
1539        for (_,n), c in new_sizes.items():
1540            block_index[n] = np.full(c, nat, dtype=np.int32)
1541        # block_index = {
1542            # n: np.full(c, nat, dtype=np.int32)
1543            # for (_,n), c in new_sizes.items()
1544        # }
1545        for s, c in zip(set_blocks, counts):
1546            if s < 0:
1547                continue
1548            block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0]
1549
1550        output = {
1551            **inputs,
1552            self.output_key: block_index,
1553            self.output_key + "_overflow": False,
1554        }
1555
1556        if return_state_update:
1557            return FrozenDict(new_state), output, state_up
1558        return FrozenDict(new_state), output
1559
1560    def check_reallocate(self, state, inputs, parent_overflow=False):
1561        """check for overflow and reallocate nblist if necessary"""
1562        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1563        if not overflow:
1564            return state, {}, inputs, False
1565
1566        add_margin = inputs[self.output_key + "_overflow"]
1567        state, inputs, state_up = self(
1568            state, inputs, return_state_update=True, add_margin=add_margin
1569        )
1570        return state, state_up, inputs, True
1571        # return state, {}, inputs, parent_overflow
1572
1573    @partial(jax.jit, static_argnums=(0, 1))
1574    def process(self, state, inputs):
1575        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1576        # assert (
1577        #     self.output_key in inputs
1578        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1579
1580        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1581        if self.output_key in inputs and not recompute_species_index:
1582            return inputs
1583
1584        if state is None:
1585            raise ValueError("Block Indexer state must be provided on accelerator.")
1586
1587        species = inputs["species"]
1588        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1589        nat = species.shape[0]
1590
1591        sizes = state["sizes"]
1592
1593        # species_index = {
1594        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1595        # for s, c in sizes.items()
1596        # }
1597        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1598        overflow = False
1599        natcount = 0
1600        for (s,name), c in sizes.items():
1601            mask = blocks == s
1602            new_size = jnp.sum(mask)
1603            natcount = natcount + new_size
1604            overflow = overflow | (new_size > c)  # check if sizes are correct
1605            block_index[name] = jnp.nonzero(
1606                mask, size=c, fill_value=nat
1607            )[0]
1608
1609        mask = blocks < 0
1610        new_size = jnp.sum(mask)
1611        natcount = natcount + new_size
1612        overflow = overflow | (
1613            natcount < species.shape[0]
1614        )  # check if any species missing
1615
1616        return {
1617            **inputs,
1618            self.output_key: block_index,
1619            self.output_key + "_overflow": overflow,
1620        }
1621
1622    @partial(jax.jit, static_argnums=(0,))
1623    def update_skin(self, inputs):
1624        return self.process(None, inputs)

Build an index that splits atomic arrays by chemical blocks.

FPID: BLOCK_INDEXER

If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.

BlockIndexer( output_key: str = 'block_index', add_atoms: int = 0, add_atoms_margin: int = 10, split_CNOPSSe: bool = False)
output_key: str = 'block_index'

Key for the output dictionary.

add_atoms: int = 0

Additional atoms to add to the sizes.

add_atoms_margin: int = 10

Additional atoms to add to the sizes when adding margin.

split_CNOPSSe: bool = False
FPID: ClassVar[str] = 'BLOCK_INDEXER'
def init(self):
1486    def init(self):
1487        return FrozenDict(
1488            {
1489                "sizes": {},
1490            }
1491        )
def build_chemical_blocks(self):
1493    def build_chemical_blocks(self):
1494        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1495        if self.split_CNOPSSe:
1496            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1497            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1498        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1499        if self.split_CNOPSSe:
1500            _CHEMICAL_BLOCKS[6] = 1
1501            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1502            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1503            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1504            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1505            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1506        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
def check_reallocate(self, state, inputs, parent_overflow=False):
1560    def check_reallocate(self, state, inputs, parent_overflow=False):
1561        """check for overflow and reallocate nblist if necessary"""
1562        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1563        if not overflow:
1564            return state, {}, inputs, False
1565
1566        add_margin = inputs[self.output_key + "_overflow"]
1567        state, inputs, state_up = self(
1568            state, inputs, return_state_update=True, add_margin=add_margin
1569        )
1570        return state, state_up, inputs, True
1571        # return state, {}, inputs, parent_overflow

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1573    @partial(jax.jit, static_argnums=(0, 1))
1574    def process(self, state, inputs):
1575        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1576        # assert (
1577        #     self.output_key in inputs
1578        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1579
1580        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1581        if self.output_key in inputs and not recompute_species_index:
1582            return inputs
1583
1584        if state is None:
1585            raise ValueError("Block Indexer state must be provided on accelerator.")
1586
1587        species = inputs["species"]
1588        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1589        nat = species.shape[0]
1590
1591        sizes = state["sizes"]
1592
1593        # species_index = {
1594        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1595        # for s, c in sizes.items()
1596        # }
1597        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1598        overflow = False
1599        natcount = 0
1600        for (s,name), c in sizes.items():
1601            mask = blocks == s
1602            new_size = jnp.sum(mask)
1603            natcount = natcount + new_size
1604            overflow = overflow | (new_size > c)  # check if sizes are correct
1605            block_index[name] = jnp.nonzero(
1606                mask, size=c, fill_value=nat
1607            )[0]
1608
1609        mask = blocks < 0
1610        new_size = jnp.sum(mask)
1611        natcount = natcount + new_size
1612        overflow = overflow | (
1613            natcount < species.shape[0]
1614        )  # check if any species missing
1615
1616        return {
1617            **inputs,
1618            self.output_key: block_index,
1619            self.output_key + "_overflow": overflow,
1620        }
@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1622    @partial(jax.jit, static_argnums=(0,))
1623    def update_skin(self, inputs):
1624        return self.process(None, inputs)
@dataclasses.dataclass(frozen=True)
class AtomPadding:
1627@dataclasses.dataclass(frozen=True)
1628class AtomPadding:
1629    """Pad atomic arrays to a fixed size."""
1630
1631    mult_size: float = 1.2
1632    """Multiplicative factor for resizing the atomic arrays."""
1633    add_sys: int = 0
1634
1635    def init(self):
1636        return {"prev_nat": 0, "prev_nsys": 0}
1637
1638    def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]:
1639        species = inputs["species"]
1640        nat = species.shape[0]
1641
1642        prev_nat = state.get("prev_nat", 0)
1643        prev_nat_ = prev_nat
1644        if nat > prev_nat_:
1645            prev_nat_ = int(self.mult_size * nat) + 1
1646
1647        nsys = len(inputs["natoms"])
1648        prev_nsys = state.get("prev_nsys", 0)
1649        prev_nsys_ = prev_nsys
1650        if nsys > prev_nsys_:
1651            prev_nsys_ = nsys + self.add_sys
1652
1653        add_atoms = prev_nat_ - nat
1654        add_sys = prev_nsys_ - nsys  + 1
1655        output = {**inputs}
1656        if add_atoms > 0:
1657            for k, v in inputs.items():
1658                if isinstance(v, np.ndarray) or isinstance(v, jax.Array):
1659                    if v.shape[0] == nat:
1660                        output[k] = np.append(
1661                            v,
1662                            np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype),
1663                            axis=0,
1664                        )
1665                    elif v.shape[0] == nsys:
1666                        if k == "cells":
1667                            output[k] = np.append(
1668                                v,
1669                                1000
1670                                * np.eye(3, dtype=v.dtype)[None, :, :].repeat(
1671                                    add_sys, axis=0
1672                                ),
1673                                axis=0,
1674                            )
1675                        else:
1676                            output[k] = np.append(
1677                                v,
1678                                np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype),
1679                                axis=0,
1680                            )
1681            output["natoms"] = np.append(
1682                inputs["natoms"], np.zeros(add_sys, dtype=np.int32)
1683            )
1684            output["species"] = np.append(
1685                species, -1 * np.ones(add_atoms, dtype=species.dtype)
1686            )
1687            output["batch_index"] = np.append(
1688                inputs["batch_index"],
1689                np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype),
1690            )
1691            if "system_index" in inputs:
1692                output["system_index"] = np.append(
1693                    inputs["system_index"],
1694                    np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype),
1695                )
1696
1697        output["true_atoms"] = output["species"] > 0
1698        output["true_sys"] = np.arange(len(output["natoms"])) < nsys
1699
1700        state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_}
1701
1702        return FrozenDict(state), output

Pad atomic arrays to a fixed size.

AtomPadding(mult_size: float = 1.2, add_sys: int = 0)
mult_size: float = 1.2

Multiplicative factor for resizing the atomic arrays.

add_sys: int = 0
def init(self):
1635    def init(self):
1636        return {"prev_nat": 0, "prev_nsys": 0}
def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1705def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1706    """Remove padding from atomic arrays."""
1707    if "true_atoms" not in inputs:
1708        return inputs
1709
1710    species = inputs["species"]
1711    true_atoms = inputs["true_atoms"]
1712    true_sys = inputs["true_sys"]
1713    natall = species.shape[0]
1714    nat = np.argmax(species <= 0)
1715    if nat == 0:
1716        return inputs
1717
1718    natoms = inputs["natoms"]
1719    nsysall = len(natoms)
1720
1721    output = {**inputs}
1722    for k, v in inputs.items():
1723        if isinstance(v, jax.Array) or isinstance(v, np.ndarray):
1724            if v.ndim == 0:
1725                continue
1726            if v.shape[0] == natall:
1727                output[k] = v[true_atoms]
1728            elif v.shape[0] == nsysall:
1729                output[k] = v[true_sys]
1730    del output["true_sys"]
1731    del output["true_atoms"]
1732    return output

Remove padding from atomic arrays.

def check_input(inputs):
1735def check_input(inputs):
1736    """Check the input dictionary for required keys and types."""
1737    assert "species" in inputs, "species must be provided"
1738    assert "coordinates" in inputs, "coordinates must be provided"
1739    species = inputs["species"].astype(np.int32)
1740    ifake = np.argmax(species <= 0)
1741    if ifake > 0:
1742        assert np.all(species[:ifake] > 0), "species must be positive"
1743    nat = inputs["species"].shape[0]
1744
1745    natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32)
1746    batch_index = inputs.get(
1747        "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms)
1748    ).astype(np.int32)
1749    output = {**inputs, "natoms": natoms, "batch_index": batch_index}
1750    if "cells" in inputs:
1751        cells = inputs["cells"]
1752        if "reciprocal_cells" not in inputs:
1753            reciprocal_cells = np.linalg.inv(cells)
1754        else:
1755            reciprocal_cells = inputs["reciprocal_cells"]
1756        if cells.ndim == 2:
1757            cells = cells[None, :, :]
1758        if reciprocal_cells.ndim == 2:
1759            reciprocal_cells = reciprocal_cells[None, :, :]
1760        output["cells"] = cells
1761        output["reciprocal_cells"] = reciprocal_cells
1762
1763    return output

Check the input dictionary for required keys and types.

def convert_to_jax(data):
1766def convert_to_jax(data):
1767    """Convert a numpy arrays to jax arrays in a pytree."""
1768
1769    def convert(x):
1770        if isinstance(x, np.ndarray):
1771            # if x.dtype == np.float64:
1772            #     return jnp.asarray(x, dtype=jnp.float32)
1773            return jnp.asarray(x)
1774        return x
1775
1776    return jax.tree_util.tree_map(convert, data)

Convert a numpy arrays to jax arrays in a pytree.

class JaxConverter(flax.linen.module.Module):
1779class JaxConverter(nn.Module):
1780    """Convert numpy arrays to jax arrays in a pytree."""
1781
1782    def __call__(self, data):
1783        return convert_to_jax(data)

Convert numpy arrays to jax arrays in a pytree.

JaxConverter( parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class PreprocessingChain:
1786@dataclasses.dataclass(frozen=True)
1787class PreprocessingChain:
1788    """Chain of preprocessing layers."""
1789
1790    layers: Tuple[Callable[..., Dict[str, Any]]]
1791    """Preprocessing layers."""
1792    use_atom_padding: bool = False
1793    """Add an AtomPadding layer at the beginning of the chain."""
1794    atom_padder: AtomPadding = AtomPadding()
1795    """AtomPadding layer."""
1796
1797    def __post_init__(self):
1798        if not isinstance(self.layers, Sequence):
1799            raise ValueError(
1800                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
1801            )
1802        if not self.layers:
1803            raise ValueError(f"Error: no Preprocessing layers were provided.")
1804
1805    def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]:
1806        do_check_input = state.get("check_input", True)
1807        if do_check_input:
1808            inputs = check_input(inputs)
1809        new_state = []
1810        layer_state = state["layers_state"]
1811        i = 0
1812        if self.use_atom_padding:
1813            s, inputs = self.atom_padder(layer_state[0], inputs)
1814            new_state.append(s)
1815            i += 1
1816        for layer in self.layers:
1817            s, inputs = layer(layer_state[i], inputs, return_state_update=False)
1818            new_state.append(s)
1819            i += 1
1820        return FrozenDict({**state, "layers_state": tuple(new_state)}), convert_to_jax(
1821            inputs
1822        )
1823
1824    def check_reallocate(self, state, inputs):
1825        new_state = []
1826        state_up = []
1827        layer_state = state["layers_state"]
1828        i = 0
1829        if self.use_atom_padding:
1830            new_state.append(layer_state[0])
1831            i += 1
1832        parent_overflow = False
1833        for layer in self.layers:
1834            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1835                layer_state[i], inputs, parent_overflow
1836            )
1837            new_state.append(s)
1838            state_up.append(s_up)
1839            i += 1
1840
1841        if not parent_overflow:
1842            return state, {}, inputs, False
1843        return (
1844            FrozenDict({**state, "layers_state": tuple(new_state)}),
1845            state_up,
1846            inputs,
1847            True,
1848        )
1849
1850    def atom_padding(self, state, inputs):
1851        if self.use_atom_padding:
1852            padder_state = state["layers_state"][0]
1853            return self.atom_padder(padder_state, inputs)
1854        return state, inputs
1855
1856    @partial(jax.jit, static_argnums=(0, 1))
1857    def process(self, state, inputs):
1858        layer_state = state["layers_state"]
1859        i = 1 if self.use_atom_padding else 0
1860        for layer in self.layers:
1861            inputs = layer.process(layer_state[i], inputs)
1862            i += 1
1863        return inputs
1864
1865    @partial(jax.jit, static_argnums=(0))
1866    def update_skin(self, inputs):
1867        for layer in self.layers:
1868            inputs = layer.update_skin(inputs)
1869        return inputs
1870
1871    def init(self):
1872        state = []
1873        if self.use_atom_padding:
1874            state.append(self.atom_padder.init())
1875        for layer in self.layers:
1876            state.append(layer.init())
1877        return FrozenDict({"check_input": True, "layers_state": state})
1878
1879    def init_with_output(self, inputs):
1880        state = self.init()
1881        return self(state, inputs)
1882
1883    def get_processors(self):
1884        processors = []
1885        for layer in self.layers:
1886            if hasattr(layer, "get_processor"):
1887                processors.append(layer.get_processor())
1888        return processors
1889
1890    def get_graphs_properties(self):
1891        properties = {}
1892        for layer in self.layers:
1893            if hasattr(layer, "get_graph_properties"):
1894                properties = deep_update(properties, layer.get_graph_properties())
1895        return properties

Chain of preprocessing layers.

PreprocessingChain( layers: Tuple[Callable[..., Dict[str, Any]]], use_atom_padding: bool = False, atom_padder: AtomPadding = AtomPadding(mult_size=1.2, add_sys=0))
layers: Tuple[Callable[..., Dict[str, Any]]]

Preprocessing layers.

use_atom_padding: bool = False

Add an AtomPadding layer at the beginning of the chain.

atom_padder: AtomPadding = AtomPadding(mult_size=1.2, add_sys=0)

AtomPadding layer.

def check_reallocate(self, state, inputs):
1824    def check_reallocate(self, state, inputs):
1825        new_state = []
1826        state_up = []
1827        layer_state = state["layers_state"]
1828        i = 0
1829        if self.use_atom_padding:
1830            new_state.append(layer_state[0])
1831            i += 1
1832        parent_overflow = False
1833        for layer in self.layers:
1834            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1835                layer_state[i], inputs, parent_overflow
1836            )
1837            new_state.append(s)
1838            state_up.append(s_up)
1839            i += 1
1840
1841        if not parent_overflow:
1842            return state, {}, inputs, False
1843        return (
1844            FrozenDict({**state, "layers_state": tuple(new_state)}),
1845            state_up,
1846            inputs,
1847            True,
1848        )
def atom_padding(self, state, inputs):
1850    def atom_padding(self, state, inputs):
1851        if self.use_atom_padding:
1852            padder_state = state["layers_state"][0]
1853            return self.atom_padder(padder_state, inputs)
1854        return state, inputs
@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1856    @partial(jax.jit, static_argnums=(0, 1))
1857    def process(self, state, inputs):
1858        layer_state = state["layers_state"]
1859        i = 1 if self.use_atom_padding else 0
1860        for layer in self.layers:
1861            inputs = layer.process(layer_state[i], inputs)
1862            i += 1
1863        return inputs
@partial(jax.jit, static_argnums=0)
def update_skin(self, inputs):
1865    @partial(jax.jit, static_argnums=(0))
1866    def update_skin(self, inputs):
1867        for layer in self.layers:
1868            inputs = layer.update_skin(inputs)
1869        return inputs
def init(self):
1871    def init(self):
1872        state = []
1873        if self.use_atom_padding:
1874            state.append(self.atom_padder.init())
1875        for layer in self.layers:
1876            state.append(layer.init())
1877        return FrozenDict({"check_input": True, "layers_state": state})
def init_with_output(self, inputs):
1879    def init_with_output(self, inputs):
1880        state = self.init()
1881        return self(state, inputs)
def get_processors(self):
1883    def get_processors(self):
1884        processors = []
1885        for layer in self.layers:
1886            if hasattr(layer, "get_processor"):
1887                processors.append(layer.get_processor())
1888        return processors
def get_graphs_properties(self):
1890    def get_graphs_properties(self):
1891        properties = {}
1892        for layer in self.layers:
1893            if hasattr(layer, "get_graph_properties"):
1894                properties = deep_update(properties, layer.get_graph_properties())
1895        return properties