fennol.models.preprocessing

   1import flax.linen as nn
   2from typing import Sequence, Callable, Union, Dict, Any, ClassVar
   3import jax.numpy as jnp
   4import jax
   5import numpy as np
   6from typing import Optional, Tuple
   7import numba
   8import dataclasses
   9from functools import partial
  10
  11from flax.core.frozen_dict import FrozenDict
  12
  13
  14from ..utils.activations import chain,safe_sqrt
  15from ..utils import deep_update, mask_filter_1d
  16from ..utils.kspace import get_reciprocal_space_parameters
  17from .misc.misc import SwitchFunction
  18from ..utils.periodic_table import PERIODIC_TABLE, PERIODIC_TABLE_REV_IDX,CHEMICAL_BLOCKS,CHEMICAL_BLOCKS_NAMES
  19
  20
  21@dataclasses.dataclass(frozen=True)
  22class GraphGenerator:
  23    """Generate a graph from a set of coordinates
  24
  25    FPID: GRAPH
  26
  27    For now, we generate all pairs of atoms and filter based on cutoff.
  28    If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist.
  29    """
  30
  31    cutoff: float
  32    """Cutoff distance for the graph."""
  33    graph_key: str = "graph"
  34    """Key of the graph in the outputs."""
  35    switch_params: dict = dataclasses.field(default_factory=dict, hash=False)
  36    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
  37    kmax: int = 30
  38    """Maximum number of k-points to consider."""
  39    kthr: float = 1e-6
  40    """Threshold for k-point filtering."""
  41    k_space: bool = False
  42    """Whether to generate k-space information for the graph."""
  43    mult_size: float = 1.05
  44    """Multiplicative factor for resizing the nblist."""
  45    # covalent_cutoff: bool = False
  46
  47    FPID: ClassVar[str] = "GRAPH"
  48
  49    def init(self):
  50        return FrozenDict(
  51            {
  52                "max_nat": 1,
  53                "npairs": 1,
  54                "nblist_mult_size": self.mult_size,
  55            }
  56        )
  57
  58    def get_processor(self) -> Tuple[nn.Module, Dict]:
  59        return GraphProcessor, {
  60            "cutoff": self.cutoff,
  61            "graph_key": self.graph_key,
  62            "switch_params": self.switch_params,
  63            "name": f"{self.graph_key}_Processor",
  64        }
  65
  66    def get_graph_properties(self):
  67        return {
  68            self.graph_key: {
  69                "cutoff": self.cutoff,
  70                "directed": True,
  71            }
  72        }
  73
  74    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
  75        """build a nblist on cpu with numpy and dynamic shapes + store max shapes"""
  76        if self.graph_key in inputs:
  77            graph = inputs[self.graph_key]
  78            if "keep_graph" in graph:
  79                return state, inputs
  80
  81        coords = np.array(inputs["coordinates"], dtype=np.float32)
  82        natoms = np.array(inputs["natoms"], dtype=np.int32)
  83        batch_index = np.array(inputs["batch_index"], dtype=np.int32)
  84
  85        new_state = {**state}
  86        state_up = {}
  87
  88        mult_size = state.get("nblist_mult_size", self.mult_size)
  89        assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0"
  90
  91        if natoms.shape[0] == 1:
  92            max_nat = coords.shape[0]
  93            true_max_nat = max_nat
  94        else:
  95            max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0]))
  96            true_max_nat = int(np.max(natoms))
  97            if true_max_nat > max_nat:
  98                add_atoms = state.get("add_atoms", 0)
  99                new_maxnat = true_max_nat + add_atoms
 100                state_up["max_nat"] = (new_maxnat, max_nat)
 101                new_state["max_nat"] = new_maxnat
 102
 103        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
 104
 105        ### compute indices of all pairs
 106        p1, p2 = np.triu_indices(true_max_nat, 1)
 107        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
 108        pbc_shifts = None
 109        if natoms.shape[0] > 1:
 110            ## batching => mask irrelevant pairs
 111            mask_p12 = (
 112                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
 113            ).flatten()
 114            shift = np.concatenate(
 115                (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32))
 116            )
 117            p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
 118            p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
 119
 120        apply_pbc = "cells" in inputs
 121        if not apply_pbc:
 122            ### NO PBC
 123            vec = coords[p2] - coords[p1]
 124        else:
 125            cells = np.array(inputs["cells"], dtype=np.float32)
 126            reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32)
 127            minimage = state.get("minimum_image", True)
 128            if minimage:
 129                ## MINIMUM IMAGE CONVENTION
 130                vec = coords[p2] - coords[p1]
 131                if cells.shape[0] == 1:
 132                    vecpbc = np.dot(vec, reciprocal_cells[0])
 133                    pbc_shifts = -np.round(vecpbc)
 134                    vec = vec + np.dot(pbc_shifts, cells[0])
 135                else:
 136                    batch_index_vec = batch_index[p1]
 137                    vecpbc = np.einsum(
 138                        "aj,aji->ai", vec, reciprocal_cells[batch_index_vec]
 139                    )
 140                    pbc_shifts = -np.round(vecpbc)
 141                    vec = vec + np.einsum(
 142                        "aj,aji->ai", pbc_shifts, cells[batch_index_vec]
 143                    )
 144            else:
 145                ### GENERAL PBC
 146                ## put all atoms in central box
 147                if cells.shape[0] == 1:
 148                    coords_pbc = np.dot(coords, reciprocal_cells[0])
 149                    at_shifts = -np.floor(coords_pbc)
 150                    coords_pbc = coords + np.dot(at_shifts, cells[0])
 151                else:
 152                    coords_pbc = np.einsum(
 153                        "aj,aji->ai", coords, reciprocal_cells[batch_index]
 154                    )
 155                    at_shifts = -np.floor(coords_pbc)
 156                    coords_pbc = coords + np.einsum(
 157                        "aj,aji->ai", at_shifts, cells[batch_index]
 158                    )
 159                vec = coords_pbc[p2] - coords_pbc[p1]
 160
 161                ## compute maximum number of repeats
 162                inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5
 163                cdinv = cutoff_skin * inv_distances
 164                num_repeats_all = np.ceil(cdinv).astype(np.int32)
 165                if "true_sys" in inputs:
 166                    num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0)
 167                # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all)
 168                num_repeats = np.max(num_repeats_all, axis=0)
 169                num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0)))
 170                if np.any(num_repeats > num_repeats_prev):
 171                    num_repeats_new = np.maximum(num_repeats, num_repeats_prev)
 172                    state_up["num_repeats_pbc"] = (
 173                        tuple(num_repeats_new),
 174                        tuple(num_repeats_prev),
 175                    )
 176                    new_state["num_repeats_pbc"] = tuple(num_repeats_new)
 177                ## build all possible shifts
 178                cell_shift_pbc = np.array(
 179                    np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
 180                    dtype=cells.dtype,
 181                ).T.reshape(-1, 3)
 182                ## shift applied to vectors
 183                if cells.shape[0] == 1:
 184                    dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :]
 185                    vec = (vec[:, None, :] + dvec).reshape(-1, 3)
 186                    pbc_shifts = np.broadcast_to(
 187                        cell_shift_pbc[None, :, :],
 188                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
 189                    ).reshape(-1, 3)
 190                    p1 = np.broadcast_to(
 191                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
 192                    ).flatten()
 193                    p2 = np.broadcast_to(
 194                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
 195                    ).flatten()
 196                    if natoms.shape[0] > 1:
 197                        mask_p12 = np.broadcast_to(
 198                            mask_p12[:, None],
 199                            (mask_p12.shape[0], cell_shift_pbc.shape[0]),
 200                        ).flatten()
 201                else:
 202                    dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells)
 203
 204                    ## get pbc shifts specific to each box
 205                    cell_shift_pbc = np.broadcast_to(
 206                        cell_shift_pbc[None, :, :],
 207                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
 208                    )
 209                    mask = np.all(
 210                        np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
 211                    ).flatten()
 212                    idx = np.nonzero(mask)[0]
 213                    nshifts = idx.shape[0]
 214                    nshifts_prev = state.get("nshifts_pbc", 0)
 215                    if nshifts > nshifts_prev or add_margin:
 216                        nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1
 217                        state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev)
 218                        new_state["nshifts_pbc"] = nshifts_new
 219
 220                    dvec_filter = dvec.reshape(-1, 3)[idx, :]
 221                    cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :]
 222
 223                    ## get batch shift in the dvec_filter array
 224                    nrep = np.prod(2 * num_repeats_all + 1, axis=1)
 225                    bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1]))
 226
 227                    ## compute vectors
 228                    batch_index_vec = batch_index[p1]
 229                    nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0)
 230                    vec = vec.repeat(nrep_vec, axis=0)
 231                    nvec_pbc = nrep_vec.sum() #vec.shape[0]
 232                    nvec_pbc_prev = state.get("nvec_pbc", 0)
 233                    if nvec_pbc > nvec_pbc_prev or add_margin:
 234                        nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1
 235                        state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev)
 236                        new_state["nvec_pbc"] = nvec_pbc_new
 237
 238                    # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev)
 239                    ## get shift index
 240                    dshift = np.concatenate(
 241                        (np.array([0]), np.cumsum(nrep_vec)[:-1])
 242                    ).repeat(nrep_vec)
 243                    # ishift = np.arange(dshift.shape[0])-dshift
 244                    # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec)
 245                    icellshift = (
 246                        np.arange(dshift.shape[0])
 247                        - dshift
 248                        + bshift[batch_index_vec].repeat(nrep_vec)
 249                    )
 250                    # shift vectors
 251                    vec = vec + dvec_filter[icellshift]
 252                    pbc_shifts = cell_shift_pbc_filter[icellshift]
 253
 254                    p1 = np.repeat(p1, nrep_vec)
 255                    p2 = np.repeat(p2, nrep_vec)
 256                    if natoms.shape[0] > 1:
 257                        mask_p12 = np.repeat(mask_p12, nrep_vec)
 258
 259        ## compute distances
 260        d12 = (vec**2).sum(axis=-1)
 261        if natoms.shape[0] > 1:
 262            d12 = np.where(mask_p12, d12, cutoff_skin**2)
 263
 264        ## filter pairs
 265        max_pairs = state.get("npairs", 1)
 266        mask = d12 < cutoff_skin**2
 267        idx = np.nonzero(mask)[0]
 268        npairs = idx.shape[0]
 269        if npairs > max_pairs or add_margin:
 270            prev_max_pairs = max_pairs
 271            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 272            state_up["npairs"] = (max_pairs, prev_max_pairs)
 273            new_state["npairs"] = max_pairs
 274
 275        nat = coords.shape[0]
 276        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 277        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 278        d12_ = np.full(max_pairs, cutoff_skin**2)
 279        edge_src[:npairs] = p1[idx]
 280        edge_dst[:npairs] = p2[idx]
 281        d12_[:npairs] = d12[idx]
 282        d12 = d12_
 283
 284        if apply_pbc:
 285            pbc_shifts_ = np.zeros((max_pairs, 3))
 286            pbc_shifts_[:npairs] = pbc_shifts[idx]
 287            pbc_shifts = pbc_shifts_
 288            if not minimage:
 289                pbc_shifts[:npairs] = (
 290                    pbc_shifts[:npairs]
 291                    + at_shifts[edge_dst[:npairs]]
 292                    - at_shifts[edge_src[:npairs]]
 293                )
 294
 295        if "nblist_skin" in state:
 296            edge_src_skin = edge_src
 297            edge_dst_skin = edge_dst
 298            if apply_pbc:
 299                pbc_shifts_skin = pbc_shifts
 300            max_pairs_skin = state.get("npairs_skin", 1)
 301            mask = d12 < self.cutoff**2
 302            idx = np.nonzero(mask)[0]
 303            npairs_skin = idx.shape[0]
 304            if npairs_skin > max_pairs_skin or add_margin:
 305                prev_max_pairs_skin = max_pairs_skin
 306                max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1
 307                state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin)
 308                new_state["npairs_skin"] = max_pairs_skin
 309            edge_src = np.full(max_pairs_skin, nat, dtype=np.int32)
 310            edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32)
 311            d12_ = np.full(max_pairs_skin, self.cutoff**2)
 312            edge_src[:npairs_skin] = edge_src_skin[idx]
 313            edge_dst[:npairs_skin] = edge_dst_skin[idx]
 314            d12_[:npairs_skin] = d12[idx]
 315            d12 = d12_
 316            if apply_pbc:
 317                pbc_shifts = np.full((max_pairs_skin, 3), 0.0)
 318                pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx]
 319
 320        ## symmetrize
 321        edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate(
 322            (edge_dst, edge_src)
 323        )
 324        d12 = np.concatenate((d12, d12))
 325        if apply_pbc:
 326            pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts))
 327
 328        graph = inputs.get(self.graph_key, {})
 329        graph_out = {
 330            **graph,
 331            "edge_src": edge_src,
 332            "edge_dst": edge_dst,
 333            "d12": d12,
 334            "overflow": False,
 335            "pbc_shifts": pbc_shifts,
 336        }
 337        if "nblist_skin" in state:
 338            graph_out["edge_src_skin"] = edge_src_skin
 339            graph_out["edge_dst_skin"] = edge_dst_skin
 340            if apply_pbc:
 341                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
 342
 343        if self.k_space and apply_pbc:
 344            if "k_points" not in graph:
 345                ks, _, _, bewald = get_reciprocal_space_parameters(
 346                    reciprocal_cells, self.cutoff, self.kmax, self.kthr
 347                )
 348            graph_out["k_points"] = ks
 349            graph_out["b_ewald"] = bewald
 350
 351        output = {**inputs, self.graph_key: graph_out}
 352
 353        if return_state_update:
 354            return FrozenDict(new_state), output, state_up
 355        return FrozenDict(new_state), output
 356
 357    def check_reallocate(self, state, inputs, parent_overflow=False):
 358        """check for overflow and reallocate nblist if necessary"""
 359        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 360        if not overflow:
 361            return state, {}, inputs, False
 362
 363        add_margin = inputs[self.graph_key].get("overflow", False)
 364        state, inputs, state_up = self(
 365            state, inputs, return_state_update=True, add_margin=add_margin
 366        )
 367        return state, state_up, inputs, True
 368
 369    @partial(jax.jit, static_argnums=(0, 1))
 370    def process(self, state, inputs):
 371        """build a nblist on accelerator with jax and precomputed shapes"""
 372        if self.graph_key in inputs:
 373            graph = inputs[self.graph_key]
 374            if "keep_graph" in graph:
 375                return inputs
 376        coords = inputs["coordinates"]
 377        natoms = inputs["natoms"]
 378        batch_index = inputs["batch_index"]
 379
 380        if natoms.shape[0] == 1:
 381            max_nat = coords.shape[0]
 382        else:
 383            max_nat = state.get(
 384                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
 385            )
 386        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
 387
 388        ### compute indices of all pairs
 389        p1, p2 = np.triu_indices(max_nat, 1)
 390        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
 391        pbc_shifts = None
 392        if natoms.shape[0] > 1:
 393            ## batching => mask irrelevant pairs
 394            mask_p12 = (
 395                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
 396            ).flatten()
 397            shift = jnp.concatenate(
 398                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
 399            )
 400            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
 401            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
 402
 403        ## compute vectors
 404        overflow_repeats = jnp.asarray(False, dtype=bool)
 405        if "cells" not in inputs:
 406            vec = coords[p2] - coords[p1]
 407        else:
 408            cells = inputs["cells"]
 409            reciprocal_cells = inputs["reciprocal_cells"]
 410            minimage = state.get("minimum_image", True)
 411
 412            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
 413                vecpbc = jnp.dot(vec, reciprocal_cell)
 414                if mode == "round":
 415                    pbc_shifts = -jnp.round(vecpbc)
 416                elif mode == "floor":
 417                    pbc_shifts = -jnp.floor(vecpbc)
 418                else:
 419                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
 420                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
 421
 422            if minimage:
 423                ## minimum image convention
 424                vec = coords[p2] - coords[p1]
 425
 426                if cells.shape[0] == 1:
 427                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
 428                else:
 429                    batch_index_vec = batch_index[p1]
 430                    vec, pbc_shifts = jax.vmap(compute_pbc)(
 431                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
 432                    )
 433            else:
 434                ### general PBC only for single cell yet
 435                # if cells.shape[0] > 1:
 436                #     raise NotImplementedError(
 437                #         "General PBC not implemented for batches on accelerator."
 438                #     )
 439                # cell = cells[0]
 440                # reciprocal_cell = reciprocal_cells[0]
 441
 442                ## put all atoms in central box
 443                if cells.shape[0] == 1:
 444                    coords_pbc, at_shifts = compute_pbc(
 445                        coords, reciprocal_cells[0], cells[0], mode="floor"
 446                    )
 447                else:
 448                    coords_pbc, at_shifts = jax.vmap(
 449                        partial(compute_pbc, mode="floor")
 450                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
 451                vec = coords_pbc[p2] - coords_pbc[p1]
 452                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
 453                # if num_repeats is None:
 454                #     raise ValueError(
 455                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
 456                #     )
 457                # check if num_repeats is larger than previous
 458                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
 459                cdinv = cutoff_skin * inv_distances
 460                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
 461                if "true_sys" in inputs:
 462                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
 463                num_repeats_new = jnp.max(num_repeats_all, axis=0)
 464                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
 465
 466                cell_shift_pbc = jnp.asarray(
 467                    np.array(
 468                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
 469                        dtype=cells.dtype,
 470                    ).T.reshape(-1, 3)
 471                )
 472
 473                if cells.shape[0] == 1:
 474                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
 475                    pbc_shifts = jnp.broadcast_to(
 476                        cell_shift_pbc[None, :, :],
 477                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
 478                    ).reshape(-1, 3)
 479                    p1 = jnp.broadcast_to(
 480                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
 481                    ).flatten()
 482                    p2 = jnp.broadcast_to(
 483                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
 484                    ).flatten()
 485                    if natoms.shape[0] > 1:
 486                        mask_p12 = jnp.broadcast_to(
 487                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
 488                        ).flatten()
 489                else:
 490                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
 491
 492                    ## get pbc shifts specific to each box
 493                    cell_shift_pbc = jnp.broadcast_to(
 494                        cell_shift_pbc[None, :, :],
 495                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
 496                    )
 497                    mask = jnp.all(
 498                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
 499                    ).flatten()
 500                    max_shifts  = state.get("nshifts_pbc", 1)
 501
 502                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
 503                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
 504                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
 505                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
 506                        mask,
 507                        max_shifts,
 508                        (dvecx, 0.),
 509                        (dvecy, 0.),
 510                        (dvecz, 0.),
 511                        (shiftx, 0),
 512                        (shifty, 0),
 513                        (shiftz, 0),
 514                    )
 515                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
 516                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
 517                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
 518
 519                    ## get batch shift in the dvec_filter array
 520                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
 521                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
 522
 523                    ## repeat vectors
 524                    nvec_max = state.get("nvec_pbc", 1)
 525                    batch_index_vec = batch_index[p1]
 526                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
 527                    nvec = nrep_vec.sum()
 528                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
 529                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
 530                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
 531
 532                    ## get shift index
 533                    dshift = jnp.concatenate(
 534                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
 535                    )
 536                    if nrep_vec.size == 0:
 537                        dshift = jnp.array([],dtype=jnp.int32)
 538                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
 539                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
 540                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
 541                    vec = vec + dvec[icellshift]
 542                    pbc_shifts = cell_shift_pbc[icellshift]
 543                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
 544                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
 545                    if natoms.shape[0] > 1:
 546                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
 547                
 548
 549        ## compute distances
 550        d12 = (vec**2).sum(axis=-1)
 551        if natoms.shape[0] > 1:
 552            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
 553
 554        ## filter pairs
 555        max_pairs = state.get("npairs", 1)
 556        mask = d12 < cutoff_skin**2
 557        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
 558            mask,
 559            max_pairs,
 560            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
 561            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
 562            (d12, cutoff_skin**2),
 563        )
 564        if "cells" in inputs:
 565            pbc_shifts = (
 566                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
 567                .at[scatter_idx]
 568                .set(pbc_shifts, mode="drop")
 569            )
 570            if not minimage:
 571                pbc_shifts = (
 572                    pbc_shifts
 573                    + at_shifts.at[edge_dst].get(fill_value=0.0)
 574                    - at_shifts.at[edge_src].get(fill_value=0.0)
 575                )
 576
 577        ## check for overflow
 578        if natoms.shape[0] == 1:
 579            true_max_nat = coords.shape[0]
 580        else:
 581            true_max_nat = jnp.max(natoms)
 582        overflow_count = npairs > max_pairs
 583        overflow_at = true_max_nat > max_nat
 584        overflow = overflow_count | overflow_at | overflow_repeats
 585
 586        if "nblist_skin" in state:
 587            # edge_mask_skin = edge_mask
 588            edge_src_skin = edge_src
 589            edge_dst_skin = edge_dst
 590            if "cells" in inputs:
 591                pbc_shifts_skin = pbc_shifts
 592            max_pairs_skin = state.get("npairs_skin", 1)
 593            mask = d12 < self.cutoff**2
 594            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
 595                mask,
 596                max_pairs_skin,
 597                (edge_src, coords.shape[0]),
 598                (edge_dst, coords.shape[0]),
 599                (d12, self.cutoff**2),
 600            )
 601            if "cells" in inputs:
 602                pbc_shifts = (
 603                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
 604                    .at[scatter_idx]
 605                    .set(pbc_shifts, mode="drop")
 606                )
 607            overflow = overflow | (npairs_skin > max_pairs_skin)
 608
 609        ## symmetrize
 610        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
 611            (edge_dst, edge_src)
 612        )
 613        d12 = jnp.concatenate((d12, d12))
 614        if "cells" in inputs:
 615            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
 616
 617        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 618        graph_out = {
 619            **graph,
 620            "edge_src": edge_src,
 621            "edge_dst": edge_dst,
 622            "d12": d12,
 623            "overflow": overflow,
 624            "pbc_shifts": pbc_shifts,
 625        }
 626        if "nblist_skin" in state:
 627            graph_out["edge_src_skin"] = edge_src_skin
 628            graph_out["edge_dst_skin"] = edge_dst_skin
 629            if "cells" in inputs:
 630                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
 631
 632        if self.k_space and "cells" in inputs:
 633            if "k_points" not in graph:
 634                raise NotImplementedError(
 635                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 636                )
 637        return {**inputs, self.graph_key: graph_out}
 638
 639    @partial(jax.jit, static_argnums=(0,))
 640    def update_skin(self, inputs):
 641        """update the nblist without recomputing the full nblist"""
 642        graph = inputs[self.graph_key]
 643
 644        edge_src_skin = graph["edge_src_skin"]
 645        edge_dst_skin = graph["edge_dst_skin"]
 646        coords = inputs["coordinates"]
 647        vec = coords.at[edge_dst_skin].get(
 648            mode="fill", fill_value=self.cutoff
 649        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
 650
 651        if "cells" in inputs:
 652            pbc_shifts_skin = graph["pbc_shifts_skin"]
 653            cells = inputs["cells"]
 654            if cells.shape[0] == 1:
 655                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
 656            else:
 657                batch_index_vec = inputs["batch_index"][edge_src_skin]
 658                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
 659
 660        nat = coords.shape[0]
 661        d12 = jnp.sum(vec**2, axis=-1)
 662        mask = d12 < self.cutoff**2
 663        max_pairs = graph["edge_src"].shape[0] // 2
 664        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
 665            mask,
 666            max_pairs,
 667            (edge_src_skin, nat),
 668            (edge_dst_skin, nat),
 669            (d12, self.cutoff**2),
 670        )
 671        if "cells" in inputs:
 672            pbc_shifts = (
 673                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
 674                .at[scatter_idx]
 675                .set(pbc_shifts_skin)
 676            )
 677
 678        overflow = graph.get("overflow", False) | (npairs > max_pairs)
 679        graph_out = {
 680            **graph,
 681            "edge_src": jnp.concatenate((edge_src, edge_dst)),
 682            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
 683            "d12": jnp.concatenate((d12, d12)),
 684            "overflow": overflow,
 685        }
 686        if "cells" in inputs:
 687            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
 688
 689        if self.k_space and "cells" in inputs:
 690            if "k_points" not in graph:
 691                raise NotImplementedError(
 692                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 693                )
 694
 695        return {**inputs, self.graph_key: graph_out}
 696
 697
 698class GraphProcessor(nn.Module):
 699    """Process a pre-generated graph
 700
 701    The pre-generated graph should contain the following keys:
 702    - edge_src: source indices of the edges
 703    - edge_dst: destination indices of the edges
 704    - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs)
 705
 706    This module is automatically added to a FENNIX model when a GraphGenerator is used.
 707
 708    """
 709
 710    cutoff: float
 711    """Cutoff distance for the graph."""
 712    graph_key: str = "graph"
 713    """Key of the graph in the outputs."""
 714    switch_params: dict = dataclasses.field(default_factory=dict)
 715    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 716
 717    @nn.compact
 718    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 719        graph = inputs[self.graph_key]
 720        coords = inputs["coordinates"]
 721        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 722        # edge_mask = edge_src < coords.shape[0]
 723        vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[
 724            edge_src
 725        ].get(mode="fill", fill_value=0.0)
 726        if "cells" in inputs:
 727            cells = inputs["cells"]
 728            if cells.shape[0] == 1:
 729                vec = vec + jnp.dot(graph["pbc_shifts"], cells[0])
 730            else:
 731                batch_index_vec = inputs["batch_index"][edge_src]
 732                vec = vec + jax.vmap(jnp.dot)(
 733                    graph["pbc_shifts"], cells[batch_index_vec]
 734                )
 735
 736        d2  = jnp.sum(vec**2, axis=-1)
 737        distances = safe_sqrt(d2)
 738        edge_mask = distances < self.cutoff
 739
 740        switch = SwitchFunction(
 741            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
 742        )((distances, edge_mask))
 743
 744        graph_out = {
 745            **graph,
 746            "vec": vec,
 747            "distances": distances,
 748            "switch": switch,
 749            "edge_mask": edge_mask,
 750        }
 751
 752        if "alch_group" in inputs:
 753            alch_group = inputs["alch_group"]
 754            lambda_e = inputs["alch_elambda"]
 755            mask = alch_group[edge_src] == alch_group[edge_dst]
 756            graph_out["switch_raw"] = switch
 757            graph_out["switch"] = jnp.where(
 758                mask,
 759                switch,
 760                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
 761            )
 762
 763            if "alch_softcore_e" in inputs or "alch_softcore_v" in inputs:
 764                graph_out["distances_raw"] = distances
 765                if "alch_softcore_e" in inputs:
 766                    alch_alpha = (1-inputs["alch_elambda"])*inputs["alch_softcore_e"]**2
 767                else:
 768                    alch_alpha = (1-inputs["alch_vlambda"])*inputs["alch_softcore_v"]**2
 769                distances = jnp.where(
 770                    mask,
 771                    distances,
 772                    safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2))
 773                )  
 774                graph_out["distances"] = distances
 775
 776        return {**inputs, self.graph_key: graph_out}
 777
 778
 779@dataclasses.dataclass(frozen=True)
 780class GraphFilter:
 781    """Filter a graph based on a cutoff distance
 782
 783    FPID: GRAPH_FILTER
 784    """
 785
 786    cutoff: float
 787    """Cutoff distance for the filtering."""
 788    parent_graph: str
 789    """Key of the parent graph in the inputs."""
 790    graph_key: str
 791    """Key of the filtered graph in the outputs."""
 792    remove_hydrogens: int = False
 793    """Remove edges where the source is a hydrogen atom."""
 794    switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict)
 795    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 796    k_space: bool = False
 797    """Generate k-space information for the graph."""
 798    kmax: int = 30
 799    """Maximum number of k-points to consider."""
 800    kthr: float = 1e-6
 801    """Threshold for k-point filtering."""
 802    mult_size: float = 1.05
 803    """Multiplicative factor for resizing the nblist."""
 804
 805    FPID: ClassVar[str] = "GRAPH_FILTER"
 806
 807    def init(self):
 808        return FrozenDict(
 809            {
 810                "npairs": 1,
 811                "nblist_mult_size": self.mult_size,
 812            }
 813        )
 814
 815    def get_processor(self) -> Tuple[nn.Module, Dict]:
 816        return GraphFilterProcessor, {
 817            "cutoff": self.cutoff,
 818            "graph_key": self.graph_key,
 819            "parent_graph": self.parent_graph,
 820            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
 821            "switch_params": self.switch_params,
 822        }
 823
 824    def get_graph_properties(self):
 825        return {
 826            self.graph_key: {
 827                "cutoff": self.cutoff,
 828                "directed": True,
 829                "parent_graph": self.parent_graph,
 830            }
 831        }
 832
 833    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 834        """filter a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 835        graph_in = inputs[self.parent_graph]
 836        nat = inputs["species"].shape[0]
 837
 838        new_state = {**state}
 839        state_up = {}
 840        mult_size = state.get("nblist_mult_size", self.mult_size)
 841        assert mult_size >= 1., "nblist_mult_size should be >= 1."
 842
 843        edge_src = np.array(graph_in["edge_src"], dtype=np.int32)
 844        d12 = np.array(graph_in["d12"], dtype=np.float32)
 845        if self.remove_hydrogens:
 846            species = inputs["species"]
 847            src_idx = (edge_src < nat).nonzero()[0]
 848            mask = np.zeros(edge_src.shape[0], dtype=bool)
 849            mask[src_idx] = (species > 1)[edge_src[src_idx]]
 850            d12 = np.where(mask, d12, self.cutoff**2)
 851        mask = d12 < self.cutoff**2
 852
 853        max_pairs = state.get("npairs", 1)
 854        idx = np.nonzero(mask)[0]
 855        npairs = idx.shape[0]
 856        if npairs > max_pairs or add_margin:
 857            prev_max_pairs = max_pairs
 858            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 859            state_up["npairs"] = (max_pairs, prev_max_pairs)
 860            new_state["npairs"] = max_pairs
 861
 862        filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32)
 863        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 864        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 865        d12_ = np.full(max_pairs, self.cutoff**2)
 866        filter_indices[:npairs] = idx
 867        edge_src[:npairs] = graph_in["edge_src"][idx]
 868        edge_dst[:npairs] = graph_in["edge_dst"][idx]
 869        d12_[:npairs] = d12[idx]
 870        d12 = d12_
 871
 872        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 873        graph_out = {
 874            **graph,
 875            "edge_src": edge_src,
 876            "edge_dst": edge_dst,
 877            "filter_indices": filter_indices,
 878            "d12": d12,
 879            "overflow": False,
 880        }
 881
 882        if self.k_space and "cells" in inputs:
 883            if "k_points" not in graph:
 884                ks, _, _, bewald = get_reciprocal_space_parameters(
 885                    inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr
 886                )
 887            graph_out["k_points"] = ks
 888            graph_out["b_ewald"] = bewald
 889
 890        output = {**inputs, self.graph_key: graph_out}
 891        if return_state_update:
 892            return FrozenDict(new_state), output, state_up
 893        return FrozenDict(new_state), output
 894
 895    def check_reallocate(self, state, inputs, parent_overflow=False):
 896        """check for overflow and reallocate nblist if necessary"""
 897        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 898        if not overflow:
 899            return state, {}, inputs, False
 900
 901        add_margin = inputs[self.graph_key].get("overflow", False)
 902        state, inputs, state_up = self(
 903            state, inputs, return_state_update=True, add_margin=add_margin
 904        )
 905        return state, state_up, inputs, True
 906
 907    @partial(jax.jit, static_argnums=(0, 1))
 908    def process(self, state, inputs):
 909        """filter a nblist on accelerator with jax and precomputed shapes"""
 910        graph_in = inputs[self.parent_graph]
 911        if state is None:
 912            # skin update mode
 913            graph = inputs[self.graph_key]
 914            max_pairs = graph["edge_src"].shape[0]
 915        else:
 916            max_pairs = state.get("npairs", 1)
 917
 918        max_pairs_in = graph_in["edge_src"].shape[0]
 919        nat = inputs["species"].shape[0]
 920
 921        edge_src = graph_in["edge_src"]
 922        d12 = graph_in["d12"]
 923        if self.remove_hydrogens:
 924            species = inputs["species"]
 925            mask = (species > 1)[edge_src]
 926            d12 = jnp.where(mask, d12, self.cutoff**2)
 927        mask = d12 < self.cutoff**2
 928
 929        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
 930            mask,
 931            max_pairs,
 932            (edge_src, nat),
 933            (graph_in["edge_dst"], nat),
 934            (d12, self.cutoff**2),
 935            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
 936        )
 937
 938        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 939        overflow = graph.get("overflow", False) | (npairs > max_pairs)
 940        graph_out = {
 941            **graph,
 942            "edge_src": edge_src,
 943            "edge_dst": edge_dst,
 944            "filter_indices": filter_indices,
 945            "d12": d12,
 946            "overflow": overflow,
 947        }
 948
 949        if self.k_space and "cells" in inputs:
 950            if "k_points" not in graph:
 951                raise NotImplementedError(
 952                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 953                )
 954
 955        return {**inputs, self.graph_key: graph_out}
 956
 957    @partial(jax.jit, static_argnums=(0,))
 958    def update_skin(self, inputs):
 959        return self.process(None, inputs)
 960
 961
 962class GraphFilterProcessor(nn.Module):
 963    """Filter processing for a pre-generated graph
 964
 965    This module is automatically added to a FENNIX model when a GraphFilter is used.
 966    """
 967
 968    cutoff: float
 969    """Cutoff distance for the filtering."""
 970    graph_key: str
 971    """Key of the filtered graph in the inputs."""
 972    parent_graph: str
 973    """Key of the parent graph in the inputs."""
 974    switch_params: dict = dataclasses.field(default_factory=dict)
 975    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 976
 977    @nn.compact
 978    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 979        graph_in = inputs[self.parent_graph]
 980        graph = inputs[self.graph_key]
 981
 982        d_key = "distances_raw" if "distances_raw" in graph else "distances"
 983
 984        if graph_in["vec"].shape[0] == 0:
 985            vec = graph_in["vec"]
 986            distances = graph_in[d_key]
 987            filter_indices = jnp.asarray([], dtype=jnp.int32)
 988        else:
 989            filter_indices = graph["filter_indices"]
 990            vec = (
 991                graph_in["vec"]
 992                .at[filter_indices]
 993                .get(mode="fill", fill_value=self.cutoff)
 994            )
 995            distances = (
 996                graph_in[d_key]
 997                .at[filter_indices]
 998                .get(mode="fill", fill_value=self.cutoff)
 999            )
1000
1001        edge_mask = distances < self.cutoff
1002        switch = SwitchFunction(
1003            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
1004        )((distances, edge_mask))
1005
1006        graph_out = {
1007            **graph,
1008            "vec": vec,
1009            "distances": distances,
1010            "switch": switch,
1011            "filter_indices": filter_indices,
1012            "edge_mask": edge_mask,
1013        }
1014
1015        if "alch_group" in inputs:
1016            edge_src=graph["edge_src"]
1017            edge_dst=graph["edge_dst"]
1018            alch_group = inputs["alch_group"]
1019            lambda_e = inputs["alch_elambda"]
1020            mask = alch_group[edge_src] == alch_group[edge_dst]
1021            graph_out["switch_raw"] = switch
1022            graph_out["switch"] = jnp.where(
1023                mask,
1024                switch,
1025                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
1026            )
1027
1028            if "alch_softcore_e" in inputs or "alch_softcore_v" in inputs:
1029                graph_out["distances_raw"] = distances
1030                if "alch_softcore_e" in inputs:
1031                    alch_alpha = (1-inputs["alch_elambda"])*inputs["alch_softcore_e"]**2
1032                else:
1033                    alch_alpha = (1-inputs["alch_vlambda"])*inputs["alch_softcore_v"]**2
1034                distances = jnp.where(
1035                    mask,
1036                    distances,
1037                    safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2))
1038                )  
1039                graph_out["distances"] = distances
1040
1041        return {**inputs, self.graph_key: graph_out}
1042
1043
1044@dataclasses.dataclass(frozen=True)
1045class GraphAngularExtension:
1046    """Add angles list to a graph
1047
1048    FPID: GRAPH_ANGULAR_EXTENSION
1049    """
1050
1051    mult_size: float = 1.05
1052    """Multiplicative factor for resizing the nblist."""
1053    add_neigh: int = 5
1054    """Additional neighbors to add to the nblist when resizing."""
1055    graph_key: str = "graph"
1056    """Key of the graph in the inputs."""
1057
1058    FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION"
1059
1060    def init(self):
1061        return FrozenDict(
1062            {
1063                "nangles": 0,
1064                "nblist_mult_size": self.mult_size,
1065                "max_neigh": self.add_neigh,
1066                "add_neigh": self.add_neigh,
1067            }
1068        )
1069
1070    def get_processor(self) -> Tuple[nn.Module, Dict]:
1071        return GraphAngleProcessor, {
1072            "graph_key": self.graph_key,
1073            "name": f"{self.graph_key}_AngleProcessor",
1074        }
1075
1076    def get_graph_properties(self):
1077        return {
1078            self.graph_key: {
1079                "has_angles": True,
1080            }
1081        }
1082
1083    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1084        """build angle nblist on cpu with numpy and dynamic shapes + store max shapes"""
1085        graph = inputs[self.graph_key]
1086        edge_src = np.array(graph["edge_src"], dtype=np.int32)
1087
1088        new_state = {**state}
1089        state_up = {}
1090        mult_size = state.get("nblist_mult_size", self.mult_size)
1091        assert mult_size >= 1., "nblist_mult_size should be >= 1."
1092
1093        ### count number of neighbors
1094        nat = inputs["species"].shape[0]
1095        count = np.zeros(nat + 1, dtype=np.int32)
1096        np.add.at(count, edge_src, 1)
1097        max_count = int(np.max(count[:-1]))
1098
1099        ### get sizes
1100        max_neigh = state.get("max_neigh", self.add_neigh)
1101        nedge = edge_src.shape[0]
1102        if max_count > max_neigh or add_margin:
1103            prev_max_neigh = max_neigh
1104            max_neigh = max(max_count, max_neigh) + state.get(
1105                "add_neigh", self.add_neigh
1106            )
1107            state_up["max_neigh"] = (max_neigh, prev_max_neigh)
1108            new_state["max_neigh"] = max_neigh
1109
1110        max_neigh_arr = np.empty(max_neigh, dtype=bool)
1111
1112        nedge = edge_src.shape[0]
1113
1114        ### sort edge_src
1115        idx_sort = np.argsort(edge_src)
1116        edge_src_sorted = edge_src[idx_sort]
1117
1118        ### map sparse to dense nblist
1119        offset = np.tile(np.arange(max_count), nat)
1120        if max_count * nat >= nedge:
1121            offset = np.tile(np.arange(max_count), nat)[:nedge]
1122        else:
1123            offset = np.zeros(nedge, dtype=np.int32)
1124            offset[: max_count * nat] = np.tile(np.arange(max_count), nat)
1125
1126        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1127        mask = edge_src_sorted < nat
1128        indices = edge_src_sorted * max_count + offset
1129        indices = indices[mask]
1130        idx_sort = idx_sort[mask]
1131        edge_idx = np.full(nat * max_count, nedge, dtype=np.int32)
1132        edge_idx[indices] = idx_sort
1133        edge_idx = edge_idx.reshape(nat, max_count)
1134
1135        ### find all triplet for each atom center
1136        local_src, local_dst = np.triu_indices(max_count, 1)
1137        angle_src = edge_idx[:, local_src].flatten()
1138        angle_dst = edge_idx[:, local_dst].flatten()
1139
1140        ### mask for valid angles
1141        mask1 = angle_src < nedge
1142        mask2 = angle_dst < nedge
1143        angle_mask = mask1 & mask2
1144
1145        max_angles = state.get("nangles", 0)
1146        idx = np.nonzero(angle_mask)[0]
1147        nangles = idx.shape[0]
1148        if nangles > max_angles or add_margin:
1149            max_angles_prev = max_angles
1150            max_angles = int(mult_size * max(nangles, max_angles)) + 1
1151            state_up["nangles"] = (max_angles, max_angles_prev)
1152            new_state["nangles"] = max_angles
1153
1154        ## filter angles to sparse representation
1155        angle_src_ = np.full(max_angles, nedge, dtype=np.int32)
1156        angle_dst_ = np.full(max_angles, nedge, dtype=np.int32)
1157        angle_src_[:nangles] = angle_src[idx]
1158        angle_dst_[:nangles] = angle_dst[idx]
1159
1160        central_atom = np.full(max_angles, nat, dtype=np.int32)
1161        central_atom[:nangles] = edge_src[angle_src_[:nangles]]
1162
1163        ## update graph
1164        output = {
1165            **inputs,
1166            self.graph_key: {
1167                **graph,
1168                "angle_src": angle_src_,
1169                "angle_dst": angle_dst_,
1170                "central_atom": central_atom,
1171                "angle_overflow": False,
1172                "max_neigh": max_neigh,
1173                "__max_neigh_array": max_neigh_arr,
1174            },
1175        }
1176
1177        if return_state_update:
1178            return FrozenDict(new_state), output, state_up
1179        return FrozenDict(new_state), output
1180
1181    def check_reallocate(self, state, inputs, parent_overflow=False):
1182        """check for overflow and reallocate nblist if necessary"""
1183        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1184        if not overflow:
1185            return state, {}, inputs, False
1186
1187        add_margin = inputs[self.graph_key]["angle_overflow"]
1188        state, inputs, state_up = self(
1189            state, inputs, return_state_update=True, add_margin=add_margin
1190        )
1191        return state, state_up, inputs, True
1192
1193    @partial(jax.jit, static_argnums=(0, 1))
1194    def process(self, state, inputs):
1195        """build angle nblist on accelerator with jax and precomputed shapes"""
1196        graph = inputs[self.graph_key]
1197        edge_src = graph["edge_src"]
1198
1199        ### count number of neighbors
1200        nat = inputs["species"].shape[0]
1201        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1202        max_count = jnp.max(count)
1203
1204        ### get sizes
1205        if state is None:
1206            max_neigh_arr = graph["__max_neigh_array"]
1207            max_neigh = max_neigh_arr.shape[0]
1208            prev_nangles = graph["angle_src"].shape[0]
1209        else:
1210            max_neigh = state.get("max_neigh", self.add_neigh)
1211            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1212            prev_nangles = state.get("nangles", 0)
1213
1214        nedge = edge_src.shape[0]
1215
1216        ### sort edge_src
1217        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1218        edge_src_sorted = edge_src[idx_sort]
1219
1220        ### map sparse to dense nblist
1221        if max_neigh * nat < nedge:
1222            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1223        offset = jnp.asarray(
1224            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1225        )
1226        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1227        indices = edge_src_sorted * max_neigh + offset
1228        edge_idx = (
1229            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1230            .at[indices]
1231            .set(idx_sort, mode="drop")
1232            .reshape(nat, max_neigh)
1233        )
1234
1235        ### find all triplet for each atom center
1236        local_src, local_dst = np.triu_indices(max_neigh, 1)
1237        angle_src = edge_idx[:, local_src].flatten()
1238        angle_dst = edge_idx[:, local_dst].flatten()
1239
1240        ### mask for valid angles
1241        mask1 = angle_src < nedge
1242        mask2 = angle_dst < nedge
1243        angle_mask = mask1 & mask2
1244
1245        ## filter angles to sparse representation
1246        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1247            angle_mask,
1248            prev_nangles,
1249            (angle_src, nedge),
1250            (angle_dst, nedge),
1251        )
1252        ## find central atom
1253        central_atom = edge_src[angle_src]
1254
1255        ## check for overflow
1256        angle_overflow = nangles > prev_nangles
1257        neigh_overflow = max_count > max_neigh
1258        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1259
1260        ## update graph
1261        output = {
1262            **inputs,
1263            self.graph_key: {
1264                **graph,
1265                "angle_src": angle_src,
1266                "angle_dst": angle_dst,
1267                "central_atom": central_atom,
1268                "angle_overflow": overflow,
1269                # "max_neigh": max_neigh,
1270                "__max_neigh_array": max_neigh_arr,
1271            },
1272        }
1273
1274        return output
1275
1276    @partial(jax.jit, static_argnums=(0,))
1277    def update_skin(self, inputs):
1278        return self.process(None, inputs)
1279
1280
1281class GraphAngleProcessor(nn.Module):
1282    """Process a pre-generated graph to compute angles
1283
1284    This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
1285
1286    """
1287
1288    graph_key: str
1289    """Key of the graph in the inputs."""
1290
1291    @nn.compact
1292    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1293        graph = inputs[self.graph_key]
1294        distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"]
1295        vec = graph["vec"]
1296        angle_src = graph["angle_src"]
1297        angle_dst = graph["angle_dst"]
1298
1299        dir = vec / jnp.clip(distances[:, None], min=1.0e-5)
1300        cos_angles = (
1301            dir.at[angle_src].get(mode="fill", fill_value=0.5)
1302            * dir.at[angle_dst].get(mode="fill", fill_value=0.5)
1303        ).sum(axis=-1)
1304
1305        angles = jnp.arccos(0.95 * cos_angles)
1306
1307        return {
1308            **inputs,
1309            self.graph_key: {
1310                **graph,
1311                # "cos_angles": cos_angles,
1312                "angles": angles,
1313                # "angle_mask": angle_mask,
1314            },
1315        }
1316
1317
1318@dataclasses.dataclass(frozen=True)
1319class SpeciesIndexer:
1320    """Build an index that splits atomic arrays by species.
1321
1322    FPID: SPECIES_INDEXER
1323
1324    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1325    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1326
1327    """
1328
1329    output_key: str = "species_index"
1330    """Key for the output dictionary."""
1331    species_order: Optional[str] = None
1332    """Comma separated list of species in the order they should be indexed."""
1333    add_atoms: int = 0
1334    """Additional atoms to add to the sizes."""
1335    add_atoms_margin: int = 10
1336    """Additional atoms to add to the sizes when adding margin."""
1337
1338    FPID: ClassVar[str] = "SPECIES_INDEXER"
1339
1340    def init(self):
1341        return FrozenDict(
1342            {
1343                "sizes": {},
1344            }
1345        )
1346
1347    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1348        species = np.array(inputs["species"], dtype=np.int32)
1349        nat = species.shape[0]
1350        set_species, counts = np.unique(species, return_counts=True)
1351
1352        new_state = {**state}
1353        state_up = {}
1354
1355        sizes = state.get("sizes", FrozenDict({}))
1356        new_sizes = {**sizes}
1357        up_sizes = False
1358        counts_dict = {}
1359        for s, c in zip(set_species, counts):
1360            if s <= 0:
1361                continue
1362            counts_dict[s] = c
1363            if c > sizes.get(s, 0):
1364                up_sizes = True
1365                add_atoms = state.get("add_atoms", self.add_atoms)
1366                if add_margin:
1367                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1368                new_sizes[s] = c + add_atoms
1369
1370        new_sizes = FrozenDict(new_sizes)
1371        if up_sizes:
1372            state_up["sizes"] = (new_sizes, sizes)
1373            new_state["sizes"] = new_sizes
1374
1375        if self.species_order is not None:
1376            species_order = [el.strip() for el in self.species_order.split(",")]
1377            max_size_prev = state.get("max_size", 0)
1378            max_size = max(new_sizes.values())
1379            if max_size > max_size_prev:
1380                state_up["max_size"] = (max_size, max_size_prev)
1381                new_state["max_size"] = max_size
1382                max_size_prev = max_size
1383
1384            species_index = np.full((len(species_order), max_size), nat, dtype=np.int32)
1385            for i, el in enumerate(species_order):
1386                s = PERIODIC_TABLE_REV_IDX[el]
1387                if s in counts_dict.keys():
1388                    species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1389        else:
1390            species_index = {
1391                PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32)
1392                for s, c in new_sizes.items()
1393            }
1394            for s, c in zip(set_species, counts):
1395                if s <= 0:
1396                    continue
1397                species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0]
1398
1399        output = {
1400            **inputs,
1401            self.output_key: species_index,
1402            self.output_key + "_overflow": False,
1403        }
1404
1405        if return_state_update:
1406            return FrozenDict(new_state), output, state_up
1407        return FrozenDict(new_state), output
1408
1409    def check_reallocate(self, state, inputs, parent_overflow=False):
1410        """check for overflow and reallocate nblist if necessary"""
1411        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1412        if not overflow:
1413            return state, {}, inputs, False
1414
1415        add_margin = inputs[self.output_key + "_overflow"]
1416        state, inputs, state_up = self(
1417            state, inputs, return_state_update=True, add_margin=add_margin
1418        )
1419        return state, state_up, inputs, True
1420        # return state, {}, inputs, parent_overflow
1421
1422    @partial(jax.jit, static_argnums=(0, 1))
1423    def process(self, state, inputs):
1424        # assert (
1425        #     self.output_key in inputs
1426        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1427
1428        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1429        if self.output_key in inputs and not recompute_species_index:
1430            return inputs
1431
1432        if state is None:
1433            raise ValueError("Species Indexer state must be provided on accelerator.")
1434
1435        species = inputs["species"]
1436        nat = species.shape[0]
1437
1438        sizes = state["sizes"]
1439
1440        if self.species_order is not None:
1441            species_order = [el.strip() for el in self.species_order.split(",")]
1442            max_size = state["max_size"]
1443
1444            species_index = jnp.full(
1445                (len(species_order), max_size), nat, dtype=jnp.int32
1446            )
1447            for i, el in enumerate(species_order):
1448                s = PERIODIC_TABLE_REV_IDX[el]
1449                if s in sizes.keys():
1450                    c = sizes[s]
1451                    species_index = species_index.at[i, :].set(
1452                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1453                    )
1454                # if s in counts_dict.keys():
1455                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1456        else:
1457            # species_index = {
1458            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1459            # for s, c in sizes.items()
1460            # }
1461            species_index = {}
1462            overflow = False
1463            natcount = 0
1464            for s, c in sizes.items():
1465                mask = species == s
1466                new_size = jnp.sum(mask)
1467                natcount = natcount + new_size
1468                overflow = overflow | (new_size > c)  # check if sizes are correct
1469                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1470                    species == s, size=c, fill_value=nat
1471                )[0]
1472
1473            mask = species <= 0
1474            new_size = jnp.sum(mask)
1475            natcount = natcount + new_size
1476            overflow = overflow | (
1477                natcount < species.shape[0]
1478            )  # check if any species missing
1479
1480        return {
1481            **inputs,
1482            self.output_key: species_index,
1483            self.output_key + "_overflow": overflow,
1484        }
1485
1486    @partial(jax.jit, static_argnums=(0,))
1487    def update_skin(self, inputs):
1488        return self.process(None, inputs)
1489
1490@dataclasses.dataclass(frozen=True)
1491class BlockIndexer:
1492    """Build an index that splits atomic arrays by chemical blocks.
1493
1494    FPID: BLOCK_INDEXER
1495
1496    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1497    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1498
1499    """
1500
1501    output_key: str = "block_index"
1502    """Key for the output dictionary."""
1503    add_atoms: int = 0
1504    """Additional atoms to add to the sizes."""
1505    add_atoms_margin: int = 10
1506    """Additional atoms to add to the sizes when adding margin."""
1507    split_CNOPSSe: bool = False
1508
1509    FPID: ClassVar[str] = "BLOCK_INDEXER"
1510
1511    def init(self):
1512        return FrozenDict(
1513            {
1514                "sizes": {},
1515            }
1516        )
1517
1518    def build_chemical_blocks(self):
1519        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1520        if self.split_CNOPSSe:
1521            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1522            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1523        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1524        if self.split_CNOPSSe:
1525            _CHEMICAL_BLOCKS[6] = 1
1526            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1527            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1528            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1529            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1530            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1531        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1532
1533    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1534        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1535
1536        species = np.array(inputs["species"], dtype=np.int32)
1537        blocks = _CHEMICAL_BLOCKS[species]
1538        nat = species.shape[0]
1539        set_blocks, counts = np.unique(blocks, return_counts=True)
1540
1541        new_state = {**state}
1542        state_up = {}
1543
1544        sizes = state.get("sizes", FrozenDict({}))
1545        new_sizes = {**sizes}
1546        up_sizes = False
1547        for s, c in zip(set_blocks, counts):
1548            if s < 0:
1549                continue
1550            key = (s, _CHEMICAL_BLOCKS_NAMES[s])
1551            if c > sizes.get(key, 0):
1552                up_sizes = True
1553                add_atoms = state.get("add_atoms", self.add_atoms)
1554                if add_margin:
1555                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1556                new_sizes[key] = c + add_atoms
1557
1558        new_sizes = FrozenDict(new_sizes)
1559        if up_sizes:
1560            state_up["sizes"] = (new_sizes, sizes)
1561            new_state["sizes"] = new_sizes
1562
1563        block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES}
1564        for (_,n), c in new_sizes.items():
1565            block_index[n] = np.full(c, nat, dtype=np.int32)
1566        # block_index = {
1567            # n: np.full(c, nat, dtype=np.int32)
1568            # for (_,n), c in new_sizes.items()
1569        # }
1570        for s, c in zip(set_blocks, counts):
1571            if s < 0:
1572                continue
1573            block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0]
1574
1575        output = {
1576            **inputs,
1577            self.output_key: block_index,
1578            self.output_key + "_overflow": False,
1579        }
1580
1581        if return_state_update:
1582            return FrozenDict(new_state), output, state_up
1583        return FrozenDict(new_state), output
1584
1585    def check_reallocate(self, state, inputs, parent_overflow=False):
1586        """check for overflow and reallocate nblist if necessary"""
1587        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1588        if not overflow:
1589            return state, {}, inputs, False
1590
1591        add_margin = inputs[self.output_key + "_overflow"]
1592        state, inputs, state_up = self(
1593            state, inputs, return_state_update=True, add_margin=add_margin
1594        )
1595        return state, state_up, inputs, True
1596        # return state, {}, inputs, parent_overflow
1597
1598    @partial(jax.jit, static_argnums=(0, 1))
1599    def process(self, state, inputs):
1600        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1601        # assert (
1602        #     self.output_key in inputs
1603        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1604
1605        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1606        if self.output_key in inputs and not recompute_species_index:
1607            return inputs
1608
1609        if state is None:
1610            raise ValueError("Block Indexer state must be provided on accelerator.")
1611
1612        species = inputs["species"]
1613        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1614        nat = species.shape[0]
1615
1616        sizes = state["sizes"]
1617
1618        # species_index = {
1619        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1620        # for s, c in sizes.items()
1621        # }
1622        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1623        overflow = False
1624        natcount = 0
1625        for (s,name), c in sizes.items():
1626            mask = blocks == s
1627            new_size = jnp.sum(mask)
1628            natcount = natcount + new_size
1629            overflow = overflow | (new_size > c)  # check if sizes are correct
1630            block_index[name] = jnp.nonzero(
1631                mask, size=c, fill_value=nat
1632            )[0]
1633
1634        mask = blocks < 0
1635        new_size = jnp.sum(mask)
1636        natcount = natcount + new_size
1637        overflow = overflow | (
1638            natcount < species.shape[0]
1639        )  # check if any species missing
1640
1641        return {
1642            **inputs,
1643            self.output_key: block_index,
1644            self.output_key + "_overflow": overflow,
1645        }
1646
1647    @partial(jax.jit, static_argnums=(0,))
1648    def update_skin(self, inputs):
1649        return self.process(None, inputs)
1650
1651
1652@dataclasses.dataclass(frozen=True)
1653class AtomPadding:
1654    """Pad atomic arrays to a fixed size."""
1655
1656    mult_size: float = 1.2
1657    """Multiplicative factor for resizing the atomic arrays."""
1658    add_sys: int = 0
1659
1660    def init(self):
1661        return {"prev_nat": 0, "prev_nsys": 0}
1662
1663    def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]:
1664        species = inputs["species"]
1665        nat = species.shape[0]
1666
1667        prev_nat = state.get("prev_nat", 0)
1668        prev_nat_ = prev_nat
1669        if nat > prev_nat_:
1670            prev_nat_ = int(self.mult_size * nat) + 1
1671
1672        nsys = len(inputs["natoms"])
1673        prev_nsys = state.get("prev_nsys", 0)
1674        prev_nsys_ = prev_nsys
1675        if nsys > prev_nsys_:
1676            prev_nsys_ = nsys + self.add_sys
1677
1678        add_atoms = prev_nat_ - nat
1679        add_sys = prev_nsys_ - nsys  + 1
1680        output = {**inputs}
1681        if add_atoms > 0:
1682            for k, v in inputs.items():
1683                if isinstance(v, np.ndarray) or isinstance(v, jax.Array):
1684                    if v.shape[0] == nat:
1685                        output[k] = np.append(
1686                            v,
1687                            np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype),
1688                            axis=0,
1689                        )
1690                    elif v.shape[0] == nsys:
1691                        if k == "cells":
1692                            output[k] = np.append(
1693                                v,
1694                                1000
1695                                * np.eye(3, dtype=v.dtype)[None, :, :].repeat(
1696                                    add_sys, axis=0
1697                                ),
1698                                axis=0,
1699                            )
1700                        else:
1701                            output[k] = np.append(
1702                                v,
1703                                np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype),
1704                                axis=0,
1705                            )
1706            output["natoms"] = np.append(
1707                inputs["natoms"], np.zeros(add_sys, dtype=np.int32)
1708            )
1709            output["species"] = np.append(
1710                species, -1 * np.ones(add_atoms, dtype=species.dtype)
1711            )
1712            output["batch_index"] = np.append(
1713                inputs["batch_index"],
1714                np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype),
1715            )
1716            if "system_index" in inputs:
1717                output["system_index"] = np.append(
1718                    inputs["system_index"],
1719                    np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype),
1720                )
1721
1722        output["true_atoms"] = output["species"] > 0
1723        output["true_sys"] = np.arange(len(output["natoms"])) < nsys
1724
1725        state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_}
1726
1727        return FrozenDict(state), output
1728
1729
1730def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1731    """Remove padding from atomic arrays."""
1732    if "true_atoms" not in inputs:
1733        return inputs
1734
1735    species = inputs["species"]
1736    true_atoms = inputs["true_atoms"]
1737    true_sys = inputs["true_sys"]
1738    natall = species.shape[0]
1739    nat = np.argmax(species <= 0)
1740    if nat == 0:
1741        return inputs
1742
1743    natoms = inputs["natoms"]
1744    nsysall = len(natoms)
1745
1746    output = {**inputs}
1747    for k, v in inputs.items():
1748        if isinstance(v, jax.Array) or isinstance(v, np.ndarray):
1749            if v.ndim == 0:
1750                continue
1751            if v.shape[0] == natall:
1752                output[k] = v[true_atoms]
1753            elif v.shape[0] == nsysall:
1754                output[k] = v[true_sys]
1755    del output["true_sys"]
1756    del output["true_atoms"]
1757    return output
1758
1759
1760def check_input(inputs):
1761    """Check the input dictionary for required keys and types."""
1762    assert "species" in inputs, "species must be provided"
1763    assert "coordinates" in inputs, "coordinates must be provided"
1764    species = inputs["species"].astype(np.int32)
1765    ifake = np.argmax(species <= 0)
1766    if ifake > 0:
1767        assert np.all(species[:ifake] > 0), "species must be positive"
1768    nat = inputs["species"].shape[0]
1769
1770    natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32)
1771    batch_index = inputs.get(
1772        "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms)
1773    ).astype(np.int32)
1774    output = {**inputs, "natoms": natoms, "batch_index": batch_index}
1775    if "cells" in inputs:
1776        cells = inputs["cells"]
1777        if "reciprocal_cells" not in inputs:
1778            reciprocal_cells = np.linalg.inv(cells)
1779        else:
1780            reciprocal_cells = inputs["reciprocal_cells"]
1781        if cells.ndim == 2:
1782            cells = cells[None, :, :]
1783        if reciprocal_cells.ndim == 2:
1784            reciprocal_cells = reciprocal_cells[None, :, :]
1785        output["cells"] = cells
1786        output["reciprocal_cells"] = reciprocal_cells
1787
1788    return output
1789
1790
1791def convert_to_jax(data):
1792    """Convert a numpy arrays to jax arrays in a pytree."""
1793
1794    def convert(x):
1795        if isinstance(x, np.ndarray):
1796            # if x.dtype == np.float64:
1797            #     return jnp.asarray(x, dtype=jnp.float32)
1798            return jnp.asarray(x)
1799        return x
1800
1801    return jax.tree_util.tree_map(convert, data)
1802
1803
1804class JaxConverter(nn.Module):
1805    """Convert numpy arrays to jax arrays in a pytree."""
1806
1807    def __call__(self, data):
1808        return convert_to_jax(data)
1809
1810
1811@dataclasses.dataclass(frozen=True)
1812class PreprocessingChain:
1813    """Chain of preprocessing layers."""
1814
1815    layers: Tuple[Callable[..., Dict[str, Any]]]
1816    """Preprocessing layers."""
1817    use_atom_padding: bool = False
1818    """Add an AtomPadding layer at the beginning of the chain."""
1819    atom_padder: AtomPadding = AtomPadding()
1820    """AtomPadding layer."""
1821
1822    def __post_init__(self):
1823        if not isinstance(self.layers, Sequence):
1824            raise ValueError(
1825                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
1826            )
1827        if not self.layers:
1828            raise ValueError(f"Error: no Preprocessing layers were provided.")
1829
1830    def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]:
1831        do_check_input = state.get("check_input", True)
1832        if do_check_input:
1833            inputs = check_input(inputs)
1834        new_state = []
1835        layer_state = state["layers_state"]
1836        i = 0
1837        if self.use_atom_padding:
1838            s, inputs = self.atom_padder(layer_state[0], inputs)
1839            new_state.append(s)
1840            i += 1
1841        for layer in self.layers:
1842            s, inputs = layer(layer_state[i], inputs, return_state_update=False)
1843            new_state.append(s)
1844            i += 1
1845        return FrozenDict({**state, "layers_state": tuple(new_state)}), convert_to_jax(
1846            inputs
1847        )
1848
1849    def check_reallocate(self, state, inputs):
1850        new_state = []
1851        state_up = []
1852        layer_state = state["layers_state"]
1853        i = 0
1854        if self.use_atom_padding:
1855            new_state.append(layer_state[0])
1856            i += 1
1857        parent_overflow = False
1858        for layer in self.layers:
1859            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1860                layer_state[i], inputs, parent_overflow
1861            )
1862            new_state.append(s)
1863            state_up.append(s_up)
1864            i += 1
1865
1866        if not parent_overflow:
1867            return state, {}, inputs, False
1868        return (
1869            FrozenDict({**state, "layers_state": tuple(new_state)}),
1870            state_up,
1871            inputs,
1872            True,
1873        )
1874
1875    def atom_padding(self, state, inputs):
1876        if self.use_atom_padding:
1877            padder_state = state["layers_state"][0]
1878            return self.atom_padder(padder_state, inputs)
1879        return state, inputs
1880
1881    @partial(jax.jit, static_argnums=(0, 1))
1882    def process(self, state, inputs):
1883        layer_state = state["layers_state"]
1884        i = 1 if self.use_atom_padding else 0
1885        for layer in self.layers:
1886            inputs = layer.process(layer_state[i], inputs)
1887            i += 1
1888        return inputs
1889
1890    @partial(jax.jit, static_argnums=(0))
1891    def update_skin(self, inputs):
1892        for layer in self.layers:
1893            inputs = layer.update_skin(inputs)
1894        return inputs
1895
1896    def init(self):
1897        state = []
1898        if self.use_atom_padding:
1899            state.append(self.atom_padder.init())
1900        for layer in self.layers:
1901            state.append(layer.init())
1902        return FrozenDict({"check_input": True, "layers_state": state})
1903
1904    def init_with_output(self, inputs):
1905        state = self.init()
1906        return self(state, inputs)
1907
1908    def get_processors(self):
1909        processors = []
1910        for layer in self.layers:
1911            if hasattr(layer, "get_processor"):
1912                processors.append(layer.get_processor())
1913        return processors
1914
1915    def get_graphs_properties(self):
1916        properties = {}
1917        for layer in self.layers:
1918            if hasattr(layer, "get_graph_properties"):
1919                properties = deep_update(properties, layer.get_graph_properties())
1920        return properties
1921
1922
1923# PREPROCESSING = {
1924#     "GRAPH": GraphGenerator,
1925#     # "GRAPH_FIXED": GraphGeneratorFixed,
1926#     "GRAPH_FILTER": GraphFilter,
1927#     "GRAPH_ANGULAR_EXTENSION": GraphAngularExtension,
1928#     # "GRAPH_DENSE_EXTENSION": GraphDenseExtension,
1929#     "SPECIES_INDEXER": SpeciesIndexer,
1930# }
@dataclasses.dataclass(frozen=True)
class GraphGenerator:
 22@dataclasses.dataclass(frozen=True)
 23class GraphGenerator:
 24    """Generate a graph from a set of coordinates
 25
 26    FPID: GRAPH
 27
 28    For now, we generate all pairs of atoms and filter based on cutoff.
 29    If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist.
 30    """
 31
 32    cutoff: float
 33    """Cutoff distance for the graph."""
 34    graph_key: str = "graph"
 35    """Key of the graph in the outputs."""
 36    switch_params: dict = dataclasses.field(default_factory=dict, hash=False)
 37    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 38    kmax: int = 30
 39    """Maximum number of k-points to consider."""
 40    kthr: float = 1e-6
 41    """Threshold for k-point filtering."""
 42    k_space: bool = False
 43    """Whether to generate k-space information for the graph."""
 44    mult_size: float = 1.05
 45    """Multiplicative factor for resizing the nblist."""
 46    # covalent_cutoff: bool = False
 47
 48    FPID: ClassVar[str] = "GRAPH"
 49
 50    def init(self):
 51        return FrozenDict(
 52            {
 53                "max_nat": 1,
 54                "npairs": 1,
 55                "nblist_mult_size": self.mult_size,
 56            }
 57        )
 58
 59    def get_processor(self) -> Tuple[nn.Module, Dict]:
 60        return GraphProcessor, {
 61            "cutoff": self.cutoff,
 62            "graph_key": self.graph_key,
 63            "switch_params": self.switch_params,
 64            "name": f"{self.graph_key}_Processor",
 65        }
 66
 67    def get_graph_properties(self):
 68        return {
 69            self.graph_key: {
 70                "cutoff": self.cutoff,
 71                "directed": True,
 72            }
 73        }
 74
 75    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 76        """build a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 77        if self.graph_key in inputs:
 78            graph = inputs[self.graph_key]
 79            if "keep_graph" in graph:
 80                return state, inputs
 81
 82        coords = np.array(inputs["coordinates"], dtype=np.float32)
 83        natoms = np.array(inputs["natoms"], dtype=np.int32)
 84        batch_index = np.array(inputs["batch_index"], dtype=np.int32)
 85
 86        new_state = {**state}
 87        state_up = {}
 88
 89        mult_size = state.get("nblist_mult_size", self.mult_size)
 90        assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0"
 91
 92        if natoms.shape[0] == 1:
 93            max_nat = coords.shape[0]
 94            true_max_nat = max_nat
 95        else:
 96            max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0]))
 97            true_max_nat = int(np.max(natoms))
 98            if true_max_nat > max_nat:
 99                add_atoms = state.get("add_atoms", 0)
