fennol.models.fennix

  1from typing import Any, Sequence, Callable, Union, Optional, Tuple, Dict
  2from copy import deepcopy
  3import dataclasses
  4from collections import OrderedDict
  5
  6import jax
  7import jax.numpy as jnp
  8import flax.linen as nn
  9import numpy as np
 10from flax import serialization
 11from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
 12from ..utils import AtomicUnits as au
 13
 14from .preprocessing import (
 15    GraphGenerator,
 16    PreprocessingChain,
 17    JaxConverter,
 18    atom_unpadding,
 19    check_input,
 20)
 21from .modules import MODULES, PREPROCESSING, FENNIXModules
 22
 23
 24@dataclasses.dataclass
 25class FENNIX:
 26    """
 27    Static wrapper for FENNIX models
 28
 29    The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary
 30    which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization.
 31
 32    Since the model is static and contains variables, it must be initialized right away with either
 33    `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data`
 34    is provided, the model is initialized with `example_data` and the resulting variables are stored
 35    in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting.
 36    """
 37
 38    cutoff: Union[float, None]
 39    modules: FENNIXModules
 40    variables: Dict
 41    preprocessing: PreprocessingChain
 42    _apply: Callable[[Dict, Dict], Dict]
 43    _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]]
 44    _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]]
 45    _input_args: Dict
 46    _graphs_properties: Dict
 47    preproc_state: Dict
 48    energy_terms: Optional[Sequence[str]] = None
 49    _initializing: bool = True
 50    use_atom_padding: bool = False
 51
 52    def __init__(
 53        self,
 54        cutoff: float,
 55        modules: OrderedDict,
 56        preprocessing: OrderedDict = OrderedDict(),
 57        example_data=None,
 58        rng_key: Optional[jax.random.PRNGKey] = None,
 59        variables: Optional[dict] = None,
 60        energy_terms: Optional[Sequence[str]] = None,
 61        use_atom_padding: bool = False,
 62        graph_config: Dict = {},
 63        energy_unit: str = "Ha",
 64        **kwargs,
 65    ) -> None:
 66        """Initialize the FENNIX model
 67
 68        Arguments:
 69        ----------
 70        cutoff: float
 71            The cutoff radius for the model
 72        modules: OrderedDict
 73            The dictionary defining the sequence of FeNNol modules and their parameters.
 74        preprocessing: OrderedDict
 75            The dictionary defining the sequence of preprocessing modules and their parameters.
 76        example_data: dict
 77            Example data to initialize the model. If not provided, a dummy system is generated.
 78        rng_key: jax.random.PRNGKey
 79            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 80        variables: dict
 81            The variables of the model (i.e. weights, biases and all other tunable parameters).
 82            If not provided, the variables are initialized (usually at random)
 83        energy_terms: Sequence[str]
 84            The energy terms in the model output that will be summed to compute the total energy.
 85            If None, the total energy is always zero (useful for non-PES models).
 86        use_atom_padding: bool
 87            If True, the model will use atom padding for the input data.
 88            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 89        graph_config: dict
 90            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 91
 92        """
 93        self._input_args = {
 94            "cutoff": cutoff,
 95            "modules": OrderedDict(modules),
 96            "preprocessing": OrderedDict(preprocessing),
 97            "energy_unit": energy_unit,
 98        }
 99        self.energy_unit = energy_unit
