fennol.models.preprocessing

   1import flax.linen as nn
   2from typing import Sequence, Callable, Union, Dict, Any, ClassVar
   3import jax.numpy as jnp
   4import jax
   5import numpy as np
   6from typing import Optional, Tuple
   7import numba
   8import dataclasses
   9from functools import partial
  10
  11from flax.core.frozen_dict import FrozenDict
  12
  13
  14from ..utils.activations import chain,safe_sqrt
  15from ..utils import deep_update, mask_filter_1d
  16from ..utils.kspace import get_reciprocal_space_parameters
  17from .misc.misc import SwitchFunction
  18from ..utils.periodic_table import PERIODIC_TABLE, PERIODIC_TABLE_REV_IDX,CHEMICAL_BLOCKS,CHEMICAL_BLOCKS_NAMES
  19
  20
  21@dataclasses.dataclass(frozen=True)
  22class GraphGenerator:
  23    """Generate a graph from a set of coordinates
  24
  25    FPID: GRAPH
  26
  27    For now, we generate all pairs of atoms and filter based on cutoff.
  28    If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist.
  29    """
  30
  31    cutoff: float
  32    """Cutoff distance for the graph."""
  33    graph_key: str = "graph"
  34    """Key of the graph in the outputs."""
  35    switch_params: dict = dataclasses.field(default_factory=dict, hash=False)
  36    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
  37    kmax: int = 30
  38    """Maximum number of k-points to consider."""
  39    kthr: float = 1e-6
  40    """Threshold for k-point filtering."""
  41    k_space: bool = False
  42    """Whether to generate k-space information for the graph."""
  43    mult_size: float = 1.05
  44    """Multiplicative factor for resizing the nblist."""
  45    # covalent_cutoff: bool = False
  46
  47    FPID: ClassVar[str] = "GRAPH"
  48
  49    def init(self):
  50        return FrozenDict(
  51            {
  52                "max_nat": 1,
  53                "npairs": 1,
  54                "nblist_mult_size": self.mult_size,
  55            }
  56        )
  57
  58    def get_processor(self) -> Tuple[nn.Module, Dict]:
  59        return GraphProcessor, {
  60            "cutoff": self.cutoff,
  61            "graph_key": self.graph_key,
  62            "switch_params": self.switch_params,
  63            "name": f"{self.graph_key}_Processor",
  64        }
  65
  66    def get_graph_properties(self):
  67        return {
  68            self.graph_key: {
  69                "cutoff": self.cutoff,
  70                "directed": True,
  71            }
  72        }
  73
  74    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
  75        """build a nblist on cpu with numpy and dynamic shapes + store max shapes"""
  76        if self.graph_key in inputs:
  77            graph = inputs[self.graph_key]
  78            if "keep_graph" in graph:
  79                return state, inputs
  80
  81        coords = np.array(inputs["coordinates"], dtype=np.float32)
  82        natoms = np.array(inputs["natoms"], dtype=np.int32)
  83        batch_index = np.array(inputs["batch_index"], dtype=np.int32)
  84
  85        new_state = {**state}
  86        state_up = {}
  87
  88        mult_size = state.get("nblist_mult_size", self.mult_size)
  89        assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0"
  90
  91        if natoms.shape[0] == 1:
  92            max_nat = coords.shape[0]
  93            true_max_nat = max_nat
  94        else:
  95            max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0]))
  96            true_max_nat = int(np.max(natoms))
  97            if true_max_nat > max_nat:
  98                add_atoms = state.get("add_atoms", 0)
  99                new_maxnat = true_max_nat + add_atoms
 100                state_up["max_nat"] = (new_maxnat, max_nat)
 101                new_state["max_nat"] = new_maxnat
 102
 103        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
 104
 105        ### compute indices of all pairs
 106        p1, p2 = np.triu_indices(true_max_nat, 1)
 107        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
 108        pbc_shifts = None
 109        if natoms.shape[0] > 1:
 110            ## batching => mask irrelevant pairs
 111            mask_p12 = (
 112                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
 113            ).flatten()
 114            shift = np.concatenate(
 115                (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32))
 116            )
 117            p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
 118            p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
 119
 120        apply_pbc = "cells" in inputs
 121        if not apply_pbc:
 122            ### NO PBC
 123            vec = coords[p2] - coords[p1]
 124        else:
 125            cells = np.array(inputs["cells"], dtype=np.float32)
 126            reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32)
 127            minimage = "minimum_image" in inputs.get("flags", {})
 128            if minimage:
 129                ## MINIMUM IMAGE CONVENTION
 130                vec = coords[p2] - coords[p1]
 131                if cells.shape[0] == 1:
 132                    vecpbc = np.dot(vec, reciprocal_cells[0])
 133                    pbc_shifts = -np.round(vecpbc)
 134                    vec = vec + np.dot(pbc_shifts, cells[0])
 135                else:
 136                    batch_index_vec = batch_index[p1]
 137                    vecpbc = np.einsum(
 138                        "aj,aji->ai", vec, reciprocal_cells[batch_index_vec]
 139                    )
 140                    pbc_shifts = -np.round(vecpbc)
 141                    vec = vec + np.einsum(
 142                        "aj,aji->ai", pbc_shifts, cells[batch_index_vec]
 143                    )
 144            else:
 145                ### GENERAL PBC
 146                ## put all atoms in central box
 147                if cells.shape[0] == 1:
 148                    coords_pbc = np.dot(coords, reciprocal_cells[0])
 149                    at_shifts = -np.floor(coords_pbc)
 150                    coords_pbc = coords + np.dot(at_shifts, cells[0])
 151                else:
 152                    coords_pbc = np.einsum(
 153                        "aj,aji->ai", coords, reciprocal_cells[batch_index]
 154                    )
 155                    at_shifts = -np.floor(coords_pbc)
 156                    coords_pbc = coords + np.einsum(
 157                        "aj,aji->ai", at_shifts, cells[batch_index]
 158                    )
 159                vec = coords_pbc[p2] - coords_pbc[p1]
 160
 161                ## compute maximum number of repeats
 162                inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5
 163                cdinv = cutoff_skin * inv_distances
 164                num_repeats_all = np.ceil(cdinv).astype(np.int32)
 165                if "true_sys" in inputs:
 166                    num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0)
 167                # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all)
 168                num_repeats = np.max(num_repeats_all, axis=0)
 169                num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0)))
 170                if np.any(num_repeats > num_repeats_prev):
 171                    num_repeats_new = np.maximum(num_repeats, num_repeats_prev)
 172                    state_up["num_repeats_pbc"] = (
 173                        tuple(num_repeats_new),
 174                        tuple(num_repeats_prev),
 175                    )
 176                    new_state["num_repeats_pbc"] = tuple(num_repeats_new)
 177                ## build all possible shifts
 178                cell_shift_pbc = np.array(
 179                    np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
 180                    dtype=cells.dtype,
 181                ).T.reshape(-1, 3)
 182                ## shift applied to vectors
 183                if cells.shape[0] == 1:
 184                    dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :]
 185                    vec = (vec[:, None, :] + dvec).reshape(-1, 3)
 186                    pbc_shifts = np.broadcast_to(
 187                        cell_shift_pbc[None, :, :],
 188                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
 189                    ).reshape(-1, 3)
 190                    p1 = np.broadcast_to(
 191                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
 192                    ).flatten()
 193                    p2 = np.broadcast_to(
 194                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
 195                    ).flatten()
 196                    if natoms.shape[0] > 1:
 197                        mask_p12 = np.broadcast_to(
 198                            mask_p12[:, None],
 199                            (mask_p12.shape[0], cell_shift_pbc.shape[0]),
 200                        ).flatten()
 201                else:
 202                    dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells)
 203
 204                    ## get pbc shifts specific to each box
 205                    cell_shift_pbc = np.broadcast_to(
 206                        cell_shift_pbc[None, :, :],
 207                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
 208                    )
 209                    mask = np.all(
 210                        np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
 211                    ).flatten()
 212                    idx = np.nonzero(mask)[0]
 213                    nshifts = idx.shape[0]
 214                    nshifts_prev = state.get("nshifts_pbc", 0)
 215                    if nshifts > nshifts_prev or add_margin:
 216                        nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1
 217                        state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev)
 218                        new_state["nshifts_pbc"] = nshifts_new
 219
 220                    dvec_filter = dvec.reshape(-1, 3)[idx, :]
 221                    cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :]
 222
 223                    ## get batch shift in the dvec_filter array
 224                    nrep = np.prod(2 * num_repeats_all + 1, axis=1)
 225                    bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1]))
 226
 227                    ## compute vectors
 228                    batch_index_vec = batch_index[p1]
 229                    nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0)
 230                    vec = vec.repeat(nrep_vec, axis=0)
 231                    nvec_pbc = nrep_vec.sum() #vec.shape[0]
 232                    nvec_pbc_prev = state.get("nvec_pbc", 0)
 233                    if nvec_pbc > nvec_pbc_prev or add_margin:
 234                        nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1
 235                        state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev)
 236                        new_state["nvec_pbc"] = nvec_pbc_new
 237
 238                    # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev)
 239                    ## get shift index
 240                    dshift = np.concatenate(
 241                        (np.array([0]), np.cumsum(nrep_vec)[:-1])
 242                    ).repeat(nrep_vec)
 243                    # ishift = np.arange(dshift.shape[0])-dshift
 244                    # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec)
 245                    icellshift = (
 246                        np.arange(dshift.shape[0])
 247                        - dshift
 248                        + bshift[batch_index_vec].repeat(nrep_vec)
 249                    )
 250                    # shift vectors
 251                    vec = vec + dvec_filter[icellshift]
 252                    pbc_shifts = cell_shift_pbc_filter[icellshift]
 253
 254                    p1 = np.repeat(p1, nrep_vec)
 255                    p2 = np.repeat(p2, nrep_vec)
 256                    if natoms.shape[0] > 1:
 257                        mask_p12 = np.repeat(mask_p12, nrep_vec)
 258
 259        ## compute distances
 260        d12 = (vec**2).sum(axis=-1)
 261        if natoms.shape[0] > 1:
 262            d12 = np.where(mask_p12, d12, cutoff_skin**2)
 263
 264        ## filter pairs
 265        max_pairs = state.get("npairs", 1)
 266        mask = d12 < cutoff_skin**2
 267        idx = np.nonzero(mask)[0]
 268        npairs = idx.shape[0]
 269        if npairs > max_pairs or add_margin:
 270            prev_max_pairs = max_pairs
 271            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 272            state_up["npairs"] = (max_pairs, prev_max_pairs)
 273            new_state["npairs"] = max_pairs
 274
 275        nat = coords.shape[0]
 276        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 277        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 278        d12_ = np.full(max_pairs, cutoff_skin**2)
 279        edge_src[:npairs] = p1[idx]
 280        edge_dst[:npairs] = p2[idx]
 281        d12_[:npairs] = d12[idx]
 282        d12 = d12_
 283
 284        if apply_pbc:
 285            pbc_shifts_ = np.zeros((max_pairs, 3))
 286            pbc_shifts_[:npairs] = pbc_shifts[idx]
 287            pbc_shifts = pbc_shifts_
 288            if not minimage:
 289                pbc_shifts[:npairs] = (
 290                    pbc_shifts[:npairs]
 291                    + at_shifts[edge_dst[:npairs]]
 292                    - at_shifts[edge_src[:npairs]]
 293                )
 294
 295        if "nblist_skin" in state:
 296            edge_src_skin = edge_src
 297            edge_dst_skin = edge_dst
 298            if apply_pbc:
 299                pbc_shifts_skin = pbc_shifts
 300            max_pairs_skin = state.get("npairs_skin", 1)
 301            mask = d12 < self.cutoff**2
 302            idx = np.nonzero(mask)[0]
 303            npairs_skin = idx.shape[0]
 304            if npairs_skin > max_pairs_skin or add_margin:
 305                prev_max_pairs_skin = max_pairs_skin
 306                max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1
 307                state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin)
 308                new_state["npairs_skin"] = max_pairs_skin
 309            edge_src = np.full(max_pairs_skin, nat, dtype=np.int32)
 310            edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32)
 311            d12_ = np.full(max_pairs_skin, self.cutoff**2)
 312            edge_src[:npairs_skin] = edge_src_skin[idx]
 313            edge_dst[:npairs_skin] = edge_dst_skin[idx]
 314            d12_[:npairs_skin] = d12[idx]
 315            d12 = d12_
 316            if apply_pbc:
 317                pbc_shifts = np.full((max_pairs_skin, 3), 0.0)
 318                pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx]
 319
 320        ## symmetrize
 321        edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate(
 322            (edge_dst, edge_src)
 323        )
 324        d12 = np.concatenate((d12, d12))
 325        if apply_pbc:
 326            pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts))
 327
 328        graph = inputs.get(self.graph_key, {})
 329        graph_out = {
 330            **graph,
 331            "edge_src": edge_src,
 332            "edge_dst": edge_dst,
 333            "d12": d12,
 334            "overflow": False,
 335            "pbc_shifts": pbc_shifts,
 336        }
 337        if "nblist_skin" in state:
 338            graph_out["edge_src_skin"] = edge_src_skin
 339            graph_out["edge_dst_skin"] = edge_dst_skin
 340            if apply_pbc:
 341                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
 342
 343        if self.k_space and apply_pbc:
 344            if "k_points" not in graph:
 345                ks, _, _, bewald = get_reciprocal_space_parameters(
 346                    reciprocal_cells, self.cutoff, self.kmax, self.kthr
 347                )
 348            graph_out["k_points"] = ks
 349            graph_out["b_ewald"] = bewald
 350
 351        output = {**inputs, self.graph_key: graph_out}
 352
 353        if return_state_update:
 354            return FrozenDict(new_state), output, state_up
 355        return FrozenDict(new_state), output
 356
 357    def check_reallocate(self, state, inputs, parent_overflow=False):
 358        """check for overflow and reallocate nblist if necessary"""
 359        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 360        if not overflow:
 361            return state, {}, inputs, False
 362
 363        add_margin = inputs[self.graph_key].get("overflow", False)
 364        state, inputs, state_up = self(
 365            state, inputs, return_state_update=True, add_margin=add_margin
 366        )
 367        return state, state_up, inputs, True
 368
 369    @partial(jax.jit, static_argnums=(0, 1))
 370    def process(self, state, inputs):
 371        """build a nblist on accelerator with jax and precomputed shapes"""
 372        if self.graph_key in inputs:
 373            graph = inputs[self.graph_key]
 374            if "keep_graph" in graph:
 375                return inputs
 376        coords = inputs["coordinates"]
 377        natoms = inputs["natoms"]
 378        batch_index = inputs["batch_index"]
 379
 380        if natoms.shape[0] == 1:
 381            max_nat = coords.shape[0]
 382        else:
 383            max_nat = state.get(
 384                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
 385            )
 386        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
 387
 388        ### compute indices of all pairs
 389        p1, p2 = np.triu_indices(max_nat, 1)
 390        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
 391        pbc_shifts = None
 392        if natoms.shape[0] > 1:
 393            ## batching => mask irrelevant pairs
 394            mask_p12 = (
 395                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
 396            ).flatten()
 397            shift = jnp.concatenate(
 398                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
 399            )
 400            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
 401            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
 402
 403        ## compute vectors
 404        overflow_repeats = jnp.asarray(False, dtype=bool)
 405        if "cells" not in inputs:
 406            vec = coords[p2] - coords[p1]
 407        else:
 408            cells = inputs["cells"]
 409            reciprocal_cells = inputs["reciprocal_cells"]
 410            # minimage = state.get("minimum_image", True)
 411            minimage = "minimum_image" in inputs.get("flags", {})
 412
 413            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
 414                vecpbc = jnp.dot(vec, reciprocal_cell)
 415                if mode == "round":
 416                    pbc_shifts = -jnp.round(vecpbc)
 417                elif mode == "floor":
 418                    pbc_shifts = -jnp.floor(vecpbc)
 419                else:
 420                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
 421                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
 422
 423            if minimage:
 424                ## minimum image convention
 425                vec = coords[p2] - coords[p1]
 426
 427                if cells.shape[0] == 1:
 428                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
 429                else:
 430                    batch_index_vec = batch_index[p1]
 431                    vec, pbc_shifts = jax.vmap(compute_pbc)(
 432                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
 433                    )
 434            else:
 435                ### general PBC only for single cell yet
 436                # if cells.shape[0] > 1:
 437                #     raise NotImplementedError(
 438                #         "General PBC not implemented for batches on accelerator."
 439                #     )
 440                # cell = cells[0]
 441                # reciprocal_cell = reciprocal_cells[0]
 442
 443                ## put all atoms in central box
 444                if cells.shape[0] == 1:
 445                    coords_pbc, at_shifts = compute_pbc(
 446                        coords, reciprocal_cells[0], cells[0], mode="floor"
 447                    )
 448                else:
 449                    coords_pbc, at_shifts = jax.vmap(
 450                        partial(compute_pbc, mode="floor")
 451                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
 452                vec = coords_pbc[p2] - coords_pbc[p1]
 453                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
 454                # if num_repeats is None:
 455                #     raise ValueError(
 456                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
 457                #     )
 458                # check if num_repeats is larger than previous
 459                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
 460                cdinv = cutoff_skin * inv_distances
 461                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
 462                if "true_sys" in inputs:
 463                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
 464                num_repeats_new = jnp.max(num_repeats_all, axis=0)
 465                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
 466
 467                cell_shift_pbc = jnp.asarray(
 468                    np.array(
 469                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
 470                        dtype=cells.dtype,
 471                    ).T.reshape(-1, 3)
 472                )
 473
 474                if cells.shape[0] == 1:
 475                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
 476                    pbc_shifts = jnp.broadcast_to(
 477                        cell_shift_pbc[None, :, :],
 478                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
 479                    ).reshape(-1, 3)
 480                    p1 = jnp.broadcast_to(
 481                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
 482                    ).flatten()
 483                    p2 = jnp.broadcast_to(
 484                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
 485                    ).flatten()
 486                    if natoms.shape[0] > 1:
 487                        mask_p12 = jnp.broadcast_to(
 488                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
 489                        ).flatten()
 490                else:
 491                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
 492
 493                    ## get pbc shifts specific to each box
 494                    cell_shift_pbc = jnp.broadcast_to(
 495                        cell_shift_pbc[None, :, :],
 496                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
 497                    )
 498                    mask = jnp.all(
 499                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
 500                    ).flatten()
 501                    max_shifts  = state.get("nshifts_pbc", 1)
 502
 503                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
 504                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
 505                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
 506                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
 507                        mask,
 508                        max_shifts,
 509                        (dvecx, 0.),
 510                        (dvecy, 0.),
 511                        (dvecz, 0.),
 512                        (shiftx, 0),
 513                        (shifty, 0),
 514                        (shiftz, 0),
 515                    )
 516                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
 517                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
 518                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
 519
 520                    ## get batch shift in the dvec_filter array
 521                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
 522                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
 523
 524                    ## repeat vectors
 525                    nvec_max = state.get("nvec_pbc", 1)
 526                    batch_index_vec = batch_index[p1]
 527                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
 528                    nvec = nrep_vec.sum()
 529                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
 530                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
 531                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
 532
 533                    ## get shift index
 534                    dshift = jnp.concatenate(
 535                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
 536                    )
 537                    if nrep_vec.size == 0:
 538                        dshift = jnp.array([],dtype=jnp.int32)
 539                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
 540                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
 541                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
 542                    vec = vec + dvec[icellshift]
 543                    pbc_shifts = cell_shift_pbc[icellshift]
 544                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
 545                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
 546                    if natoms.shape[0] > 1:
 547                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
 548                
 549
 550        ## compute distances
 551        d12 = (vec**2).sum(axis=-1)
 552        if natoms.shape[0] > 1:
 553            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
 554
 555        ## filter pairs
 556        max_pairs = state.get("npairs", 1)
 557        mask = d12 < cutoff_skin**2
 558        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
 559            mask,
 560            max_pairs,
 561            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
 562            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
 563            (d12, cutoff_skin**2),
 564        )
 565        if "cells" in inputs:
 566            pbc_shifts = (
 567                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
 568                .at[scatter_idx]
 569                .set(pbc_shifts, mode="drop")
 570            )
 571            if not minimage:
 572                pbc_shifts = (
 573                    pbc_shifts
 574                    + at_shifts.at[edge_dst].get(fill_value=0.0)
 575                    - at_shifts.at[edge_src].get(fill_value=0.0)
 576                )
 577
 578        ## check for overflow
 579        if natoms.shape[0] == 1:
 580            true_max_nat = coords.shape[0]
 581        else:
 582            true_max_nat = jnp.max(natoms)
 583        overflow_count = npairs > max_pairs
 584        overflow_at = true_max_nat > max_nat
 585        overflow = overflow_count | overflow_at | overflow_repeats
 586
 587        if "nblist_skin" in state:
 588            # edge_mask_skin = edge_mask
 589            edge_src_skin = edge_src
 590            edge_dst_skin = edge_dst
 591            if "cells" in inputs:
 592                pbc_shifts_skin = pbc_shifts
 593            max_pairs_skin = state.get("npairs_skin", 1)
 594            mask = d12 < self.cutoff**2
 595            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
 596                mask,
 597                max_pairs_skin,
 598                (edge_src, coords.shape[0]),
 599                (edge_dst, coords.shape[0]),
 600                (d12, self.cutoff**2),
 601            )
 602            if "cells" in inputs:
 603                pbc_shifts = (
 604                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
 605                    .at[scatter_idx]
 606                    .set(pbc_shifts, mode="drop")
 607                )
 608            overflow = overflow | (npairs_skin > max_pairs_skin)
 609
 610        ## symmetrize
 611        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
 612            (edge_dst, edge_src)
 613        )
 614        d12 = jnp.concatenate((d12, d12))
 615        if "cells" in inputs:
 616            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
 617
 618        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 619        graph_out = {
 620            **graph,
 621            "edge_src": edge_src,
 622            "edge_dst": edge_dst,
 623            "d12": d12,
 624            "overflow": overflow,
 625            "pbc_shifts": pbc_shifts,
 626        }
 627        if "nblist_skin" in state:
 628            graph_out["edge_src_skin"] = edge_src_skin
 629            graph_out["edge_dst_skin"] = edge_dst_skin
 630            if "cells" in inputs:
 631                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
 632
 633        if self.k_space and "cells" in inputs:
 634            if "k_points" not in graph:
 635                raise NotImplementedError(
 636                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 637                )
 638        return {**inputs, self.graph_key: graph_out}
 639
 640    @partial(jax.jit, static_argnums=(0,))
 641    def update_skin(self, inputs):
 642        """update the nblist without recomputing the full nblist"""
 643        graph = inputs[self.graph_key]
 644
 645        edge_src_skin = graph["edge_src_skin"]
 646        edge_dst_skin = graph["edge_dst_skin"]
 647        coords = inputs["coordinates"]
 648        vec = coords.at[edge_dst_skin].get(
 649            mode="fill", fill_value=self.cutoff
 650        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
 651
 652        if "cells" in inputs:
 653            pbc_shifts_skin = graph["pbc_shifts_skin"]
 654            cells = inputs["cells"]
 655            if cells.shape[0] == 1:
 656                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
 657            else:
 658                batch_index_vec = inputs["batch_index"][edge_src_skin]
 659                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
 660
 661        nat = coords.shape[0]
 662        d12 = jnp.sum(vec**2, axis=-1)
 663        mask = d12 < self.cutoff**2
 664        max_pairs = graph["edge_src"].shape[0] // 2
 665        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
 666            mask,
 667            max_pairs,
 668            (edge_src_skin, nat),
 669            (edge_dst_skin, nat),
 670            (d12, self.cutoff**2),
 671        )
 672        if "cells" in inputs:
 673            pbc_shifts = (
 674                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
 675                .at[scatter_idx]
 676                .set(pbc_shifts_skin)
 677            )
 678
 679        overflow = graph.get("overflow", False) | (npairs > max_pairs)
 680        graph_out = {
 681            **graph,
 682            "edge_src": jnp.concatenate((edge_src, edge_dst)),
 683            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
 684            "d12": jnp.concatenate((d12, d12)),
 685            "overflow": overflow,
 686        }
 687        if "cells" in inputs:
 688            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
 689
 690        if self.k_space and "cells" in inputs:
 691            if "k_points" not in graph:
 692                raise NotImplementedError(
 693                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 694                )
 695
 696        return {**inputs, self.graph_key: graph_out}
 697
 698
 699class GraphProcessor(nn.Module):
 700    """Process a pre-generated graph
 701
 702    The pre-generated graph should contain the following keys:
 703    - edge_src: source indices of the edges
 704    - edge_dst: destination indices of the edges
 705    - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs)
 706
 707    This module is automatically added to a FENNIX model when a GraphGenerator is used.
 708
 709    """
 710
 711    cutoff: float
 712    """Cutoff distance for the graph."""
 713    graph_key: str = "graph"
 714    """Key of the graph in the outputs."""
 715    switch_params: dict = dataclasses.field(default_factory=dict)
 716    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 717
 718    @nn.compact
 719    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 720        graph = inputs[self.graph_key]
 721        coords = inputs["coordinates"]
 722        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
 723        # edge_mask = edge_src < coords.shape[0]
 724        vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[
 725            edge_src
 726        ].get(mode="fill", fill_value=0.0)
 727        if "cells" in inputs:
 728            cells = inputs["cells"]
 729            if cells.shape[0] == 1:
 730                vec = vec + jnp.dot(graph["pbc_shifts"], cells[0])
 731            else:
 732                batch_index_vec = inputs["batch_index"][edge_src]
 733                vec = vec + jax.vmap(jnp.dot)(
 734                    graph["pbc_shifts"], cells[batch_index_vec]
 735                )
 736
 737        d2  = jnp.sum(vec**2, axis=-1)
 738        distances = safe_sqrt(d2)
 739        edge_mask = distances < self.cutoff
 740
 741        switch = SwitchFunction(
 742            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
 743        )((distances, edge_mask))
 744
 745        graph_out = {
 746            **graph,
 747            "vec": vec,
 748            "distances": distances,
 749            "switch": switch,
 750            "edge_mask": edge_mask,
 751        }
 752
 753        if "alch_group" in inputs:
 754            alch_group = inputs["alch_group"]
 755            lambda_e = inputs["alch_elambda"]
 756            lambda_v = inputs["alch_vlambda"]
 757            mask = alch_group[edge_src] == alch_group[edge_dst]
 758            graph_out["switch_raw"] = switch
 759            graph_out["switch"] = jnp.where(
 760                mask,
 761                switch,
 762                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
 763            )
 764            graph_out["distances_raw"] = distances
 765            if "alch_softcore_e" in inputs:
 766                alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2
 767            else:
 768                alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2
 769
 770            graph_out["distances"] = jnp.where(
 771                mask,
 772                distances,
 773                safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2))
 774            )  
 775
 776
 777        return {**inputs, self.graph_key: graph_out}
 778
 779
 780@dataclasses.dataclass(frozen=True)
 781class GraphFilter:
 782    """Filter a graph based on a cutoff distance
 783
 784    FPID: GRAPH_FILTER
 785    """
 786
 787    cutoff: float
 788    """Cutoff distance for the filtering."""
 789    parent_graph: str
 790    """Key of the parent graph in the inputs."""
 791    graph_key: str
 792    """Key of the filtered graph in the outputs."""
 793    remove_hydrogens: int = False
 794    """Remove edges where the source is a hydrogen atom."""
 795    switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict)
 796    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 797    k_space: bool = False
 798    """Generate k-space information for the graph."""
 799    kmax: int = 30
 800    """Maximum number of k-points to consider."""
 801    kthr: float = 1e-6
 802    """Threshold for k-point filtering."""
 803    mult_size: float = 1.05
 804    """Multiplicative factor for resizing the nblist."""
 805
 806    FPID: ClassVar[str] = "GRAPH_FILTER"
 807
 808    def init(self):
 809        return FrozenDict(
 810            {
 811                "npairs": 1,
 812                "nblist_mult_size": self.mult_size,
 813            }
 814        )
 815
 816    def get_processor(self) -> Tuple[nn.Module, Dict]:
 817        return GraphFilterProcessor, {
 818            "cutoff": self.cutoff,
 819            "graph_key": self.graph_key,
 820            "parent_graph": self.parent_graph,
 821            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
 822            "switch_params": self.switch_params,
 823        }
 824
 825    def get_graph_properties(self):
 826        return {
 827            self.graph_key: {
 828                "cutoff": self.cutoff,
 829                "directed": True,
 830                "parent_graph": self.parent_graph,
 831            }
 832        }
 833
 834    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 835        """filter a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 836        graph_in = inputs[self.parent_graph]
 837        nat = inputs["species"].shape[0]
 838
 839        new_state = {**state}
 840        state_up = {}
 841        mult_size = state.get("nblist_mult_size", self.mult_size)
 842        assert mult_size >= 1., "nblist_mult_size should be >= 1."
 843
 844        edge_src = np.array(graph_in["edge_src"], dtype=np.int32)
 845        d12 = np.array(graph_in["d12"], dtype=np.float32)
 846        if self.remove_hydrogens:
 847            species = inputs["species"]
 848            src_idx = (edge_src < nat).nonzero()[0]
 849            mask = np.zeros(edge_src.shape[0], dtype=bool)
 850            mask[src_idx] = (species > 1)[edge_src[src_idx]]
 851            d12 = np.where(mask, d12, self.cutoff**2)
 852        mask = d12 < self.cutoff**2
 853
 854        max_pairs = state.get("npairs", 1)
 855        idx = np.nonzero(mask)[0]
 856        npairs = idx.shape[0]
 857        if npairs > max_pairs or add_margin:
 858            prev_max_pairs = max_pairs
 859            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
 860            state_up["npairs"] = (max_pairs, prev_max_pairs)
 861            new_state["npairs"] = max_pairs
 862
 863        filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32)
 864        edge_src = np.full(max_pairs, nat, dtype=np.int32)
 865        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
 866        d12_ = np.full(max_pairs, self.cutoff**2)
 867        filter_indices[:npairs] = idx
 868        edge_src[:npairs] = graph_in["edge_src"][idx]
 869        edge_dst[:npairs] = graph_in["edge_dst"][idx]
 870        d12_[:npairs] = d12[idx]
 871        d12 = d12_
 872
 873        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 874        graph_out = {
 875            **graph,
 876            "edge_src": edge_src,
 877            "edge_dst": edge_dst,
 878            "filter_indices": filter_indices,
 879            "d12": d12,
 880            "overflow": False,
 881        }
 882
 883        if self.k_space and "cells" in inputs:
 884            if "k_points" not in graph:
 885                ks, _, _, bewald = get_reciprocal_space_parameters(
 886                    inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr
 887                )
 888            graph_out["k_points"] = ks
 889            graph_out["b_ewald"] = bewald
 890
 891        output = {**inputs, self.graph_key: graph_out}
 892        if return_state_update:
 893            return FrozenDict(new_state), output, state_up
 894        return FrozenDict(new_state), output
 895
 896    def check_reallocate(self, state, inputs, parent_overflow=False):
 897        """check for overflow and reallocate nblist if necessary"""
 898        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
 899        if not overflow:
 900            return state, {}, inputs, False
 901
 902        add_margin = inputs[self.graph_key].get("overflow", False)
 903        state, inputs, state_up = self(
 904            state, inputs, return_state_update=True, add_margin=add_margin
 905        )
 906        return state, state_up, inputs, True
 907
 908    @partial(jax.jit, static_argnums=(0, 1))
 909    def process(self, state, inputs):
 910        """filter a nblist on accelerator with jax and precomputed shapes"""
 911        graph_in = inputs[self.parent_graph]
 912        if state is None:
 913            # skin update mode
 914            graph = inputs[self.graph_key]
 915            max_pairs = graph["edge_src"].shape[0]
 916        else:
 917            max_pairs = state.get("npairs", 1)
 918
 919        max_pairs_in = graph_in["edge_src"].shape[0]
 920        nat = inputs["species"].shape[0]
 921
 922        edge_src = graph_in["edge_src"]
 923        d12 = graph_in["d12"]
 924        if self.remove_hydrogens:
 925            species = inputs["species"]
 926            mask = (species > 1)[edge_src]
 927            d12 = jnp.where(mask, d12, self.cutoff**2)
 928        mask = d12 < self.cutoff**2
 929
 930        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
 931            mask,
 932            max_pairs,
 933            (edge_src, nat),
 934            (graph_in["edge_dst"], nat),
 935            (d12, self.cutoff**2),
 936            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
 937        )
 938
 939        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
 940        overflow = graph.get("overflow", False) | (npairs > max_pairs)
 941        graph_out = {
 942            **graph,
 943            "edge_src": edge_src,
 944            "edge_dst": edge_dst,
 945            "filter_indices": filter_indices,
 946            "d12": d12,
 947            "overflow": overflow,
 948        }
 949
 950        if self.k_space and "cells" in inputs:
 951            if "k_points" not in graph:
 952                raise NotImplementedError(
 953                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
 954                )
 955
 956        return {**inputs, self.graph_key: graph_out}
 957
 958    @partial(jax.jit, static_argnums=(0,))
 959    def update_skin(self, inputs):
 960        return self.process(None, inputs)
 961
 962
 963class GraphFilterProcessor(nn.Module):
 964    """Filter processing for a pre-generated graph
 965
 966    This module is automatically added to a FENNIX model when a GraphFilter is used.
 967    """
 968
 969    cutoff: float
 970    """Cutoff distance for the filtering."""
 971    graph_key: str
 972    """Key of the filtered graph in the inputs."""
 973    parent_graph: str
 974    """Key of the parent graph in the inputs."""
 975    switch_params: dict = dataclasses.field(default_factory=dict)
 976    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 977
 978    @nn.compact
 979    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 980        graph_in = inputs[self.parent_graph]
 981        graph = inputs[self.graph_key]
 982
 983        d_key = "distances_raw" if "distances_raw" in graph else "distances"
 984
 985        if graph_in["vec"].shape[0] == 0:
 986            vec = graph_in["vec"]
 987            distances = graph_in[d_key]
 988            filter_indices = jnp.asarray([], dtype=jnp.int32)
 989        else:
 990            filter_indices = graph["filter_indices"]
 991            vec = (
 992                graph_in["vec"]
 993                .at[filter_indices]
 994                .get(mode="fill", fill_value=self.cutoff)
 995            )
 996            distances = (
 997                graph_in[d_key]
 998                .at[filter_indices]
 999                .get(mode="fill", fill_value=self.cutoff)
1000            )
1001
1002        edge_mask = distances < self.cutoff
1003        switch = SwitchFunction(
1004            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
1005        )((distances, edge_mask))
1006
1007        graph_out = {
1008            **graph,
1009            "vec": vec,
1010            "distances": distances,
1011            "switch": switch,
1012            "filter_indices": filter_indices,
1013            "edge_mask": edge_mask,
1014        }
1015
1016        if "alch_group" in inputs:
1017            edge_src=graph["edge_src"]
1018            edge_dst=graph["edge_dst"]
1019            alch_group = inputs["alch_group"]
1020            lambda_e = inputs["alch_elambda"]
1021            lambda_v = inputs["alch_vlambda"]
1022            mask = alch_group[edge_src] == alch_group[edge_dst]
1023            graph_out["switch_raw"] = switch
1024            graph_out["switch"] = jnp.where(
1025                mask,
1026                switch,
1027                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
1028            )
1029
1030            graph_out["distances_raw"] = distances
1031            if "alch_softcore_e" in inputs:
1032                alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2
1033            else:
1034                alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2
1035
1036            graph_out["distances"] = jnp.where(
1037                mask,
1038                distances,
1039                safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2))
1040            )  
1041            
1042
1043        return {**inputs, self.graph_key: graph_out}
1044
1045
1046@dataclasses.dataclass(frozen=True)
1047class GraphAngularExtension:
1048    """Add angles list to a graph
1049
1050    FPID: GRAPH_ANGULAR_EXTENSION
1051    """
1052
1053    mult_size: float = 1.05
1054    """Multiplicative factor for resizing the nblist."""
1055    add_neigh: int = 5
1056    """Additional neighbors to add to the nblist when resizing."""
1057    graph_key: str = "graph"
1058    """Key of the graph in the inputs."""
1059
1060    FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION"
1061
1062    def init(self):
1063        return FrozenDict(
1064            {
1065                "nangles": 0,
1066                "nblist_mult_size": self.mult_size,
1067                "max_neigh": self.add_neigh,
1068                "add_neigh": self.add_neigh,
1069            }
1070        )
1071
1072    def get_processor(self) -> Tuple[nn.Module, Dict]:
1073        return GraphAngleProcessor, {
1074            "graph_key": self.graph_key,
1075            "name": f"{self.graph_key}_AngleProcessor",
1076        }
1077
1078    def get_graph_properties(self):
1079        return {
1080            self.graph_key: {
1081                "has_angles": True,
1082            }
1083        }
1084
1085    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1086        """build angle nblist on cpu with numpy and dynamic shapes + store max shapes"""
1087        graph = inputs[self.graph_key]
1088        edge_src = np.array(graph["edge_src"], dtype=np.int32)
1089
1090        new_state = {**state}
1091        state_up = {}
1092        mult_size = state.get("nblist_mult_size", self.mult_size)
1093        assert mult_size >= 1., "nblist_mult_size should be >= 1."
1094
1095        ### count number of neighbors
1096        nat = inputs["species"].shape[0]
1097        count = np.zeros(nat + 1, dtype=np.int32)
1098        np.add.at(count, edge_src, 1)
1099        max_count = int(np.max(count[:-1]))
1100
1101        ### get sizes
1102        max_neigh = state.get("max_neigh", self.add_neigh)
1103        nedge = edge_src.shape[0]
1104        if max_count > max_neigh or add_margin:
1105            prev_max_neigh = max_neigh
1106            max_neigh = max(max_count, max_neigh) + state.get(
1107                "add_neigh", self.add_neigh
1108            )
1109            state_up["max_neigh"] = (max_neigh, prev_max_neigh)
1110            new_state["max_neigh"] = max_neigh
1111
1112        max_neigh_arr = np.empty(max_neigh, dtype=bool)
1113
1114        nedge = edge_src.shape[0]
1115
1116        ### sort edge_src
1117        idx_sort = np.argsort(edge_src)
1118        edge_src_sorted = edge_src[idx_sort]
1119
1120        ### map sparse to dense nblist
1121        offset = np.tile(np.arange(max_count), nat)
1122        if max_count * nat >= nedge:
1123            offset = np.tile(np.arange(max_count), nat)[:nedge]
1124        else:
1125            offset = np.zeros(nedge, dtype=np.int32)
1126            offset[: max_count * nat] = np.tile(np.arange(max_count), nat)
1127
1128        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1129        mask = edge_src_sorted < nat
1130        indices = edge_src_sorted * max_count + offset
1131        indices = indices[mask]
1132        idx_sort = idx_sort[mask]
1133        edge_idx = np.full(nat * max_count, nedge, dtype=np.int32)
1134        edge_idx[indices] = idx_sort
1135        edge_idx = edge_idx.reshape(nat, max_count)
1136
1137        ### find all triplet for each atom center
1138        local_src, local_dst = np.triu_indices(max_count, 1)
1139        angle_src = edge_idx[:, local_src].flatten()
1140        angle_dst = edge_idx[:, local_dst].flatten()
1141
1142        ### mask for valid angles
1143        mask1 = angle_src < nedge
1144        mask2 = angle_dst < nedge
1145        angle_mask = mask1 & mask2
1146
1147        max_angles = state.get("nangles", 0)
1148        idx = np.nonzero(angle_mask)[0]
1149        nangles = idx.shape[0]
1150        if nangles > max_angles or add_margin:
1151            max_angles_prev = max_angles
1152            max_angles = int(mult_size * max(nangles, max_angles)) + 1
1153            state_up["nangles"] = (max_angles, max_angles_prev)
1154            new_state["nangles"] = max_angles
1155
1156        ## filter angles to sparse representation
1157        angle_src_ = np.full(max_angles, nedge, dtype=np.int32)
1158        angle_dst_ = np.full(max_angles, nedge, dtype=np.int32)
1159        angle_src_[:nangles] = angle_src[idx]
1160        angle_dst_[:nangles] = angle_dst[idx]
1161
1162        central_atom = np.full(max_angles, nat, dtype=np.int32)
1163        central_atom[:nangles] = edge_src[angle_src_[:nangles]]
1164
1165        ## update graph
1166        output = {
1167            **inputs,
1168            self.graph_key: {
1169                **graph,
1170                "angle_src": angle_src_,
1171                "angle_dst": angle_dst_,
1172                "central_atom": central_atom,
1173                "angle_overflow": False,
1174                "max_neigh": max_neigh,
1175                "__max_neigh_array": max_neigh_arr,
1176            },
1177        }
1178
1179        if return_state_update:
1180            return FrozenDict(new_state), output, state_up
1181        return FrozenDict(new_state), output
1182
1183    def check_reallocate(self, state, inputs, parent_overflow=False):
1184        """check for overflow and reallocate nblist if necessary"""
1185        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1186        if not overflow:
1187            return state, {}, inputs, False
1188
1189        add_margin = inputs[self.graph_key]["angle_overflow"]
1190        state, inputs, state_up = self(
1191            state, inputs, return_state_update=True, add_margin=add_margin
1192        )
1193        return state, state_up, inputs, True
1194
1195    @partial(jax.jit, static_argnums=(0, 1))
1196    def process(self, state, inputs):
1197        """build angle nblist on accelerator with jax and precomputed shapes"""
1198        graph = inputs[self.graph_key]
1199        edge_src = graph["edge_src"]
1200
1201        ### count number of neighbors
1202        nat = inputs["species"].shape[0]
1203        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1204        max_count = jnp.max(count)
1205
1206        ### get sizes
1207        if state is None:
1208            max_neigh_arr = graph["__max_neigh_array"]
1209            max_neigh = max_neigh_arr.shape[0]
1210            prev_nangles = graph["angle_src"].shape[0]
1211        else:
1212            max_neigh = state.get("max_neigh", self.add_neigh)
1213            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1214            prev_nangles = state.get("nangles", 0)
1215
1216        nedge = edge_src.shape[0]
1217
1218        ### sort edge_src
1219        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1220        edge_src_sorted = edge_src[idx_sort]
1221
1222        ### map sparse to dense nblist
1223        if max_neigh * nat < nedge:
1224            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1225        offset = jnp.asarray(
1226            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1227        )
1228        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1229        indices = edge_src_sorted * max_neigh + offset
1230        edge_idx = (
1231            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1232            .at[indices]
1233            .set(idx_sort, mode="drop")
1234            .reshape(nat, max_neigh)
1235        )
1236
1237        ### find all triplet for each atom center
1238        local_src, local_dst = np.triu_indices(max_neigh, 1)
1239        angle_src = edge_idx[:, local_src].flatten()
1240        angle_dst = edge_idx[:, local_dst].flatten()
1241
1242        ### mask for valid angles
1243        mask1 = angle_src < nedge
1244        mask2 = angle_dst < nedge
1245        angle_mask = mask1 & mask2
1246
1247        ## filter angles to sparse representation
1248        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1249            angle_mask,
1250            prev_nangles,
1251            (angle_src, nedge),
1252            (angle_dst, nedge),
1253        )
1254        ## find central atom
1255        central_atom = edge_src[angle_src]
1256
1257        ## check for overflow
1258        angle_overflow = nangles > prev_nangles
1259        neigh_overflow = max_count > max_neigh
1260        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1261
1262        ## update graph
1263        output = {
1264            **inputs,
1265            self.graph_key: {
1266                **graph,
1267                "angle_src": angle_src,
1268                "angle_dst": angle_dst,
1269                "central_atom": central_atom,
1270                "angle_overflow": overflow,
1271                # "max_neigh": max_neigh,
1272                "__max_neigh_array": max_neigh_arr,
1273            },
1274        }
1275
1276        return output
1277
1278    @partial(jax.jit, static_argnums=(0,))
1279    def update_skin(self, inputs):
1280        return self.process(None, inputs)
1281
1282
1283class GraphAngleProcessor(nn.Module):
1284    """Process a pre-generated graph to compute angles
1285
1286    This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
1287
1288    """
1289
1290    graph_key: str
1291    """Key of the graph in the inputs."""
1292
1293    @nn.compact
1294    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1295        graph = inputs[self.graph_key]
1296        distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"]
1297        vec = graph["vec"]
1298        angle_src = graph["angle_src"]
1299        angle_dst = graph["angle_dst"]
1300
1301        dir = vec / jnp.clip(distances[:, None], min=1.0e-5)
1302        cos_angles = (
1303            dir.at[angle_src].get(mode="fill", fill_value=0.5)
1304            * dir.at[angle_dst].get(mode="fill", fill_value=0.5)
1305        ).sum(axis=-1)
1306
1307        angles = jnp.arccos(0.95 * cos_angles)
1308
1309        return {
1310            **inputs,
1311            self.graph_key: {
1312                **graph,
1313                # "cos_angles": cos_angles,
1314                "angles": angles,
1315                # "angle_mask": angle_mask,
1316            },
1317        }
1318
1319
1320@dataclasses.dataclass(frozen=True)
1321class SpeciesIndexer:
1322    """Build an index that splits atomic arrays by species.
1323
1324    FPID: SPECIES_INDEXER
1325
1326    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1327    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1328
1329    """
1330
1331    output_key: str = "species_index"
1332    """Key for the output dictionary."""
1333    species_order: Optional[str] = None
1334    """Comma separated list of species in the order they should be indexed."""
1335    add_atoms: int = 0
1336    """Additional atoms to add to the sizes."""
1337    add_atoms_margin: int = 10
1338    """Additional atoms to add to the sizes when adding margin."""
1339
1340    FPID: ClassVar[str] = "SPECIES_INDEXER"
1341
1342    def init(self):
1343        return FrozenDict(
1344            {
1345                "sizes": {},
1346            }
1347        )
1348
1349    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1350        species = np.array(inputs["species"], dtype=np.int32)
1351        nat = species.shape[0]
1352        set_species, counts = np.unique(species, return_counts=True)
1353
1354        new_state = {**state}
1355        state_up = {}
1356
1357        sizes = state.get("sizes", FrozenDict({}))
1358        new_sizes = {**sizes}
1359        up_sizes = False
1360        counts_dict = {}
1361        for s, c in zip(set_species, counts):
1362            if s <= 0:
1363                continue
1364            counts_dict[s] = c
1365            if c > sizes.get(s, 0):
1366                up_sizes = True
1367                add_atoms = state.get("add_atoms", self.add_atoms)
1368                if add_margin:
1369                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1370                new_sizes[s] = c + add_atoms
1371
1372        new_sizes = FrozenDict(new_sizes)
1373        if up_sizes:
1374            state_up["sizes"] = (new_sizes, sizes)
1375            new_state["sizes"] = new_sizes
1376
1377        if self.species_order is not None:
1378            species_order = [el.strip() for el in self.species_order.split(",")]
1379            max_size_prev = state.get("max_size", 0)
1380            max_size = max(new_sizes.values())
1381            if max_size > max_size_prev:
1382                state_up["max_size"] = (max_size, max_size_prev)
1383                new_state["max_size"] = max_size
1384                max_size_prev = max_size
1385
1386            species_index = np.full((len(species_order), max_size), nat, dtype=np.int32)
1387            for i, el in enumerate(species_order):
1388                s = PERIODIC_TABLE_REV_IDX[el]
1389                if s in counts_dict.keys():
1390                    species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1391        else:
1392            species_index = {
1393                PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32)
1394                for s, c in new_sizes.items()
1395            }
1396            for s, c in zip(set_species, counts):
1397                if s <= 0:
1398                    continue
1399                species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0]
1400
1401        output = {
1402            **inputs,
1403            self.output_key: species_index,
1404            self.output_key + "_overflow": False,
1405        }
1406
1407        if return_state_update:
1408            return FrozenDict(new_state), output, state_up
1409        return FrozenDict(new_state), output
1410
1411    def check_reallocate(self, state, inputs, parent_overflow=False):
1412        """check for overflow and reallocate nblist if necessary"""
1413        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1414        if not overflow:
1415            return state, {}, inputs, False
1416
1417        add_margin = inputs[self.output_key + "_overflow"]
1418        state, inputs, state_up = self(
1419            state, inputs, return_state_update=True, add_margin=add_margin
1420        )
1421        return state, state_up, inputs, True
1422        # return state, {}, inputs, parent_overflow
1423
1424    @partial(jax.jit, static_argnums=(0, 1))
1425    def process(self, state, inputs):
1426        # assert (
1427        #     self.output_key in inputs
1428        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1429
1430        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1431        if self.output_key in inputs and not recompute_species_index:
1432            return inputs
1433
1434        if state is None:
1435            raise ValueError("Species Indexer state must be provided on accelerator.")
1436
1437        species = inputs["species"]
1438        nat = species.shape[0]
1439
1440        sizes = state["sizes"]
1441
1442        if self.species_order is not None:
1443            species_order = [el.strip() for el in self.species_order.split(",")]
1444            max_size = state["max_size"]
1445
1446            species_index = jnp.full(
1447                (len(species_order), max_size), nat, dtype=jnp.int32
1448            )
1449            for i, el in enumerate(species_order):
1450                s = PERIODIC_TABLE_REV_IDX[el]
1451                if s in sizes.keys():
1452                    c = sizes[s]
1453                    species_index = species_index.at[i, :].set(
1454                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1455                    )
1456                # if s in counts_dict.keys():
1457                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1458        else:
1459            # species_index = {
1460            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1461            # for s, c in sizes.items()
1462            # }
1463            species_index = {}
1464            overflow = False
1465            natcount = 0
1466            for s, c in sizes.items():
1467                mask = species == s
1468                new_size = jnp.sum(mask)
1469                natcount = natcount + new_size
1470                overflow = overflow | (new_size > c)  # check if sizes are correct
1471                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1472                    species == s, size=c, fill_value=nat
1473                )[0]
1474
1475            mask = species <= 0
1476            new_size = jnp.sum(mask)
1477            natcount = natcount + new_size
1478            overflow = overflow | (
1479                natcount < species.shape[0]
1480            )  # check if any species missing
1481
1482        return {
1483            **inputs,
1484            self.output_key: species_index,
1485            self.output_key + "_overflow": overflow,
1486        }
1487
1488    @partial(jax.jit, static_argnums=(0,))
1489    def update_skin(self, inputs):
1490        return self.process(None, inputs)
1491
1492@dataclasses.dataclass(frozen=True)
1493class BlockIndexer:
1494    """Build an index that splits atomic arrays by chemical blocks.
1495
1496    FPID: BLOCK_INDEXER
1497
1498    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1499    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1500
1501    """
1502
1503    output_key: str = "block_index"
1504    """Key for the output dictionary."""
1505    add_atoms: int = 0
1506    """Additional atoms to add to the sizes."""
1507    add_atoms_margin: int = 10
1508    """Additional atoms to add to the sizes when adding margin."""
1509    split_CNOPSSe: bool = False
1510
1511    FPID: ClassVar[str] = "BLOCK_INDEXER"
1512
1513    def init(self):
1514        return FrozenDict(
1515            {
1516                "sizes": {},
1517            }
1518        )
1519
1520    def build_chemical_blocks(self):
1521        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1522        if self.split_CNOPSSe:
1523            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1524            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1525        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1526        if self.split_CNOPSSe:
1527            _CHEMICAL_BLOCKS[6] = 1
1528            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1529            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1530            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1531            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1532            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1533        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1534
1535    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1536        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1537
1538        species = np.array(inputs["species"], dtype=np.int32)
1539        blocks = _CHEMICAL_BLOCKS[species]
1540        nat = species.shape[0]
1541        set_blocks, counts = np.unique(blocks, return_counts=True)
1542
1543        new_state = {**state}
1544        state_up = {}
1545
1546        sizes = state.get("sizes", FrozenDict({}))
1547        new_sizes = {**sizes}
1548        up_sizes = False
1549        for s, c in zip(set_blocks, counts):
1550            if s < 0:
1551                continue
1552            key = (s, _CHEMICAL_BLOCKS_NAMES[s])
1553            if c > sizes.get(key, 0):
1554                up_sizes = True
1555                add_atoms = state.get("add_atoms", self.add_atoms)
1556                if add_margin:
1557                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1558                new_sizes[key] = c + add_atoms
1559
1560        new_sizes = FrozenDict(new_sizes)
1561        if up_sizes:
1562            state_up["sizes"] = (new_sizes, sizes)
1563            new_state["sizes"] = new_sizes
1564
1565        block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES}
1566        for (_,n), c in new_sizes.items():
1567            block_index[n] = np.full(c, nat, dtype=np.int32)
1568        # block_index = {
1569            # n: np.full(c, nat, dtype=np.int32)
1570            # for (_,n), c in new_sizes.items()
1571        # }
1572        for s, c in zip(set_blocks, counts):
1573            if s < 0:
1574                continue
1575            block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0]
1576
1577        output = {
1578            **inputs,
1579            self.output_key: block_index,
1580            self.output_key + "_overflow": False,
1581        }
1582
1583        if return_state_update:
1584            return FrozenDict(new_state), output, state_up
1585        return FrozenDict(new_state), output
1586
1587    def check_reallocate(self, state, inputs, parent_overflow=False):
1588        """check for overflow and reallocate nblist if necessary"""
1589        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1590        if not overflow:
1591            return state, {}, inputs, False
1592
1593        add_margin = inputs[self.output_key + "_overflow"]
1594        state, inputs, state_up = self(
1595            state, inputs, return_state_update=True, add_margin=add_margin
1596        )
1597        return state, state_up, inputs, True
1598        # return state, {}, inputs, parent_overflow
1599
1600    @partial(jax.jit, static_argnums=(0, 1))
1601    def process(self, state, inputs):
1602        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1603        # assert (
1604        #     self.output_key in inputs
1605        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1606
1607        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1608        if self.output_key in inputs and not recompute_species_index:
1609            return inputs
1610
1611        if state is None:
1612            raise ValueError("Block Indexer state must be provided on accelerator.")
1613
1614        species = inputs["species"]
1615        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1616        nat = species.shape[0]
1617
1618        sizes = state["sizes"]
1619
1620        # species_index = {
1621        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1622        # for s, c in sizes.items()
1623        # }
1624        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1625        overflow = False
1626        natcount = 0
1627        for (s,name), c in sizes.items():
1628            mask = blocks == s
1629            new_size = jnp.sum(mask)
1630            natcount = natcount + new_size
1631            overflow = overflow | (new_size > c)  # check if sizes are correct
1632            block_index[name] = jnp.nonzero(
1633                mask, size=c, fill_value=nat
1634            )[0]
1635
1636        mask = blocks < 0
1637        new_size = jnp.sum(mask)
1638        natcount = natcount + new_size
1639        overflow = overflow | (
1640            natcount < species.shape[0]
1641        )  # check if any species missing
1642
1643        return {
1644            **inputs,
1645            self.output_key: block_index,
1646            self.output_key + "_overflow": overflow,
1647        }
1648
1649    @partial(jax.jit, static_argnums=(0,))
1650    def update_skin(self, inputs):
1651        return self.process(None, inputs)
1652
1653
1654@dataclasses.dataclass(frozen=True)
1655class AtomPadding:
1656    """Pad atomic arrays to a fixed size."""
1657
1658    mult_size: float = 1.2
1659    """Multiplicative factor for resizing the atomic arrays."""
1660    add_sys: int = 0
1661
1662    def init(self):
1663        return {"prev_nat": 0, "prev_nsys": 0}
1664
1665    def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]:
1666        species = inputs["species"]
1667        nat = species.shape[0]
1668
1669        prev_nat = state.get("prev_nat", 0)
1670        prev_nat_ = prev_nat
1671        if nat > prev_nat_:
1672            prev_nat_ = int(self.mult_size * nat) + 1
1673
1674        nsys = len(inputs["natoms"])
1675        prev_nsys = state.get("prev_nsys", 0)
1676        prev_nsys_ = prev_nsys
1677        if nsys > prev_nsys_:
1678            prev_nsys_ = nsys + self.add_sys
1679
1680        add_atoms = prev_nat_ - nat
1681        add_sys = prev_nsys_ - nsys  + 1
1682        output = {**inputs}
1683        if add_atoms > 0:
1684            for k, v in inputs.items():
1685                if isinstance(v, np.ndarray) or isinstance(v, jax.Array):
1686                    if v.shape[0] == nat:
1687                        output[k] = np.append(
1688                            v,
1689                            np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype),
1690                            axis=0,
1691                        )
1692                    elif v.shape[0] == nsys:
1693                        if k == "cells":
1694                            output[k] = np.append(
1695                                v,
1696                                1000
1697                                * np.eye(3, dtype=v.dtype)[None, :, :].repeat(
1698                                    add_sys, axis=0
1699                                ),
1700                                axis=0,
1701                            )
1702                        else:
1703                            output[k] = np.append(
1704                                v,
1705                                np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype),
1706                                axis=0,
1707                            )
1708            output["natoms"] = np.append(
1709                inputs["natoms"], np.zeros(add_sys, dtype=np.int32)
1710            )
1711            output["species"] = np.append(
1712                species, -1 * np.ones(add_atoms, dtype=species.dtype)
1713            )
1714            output["batch_index"] = np.append(
1715                inputs["batch_index"],
1716                np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype),
1717            )
1718            if "system_index" in inputs:
1719                output["system_index"] = np.append(
1720                    inputs["system_index"],
1721                    np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype),
1722                )
1723
1724        output["true_atoms"] = output["species"] > 0
1725        output["true_sys"] = np.arange(len(output["natoms"])) < nsys
1726
1727        state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_}
1728
1729        return FrozenDict(state), output
1730
1731
1732def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1733    """Remove padding from atomic arrays."""
1734    if "true_atoms" not in inputs:
1735        return inputs
1736
1737    species = np.asarray(inputs["species"])
1738    true_atoms = np.asarray(inputs["true_atoms"])
1739    true_sys = np.asarray(inputs["true_sys"])
1740    natall = species.shape[0]
1741    nat = np.argmax(species <= 0)
1742    if nat == 0:
1743        return inputs
1744
1745    natoms = inputs["natoms"]
1746    nsysall = len(natoms)
1747
1748    output = {**inputs}
1749    for k, v in inputs.items():
1750        if isinstance(v, jax.Array) or isinstance(v, np.ndarray):
1751            v = np.asarray(v)
1752            if v.ndim == 0:
1753                output[k] = v
1754            elif v.shape[0] == natall:
1755                output[k] = v[true_atoms]
1756            elif v.shape[0] == nsysall:
1757                output[k] = v[true_sys]
1758    del output["true_sys"]
1759    del output["true_atoms"]
1760    return output
1761
1762
1763def check_input(inputs):
1764    """Check the input dictionary for required keys and types."""
1765    assert "species" in inputs, "species must be provided"
1766    assert "coordinates" in inputs, "coordinates must be provided"
1767    species = inputs["species"].astype(np.int32)
1768    ifake = np.argmax(species <= 0)
1769    if ifake > 0:
1770        assert np.all(species[:ifake] > 0), "species must be positive"
1771    nat = inputs["species"].shape[0]
1772
1773    natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32)
1774    batch_index = inputs.get(
1775        "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms)
1776    ).astype(np.int32)
1777    output = {**inputs, "natoms": natoms, "batch_index": batch_index}
1778    if "cells" in inputs:
1779        cells = inputs["cells"]
1780        if "reciprocal_cells" not in inputs:
1781            reciprocal_cells = np.linalg.inv(cells)
1782        else:
1783            reciprocal_cells = inputs["reciprocal_cells"]
1784        if cells.ndim == 2:
1785            cells = cells[None, :, :]
1786        if reciprocal_cells.ndim == 2:
1787            reciprocal_cells = reciprocal_cells[None, :, :]
1788        output["cells"] = cells
1789        output["reciprocal_cells"] = reciprocal_cells
1790
1791    return output
1792
1793
1794def convert_to_jax(data):
1795    """Convert a numpy arrays to jax arrays in a pytree."""
1796
1797    def convert(x):
1798        if isinstance(x, np.ndarray):
1799            # if x.dtype == np.float64:
1800            #     return jnp.asarray(x, dtype=jnp.float32)
1801            return jnp.asarray(x)
1802        return x
1803
1804    return jax.tree_util.tree_map(convert, data)
1805
1806
1807class JaxConverter(nn.Module):
1808    """Convert numpy arrays to jax arrays in a pytree."""
1809
1810    def __call__(self, data):
1811        return convert_to_jax(data)
1812
1813
1814@dataclasses.dataclass(frozen=True)
1815class PreprocessingChain:
1816    """Chain of preprocessing layers."""
1817
1818    layers: Tuple[Callable[..., Dict[str, Any]]]
1819    """Preprocessing layers."""
1820    use_atom_padding: bool = False
1821    """Add an AtomPadding layer at the beginning of the chain."""
1822    atom_padder: AtomPadding = AtomPadding()
1823    """AtomPadding layer."""
1824
1825    def __post_init__(self):
1826        if not isinstance(self.layers, Sequence):
1827            raise ValueError(
1828                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
1829            )
1830        if not self.layers:
1831            raise ValueError(f"Error: no Preprocessing layers were provided.")
1832
1833    def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]:
1834        do_check_input = state.get("check_input", True)
1835        if do_check_input:
1836            inputs = check_input(inputs)
1837        new_state = {**state}
1838        if self.use_atom_padding:
1839            s, inputs = self.atom_padder(state["padder_state"], inputs)
1840            new_state["padder_state"] = s
1841        layer_state = state["layers_state"]
1842        new_layer_state = []
1843        for i,layer in enumerate(self.layers):
1844            s, inputs = layer(layer_state[i], inputs, return_state_update=False)
1845            new_layer_state.append(s)
1846        new_state["layers_state"] = tuple(new_layer_state)
1847        return FrozenDict(new_state), convert_to_jax(inputs)
1848
1849    def check_reallocate(self, state, inputs):
1850        new_state = []
1851        state_up = []
1852        layer_state = state["layers_state"]
1853        parent_overflow = False
1854        for i,layer in enumerate(self.layers):
1855            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1856                layer_state[i], inputs, parent_overflow
1857            )
1858            new_state.append(s)
1859            state_up.append(s_up)
1860
1861        if not parent_overflow:
1862            return state, {}, inputs, False
1863        return (
1864            FrozenDict({**state, "layers_state": tuple(new_state)}),
1865            state_up,
1866            inputs,
1867            True,
1868        )
1869
1870    def atom_padding(self, state, inputs):
1871        if self.use_atom_padding:
1872            padder_state,inputs = self.atom_padder(state["padder_state"], inputs)
1873            return FrozenDict({**state,"padder_state": padder_state}), inputs
1874        return state, inputs
1875
1876    @partial(jax.jit, static_argnums=(0, 1))
1877    def process(self, state, inputs):
1878        layer_state = state["layers_state"]
1879        for i,layer in enumerate(self.layers):
1880            inputs = layer.process(layer_state[i], inputs)
1881        return inputs
1882
1883    @partial(jax.jit, static_argnums=(0))
1884    def update_skin(self, inputs):
1885        for layer in self.layers:
1886            inputs = layer.update_skin(inputs)
1887        return inputs
1888
1889    def init(self):
1890        state = {"check_input": True}
1891        if self.use_atom_padding:
1892            state["padder_state"] = self.atom_padder.init()
1893        layer_state = []
1894        for layer in self.layers:
1895            layer_state.append(layer.init())
1896        state["layers_state"] = tuple(layer_state)
1897        return FrozenDict(state)
1898
1899    def init_with_output(self, inputs):
1900        state = self.init()
1901        return self(state, inputs)
1902
1903    def get_processors(self):
1904        processors = []
1905        for layer in self.layers:
1906            if hasattr(layer, "get_processor"):
1907                processors.append(layer.get_processor())
1908        return processors
1909
1910    def get_graphs_properties(self):
1911        properties = {}
1912        for layer in self.layers:
1913            if hasattr(layer, "get_graph_properties"):
1914                properties = deep_update(properties, layer.get_graph_properties())
1915        return properties
1916
1917
1918# PREPROCESSING = {
1919#     "GRAPH": GraphGenerator,
1920#     # "GRAPH_FIXED": GraphGeneratorFixed,
1921#     "GRAPH_FILTER": GraphFilter,
1922#     "GRAPH_ANGULAR_EXTENSION": GraphAngularExtension,
1923#     # "GRAPH_DENSE_EXTENSION": GraphDenseExtension,
1924#     "SPECIES_INDEXER": SpeciesIndexer,
1925# }
@dataclasses.dataclass(frozen=True)
class GraphGenerator:
 22@dataclasses.dataclass(frozen=True)
 23class GraphGenerator:
 24    """Generate a graph from a set of coordinates
 25
 26    FPID: GRAPH
 27
 28    For now, we generate all pairs of atoms and filter based on cutoff.
 29    If a `nblist_skin` is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the `update_skin` method to update the original graph without recomputing the full nblist.
 30    """
 31
 32    cutoff: float
 33    """Cutoff distance for the graph."""
 34    graph_key: str = "graph"
 35    """Key of the graph in the outputs."""
 36    switch_params: dict = dataclasses.field(default_factory=dict, hash=False)
 37    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 38    kmax: int = 30
 39    """Maximum number of k-points to consider."""
 40    kthr: float = 1e-6
 41    """Threshold for k-point filtering."""
 42    k_space: bool = False
 43    """Whether to generate k-space information for the graph."""
 44    mult_size: float = 1.05
 45    """Multiplicative factor for resizing the nblist."""
 46    # covalent_cutoff: bool = False
 47
 48    FPID: ClassVar[str] = "GRAPH"
 49
 50    def init(self):
 51        return FrozenDict(
 52            {
 53                "max_nat": 1,
 54                "npairs": 1,
 55                "nblist_mult_size": self.mult_size,
 56            }
 57        )
 58
 59    def get_processor(self) -> Tuple[nn.Module, Dict]:
 60        return GraphProcessor, {
 61            "cutoff": self.cutoff,
 62            "graph_key": self.graph_key,
 63            "switch_params": self.switch_params,
 64            "name": f"{self.graph_key}_Processor",
 65        }
 66
 67    def get_graph_properties(self):
 68        return {
 69            self.graph_key: {
 70                "cutoff": self.cutoff,
 71                "directed": True,
 72            }
 73        }
 74
 75    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
 76        """build a nblist on cpu with numpy and dynamic shapes + store max shapes"""
 77        if self.graph_key in inputs:
 78            graph = inputs[self.graph_key]
 79            if "keep_graph" in graph:
 80                return state, inputs
 81
 82        coords = np.array(inputs["coordinates"], dtype=np.float32)
 83        natoms = np.array(inputs["natoms"], dtype=np.int32)
 84        batch_index = np.array(inputs["batch_index"], dtype=np.int32)
 85
 86        new_state = {**state}
 87        state_up = {}
 88
 89        mult_size = state.get("nblist_mult_size", self.mult_size)
 90        assert mult_size >= 1.0, "mult_size should be larger or equal than 1.0"
 91
 92        if natoms.shape[0] == 1:
 93            max_nat = coords.shape[0]
 94            true_max_nat = max_nat
 95        else:
 96            max_nat = state.get("max_nat", round(coords.shape[0] / natoms.shape[0]))
 97            true_max_nat = int(np.max(natoms))
 98            if true_max_nat > max_nat:
 99                add_atoms = state.get("add_atoms", 0)
100                new_maxnat = true_max_nat + add_atoms
101                state_up["max_nat"] = (new_maxnat, max_nat)
102                new_state["max_nat"] = new_maxnat
103
104        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
105
106        ### compute indices of all pairs
107        p1, p2 = np.triu_indices(true_max_nat, 1)
108        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
109        pbc_shifts = None
110        if natoms.shape[0] > 1:
111            ## batching => mask irrelevant pairs
112            mask_p12 = (
113                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
114            ).flatten()
115            shift = np.concatenate(
116                (np.array([0], dtype=np.int32), np.cumsum(natoms[:-1], dtype=np.int32))
117            )
118            p1 = np.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
119            p2 = np.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
120
121        apply_pbc = "cells" in inputs
122        if not apply_pbc:
123            ### NO PBC
124            vec = coords[p2] - coords[p1]
125        else:
126            cells = np.array(inputs["cells"], dtype=np.float32)
127            reciprocal_cells = np.array(inputs["reciprocal_cells"], dtype=np.float32)
128            minimage = "minimum_image" in inputs.get("flags", {})
129            if minimage:
130                ## MINIMUM IMAGE CONVENTION
131                vec = coords[p2] - coords[p1]
132                if cells.shape[0] == 1:
133                    vecpbc = np.dot(vec, reciprocal_cells[0])
134                    pbc_shifts = -np.round(vecpbc)
135                    vec = vec + np.dot(pbc_shifts, cells[0])
136                else:
137                    batch_index_vec = batch_index[p1]
138                    vecpbc = np.einsum(
139                        "aj,aji->ai", vec, reciprocal_cells[batch_index_vec]
140                    )
141                    pbc_shifts = -np.round(vecpbc)
142                    vec = vec + np.einsum(
143                        "aj,aji->ai", pbc_shifts, cells[batch_index_vec]
144                    )
145            else:
146                ### GENERAL PBC
147                ## put all atoms in central box
148                if cells.shape[0] == 1:
149                    coords_pbc = np.dot(coords, reciprocal_cells[0])
150                    at_shifts = -np.floor(coords_pbc)
151                    coords_pbc = coords + np.dot(at_shifts, cells[0])
152                else:
153                    coords_pbc = np.einsum(
154                        "aj,aji->ai", coords, reciprocal_cells[batch_index]
155                    )
156                    at_shifts = -np.floor(coords_pbc)
157                    coords_pbc = coords + np.einsum(
158                        "aj,aji->ai", at_shifts, cells[batch_index]
159                    )
160                vec = coords_pbc[p2] - coords_pbc[p1]
161
162                ## compute maximum number of repeats
163                inv_distances = (np.sum(reciprocal_cells**2, axis=1)) ** 0.5
164                cdinv = cutoff_skin * inv_distances
165                num_repeats_all = np.ceil(cdinv).astype(np.int32)
166                if "true_sys" in inputs:
167                    num_repeats_all = np.where(np.array(inputs["true_sys"],dtype=bool)[:, None], num_repeats_all, 0)
168                # num_repeats_all = np.where(cdinv < 0.5, 0, num_repeats_all)
169                num_repeats = np.max(num_repeats_all, axis=0)
170                num_repeats_prev = np.array(state.get("num_repeats_pbc", (0, 0, 0)))
171                if np.any(num_repeats > num_repeats_prev):
172                    num_repeats_new = np.maximum(num_repeats, num_repeats_prev)
173                    state_up["num_repeats_pbc"] = (
174                        tuple(num_repeats_new),
175                        tuple(num_repeats_prev),
176                    )
177                    new_state["num_repeats_pbc"] = tuple(num_repeats_new)
178                ## build all possible shifts
179                cell_shift_pbc = np.array(
180                    np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
181                    dtype=cells.dtype,
182                ).T.reshape(-1, 3)
183                ## shift applied to vectors
184                if cells.shape[0] == 1:
185                    dvec = np.dot(cell_shift_pbc, cells[0])[None, :, :]
186                    vec = (vec[:, None, :] + dvec).reshape(-1, 3)
187                    pbc_shifts = np.broadcast_to(
188                        cell_shift_pbc[None, :, :],
189                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
190                    ).reshape(-1, 3)
191                    p1 = np.broadcast_to(
192                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
193                    ).flatten()
194                    p2 = np.broadcast_to(
195                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
196                    ).flatten()
197                    if natoms.shape[0] > 1:
198                        mask_p12 = np.broadcast_to(
199                            mask_p12[:, None],
200                            (mask_p12.shape[0], cell_shift_pbc.shape[0]),
201                        ).flatten()
202                else:
203                    dvec = np.einsum("bj,sji->sbi", cell_shift_pbc, cells)
204
205                    ## get pbc shifts specific to each box
206                    cell_shift_pbc = np.broadcast_to(
207                        cell_shift_pbc[None, :, :],
208                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
209                    )
210                    mask = np.all(
211                        np.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
212                    ).flatten()
213                    idx = np.nonzero(mask)[0]
214                    nshifts = idx.shape[0]
215                    nshifts_prev = state.get("nshifts_pbc", 0)
216                    if nshifts > nshifts_prev or add_margin:
217                        nshifts_new = int(mult_size * max(nshifts, nshifts_prev)) + 1
218                        state_up["nshifts_pbc"] = (nshifts_new, nshifts_prev)
219                        new_state["nshifts_pbc"] = nshifts_new
220
221                    dvec_filter = dvec.reshape(-1, 3)[idx, :]
222                    cell_shift_pbc_filter = cell_shift_pbc.reshape(-1, 3)[idx, :]
223
224                    ## get batch shift in the dvec_filter array
225                    nrep = np.prod(2 * num_repeats_all + 1, axis=1)
226                    bshift = np.concatenate((np.array([0]), np.cumsum(nrep)[:-1]))
227
228                    ## compute vectors
229                    batch_index_vec = batch_index[p1]
230                    nrep_vec = np.where(mask_p12,nrep[batch_index_vec],0)
231                    vec = vec.repeat(nrep_vec, axis=0)
232                    nvec_pbc = nrep_vec.sum() #vec.shape[0]
233                    nvec_pbc_prev = state.get("nvec_pbc", 0)
234                    if nvec_pbc > nvec_pbc_prev or add_margin:
235                        nvec_pbc_new = int(mult_size * max(nvec_pbc, nvec_pbc_prev)) + 1
236                        state_up["nvec_pbc"] = (nvec_pbc_new, nvec_pbc_prev)
237                        new_state["nvec_pbc"] = nvec_pbc_new
238
239                    # print("cpu: ", nvec_pbc, nvec_pbc_prev, nshifts, nshifts_prev)
240                    ## get shift index
241                    dshift = np.concatenate(
242                        (np.array([0]), np.cumsum(nrep_vec)[:-1])
243                    ).repeat(nrep_vec)
244                    # ishift = np.arange(dshift.shape[0])-dshift
245                    # bshift_vec_rep = bshift[batch_index_vec].repeat(nrep_vec)
246                    icellshift = (
247                        np.arange(dshift.shape[0])
248                        - dshift
249                        + bshift[batch_index_vec].repeat(nrep_vec)
250                    )
251                    # shift vectors
252                    vec = vec + dvec_filter[icellshift]
253                    pbc_shifts = cell_shift_pbc_filter[icellshift]
254
255                    p1 = np.repeat(p1, nrep_vec)
256                    p2 = np.repeat(p2, nrep_vec)
257                    if natoms.shape[0] > 1:
258                        mask_p12 = np.repeat(mask_p12, nrep_vec)
259
260        ## compute distances
261        d12 = (vec**2).sum(axis=-1)
262        if natoms.shape[0] > 1:
263            d12 = np.where(mask_p12, d12, cutoff_skin**2)
264
265        ## filter pairs
266        max_pairs = state.get("npairs", 1)
267        mask = d12 < cutoff_skin**2
268        idx = np.nonzero(mask)[0]
269        npairs = idx.shape[0]
270        if npairs > max_pairs or add_margin:
271            prev_max_pairs = max_pairs
272            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
273            state_up["npairs"] = (max_pairs, prev_max_pairs)
274            new_state["npairs"] = max_pairs
275
276        nat = coords.shape[0]
277        edge_src = np.full(max_pairs, nat, dtype=np.int32)
278        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
279        d12_ = np.full(max_pairs, cutoff_skin**2)
280        edge_src[:npairs] = p1[idx]
281        edge_dst[:npairs] = p2[idx]
282        d12_[:npairs] = d12[idx]
283        d12 = d12_
284
285        if apply_pbc:
286            pbc_shifts_ = np.zeros((max_pairs, 3))
287            pbc_shifts_[:npairs] = pbc_shifts[idx]
288            pbc_shifts = pbc_shifts_
289            if not minimage:
290                pbc_shifts[:npairs] = (
291                    pbc_shifts[:npairs]
292                    + at_shifts[edge_dst[:npairs]]
293                    - at_shifts[edge_src[:npairs]]
294                )
295
296        if "nblist_skin" in state:
297            edge_src_skin = edge_src
298            edge_dst_skin = edge_dst
299            if apply_pbc:
300                pbc_shifts_skin = pbc_shifts
301            max_pairs_skin = state.get("npairs_skin", 1)
302            mask = d12 < self.cutoff**2
303            idx = np.nonzero(mask)[0]
304            npairs_skin = idx.shape[0]
305            if npairs_skin > max_pairs_skin or add_margin:
306                prev_max_pairs_skin = max_pairs_skin
307                max_pairs_skin = int(mult_size * max(npairs_skin, max_pairs_skin)) + 1
308                state_up["npairs_skin"] = (max_pairs_skin, prev_max_pairs_skin)
309                new_state["npairs_skin"] = max_pairs_skin
310            edge_src = np.full(max_pairs_skin, nat, dtype=np.int32)
311            edge_dst = np.full(max_pairs_skin, nat, dtype=np.int32)
312            d12_ = np.full(max_pairs_skin, self.cutoff**2)
313            edge_src[:npairs_skin] = edge_src_skin[idx]
314            edge_dst[:npairs_skin] = edge_dst_skin[idx]
315            d12_[:npairs_skin] = d12[idx]
316            d12 = d12_
317            if apply_pbc:
318                pbc_shifts = np.full((max_pairs_skin, 3), 0.0)
319                pbc_shifts[:npairs_skin] = pbc_shifts_skin[idx]
320
321        ## symmetrize
322        edge_src, edge_dst = np.concatenate((edge_src, edge_dst)), np.concatenate(
323            (edge_dst, edge_src)
324        )
325        d12 = np.concatenate((d12, d12))
326        if apply_pbc:
327            pbc_shifts = np.concatenate((pbc_shifts, -pbc_shifts))
328
329        graph = inputs.get(self.graph_key, {})
330        graph_out = {
331            **graph,
332            "edge_src": edge_src,
333            "edge_dst": edge_dst,
334            "d12": d12,
335            "overflow": False,
336            "pbc_shifts": pbc_shifts,
337        }
338        if "nblist_skin" in state:
339            graph_out["edge_src_skin"] = edge_src_skin
340            graph_out["edge_dst_skin"] = edge_dst_skin
341            if apply_pbc:
342                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
343
344        if self.k_space and apply_pbc:
345            if "k_points" not in graph:
346                ks, _, _, bewald = get_reciprocal_space_parameters(
347                    reciprocal_cells, self.cutoff, self.kmax, self.kthr
348                )
349            graph_out["k_points"] = ks
350            graph_out["b_ewald"] = bewald
351
352        output = {**inputs, self.graph_key: graph_out}
353
354        if return_state_update:
355            return FrozenDict(new_state), output, state_up
356        return FrozenDict(new_state), output
357
358    def check_reallocate(self, state, inputs, parent_overflow=False):
359        """check for overflow and reallocate nblist if necessary"""
360        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
361        if not overflow:
362            return state, {}, inputs, False
363
364        add_margin = inputs[self.graph_key].get("overflow", False)
365        state, inputs, state_up = self(
366            state, inputs, return_state_update=True, add_margin=add_margin
367        )
368        return state, state_up, inputs, True
369
370    @partial(jax.jit, static_argnums=(0, 1))
371    def process(self, state, inputs):
372        """build a nblist on accelerator with jax and precomputed shapes"""
373        if self.graph_key in inputs:
374            graph = inputs[self.graph_key]
375            if "keep_graph" in graph:
376                return inputs
377        coords = inputs["coordinates"]
378        natoms = inputs["natoms"]
379        batch_index = inputs["batch_index"]
380
381        if natoms.shape[0] == 1:
382            max_nat = coords.shape[0]
383        else:
384            max_nat = state.get(
385                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
386            )
387        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
388
389        ### compute indices of all pairs
390        p1, p2 = np.triu_indices(max_nat, 1)
391        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
392        pbc_shifts = None
393        if natoms.shape[0] > 1:
394            ## batching => mask irrelevant pairs
395            mask_p12 = (
396                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
397            ).flatten()
398            shift = jnp.concatenate(
399                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
400            )
401            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
402            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
403
404        ## compute vectors
405        overflow_repeats = jnp.asarray(False, dtype=bool)
406        if "cells" not in inputs:
407            vec = coords[p2] - coords[p1]
408        else:
409            cells = inputs["cells"]
410            reciprocal_cells = inputs["reciprocal_cells"]
411            # minimage = state.get("minimum_image", True)
412            minimage = "minimum_image" in inputs.get("flags", {})
413
414            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
415                vecpbc = jnp.dot(vec, reciprocal_cell)
416                if mode == "round":
417                    pbc_shifts = -jnp.round(vecpbc)
418                elif mode == "floor":
419                    pbc_shifts = -jnp.floor(vecpbc)
420                else:
421                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
422                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
423
424            if minimage:
425                ## minimum image convention
426                vec = coords[p2] - coords[p1]
427
428                if cells.shape[0] == 1:
429                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
430                else:
431                    batch_index_vec = batch_index[p1]
432                    vec, pbc_shifts = jax.vmap(compute_pbc)(
433                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
434                    )
435            else:
436                ### general PBC only for single cell yet
437                # if cells.shape[0] > 1:
438                #     raise NotImplementedError(
439                #         "General PBC not implemented for batches on accelerator."
440                #     )
441                # cell = cells[0]
442                # reciprocal_cell = reciprocal_cells[0]
443
444                ## put all atoms in central box
445                if cells.shape[0] == 1:
446                    coords_pbc, at_shifts = compute_pbc(
447                        coords, reciprocal_cells[0], cells[0], mode="floor"
448                    )
449                else:
450                    coords_pbc, at_shifts = jax.vmap(
451                        partial(compute_pbc, mode="floor")
452                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
453                vec = coords_pbc[p2] - coords_pbc[p1]
454                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
455                # if num_repeats is None:
456                #     raise ValueError(
457                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
458                #     )
459                # check if num_repeats is larger than previous
460                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
461                cdinv = cutoff_skin * inv_distances
462                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
463                if "true_sys" in inputs:
464                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
465                num_repeats_new = jnp.max(num_repeats_all, axis=0)
466                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
467
468                cell_shift_pbc = jnp.asarray(
469                    np.array(
470                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
471                        dtype=cells.dtype,
472                    ).T.reshape(-1, 3)
473                )
474
475                if cells.shape[0] == 1:
476                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
477                    pbc_shifts = jnp.broadcast_to(
478                        cell_shift_pbc[None, :, :],
479                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
480                    ).reshape(-1, 3)
481                    p1 = jnp.broadcast_to(
482                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
483                    ).flatten()
484                    p2 = jnp.broadcast_to(
485                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
486                    ).flatten()
487                    if natoms.shape[0] > 1:
488                        mask_p12 = jnp.broadcast_to(
489                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
490                        ).flatten()
491                else:
492                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
493
494                    ## get pbc shifts specific to each box
495                    cell_shift_pbc = jnp.broadcast_to(
496                        cell_shift_pbc[None, :, :],
497                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
498                    )
499                    mask = jnp.all(
500                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
501                    ).flatten()
502                    max_shifts  = state.get("nshifts_pbc", 1)
503
504                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
505                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
506                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
507                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
508                        mask,
509                        max_shifts,
510                        (dvecx, 0.),
511                        (dvecy, 0.),
512                        (dvecz, 0.),
513                        (shiftx, 0),
514                        (shifty, 0),
515                        (shiftz, 0),
516                    )
517                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
518                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
519                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
520
521                    ## get batch shift in the dvec_filter array
522                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
523                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
524
525                    ## repeat vectors
526                    nvec_max = state.get("nvec_pbc", 1)
527                    batch_index_vec = batch_index[p1]
528                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
529                    nvec = nrep_vec.sum()
530                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
531                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
532                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
533
534                    ## get shift index
535                    dshift = jnp.concatenate(
536                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
537                    )
538                    if nrep_vec.size == 0:
539                        dshift = jnp.array([],dtype=jnp.int32)
540                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
541                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
542                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
543                    vec = vec + dvec[icellshift]
544                    pbc_shifts = cell_shift_pbc[icellshift]
545                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
546                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
547                    if natoms.shape[0] > 1:
548                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
549                
550
551        ## compute distances
552        d12 = (vec**2).sum(axis=-1)
553        if natoms.shape[0] > 1:
554            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
555
556        ## filter pairs
557        max_pairs = state.get("npairs", 1)
558        mask = d12 < cutoff_skin**2
559        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
560            mask,
561            max_pairs,
562            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
563            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
564            (d12, cutoff_skin**2),
565        )
566        if "cells" in inputs:
567            pbc_shifts = (
568                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
569                .at[scatter_idx]
570                .set(pbc_shifts, mode="drop")
571            )
572            if not minimage:
573                pbc_shifts = (
574                    pbc_shifts
575                    + at_shifts.at[edge_dst].get(fill_value=0.0)
576                    - at_shifts.at[edge_src].get(fill_value=0.0)
577                )
578
579        ## check for overflow
580        if natoms.shape[0] == 1:
581            true_max_nat = coords.shape[0]
582        else:
583            true_max_nat = jnp.max(natoms)
584        overflow_count = npairs > max_pairs
585        overflow_at = true_max_nat > max_nat
586        overflow = overflow_count | overflow_at | overflow_repeats
587
588        if "nblist_skin" in state:
589            # edge_mask_skin = edge_mask
590            edge_src_skin = edge_src
591            edge_dst_skin = edge_dst
592            if "cells" in inputs:
593                pbc_shifts_skin = pbc_shifts
594            max_pairs_skin = state.get("npairs_skin", 1)
595            mask = d12 < self.cutoff**2
596            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
597                mask,
598                max_pairs_skin,
599                (edge_src, coords.shape[0]),
600                (edge_dst, coords.shape[0]),
601                (d12, self.cutoff**2),
602            )
603            if "cells" in inputs:
604                pbc_shifts = (
605                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
606                    .at[scatter_idx]
607                    .set(pbc_shifts, mode="drop")
608                )
609            overflow = overflow | (npairs_skin > max_pairs_skin)
610
611        ## symmetrize
612        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
613            (edge_dst, edge_src)
614        )
615        d12 = jnp.concatenate((d12, d12))
616        if "cells" in inputs:
617            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
618
619        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
620        graph_out = {
621            **graph,
622            "edge_src": edge_src,
623            "edge_dst": edge_dst,
624            "d12": d12,
625            "overflow": overflow,
626            "pbc_shifts": pbc_shifts,
627        }
628        if "nblist_skin" in state:
629            graph_out["edge_src_skin"] = edge_src_skin
630            graph_out["edge_dst_skin"] = edge_dst_skin
631            if "cells" in inputs:
632                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
633
634        if self.k_space and "cells" in inputs:
635            if "k_points" not in graph:
636                raise NotImplementedError(
637                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
638                )
639        return {**inputs, self.graph_key: graph_out}
640
641    @partial(jax.jit, static_argnums=(0,))
642    def update_skin(self, inputs):
643        """update the nblist without recomputing the full nblist"""
644        graph = inputs[self.graph_key]
645
646        edge_src_skin = graph["edge_src_skin"]
647        edge_dst_skin = graph["edge_dst_skin"]
648        coords = inputs["coordinates"]
649        vec = coords.at[edge_dst_skin].get(
650            mode="fill", fill_value=self.cutoff
651        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
652
653        if "cells" in inputs:
654            pbc_shifts_skin = graph["pbc_shifts_skin"]
655            cells = inputs["cells"]
656            if cells.shape[0] == 1:
657                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
658            else:
659                batch_index_vec = inputs["batch_index"][edge_src_skin]
660                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
661
662        nat = coords.shape[0]
663        d12 = jnp.sum(vec**2, axis=-1)
664        mask = d12 < self.cutoff**2
665        max_pairs = graph["edge_src"].shape[0] // 2
666        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
667            mask,
668            max_pairs,
669            (edge_src_skin, nat),
670            (edge_dst_skin, nat),
671            (d12, self.cutoff**2),
672        )
673        if "cells" in inputs:
674            pbc_shifts = (
675                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
676                .at[scatter_idx]
677                .set(pbc_shifts_skin)
678            )
679
680        overflow = graph.get("overflow", False) | (npairs > max_pairs)
681        graph_out = {
682            **graph,
683            "edge_src": jnp.concatenate((edge_src, edge_dst)),
684            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
685            "d12": jnp.concatenate((d12, d12)),
686            "overflow": overflow,
687        }
688        if "cells" in inputs:
689            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
690
691        if self.k_space and "cells" in inputs:
692            if "k_points" not in graph:
693                raise NotImplementedError(
694                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
695                )
696
697        return {**inputs, self.graph_key: graph_out}

Generate a graph from a set of coordinates

FPID: GRAPH

For now, we generate all pairs of atoms and filter based on cutoff. If a nblist_skin is present in the state, we generate a second graph with a larger cutoff that includes all pairs within the cutoff+skin. This graph is then reused by the update_skin method to update the original graph without recomputing the full nblist.

GraphGenerator( cutoff: float, graph_key: str = 'graph', switch_params: dict = <factory>, kmax: int = 30, kthr: float = 1e-06, k_space: bool = False, mult_size: float = 1.05)
cutoff: float

Cutoff distance for the graph.

graph_key: str = 'graph'

Key of the graph in the outputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

kmax: int = 30

Maximum number of k-points to consider.

kthr: float = 1e-06

Threshold for k-point filtering.

k_space: bool = False

Whether to generate k-space information for the graph.

mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

FPID: ClassVar[str] = 'GRAPH'
def init(self):
50    def init(self):
51        return FrozenDict(
52            {
53                "max_nat": 1,
54                "npairs": 1,
55                "nblist_mult_size": self.mult_size,
56            }
57        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
59    def get_processor(self) -> Tuple[nn.Module, Dict]:
60        return GraphProcessor, {
61            "cutoff": self.cutoff,
62            "graph_key": self.graph_key,
63            "switch_params": self.switch_params,
64            "name": f"{self.graph_key}_Processor",
65        }
def get_graph_properties(self):
67    def get_graph_properties(self):
68        return {
69            self.graph_key: {
70                "cutoff": self.cutoff,
71                "directed": True,
72            }
73        }
def check_reallocate(self, state, inputs, parent_overflow=False):
358    def check_reallocate(self, state, inputs, parent_overflow=False):
359        """check for overflow and reallocate nblist if necessary"""
360        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
361        if not overflow:
362            return state, {}, inputs, False
363
364        add_margin = inputs[self.graph_key].get("overflow", False)
365        state, inputs, state_up = self(
366            state, inputs, return_state_update=True, add_margin=add_margin
367        )
368        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
370    @partial(jax.jit, static_argnums=(0, 1))
371    def process(self, state, inputs):
372        """build a nblist on accelerator with jax and precomputed shapes"""
373        if self.graph_key in inputs:
374            graph = inputs[self.graph_key]
375            if "keep_graph" in graph:
376                return inputs
377        coords = inputs["coordinates"]
378        natoms = inputs["natoms"]
379        batch_index = inputs["batch_index"]
380
381        if natoms.shape[0] == 1:
382            max_nat = coords.shape[0]
383        else:
384            max_nat = state.get(
385                "max_nat", int(round(coords.shape[0] / natoms.shape[0]))
386            )
387        cutoff_skin = self.cutoff + state.get("nblist_skin", 0.0)
388
389        ### compute indices of all pairs
390        p1, p2 = np.triu_indices(max_nat, 1)
391        p1, p2 = p1.astype(np.int32), p2.astype(np.int32)
392        pbc_shifts = None
393        if natoms.shape[0] > 1:
394            ## batching => mask irrelevant pairs
395            mask_p12 = (
396                (p1[None, :] < natoms[:, None]) * (p2[None, :] < natoms[:, None])
397            ).flatten()
398            shift = jnp.concatenate(
399                (jnp.array([0], dtype=jnp.int32), jnp.cumsum(natoms[:-1]))
400            )
401            p1 = jnp.where(mask_p12, (p1[None, :] + shift[:, None]).flatten(), -1)
402            p2 = jnp.where(mask_p12, (p2[None, :] + shift[:, None]).flatten(), -1)
403
404        ## compute vectors
405        overflow_repeats = jnp.asarray(False, dtype=bool)
406        if "cells" not in inputs:
407            vec = coords[p2] - coords[p1]
408        else:
409            cells = inputs["cells"]
410            reciprocal_cells = inputs["reciprocal_cells"]
411            # minimage = state.get("minimum_image", True)
412            minimage = "minimum_image" in inputs.get("flags", {})
413
414            def compute_pbc(vec, reciprocal_cell, cell, mode="round"):
415                vecpbc = jnp.dot(vec, reciprocal_cell)
416                if mode == "round":
417                    pbc_shifts = -jnp.round(vecpbc)
418                elif mode == "floor":
419                    pbc_shifts = -jnp.floor(vecpbc)
420                else:
421                    raise NotImplementedError(f"Unknown mode {mode} for compute_pbc.")
422                return vec + jnp.dot(pbc_shifts, cell), pbc_shifts
423
424            if minimage:
425                ## minimum image convention
426                vec = coords[p2] - coords[p1]
427
428                if cells.shape[0] == 1:
429                    vec, pbc_shifts = compute_pbc(vec, reciprocal_cells[0], cells[0])
430                else:
431                    batch_index_vec = batch_index[p1]
432                    vec, pbc_shifts = jax.vmap(compute_pbc)(
433                        vec, reciprocal_cells[batch_index_vec], cells[batch_index_vec]
434                    )
435            else:
436                ### general PBC only for single cell yet
437                # if cells.shape[0] > 1:
438                #     raise NotImplementedError(
439                #         "General PBC not implemented for batches on accelerator."
440                #     )
441                # cell = cells[0]
442                # reciprocal_cell = reciprocal_cells[0]
443
444                ## put all atoms in central box
445                if cells.shape[0] == 1:
446                    coords_pbc, at_shifts = compute_pbc(
447                        coords, reciprocal_cells[0], cells[0], mode="floor"
448                    )
449                else:
450                    coords_pbc, at_shifts = jax.vmap(
451                        partial(compute_pbc, mode="floor")
452                    )(coords, reciprocal_cells[batch_index], cells[batch_index])
453                vec = coords_pbc[p2] - coords_pbc[p1]
454                num_repeats = state.get("num_repeats_pbc", (0, 0, 0))
455                # if num_repeats is None:
456                #     raise ValueError(
457                #         "num_repeats_pbc should be provided for general PBC on accelerator. Call the numpy routine (self.__call__) first."
458                #     )
459                # check if num_repeats is larger than previous
460                inv_distances = jnp.linalg.norm(reciprocal_cells, axis=1)
461                cdinv = cutoff_skin * inv_distances
462                num_repeats_all = jnp.ceil(cdinv).astype(jnp.int32)
463                if "true_sys" in inputs:
464                    num_repeats_all = jnp.where(inputs["true_sys"][:,None], num_repeats_all, 0)
465                num_repeats_new = jnp.max(num_repeats_all, axis=0)
466                overflow_repeats = jnp.any(num_repeats_new > jnp.asarray(num_repeats))
467
468                cell_shift_pbc = jnp.asarray(
469                    np.array(
470                        np.meshgrid(*[np.arange(-n, n + 1) for n in num_repeats]),
471                        dtype=cells.dtype,
472                    ).T.reshape(-1, 3)
473                )
474
475                if cells.shape[0] == 1:
476                    vec = (vec[:,None,:] + jnp.dot(cell_shift_pbc, cells[0])[None, :, :]).reshape(-1, 3)    
477                    pbc_shifts = jnp.broadcast_to(
478                        cell_shift_pbc[None, :, :],
479                        (p1.shape[0], cell_shift_pbc.shape[0], 3),
480                    ).reshape(-1, 3)
481                    p1 = jnp.broadcast_to(
482                        p1[:, None], (p1.shape[0], cell_shift_pbc.shape[0])
483                    ).flatten()
484                    p2 = jnp.broadcast_to(
485                        p2[:, None], (p2.shape[0], cell_shift_pbc.shape[0])
486                    ).flatten()
487                    if natoms.shape[0] > 1:
488                        mask_p12 = jnp.broadcast_to(
489                            mask_p12[:, None], (mask_p12.shape[0], cell_shift_pbc.shape[0])
490                        ).flatten()
491                else:
492                    dvec = jnp.einsum("bj,sji->sbi", cell_shift_pbc, cells).reshape(-1, 3)
493
494                    ## get pbc shifts specific to each box
495                    cell_shift_pbc = jnp.broadcast_to(
496                        cell_shift_pbc[None, :, :],
497                        (num_repeats_all.shape[0], cell_shift_pbc.shape[0], 3),
498                    )
499                    mask = jnp.all(
500                        jnp.abs(cell_shift_pbc) <= num_repeats_all[:, None, :], axis=-1
501                    ).flatten()
502                    max_shifts  = state.get("nshifts_pbc", 1)
503
504                    cell_shift_pbc = cell_shift_pbc.reshape(-1,3)
505                    shiftx,shifty,shiftz = cell_shift_pbc[:,0],cell_shift_pbc[:,1],cell_shift_pbc[:,2]
506                    dvecx,dvecy,dvecz = dvec[:,0],dvec[:,1],dvec[:,2]
507                    (dvecx, dvecy,dvecz,shiftx,shifty,shiftz), scatter_idx, nshifts = mask_filter_1d(
508                        mask,
509                        max_shifts,
510                        (dvecx, 0.),
511                        (dvecy, 0.),
512                        (dvecz, 0.),
513                        (shiftx, 0),
514                        (shifty, 0),
515                        (shiftz, 0),
516                    )
517                    dvec = jnp.stack((dvecx,dvecy,dvecz),axis=-1)
518                    cell_shift_pbc = jnp.stack((shiftx,shifty,shiftz),axis=-1)
519                    overflow_repeats = overflow_repeats | (nshifts > max_shifts)
520
521                    ## get batch shift in the dvec_filter array
522                    nrep = jnp.prod(2 * num_repeats_all + 1, axis=1)
523                    bshift = jnp.concatenate((jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep)[:-1]))
524
525                    ## repeat vectors
526                    nvec_max = state.get("nvec_pbc", 1)
527                    batch_index_vec = batch_index[p1]
528                    nrep_vec = jnp.where(mask_p12,nrep[batch_index_vec],0)
529                    nvec = nrep_vec.sum()
530                    overflow_repeats = overflow_repeats | (nvec > nvec_max)
531                    vec = jnp.repeat(vec,nrep_vec,axis=0,total_repeat_length=nvec_max)
532                    # jax.debug.print("{nvec} {nvec_max} {nshifts} {max_shifts}",nvec=nvec,nvec_max=jnp.asarray(nvec_max),nshifts=nshifts,max_shifts=jnp.asarray(max_shifts))
533
534                    ## get shift index
535                    dshift = jnp.concatenate(
536                        (jnp.array([0],dtype=jnp.int32), jnp.cumsum(nrep_vec)[:-1])
537                    )
538                    if nrep_vec.size == 0:
539                        dshift = jnp.array([],dtype=jnp.int32)
540                    dshift = jnp.repeat(dshift,nrep_vec, total_repeat_length=nvec_max)
541                    bshift = jnp.repeat(bshift[batch_index_vec],nrep_vec, total_repeat_length=nvec_max)
542                    icellshift = jnp.arange(dshift.shape[0]) - dshift + bshift
543                    vec = vec + dvec[icellshift]
544                    pbc_shifts = cell_shift_pbc[icellshift]
545                    p1 = jnp.repeat(p1,nrep_vec, total_repeat_length=nvec_max)
546                    p2 = jnp.repeat(p2,nrep_vec, total_repeat_length=nvec_max)
547                    if natoms.shape[0] > 1:
548                        mask_p12 = jnp.repeat(mask_p12,nrep_vec, total_repeat_length=nvec_max)
549                
550
551        ## compute distances
552        d12 = (vec**2).sum(axis=-1)
553        if natoms.shape[0] > 1:
554            d12 = jnp.where(mask_p12, d12, cutoff_skin**2)
555
556        ## filter pairs
557        max_pairs = state.get("npairs", 1)
558        mask = d12 < cutoff_skin**2
559        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
560            mask,
561            max_pairs,
562            (jnp.asarray(p1, dtype=jnp.int32), coords.shape[0]),
563            (jnp.asarray(p2, dtype=jnp.int32), coords.shape[0]),
564            (d12, cutoff_skin**2),
565        )
566        if "cells" in inputs:
567            pbc_shifts = (
568                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts.dtype)
569                .at[scatter_idx]
570                .set(pbc_shifts, mode="drop")
571            )
572            if not minimage:
573                pbc_shifts = (
574                    pbc_shifts
575                    + at_shifts.at[edge_dst].get(fill_value=0.0)
576                    - at_shifts.at[edge_src].get(fill_value=0.0)
577                )
578
579        ## check for overflow
580        if natoms.shape[0] == 1:
581            true_max_nat = coords.shape[0]
582        else:
583            true_max_nat = jnp.max(natoms)
584        overflow_count = npairs > max_pairs
585        overflow_at = true_max_nat > max_nat
586        overflow = overflow_count | overflow_at | overflow_repeats
587
588        if "nblist_skin" in state:
589            # edge_mask_skin = edge_mask
590            edge_src_skin = edge_src
591            edge_dst_skin = edge_dst
592            if "cells" in inputs:
593                pbc_shifts_skin = pbc_shifts
594            max_pairs_skin = state.get("npairs_skin", 1)
595            mask = d12 < self.cutoff**2
596            (edge_src, edge_dst, d12), scatter_idx, npairs_skin = mask_filter_1d(
597                mask,
598                max_pairs_skin,
599                (edge_src, coords.shape[0]),
600                (edge_dst, coords.shape[0]),
601                (d12, self.cutoff**2),
602            )
603            if "cells" in inputs:
604                pbc_shifts = (
605                    jnp.full((max_pairs_skin, 3), 0.0, dtype=pbc_shifts.dtype)
606                    .at[scatter_idx]
607                    .set(pbc_shifts, mode="drop")
608                )
609            overflow = overflow | (npairs_skin > max_pairs_skin)
610
611        ## symmetrize
612        edge_src, edge_dst = jnp.concatenate((edge_src, edge_dst)), jnp.concatenate(
613            (edge_dst, edge_src)
614        )
615        d12 = jnp.concatenate((d12, d12))
616        if "cells" in inputs:
617            pbc_shifts = jnp.concatenate((pbc_shifts, -pbc_shifts))
618
619        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
620        graph_out = {
621            **graph,
622            "edge_src": edge_src,
623            "edge_dst": edge_dst,
624            "d12": d12,
625            "overflow": overflow,
626            "pbc_shifts": pbc_shifts,
627        }
628        if "nblist_skin" in state:
629            graph_out["edge_src_skin"] = edge_src_skin
630            graph_out["edge_dst_skin"] = edge_dst_skin
631            if "cells" in inputs:
632                graph_out["pbc_shifts_skin"] = pbc_shifts_skin
633
634        if self.k_space and "cells" in inputs:
635            if "k_points" not in graph:
636                raise NotImplementedError(
637                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
638                )
639        return {**inputs, self.graph_key: graph_out}

build a nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
641    @partial(jax.jit, static_argnums=(0,))
642    def update_skin(self, inputs):
643        """update the nblist without recomputing the full nblist"""
644        graph = inputs[self.graph_key]
645
646        edge_src_skin = graph["edge_src_skin"]
647        edge_dst_skin = graph["edge_dst_skin"]
648        coords = inputs["coordinates"]
649        vec = coords.at[edge_dst_skin].get(
650            mode="fill", fill_value=self.cutoff
651        ) - coords.at[edge_src_skin].get(mode="fill", fill_value=0.0)
652
653        if "cells" in inputs:
654            pbc_shifts_skin = graph["pbc_shifts_skin"]
655            cells = inputs["cells"]
656            if cells.shape[0] == 1:
657                vec = vec + jnp.dot(pbc_shifts_skin, cells[0])
658            else:
659                batch_index_vec = inputs["batch_index"][edge_src_skin]
660                vec = vec + jax.vmap(jnp.dot)(pbc_shifts_skin, cells[batch_index_vec])
661
662        nat = coords.shape[0]
663        d12 = jnp.sum(vec**2, axis=-1)
664        mask = d12 < self.cutoff**2
665        max_pairs = graph["edge_src"].shape[0] // 2
666        (edge_src, edge_dst, d12), scatter_idx, npairs = mask_filter_1d(
667            mask,
668            max_pairs,
669            (edge_src_skin, nat),
670            (edge_dst_skin, nat),
671            (d12, self.cutoff**2),
672        )
673        if "cells" in inputs:
674            pbc_shifts = (
675                jnp.full((max_pairs, 3), 0.0, dtype=pbc_shifts_skin.dtype)
676                .at[scatter_idx]
677                .set(pbc_shifts_skin)
678            )
679
680        overflow = graph.get("overflow", False) | (npairs > max_pairs)
681        graph_out = {
682            **graph,
683            "edge_src": jnp.concatenate((edge_src, edge_dst)),
684            "edge_dst": jnp.concatenate((edge_dst, edge_src)),
685            "d12": jnp.concatenate((d12, d12)),
686            "overflow": overflow,
687        }
688        if "cells" in inputs:
689            graph_out["pbc_shifts"] = jnp.concatenate((pbc_shifts, -pbc_shifts))
690
691        if self.k_space and "cells" in inputs:
692            if "k_points" not in graph:
693                raise NotImplementedError(
694                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
695                )
696
697        return {**inputs, self.graph_key: graph_out}

update the nblist without recomputing the full nblist

class GraphProcessor(flax.linen.module.Module):
700class GraphProcessor(nn.Module):
701    """Process a pre-generated graph
702
703    The pre-generated graph should contain the following keys:
704    - edge_src: source indices of the edges
705    - edge_dst: destination indices of the edges
706    - pbcs_shifts: pbc shifts for the edges (only if `cells` are present in the inputs)
707
708    This module is automatically added to a FENNIX model when a GraphGenerator is used.
709
710    """
711
712    cutoff: float
713    """Cutoff distance for the graph."""
714    graph_key: str = "graph"
715    """Key of the graph in the outputs."""
716    switch_params: dict = dataclasses.field(default_factory=dict)
717    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
718
719    @nn.compact
720    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
721        graph = inputs[self.graph_key]
722        coords = inputs["coordinates"]
723        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
724        # edge_mask = edge_src < coords.shape[0]
725        vec = coords.at[edge_dst].get(mode="fill", fill_value=self.cutoff) - coords.at[
726            edge_src
727        ].get(mode="fill", fill_value=0.0)
728        if "cells" in inputs:
729            cells = inputs["cells"]
730            if cells.shape[0] == 1:
731                vec = vec + jnp.dot(graph["pbc_shifts"], cells[0])
732            else:
733                batch_index_vec = inputs["batch_index"][edge_src]
734                vec = vec + jax.vmap(jnp.dot)(
735                    graph["pbc_shifts"], cells[batch_index_vec]
736                )
737
738        d2  = jnp.sum(vec**2, axis=-1)
739        distances = safe_sqrt(d2)
740        edge_mask = distances < self.cutoff
741
742        switch = SwitchFunction(
743            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
744        )((distances, edge_mask))
745
746        graph_out = {
747            **graph,
748            "vec": vec,
749            "distances": distances,
750            "switch": switch,
751            "edge_mask": edge_mask,
752        }
753
754        if "alch_group" in inputs:
755            alch_group = inputs["alch_group"]
756            lambda_e = inputs["alch_elambda"]
757            lambda_v = inputs["alch_vlambda"]
758            mask = alch_group[edge_src] == alch_group[edge_dst]
759            graph_out["switch_raw"] = switch
760            graph_out["switch"] = jnp.where(
761                mask,
762                switch,
763                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
764            )
765            graph_out["distances_raw"] = distances
766            if "alch_softcore_e" in inputs:
767                alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2
768            else:
769                alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2
770
771            graph_out["distances"] = jnp.where(
772                mask,
773                distances,
774                safe_sqrt(alch_alpha + d2 * (1. - alch_alpha/self.cutoff**2))
775            )  
776
777
778        return {**inputs, self.graph_key: graph_out}

Process a pre-generated graph

The pre-generated graph should contain the following keys:

  • edge_src: source indices of the edges
  • edge_dst: destination indices of the edges
  • pbcs_shifts: pbc shifts for the edges (only if cells are present in the inputs)

This module is automatically added to a FENNIX model when a GraphGenerator is used.

GraphProcessor( cutoff: float, graph_key: str = 'graph', switch_params: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: float

Cutoff distance for the graph.

graph_key: str = 'graph'

Key of the graph in the outputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class GraphFilter:
781@dataclasses.dataclass(frozen=True)
782class GraphFilter:
783    """Filter a graph based on a cutoff distance
784
785    FPID: GRAPH_FILTER
786    """
787
788    cutoff: float
789    """Cutoff distance for the filtering."""
790    parent_graph: str
791    """Key of the parent graph in the inputs."""
792    graph_key: str
793    """Key of the filtered graph in the outputs."""
794    remove_hydrogens: int = False
795    """Remove edges where the source is a hydrogen atom."""
796    switch_params: FrozenDict = dataclasses.field(default_factory=FrozenDict)
797    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
798    k_space: bool = False
799    """Generate k-space information for the graph."""
800    kmax: int = 30
801    """Maximum number of k-points to consider."""
802    kthr: float = 1e-6
803    """Threshold for k-point filtering."""
804    mult_size: float = 1.05
805    """Multiplicative factor for resizing the nblist."""
806
807    FPID: ClassVar[str] = "GRAPH_FILTER"
808
809    def init(self):
810        return FrozenDict(
811            {
812                "npairs": 1,
813                "nblist_mult_size": self.mult_size,
814            }
815        )
816
817    def get_processor(self) -> Tuple[nn.Module, Dict]:
818        return GraphFilterProcessor, {
819            "cutoff": self.cutoff,
820            "graph_key": self.graph_key,
821            "parent_graph": self.parent_graph,
822            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
823            "switch_params": self.switch_params,
824        }
825
826    def get_graph_properties(self):
827        return {
828            self.graph_key: {
829                "cutoff": self.cutoff,
830                "directed": True,
831                "parent_graph": self.parent_graph,
832            }
833        }
834
835    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
836        """filter a nblist on cpu with numpy and dynamic shapes + store max shapes"""
837        graph_in = inputs[self.parent_graph]
838        nat = inputs["species"].shape[0]
839
840        new_state = {**state}
841        state_up = {}
842        mult_size = state.get("nblist_mult_size", self.mult_size)
843        assert mult_size >= 1., "nblist_mult_size should be >= 1."
844
845        edge_src = np.array(graph_in["edge_src"], dtype=np.int32)
846        d12 = np.array(graph_in["d12"], dtype=np.float32)
847        if self.remove_hydrogens:
848            species = inputs["species"]
849            src_idx = (edge_src < nat).nonzero()[0]
850            mask = np.zeros(edge_src.shape[0], dtype=bool)
851            mask[src_idx] = (species > 1)[edge_src[src_idx]]
852            d12 = np.where(mask, d12, self.cutoff**2)
853        mask = d12 < self.cutoff**2
854
855        max_pairs = state.get("npairs", 1)
856        idx = np.nonzero(mask)[0]
857        npairs = idx.shape[0]
858        if npairs > max_pairs or add_margin:
859            prev_max_pairs = max_pairs
860            max_pairs = int(mult_size * max(npairs, max_pairs)) + 1
861            state_up["npairs"] = (max_pairs, prev_max_pairs)
862            new_state["npairs"] = max_pairs
863
864        filter_indices = np.full(max_pairs, edge_src.shape[0], dtype=np.int32)
865        edge_src = np.full(max_pairs, nat, dtype=np.int32)
866        edge_dst = np.full(max_pairs, nat, dtype=np.int32)
867        d12_ = np.full(max_pairs, self.cutoff**2)
868        filter_indices[:npairs] = idx
869        edge_src[:npairs] = graph_in["edge_src"][idx]
870        edge_dst[:npairs] = graph_in["edge_dst"][idx]
871        d12_[:npairs] = d12[idx]
872        d12 = d12_
873
874        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
875        graph_out = {
876            **graph,
877            "edge_src": edge_src,
878            "edge_dst": edge_dst,
879            "filter_indices": filter_indices,
880            "d12": d12,
881            "overflow": False,
882        }
883
884        if self.k_space and "cells" in inputs:
885            if "k_points" not in graph:
886                ks, _, _, bewald = get_reciprocal_space_parameters(
887                    inputs["reciprocal_cells"], self.cutoff, self.kmax, self.kthr
888                )
889            graph_out["k_points"] = ks
890            graph_out["b_ewald"] = bewald
891
892        output = {**inputs, self.graph_key: graph_out}
893        if return_state_update:
894            return FrozenDict(new_state), output, state_up
895        return FrozenDict(new_state), output
896
897    def check_reallocate(self, state, inputs, parent_overflow=False):
898        """check for overflow and reallocate nblist if necessary"""
899        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
900        if not overflow:
901            return state, {}, inputs, False
902
903        add_margin = inputs[self.graph_key].get("overflow", False)
904        state, inputs, state_up = self(
905            state, inputs, return_state_update=True, add_margin=add_margin
906        )
907        return state, state_up, inputs, True
908
909    @partial(jax.jit, static_argnums=(0, 1))
910    def process(self, state, inputs):
911        """filter a nblist on accelerator with jax and precomputed shapes"""
912        graph_in = inputs[self.parent_graph]
913        if state is None:
914            # skin update mode
915            graph = inputs[self.graph_key]
916            max_pairs = graph["edge_src"].shape[0]
917        else:
918            max_pairs = state.get("npairs", 1)
919
920        max_pairs_in = graph_in["edge_src"].shape[0]
921        nat = inputs["species"].shape[0]
922
923        edge_src = graph_in["edge_src"]
924        d12 = graph_in["d12"]
925        if self.remove_hydrogens:
926            species = inputs["species"]
927            mask = (species > 1)[edge_src]
928            d12 = jnp.where(mask, d12, self.cutoff**2)
929        mask = d12 < self.cutoff**2
930
931        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
932            mask,
933            max_pairs,
934            (edge_src, nat),
935            (graph_in["edge_dst"], nat),
936            (d12, self.cutoff**2),
937            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
938        )
939
940        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
941        overflow = graph.get("overflow", False) | (npairs > max_pairs)
942        graph_out = {
943            **graph,
944            "edge_src": edge_src,
945            "edge_dst": edge_dst,
946            "filter_indices": filter_indices,
947            "d12": d12,
948            "overflow": overflow,
949        }
950
951        if self.k_space and "cells" in inputs:
952            if "k_points" not in graph:
953                raise NotImplementedError(
954                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
955                )
956
957        return {**inputs, self.graph_key: graph_out}
958
959    @partial(jax.jit, static_argnums=(0,))
960    def update_skin(self, inputs):
961        return self.process(None, inputs)

Filter a graph based on a cutoff distance

FPID: GRAPH_FILTER

GraphFilter( cutoff: float, parent_graph: str, graph_key: str, remove_hydrogens: int = False, switch_params: flax.core.frozen_dict.FrozenDict = <factory>, k_space: bool = False, kmax: int = 30, kthr: float = 1e-06, mult_size: float = 1.05)
cutoff: float

Cutoff distance for the filtering.

parent_graph: str

Key of the parent graph in the inputs.

graph_key: str

Key of the filtered graph in the outputs.

remove_hydrogens: int = False

Remove edges where the source is a hydrogen atom.

switch_params: flax.core.frozen_dict.FrozenDict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

k_space: bool = False

Generate k-space information for the graph.

kmax: int = 30

Maximum number of k-points to consider.

kthr: float = 1e-06

Threshold for k-point filtering.

mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

FPID: ClassVar[str] = 'GRAPH_FILTER'
def init(self):
809    def init(self):
810        return FrozenDict(
811            {
812                "npairs": 1,
813                "nblist_mult_size": self.mult_size,
814            }
815        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
817    def get_processor(self) -> Tuple[nn.Module, Dict]:
818        return GraphFilterProcessor, {
819            "cutoff": self.cutoff,
820            "graph_key": self.graph_key,
821            "parent_graph": self.parent_graph,
822            "name": f"{self.graph_key}_Filter_{self.parent_graph}",
823            "switch_params": self.switch_params,
824        }
def get_graph_properties(self):
826    def get_graph_properties(self):
827        return {
828            self.graph_key: {
829                "cutoff": self.cutoff,
830                "directed": True,
831                "parent_graph": self.parent_graph,
832            }
833        }
def check_reallocate(self, state, inputs, parent_overflow=False):
897    def check_reallocate(self, state, inputs, parent_overflow=False):
898        """check for overflow and reallocate nblist if necessary"""
899        overflow = parent_overflow or inputs[self.graph_key].get("overflow", False)
900        if not overflow:
901            return state, {}, inputs, False
902
903        add_margin = inputs[self.graph_key].get("overflow", False)
904        state, inputs, state_up = self(
905            state, inputs, return_state_update=True, add_margin=add_margin
906        )
907        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
909    @partial(jax.jit, static_argnums=(0, 1))
910    def process(self, state, inputs):
911        """filter a nblist on accelerator with jax and precomputed shapes"""
912        graph_in = inputs[self.parent_graph]
913        if state is None:
914            # skin update mode
915            graph = inputs[self.graph_key]
916            max_pairs = graph["edge_src"].shape[0]
917        else:
918            max_pairs = state.get("npairs", 1)
919
920        max_pairs_in = graph_in["edge_src"].shape[0]
921        nat = inputs["species"].shape[0]
922
923        edge_src = graph_in["edge_src"]
924        d12 = graph_in["d12"]
925        if self.remove_hydrogens:
926            species = inputs["species"]
927            mask = (species > 1)[edge_src]
928            d12 = jnp.where(mask, d12, self.cutoff**2)
929        mask = d12 < self.cutoff**2
930
931        (edge_src, edge_dst, d12, filter_indices), _, npairs = mask_filter_1d(
932            mask,
933            max_pairs,
934            (edge_src, nat),
935            (graph_in["edge_dst"], nat),
936            (d12, self.cutoff**2),
937            (jnp.arange(max_pairs_in, dtype=jnp.int32), max_pairs_in),
938        )
939
940        graph = inputs[self.graph_key] if self.graph_key in inputs else {}
941        overflow = graph.get("overflow", False) | (npairs > max_pairs)
942        graph_out = {
943            **graph,
944            "edge_src": edge_src,
945            "edge_dst": edge_dst,
946            "filter_indices": filter_indices,
947            "d12": d12,
948            "overflow": overflow,
949        }
950
951        if self.k_space and "cells" in inputs:
952            if "k_points" not in graph:
953                raise NotImplementedError(
954                    "k_space generation not implemented on accelerator. Call the numpy routine (self.__call__) first."
955                )
956
957        return {**inputs, self.graph_key: graph_out}

filter a nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
959    @partial(jax.jit, static_argnums=(0,))
960    def update_skin(self, inputs):
961        return self.process(None, inputs)
class GraphFilterProcessor(flax.linen.module.Module):
 964class GraphFilterProcessor(nn.Module):
 965    """Filter processing for a pre-generated graph
 966
 967    This module is automatically added to a FENNIX model when a GraphFilter is used.
 968    """
 969
 970    cutoff: float
 971    """Cutoff distance for the filtering."""
 972    graph_key: str
 973    """Key of the filtered graph in the inputs."""
 974    parent_graph: str
 975    """Key of the parent graph in the inputs."""
 976    switch_params: dict = dataclasses.field(default_factory=dict)
 977    """Parameters for the switching function. See `fennol.models.misc.misc.SwitchFunction`."""
 978
 979    @nn.compact
 980    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
 981        graph_in = inputs[self.parent_graph]
 982        graph = inputs[self.graph_key]
 983
 984        d_key = "distances_raw" if "distances_raw" in graph else "distances"
 985
 986        if graph_in["vec"].shape[0] == 0:
 987            vec = graph_in["vec"]
 988            distances = graph_in[d_key]
 989            filter_indices = jnp.asarray([], dtype=jnp.int32)
 990        else:
 991            filter_indices = graph["filter_indices"]
 992            vec = (
 993                graph_in["vec"]
 994                .at[filter_indices]
 995                .get(mode="fill", fill_value=self.cutoff)
 996            )
 997            distances = (
 998                graph_in[d_key]
 999                .at[filter_indices]
1000                .get(mode="fill", fill_value=self.cutoff)
1001            )
1002
1003        edge_mask = distances < self.cutoff
1004        switch = SwitchFunction(
1005            **{**self.switch_params, "cutoff": self.cutoff, "graph_key": None}
1006        )((distances, edge_mask))
1007
1008        graph_out = {
1009            **graph,
1010            "vec": vec,
1011            "distances": distances,
1012            "switch": switch,
1013            "filter_indices": filter_indices,
1014            "edge_mask": edge_mask,
1015        }
1016
1017        if "alch_group" in inputs:
1018            edge_src=graph["edge_src"]
1019            edge_dst=graph["edge_dst"]
1020            alch_group = inputs["alch_group"]
1021            lambda_e = inputs["alch_elambda"]
1022            lambda_v = inputs["alch_vlambda"]
1023            mask = alch_group[edge_src] == alch_group[edge_dst]
1024            graph_out["switch_raw"] = switch
1025            graph_out["switch"] = jnp.where(
1026                mask,
1027                switch,
1028                0.5*(1.-jnp.cos(jnp.pi*lambda_e)) * switch ,
1029            )
1030
1031            graph_out["distances_raw"] = distances
1032            if "alch_softcore_e" in inputs:
1033                alch_alpha = (1-lambda_e)*inputs["alch_softcore_e"]**2
1034            else:
1035                alch_alpha = (1-lambda_v)*inputs.get("alch_softcore_v",0.5)**2
1036
1037            graph_out["distances"] = jnp.where(
1038                mask,
1039                distances,
1040                safe_sqrt(alch_alpha + distances**2 * (1. - alch_alpha/self.cutoff**2))
1041            )  
1042            
1043
1044        return {**inputs, self.graph_key: graph_out}

Filter processing for a pre-generated graph

This module is automatically added to a FENNIX model when a GraphFilter is used.

GraphFilterProcessor( cutoff: float, graph_key: str, parent_graph: str, switch_params: dict = <factory>, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
cutoff: float

Cutoff distance for the filtering.

graph_key: str

Key of the filtered graph in the inputs.

parent_graph: str

Key of the parent graph in the inputs.

switch_params: dict

Parameters for the switching function. See fennol.models.misc.misc.SwitchFunction.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class GraphAngularExtension:
1047@dataclasses.dataclass(frozen=True)
1048class GraphAngularExtension:
1049    """Add angles list to a graph
1050
1051    FPID: GRAPH_ANGULAR_EXTENSION
1052    """
1053
1054    mult_size: float = 1.05
1055    """Multiplicative factor for resizing the nblist."""
1056    add_neigh: int = 5
1057    """Additional neighbors to add to the nblist when resizing."""
1058    graph_key: str = "graph"
1059    """Key of the graph in the inputs."""
1060
1061    FPID: ClassVar[str] = "GRAPH_ANGULAR_EXTENSION"
1062
1063    def init(self):
1064        return FrozenDict(
1065            {
1066                "nangles": 0,
1067                "nblist_mult_size": self.mult_size,
1068                "max_neigh": self.add_neigh,
1069                "add_neigh": self.add_neigh,
1070            }
1071        )
1072
1073    def get_processor(self) -> Tuple[nn.Module, Dict]:
1074        return GraphAngleProcessor, {
1075            "graph_key": self.graph_key,
1076            "name": f"{self.graph_key}_AngleProcessor",
1077        }
1078
1079    def get_graph_properties(self):
1080        return {
1081            self.graph_key: {
1082                "has_angles": True,
1083            }
1084        }
1085
1086    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1087        """build angle nblist on cpu with numpy and dynamic shapes + store max shapes"""
1088        graph = inputs[self.graph_key]
1089        edge_src = np.array(graph["edge_src"], dtype=np.int32)
1090
1091        new_state = {**state}
1092        state_up = {}
1093        mult_size = state.get("nblist_mult_size", self.mult_size)
1094        assert mult_size >= 1., "nblist_mult_size should be >= 1."
1095
1096        ### count number of neighbors
1097        nat = inputs["species"].shape[0]
1098        count = np.zeros(nat + 1, dtype=np.int32)
1099        np.add.at(count, edge_src, 1)
1100        max_count = int(np.max(count[:-1]))
1101
1102        ### get sizes
1103        max_neigh = state.get("max_neigh", self.add_neigh)
1104        nedge = edge_src.shape[0]
1105        if max_count > max_neigh or add_margin:
1106            prev_max_neigh = max_neigh
1107            max_neigh = max(max_count, max_neigh) + state.get(
1108                "add_neigh", self.add_neigh
1109            )
1110            state_up["max_neigh"] = (max_neigh, prev_max_neigh)
1111            new_state["max_neigh"] = max_neigh
1112
1113        max_neigh_arr = np.empty(max_neigh, dtype=bool)
1114
1115        nedge = edge_src.shape[0]
1116
1117        ### sort edge_src
1118        idx_sort = np.argsort(edge_src)
1119        edge_src_sorted = edge_src[idx_sort]
1120
1121        ### map sparse to dense nblist
1122        offset = np.tile(np.arange(max_count), nat)
1123        if max_count * nat >= nedge:
1124            offset = np.tile(np.arange(max_count), nat)[:nedge]
1125        else:
1126            offset = np.zeros(nedge, dtype=np.int32)
1127            offset[: max_count * nat] = np.tile(np.arange(max_count), nat)
1128
1129        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1130        mask = edge_src_sorted < nat
1131        indices = edge_src_sorted * max_count + offset
1132        indices = indices[mask]
1133        idx_sort = idx_sort[mask]
1134        edge_idx = np.full(nat * max_count, nedge, dtype=np.int32)
1135        edge_idx[indices] = idx_sort
1136        edge_idx = edge_idx.reshape(nat, max_count)
1137
1138        ### find all triplet for each atom center
1139        local_src, local_dst = np.triu_indices(max_count, 1)
1140        angle_src = edge_idx[:, local_src].flatten()
1141        angle_dst = edge_idx[:, local_dst].flatten()
1142
1143        ### mask for valid angles
1144        mask1 = angle_src < nedge
1145        mask2 = angle_dst < nedge
1146        angle_mask = mask1 & mask2
1147
1148        max_angles = state.get("nangles", 0)
1149        idx = np.nonzero(angle_mask)[0]
1150        nangles = idx.shape[0]
1151        if nangles > max_angles or add_margin:
1152            max_angles_prev = max_angles
1153            max_angles = int(mult_size * max(nangles, max_angles)) + 1
1154            state_up["nangles"] = (max_angles, max_angles_prev)
1155            new_state["nangles"] = max_angles
1156
1157        ## filter angles to sparse representation
1158        angle_src_ = np.full(max_angles, nedge, dtype=np.int32)
1159        angle_dst_ = np.full(max_angles, nedge, dtype=np.int32)
1160        angle_src_[:nangles] = angle_src[idx]
1161        angle_dst_[:nangles] = angle_dst[idx]
1162
1163        central_atom = np.full(max_angles, nat, dtype=np.int32)
1164        central_atom[:nangles] = edge_src[angle_src_[:nangles]]
1165
1166        ## update graph
1167        output = {
1168            **inputs,
1169            self.graph_key: {
1170                **graph,
1171                "angle_src": angle_src_,
1172                "angle_dst": angle_dst_,
1173                "central_atom": central_atom,
1174                "angle_overflow": False,
1175                "max_neigh": max_neigh,
1176                "__max_neigh_array": max_neigh_arr,
1177            },
1178        }
1179
1180        if return_state_update:
1181            return FrozenDict(new_state), output, state_up
1182        return FrozenDict(new_state), output
1183
1184    def check_reallocate(self, state, inputs, parent_overflow=False):
1185        """check for overflow and reallocate nblist if necessary"""
1186        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1187        if not overflow:
1188            return state, {}, inputs, False
1189
1190        add_margin = inputs[self.graph_key]["angle_overflow"]
1191        state, inputs, state_up = self(
1192            state, inputs, return_state_update=True, add_margin=add_margin
1193        )
1194        return state, state_up, inputs, True
1195
1196    @partial(jax.jit, static_argnums=(0, 1))
1197    def process(self, state, inputs):
1198        """build angle nblist on accelerator with jax and precomputed shapes"""
1199        graph = inputs[self.graph_key]
1200        edge_src = graph["edge_src"]
1201
1202        ### count number of neighbors
1203        nat = inputs["species"].shape[0]
1204        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1205        max_count = jnp.max(count)
1206
1207        ### get sizes
1208        if state is None:
1209            max_neigh_arr = graph["__max_neigh_array"]
1210            max_neigh = max_neigh_arr.shape[0]
1211            prev_nangles = graph["angle_src"].shape[0]
1212        else:
1213            max_neigh = state.get("max_neigh", self.add_neigh)
1214            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1215            prev_nangles = state.get("nangles", 0)
1216
1217        nedge = edge_src.shape[0]
1218
1219        ### sort edge_src
1220        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1221        edge_src_sorted = edge_src[idx_sort]
1222
1223        ### map sparse to dense nblist
1224        if max_neigh * nat < nedge:
1225            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1226        offset = jnp.asarray(
1227            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1228        )
1229        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1230        indices = edge_src_sorted * max_neigh + offset
1231        edge_idx = (
1232            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1233            .at[indices]
1234            .set(idx_sort, mode="drop")
1235            .reshape(nat, max_neigh)
1236        )
1237
1238        ### find all triplet for each atom center
1239        local_src, local_dst = np.triu_indices(max_neigh, 1)
1240        angle_src = edge_idx[:, local_src].flatten()
1241        angle_dst = edge_idx[:, local_dst].flatten()
1242
1243        ### mask for valid angles
1244        mask1 = angle_src < nedge
1245        mask2 = angle_dst < nedge
1246        angle_mask = mask1 & mask2
1247
1248        ## filter angles to sparse representation
1249        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1250            angle_mask,
1251            prev_nangles,
1252            (angle_src, nedge),
1253            (angle_dst, nedge),
1254        )
1255        ## find central atom
1256        central_atom = edge_src[angle_src]
1257
1258        ## check for overflow
1259        angle_overflow = nangles > prev_nangles
1260        neigh_overflow = max_count > max_neigh
1261        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1262
1263        ## update graph
1264        output = {
1265            **inputs,
1266            self.graph_key: {
1267                **graph,
1268                "angle_src": angle_src,
1269                "angle_dst": angle_dst,
1270                "central_atom": central_atom,
1271                "angle_overflow": overflow,
1272                # "max_neigh": max_neigh,
1273                "__max_neigh_array": max_neigh_arr,
1274            },
1275        }
1276
1277        return output
1278
1279    @partial(jax.jit, static_argnums=(0,))
1280    def update_skin(self, inputs):
1281        return self.process(None, inputs)

Add angles list to a graph

FPID: GRAPH_ANGULAR_EXTENSION

GraphAngularExtension( mult_size: float = 1.05, add_neigh: int = 5, graph_key: str = 'graph')
mult_size: float = 1.05

Multiplicative factor for resizing the nblist.

add_neigh: int = 5

Additional neighbors to add to the nblist when resizing.

graph_key: str = 'graph'

Key of the graph in the inputs.

FPID: ClassVar[str] = 'GRAPH_ANGULAR_EXTENSION'
def init(self):
1063    def init(self):
1064        return FrozenDict(
1065            {
1066                "nangles": 0,
1067                "nblist_mult_size": self.mult_size,
1068                "max_neigh": self.add_neigh,
1069                "add_neigh": self.add_neigh,
1070            }
1071        )
def get_processor(self) -> Tuple[flax.linen.module.Module, Dict]:
1073    def get_processor(self) -> Tuple[nn.Module, Dict]:
1074        return GraphAngleProcessor, {
1075            "graph_key": self.graph_key,
1076            "name": f"{self.graph_key}_AngleProcessor",
1077        }
def get_graph_properties(self):
1079    def get_graph_properties(self):
1080        return {
1081            self.graph_key: {
1082                "has_angles": True,
1083            }
1084        }
def check_reallocate(self, state, inputs, parent_overflow=False):
1184    def check_reallocate(self, state, inputs, parent_overflow=False):
1185        """check for overflow and reallocate nblist if necessary"""
1186        overflow = parent_overflow or inputs[self.graph_key]["angle_overflow"]
1187        if not overflow:
1188            return state, {}, inputs, False
1189
1190        add_margin = inputs[self.graph_key]["angle_overflow"]
1191        state, inputs, state_up = self(
1192            state, inputs, return_state_update=True, add_margin=add_margin
1193        )
1194        return state, state_up, inputs, True

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1196    @partial(jax.jit, static_argnums=(0, 1))
1197    def process(self, state, inputs):
1198        """build angle nblist on accelerator with jax and precomputed shapes"""
1199        graph = inputs[self.graph_key]
1200        edge_src = graph["edge_src"]
1201
1202        ### count number of neighbors
1203        nat = inputs["species"].shape[0]
1204        count = jnp.zeros(nat, dtype=jnp.int32).at[edge_src].add(1, mode="drop")
1205        max_count = jnp.max(count)
1206
1207        ### get sizes
1208        if state is None:
1209            max_neigh_arr = graph["__max_neigh_array"]
1210            max_neigh = max_neigh_arr.shape[0]
1211            prev_nangles = graph["angle_src"].shape[0]
1212        else:
1213            max_neigh = state.get("max_neigh", self.add_neigh)
1214            max_neigh_arr = jnp.empty(max_neigh, dtype=bool)
1215            prev_nangles = state.get("nangles", 0)
1216
1217        nedge = edge_src.shape[0]
1218
1219        ### sort edge_src
1220        idx_sort = jnp.argsort(edge_src).astype(jnp.int32)
1221        edge_src_sorted = edge_src[idx_sort]
1222
1223        ### map sparse to dense nblist
1224        if max_neigh * nat < nedge:
1225            raise ValueError("Found max_neigh*nat < nedge. This should not happen.")
1226        offset = jnp.asarray(
1227            np.tile(np.arange(max_neigh), nat)[:nedge], dtype=jnp.int32
1228        )
1229        # offset = jnp.where(edge_src_sorted < nat, offset, 0)
1230        indices = edge_src_sorted * max_neigh + offset
1231        edge_idx = (
1232            jnp.full(nat * max_neigh, nedge, dtype=jnp.int32)
1233            .at[indices]
1234            .set(idx_sort, mode="drop")
1235            .reshape(nat, max_neigh)
1236        )
1237
1238        ### find all triplet for each atom center
1239        local_src, local_dst = np.triu_indices(max_neigh, 1)
1240        angle_src = edge_idx[:, local_src].flatten()
1241        angle_dst = edge_idx[:, local_dst].flatten()
1242
1243        ### mask for valid angles
1244        mask1 = angle_src < nedge
1245        mask2 = angle_dst < nedge
1246        angle_mask = mask1 & mask2
1247
1248        ## filter angles to sparse representation
1249        (angle_src, angle_dst), _, nangles = mask_filter_1d(
1250            angle_mask,
1251            prev_nangles,
1252            (angle_src, nedge),
1253            (angle_dst, nedge),
1254        )
1255        ## find central atom
1256        central_atom = edge_src[angle_src]
1257
1258        ## check for overflow
1259        angle_overflow = nangles > prev_nangles
1260        neigh_overflow = max_count > max_neigh
1261        overflow = graph.get("angle_overflow", False) | angle_overflow | neigh_overflow
1262
1263        ## update graph
1264        output = {
1265            **inputs,
1266            self.graph_key: {
1267                **graph,
1268                "angle_src": angle_src,
1269                "angle_dst": angle_dst,
1270                "central_atom": central_atom,
1271                "angle_overflow": overflow,
1272                # "max_neigh": max_neigh,
1273                "__max_neigh_array": max_neigh_arr,
1274            },
1275        }
1276
1277        return output

build angle nblist on accelerator with jax and precomputed shapes

@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1279    @partial(jax.jit, static_argnums=(0,))
1280    def update_skin(self, inputs):
1281        return self.process(None, inputs)
class GraphAngleProcessor(flax.linen.module.Module):
1284class GraphAngleProcessor(nn.Module):
1285    """Process a pre-generated graph to compute angles
1286
1287    This module is automatically added to a FENNIX model when a GraphAngularExtension is used.
1288
1289    """
1290
1291    graph_key: str
1292    """Key of the graph in the inputs."""
1293
1294    @nn.compact
1295    def __call__(self, inputs: Union[dict, Tuple[jax.Array, dict]]):
1296        graph = inputs[self.graph_key]
1297        distances = graph["distances_raw"] if "distances_raw" in graph else graph["distances"]
1298        vec = graph["vec"]
1299        angle_src = graph["angle_src"]
1300        angle_dst = graph["angle_dst"]
1301
1302        dir = vec / jnp.clip(distances[:, None], min=1.0e-5)
1303        cos_angles = (
1304            dir.at[angle_src].get(mode="fill", fill_value=0.5)
1305            * dir.at[angle_dst].get(mode="fill", fill_value=0.5)
1306        ).sum(axis=-1)
1307
1308        angles = jnp.arccos(0.95 * cos_angles)
1309
1310        return {
1311            **inputs,
1312            self.graph_key: {
1313                **graph,
1314                # "cos_angles": cos_angles,
1315                "angles": angles,
1316                # "angle_mask": angle_mask,
1317            },
1318        }

Process a pre-generated graph to compute angles

This module is automatically added to a FENNIX model when a GraphAngularExtension is used.

GraphAngleProcessor( graph_key: str, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str

Key of the graph in the inputs.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class SpeciesIndexer:
1321@dataclasses.dataclass(frozen=True)
1322class SpeciesIndexer:
1323    """Build an index that splits atomic arrays by species.
1324
1325    FPID: SPECIES_INDEXER
1326
1327    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1328    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1329
1330    """
1331
1332    output_key: str = "species_index"
1333    """Key for the output dictionary."""
1334    species_order: Optional[str] = None
1335    """Comma separated list of species in the order they should be indexed."""
1336    add_atoms: int = 0
1337    """Additional atoms to add to the sizes."""
1338    add_atoms_margin: int = 10
1339    """Additional atoms to add to the sizes when adding margin."""
1340
1341    FPID: ClassVar[str] = "SPECIES_INDEXER"
1342
1343    def init(self):
1344        return FrozenDict(
1345            {
1346                "sizes": {},
1347            }
1348        )
1349
1350    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1351        species = np.array(inputs["species"], dtype=np.int32)
1352        nat = species.shape[0]
1353        set_species, counts = np.unique(species, return_counts=True)
1354
1355        new_state = {**state}
1356        state_up = {}
1357
1358        sizes = state.get("sizes", FrozenDict({}))
1359        new_sizes = {**sizes}
1360        up_sizes = False
1361        counts_dict = {}
1362        for s, c in zip(set_species, counts):
1363            if s <= 0:
1364                continue
1365            counts_dict[s] = c
1366            if c > sizes.get(s, 0):
1367                up_sizes = True
1368                add_atoms = state.get("add_atoms", self.add_atoms)
1369                if add_margin:
1370                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1371                new_sizes[s] = c + add_atoms
1372
1373        new_sizes = FrozenDict(new_sizes)
1374        if up_sizes:
1375            state_up["sizes"] = (new_sizes, sizes)
1376            new_state["sizes"] = new_sizes
1377
1378        if self.species_order is not None:
1379            species_order = [el.strip() for el in self.species_order.split(",")]
1380            max_size_prev = state.get("max_size", 0)
1381            max_size = max(new_sizes.values())
1382            if max_size > max_size_prev:
1383                state_up["max_size"] = (max_size, max_size_prev)
1384                new_state["max_size"] = max_size
1385                max_size_prev = max_size
1386
1387            species_index = np.full((len(species_order), max_size), nat, dtype=np.int32)
1388            for i, el in enumerate(species_order):
1389                s = PERIODIC_TABLE_REV_IDX[el]
1390                if s in counts_dict.keys():
1391                    species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1392        else:
1393            species_index = {
1394                PERIODIC_TABLE[s]: np.full(c, nat, dtype=np.int32)
1395                for s, c in new_sizes.items()
1396            }
1397            for s, c in zip(set_species, counts):
1398                if s <= 0:
1399                    continue
1400                species_index[PERIODIC_TABLE[s]][:c] = np.nonzero(species == s)[0]
1401
1402        output = {
1403            **inputs,
1404            self.output_key: species_index,
1405            self.output_key + "_overflow": False,
1406        }
1407
1408        if return_state_update:
1409            return FrozenDict(new_state), output, state_up
1410        return FrozenDict(new_state), output
1411
1412    def check_reallocate(self, state, inputs, parent_overflow=False):
1413        """check for overflow and reallocate nblist if necessary"""
1414        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1415        if not overflow:
1416            return state, {}, inputs, False
1417
1418        add_margin = inputs[self.output_key + "_overflow"]
1419        state, inputs, state_up = self(
1420            state, inputs, return_state_update=True, add_margin=add_margin
1421        )
1422        return state, state_up, inputs, True
1423        # return state, {}, inputs, parent_overflow
1424
1425    @partial(jax.jit, static_argnums=(0, 1))
1426    def process(self, state, inputs):
1427        # assert (
1428        #     self.output_key in inputs
1429        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1430
1431        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1432        if self.output_key in inputs and not recompute_species_index:
1433            return inputs
1434
1435        if state is None:
1436            raise ValueError("Species Indexer state must be provided on accelerator.")
1437
1438        species = inputs["species"]
1439        nat = species.shape[0]
1440
1441        sizes = state["sizes"]
1442
1443        if self.species_order is not None:
1444            species_order = [el.strip() for el in self.species_order.split(",")]
1445            max_size = state["max_size"]
1446
1447            species_index = jnp.full(
1448                (len(species_order), max_size), nat, dtype=jnp.int32
1449            )
1450            for i, el in enumerate(species_order):
1451                s = PERIODIC_TABLE_REV_IDX[el]
1452                if s in sizes.keys():
1453                    c = sizes[s]
1454                    species_index = species_index.at[i, :].set(
1455                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1456                    )
1457                # if s in counts_dict.keys():
1458                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1459        else:
1460            # species_index = {
1461            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1462            # for s, c in sizes.items()
1463            # }
1464            species_index = {}
1465            overflow = False
1466            natcount = 0
1467            for s, c in sizes.items():
1468                mask = species == s
1469                new_size = jnp.sum(mask)
1470                natcount = natcount + new_size
1471                overflow = overflow | (new_size > c)  # check if sizes are correct
1472                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1473                    species == s, size=c, fill_value=nat
1474                )[0]
1475
1476            mask = species <= 0
1477            new_size = jnp.sum(mask)
1478            natcount = natcount + new_size
1479            overflow = overflow | (
1480                natcount < species.shape[0]
1481            )  # check if any species missing
1482
1483        return {
1484            **inputs,
1485            self.output_key: species_index,
1486            self.output_key + "_overflow": overflow,
1487        }
1488
1489    @partial(jax.jit, static_argnums=(0,))
1490    def update_skin(self, inputs):
1491        return self.process(None, inputs)

Build an index that splits atomic arrays by species.

FPID: SPECIES_INDEXER

If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.

SpeciesIndexer( output_key: str = 'species_index', species_order: Optional[str] = None, add_atoms: int = 0, add_atoms_margin: int = 10)
output_key: str = 'species_index'

Key for the output dictionary.

species_order: Optional[str] = None

Comma separated list of species in the order they should be indexed.

add_atoms: int = 0

Additional atoms to add to the sizes.

add_atoms_margin: int = 10

Additional atoms to add to the sizes when adding margin.

FPID: ClassVar[str] = 'SPECIES_INDEXER'
def init(self):
1343    def init(self):
1344        return FrozenDict(
1345            {
1346                "sizes": {},
1347            }
1348        )
def check_reallocate(self, state, inputs, parent_overflow=False):
1412    def check_reallocate(self, state, inputs, parent_overflow=False):
1413        """check for overflow and reallocate nblist if necessary"""
1414        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1415        if not overflow:
1416            return state, {}, inputs, False
1417
1418        add_margin = inputs[self.output_key + "_overflow"]
1419        state, inputs, state_up = self(
1420            state, inputs, return_state_update=True, add_margin=add_margin
1421        )
1422        return state, state_up, inputs, True
1423        # return state, {}, inputs, parent_overflow

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1425    @partial(jax.jit, static_argnums=(0, 1))
1426    def process(self, state, inputs):
1427        # assert (
1428        #     self.output_key in inputs
1429        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1430
1431        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1432        if self.output_key in inputs and not recompute_species_index:
1433            return inputs
1434
1435        if state is None:
1436            raise ValueError("Species Indexer state must be provided on accelerator.")
1437
1438        species = inputs["species"]
1439        nat = species.shape[0]
1440
1441        sizes = state["sizes"]
1442
1443        if self.species_order is not None:
1444            species_order = [el.strip() for el in self.species_order.split(",")]
1445            max_size = state["max_size"]
1446
1447            species_index = jnp.full(
1448                (len(species_order), max_size), nat, dtype=jnp.int32
1449            )
1450            for i, el in enumerate(species_order):
1451                s = PERIODIC_TABLE_REV_IDX[el]
1452                if s in sizes.keys():
1453                    c = sizes[s]
1454                    species_index = species_index.at[i, :].set(
1455                        jnp.nonzero(species == s, size=max_size, fill_value=nat)[0]
1456                    )
1457                # if s in counts_dict.keys():
1458                #     species_index[i, : counts_dict[s]] = np.nonzero(species == s)[0]
1459        else:
1460            # species_index = {
1461            # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1462            # for s, c in sizes.items()
1463            # }
1464            species_index = {}
1465            overflow = False
1466            natcount = 0
1467            for s, c in sizes.items():
1468                mask = species == s
1469                new_size = jnp.sum(mask)
1470                natcount = natcount + new_size
1471                overflow = overflow | (new_size > c)  # check if sizes are correct
1472                species_index[PERIODIC_TABLE[s]] = jnp.nonzero(
1473                    species == s, size=c, fill_value=nat
1474                )[0]
1475
1476            mask = species <= 0
1477            new_size = jnp.sum(mask)
1478            natcount = natcount + new_size
1479            overflow = overflow | (
1480                natcount < species.shape[0]
1481            )  # check if any species missing
1482
1483        return {
1484            **inputs,
1485            self.output_key: species_index,
1486            self.output_key + "_overflow": overflow,
1487        }
@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1489    @partial(jax.jit, static_argnums=(0,))
1490    def update_skin(self, inputs):
1491        return self.process(None, inputs)
@dataclasses.dataclass(frozen=True)
class BlockIndexer:
1493@dataclasses.dataclass(frozen=True)
1494class BlockIndexer:
1495    """Build an index that splits atomic arrays by chemical blocks.
1496
1497    FPID: BLOCK_INDEXER
1498
1499    If `species_order` is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays.
1500    If `species_order` is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.
1501
1502    """
1503
1504    output_key: str = "block_index"
1505    """Key for the output dictionary."""
1506    add_atoms: int = 0
1507    """Additional atoms to add to the sizes."""
1508    add_atoms_margin: int = 10
1509    """Additional atoms to add to the sizes when adding margin."""
1510    split_CNOPSSe: bool = False
1511
1512    FPID: ClassVar[str] = "BLOCK_INDEXER"
1513
1514    def init(self):
1515        return FrozenDict(
1516            {
1517                "sizes": {},
1518            }
1519        )
1520
1521    def build_chemical_blocks(self):
1522        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1523        if self.split_CNOPSSe:
1524            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1525            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1526        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1527        if self.split_CNOPSSe:
1528            _CHEMICAL_BLOCKS[6] = 1
1529            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1530            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1531            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1532            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1533            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1534        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
1535
1536    def __call__(self, state, inputs, return_state_update=False, add_margin=False):
1537        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1538
1539        species = np.array(inputs["species"], dtype=np.int32)
1540        blocks = _CHEMICAL_BLOCKS[species]
1541        nat = species.shape[0]
1542        set_blocks, counts = np.unique(blocks, return_counts=True)
1543
1544        new_state = {**state}
1545        state_up = {}
1546
1547        sizes = state.get("sizes", FrozenDict({}))
1548        new_sizes = {**sizes}
1549        up_sizes = False
1550        for s, c in zip(set_blocks, counts):
1551            if s < 0:
1552                continue
1553            key = (s, _CHEMICAL_BLOCKS_NAMES[s])
1554            if c > sizes.get(key, 0):
1555                up_sizes = True
1556                add_atoms = state.get("add_atoms", self.add_atoms)
1557                if add_margin:
1558                    add_atoms += state.get("add_atoms_margin", self.add_atoms_margin)
1559                new_sizes[key] = c + add_atoms
1560
1561        new_sizes = FrozenDict(new_sizes)
1562        if up_sizes:
1563            state_up["sizes"] = (new_sizes, sizes)
1564            new_state["sizes"] = new_sizes
1565
1566        block_index = {n:None for n in _CHEMICAL_BLOCKS_NAMES}
1567        for (_,n), c in new_sizes.items():
1568            block_index[n] = np.full(c, nat, dtype=np.int32)
1569        # block_index = {
1570            # n: np.full(c, nat, dtype=np.int32)
1571            # for (_,n), c in new_sizes.items()
1572        # }
1573        for s, c in zip(set_blocks, counts):
1574            if s < 0:
1575                continue
1576            block_index[_CHEMICAL_BLOCKS_NAMES[s]][:c] = np.nonzero(blocks == s)[0]
1577
1578        output = {
1579            **inputs,
1580            self.output_key: block_index,
1581            self.output_key + "_overflow": False,
1582        }
1583
1584        if return_state_update:
1585            return FrozenDict(new_state), output, state_up
1586        return FrozenDict(new_state), output
1587
1588    def check_reallocate(self, state, inputs, parent_overflow=False):
1589        """check for overflow and reallocate nblist if necessary"""
1590        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1591        if not overflow:
1592            return state, {}, inputs, False
1593
1594        add_margin = inputs[self.output_key + "_overflow"]
1595        state, inputs, state_up = self(
1596            state, inputs, return_state_update=True, add_margin=add_margin
1597        )
1598        return state, state_up, inputs, True
1599        # return state, {}, inputs, parent_overflow
1600
1601    @partial(jax.jit, static_argnums=(0, 1))
1602    def process(self, state, inputs):
1603        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1604        # assert (
1605        #     self.output_key in inputs
1606        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1607
1608        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1609        if self.output_key in inputs and not recompute_species_index:
1610            return inputs
1611
1612        if state is None:
1613            raise ValueError("Block Indexer state must be provided on accelerator.")
1614
1615        species = inputs["species"]
1616        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1617        nat = species.shape[0]
1618
1619        sizes = state["sizes"]
1620
1621        # species_index = {
1622        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1623        # for s, c in sizes.items()
1624        # }
1625        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1626        overflow = False
1627        natcount = 0
1628        for (s,name), c in sizes.items():
1629            mask = blocks == s
1630            new_size = jnp.sum(mask)
1631            natcount = natcount + new_size
1632            overflow = overflow | (new_size > c)  # check if sizes are correct
1633            block_index[name] = jnp.nonzero(
1634                mask, size=c, fill_value=nat
1635            )[0]
1636
1637        mask = blocks < 0
1638        new_size = jnp.sum(mask)
1639        natcount = natcount + new_size
1640        overflow = overflow | (
1641            natcount < species.shape[0]
1642        )  # check if any species missing
1643
1644        return {
1645            **inputs,
1646            self.output_key: block_index,
1647            self.output_key + "_overflow": overflow,
1648        }
1649
1650    @partial(jax.jit, static_argnums=(0,))
1651    def update_skin(self, inputs):
1652        return self.process(None, inputs)

Build an index that splits atomic arrays by chemical blocks.

FPID: BLOCK_INDEXER

If species_order is specified, the output will be a dense array with size (len(species_order), max_size) that can directly index atomic arrays. If species_order is None, the output will be a dictionary with species as keys and an index to filter atomic arrays for that species as values.

BlockIndexer( output_key: str = 'block_index', add_atoms: int = 0, add_atoms_margin: int = 10, split_CNOPSSe: bool = False)
output_key: str = 'block_index'

Key for the output dictionary.

add_atoms: int = 0

Additional atoms to add to the sizes.

add_atoms_margin: int = 10

Additional atoms to add to the sizes when adding margin.

split_CNOPSSe: bool = False
FPID: ClassVar[str] = 'BLOCK_INDEXER'
def init(self):
1514    def init(self):
1515        return FrozenDict(
1516            {
1517                "sizes": {},
1518            }
1519        )
def build_chemical_blocks(self):
1521    def build_chemical_blocks(self):
1522        _CHEMICAL_BLOCKS_NAMES = CHEMICAL_BLOCKS_NAMES.copy()
1523        if self.split_CNOPSSe:
1524            _CHEMICAL_BLOCKS_NAMES[1] = "C"
1525            _CHEMICAL_BLOCKS_NAMES.extend(["N","O","P","S","Se"])
1526        _CHEMICAL_BLOCKS = CHEMICAL_BLOCKS.copy()
1527        if self.split_CNOPSSe:
1528            _CHEMICAL_BLOCKS[6] = 1
1529            _CHEMICAL_BLOCKS[7] = len(CHEMICAL_BLOCKS_NAMES)
1530            _CHEMICAL_BLOCKS[8] = len(CHEMICAL_BLOCKS_NAMES)+1
1531            _CHEMICAL_BLOCKS[15] = len(CHEMICAL_BLOCKS_NAMES)+2
1532            _CHEMICAL_BLOCKS[16] = len(CHEMICAL_BLOCKS_NAMES)+3
1533            _CHEMICAL_BLOCKS[34] = len(CHEMICAL_BLOCKS_NAMES)+4
1534        return _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS
def check_reallocate(self, state, inputs, parent_overflow=False):
1588    def check_reallocate(self, state, inputs, parent_overflow=False):
1589        """check for overflow and reallocate nblist if necessary"""
1590        overflow = parent_overflow or inputs[self.output_key + "_overflow"]
1591        if not overflow:
1592            return state, {}, inputs, False
1593
1594        add_margin = inputs[self.output_key + "_overflow"]
1595        state, inputs, state_up = self(
1596            state, inputs, return_state_update=True, add_margin=add_margin
1597        )
1598        return state, state_up, inputs, True
1599        # return state, {}, inputs, parent_overflow

check for overflow and reallocate nblist if necessary

@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1601    @partial(jax.jit, static_argnums=(0, 1))
1602    def process(self, state, inputs):
1603        _CHEMICAL_BLOCKS_NAMES, _CHEMICAL_BLOCKS = self.build_chemical_blocks()
1604        # assert (
1605        #     self.output_key in inputs
1606        # ), f"Species Index {self.output_key} must be provided on accelerator. Call the numpy routine (self.__call__) first."
1607
1608        recompute_species_index = "recompute_species_index" in inputs.get("flags", {})
1609        if self.output_key in inputs and not recompute_species_index:
1610            return inputs
1611
1612        if state is None:
1613            raise ValueError("Block Indexer state must be provided on accelerator.")
1614
1615        species = inputs["species"]
1616        blocks = jnp.asarray(_CHEMICAL_BLOCKS)[species]
1617        nat = species.shape[0]
1618
1619        sizes = state["sizes"]
1620
1621        # species_index = {
1622        # PERIODIC_TABLE[s]: jnp.nonzero(species == s, size=c, fill_value=nat)[0]
1623        # for s, c in sizes.items()
1624        # }
1625        block_index = {n: None for n in _CHEMICAL_BLOCKS_NAMES}
1626        overflow = False
1627        natcount = 0
1628        for (s,name), c in sizes.items():
1629            mask = blocks == s
1630            new_size = jnp.sum(mask)
1631            natcount = natcount + new_size
1632            overflow = overflow | (new_size > c)  # check if sizes are correct
1633            block_index[name] = jnp.nonzero(
1634                mask, size=c, fill_value=nat
1635            )[0]
1636
1637        mask = blocks < 0
1638        new_size = jnp.sum(mask)
1639        natcount = natcount + new_size
1640        overflow = overflow | (
1641            natcount < species.shape[0]
1642        )  # check if any species missing
1643
1644        return {
1645            **inputs,
1646            self.output_key: block_index,
1647            self.output_key + "_overflow": overflow,
1648        }
@partial(jax.jit, static_argnums=(0,))
def update_skin(self, inputs):
1650    @partial(jax.jit, static_argnums=(0,))
1651    def update_skin(self, inputs):
1652        return self.process(None, inputs)
@dataclasses.dataclass(frozen=True)
class AtomPadding:
1655@dataclasses.dataclass(frozen=True)
1656class AtomPadding:
1657    """Pad atomic arrays to a fixed size."""
1658
1659    mult_size: float = 1.2
1660    """Multiplicative factor for resizing the atomic arrays."""
1661    add_sys: int = 0
1662
1663    def init(self):
1664        return {"prev_nat": 0, "prev_nsys": 0}
1665
1666    def __call__(self, state, inputs: Dict) -> Union[dict, jax.Array]:
1667        species = inputs["species"]
1668        nat = species.shape[0]
1669
1670        prev_nat = state.get("prev_nat", 0)
1671        prev_nat_ = prev_nat
1672        if nat > prev_nat_:
1673            prev_nat_ = int(self.mult_size * nat) + 1
1674
1675        nsys = len(inputs["natoms"])
1676        prev_nsys = state.get("prev_nsys", 0)
1677        prev_nsys_ = prev_nsys
1678        if nsys > prev_nsys_:
1679            prev_nsys_ = nsys + self.add_sys
1680
1681        add_atoms = prev_nat_ - nat
1682        add_sys = prev_nsys_ - nsys  + 1
1683        output = {**inputs}
1684        if add_atoms > 0:
1685            for k, v in inputs.items():
1686                if isinstance(v, np.ndarray) or isinstance(v, jax.Array):
1687                    if v.shape[0] == nat:
1688                        output[k] = np.append(
1689                            v,
1690                            np.zeros((add_atoms, *v.shape[1:]), dtype=v.dtype),
1691                            axis=0,
1692                        )
1693                    elif v.shape[0] == nsys:
1694                        if k == "cells":
1695                            output[k] = np.append(
1696                                v,
1697                                1000
1698                                * np.eye(3, dtype=v.dtype)[None, :, :].repeat(
1699                                    add_sys, axis=0
1700                                ),
1701                                axis=0,
1702                            )
1703                        else:
1704                            output[k] = np.append(
1705                                v,
1706                                np.zeros((add_sys, *v.shape[1:]), dtype=v.dtype),
1707                                axis=0,
1708                            )
1709            output["natoms"] = np.append(
1710                inputs["natoms"], np.zeros(add_sys, dtype=np.int32)
1711            )
1712            output["species"] = np.append(
1713                species, -1 * np.ones(add_atoms, dtype=species.dtype)
1714            )
1715            output["batch_index"] = np.append(
1716                inputs["batch_index"],
1717                np.array([output["natoms"].shape[0] - 1] * add_atoms, dtype=inputs["batch_index"].dtype),
1718            )
1719            if "system_index" in inputs:
1720                output["system_index"] = np.append(
1721                    inputs["system_index"],
1722                    np.array([output["natoms"].shape[0] - 1] * add_sys, dtype=inputs["system_index"].dtype),
1723                )
1724
1725        output["true_atoms"] = output["species"] > 0
1726        output["true_sys"] = np.arange(len(output["natoms"])) < nsys
1727
1728        state = {**state, "prev_nat": prev_nat_, "prev_nsys": prev_nsys_}
1729
1730        return FrozenDict(state), output

Pad atomic arrays to a fixed size.

AtomPadding(mult_size: float = 1.2, add_sys: int = 0)
mult_size: float = 1.2

Multiplicative factor for resizing the atomic arrays.

add_sys: int = 0
def init(self):
1663    def init(self):
1664        return {"prev_nat": 0, "prev_nsys": 0}
def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1733def atom_unpadding(inputs: Dict[str, Any]) -> Dict[str, Any]:
1734    """Remove padding from atomic arrays."""
1735    if "true_atoms" not in inputs:
1736        return inputs
1737
1738    species = np.asarray(inputs["species"])
1739    true_atoms = np.asarray(inputs["true_atoms"])
1740    true_sys = np.asarray(inputs["true_sys"])
1741    natall = species.shape[0]
1742    nat = np.argmax(species <= 0)
1743    if nat == 0:
1744        return inputs
1745
1746    natoms = inputs["natoms"]
1747    nsysall = len(natoms)
1748
1749    output = {**inputs}
1750    for k, v in inputs.items():
1751        if isinstance(v, jax.Array) or isinstance(v, np.ndarray):
1752            v = np.asarray(v)
1753            if v.ndim == 0:
1754                output[k] = v
1755            elif v.shape[0] == natall:
1756                output[k] = v[true_atoms]
1757            elif v.shape[0] == nsysall:
1758                output[k] = v[true_sys]
1759    del output["true_sys"]
1760    del output["true_atoms"]
1761    return output

Remove padding from atomic arrays.

def check_input(inputs):
1764def check_input(inputs):
1765    """Check the input dictionary for required keys and types."""
1766    assert "species" in inputs, "species must be provided"
1767    assert "coordinates" in inputs, "coordinates must be provided"
1768    species = inputs["species"].astype(np.int32)
1769    ifake = np.argmax(species <= 0)
1770    if ifake > 0:
1771        assert np.all(species[:ifake] > 0), "species must be positive"
1772    nat = inputs["species"].shape[0]
1773
1774    natoms = inputs.get("natoms", np.array([nat], dtype=np.int32)).astype(np.int32)
1775    batch_index = inputs.get(
1776        "batch_index", np.repeat(np.arange(len(natoms), dtype=np.int32), natoms)
1777    ).astype(np.int32)
1778    output = {**inputs, "natoms": natoms, "batch_index": batch_index}
1779    if "cells" in inputs:
1780        cells = inputs["cells"]
1781        if "reciprocal_cells" not in inputs:
1782            reciprocal_cells = np.linalg.inv(cells)
1783        else:
1784            reciprocal_cells = inputs["reciprocal_cells"]
1785        if cells.ndim == 2:
1786            cells = cells[None, :, :]
1787        if reciprocal_cells.ndim == 2:
1788            reciprocal_cells = reciprocal_cells[None, :, :]
1789        output["cells"] = cells
1790        output["reciprocal_cells"] = reciprocal_cells
1791
1792    return output

Check the input dictionary for required keys and types.

def convert_to_jax(data):
1795def convert_to_jax(data):
1796    """Convert a numpy arrays to jax arrays in a pytree."""
1797
1798    def convert(x):
1799        if isinstance(x, np.ndarray):
1800            # if x.dtype == np.float64:
1801            #     return jnp.asarray(x, dtype=jnp.float32)
1802            return jnp.asarray(x)
1803        return x
1804
1805    return jax.tree_util.tree_map(convert, data)

Convert a numpy arrays to jax arrays in a pytree.

class JaxConverter(flax.linen.module.Module):
1808class JaxConverter(nn.Module):
1809    """Convert numpy arrays to jax arrays in a pytree."""
1810
1811    def __call__(self, data):
1812        return convert_to_jax(data)

Convert numpy arrays to jax arrays in a pytree.

JaxConverter( parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@dataclasses.dataclass(frozen=True)
class PreprocessingChain:
1815@dataclasses.dataclass(frozen=True)
1816class PreprocessingChain:
1817    """Chain of preprocessing layers."""
1818
1819    layers: Tuple[Callable[..., Dict[str, Any]]]
1820    """Preprocessing layers."""
1821    use_atom_padding: bool = False
1822    """Add an AtomPadding layer at the beginning of the chain."""
1823    atom_padder: AtomPadding = AtomPadding()
1824    """AtomPadding layer."""
1825
1826    def __post_init__(self):
1827        if not isinstance(self.layers, Sequence):
1828            raise ValueError(
1829                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
1830            )
1831        if not self.layers:
1832            raise ValueError(f"Error: no Preprocessing layers were provided.")
1833
1834    def __call__(self, state, inputs: Dict[str, Any]) -> Dict[str, Any]:
1835        do_check_input = state.get("check_input", True)
1836        if do_check_input:
1837            inputs = check_input(inputs)
1838        new_state = {**state}
1839        if self.use_atom_padding:
1840            s, inputs = self.atom_padder(state["padder_state"], inputs)
1841            new_state["padder_state"] = s
1842        layer_state = state["layers_state"]
1843        new_layer_state = []
1844        for i,layer in enumerate(self.layers):
1845            s, inputs = layer(layer_state[i], inputs, return_state_update=False)
1846            new_layer_state.append(s)
1847        new_state["layers_state"] = tuple(new_layer_state)
1848        return FrozenDict(new_state), convert_to_jax(inputs)
1849
1850    def check_reallocate(self, state, inputs):
1851        new_state = []
1852        state_up = []
1853        layer_state = state["layers_state"]
1854        parent_overflow = False
1855        for i,layer in enumerate(self.layers):
1856            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1857                layer_state[i], inputs, parent_overflow
1858            )
1859            new_state.append(s)
1860            state_up.append(s_up)
1861
1862        if not parent_overflow:
1863            return state, {}, inputs, False
1864        return (
1865            FrozenDict({**state, "layers_state": tuple(new_state)}),
1866            state_up,
1867            inputs,
1868            True,
1869        )
1870
1871    def atom_padding(self, state, inputs):
1872        if self.use_atom_padding:
1873            padder_state,inputs = self.atom_padder(state["padder_state"], inputs)
1874            return FrozenDict({**state,"padder_state": padder_state}), inputs
1875        return state, inputs
1876
1877    @partial(jax.jit, static_argnums=(0, 1))
1878    def process(self, state, inputs):
1879        layer_state = state["layers_state"]
1880        for i,layer in enumerate(self.layers):
1881            inputs = layer.process(layer_state[i], inputs)
1882        return inputs
1883
1884    @partial(jax.jit, static_argnums=(0))
1885    def update_skin(self, inputs):
1886        for layer in self.layers:
1887            inputs = layer.update_skin(inputs)
1888        return inputs
1889
1890    def init(self):
1891        state = {"check_input": True}
1892        if self.use_atom_padding:
1893            state["padder_state"] = self.atom_padder.init()
1894        layer_state = []
1895        for layer in self.layers:
1896            layer_state.append(layer.init())
1897        state["layers_state"] = tuple(layer_state)
1898        return FrozenDict(state)
1899
1900    def init_with_output(self, inputs):
1901        state = self.init()
1902        return self(state, inputs)
1903
1904    def get_processors(self):
1905        processors = []
1906        for layer in self.layers:
1907            if hasattr(layer, "get_processor"):
1908                processors.append(layer.get_processor())
1909        return processors
1910
1911    def get_graphs_properties(self):
1912        properties = {}
1913        for layer in self.layers:
1914            if hasattr(layer, "get_graph_properties"):
1915                properties = deep_update(properties, layer.get_graph_properties())
1916        return properties

Chain of preprocessing layers.

PreprocessingChain( layers: Tuple[Callable[..., Dict[str, Any]]], use_atom_padding: bool = False, atom_padder: AtomPadding = AtomPadding(mult_size=1.2, add_sys=0))
layers: Tuple[Callable[..., Dict[str, Any]]]

Preprocessing layers.

use_atom_padding: bool = False

Add an AtomPadding layer at the beginning of the chain.

atom_padder: AtomPadding = AtomPadding(mult_size=1.2, add_sys=0)

AtomPadding layer.

def check_reallocate(self, state, inputs):
1850    def check_reallocate(self, state, inputs):
1851        new_state = []
1852        state_up = []
1853        layer_state = state["layers_state"]
1854        parent_overflow = False
1855        for i,layer in enumerate(self.layers):
1856            s, s_up, inputs, parent_overflow = layer.check_reallocate(
1857                layer_state[i], inputs, parent_overflow
1858            )
1859            new_state.append(s)
1860            state_up.append(s_up)
1861
1862        if not parent_overflow:
1863            return state, {}, inputs, False
1864        return (
1865            FrozenDict({**state, "layers_state": tuple(new_state)}),
1866            state_up,
1867            inputs,
1868            True,
1869        )
def atom_padding(self, state, inputs):
1871    def atom_padding(self, state, inputs):
1872        if self.use_atom_padding:
1873            padder_state,inputs = self.atom_padder(state["padder_state"], inputs)
1874            return FrozenDict({**state,"padder_state": padder_state}), inputs
1875        return state, inputs
@partial(jax.jit, static_argnums=(0, 1))
def process(self, state, inputs):
1877    @partial(jax.jit, static_argnums=(0, 1))
1878    def process(self, state, inputs):
1879        layer_state = state["layers_state"]
1880        for i,layer in enumerate(self.layers):
1881            inputs = layer.process(layer_state[i], inputs)
1882        return inputs
@partial(jax.jit, static_argnums=0)
def update_skin(self, inputs):
1884    @partial(jax.jit, static_argnums=(0))
1885    def update_skin(self, inputs):
1886        for layer in self.layers:
1887            inputs = layer.update_skin(inputs)
1888        return inputs
def init(self):
1890    def init(self):
1891        state = {"check_input": True}
1892        if self.use_atom_padding:
1893            state["padder_state"] = self.atom_padder.init()
1894        layer_state = []
1895        for layer in self.layers:
1896            layer_state.append(layer.init())
1897        state["layers_state"] = tuple(layer_state)
1898        return FrozenDict(state)
def init_with_output(self, inputs):
1900    def init_with_output(self, inputs):
1901        state = self.init()
1902        return self(state, inputs)
def get_processors(self):
1904    def get_processors(self):
1905        processors = []
1906        for layer in self.layers:
1907            if hasattr(layer, "get_processor"):
1908                processors.append(layer.get_processor())
1909        return processors
def get_graphs_properties(self):
1911    def get_graphs_properties(self):
1912        properties = {}
1913        for layer in self.layers:
1914            if hasattr(layer, "get_graph_properties"):
1915                properties = deep_update(properties, layer.get_graph_properties())
1916        return properties