100                new_maxnat = true_max_nat + add_atoms
101                state_up["max_nat"] = (new_maxnat, max_nat)
102                new_state["max_nat"] = new_maxnat
103
104        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
105
106        ### compute indices of all pairs
107        p1, p2 = np.triu_indices(true_max_nat, 1)
108        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
109        pbc_shifts = None
110        if natoms.shape[0] > 1:
111            ## batching => mask irrelevant pairs
112            mask_p12 = (
113                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
114            ).flatten()
115            shift = np.concatenate(
116                (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32))
117            )
118            p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
119            p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
120
121        apply_pbc = "cells" in inputs
122        if not apply_pbc:
123            ### NO PBC
124            vec = coords[p2] - coords[p1]
125        else:
126            cells = np.array(inputs["cells"], dtype=np.float32)
127            reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32)
128            minimage = state.get("minimum_image", True)
129            if minimage:
130                ## MINIMUM IMAGE CONVENTION
131                vec = coords[p2] - coords[p1]
132                if cells.shape[0] == 1:
133                    vecpbc = np.dot(vec, reciprocal_cells[0])
134                    pbc_shifts = -np.round(vecpbc)
135                    vec = vec + np.dot(pbc_shifts, cells[0])
136                else:
137                    batch_index_vec = batch_index[p1]
138                    vecpbc = np.einsum(
139                        "aj,aji->ai", vec, reciprocal_cells[batch_index_vec]
140                    )
141                    pbc_shifts = -np.round(vecpbc)
142                    vec = vec + np.einsum(
143                        "aj,aji->ai", pbc_shifts, cells[batch_index_vec]
144                    )
145            else:
146                ### GENERAL PBC
147                ## put all atoms in central box
148                if cells.shape[0] == 1:
149                    coords_pbc = np.dot(coords, reciprocal_cells[0])
150                    at_shifts = -np.floor(coords_pbc)
151                    coords_pbc = coords + np.dot(at_shifts, cells[0])
152                else:
153                    coords_pbc = np.einsum(
154                        "aj,aji->ai", coords, reciprocal_cells[batch_index]
155                    )
156                    at_shifts = -np.floor(coords_pbc)
157                    coords_pbc = coords + np.einsum(
158                        "aj,aji->ai", at_shifts, cells[batch_index]
159                    )
160                vec = coords_pbc[p2] - coords_pbc[p1]
161
162                ## compute maximum number of repeats
163                inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5
164                cdinv = cutoff_skin * inv_distances
165                num_repeats_all = np.ceil(cdinv).astype(np.int32)
166                if "true_sys" in inputs:
167                    num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0)
168                # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all)
169                num_repeats = np.max(num_repeats_all, axis=0)
170                num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0)))
171                if np.any(num_repeats > num_repeats_prev):
172                    num_repeats_new = np.maximum(num_repeats, num_repeats_prev)
173                    state_up["num_repeats_pbc"] = (
174                        tuple(num_repeats_new),
175                        tuple(num_repeats_prev),
176                    )
177                    new_state["num_repeats_pbc"] = tuple(num_repeats_new)
178                ## build all possible shifts
179                cell_shift_pbc = np.array(
180                    np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
181                    dtype=cells.dtype,
182                ).T.reshape(-1, 3)
183                ## shift applied to vectors
184                if cells.shape[0] == 1:
185                    dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :]
186                    vec = (vec[:, None, :] + dvec).reshape(-1, 3)
187                    pbc_shifts = np.broadcast_to(
188                        cell_shift_pbc[None, :, :],
189                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
190                    ).reshape(-1, 3)
191                    p1 = np.broadcast_to(
192                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
193                    ).flatten()
194                    p2 = np.broadcast_to(
195                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
196                    ).flatten()
197                    if natoms.shape[0] > 1:
198                        mask_p12 = np.broadcast_to(
199                            mask_p12[:, None],
200                            (mask_p12.shape[0], cell_shift_pbc.shape[0]),
201                        ).flatten()
202                else:
203                    dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells)
204
205                    ## get pbc shifts specific to each box
206                    cell_shift_pbc = np.broadcast_to(
207                        cell_shift_pbc[None, :, :],
208                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
209                    )
210                    mask = np.all(
211                        np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
212                    ).flatten()
213                    idx = np.nonzero(mask)[0]
214                    nshifts = idx.shape[0]
215                    nshifts_prev = state.get("nshifts_pbc", 0)
216                    if nshifts > nshifts_prev or add_margin:
217                        nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1
218                        state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev)
219                        new_state["nshifts_pbc"] = nshifts_new
220
221                    dvec_filter = dvec.reshape(-1, 3)[idx, :]
222                    cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :]
223
224                    ## get batch shift in the dvec_filter array
225                    nrep = np.prod(2 * num_repeats_all + 1, axis=1)
226                    bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1]))
227
228                    ## compute vectors
229                    batch_index_vec = batch_index[p1]
230                    nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0)
231                    vec = vec.repeat(nrep_vec, axis=0)
232                    nvec_pbc = nrep_vec.sum() #vec.shape[0]
233                    nvec_pbc_prev = state.get("nvec_pbc", 0)
234                    if nvec_pbc > nvec_pbc_prev or add_margin:
235                        nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1
236                        state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev)
237                        new_state["nvec_pbc"] = nvec_pbc_new
238
239                    # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev)
240                    ## get shift index
241                    dshift = np.concatenate(
242                        (np.array([0]), np.cumsum(nrep_vec)[:-1])
243                    ).repeat(nrep_vec)
244                    # ishift = np.arange(dshift.shape[0])-dshift
245                    # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec)
246                    icellshift = (
247                        np.arange(dshift.shape[0])
248                        - dshift
249                        + bshift[batch_index_vec].repeat(nrep_vec)
250                    )
251                    # shift vectors
252                    vec = vec + dvec_filter[icellshift]
253                    pbc_shifts = cell_shift_pbc_filter[icellshift]
254
255                    p1 = np.repeat(p1, nrep_vec)
256                    p2 = np.repeat(p2, nrep_vec)
257                    if natoms.shape[0] > 1:
258                        mask_p12 = np.repeat(mask_p12, nrep_vec)
259
260        ## compute distances
261        d12 = (vec**2).sum(axis=-1)
262        if natoms.shape[0] > 1:
263            d12 = np.where(mask_p12, d12, cutoff_skin**2)
264
265        ## filter pairs
266        max_pairs = state.get("npairs", 1)
267        mask = d12 < cutoff_skin**2
268        idx = np.nonzero(mask)[0]
269        npairs = idx.shape[0]
270        if npairs > max_pairs or add_margin:
271            prev_max_pairs = max_pairs
272            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
273            state_up["npairs"] = (max_pairs, prev_max_pairs)
274            new_state["npairs"] = max_pairs
275
276        nat = coords.shape[0]
277        edge_src = np.full(max_pairs, nat, dtype=np.int32)
278        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
279        d12_ = np.full(max_pairs, cutoff_skin**2)
280        edge_src[:npairs] = p1[idx]
281        edge_dst[:npairs] = p2[idx]
282        d12_[:npairs] = d12[idx]
283        d12 = d12_
284
285        if apply_pbc:
286            pbc_shifts_ = np.zeros((max_pairs, 3))
287            pbc_shifts_[:npairs] = pbc_shifts[idx]
288            pbc_shifts = pbc_shifts_
289            if not minimage:
290                pbc_shifts[:npairs] = (
291                    pbc_shifts[:npairs]
292                    + at_shifts[edge_dst[:npairs]]
293                    - at_shifts[edge_src[:npairs]]
294                )
295
296        if "nblist_skin" in state:
297            edge_src_skin = edge_src
298            edge_dst_skin = edge_dst
299            if apply_pbc:
300                pbc_shifts_skin = pbc_shifts
301            max_pairs_skin = state.get("npairs_skin", 1)
302            mask = d12 < self.cutoff**2
303            idx = np.nonzero(mask)[0]
304            npairs_skin = idx.shape[0]
305            if npairs_skin > max_pairs_skin or add_margin:
306                prev_max_pairs_skin = max_pairs_skin
307                max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1
308                state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin)
309                new_state["npairs_skin"] = max_pairs_skin
310            edge_src = np.full(max_pairs_skin, nat, dtype=np.int32)
311            edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32)
312            d12_ = np.full(max_pairs_skin, self.cutoff**2)
313            edge_src[:npairs_skin] = edge_src_skin[idx]
314            edge_dst[:npairs_skin] = edge_dst_skin[idx]
315            d12_[:npairs_skin] = d12[idx]
316            d12 = d12_
317            if apply_pbc:
318                pbc_shifts = np.full((max_pairs_skin, 3), 0.0)
319                pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx]
320
321        ## symmetrize
322        edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate(
323            (edge_dst, edge_src)
324        )
325        d12 = np.concatenate((d12, d12))
326        if apply_pbc:
327            pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts))
328
329        graph = inputs.get(self.graph_key, {})
330        graph_out = {
331            **graph,
332            "edge_src": edge_src,
333            "edge_dst": edge_dst,
334            "d12": d12,
335            "overflow": False,
336            "pbc_shifts": pbc_shifts,
337        }
338        if "nblist_skin" in state:
339            graph_out["edge_src_skin"] = edge_src_skin
340            graph_out["edge_dst_skin"] = edge_dst_skin
341            if apply_pbc:
342                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
343
344        if self.k_space and apply_pbc:
345            if "k_points" not in graph:
346                ks, _, _, bewald = get_reciprocal_space_parameters(
347                    reciprocal_cells, self.cutoff, self.kmax, self.kthr
348                )
349            graph_out["k_points"] = ks
350            graph_out["b_ewald"] = bewald
351
352        output = {**inputs, self.graph_key: graph_out}
353
354        if return_state_update:
355            return FrozenDict(new_state), output, state_up
356        return FrozenDict(new_state), output
357
358    def check_reallocate(self, state, inputs, parent_overflow=False):
359        """check for overflow and reallocate nblist if necessary"""
360        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
361        if not overflow:
362            return state, {}, inputs, False
363
364        add_margin = inputs[self.graph_key].get("overflow", False)
365        state, inputs, state_up = self(
366            state, inputs, return_state_update=True, add_margin=add_margin
367        )
368        return state, state_up, inputs, True
369
370    @partial(jax.jit, static_argnums=(0, 1))
371    def process(self, state, inputs):
372        """build a nblist on accelerator with jax and precomputed shapes"""
373        if self.graph_key in inputs:
374            graph = inputs[self.graph_key]
375            if "keep_graph" in graph:
376                return inputs
377        coords = inputs["coordinates"]
378        natoms = inputs["natoms"]
379        batch_index = inputs["batch_index"]
380
381        if natoms.shape[0] == 1:
382            max_nat = coords.shape[0]
383        else:
384            max_nat = state.get(
385                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
386            )
387        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
388
389        ### compute indices of all pairs
390        p1, p2 = np.triu_indices(max_nat, 1)
391        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
392        pbc_shifts = None
393        if natoms.shape[0] > 1:
394            ## batching => mask irrelevant pairs
395            mask_p12 = (
396                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
397            ).flatten()
398            shift = jnp.concatenate(
399                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
400            )
401            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
402            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
403
404        ## compute vectors
405        overflow_repeats = jnp.asarray(False, dtype=bool)
406        if "cells" not in inputs:
407            vec = coords[p2] - coords[p1]
408        else:
409            cells = inputs["cells"]
410            reciprocal_cells = inputs["reciprocal_cells"]
411            minimage = state.get("minimum_image", True)
412
413            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
414                vecpbc = jnp.dot(vec, reciprocal_cell)
415                if mode == "round":
416                    pbc_shifts = -jnp.round(vecpbc)
417                elif mode == "floor":
418                    pbc_shifts = -jnp.floor(vecpbc)
419                else:
420                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
421                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
422
423            if minimage:
424                ## minimum image convention
425                vec = coords[p2] - coords[p1]
426
427                if cells.shape[0] == 1:
428                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
429                else:
430                    batch_index_vec = batch_index[p1]
431                    vec, pbc_shifts = jax.vmap(compute_pbc)(
432                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
433                    )
434            else:
435                ### general PBC only for single cell yet
436                # if cells.shape[0] > 1:
437                #     raise NotImplementedError(
438                #         "General PBC not implemented for batches on accelerator."
439                #     )
440                # cell = cells[0]
441                # reciprocal_cell = reciprocal_cells[0]
442
443                ## put all atoms in central box
444                if cells.shape[0] == 1:
445                    coords_pbc, at_shifts = compute_pbc(
446                        coords, reciprocal_cells[0], cells[0], mode="floor"
447                    )
448                else:
449                    coords_pbc, at_shifts = jax.vmap(
450                        partial(compute_pbc, mode="floor")
451                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
452                vec = coords_pbc[p2] - coords_pbc[p1]
453                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
454                # if num_repeats is None:
455                #     raise ValueError(
456                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
457                #     )
458                # check if num_repeats is larger than previous
459                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
460                cdinv = cutoff_skin * inv_distances
461                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
462                if "true_sys" in inputs:
463                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
464                num_repeats_new = jnp.max(num_repeats_all, axis=0)
465                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
466
467                cell_shift_pbc = jnp.asarray(
468                    np.array(
469                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
470                        dtype=cells.dtype,
471                    ).T.reshape(-1, 3)
472                )
473
474                if cells.shape[0] == 1:
475                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
476                    pbc_shifts = jnp.broadcast_to(
477                        cell_shift_pbc[None, :, :],
478                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
479                    ).reshape(-1, 3)
480                    p1 = jnp.broadcast_to(
481                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
482                    ).flatten()
483                    p2 = jnp.broadcast_to(
484                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
485                    ).flatten()
486                    if natoms.shape[0] > 1:
487                        mask_p12 = jnp.broadcast_to(
488                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
489                        ).flatten()
490                else:
491                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
492
493                    ## get pbc shifts specific to each box
494                    cell_shift_pbc = jnp.broadcast_to(
495                        cell_shift_pbc[None, :, :],
496                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
497                    )
498                    mask = jnp.all(
499                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
500                    ).flatten()
501                    max_shifts  = state.get("nshifts_pbc", 1)
502
503                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
504                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
505                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
506                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
507                        mask,
508                        max_shifts,
509                        (dvecx, 0.),
510                        (dvecy, 0.),
511                        (dvecz, 0.),
512                        (shiftx, 0),
513                        (shifty, 0),
514                        (shiftz, 0),
515                    )
516                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
517                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
518                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
519
520                    ## get batch shift in the dvec_filter array
521                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
522                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
523
524                    ## repeat vectors
525                    nvec_max = state.get("nvec_pbc", 1)
526                    batch_index_vec = batch_index[p1]
527                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
528                    nvec = nrep_vec.sum()
529                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
530                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
531                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
532
533                    ## get shift index
534                    dshift = jnp.concatenate(
535                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
536                    )
537                    if nrep_vec.size == 0:
538                        dshift = jnp.array([],dtype=jnp.int32)
539                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
540                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
541                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
542                    vec = vec + dvec[icellshift]
543                    pbc_shifts = cell_shift_pbc[icellshift]
544                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
545                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
546                    if natoms.shape[0] > 1:
547                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
548                
549
550        ## compute distances
551        d12 = (vec**2).sum(axis=-1)
552        if natoms.shape[0] > 1:
553            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
554
555        ## filter pairs
556        max_pairs = state.get("npairs", 1)
557        mask = d12 < cutoff_skin**2
558        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
559            mask,
560            max_pairs,
561            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
562            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
563            (d12, cutoff_skin**2),
564        )
565        if "cells" in inputs:
566            pbc_shifts = (
567                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
568                .at[scatter_idx]
569                .set(pbc_shifts, mode="drop")
570            )
571            if not minimage:
572                pbc_shifts = (
573                    pbc_shifts
574                    + at_shifts.at[edge_dst].get(fill_value=0.0)
575                    - at_shifts.at[edge_src].get(fill_value=0.0)
576                )
577
578        ## check for overflow
579        if natoms.shape[0] == 1:
580            true_max_nat = coords.shape[0]
581        else:
582            true_max_nat = jnp.max(natoms)
583        overflow_count = npairs > max_pairs
584        overflow_at = true_max_nat > max_nat
585        overflow = overflow_count | overflow_at | overflow_repeats
586
587        if "nblist_skin" in state:
588            # edge_mask_skin = edge_mask
589            edge_src_skin = edge_src
590            edge_dst_skin = edge_dst
591            if "cells" in inputs:
592                pbc_shifts_skin = pbc_shifts
593            max_pairs_skin = state.get("npairs_skin", 1)
594            mask = d12 < self.cutoff**2
595            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
596                mask,
597                max_pairs_skin,
598                (edge_src, coords.shape[0]),
599                (edge_dst, coords.shape[0]),
600                (d12, self.cutoff**2),
601            )
602            if "cells" in inputs:
603                pbc_shifts = (
604                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
605                    .at[scatter_idx]
606                    .set(pbc_shifts, mode="drop")
607                )
608            overflow = overflow | (npairs_skin > max_pairs_skin)
609
610        ## symmetrize
611        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
612            (edge_dst, edge_src)
613        )
614        d12 = jnp.concatenate((d12, d12))
615        if "cells" in inputs:
616            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
617
618        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
619        graph_out = {
620            **graph,
621            "edge_src": edge_src,
622            "edge_dst": edge_dst,
623            "d12": d12,
624            "overflow": overflow,
625            "pbc_shifts": pbc_shifts,
626        }
627        if "nblist_skin" in state:
628            graph_out["edge_src_skin"] = edge_src_skin
629            graph_out["edge_dst_skin"] = edge_dst_skin
630            if "cells" in inputs:
631                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
632
633        if self.k_space and "cells" in inputs:
634            if "k_points" not in graph:
635                raise NotImplementedError(
636                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
637                )
638        return {**inputs, self.graph_key: graph_out}
639
640    @partial(jax.jit, static_argnums=(0,))
641    def update_skin(self, inputs):
642        """update the nblist without recomputing the full nblist"""
643        graph = inputs[self.graph_key]
644
645        edge_src_skin = graph["edge_src_skin"]
646        edge_dst_skin = graph["edge_dst_skin"]
647        coords = inputs["coordinates"]
648        vec = coords.at[edge_dst_skin].get(
649            mode="fill", fill_value=self.cutoff
650        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
651
652        if "cells" in inputs:
653            pbc_shifts_skin = graph["pbc_shifts_skin"]
654            cells = inputs["cells"]
655            if cells.shape[0] == 1:
656                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
657            else:
658                batch_index_vec = inputs["batch_index"][edge_src_skin]
659                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
660
661        nat = coords.shape[0]
662        d12 = jnp.sum(vec**2, axis=-1)
663        mask = d12 < self.cutoff**2
664        max_pairs = graph["edge_src"].shape[0] // 2
665        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
666            mask,
667            max_pairs,
668            (edge_src_skin, nat),
669            (edge_dst_skin, nat),
670            (d12, self.cutoff**2),
671        )
672        if "cells" in inputs:
673            pbc_shifts = (
674                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
675                .at[scatter_idx]
676                .set(pbc_shifts_skin)
677            )
678
679        overflow = graph.get("overflow", False) | (npairs > max_pairs)
680        graph_out = {
681            **graph,
682            "edge_src": jnp.concatenate((edge_src, edge_dst)),
683            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
684            "d12": jnp.concatenate((d12, d12)),
685            "overflow": overflow,
686        }
687        if "cells" in inputs:
688            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
689
690        if self.k_space and "cells" in inputs:
691            if "k_points" not in graph:
692                raise NotImplementedError(
693                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
694                )
695
696        return {**inputs, self.graph_key: graph_out}