100        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
101        self.cutoff = cutoff
102        self.energy_terms = energy_terms
103        self.use_atom_padding = use_atom_padding
104
105        # add non-differentiable/non-jittable modules
106        preprocessing = deepcopy(preprocessing)
107        if cutoff is None:
108            preprocessing_modules = []
109        else:
110            prep_keys = list(preprocessing.keys())
111            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
112            if len(prep_keys) > 0 and prep_keys[0] == "graph":
113                graph_params = {
114                    **graph_params,
115                    **preprocessing.pop("graph"),
116                }
117            graph_params = {**graph_params, **graph_config}
118
119            preprocessing_modules = [
120                GraphGenerator(**graph_params),
121            ]
122
123        for name, params in preprocessing.items():
124            key = str(params.pop("module_name")) if "module_name" in params else name
125            key = str(params.pop("FID")) if "FID" in params else key
126            mod = PREPROCESSING[key.upper()](**freeze(params))
127            preprocessing_modules.append(mod)
128
129        self.preprocessing = PreprocessingChain(
130            tuple(preprocessing_modules), use_atom_padding
131        )
132        graphs_properties = self.preprocessing.get_graphs_properties()
133        self._graphs_properties = freeze(graphs_properties)
134        # add preprocessing modules that should be differentiated/jitted
135        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
136        # mods = self.preprocessing.get_processors(return_list=True)
137
138        # build the model
139        modules = deepcopy(modules)
140        modules_names = []
141        for name, params in modules.items():
142            key = str(params.pop("module_name")) if "module_name" in params else name
143            key = str(params.pop("FID")) if "FID" in params else key
144            if name in modules_names:
145                raise ValueError(f"Module {name} already exists")
146            modules_names.append(name)
147            params["name"] = name
148            mod = MODULES[key.upper()]
149            fields = [f.name for f in dataclasses.fields(mod)]
150            if "_graphs_properties" in fields:
151                params["_graphs_properties"] = graphs_properties
152            if "_energy_unit" in fields:
153                params["_energy_unit"] = energy_unit
154            mods.append((mod, params))
155
156        self.modules = FENNIXModules(mods)
157
158        self.__apply = self.modules.apply
159        self._apply = jax.jit(self.modules.apply)
160
161        self.set_energy_terms(energy_terms)
162
163        # initialize the model
164
165        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
166
167        if variables is not None:
168            self.variables = variables
169        elif rng_key is not None:
170            self.variables = self.modules.init(rng_key, inputs)
171        else:
172            raise ValueError(
173                "Either variables or a jax.random.PRNGKey must be provided for initialization"
174            )
175
176        self._initializing = False
177
178    def set_energy_terms(
179        self, energy_terms: Union[Sequence[str], None], jit: bool = True
180    ) -> None:
181        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
182        object.__setattr__(self, "energy_terms", energy_terms)
183        if isinstance(energy_terms, str):
184            energy_terms = [energy_terms]
185        
186        if energy_terms is None or len(energy_terms) == 0:
187
188            def total_energy(variables, data):
189                out = self.__apply(variables, data)
190                coords = out["coordinates"]
191                nsys = out["natoms"].shape[0]
192                nat = coords.shape[0]
193                dtype = coords.dtype
194                e = jnp.zeros(nsys, dtype=dtype)
195                eat = jnp.zeros(nat, dtype=dtype)
196                out["total_energy"] = e
197                out["atomic_energies"] = eat
198                return e, out
199
200            def energy_and_forces(variables, data):
201                e, out = total_energy(variables, data)
202                f = jnp.zeros_like(out["coordinates"])
203                out["forces"] = f
204                return e, f, out
205
206            def energy_and_forces_and_virial(variables, data):
207                e, f, out = energy_and_forces(variables, data)
208                v = jnp.zeros(
209                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
210                )
211                out["virial_tensor"] = v
212                return e, f, v, out
213
214        else:
215            # build the energy and force functions
216            def total_energy(variables, data):
217                out = self.__apply(variables, data)
218                atomic_energies = 0.0
219                system_energies = 0.0
220                species = out["species"]
221                nsys = out["natoms"].shape[0]
222                for term in self.energy_terms:
223                    e = out[term]
224                    if e.ndim > 1 and e.shape[-1] == 1:
225                        e = jnp.squeeze(e, axis=-1)
226                    if e.shape[0] == nsys and nsys != species.shape[0]:
227                        system_energies += e
228                        continue
229                    assert e.shape == species.shape
230                    atomic_energies += e
231                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
232                if isinstance(atomic_energies, jnp.ndarray):
233                    if "true_atoms" in out:
234                        atomic_energies = jnp.where(
235                            out["true_atoms"], atomic_energies, 0.0
236                        )
237                    out["atomic_energies"] = atomic_energies
238                    energies = jax.ops.segment_sum(
239                        atomic_energies,
240                        data["batch_index"],
241                        num_segments=len(data["natoms"]),
242                    )
243                else:
244                    energies = 0.0
245
246                if isinstance(system_energies, jnp.ndarray):
247                    if "true_sys" in out:
248                        system_energies = jnp.where(
249                            out["true_sys"], system_energies, 0.0
250                        )
251                    out["system_energies"] = system_energies
252
253                out["total_energy"] = energies + system_energies
254                return out["total_energy"], out
255
256            def energy_and_forces(variables, data):
257                def _etot(variables, coordinates):
258                    energy, out = total_energy(
259                        variables, {**data, "coordinates": coordinates}
260                    )
261                    return energy.sum(), out
262
263                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
264                    variables, data["coordinates"]
265                )
266                out["forces"] = -de
267
268                return out["total_energy"], out["forces"], out
269
270            # def energy_and_forces_and_virial(variables, data):
271            #     x = data["coordinates"]
272            #     batch_index = data["batch_index"]
273            #     if "cells" in data:
274            #         cells = data["cells"]
275            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
276
277            #         def _etot(variables, coordinates, cells):
278            #             reciprocal_cells = jnp.linalg.inv(cells)
279            #             energy, out = total_energy(
280            #                 variables,
281            #                 {
282            #                     **data,
283            #                     "coordinates": coordinates,
284            #                     "cells": cells,
285            #                     "reciprocal_cells": reciprocal_cells,
286            #                 },
287            #             )
288            #             return energy.sum(), out
289
290            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
291            #             variables, x, cells
292            #         )
293            #         f= -dedx
294            #         out["forces"] = f
295            #     else:
296            #         _,f,out = energy_and_forces(variables, data)
297
298            #     vir = -jax.ops.segment_sum(
299            #         f[:, :, None] * x[:, None, :],
300            #         batch_index,
301            #         num_segments=len(data["natoms"]),
302            #     )
303
304            #     if "cells" in data:
305            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
306            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
307            #         nsys = data["natoms"].shape[0]
308            #         if cells.shape[0]==1 and nsys>1:
309            #             dvir = dvir / nsys
310            #         vir = vir + dvir
311
312            #     out["virial_tensor"] = vir
313
314            #     return out["total_energy"], f, vir, out
315
316            def energy_and_forces_and_virial(variables, data):
317                x = data["coordinates"]
318                scaling = jnp.asarray(
319                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
320                )
321                def _etot(variables, coordinates, scaling):
322                    batch_index = data["batch_index"]
323                    coordinates = jax.vmap(jnp.matmul)(
324                        coordinates, scaling[batch_index]
325                    )
326                    inputs = {**data, "coordinates": coordinates}
327                    if "cells" in data:
328                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
329                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
330                        reciprocal_cells = jnp.linalg.inv(cells)
331                        inputs["cells"] = cells
332                        inputs["reciprocal_cells"] = reciprocal_cells
333                    energy, out = total_energy(variables, inputs)
334                    return energy.sum(), out
335
336                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
337                    variables, x, scaling
338                )
339                f = -dedx
340                out["forces"] = f
341                out["virial_tensor"] = vir
342
343                return out["total_energy"], f, vir, out
344
345        object.__setattr__(self, "_total_energy_raw", total_energy)
346        if jit:
347            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
348            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
349            object.__setattr__(
350                self,
351                "_energy_and_forces_and_virial",
352                jax.jit(energy_and_forces_and_virial),
353            )
354        else:
355            object.__setattr__(self, "_total_energy", total_energy)
356            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
357            object.__setattr__(
358                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
359            )
360
361    def get_gradient_function(
362        self,
363        *gradient_keys: Sequence[str],
364        jit: bool = True,
365        variables_as_input: bool = False,
366    ):
367        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
368
369        def _energy_gradient(variables, data):
370            def _etot(variables, inputs):
371                if "strain" in inputs:
372                    scaling = inputs["strain"]
373                    batch_index = data["batch_index"]
374                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
375                    coordinates = jax.vmap(jnp.matmul)(
376                        coordinates, scaling[batch_index]
377                    )
378                    inputs = {**inputs, "coordinates": coordinates}
379                    if "cells" in inputs or "cells" in data:
380                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
381                        cells = jax.vmap(jnp.matmul)(cells, scaling)
382                        inputs["cells"] = cells
383                if "cells" in inputs:
384                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
385                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
386                energy, out = self._total_energy_raw(variables, {**data, **inputs})
387                return energy.sum(), out
388
389            if "strain" in gradient_keys and "strain" not in data:
390                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
391            inputs = {k: data[k] for k in gradient_keys}
392            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
393
394            return (
395                out["total_energy"],
396                de,
397                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
398            )
399
400        if variables_as_input:
401            energy_gradient = _energy_gradient
402        else:
403
404            def energy_gradient(data):
405                return _energy_gradient(self.variables, data)
406
407        if jit:
408            return jax.jit(energy_gradient)
409        else:
410            return energy_gradient
411
412    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
413        """apply preprocessing to the input data
414
415        !!! This is not a pure function => do not apply jax transforms !!!"""
416        if self.preproc_state is None:
417            out, _ = self.reinitialize_preprocessing(example_data=inputs)
418        elif use_gpu:
419            do_check_input = self.preproc_state.get("check_input", True)
420            if do_check_input:
421                inputs = check_input(inputs)
422            preproc_state, inputs = self.preprocessing.atom_padding(
423                self.preproc_state, inputs
424            )
425            inputs = self.preprocessing.process(preproc_state, inputs)
426            preproc_state, state_up, out, overflow = (
427                self.preprocessing.check_reallocate(
428                    preproc_state, inputs
429                )
430            )
431            if verbose and overflow:
432                print("GPU preprocessing: nblist overflow => reallocating nblist")
433                print("size updates:", state_up)
434        else:
435            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
436
437        object.__setattr__(self, "preproc_state", preproc_state)
438        return out
439
440    def reinitialize_preprocessing(
441        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
442    ) -> None:
443        ### TODO ###
444        if rng_key is None:
445            rng_key_pre = jax.random.PRNGKey(0)
446        else:
447            rng_key, rng_key_pre = jax.random.split(rng_key)
448
449        if example_data is None:
450            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
451            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
452
453        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
454        object.__setattr__(self, "preproc_state", preproc_state)
455        return inputs, rng_key
456
457    def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]:
458        """Apply the FENNIX model (preprocess + modules)
459
460        !!! This is not a pure function => do not apply jax transforms !!!
461        if you want to apply jax transforms, use  self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess
462        """
463        if variables is None:
464            variables = self.variables
465        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
466        output = self._apply(variables, inputs)
467        if self.use_atom_padding:
468            output = atom_unpadding(output)
469        return output
470
471    def total_energy(
472        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
473    ) -> Tuple[jnp.ndarray, Dict]:
474        """compute the total energy of the system
475
476        !!! This is not a pure function => do not apply jax transforms !!!
477        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
478        """
479        if variables is None:
480            variables = self.variables
481        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
482        # def print_shape(path,value):
483        #     if isinstance(value,jnp.ndarray):
484        #         print(path,value.shape)
485        #     else:
486        #         print(path,value)
487        # jax.tree_util.tree_map_with_path(print_shape,inputs)
488        _, output = self._total_energy(variables, inputs)
489        if self.use_atom_padding:
490            output = atom_unpadding(output)
491        e = output["total_energy"]
492        if unit is not None:
493            model_energy_unit = self.Ha_to_model_energy
494            if isinstance(unit, str):
495                unit = au.get_multiplier(unit)
496            e = e * (unit / model_energy_unit)
497        return e, output
498
499    def energy_and_forces(
500        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
501    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
502        """compute the total energy and forces of the system
503
504        !!! This is not a pure function => do not apply jax transforms !!!
505        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
506        """
507        if variables is None:
508            variables = self.variables
509        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
510        _, _, output = self._energy_and_forces(variables, inputs)
511        if self.use_atom_padding:
512            output = atom_unpadding(output)
513        e = output["total_energy"]
514        f = output["forces"]
515        if unit is not None:
516            model_energy_unit = self.Ha_to_model_energy
517            if isinstance(unit, str):
518                unit = au.get_multiplier(unit)
519            e = e * (unit / model_energy_unit)
520            f = f * (unit / model_energy_unit)
521        return e, f, output
522
523    def energy_and_forces_and_virial(
524        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
525    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
526        """compute the total energy and forces of the system
527
528        !!! This is not a pure function => do not apply jax transforms !!!
529        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
530        """
531        if variables is None:
532            variables = self.variables
533        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
534        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
535        if self.use_atom_padding:
536            output = atom_unpadding(output)
537        e = output["total_energy"]
538        f = output["forces"]
539        v = output["virial_tensor"]
540        if unit is not None:
541            model_energy_unit = self.Ha_to_model_energy
542            if isinstance(unit, str):
543                unit = au.get_multiplier(unit)
544            e = e * (unit / model_energy_unit)
545            f = f * (unit / model_energy_unit)
546            v = v * (unit / model_energy_unit)
547        return e, f, v, output
548
549    def remove_atom_padding(self, output):
550        """remove atom padding from the output"""
551        return atom_unpadding(output)
552
553    def get_model(self) -> Tuple[FENNIXModules, Dict]:
554        """return the model and its variables"""
555        return self.modules, self.variables
556
557    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
558        """return the preprocessing chain and its state"""
559        return self.preprocessing, self.preproc_state
560
561    def __setattr__(self, __name: str, __value: Any) -> None:
562        if __name == "variables":
563            if __value is not None:
564                if not (
565                    isinstance(__value, dict)
566                    or isinstance(__value, OrderedDict)
567                    or isinstance(__value, FrozenDict)
568                ):
569                    raise ValueError(f"{__name} must be a dict")
570                object.__setattr__(self, __name, JaxConverter()(__value))
571            else:
572                raise ValueError(f"{__name} cannot be None")
573        elif __name == "preproc_state":
574            if __value is not None:
575                if not (
576                    isinstance(__value, dict)
577                    or isinstance(__value, OrderedDict)
578                    or isinstance(__value, FrozenDict)
579                ):
580                    raise ValueError(f"{__name} must be a FrozenDict")
581                object.__setattr__(self, __name, freeze(JaxConverter()(__value)))
582            else:
583                raise ValueError(f"{__name} cannot be None")
584
585        elif self._initializing:
586            object.__setattr__(self, __name, __value)
587        else:
588            raise ValueError(f"{__name} attribute of FENNIX model is immutable.")
589
590    def generate_dummy_system(
591        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
592    ) -> Dict[str, Any]:
593        """
594        Generate dummy system for initialization
595        """
596        if box_size is None:
597            box_size = 2 * self.cutoff
598            for g in self._graphs_properties.values():
599                cutoff = g["cutoff"]
600                if cutoff is not None:
601                    box_size = min(box_size, 2 * g["cutoff"])
602        coordinates = np.array(
603            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
604        )
605        species = np.ones((n_atoms,), dtype=np.int32)
606        batch_index = np.zeros((n_atoms,), dtype=np.int32)
607        natoms = np.array([n_atoms], dtype=np.int32)
608        return {
609            "species": species,
610            "coordinates": coordinates,
611            # "graph": graph,
612            "batch_index": batch_index,
613            "natoms": natoms,
614        }
615
616    def summarize(
617        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
618    ) -> str:
619        """Summarize the model architecture and parameters"""
620        if rng_key is None:
621            head = "Summarizing with example data:\n"
622            rng_key = jax.random.PRNGKey(0)
623        if example_data is None:
624            head = "Summarizing with dummy 10 atoms system:\n"
625            rng_key, rng_key_sys = jax.random.split(rng_key)
626            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
627        rng_key, rng_key_pre = jax.random.split(rng_key)
628        _, inputs = self.preprocessing.init_with_output(example_data)
629        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
630
631    def to_dict(self):
632        """return a dictionary representation of the model"""
633        return {
634            **self._input_args,
635            "energy_terms": self.energy_terms,
636            "variables": deepcopy(self.variables),
637        }
638
639    def save(self, filename):
640        """save the model to a file"""
641        state_dict = self.to_dict()
642        state_dict["preprocessing"] = [
643            [k, v] for k, v in state_dict["preprocessing"].items()
644        ]
645        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
646        with open(filename, "wb") as f:
647            f.write(serialization.msgpack_serialize(state_dict))
648
649    @classmethod
650    def load(
651        cls,
652        filename,
653        use_atom_padding=False,
654        graph_config={},
655    ):
656        """load a model from a file"""
657        with open(filename, "rb") as f:
658            state_dict = serialization.msgpack_restore(f.read())
659        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
660        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
661        return cls(
662            **state_dict,
663            graph_config=graph_config,
664            use_atom_padding=use_atom_padding,
665        )
@dataclasses.dataclass
class FENNIX:
 25@dataclasses.dataclass
 26class FENNIX:
 27    """
 28    Static wrapper for FENNIX models
 29
 30    The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary
 31    which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization.
 32
 33    Since the model is static and contains variables, it must be initialized right away with either
 34    `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data`
 35    is provided, the model is initialized with `example_data` and the resulting variables are stored
 36    in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting.
 37    """
 38
 39    cutoff: Union[float, None]
 40    modules: FENNIXModules
 41    variables: Dict
 42    preprocessing: PreprocessingChain
 43    _apply: Callable[[Dict, Dict], Dict]
 44    _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]]
 45    _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]]
 46    _input_args: Dict
 47    _graphs_properties: Dict
 48    preproc_state: Dict
 49    energy_terms: Optional[Sequence[str]] = None
 50    _initializing: bool = True
 51    use_atom_padding: bool = False
 52
 53    def __init__(
 54        self,
 55        cutoff: float,
 56        modules: OrderedDict,
 57        preprocessing: OrderedDict = OrderedDict(),
 58        example_data=None,
 59        rng_key: Optional[jax.random.PRNGKey] = None,
 60        variables: Optional[dict] = None,
 61        energy_terms: Optional[Sequence[str]] = None,
 62        use_atom_padding: bool = False,
 63        graph_config: Dict = {},
 64        energy_unit: str = "Ha",
 65        **kwargs,
 66    ) -> None:
 67        """Initialize the FENNIX model
 68
 69        Arguments:
 70        ----------
 71        cutoff: float
 72            The cutoff radius for the model
 73        modules: OrderedDict
 74            The dictionary defining the sequence of FeNNol modules and their parameters.
 75        preprocessing: OrderedDict
 76            The dictionary defining the sequence of preprocessing modules and their parameters.
 77        example_data: dict
 78            Example data to initialize the model. If not provided, a dummy system is generated.
 79        rng_key: jax.random.PRNGKey
 80            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 81        variables: dict
 82            The variables of the model (i.e. weights, biases and all other tunable parameters).
 83            If not provided, the variables are initialized (usually at random)
 84        energy_terms: Sequence[str]
 85            The energy terms in the model output that will be summed to compute the total energy.
 86            If None, the total energy is always zero (useful for non-PES models).
 87        use_atom_padding: bool
 88            If True, the model will use atom padding for the input data.
 89            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 90        graph_config: dict
 91            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 92
 93        """
 94        self._input_args = {
 95            "cutoff": cutoff,
 96            "modules": OrderedDict(modules),
 97            "preprocessing": OrderedDict(preprocessing),
 98            "energy_unit": energy_unit,
 99        }