Generate a graph from a set of coordinates

FPID: GRAPH

For now, we generate all pairs of atoms and filter based on cutoff. If a nblist_skin is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the update_skin method to update the original graph without recomputing the full nblist.

GraphGenerator( cutoff: float, graph_key: str = 'graph', switch_params: dict = <factory>, kmax: int = 30, kthr: float = 1e-06, k_space: bool = False, mult_size: float = 1.05)
cutoff: float

Cutoff distance for the graph.

graph_key: str = 'graph'

Key of the graph in the outputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

kmax: int = 30

Maximum number of k-points to consider.

kthr: float = 1e-06

Threshold for k-point filtering.

k_space: bool = False

Whether to generate k-space information for the graph.

mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

FPID: ClassVar[str] = 'GRAPH'
def init(self):
50    def init(self):
51        return FrozenDict(
52            {
53                "max_nat": 1,
54                "npairs": 1,
55                "nblist_mult_size": self.mult_size,
56            }
57        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
59    def get_processor(self) -> Tuple[nn.Module, Dict]:
60        return GraphProcessor, {
61            "cutoff": self.cutoff,
62            "graph_key": self.graph_key,
63            "switch_params": self.switch_params,
64            "name": f"{self.graph_key}_Processor",
65        }
def get_graph_properties(self):
67    def get_graph_properties(self):
68        return {
69            self.graph_key: {
70                "cutoff": self.cutoff,
71                "directed": True,
72            }
73        }
def check_reallocate(self, state, inputs, parent_overflow=False):
358    def check_reallocate(self, state, inputs, parent_overflow=False):
359        """check for overflow and reallocate nblist if necessary"""
360        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
361        if not overflow:
362            return state, {}, inputs, False
363
364        add_margin = inputs[self.graph_key].get("overflow", False)
365        state, inputs, state_up = self(
366            state, inputs, return_state_update=True, add_margin=add_margin
367        )
368        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
370    @partial(jax.jit, static_argnums=(0, 1))
371    def process(self, state, inputs):
372        """build a nblist on accelerator with jax and precomputed shapes"""
373        if self.graph_key in inputs:
374            graph = inputs[self.graph_key]
375            if "keep_graph" in graph:
376                return inputs
377        coords = inputs["coordinates"]
378        natoms = inputs["natoms"]
379        batch_index = inputs["batch_index"]
380
381        if natoms.shape[0] == 1:
382            max_nat = coords.shape[0]
383        else:
384            max_nat = state.get(
385                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
386            )
387        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
388
389        ### compute indices of all pairs
390        p1, p2 = np.triu_indices(max_nat, 1)
391        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
392        pbc_shifts = None
393        if natoms.shape[0] > 1:
394            ## batching => mask irrelevant pairs
395            mask_p12 = (
396                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
397            ).flatten()
398            shift = jnp.concatenate(
399                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
400            )
401            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
402            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
403
404        ## compute vectors
405        overflow_repeats = jnp.asarray(False, dtype=bool)
406        if "cells" not in inputs:
407            vec = coords[p2] - coords[p1]
408        else:
409            cells = inputs["cells"]
410            reciprocal_cells = inputs["reciprocal_cells"]
411            minimage = state.get("minimum_image", True)
412
413            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
414                vecpbc = jnp.dot(vec, reciprocal_cell)
415                if mode == "round":
416                    pbc_shifts = -jnp.round(vecpbc)
417                elif mode == "floor":
418                    pbc_shifts = -jnp.floor(vecpbc)
419                else:
420                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
421                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
422
423            if minimage:
424                ## minimum image convention
425                vec = coords[p2] - coords[p1]
426
427                if cells.shape[0] == 1:
428                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
429                else:
430                    batch_index_vec = batch_index[p1]
431                    vec, pbc_shifts = jax.vmap(compute_pbc)(
432                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
433                    )
434            else:
435                ### general PBC only for single cell yet
436                # if cells.shape[0] > 1:
437                #     raise NotImplementedError(
438                #         "General PBC not implemented for batches on accelerator."
439                #     )
440                # cell = cells[0]
441                # reciprocal_cell = reciprocal_cells[0]
442
443                ## put all atoms in central box
444                if cells.shape[0] == 1:
445                    coords_pbc, at_shifts = compute_pbc(
446                        coords, reciprocal_cells[0], cells[0], mode="floor"
447                    )
448                else:
449                    coords_pbc, at_shifts = jax.vmap(
450                        partial(compute_pbc, mode="floor")
451                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
452                vec = coords_pbc[p2] - coords_pbc[p1]
453                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
454                # if num_repeats is None:
455                #     raise ValueError(
456                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
457                #     )
458                # check if num_repeats is larger than previous
459                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
460                cdinv = cutoff_skin * inv_distances
461                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
462                if "true_sys" in inputs:
463                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
464                num_repeats_new = jnp.max(num_repeats_all, axis=0)
465                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
466
467                cell_shift_pbc = jnp.asarray(
468                    np.array(
469                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
470                        dtype=cells.dtype,
471                    ).T.reshape(-1, 3)
472                )
473
474                if cells.shape[0] == 1:
475                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
476                    pbc_shifts = jnp.broadcast_to(
477                        cell_shift_pbc[None, :, :],
478                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
479                    ).reshape(-1, 3)
480                    p1 = jnp.broadcast_to(
481                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
482                    ).flatten()
483                    p2 = jnp.broadcast_to(
484                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
485                    ).flatten()
486                    if natoms.shape[0] > 1:
487                        mask_p12 = jnp.broadcast_to(
488                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
489                        ).flatten()
490                else:
491                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
492
493                    ## get pbc shifts specific to each box
494                    cell_shift_pbc = jnp.broadcast_to(
495                        cell_shift_pbc[None, :, :],
496                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
497                    )
498                    mask = jnp.all(
499                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
500                    ).flatten()
501                    max_shifts  = state.get("nshifts_pbc", 1)
502
503                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
504                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
505                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
506                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
507                        mask,
508                        max_shifts,
509                        (dvecx, 0.),
510                        (dvecy, 0.),
511                        (dvecz, 0.),
512                        (shiftx, 0),
513                        (shifty, 0),
514                        (shiftz, 0),
515                    )
516                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
517                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
518                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
519
520                    ## get batch shift in the dvec_filter array
521                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
522                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
523
524                    ## repeat vectors
525                    nvec_max = state.get("nvec_pbc", 1)
526                    batch_index_vec = batch_index[p1]
527                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
528                    nvec = nrep_vec.sum()
529                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
530                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
531                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
532
533                    ## get shift index
534                    dshift = jnp.concatenate(
535                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
536                    )
537                    if nrep_vec.size == 0:
538                        dshift = jnp.array([],dtype=jnp.int32)
539                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
540                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
541                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
542                    vec = vec + dvec[icellshift]
543                    pbc_shifts = cell_shift_pbc[icellshift]
544                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
545                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
546                    if natoms.shape[0] > 1:
547                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
548                
549
550        ## compute distances
551        d12 = (vec**2).sum(axis=-1)
552        if natoms.shape[0] > 1:
553            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
554
555        ## filter pairs
556        max_pairs = state.get("npairs", 1)
557        mask = d12 < cutoff_skin**2
558        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
559            mask,
560            max_pairs,
561            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
562            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
563            (d12, cutoff_skin**2),
564        )
565        if "cells" in inputs:
566            pbc_shifts = (
567                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
568                .at[scatter_idx]
569                .set(pbc_shifts, mode="drop")
570            )
571            if not minimage:
572                pbc_shifts = (
573                    pbc_shifts
574                    + at_shifts.at[edge_dst].get(fill_value=0.0)
575                    - at_shifts.at[edge_src].get(fill_value=0.0)
576                )
577
578        ## check for overflow
579        if natoms.shape[0] == 1:
580            true_max_nat = coords.shape[0]
581        else:
582            true_max_nat = jnp.max(natoms)
583        overflow_count = npairs > max_pairs
584        overflow_at = true_max_nat > max_nat
585        overflow = overflow_count | overflow_at | overflow_repeats
586
587        if "nblist_skin" in state:
588            # edge_mask_skin = edge_mask
589            edge_src_skin = edge_src
590            edge_dst_skin = edge_dst
591            if "cells" in inputs:
592                pbc_shifts_skin = pbc_shifts
593            max_pairs_skin = state.get("npairs_skin", 1)
594            mask = d12 < self.cutoff**2
595            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
596                mask,
597                max_pairs_skin,
598                (edge_src, coords.shape[0]),
599                (edge_dst, coords.shape[0]),
600                (d12, self.cutoff**2),
601            )
602            if "cells" in inputs:
603                pbc_shifts = (
604                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
605                    .at[scatter_idx]
606                    .set(pbc_shifts, mode="drop")
607                )
608            overflow = overflow | (npairs_skin > max_pairs_skin)
609
610        ## symmetrize
611        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
612            (edge_dst, edge_src)
613        )
614        d12 = jnp.concatenate((d12, d12))
615        if "cells" in inputs:
616            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
617
618        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
619        graph_out = {
620            **graph,
621            "edge_src": edge_src,
622            "edge_dst": edge_dst,
623            "d12": d12,
624            "overflow": overflow,
625            "pbc_shifts": pbc_shifts,
626        }
627        if "nblist_skin" in state:
628            graph_out["edge_src_skin"] = edge_src_skin
629            graph_out["edge_dst_skin"] = edge_dst_skin
630            if "cells" in inputs:
631                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
632
633        if self.k_space and "cells" in inputs:
634            if "k_points" not in graph:
635                raise NotImplementedError(
636                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
637                )
638        return {**inputs, self.graph_key: graph_out}

build a nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
640    @partial(jax.jit, static_argnums=(0,))
641    def update_skin(self, inputs):
642        """update the nblist without recomputing the full nblist"""
643        graph = inputs[self.graph_key]
644
645        edge_src_skin = graph["edge_src_skin"]
646        edge_dst_skin = graph["edge_dst_skin"]
647        coords = inputs["coordinates"]
648        vec = coords.at[edge_dst_skin].get(
649            mode="fill", fill_value=self.cutoff
650        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
651
652        if "cells" in inputs:
653            pbc_shifts_skin = graph["pbc_shifts_skin"]
654            cells = inputs["cells"]
655            if cells.shape[0] == 1:
656                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
657            else:
658                batch_index_vec = inputs["batch_index"][edge_src_skin]
659                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
660
661        nat = coords.shape[0]
662        d12 = jnp.sum(vec**2, axis=-1)
663        mask = d12 < self.cutoff**2
664        max_pairs = graph["edge_src"].shape[0] // 2
665        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
666            mask,
667            max_pairs,
668            (edge_src_skin, nat),
669            (edge_dst_skin, nat),
670            (d12, self.cutoff**2),
671        )
672        if "cells" in inputs:
673            pbc_shifts = (
674                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
675                .at[scatter_idx]
676                .set(pbc_shifts_skin)
677            )
678
679        overflow = graph.get("overflow", False) | (npairs > max_pairs)
680        graph_out = {
681            **graph,
682            "edge_src": jnp.concatenate((edge_src, edge_dst)),
683            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
684            "d12": jnp.concatenate((d12, d12)),
685            "overflow": overflow,
686        }
687        if "cells" in inputs:
688            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
689
690        if self.k_space and "cells" in inputs:
691            if "k_points" not in graph:
692                raise NotImplementedError(
693                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
694                )
695
696        return {**inputs, self.graph_key: graph_out}