100        self.energy_unit = energy_unit
101        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
102        self.cutoff = cutoff
103        self.energy_terms = energy_terms
104        self.use_atom_padding = use_atom_padding
105
106        # add non-differentiable/non-jittable modules
107        preprocessing = deepcopy(preprocessing)
108        if cutoff is None:
109            preprocessing_modules = []
110        else:
111            prep_keys = list(preprocessing.keys())
112            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
113            if len(prep_keys) > 0 and prep_keys[0] == "graph":
114                graph_params = {
115                    **graph_params,
116                    **preprocessing.pop("graph"),
117                }
118            graph_params = {**graph_params, **graph_config}
119
120            preprocessing_modules = [
121                GraphGenerator(**graph_params),
122            ]
123
124        for name, params in preprocessing.items():
125            key = str(params.pop("module_name")) if "module_name" in params else name
126            key = str(params.pop("FID")) if "FID" in params else key
127            mod = PREPROCESSING[key.upper()](**freeze(params))
128            preprocessing_modules.append(mod)
129
130        self.preprocessing = PreprocessingChain(
131            tuple(preprocessing_modules), use_atom_padding
132        )
133        graphs_properties = self.preprocessing.get_graphs_properties()
134        self._graphs_properties = freeze(graphs_properties)
135        # add preprocessing modules that should be differentiated/jitted
136        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
137        # mods = self.preprocessing.get_processors(return_list=True)
138
139        # build the model
140        modules = deepcopy(modules)
141        modules_names = []
142        for name, params in modules.items():
143            key = str(params.pop("module_name")) if "module_name" in params else name
144            key = str(params.pop("FID")) if "FID" in params else key
145            if name in modules_names:
146                raise ValueError(f"Module {name} already exists")
147            modules_names.append(name)
148            params["name"] = name
149            mod = MODULES[key.upper()]
150            fields = [f.name for f in dataclasses.fields(mod)]
151            if "_graphs_properties" in fields:
152                params["_graphs_properties"] = graphs_properties
153            if "_energy_unit" in fields:
154                params["_energy_unit"] = energy_unit
155            mods.append((mod, params))
156
157        self.modules = FENNIXModules(mods)
158
159        self.__apply = self.modules.apply
160        self._apply = jax.jit(self.modules.apply)
161
162        self.set_energy_terms(energy_terms)
163
164        # initialize the model
165
166        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
167
168        if variables is not None:
169            self.variables = variables
170        elif rng_key is not None:
171            self.variables = self.modules.init(rng_key, inputs)
172        else:
173            raise ValueError(
174                "Either variables or a jax.random.PRNGKey must be provided for initialization"
175            )
176
177        self._initializing = False
178
179    def set_energy_terms(
180        self, energy_terms: Union[Sequence[str], None], jit: bool = True
181    ) -> None:
182        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
183        object.__setattr__(self, "energy_terms", energy_terms)
184        if isinstance(energy_terms, str):
185            energy_terms = [energy_terms]
186        
187        if energy_terms is None or len(energy_terms) == 0:
188
189            def total_energy(variables, data):
190                out = self.__apply(variables, data)
191                coords = out["coordinates"]
192                nsys = out["natoms"].shape[0]
193                nat = coords.shape[0]
194                dtype = coords.dtype
195                e = jnp.zeros(nsys, dtype=dtype)
196                eat = jnp.zeros(nat, dtype=dtype)
197                out["total_energy"] = e
198                out["atomic_energies"] = eat
199                return e, out
200
201            def energy_and_forces(variables, data):
202                e, out = total_energy(variables, data)
203                f = jnp.zeros_like(out["coordinates"])
204                out["forces"] = f
205                return e, f, out
206
207            def energy_and_forces_and_virial(variables, data):
208                e, f, out = energy_and_forces(variables, data)
209                v = jnp.zeros(
210                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
211                )
212                out["virial_tensor"] = v
213                return e, f, v, out
214
215        else:
216            # build the energy and force functions
217            def total_energy(variables, data):
218                out = self.__apply(variables, data)
219                atomic_energies = 0.0
220                system_energies = 0.0
221                species = out["species"]
222                nsys = out["natoms"].shape[0]
223                for term in self.energy_terms:
224                    e = out[term]
225                    if e.ndim > 1 and e.shape[-1] == 1:
226                        e = jnp.squeeze(e, axis=-1)
227                    if e.shape[0] == nsys and nsys != species.shape[0]:
228                        system_energies += e
229                        continue
230                    assert e.shape == species.shape
231                    atomic_energies += e
232                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
233                if isinstance(atomic_energies, jnp.ndarray):
234                    if "true_atoms" in out:
235                        atomic_energies = jnp.where(
236                            out["true_atoms"], atomic_energies, 0.0
237                        )
238                    out["atomic_energies"] = atomic_energies
239                    energies = jax.ops.segment_sum(
240                        atomic_energies,
241                        data["batch_index"],
242                        num_segments=len(data["natoms"]),
243                    )
244                else:
245                    energies = 0.0
246
247                if isinstance(system_energies, jnp.ndarray):
248                    if "true_sys" in out:
249                        system_energies = jnp.where(
250                            out["true_sys"], system_energies, 0.0
251                        )
252                    out["system_energies"] = system_energies
253
254                out["total_energy"] = energies + system_energies
255                return out["total_energy"], out
256
257            def energy_and_forces(variables, data):
258                def _etot(variables, coordinates):
259                    energy, out = total_energy(
260                        variables, {**data, "coordinates": coordinates}
261                    )
262                    return energy.sum(), out
263
264                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
265                    variables, data["coordinates"]
266                )
267                out["forces"] = -de
268
269                return out["total_energy"], out["forces"], out
270
271            # def energy_and_forces_and_virial(variables, data):
272            #     x = data["coordinates"]
273            #     batch_index = data["batch_index"]
274            #     if "cells" in data:
275            #         cells = data["cells"]
276            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
277
278            #         def _etot(variables, coordinates, cells):
279            #             reciprocal_cells = jnp.linalg.inv(cells)
280            #             energy, out = total_energy(
281            #                 variables,
282            #                 {
283            #                     **data,
284            #                     "coordinates": coordinates,
285            #                     "cells": cells,
286            #                     "reciprocal_cells": reciprocal_cells,
287            #                 },
288            #             )
289            #             return energy.sum(), out
290
291            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
292            #             variables, x, cells
293            #         )
294            #         f= -dedx
295            #         out["forces"] = f
296            #     else:
297            #         _,f,out = energy_and_forces(variables, data)
298
299            #     vir = -jax.ops.segment_sum(
300            #         f[:, :, None] * x[:, None, :],
301            #         batch_index,
302            #         num_segments=len(data["natoms"]),
303            #     )
304
305            #     if "cells" in data:
306            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
307            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
308            #         nsys = data["natoms"].shape[0]
309            #         if cells.shape[0]==1 and nsys>1:
310            #             dvir = dvir / nsys
311            #         vir = vir + dvir
312
313            #     out["virial_tensor"] = vir
314
315            #     return out["total_energy"], f, vir, out
316
317            def energy_and_forces_and_virial(variables, data):
318                x = data["coordinates"]
319                scaling = jnp.asarray(
320                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
321                )
322                def _etot(variables, coordinates, scaling):
323                    batch_index = data["batch_index"]
324                    coordinates = jax.vmap(jnp.matmul)(
325                        coordinates, scaling[batch_index]
326                    )
327                    inputs = {**data, "coordinates": coordinates}
328                    if "cells" in data:
329                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
330                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
331                        reciprocal_cells = jnp.linalg.inv(cells)
332                        inputs["cells"] = cells
333                        inputs["reciprocal_cells"] = reciprocal_cells
334                    energy, out = total_energy(variables, inputs)
335                    return energy.sum(), out
336
337                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
338                    variables, x, scaling
339                )
340                f = -dedx
341                out["forces"] = f
342                out["virial_tensor"] = vir
343
344                return out["total_energy"], f, vir, out
345
346        object.__setattr__(self, "_total_energy_raw", total_energy)
347        if jit:
348            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
349            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
350            object.__setattr__(
351                self,
352                "_energy_and_forces_and_virial",
353                jax.jit(energy_and_forces_and_virial),
354            )
355        else:
356            object.__setattr__(self, "_total_energy", total_energy)
357            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
358            object.__setattr__(
359                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
360            )
361
362    def get_gradient_function(
363        self,
364        *gradient_keys: Sequence[str],
365        jit: bool = True,
366        variables_as_input: bool = False,
367    ):
368        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
369
370        def _energy_gradient(variables, data):
371            def _etot(variables, inputs):
372                if "strain" in inputs:
373                    scaling = inputs["strain"]
374                    batch_index = data["batch_index"]
375                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
376                    coordinates = jax.vmap(jnp.matmul)(
377                        coordinates, scaling[batch_index]
378                    )
379                    inputs = {**inputs, "coordinates": coordinates}
380                    if "cells" in inputs or "cells" in data:
381                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
382                        cells = jax.vmap(jnp.matmul)(cells, scaling)
383                        inputs["cells"] = cells
384                if "cells" in inputs:
385                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
386                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
387                energy, out = self._total_energy_raw(variables, {**data, **inputs})
388                return energy.sum(), out
389
390            if "strain" in gradient_keys and "strain" not in data:
391                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
392            inputs = {k: data[k] for k in gradient_keys}
393            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
394
395            return (
396                out["total_energy"],
397                de,
398                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
399            )
400
401        if variables_as_input:
402            energy_gradient = _energy_gradient
403        else:
404
405            def energy_gradient(data):
406                return _energy_gradient(self.variables, data)
407
408        if jit:
409            return jax.jit(energy_gradient)
410        else:
411            return energy_gradient
412
413    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
414        """apply preprocessing to the input data
415
416        !!! This is not a pure function => do not apply jax transforms !!!"""
417        if self.preproc_state is None:
418            out, _ = self.reinitialize_preprocessing(example_data=inputs)
419        elif use_gpu:
420            do_check_input = self.preproc_state.get("check_input", True)
421            if do_check_input:
422                inputs = check_input(inputs)
423            preproc_state, inputs = self.preprocessing.atom_padding(
424                self.preproc_state, inputs
425            )
426            inputs = self.preprocessing.process(preproc_state, inputs)
427            preproc_state, state_up, out, overflow = (
428                self.preprocessing.check_reallocate(
429                    preproc_state, inputs
430                )
431            )
432            if verbose and overflow:
433                print("GPU preprocessing: nblist overflow => reallocating nblist")
434                print("size updates:", state_up)
435        else:
436            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
437
438        object.__setattr__(self, "preproc_state", preproc_state)
439        return out
440
441    def reinitialize_preprocessing(
442        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
443    ) -> None:
444        ### TODO ###
445        if rng_key is None:
446            rng_key_pre = jax.random.PRNGKey(0)
447        else:
448            rng_key, rng_key_pre = jax.random.split(rng_key)
449
450        if example_data is None:
451            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
452            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
453
454        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
455        object.__setattr__(self, "preproc_state", preproc_state)
456        return inputs, rng_key
457
458    def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]:
459        """Apply the FENNIX model (preprocess + modules)
460
461        !!! This is not a pure function => do not apply jax transforms !!!
462        if you want to apply jax transforms, use  self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess
463        """
464        if variables is None:
465            variables = self.variables
466        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
467        output = self._apply(variables, inputs)
468        if self.use_atom_padding:
469            output = atom_unpadding(output)
470        return output
471
472    def total_energy(
473        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
474    ) -> Tuple[jnp.ndarray, Dict]:
475        """compute the total energy of the system
476
477        !!! This is not a pure function => do not apply jax transforms !!!
478        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
479        """
480        if variables is None:
481            variables = self.variables
482        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
483        # def print_shape(path,value):
484        #     if isinstance(value,jnp.ndarray):
485        #         print(path,value.shape)
486        #     else:
487        #         print(path,value)
488        # jax.tree_util.tree_map_with_path(print_shape,inputs)
489        _, output = self._total_energy(variables, inputs)
490        if self.use_atom_padding:
491            output = atom_unpadding(output)
492        e = output["total_energy"]
493        if unit is not None:
494            model_energy_unit = self.Ha_to_model_energy
495            if isinstance(unit, str):
496                unit = au.get_multiplier(unit)
497            e = e * (unit / model_energy_unit)
498        return e, output
499
500    def energy_and_forces(
501        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
502    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
503        """compute the total energy and forces of the system
504
505        !!! This is not a pure function => do not apply jax transforms !!!
506        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
507        """
508        if variables is None:
509            variables = self.variables
510        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
511        _, _, output = self._energy_and_forces(variables, inputs)
512        if self.use_atom_padding:
513            output = atom_unpadding(output)
514        e = output["total_energy"]
515        f = output["forces"]
516        if unit is not None:
517            model_energy_unit = self.Ha_to_model_energy
518            if isinstance(unit, str):
519                unit = au.get_multiplier(unit)
520            e = e * (unit / model_energy_unit)
521            f = f * (unit / model_energy_unit)
522        return e, f, output
523
524    def energy_and_forces_and_virial(
525        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
526    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
527        """compute the total energy and forces of the system
528
529        !!! This is not a pure function => do not apply jax transforms !!!
530        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
531        """
532        if variables is None:
533            variables = self.variables
534        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
535        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
536        if self.use_atom_padding:
537            output = atom_unpadding(output)
538        e = output["total_energy"]
539        f = output["forces"]
540        v = output["virial_tensor"]
541        if unit is not None:
542            model_energy_unit = self.Ha_to_model_energy
543            if isinstance(unit, str):
544                unit = au.get_multiplier(unit)
545            e = e * (unit / model_energy_unit)
546            f = f * (unit / model_energy_unit)
547            v = v * (unit / model_energy_unit)
548        return e, f, v, output
549
550    def remove_atom_padding(self, output):
551        """remove atom padding from the output"""
552        return atom_unpadding(output)
553
554    def get_model(self) -> Tuple[FENNIXModules, Dict]:
555        """return the model and its variables"""
556        return self.modules, self.variables
557
558    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
559        """return the preprocessing chain and its state"""
560        return self.preprocessing, self.preproc_state
561
562    def __setattr__(self, __name: str, __value: Any) -> None:
563        if __name == "variables":
564            if __value is not None:
565                if not (
566                    isinstance(__value, dict)
567                    or isinstance(__value, OrderedDict)
568                    or isinstance(__value, FrozenDict)
569                ):
570                    raise ValueError(f"{__name} must be a dict")
571                object.__setattr__(self, __name, JaxConverter()(__value))
572            else:
573                raise ValueError(f"{__name} cannot be None")
574        elif __name == "preproc_state":
575            if __value is not None:
576                if not (
577                    isinstance(__value, dict)
578                    or isinstance(__value, OrderedDict)
579                    or isinstance(__value, FrozenDict)
580                ):
581                    raise ValueError(f"{__name} must be a FrozenDict")
582                object.__setattr__(self, __name, freeze(JaxConverter()(__value)))
583            else:
584                raise ValueError(f"{__name} cannot be None")
585
586        elif self._initializing:
587            object.__setattr__(self, __name, __value)
588        else:
589            raise ValueError(f"{__name} attribute of FENNIX model is immutable.")
590
591    def generate_dummy_system(
592        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
593    ) -> Dict[str, Any]:
594        """
595        Generate dummy system for initialization
596        """
597        if box_size is None:
598            box_size = 2 * self.cutoff
599            for g in self._graphs_properties.values():
600                cutoff = g["cutoff"]
601                if cutoff is not None:
602                    box_size = min(box_size, 2 * g["cutoff"])
603        coordinates = np.array(
604            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
605        )
606        species = np.ones((n_atoms,), dtype=np.int32)
607        batch_index = np.zeros((n_atoms,), dtype=np.int32)
608        natoms = np.array([n_atoms], dtype=np.int32)
609        return {
610            "species": species,
611            "coordinates": coordinates,
612            # "graph": graph,
613            "batch_index": batch_index,
614            "natoms": natoms,
615        }
616
617    def summarize(
618        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
619    ) -> str:
620        """Summarize the model architecture and parameters"""
621        if rng_key is None:
622            head = "Summarizing with example data:\n"
623            rng_key = jax.random.PRNGKey(0)
624        if example_data is None:
625            head = "Summarizing with dummy 10 atoms system:\n"
626            rng_key, rng_key_sys = jax.random.split(rng_key)
627            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
628        rng_key, rng_key_pre = jax.random.split(rng_key)
629        _, inputs = self.preprocessing.init_with_output(example_data)
630        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
631
632    def to_dict(self):
633        """return a dictionary representation of the model"""
634        return {
635            **self._input_args,
636            "energy_terms": self.energy_terms,
637            "variables": deepcopy(self.variables),
638        }
639
640    def save(self, filename):
641        """save the model to a file"""
642        state_dict = self.to_dict()
643        state_dict["preprocessing"] = [
644            [k, v] for k, v in state_dict["preprocessing"].items()
645        ]
646        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
647        with open(filename, "wb") as f:
648            f.write(serialization.msgpack_serialize(state_dict))
649
650    @classmethod
651    def load(
652        cls,
653        filename,
654        use_atom_padding=False,
655        graph_config={},
656    ):
657        """load a model from a file"""
658        with open(filename, "rb") as f:
659            state_dict = serialization.msgpack_restore(f.read())
660        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
661        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
662        return cls(
663            **state_dict,
664            graph_config=graph_config,
665            use_atom_padding=use_atom_padding,
666        )

Static wrapper for FENNIX models

The underlying model is a fennol.models.modules.FENNIXModules built from the modules dictionary which references registered modules in fennol.models.modules.MODULES and provides the parameters for initialization.

Since the model is static and contains variables, it must be initialized right away with either example_data, variables or rng_key. If variables is provided, it is used directly. If example_data is provided, the model is initialized with example_data and the resulting variables are stored in the wrapper. If only rng_key is provided, the model is initialized with a dummy system and the resulting.

FENNIX( cutoff: float, modules: collections.OrderedDict, preprocessing: collections.OrderedDict = OrderedDict(), example_data=None, rng_key: Optional[PRNGKey] = None, variables: Optional[dict] = None, energy_terms: Optional[Sequence[str]] = None, use_atom_padding: bool = False, graph_config: Dict = {}, energy_unit: str = 'Ha', **kwargs)
 53    def __init__(
 54        self,
 55        cutoff: float,
 56        modules: OrderedDict,
 57        preprocessing: OrderedDict = OrderedDict(),
 58        example_data=None,
 59        rng_key: Optional[jax.random.PRNGKey] = None,
 60        variables: Optional[dict] = None,
 61        energy_terms: Optional[Sequence[str]] = None,
 62        use_atom_padding: bool = False,
 63        graph_config: Dict = {},
 64        energy_unit: str = "Ha",
 65        **kwargs,
 66    ) -> None:
 67        """Initialize the FENNIX model
 68
 69        Arguments:
 70        ----------
 71        cutoff: float
 72            The cutoff radius for the model
 73        modules: OrderedDict
 74            The dictionary defining the sequence of FeNNol modules and their parameters.
 75        preprocessing: OrderedDict
 76            The dictionary defining the sequence of preprocessing modules and their parameters.
 77        example_data: dict
 78            Example data to initialize the model. If not provided, a dummy system is generated.
 79        rng_key: jax.random.PRNGKey
 80            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 81        variables: dict
 82            The variables of the model (i.e. weights, biases and all other tunable parameters).
 83            If not provided, the variables are initialized (usually at random)
 84        energy_terms: Sequence[str]
 85            The energy terms in the model output that will be summed to compute the total energy.
 86            If None, the total energy is always zero (useful for non-PES models).
 87        use_atom_padding: bool
 88            If True, the model will use atom padding for the input data.
 89            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 90        graph_config: dict
 91            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 92
 93        """
 94        self._input_args = {
 95            "cutoff": cutoff,
 96            "modules": OrderedDict(modules),
 97            "preprocessing": OrderedDict(preprocessing),
 98            "energy_unit": energy_unit,
 99        }
100        self.energy_unit = energy_unit
101        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
102        self.cutoff = cutoff
103        self.energy_terms = energy_terms
104        self.use_atom_padding = use_atom_padding
105
106        # add non-differentiable/non-jittable modules
107        preprocessing = deepcopy(preprocessing)
108        if cutoff is None:
109            preprocessing_modules = []
110        else:
111            prep_keys = list(preprocessing.keys())
112            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
113            if len(prep_keys) > 0 and prep_keys[0] == "graph":
114                graph_params = {
115                    **graph_params,
116                    **preprocessing.pop("graph"),
117                }
118            graph_params = {**graph_params, **graph_config}
119
120            preprocessing_modules = [
121                GraphGenerator(**graph_params),
122            ]
123
124        for name, params in preprocessing.items():
125            key = str(params.pop("module_name")) if "module_name" in params else name
126            key = str(params.pop("FID")) if "FID" in params else key
127            mod = PREPROCESSING[key.upper()](**freeze(params))
128            preprocessing_modules.append(mod)
129
130        self.preprocessing = PreprocessingChain(
131            tuple(preprocessing_modules), use_atom_padding
132        )
133        graphs_properties = self.preprocessing.get_graphs_properties()
134        self._graphs_properties = freeze(graphs_properties)
135        # add preprocessing modules that should be differentiated/jitted
136        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
137        # mods = self.preprocessing.get_processors(return_list=True)
138
139        # build the model
140        modules = deepcopy(modules)
141        modules_names = []
142        for name, params in modules.items():
143            key = str(params.pop("module_name")) if "module_name" in params else name
144            key = str(params.pop("FID")) if "FID" in params else key
145            if name in modules_names:
146                raise ValueError(f"Module {name} already exists")
147            modules_names.append(name)
148            params["name"] = name
149            mod = MODULES[key.upper()]
150            fields = [f.name for f in dataclasses.fields(mod)]
151            if "_graphs_properties" in fields:
152                params["_graphs_properties"] = graphs_properties
153            if "_energy_unit" in fields:
154                params["_energy_unit"] = energy_unit
155            mods.append((mod, params))
156
157        self.modules = FENNIXModules(mods)
158
159        self.__apply = self.modules.apply
160        self._apply = jax.jit(self.modules.apply)
161
162        self.set_energy_terms(energy_terms)
163
164        # initialize the model
165
166        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
167
168        if variables is not None:
169            self.variables = variables
170        elif rng_key is not None:
171            self.variables = self.modules.init(rng_key, inputs)
172        else:
173            raise ValueError(
174                "Either variables or a jax.random.PRNGKey must be provided for initialization"
175            )
176
177        self._initializing = False

Initialize the FENNIX model

Arguments:

cutoff: float The cutoff radius for the model modules: OrderedDict The dictionary defining the sequence of FeNNol modules and their parameters. preprocessing: OrderedDict The dictionary defining the sequence of preprocessing modules and their parameters. example_data: dict Example data to initialize the model. If not provided, a dummy system is generated. rng_key: jax.random.PRNGKey The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). variables: dict The variables of the model (i.e. weights, biases and all other tunable parameters). If not provided, the variables are initialized (usually at random) energy_terms: Sequence[str] The energy terms in the model output that will be summed to compute the total energy. If None, the total energy is always zero (useful for non-PES models). use_atom_padding: bool If True, the model will use atom padding for the input data. This is useful when one plans to frequently change the number of atoms in the system (for example during training). graph_config: dict Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.

cutoff: Optional[float]
variables: Dict
preproc_state: Dict
energy_terms: Optional[Sequence[str]] = None
use_atom_padding: bool = False
energy_unit
Ha_to_model_energy
def set_energy_terms(self, energy_terms: Optional[Sequence[str]], jit: bool = True) -> None:
179    def set_energy_terms(
180        self, energy_terms: Union[Sequence[str], None], jit: bool = True
181    ) -> None:
182        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
183        object.__setattr__(self, "energy_terms", energy_terms)
184        if isinstance(energy_terms, str):
185            energy_terms = [energy_terms]
186        
187        if energy_terms is None or len(energy_terms) == 0:
188
189            def total_energy(variables, data):
190                out = self.__apply(variables, data)
191                coords = out["coordinates"]
192                nsys = out["natoms"].shape[0]
193                nat = coords.shape[0]
194                dtype = coords.dtype
195                e = jnp.zeros(nsys, dtype=dtype)
196                eat = jnp.zeros(nat, dtype=dtype)
197                out["total_energy"] = e
198                out["atomic_energies"] = eat
199                return e, out
200
201            def energy_and_forces(variables, data):
202                e, out = total_energy(variables, data)
203                f = jnp.zeros_like(out["coordinates"])
204                out["forces"] = f
205                return e, f, out
206
207            def energy_and_forces_and_virial(variables, data):
208                e, f, out = energy_and_forces(variables, data)
209                v = jnp.zeros(
210                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
211                )
212                out["virial_tensor"] = v
213                return e, f, v, out
214
215        else:
216            # build the energy and force functions
217            def total_energy(variables, data):
218                out = self.__apply(variables, data)
219                atomic_energies = 0.0
220                system_energies = 0.0
221                species = out["species"]
222                nsys = out["natoms"].shape[0]
223                for term in self.energy_terms:
224                    e = out[term]
225                    if e.ndim > 1 and e.shape[-1] == 1:
226                        e = jnp.squeeze(e, axis=-1)
227                    if e.shape[0] == nsys and nsys != species.shape[0]:
228                        system_energies += e
229                        continue
230                    assert e.shape == species.shape
231                    atomic_energies += e
232                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
233                if isinstance(atomic_energies, jnp.ndarray):
234                    if "true_atoms" in out:
235                        atomic_energies = jnp.where(
236                            out["true_atoms"], atomic_energies, 0.0
237                        )
238                    out["atomic_energies"] = atomic_energies
239                    energies = jax.ops.segment_sum(
240                        atomic_energies,
241                        data["batch_index"],
242                        num_segments=len(data["natoms"]),
243                    )
244                else:
245                    energies = 0.0
246
247                if isinstance(system_energies, jnp.ndarray):
248                    if "true_sys" in out:
249                        system_energies = jnp.where(
250                            out["true_sys"], system_energies, 0.0
251                        )
252                    out["system_energies"] = system_energies
253
254                out["total_energy"] = energies + system_energies
255                return out["total_energy"], out
256
257            def energy_and_forces(variables, data):
258                def _etot(variables, coordinates):
259                    energy, out = total_energy(
260                        variables, {**data, "coordinates": coordinates}
261                    )
262                    return energy.sum(), out
263
264                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
265                    variables, data["coordinates"]
266                )
267                out["forces"] = -de
268
269                return out["total_energy"], out["forces"], out
270
271            # def energy_and_forces_and_virial(variables, data):
272            #     x = data["coordinates"]
273            #     batch_index = data["batch_index"]
274            #     if "cells" in data:
275            #         cells = data["cells"]
276            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
277
278            #         def _etot(variables, coordinates, cells):
279            #             reciprocal_cells = jnp.linalg.inv(cells)
280            #             energy, out = total_energy(
281            #                 variables,
282            #                 {
283            #                     **data,
284            #                     "coordinates": coordinates,
285            #                     "cells": cells,
286            #                     "reciprocal_cells": reciprocal_cells,
287            #                 },
288            #             )
289            #             return energy.sum(), out
290
291            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
292            #             variables, x, cells
293            #         )
294            #         f= -dedx
295            #         out["forces"] = f
296            #     else:
297            #         _,f,out = energy_and_forces(variables, data)
298
299            #     vir = -jax.ops.segment_sum(
300            #         f[:, :, None] * x[:, None, :],
301            #         batch_index,
302            #         num_segments=len(data["natoms"]),
303            #     )
304
305            #     if "cells" in data:
306            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
307            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
308            #         nsys = data["natoms"].shape[0]
309            #         if cells.shape[0]==1 and nsys>1:
310            #             dvir = dvir / nsys
311            #         vir = vir + dvir
312
313            #     out["virial_tensor"] = vir
314
315            #     return out["total_energy"], f, vir, out
316
317            def energy_and_forces_and_virial(variables, data):
318                x = data["coordinates"]
319                scaling = jnp.asarray(
320                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
321                )
322                def _etot(variables, coordinates, scaling):
323                    batch_index = data["batch_index"]
324                    coordinates = jax.vmap(jnp.matmul)(
325                        coordinates, scaling[batch_index]
326                    )
327                    inputs = {**data, "coordinates": coordinates}
328                    if "cells" in data:
329                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
330                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
331                        reciprocal_cells = jnp.linalg.inv(cells)
332                        inputs["cells"] = cells
333                        inputs["reciprocal_cells"] = reciprocal_cells
334                    energy, out = total_energy(variables, inputs)
335                    return energy.sum(), out
336
337                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
338                    variables, x, scaling
339                )
340                f = -dedx
341                out["forces"] = f
342                out["virial_tensor"] = vir
343
344                return out["total_energy"], f, vir, out
345
346        object.__setattr__(self, "_total_energy_raw", total_energy)
347        if jit:
348            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
349            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
350            object.__setattr__(
351                self,
352                "_energy_and_forces_and_virial",
353                jax.jit(energy_and_forces_and_virial),
354            )
355        else:
356            object.__setattr__(self, "_total_energy", total_energy)
357            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
358            object.__setattr__(
359                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
360            )