update the nblist without recomputing the full nblist

class GraphProcessor(flax.linen.module.Module):
699class GraphProcessor(nn.Module):
700    """Process a pre-generated graph
701
702    The pre-generated graph should contain the following keys:
703    - edge_src: source indices of the edges
704    - edge_dst: destination indices of the edges
705    - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs)
706
707    This module is automatically added to a FENNIX model when a GraphGenerator is used.
708
709    """
710
711    cutoff: float
712    """Cutoff distance for the graph."""
713    graph_key: str = "graph"
714    """Key of the graph in the outputs."""
715    switch_params: dict = dataclasses.field(default_factory=dict)
716    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
717
718    @nn.compact
719    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
720        graph = inputs[self.graph_key]
721        coords = inputs["coordinates"]
722        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
723        # edge_mask = edge_src < coords.shape[0]
724        vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[
725            edge_src
726        ].get(mode="fill", fill_value=0.0)
727        if "cells" in inputs:
728            cells = inputs["cells"]
729            if cells.shape[0] == 1:
730                vec = vec + jnp.dot(graph["pbc_shifts"], cells[0])
731            else:
732                batch_index_vec = inputs["batch_index"][edge_src]
733                vec = vec + jax.vmap(jnp.dot)(
734                    graph["pbc_shifts"], cells[batch_index_vec]
735                )
736
737        d2  = jnp.sum(vec**2, axis=-1)
738        distances = safe_sqrt(d2)
739        edge_mask = distances < self.cutoff
740
741        switch = SwitchFunction(
742            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
743        )((distances, edge_mask))
744
745        graph_out = {
746            **graph,
747            "vec": vec,
748            "distances": distances,
749            "switch": switch,
750            "edge_mask": edge_mask,
751        }
752
753        if "alch_group" in inputs:
754            alch_group = inputs["alch_group"]
755            lambda_e = inputs["alch_elambda"]
756            mask = alch_group[edge_src] == alch_group[edge_dst]
757            graph_out["switch_raw"] = switch
758            graph_out["switch"] = jnp.where(
759                mask,
760                switch,
761                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
762            )
763
764            if "alch_softcore_e" in inputs or "alch_softcore_v" in inputs:
765                graph_out["distances_raw"] = distances
766                if "alch_softcore_e" in inputs:
767                    alch_alpha = (1-inputs["alch_elambda"])*inputs["alch_softcore_e"]**2
768                else:
769                    alch_alpha = (1-inputs["alch_vlambda"])*inputs["alch_softcore_v"]**2
770                distances = jnp.where(
771                    mask,
772                    distances,
773                    safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2))
774                )  
775                graph_out["distances"] = distances
776
777        return {**inputs, self.graph_key: graph_out}

Process a pre-generated graph

The pre-generated graph should contain the following keys:

  • edge_src: source indices of the edges
  • edge_dst: destination indices of the edges
  • pbcs_shifts: pbc shifts for the edges (only if cells are present in the inputs)

This module is automatically added to a FENNIX model when a GraphGenerator is used.

GraphProcessor( cutoff: float, graph_key: str = 'graph', switch_params: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: float

Cutoff distance for the graph.

graph_key: str = 'graph'

Key of the graph in the outputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class GraphFilter:
780@dataclasses.dataclass(frozen=True)
781class GraphFilter:
782    """Filter a graph based on a cutoff distance
783
784    FPID: GRAPH_FILTER
785    """
786
787    cutoff: float
788    """Cutoff distance for the filtering."""
789    parent_graph: str
790    """Key of the parent graph in the inputs."""
791    graph_key: str
792    """Key of the filtered graph in the outputs."""
793    remove_hydrogens: int = False
794    """Remove edges where the source is a hydrogen atom."""
795    switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict)
796    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
797    k_space: bool = False
798    """Generate k-space information for the graph."""
799    kmax: int = 30
800    """Maximum number of k-points to consider."""
801    kthr: float = 1e-6
802    """Threshold for k-point filtering."""
803    mult_size: float = 1.05
804    """Multiplicative factor for resizing the nblist."""
805
806    FPID: ClassVar[str] = "GRAPH_FILTER"
807
808    def init(self):
809        return FrozenDict(
810            {
811                "npairs": 1,
812                "nblist_mult_size": self.mult_size,
813            }
814        )
815
816    def get_processor(self) -> Tuple[nn.Module, Dict]:
817        return GraphFilterProcessor, {
818            "cutoff": self.cutoff,
819            "graph_key": self.graph_key,
820            "parent_graph": self.parent_graph,
821            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
822            "switch_params": self.switch_params,
823        }
824
825    def get_graph_properties(self):
826        return {
827            self.graph_key: {
828                "cutoff": self.cutoff,
829                "directed": True,
830                "parent_graph": self.parent_graph,
831            }
832        }
833
834    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
835        """filter a nblist on cpu with numpy and dynamic shapes + store max shapes"""
836        graph_in = inputs[self.parent_graph]
837        nat = inputs["species"].shape[0]
838
839        new_state = {**state}
840        state_up = {}
841        mult_size = state.get("nblist_mult_size", self.mult_size)
842        assert mult_size >= 1., "nblist_mult_size should be >= 1."
843
844        edge_src = np.array(graph_in["edge_src"], dtype=np.int32)
845        d12 = np.array(graph_in["d12"], dtype=np.float32)
846        if self.remove_hydrogens:
847            species = inputs["species"]
848            src_idx = (edge_src < nat).nonzero()[0]
849            mask = np.zeros(edge_src.shape[0], dtype=bool)
850            mask[src_idx] = (species > 1)[edge_src[src_idx]]
851            d12 = np.where(mask, d12, self.cutoff**2)
852        mask = d12 < self.cutoff**2
853
854        max_pairs = state.get("npairs", 1)
855        idx = np.nonzero(mask)[0]
856        npairs = idx.shape[0]
857        if npairs > max_pairs or add_margin:
858            prev_max_pairs = max_pairs
859            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
860            state_up["npairs"] = (max_pairs, prev_max_pairs)
861            new_state["npairs"] = max_pairs
862
863        filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32)
864        edge_src = np.full(max_pairs, nat, dtype=np.int32)
865        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
866        d12_ = np.full(max_pairs, self.cutoff**2)
867        filter_indices[:npairs] = idx
868        edge_src[:npairs] = graph_in["edge_src"][idx]
869        edge_dst[:npairs] = graph_in["edge_dst"][idx]
870        d12_[:npairs] = d12[idx]
871        d12 = d12_
872
873        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
874        graph_out = {
875            **graph,
876            "edge_src": edge_src,
877            "edge_dst": edge_dst,
878            "filter_indices": filter_indices,
879            "d12": d12,
880            "overflow": False,
881        }
882
883        if self.k_space and "cells" in inputs:
884            if "k_points" not in graph:
885                ks, _, _, bewald = get_reciprocal_space_parameters(
886                    inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr
887                )
888            graph_out["k_points"] = ks
889            graph_out["b_ewald"] = bewald
890
891        output = {**inputs, self.graph_key: graph_out}
892        if return_state_update:
893            return FrozenDict(new_state), output, state_up
894        return FrozenDict(new_state), output
895
896    def check_reallocate(self, state, inputs, parent_overflow=False):
897        """check for overflow and reallocate nblist if necessary"""
898        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
899        if not overflow:
900            return state, {}, inputs, False
901
902        add_margin = inputs[self.graph_key].get("overflow", False)
903        state, inputs, state_up = self(
904            state, inputs, return_state_update=True, add_margin=add_margin
905        )
906        return state, state_up, inputs, True
907
908    @partial(jax.jit, static_argnums=(0, 1))
909    def process(self, state, inputs):
910        """filter a nblist on accelerator with jax and precomputed shapes"""
911        graph_in = inputs[self.parent_graph]
912        if state is None:
913            # skin update mode
914            graph = inputs[self.graph_key]
915            max_pairs = graph["edge_src"].shape[0]
916        else:
917            max_pairs = state.get("npairs", 1)
918
919        max_pairs_in = graph_in["edge_src"].shape[0]
920        nat = inputs["species"].shape[0]
921
922        edge_src = graph_in["edge_src"]
923        d12 = graph_in["d12"]
924        if self.remove_hydrogens:
925            species = inputs["species"]
926            mask = (species > 1)[edge_src]
927            d12 = jnp.where(mask, d12, self.cutoff**2)
928        mask = d12 < self.cutoff**2
929
930        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
931            mask,
932            max_pairs,
933            (edge_src, nat),
934            (graph_in["edge_dst"], nat),
935            (d12, self.cutoff**2),
936            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
937        )
938
939        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
940        overflow = graph.get("overflow", False) | (npairs > max_pairs)
941        graph_out = {
942            **graph,
943            "edge_src": edge_src,
944            "edge_dst": edge_dst,
945            "filter_indices": filter_indices,
946            "d12": d12,
947            "overflow": overflow,
948        }
949
950        if self.k_space and "cells" in inputs:
951            if "k_points" not in graph:
952                raise NotImplementedError(
953                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
954                )
955
956        return {**inputs, self.graph_key: graph_out}
957
958    @partial(jax.jit, static_argnums=(0,))
959    def update_skin(self, inputs):
960        return self.process(None, inputs)

Filter a graph based on a cutoff distance

FPID: GRAPH_FILTER

GraphFilter( cutoff: float, parent_graph: str, graph_key: str, remove_hydrogens: int = False, switch_params: flax.core.frozen_dict.FrozenDict = <factory>, k_space: bool = False, kmax: int = 30, kthr: float = 1e-06, mult_size: float = 1.05)
cutoff: float

Cutoff distance for the filtering.

parent_graph: str

Key of the parent graph in the inputs.

graph_key: str

Key of the filtered graph in the outputs.

remove_hydrogens: int = False

Remove edges where the source is a hydrogen atom.

switch_params: flax.core.frozen_dict.FrozenDict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

k_space: bool = False

Generate k-space information for the graph.

kmax: int = 30

Maximum number of k-points to consider.

kthr: float = 1e-06

Threshold for k-point filtering.

mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

FPID: ClassVar[str] = 'GRAPH_FILTER'
def init(self):
808    def init(self):
809        return FrozenDict(
810            {
811                "npairs": 1,
812                "nblist_mult_size": self.mult_size,
813            }
814        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
816    def get_processor(self) -> Tuple[nn.Module, Dict]:
817        return GraphFilterProcessor, {
818            "cutoff": self.cutoff,
819            "graph_key": self.graph_key,
820            "parent_graph": self.parent_graph,
821            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
822            "switch_params": self.switch_params,
823        }
def get_graph_properties(self):
825    def get_graph_properties(self):
826        return {
827            self.graph_key: {
828                "cutoff": self.cutoff,
829                "directed": True,
830                "parent_graph": self.parent_graph,
831            }
832        }
def check_reallocate(self, state, inputs, parent_overflow=False):
896    def check_reallocate(self, state, inputs, parent_overflow=False):
897        """check for overflow and reallocate nblist if necessary"""
898        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
899        if not overflow:
900            return state, {}, inputs, False
901
902        add_margin = inputs[self.graph_key].get("overflow", False)
903        state, inputs, state_up = self(
904            state, inputs, return_state_update=True, add_margin=add_margin
905        )
906        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
908    @partial(jax.jit, static_argnums=(0, 1))
909    def process(self, state, inputs):
910        """filter a nblist on accelerator with jax and precomputed shapes"""
911        graph_in = inputs[self.parent_graph]
912        if state is None:
913            # skin update mode
914            graph = inputs[self.graph_key]
915            max_pairs = graph["edge_src"].shape[0]
916        else:
917            max_pairs = state.get("npairs", 1)
918
919        max_pairs_in = graph_in["edge_src"].shape[0]
920        nat = inputs["species"].shape[0]
921
922        edge_src = graph_in["edge_src"]
923        d12 = graph_in["d12"]
924        if self.remove_hydrogens:
925            species = inputs["species"]
926            mask = (species > 1)[edge_src]
927            d12 = jnp.where(mask, d12, self.cutoff**2)
928        mask = d12 < self.cutoff**2
929
930        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
931            mask,
932            max_pairs,
933            (edge_src, nat),
934            (graph_in["edge_dst"], nat),
935            (d12, self.cutoff**2),
936            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
937        )
938
939        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
940        overflow = graph.get("overflow", False) | (npairs > max_pairs)
941        graph_out = {
942            **graph,
943            "edge_src": edge_src,
944            "edge_dst": edge_dst,
945            "filter_indices": filter_indices,
946            "d12": d12,
947            "overflow": overflow,
948        }
949
950        if self.k_space and "cells" in inputs:
951            if "k_points" not in graph:
952                raise NotImplementedError(
953                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
954                )
955
956        return {**inputs, self.graph_key: graph_out}

filter a nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
958    @partial(jax.jit, static_argnums=(0,))
959    def update_skin(self, inputs):
960        return self.process(None, inputs)
class GraphFilterProcessor(flax.linen.module.Module):
 963class GraphFilterProcessor(nn.Module):
 964    """Filter processing for a pre-generated graph
 965
 966    This module is automatically added to a FENNIX model when a GraphFilter is used.
 967    """
 968
 969    cutoff: float
 970    """Cutoff distance for the filtering."""
 971    graph_key: str
 972    """Key of the filtered graph in the inputs."""
 973    parent_graph: str
 974    """Key of the parent graph in the inputs."""
 975    switch_params: dict = dataclasses.field(default_factory=dict)
 976    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 977
 978    @nn.compact
 979    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 980        graph_in = inputs[self.parent_graph]
 981        graph = inputs[self.graph_key]
 982
 983        d_key = "distances_raw" if "distances_raw" in graph else "distances"
 984
 985        if graph_in["vec"].shape[0] == 0:
 986            vec = graph_in["vec"]
 987            distances = graph_in[d_key]
 988            filter_indices = jnp.asarray([], dtype=jnp.int32)
 989        else:
 990            filter_indices = graph["filter_indices"]
 991            vec = (
 992                graph_in["vec"]
 993                .at[filter_indices]
 994                .get(mode="fill", fill_value=self.cutoff)
 995            )
 996            distances = (
 997                graph_in[d_key]
 998                .at[filter_indices]
 999                .get(mode="fill", fill_value=self.cutoff)
1000            )
1001
1002        edge_mask = distances < self.cutoff
1003        switch = SwitchFunction(
1004            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
1005        )((distances, edge_mask))
1006
1007        graph_out = {
1008            **graph,
1009            "vec": vec,
1010            "distances": distances,
1011            "switch": switch,
1012            "filter_indices": filter_indices,
1013            "edge_mask": edge_mask,
1014        }
1015
1016        if "alch_group" in inputs:
1017            edge_src=graph["edge_src"]
1018            edge_dst=graph["edge_dst"]
1019            alch_group = inputs["alch_group"]
1020            lambda_e = inputs["alch_elambda"]
1021            mask = alch_group[edge_src] == alch_group[edge_dst]
1022            graph_out["switch_raw"] = switch
1023            graph_out["switch"] = jnp.where(
1024                mask,
1025                switch,
1026                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
1027            )
1028
1029            if "alch_softcore_e" in inputs or "alch_softcore_v" in inputs:
1030                graph_out["distances_raw"] = distances
1031                if "alch_softcore_e" in inputs:
1032                    alch_alpha = (1-inputs["alch_elambda"])*inputs["alch_softcore_e"]**2
1033                else:
1034                    alch_alpha = (1-inputs["alch_vlambda"])*inputs["alch_softcore_v"]**2
1035                distances = jnp.where(
1036                    mask,
1037                    distances,
1038                    safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2))
1039                )  
1040                graph_out["distances"] = distances
1041
1042        return {**inputs, self.graph_key: graph_out}

Filter processing for a pre-generated graph

This module is automatically added to a FENNIX model when a GraphFilter is used.