Set the energy terms to be computed by the model and prepare the energy and force functions.

def get_gradient_function( self, *gradient_keys: Sequence[str], jit: bool = True, variables_as_input: bool = False):
362    def get_gradient_function(
363        self,
364        *gradient_keys: Sequence[str],
365        jit: bool = True,
366        variables_as_input: bool = False,
367    ):
368        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
369
370        def _energy_gradient(variables, data):
371            def _etot(variables, inputs):
372                if "strain" in inputs:
373                    scaling = inputs["strain"]
374                    batch_index = data["batch_index"]
375                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
376                    coordinates = jax.vmap(jnp.matmul)(
377                        coordinates, scaling[batch_index]
378                    )
379                    inputs = {**inputs, "coordinates": coordinates}
380                    if "cells" in inputs or "cells" in data:
381                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
382                        cells = jax.vmap(jnp.matmul)(cells, scaling)
383                        inputs["cells"] = cells
384                if "cells" in inputs:
385                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
386                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
387                energy, out = self._total_energy_raw(variables, {**data, **inputs})
388                return energy.sum(), out
389
390            if "strain" in gradient_keys and "strain" not in data:
391                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
392            inputs = {k: data[k] for k in gradient_keys}
393            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
394
395            return (
396                out["total_energy"],
397                de,
398                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
399            )
400
401        if variables_as_input:
402            energy_gradient = _energy_gradient
403        else:
404
405            def energy_gradient(data):
406                return _energy_gradient(self.variables, data)
407
408        if jit:
409            return jax.jit(energy_gradient)
410        else:
411            return energy_gradient

Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys

def preprocess(self, use_gpu=False, verbose=False, **inputs) -> Dict[str, Any]:
413    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
414        """apply preprocessing to the input data
415
416        !!! This is not a pure function => do not apply jax transforms !!!"""
417        if self.preproc_state is None:
418            out, _ = self.reinitialize_preprocessing(example_data=inputs)
419        elif use_gpu:
420            do_check_input = self.preproc_state.get("check_input", True)
421            if do_check_input:
422                inputs = check_input(inputs)
423            preproc_state, inputs = self.preprocessing.atom_padding(
424                self.preproc_state, inputs
425            )
426            inputs = self.preprocessing.process(preproc_state, inputs)
427            preproc_state, state_up, out, overflow = (
428                self.preprocessing.check_reallocate(
429                    preproc_state, inputs
430                )
431            )
432            if verbose and overflow:
433                print("GPU preprocessing: nblist overflow => reallocating nblist")
434                print("size updates:", state_up)
435        else:
436            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
437
438        object.__setattr__(self, "preproc_state", preproc_state)
439        return out

apply preprocessing to the input data

!!! This is not a pure function => do not apply jax transforms !!!

def reinitialize_preprocessing(self, rng_key: Optional[PRNGKey] = None, example_data=None) -> None:
441    def reinitialize_preprocessing(
442        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
443    ) -> None:
444        ### TODO ###
445        if rng_key is None:
446            rng_key_pre = jax.random.PRNGKey(0)
447        else:
448            rng_key, rng_key_pre = jax.random.split(rng_key)
449
450        if example_data is None:
451            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
452            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
453
454        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
455        object.__setattr__(self, "preproc_state", preproc_state)
456        return inputs, rng_key
def total_energy( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, Dict]:
472    def total_energy(
473        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
474    ) -> Tuple[jnp.ndarray, Dict]:
475        """compute the total energy of the system
476
477        !!! This is not a pure function => do not apply jax transforms !!!
478        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
479        """
480        if variables is None:
481            variables = self.variables
482        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
483        # def print_shape(path,value):
484        #     if isinstance(value,jnp.ndarray):
485        #         print(path,value.shape)
486        #     else:
487        #         print(path,value)
488        # jax.tree_util.tree_map_with_path(print_shape,inputs)
489        _, output = self._total_energy(variables, inputs)
490        if self.use_atom_padding:
491            output = atom_unpadding(output)
492        e = output["total_energy"]
493        if unit is not None:
494            model_energy_unit = self.Ha_to_model_energy
495            if isinstance(unit, str):
496                unit = au.get_multiplier(unit)
497            e = e * (unit / model_energy_unit)
498        return e, output