GraphFilterProcessor( cutoff: float, graph_key: str, parent_graph: str, switch_params: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: float

Cutoff distance for the filtering.

graph_key: str

Key of the filtered graph in the inputs.

parent_graph: str

Key of the parent graph in the inputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class GraphAngularExtension:
1045@dataclasses.dataclass(frozen=True)
1046class GraphAngularExtension:
1047    """Add angles list to a graph
1048
1049    FPID: GRAPH_ANGULAR_EXTENSION
1050    """
1051
1052    mult_size: float = 1.05
1053    """Multiplicative factor for resizing the nblist."""
1054    add_neigh: int = 5
1055    """Additional neighbors to add to the nblist when resizing."""
1056    graph_key: str = "graph"
1057    """Key of the graph in the inputs."""
1058
1059    FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION"
1060
1061    def init(self):
1062        return FrozenDict(
1063            {
1064                "nangles": 0,
1065                "nblist_mult_size": self.mult_size,
1066                "max_neigh": self.add_neigh,
1067                "add_neigh": self.add_neigh,
1068            }
1069        )
1070
1071    def get_processor(self) -> Tuple[nn.Module, Dict]:
1072        return GraphAngleProcessor, {
1073            "graph_key": self.graph_key,
1074            "name": f"{self.graph_key}_AngleProcessor",
1075        }
1076
1077    def get_graph_properties(self):
1078        return {
1079            self.graph_key: {
1080                "has_angles": True,
1081            }
1082        }
1083
1084    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1085        """build angle nblist on cpu with numpy and dynamic shapes + store max shapes"""
1086        graph = inputs[self.graph_key]
1087        edge_src = np.array(graph["edge_src"], dtype=np.int32)
1088
1089        new_state = {**state}
1090        state_up = {}
1091        mult_size = state.get("nblist_mult_size", self.mult_size)
1092        assert mult_size >= 1., "nblist_mult_size should be >= 1."
1093
1094        ### count number of neighbors
1095        nat = inputs["species"].shape[0]
1096        count = np.zeros(nat + 1, dtype=np.int32)
1097        np.add.at(count, edge_src, 1)
1098        max_count = int(np.max(count[:-1]))
1099
1100        ### get sizes
1101        max_neigh = state.get("max_neigh", self.add_neigh)
1102        nedge = edge_src.shape[0]
1103        if max_count > max_neigh or add_margin:
1104            prev_max_neigh = max_neigh
1105            max_neigh = max(max_count, max_neigh) + state.get(
1106                "add_neigh", self.add_neigh
1107            )
1108            state_up["max_neigh"] = (max_neigh, prev_max_neigh)
1109            new_state["max_neigh"] = max_neigh
1110
1111        max_neigh_arr = np.empty(max_neigh, dtype=bool)
1112
1113        nedge = edge_src.shape[0]
1114
1115        ### sort edge_src
1116        idx_sort = np.argsort(edge_src)
1117        edge_src_sorted = edge_src[idx_sort]
1118
1119        ### map sparse to dense nblist
1120        offset = np.tile(np.arange(max_count), nat)
1121        if max_count * nat >= nedge:
1122            offset = np.tile(np.arange(max_count), nat)[:nedge]
1123        else:
1124            offset = np.zeros(nedge, dtype=np.int32)
1125            offset[: max_count * nat] = np.tile(np.arange(max_count), nat)
1126
1127        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1128        mask = edge_src_sorted < nat
1129        indices = edge_src_sorted * max_count + offset
1130        indices = indices[mask]
1131        idx_sort = idx_sort[mask]
1132        edge_idx = np.full(nat * max_count, nedge, dtype=np.int32)
1133        edge_idx[indices] = idx_sort
1134        edge_idx = edge_idx.reshape(nat, max_count)
1135
1136        ### find all triplet for each atom center
1137        local_src, local_dst = np.triu_indices(max_count, 1)
1138        angle_src = edge_idx[:, local_src].flatten()
1139        angle_dst = edge_idx[:, local_dst].flatten()
1140
1141        ### mask for valid angles
1142        mask1 = angle_src < nedge
1143        mask2 = angle_dst < nedge
1144        angle_mask = mask1 & mask2
1145
1146        max_angles = state.get("nangles", 0)
1147        idx = np.nonzero(angle_mask)[0]
1148        nangles = idx.shape[0]
1149        if nangles > max_angles or add_margin:
1150            max_angles_prev = max_angles
1151            max_angles = int(mult_size * max(nangles, max_angles)) + 1
1152            state_up["nangles"] = (max_angles, max_angles_prev)
1153            new_state["nangles"] = max_angles
1154
1155        ## filter angles to sparse representation
1156        angle_src_ = np.full(max_angles, nedge, dtype=np.int32)
1157        angle_dst_ = np.full(max_angles, nedge, dtype=np.int32)
1158        angle_src_[:nangles] = angle_src[idx]
1159        angle_dst_[:nangles] = angle_dst[idx]
1160
1161        central_atom = np.full(max_angles, nat, dtype=np.int32)
1162        central_atom[:nangles] = edge_src[angle_src_[:nangles]]
1163
1164        ## update graph
1165        output = {
1166            **inputs,
1167            self.graph_key: {
1168                **graph,
1169                "angle_src": angle_src_,
1170                "angle_dst": angle_dst_,
1171                "central_atom": central_atom,
1172                "angle_overflow": False,
1173                "max_neigh": max_neigh,
1174                "__max_neigh_array": max_neigh_arr,
1175            },
1176        }
1177
1178        if return_state_update:
1179            return FrozenDict(new_state), output, state_up
1180        return FrozenDict(new_state), output
1181
1182    def check_reallocate(self, state, inputs, parent_overflow=False):
1183        """check for overflow and reallocate nblist if necessary"""
1184        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1185        if not overflow:
1186            return state, {}, inputs, False
1187
1188        add_margin = inputs[self.graph_key]["angle_overflow"]
1189        state, inputs, state_up = self(
1190            state, inputs, return_state_update=True, add_margin=add_margin
1191        )
1192        return state, state_up, inputs, True
1193
1194    @partial(jax.jit, static_argnums=(0, 1))
1195    def process(self, state, inputs):
1196        """build angle nblist on accelerator with jax and precomputed shapes"""
1197        graph = inputs[self.graph_key]
1198        edge_src = graph["edge_src"]
1199
1200        ### count number of neighbors
1201        nat = inputs["species"].shape[0]
1202        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1203        max_count = jnp.max(count)
1204
1205        ### get sizes
1206        if state is None:
1207            max_neigh_arr = graph["__max_neigh_array"]
1208            max_neigh = max_neigh_arr.shape[0]
1209            prev_nangles = graph["angle_src"].shape[0]
1210        else:
1211            max_neigh = state.get("max_neigh", self.add_neigh)
1212            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1213            prev_nangles = state.get("nangles", 0)
1214
1215        nedge = edge_src.shape[0]
1216
1217        ### sort edge_src
1218        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1219        edge_src_sorted = edge_src[idx_sort]
1220
1221        ### map sparse to dense nblist
1222        if max_neigh * nat < nedge:
1223            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1224        offset = jnp.asarray(
1225            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1226        )
1227        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1228        indices = edge_src_sorted * max_neigh + offset
1229        edge_idx = (
1230            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1231            .at[indices]
1232            .set(idx_sort, mode="drop")
1233            .reshape(nat, max_neigh)
1234        )
1235
1236        ### find all triplet for each atom center
1237        local_src, local_dst = np.triu_indices(max_neigh, 1)
1238        angle_src = edge_idx[:, local_src].flatten()
1239        angle_dst = edge_idx[:, local_dst].flatten()
1240
1241        ### mask for valid angles
1242        mask1 = angle_src < nedge
1243        mask2 = angle_dst < nedge
1244        angle_mask = mask1 & mask2
1245
1246        ## filter angles to sparse representation
1247        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1248            angle_mask,
1249            prev_nangles,
1250            (angle_src, nedge),
1251            (angle_dst, nedge),
1252        )
1253        ## find central atom
1254        central_atom = edge_src[angle_src]
1255
1256        ## check for overflow
1257        angle_overflow = nangles > prev_nangles
1258        neigh_overflow = max_count > max_neigh
1259        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1260
1261        ## update graph
1262        output = {
1263            **inputs,
1264            self.graph_key: {
1265                **graph,
1266                "angle_src": angle_src,
1267                "angle_dst": angle_dst,
1268                "central_atom": central_atom,
1269                "angle_overflow": overflow,
1270                # "max_neigh": max_neigh,
1271                "__max_neigh_array": max_neigh_arr,
1272            },
1273        }
1274
1275        return output
1276
1277    @partial(jax.jit, static_argnums=(0,))
1278    def update_skin(self, inputs):
1279        return self.process(None, inputs)

Add angles list to a graph

FPID: GRAPH_ANGULAR_EXTENSION

GraphAngularExtension( mult_size: float = 1.05, add_neigh: int = 5, graph_key: str = 'graph')
mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

add_neigh: int = 5

Additional neighbors to add to the nblist when resizing.

graph_key: str = 'graph'

Key of the graph in the inputs.

FPID: ClassVar[str] = 'GRAPH_ANGULAR_EXTENSION'
def init(self):
1061    def init(self):
1062        return FrozenDict(
1063            {
1064                "nangles": 0,
1065                "nblist_mult_size": self.mult_size,
1066                "max_neigh": self.add_neigh,
1067                "add_neigh": self.add_neigh,
1068            }
1069        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
1071    def get_processor(self) -> Tuple[nn.Module, Dict]:
1072        return GraphAngleProcessor, {
1073            "graph_key": self.graph_key,
1074            "name": f"{self.graph_key}_AngleProcessor",
1075        }
def get_graph_properties(self):
1077    def get_graph_properties(self):
1078        return {
1079            self.graph_key: {
1080                "has_angles": True,
1081            }
1082        }
def check_reallocate(self, state, inputs, parent_overflow=False):
1182    def check_reallocate(self, state, inputs, parent_overflow=False):
1183        """check for overflow and reallocate nblist if necessary"""
1184        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1185        if not overflow:
1186            return state, {}, inputs, False
1187
1188        add_margin = inputs[self.graph_key]["angle_overflow"]
1189        state, inputs, state_up = self(
1190            state, inputs, return_state_update=True, add_margin=add_margin
1191        )
1192        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1194    @partial(jax.jit, static_argnums=(0, 1))
1195    def process(self, state, inputs):
1196        """build angle nblist on accelerator with jax and precomputed shapes"""
1197        graph = inputs[self.graph_key]
1198        edge_src = graph["edge_src"]
1199
1200        ### count number of neighbors
1201        nat = inputs["species"].shape[0]
1202        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1203        max_count = jnp.max(count)
1204
1205        ### get sizes
1206        if state is None:
1207            max_neigh_arr = graph["__max_neigh_array"]
1208            max_neigh = max_neigh_arr.shape[0]
1209            prev_nangles = graph["angle_src"].shape[0]
1210        else:
1211            max_neigh = state.get("max_neigh", self.add_neigh)
1212            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1213            prev_nangles = state.get("nangles", 0)
1214
1215        nedge = edge_src.shape[0]
1216
1217        ### sort edge_src
1218        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1219        edge_src_sorted = edge_src[idx_sort]
1220
1221        ### map sparse to dense nblist
1222        if max_neigh * nat < nedge:
1223            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1224        offset = jnp.asarray(
1225            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1226        )
1227        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1228        indices = edge_src_sorted * max_neigh + offset
1229        edge_idx = (
1230            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1231            .at[indices]
1232            .set(idx_sort, mode="drop")
1233            .reshape(nat, max_neigh)
1234        )
1235
1236        ### find all triplet for each atom center
1237        local_src, local_dst = np.triu_indices(max_neigh, 1)
1238        angle_src = edge_idx[:, local_src].flatten()
1239        angle_dst = edge_idx[:, local_dst].flatten()
1240
1241        ### mask for valid angles
1242        mask1 = angle_src < nedge
1243        mask2 = angle_dst < nedge
1244        angle_mask = mask1 & mask2
1245
1246        ## filter angles to sparse representation
1247        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1248            angle_mask,
1249            prev_nangles,
1250            (angle_src, nedge),
1251            (angle_dst, nedge),
1252        )
1253        ## find central atom
1254        central_atom = edge_src[angle_src]
1255
1256        ## check for overflow
1257        angle_overflow = nangles > prev_nangles
1258        neigh_overflow = max_count > max_neigh
1259        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1260
1261        ## update graph
1262        output = {
1263            **inputs,
1264            self.graph_key: {
1265                **graph,
1266                "angle_src": angle_src,
1267                "angle_dst": angle_dst,
1268                "central_atom": central_atom,
1269                "angle_overflow": overflow,
1270                # "max_neigh": max_neigh,
1271                "__max_neigh_array": max_neigh_arr,
1272            },
1273        }
1274
1275        return output

build angle nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1277    @partial(jax.jit, static_argnums=(0,))
1278    def update_skin(self, inputs):
1279        return self.process(None, inputs)
class GraphAngleProcessor(flax.linen.module.Module):
1282class GraphAngleProcessor(nn.Module):
1283    """Process a pre-generated graph to compute angles
1284
1285    This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
1286
1287    """
1288
1289    graph_key: str
1290    """Key of the graph in the inputs."""
1291
1292    @nn.compact
1293    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1294        graph = inputs[self.graph_key]
1295        distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"]
1296        vec = graph["vec"]
1297        angle_src = graph["angle_src"]
1298        angle_dst = graph["angle_dst"]
1299
1300        dir = vec / jnp.clip(distances[:, None], min=1.0e-5)
1301        cos_angles = (
1302            dir.at[angle_src].get(mode="fill", fill_value=0.5)
1303            * dir.at[angle_dst].get(mode="fill", fill_value=0.5)
1304        ).sum(axis=-1)
1305
1306        angles = jnp.arccos(0.95 * cos_angles)
1307
1308        return {
1309            **inputs,
1310            self.graph_key: {
1311                **graph,
1312                # "cos_angles": cos_angles,
1313                "angles": angles,
1314                # "angle_mask": angle_mask,
1315            },
1316        }

Process a pre-generated graph to compute angles

This module is automatically added to a FENNIX model when a GraphAngularExtension is used.

GraphAngleProcessor( graph_key: str, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str

Key of the graph in the inputs.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class SpeciesIndexer:
1319@dataclasses.dataclass(frozen=True)
1320class SpeciesIndexer:
1321    """Build an index that splits atomic arrays by species.
1322
1323    FPID: SPECIES_INDEXER
1324
1325    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1326    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1327
1328    """
1329
1330    output_key: str = "species_index"
1331    """Key for the output dictionary."""
1332    species_order: Optional[str] = None
1333    """Comma separated list of species in the order they should be indexed."""
1334    add_atoms: int = 0
1335    """Additional atoms to add to the sizes."""
1336    add_atoms_margin: int = 10
1337    """Additional atoms to add to the sizes when adding margin."""
1338
1339    FPID: ClassVar[str] = "SPECIES_INDEXER"
1340
1341    def init(self):
1342        return FrozenDict(
1343            {
1344                "sizes": {},
1345            }
1346        )
1347
1348    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1349        species = np.array(inputs["species"], dtype=np.int32)
1350        nat = species.shape[0]
1351        set_species, counts = np.unique(species, return_counts=True)
1352
1353        new_state = {**state}
1354        state_up = {}
1355
1356        sizes = state.get("sizes", FrozenDict({}))
1357        new_sizes = {**sizes}
1358        up_sizes = False
1359        counts_dict = {}
1360        for s, c in zip(set_species, counts):
1361            if s <= 0:
1362                continue
1363            counts_dict[s] = c
1364            if c > sizes.get(s, 0):
1365                up_sizes = True
1366                add_atoms = state.get("add_atoms", self.add_atoms)
1367                if add_margin:
1368                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1369                new_sizes[s] = c + add_atoms
1370
1371        new_sizes = FrozenDict(new_sizes)
1372        if up_sizes:
1373            state_up["sizes"] = (new_sizes, sizes)
1374            new_state["sizes"] = new_sizes
1375
1376        if self.species_order is not None:
1377            species_order = [el.strip() for el in self.species_order.split(",")]
1378            max_size_prev = state.get("max_size", 0)
1379            max_size = max(new_sizes.values())
1380            if max_size > max_size_prev:
1381                state_up["max_size"] = (max_size, max_size_prev)
1382                new_state["max_size"] = max_size
1383                max_size_prev = max_size
1384
1385            species_index = np.full((len(species_order), max_size), nat, dtype=np.int32)
1386            for i, el in enumerate(species_order):
1387                s = PERIODIC_TABLE_REV_IDX[el]
1388                if s in counts_dict.keys():
1389                    species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1390        else:
1391            species_index = {
1392                PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32)
1393                for s, c in new_sizes.items()
1394            }
1395            for s, c in zip(set_species, counts):
1396                if s <= 0:
1397                    continue
1398                species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0]
1399
1400        output = {
1401            **inputs,
1402            self.output_key: species_index,
1403            self.output_key + "_overflow": False,
1404        }
1405
1406        if return_state_update:
1407            return FrozenDict(new_state), output, state_up
1408        return FrozenDict(new_state), output
1409
1410    def check_reallocate(self, state, inputs, parent_overflow=False):
1411        """check for overflow and reallocate nblist if necessary"""
1412        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1413        if not overflow:
1414            return state, {}, inputs, False
1415
1416        add_margin = inputs[self.output_key + "_overflow"]
1417        state, inputs, state_up = self(
1418            state, inputs, return_state_update=True, add_margin=add_margin
1419        )
1420        return state, state_up, inputs, True
1421        # return state, {}, inputs, parent_overflow
1422
1423    @partial(jax.jit, static_argnums=(0, 1))
1424    def process(self, state, inputs):
1425        # assert (
1426        #     self.output_key in inputs
1427        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1428
1429        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1430        if self.output_key in inputs and not recompute_species_index:
1431            return inputs
1432
1433        if state is None:
1434            raise ValueError("Species Indexer state must be provided on accelerator.")
1435
1436        species = inputs["species"]
1437        nat = species.shape[0]
1438
1439        sizes = state["sizes"]
1440
1441        if self.species_order is not None:
1442            species_order = [el.strip() for el in self.species_order.split(",")]
1443            max_size = state["max_size"]
1444
1445            species_index = jnp.full(
1446                (len(species_order), max_size), nat, dtype=jnp.int32
1447            )
1448            for i, el in enumerate(species_order):
1449                s = PERIODIC_TABLE_REV_IDX[el]
1450                if s in sizes.keys():
1451                    c = sizes[s]
1452                    species_index = species_index.at[i, :].set(
1453                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1454                    )
1455                # if s in counts_dict.keys():
1456                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1457        else:
1458            # species_index = {
1459            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1460            # for s, c in sizes.items()
1461            # }
1462            species_index = {}
1463            overflow = False
1464            natcount = 0
1465            for s, c in sizes.items():
1466                mask = species == s
1467                new_size = jnp.sum(mask)
1468                natcount = natcount + new_size
1469                overflow = overflow | (new_size > c)  # check if sizes are correct
1470                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1471                    species == s, size=c, fill_value=nat
1472                )[0]
1473
1474            mask = species <= 0
1475            new_size = jnp.sum(mask)
1476            natcount = natcount + new_size
1477            overflow = overflow | (
1478                natcount < species.shape[0]
1479            )  # check if any species missing
1480
1481        return {
1482            **inputs,
1483            self.output_key: species_index,
1484            self.output_key + "_overflow": overflow,
1485        }
1486
1487    @partial(jax.jit, static_argnums=(0,))
1488    def update_skin(self, inputs):
1489        return self.process(None, inputs)

Build an index that splits atomic arrays by species.

FPID: SPECIES_INDEXER

If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.

SpeciesIndexer( output_key: str = 'species_index', species_order: Optional[str] = None, add_atoms: int = 0, add_atoms_margin: int = 10)
output_key: str = 'species_index'

Key for the output dictionary.

species_order: Optional[str] = None

Comma separated list of species in the order they should be indexed.

add_atoms: int = 0

Additional atoms to add to the sizes.

add_atoms_margin: int = 10

Additional atoms to add to the sizes when adding margin.