compute the total energy of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess

def energy_and_forces( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, jax.Array, Dict]:
500    def energy_and_forces(
501        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
502    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
503        """compute the total energy and forces of the system
504
505        !!! This is not a pure function => do not apply jax transforms !!!
506        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
507        """
508        if variables is None:
509            variables = self.variables
510        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
511        _, _, output = self._energy_and_forces(variables, inputs)
512        if self.use_atom_padding:
513            output = atom_unpadding(output)
514        e = output["total_energy"]
515        f = output["forces"]
516        if unit is not None:
517            model_energy_unit = self.Ha_to_model_energy
518            if isinstance(unit, str):
519                unit = au.get_multiplier(unit)
520            e = e * (unit / model_energy_unit)
521            f = f * (unit / model_energy_unit)
522        return e, f, output

compute the total energy and forces of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess

def energy_and_forces_and_virial( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, jax.Array, Dict]:
524    def energy_and_forces_and_virial(
525        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
526    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
527        """compute the total energy and forces of the system
528
529        !!! This is not a pure function => do not apply jax transforms !!!
530        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
531        """
532        if variables is None:
533            variables = self.variables
534        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
535        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
536        if self.use_atom_padding:
537            output = atom_unpadding(output)
538        e = output["total_energy"]
539        f = output["forces"]
540        v = output["virial_tensor"]
541        if unit is not None:
542            model_energy_unit = self.Ha_to_model_energy
543            if isinstance(unit, str):
544                unit = au.get_multiplier(unit)
545            e = e * (unit / model_energy_unit)
546            f = f * (unit / model_energy_unit)
547            v = v * (unit / model_energy_unit)
548        return e, f, v, output

compute the total energy and forces of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess

def remove_atom_padding(self, output):
550    def remove_atom_padding(self, output):
551        """remove atom padding from the output"""
552        return atom_unpadding(output)

remove atom padding from the output

def get_model(self) -> Tuple[fennol.models.modules.FENNIXModules, Dict]:
554    def get_model(self) -> Tuple[FENNIXModules, Dict]:
555        """return the model and its variables"""
556        return self.modules, self.variables

return the model and its variables

def get_preprocessing(self) -> Tuple[fennol.models.preprocessing.PreprocessingChain, Dict]:
558    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
559        """return the preprocessing chain and its state"""
560        return self.preprocessing, self.preproc_state

return the preprocessing chain and its state

def generate_dummy_system( self, rng_key: <function PRNGKey>, box_size=None, n_atoms: int = 10) -> Dict[str, Any]:
591    def generate_dummy_system(
592        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
593    ) -> Dict[str, Any]:
594        """
595        Generate dummy system for initialization
596        """
597        if box_size is None:
598            box_size = 2 * self.cutoff
599            for g in self._graphs_properties.values():
600                cutoff = g["cutoff"]
601                if cutoff is not None:
602                    box_size = min(box_size, 2 * g["cutoff"])
603        coordinates = np.array(
604            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
605        )
606        species = np.ones((n_atoms,), dtype=np.int32)
607        batch_index = np.zeros((n_atoms,), dtype=np.int32)
608        natoms = np.array([n_atoms], dtype=np.int32)
609        return {
610            "species": species,
611            "coordinates": coordinates,
612            # "graph": graph,
613            "batch_index": batch_index,
614            "natoms": natoms,
615        }

Generate dummy system for initialization

def summarize( self, rng_key: <function PRNGKey> = None, example_data=None, **kwargs) -> str:
617    def summarize(
618        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
619    ) -> str:
620        """Summarize the model architecture and parameters"""
621        if rng_key is None:
622            head = "Summarizing with example data:\n"
623            rng_key = jax.random.PRNGKey(0)
624        if example_data is None:
625            head = "Summarizing with dummy 10 atoms system:\n"
626            rng_key, rng_key_sys = jax.random.split(rng_key)
627            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
628        rng_key, rng_key_pre = jax.random.split(rng_key)
629        _, inputs = self.preprocessing.init_with_output(example_data)
630        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)

Summarize the model architecture and parameters

def to_dict(self):
632    def to_dict(self):
633        """return a dictionary representation of the model"""
634        return {
635            **self._input_args,
636            "energy_terms": self.energy_terms,
637            "variables": deepcopy(self.variables),
638        }

return a dictionary representation of the model

def save(self, filename):
640    def save(self, filename):
641        """save the model to a file"""
642        state_dict = self.to_dict()
643        state_dict["preprocessing"] = [
644            [k, v] for k, v in state_dict["preprocessing"].items()
645        ]
646        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
647        with open(filename, "wb") as f:
648            f.write(serialization.msgpack_serialize(state_dict))

save the model to a file

@classmethod
def load(cls, filename, use_atom_padding=False, graph_config={}):
650    @classmethod
651    def load(
652        cls,
653        filename,
654        use_atom_padding=False,
655        graph_config={},
656    ):
657        """load a model from a file"""
658        with open(filename, "rb") as f:
659            state_dict = serialization.msgpack_restore(f.read())
660        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
661        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
662        return cls(
663            **state_dict,
664            graph_config=graph_config,
665            use_atom_padding=use_atom_padding,
666        )

load a model from a file