FPID: ClassVar[str] = 'SPECIES_INDEXER'
def init(self):
1341    def init(self):
1342        return FrozenDict(
1343            {
1344                "sizes": {},
1345            }
1346        )
def check_reallocate(self, state, inputs, parent_overflow=False):
1410    def check_reallocate(self, state, inputs, parent_overflow=False):
1411        """check for overflow and reallocate nblist if necessary"""
1412        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1413        if not overflow:
1414            return state, {}, inputs, False
1415
1416        add_margin = inputs[self.output_key + "_overflow"]
1417        state, inputs, state_up = self(
1418            state, inputs, return_state_update=True, add_margin=add_margin
1419        )
1420        return state, state_up, inputs, True
1421        # return state, {}, inputs, parent_overflow

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1423    @partial(jax.jit, static_argnums=(0, 1))
1424    def process(self, state, inputs):
1425        # assert (
1426        #     self.output_key in inputs
1427        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1428
1429        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1430        if self.output_key in inputs and not recompute_species_index:
1431            return inputs
1432
1433        if state is None:
1434            raise ValueError("Species Indexer state must be provided on accelerator.")
1435
1436        species = inputs["species"]
1437        nat = species.shape[0]
1438
1439        sizes = state["sizes"]
1440
1441        if self.species_order is not None:
1442            species_order = [el.strip() for el in self.species_order.split(",")]
1443            max_size = state["max_size"]
1444
1445            species_index = jnp.full(
1446                (len(species_order), max_size), nat, dtype=jnp.int32
1447            )
1448            for i, el in enumerate(species_order):
1449                s = PERIODIC_TABLE_REV_IDX[el]
1450                if s in sizes.keys():
1451                    c = sizes[s]
1452                    species_index = species_index.at[i, :].set(
1453                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1454                    )
1455                # if s in counts_dict.keys():
1456                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1457        else:
1458            # species_index = {
1459            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1460            # for s, c in sizes.items()
1461            # }
1462            species_index = {}
1463            overflow = False
1464            natcount = 0
1465            for s, c in sizes.items():
1466                mask = species == s
1467                new_size = jnp.sum(mask)
1468                natcount = natcount + new_size
1469                overflow = overflow | (new_size > c)  # check if sizes are correct
1470                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1471                    species == s, size=c, fill_value=nat
1472                )[0]
1473
1474            mask = species <= 0
1475            new_size = jnp.sum(mask)
1476            natcount = natcount + new_size
1477            overflow = overflow | (
1478                natcount < species.shape[0]
1479            )  # check if any species missing
1480
1481        return {
1482            **inputs,
1483            self.output_key: species_index,
1484            self.output_key + "_overflow": overflow,
1485        }
@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1487    @partial(jax.jit, static_argnums=(0,))
1488    def update_skin(self, inputs):
1489        return self.process(None, inputs)
@dataclasses.dataclass(frozen=True)
class BlockIndexer:
1491@dataclasses.dataclass(frozen=True)
1492class BlockIndexer:
1493    """Build an index that splits atomic arrays by chemical blocks.
1494
1495    FPID: BLOCK_INDEXER
1496
1497    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1498    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1499
1500    """
1501
1502    output_key: str = "block_index"
1503    """Key for the output dictionary."""
1504    add_atoms: int = 0
1505    """Additional atoms to add to the sizes."""
1506    add_atoms_margin: int = 10
1507    """Additional atoms to add to the sizes when adding margin."""
1508    split_CNOPSSe: bool = False
1509
1510    FPID: ClassVar[str] = "BLOCK_INDEXER"
1511
1512    def init(self):
1513        return FrozenDict(
1514            {
1515                "sizes": {},
1516            }
1517        )
1518
1519    def build_chemical_blocks(self):
1520        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1521        if self.split_CNOPSSe:
1522            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1523            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1524        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1525        if self.split_CNOPSSe:
1526            _CHEMICAL_BLOCKS[6] = 1
1527            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1528            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1529            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1530            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1531            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1532        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1533
1534    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1535        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1536
1537        species = np.array(inputs["species"], dtype=np.int32)
1538        blocks = _CHEMICAL_BLOCKS[species]
1539        nat = species.shape[0]
1540        set_blocks, counts = np.unique(blocks, return_counts=True)
1541
1542        new_state = {**state}
1543        state_up = {}
1544
1545        sizes = state.get("sizes", FrozenDict({}))
1546        new_sizes = {**sizes}
1547        up_sizes = False
1548        for s, c in zip(set_blocks, counts):
1549            if s < 0:
1550                continue
1551            key = (s, _CHEMICAL_BLOCKS_NAMES[s])
1552            if c > sizes.get(key, 0):
1553                up_sizes = True
1554                add_atoms = state.get("add_atoms", self.add_atoms)
1555                if add_margin:
1556                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1557                new_sizes[key] = c + add_atoms
1558
1559        new_sizes = FrozenDict(new_sizes)
1560        if up_sizes:
1561            state_up["sizes"] = (new_sizes, sizes)
1562            new_state["sizes"] = new_sizes
1563
1564        block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES}
1565        for (_,n), c in new_sizes.items():
1566            block_index[n] = np.full(c, nat, dtype=np.int32)
1567        # block_index = {
1568            # n: np.full(c, nat, dtype=np.int32)
1569            # for (_,n), c in new_sizes.items()
1570        # }
1571        for s, c in zip(set_blocks, counts):
1572            if s < 0:
1573                continue
1574            block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0]
1575
1576        output = {
1577            **inputs,
1578            self.output_key: block_index,
1579            self.output_key + "_overflow": False,
1580        }
1581
1582        if return_state_update:
1583            return FrozenDict(new_state), output, state_up
1584        return FrozenDict(new_state), output
1585
1586    def check_reallocate(self, state, inputs, parent_overflow=False):
1587        """check for overflow and reallocate nblist if necessary"""
1588        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1589        if not overflow:
1590            return state, {}, inputs, False
1591
1592        add_margin = inputs[self.output_key + "_overflow"]
1593        state, inputs, state_up = self(
1594            state, inputs, return_state_update=True, add_margin=add_margin
1595        )
1596        return state, state_up, inputs, True
1597        # return state, {}, inputs, parent_overflow
1598
1599    @partial(jax.jit, static_argnums=(0, 1))
1600    def process(self, state, inputs):
1601        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1602        # assert (
1603        #     self.output_key in inputs
1604        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1605
1606        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1607        if self.output_key in inputs and not recompute_species_index:
1608            return inputs
1609
1610        if state is None:
1611            raise ValueError("Block Indexer state must be provided on accelerator.")
1612
1613        species = inputs["species"]
1614        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1615        nat = species.shape[0]
1616
1617        sizes = state["sizes"]
1618
1619        # species_index = {
1620        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1621        # for s, c in sizes.items()
1622        # }
1623        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1624        overflow = False
1625        natcount = 0
1626        for (s,name), c in sizes.items():
1627            mask = blocks == s
1628            new_size = jnp.sum(mask)
1629            natcount = natcount + new_size
1630            overflow = overflow | (new_size > c)  # check if sizes are correct
1631            block_index[name] = jnp.nonzero(
1632                mask, size=c, fill_value=nat
1633            )[0]
1634
1635        mask = blocks < 0
1636        new_size = jnp.sum(mask)
1637        natcount = natcount + new_size
1638        overflow = overflow | (
1639            natcount < species.shape[0]
1640        )  # check if any species missing
1641
1642        return {
1643            **inputs,
1644            self.output_key: block_index,
1645            self.output_key + "_overflow": overflow,
1646        }
1647
1648    @partial(jax.jit, static_argnums=(0,))
1649    def update_skin(self, inputs):
1650        return self.process(None, inputs)

Build an index that splits atomic arrays by chemical blocks.

FPID: BLOCK_INDEXER

If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.

BlockIndexer( output_key: str = 'block_index', add_atoms: int = 0, add_atoms_margin: int = 10, split_CNOPSSe: bool = False)
output_key: str = 'block_index'

Key for the output dictionary.

add_atoms: int = 0

Additional atoms to add to the sizes.

add_atoms_margin: int = 10

Additional atoms to add to the sizes when adding margin.

split_CNOPSSe: bool = False
FPID: ClassVar[str] = 'BLOCK_INDEXER'
def init(self):
1512    def init(self):
1513        return FrozenDict(
1514            {
1515                "sizes": {},
1516            }
1517        )
def build_chemical_blocks(self):
1519    def build_chemical_blocks(self):
1520        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1521        if self.split_CNOPSSe:
1522            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1523            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1524        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1525        if self.split_CNOPSSe:
1526            _CHEMICAL_BLOCKS[6] = 1
1527            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1528            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1529            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1530            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1531            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1532        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
def check_reallocate(self, state, inputs, parent_overflow=False):
1586    def check_reallocate(self, state, inputs, parent_overflow=False):
1587        """check for overflow and reallocate nblist if necessary"""
1588        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1589        if not overflow:
1590            return state, {}, inputs, False
1591
1592        add_margin = inputs[self.output_key + "_overflow"]
1593        state, inputs, state_up = self(
1594            state, inputs, return_state_update=True, add_margin=add_margin
1595        )
1596        return state, state_up, inputs, True
1597        # return state, {}, inputs, parent_overflow

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1599    @partial(jax.jit, static_argnums=(0, 1))
1600    def process(self, state, inputs):
1601        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1602        # assert (
1603        #     self.output_key in inputs
1604        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1605
1606        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1607        if self.output_key in inputs and not recompute_species_index:
1608            return inputs
1609
1610        if state is None:
1611            raise ValueError("Block Indexer state must be provided on accelerator.")
1612
1613        species = inputs["species"]
1614        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1615        nat = species.shape[0]
1616
1617        sizes = state["sizes"]
1618
1619        # species_index = {
1620        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1621        # for s, c in sizes.items()
1622        # }
1623        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1624        overflow = False
1625        natcount = 0
1626        for (s,name), c in sizes.items():
1627            mask = blocks == s
1628            new_size = jnp.sum(mask)
1629            natcount = natcount + new_size
1630            overflow = overflow | (new_size > c)  # check if sizes are correct
1631            block_index[name] = jnp.nonzero(
1632                mask, size=c, fill_value=nat
1633            )[0]
1634
1635        mask = blocks < 0
1636        new_size = jnp.sum(mask)
1637        natcount = natcount + new_size
1638        overflow = overflow | (
1639            natcount < species.shape[0]
1640        )  # check if any species missing
1641
1642        return {
1643            **inputs,
1644            self.output_key: block_index,
1645            self.output_key + "_overflow": overflow,
1646        }
@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1648    @partial(jax.jit, static_argnums=(0,))
1649    def update_skin(self, inputs):
1650        return self.process(None, inputs)
@dataclasses.dataclass(frozen=True)
class AtomPadding:
1653@dataclasses.dataclass(frozen=True)
1654class AtomPadding:
1655    """Pad atomic arrays to a fixed size."""
1656
1657    mult_size: float = 1.2
1658    """Multiplicative factor for resizing the atomic arrays."""
1659    add_sys: int = 0
1660
1661    def init(self):
1662        return {"prev_nat": 0, "prev_nsys": 0}
1663
1664    def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]:
1665        species = inputs["species"]
1666        nat = species.shape[0]
1667
1668        prev_nat = state.get("prev_nat", 0)
1669        prev_nat_ = prev_nat
1670        if nat > prev_nat_:
1671            prev_nat_ = int(self.mult_size * nat) + 1
1672
1673        nsys = len(inputs["natoms"])
1674        prev_nsys = state.get("prev_nsys", 0)
1675        prev_nsys_ = prev_nsys
1676        if nsys > prev_nsys_:
1677            prev_nsys_ = nsys + self.add_sys
1678
1679        add_atoms = prev_nat_ - nat
1680        add_sys = prev_nsys_ - nsys  + 1
1681        output = {**inputs}
1682        if add_atoms > 0:
1683            for k, v in inputs.items():
1684                if isinstance(v, np.ndarray) or isinstance(v, jax.Array):
1685                    if v.shape[0] == nat:
1686                        output[k] = np.append(
1687                            v,
1688                            np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype),
1689                            axis=0,
1690                        )
1691                    elif v.shape[0] == nsys:
1692                        if k == "cells":
1693                            output[k] = np.append(
1694                                v,
1695                                1000
1696                                * np.eye(3, dtype=v.dtype)[None, :, :].repeat(
1697                                    add_sys, axis=0
1698                                ),
1699                                axis=0,
1700                            )
1701                        else:
1702                            output[k] = np.append(
1703                                v,
1704                                np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype),
1705                                axis=0,
1706                            )
1707            output["natoms"] = np.append(
1708                inputs["natoms"], np.zeros(add_sys, dtype=np.int32)
1709            )
1710            output["species"] = np.append(
1711                species, -1 * np.ones(add_atoms, dtype=species.dtype)
1712            )
1713            output["batch_index"] = np.append(
1714                inputs["batch_index"],
1715                np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype),
1716            )
1717            if "system_index" in inputs:
1718                output["system_index"] = np.append(
1719                    inputs["system_index"],
1720                    np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype),
1721                )
1722
1723        output["true_atoms"] = output["species"] > 0
1724        output["true_sys"] = np.arange(len(output["natoms"])) < nsys
1725
1726        state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_}
1727
1728        return FrozenDict(state), output

Pad atomic arrays to a fixed size.

AtomPadding(mult_size: float = 1.2, add_sys: int = 0)
mult_size: float = 1.2

Multiplicative factor for resizing the atomic arrays.

add_sys: int = 0
def init(self):
1661    def init(self):
1662        return {"prev_nat": 0, "prev_nsys": 0}
def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1731def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1732    """Remove padding from atomic arrays."""
1733    if "true_atoms" not in inputs:
1734        return inputs
1735
1736    species = inputs["species"]
1737    true_atoms = inputs["true_atoms"]
1738    true_sys = inputs["true_sys"]
1739    natall = species.shape[0]
1740    nat = np.argmax(species <= 0)
1741    if nat == 0:
1742        return inputs
1743
1744    natoms = inputs["natoms"]
1745    nsysall = len(natoms)
1746
1747    output = {**inputs}
1748    for k, v in inputs.items():
1749        if isinstance(v, jax.Array) or isinstance(v, np.ndarray):
1750            if v.ndim == 0:
1751                continue
1752            if v.shape[0] == natall:
1753                output[k] = v[true_atoms]
1754            elif v.shape[0] == nsysall:
1755                output[k] = v[true_sys]
1756    del output["true_sys"]
1757    del output["true_atoms"]
1758    return output

Remove padding from atomic arrays.

def check_input(inputs):
1761def check_input(inputs):
1762    """Check the input dictionary for required keys and types."""
1763    assert "species" in inputs, "species must be provided"
1764    assert "coordinates" in inputs, "coordinates must be provided"
1765    species = inputs["species"].astype(np.int32)
1766    ifake = np.argmax(species <= 0)
1767    if ifake > 0:
1768        assert np.all(species[:ifake] > 0), "species must be positive"
1769    nat = inputs["species"].shape[0]
1770
1771    natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32)
1772    batch_index = inputs.get(
1773        "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms)
1774    ).astype(np.int32)
1775    output = {**inputs, "natoms": natoms, "batch_index": batch_index}
1776    if "cells" in inputs:
1777        cells = inputs["cells"]
1778        if "reciprocal_cells" not in inputs:
1779            reciprocal_cells = np.linalg.inv(cells)
1780        else:
1781            reciprocal_cells = inputs["reciprocal_cells"]
1782        if cells.ndim == 2:
1783            cells = cells[None, :, :]
1784        if reciprocal_cells.ndim == 2:
1785            reciprocal_cells = reciprocal_cells[None, :, :]
1786        output["cells"] = cells
1787        output["reciprocal_cells"] = reciprocal_cells
1788
1789    return output

Check the input dictionary for required keys and types.

def convert_to_jax(data):
1792def convert_to_jax(data):
1793    """Convert a numpy arrays to jax arrays in a pytree."""
1794
1795    def convert(x):
1796        if isinstance(x, np.ndarray):
1797            # if x.dtype == np.float64:
1798            #     return jnp.asarray(x, dtype=jnp.float32)
1799            return jnp.asarray(x)
1800        return x
1801
1802    return jax.tree_util.tree_map(convert, data)

Convert a numpy arrays to jax arrays in a pytree.

class JaxConverter(flax.linen.module.Module):
1805class JaxConverter(nn.Module):
1806    """Convert numpy arrays to jax arrays in a pytree."""
1807
1808    def __call__(self, data):
1809        return convert_to_jax(data)

Convert numpy arrays to jax arrays in a pytree.

JaxConverter( parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class PreprocessingChain:
1812@dataclasses.dataclass(frozen=True)
1813class PreprocessingChain:
1814    """Chain of preprocessing layers."""
1815
1816    layers: Tuple[Callable[..., Dict[str, Any]]]
1817    """Preprocessing layers."""
1818    use_atom_padding: bool = False
1819    """Add an AtomPadding layer at the beginning of the chain."""
1820    atom_padder: AtomPadding = AtomPadding()
1821    """AtomPadding layer."""
1822
1823    def __post_init__(self):
1824        if not isinstance(self.layers, Sequence):
1825            raise ValueError(
1826                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
1827            )
1828        if not self.layers:
1829            raise ValueError(f"Error: no Preprocessing layers were provided.")
1830
1831    def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]:
1832        do_check_input = state.get("check_input", True)
1833        if do_check_input:
1834            inputs = check_input(inputs)
1835        new_state = []
1836        layer_state = state["layers_state"]
1837        i = 0
1838        if self.use_atom_padding:
1839            s, inputs = self.atom_padder(layer_state[0], inputs)
1840            new_state.append(s)
1841            i += 1
1842        for layer in self.layers:
1843            s, inputs = layer(layer_state[i], inputs, return_state_update=False)
1844            new_state.append(s)
1845            i += 1
1846        return FrozenDict({**state, "layers_state": tuple(new_state)}), convert_to_jax(
1847            inputs
1848        )
1849
1850    def check_reallocate(self, state, inputs):
1851        new_state = []
1852        state_up = []
1853        layer_state = state["layers_state"]
1854        i = 0
1855        if self.use_atom_padding:
1856            new_state.append(layer_state[0])
1857            i += 1
1858        parent_overflow = False
1859        for layer in self.layers:
1860            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1861                layer_state[i], inputs, parent_overflow
1862            )
1863            new_state.append(s)
1864            state_up.append(s_up)
1865            i += 1
1866
1867        if not parent_overflow:
1868            return state, {}, inputs, False
1869        return (
1870            FrozenDict({**state, "layers_state": tuple(new_state)}),
1871            state_up,
1872            inputs,
1873            True,
1874        )
1875
1876    def atom_padding(self, state, inputs):
1877        if self.use_atom_padding:
1878            padder_state = state["layers_state"][0]
1879            return self.atom_padder(padder_state, inputs)
1880        return state, inputs
1881
1882    @partial(jax.jit, static_argnums=(0, 1))
1883    def process(self, state, inputs):
1884        layer_state = state["layers_state"]
1885        i = 1 if self.use_atom_padding else 0
1886        for layer in self.layers:
1887            inputs = layer.process(layer_state[i], inputs)
1888            i += 1
1889        return inputs
1890
1891    @partial(jax.jit, static_argnums=(0))
1892    def update_skin(self, inputs):
1893        for layer in self.layers:
1894            inputs = layer.update_skin(inputs)
1895        return inputs
1896
1897    def init(self):
1898        state = []
1899        if self.use_atom_padding:
1900            state.append(self.atom_padder.init())
1901        for layer in self.layers:
1902            state.append(layer.init())
1903        return FrozenDict({"check_input": True, "layers_state": state})
1904
1905    def init_with_output(self, inputs):
1906        state = self.init()
1907        return self(state, inputs)
1908
1909    def get_processors(self):
1910        processors = []
1911        for layer in self.layers:
1912            if hasattr(layer, "get_processor"):
1913                processors.append(layer.get_processor())
1914        return processors
1915
1916    def get_graphs_properties(self):
1917        properties = {}
1918        for layer in self.layers:
1919            if hasattr(layer, "get_graph_properties"):
1920                properties = deep_update(properties, layer.get_graph_properties())
1921        return properties

Chain of preprocessing layers.

PreprocessingChain( layers: Tuple[Callable[..., Dict[str, Any]]], use_atom_padding: bool = False, atom_padder: AtomPadding = AtomPadding(mult_size=1.2, add_sys=0))
layers: Tuple[Callable[..., Dict[str, Any]]]

Preprocessing layers.

use_atom_padding: bool = False

Add an AtomPadding layer at the beginning of the chain.

atom_padder: AtomPadding = AtomPadding(mult_size=1.2, add_sys=0)

AtomPadding layer.

def check_reallocate(self, state, inputs):
1850    def check_reallocate(self, state, inputs):
1851        new_state = []
1852        state_up = []
1853        layer_state = state["layers_state"]
1854        i = 0
1855        if self.use_atom_padding:
1856            new_state.append(layer_state[0])
1857            i += 1
1858        parent_overflow = False
1859        for layer in self.layers:
1860            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1861                layer_state[i], inputs, parent_overflow
1862            )
1863            new_state.append(s)
1864            state_up.append(s_up)
1865            i += 1
1866
1867        if not parent_overflow:
1868            return state, {}, inputs, False
1869        return (
1870            FrozenDict({**state, "layers_state": tuple(new_state)}),
1871            state_up,
1872            inputs,
1873            True,
1874        )
def atom_padding(self, state, inputs):
1876    def atom_padding(self, state, inputs):
1877        if self.use_atom_padding:
1878            padder_state = state["layers_state"][0]
1879            return self.atom_padder(padder_state, inputs)
1880        return state, inputs
@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1882    @partial(jax.jit, static_argnums=(0, 1))
1883    def process(self, state, inputs):
1884        layer_state = state["layers_state"]
1885        i = 1 if self.use_atom_padding else 0
1886        for layer in self.layers:
1887            inputs = layer.process(layer_state[i], inputs)
1888            i += 1
1889        return inputs
@partial(jax.jit, static_argnums=0)
def update_skin(self, inputs):
1891    @partial(jax.jit, static_argnums=(0))
1892    def update_skin(self, inputs):
1893        for layer in self.layers:
1894            inputs = layer.update_skin(inputs)
1895        return inputs
def init(self):
1897    def init(self):
1898        state = []
1899        if self.use_atom_padding:
1900            state.append(self.atom_padder.init())
1901        for layer in self.layers:
1902            state.append(layer.init())
1903        return FrozenDict({"check_input": True, "layers_state": state})
def init_with_output(self, inputs):
1905    def init_with_output(self, inputs):
1906        state = self.init()
1907        return self(state, inputs)
def get_processors(self):
1909    def get_processors(self):
1910        processors = []
1911        for layer in self.layers:
1912            if hasattr(layer, "get_processor"):
1913                processors.append(layer.get_processor())
1914        return processors
def get_graphs_properties(self):
1916    def get_graphs_properties(self):
1917        properties = {}
1918        for layer in self.layers:
1919            if hasattr(layer, "get_graph_properties"):
1920                properties = deep_update(properties, layer.get_graph_properties())
1921        return properties