fennol.models.fennix

  1from typing import Any, Sequence, Callable, Union, Optional, Tuple, Dict
  2from copy import deepcopy
  3import dataclasses
  4from collections import OrderedDict
  5
  6import jax
  7import jax.numpy as jnp
  8import flax.linen as nn
  9import numpy as np
 10from flax import serialization
 11from flax.core.frozen_dict import freeze, unfreeze, FrozenDict
 12from ..utils import AtomicUnits as au
 13
 14from .preprocessing import (
 15    GraphGenerator,
 16    PreprocessingChain,
 17    JaxConverter,
 18    atom_unpadding,
 19    check_input,
 20)
 21from .modules import MODULES, PREPROCESSING, FENNIXModules
 22
 23
 24@dataclasses.dataclass
 25class FENNIX:
 26    """
 27    Static wrapper for FENNIX models
 28
 29    The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary
 30    which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization.
 31
 32    Since the model is static and contains variables, it must be initialized right away with either
 33    `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data`
 34    is provided, the model is initialized with `example_data` and the resulting variables are stored
 35    in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting.
 36    """
 37
 38    cutoff: Union[float, None]
 39    modules: FENNIXModules
 40    variables: Dict
 41    preprocessing: PreprocessingChain
 42    _apply: Callable[[Dict, Dict], Dict]
 43    _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]]
 44    _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]]
 45    _input_args: Dict
 46    _graphs_properties: Dict
 47    preproc_state: Dict
 48    energy_terms: Optional[Sequence[str]] = None
 49    _initializing: bool = True
 50    use_atom_padding: bool = False
 51
 52    def __init__(
 53        self,
 54        cutoff: float,
 55        modules: OrderedDict,
 56        preprocessing: OrderedDict = OrderedDict(),
 57        example_data=None,
 58        rng_key: Optional[jax.random.PRNGKey] = None,
 59        variables: Optional[dict] = None,
 60        energy_terms: Optional[Sequence[str]] = None,
 61        use_atom_padding: bool = False,
 62        graph_config: Dict = {},
 63        energy_unit: str = "Ha",
 64        **kwargs,
 65    ) -> None:
 66        """Initialize the FENNIX model
 67
 68        Arguments:
 69        ----------
 70        cutoff: float
 71            The cutoff radius for the model
 72        modules: OrderedDict
 73            The dictionary defining the sequence of FeNNol modules and their parameters.
 74        preprocessing: OrderedDict
 75            The dictionary defining the sequence of preprocessing modules and their parameters.
 76        example_data: dict
 77            Example data to initialize the model. If not provided, a dummy system is generated.
 78        rng_key: jax.random.PRNGKey
 79            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 80        variables: dict
 81            The variables of the model (i.e. weights, biases and all other tunable parameters).
 82            If not provided, the variables are initialized (usually at random)
 83        energy_terms: Sequence[str]
 84            The energy terms in the model output that will be summed to compute the total energy.
 85            If None, the total energy is always zero (useful for non-PES models).
 86        use_atom_padding: bool
 87            If True, the model will use atom padding for the input data.
 88            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 89        graph_config: dict
 90            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 91
 92        """
 93        self._input_args = {
 94            "cutoff": cutoff,
 95            "modules": OrderedDict(modules),
 96            "preprocessing": OrderedDict(preprocessing),
 97            "energy_unit": energy_unit,
 98        }
 99        self.energy_unit = energy_unit
100        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
101        self.cutoff = cutoff
102        self.energy_terms = energy_terms
103        self.use_atom_padding = use_atom_padding
104
105        # add non-differentiable/non-jittable modules
106        preprocessing = deepcopy(preprocessing)
107        if cutoff is None:
108            preprocessing_modules = []
109        else:
110            prep_keys = list(preprocessing.keys())
111            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
112            if len(prep_keys) > 0 and prep_keys[0] == "graph":
113                graph_params = {
114                    **graph_params,
115                    **preprocessing.pop("graph"),
116                }
117            graph_params = {**graph_params, **graph_config}
118
119            preprocessing_modules = [
120                GraphGenerator(**graph_params),
121            ]
122
123        for name, params in preprocessing.items():
124            key = str(params.pop("module_name")) if "module_name" in params else name
125            key = str(params.pop("FID")) if "FID" in params else key
126            mod = PREPROCESSING[key.upper()](**freeze(params))
127            preprocessing_modules.append(mod)
128
129        self.preprocessing = PreprocessingChain(
130            tuple(preprocessing_modules), use_atom_padding
131        )
132        graphs_properties = self.preprocessing.get_graphs_properties()
133        self._graphs_properties = freeze(graphs_properties)
134        # add preprocessing modules that should be differentiated/jitted
135        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
136        # mods = self.preprocessing.get_processors(return_list=True)
137
138        # build the model
139        modules = deepcopy(modules)
140        modules_names = []
141        for name, params in modules.items():
142            key = str(params.pop("module_name")) if "module_name" in params else name
143            key = str(params.pop("FID")) if "FID" in params else key
144            if name in modules_names:
145                raise ValueError(f"Module {name} already exists")
146            modules_names.append(name)
147            params["name"] = name
148            mod = MODULES[key.upper()]
149            fields = [f.name for f in dataclasses.fields(mod)]
150            if "_graphs_properties" in fields:
151                params["_graphs_properties"] = graphs_properties
152            if "_energy_unit" in fields:
153                params["_energy_unit"] = energy_unit
154            mods.append((mod, params))
155
156        self.modules = FENNIXModules(mods)
157
158        self.__apply = self.modules.apply
159        self._apply = jax.jit(self.modules.apply)
160
161        self.set_energy_terms(energy_terms)
162
163        # initialize the model
164
165        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
166
167        if variables is not None:
168            self.variables = variables
169        elif rng_key is not None:
170            self.variables = self.modules.init(rng_key, inputs)
171        else:
172            raise ValueError(
173                "Either variables or a jax.random.PRNGKey must be provided for initialization"
174            )
175
176        self._initializing = False
177
178    def set_energy_terms(
179        self, energy_terms: Union[Sequence[str], None], jit: bool = True
180    ) -> None:
181        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
182        object.__setattr__(self, "energy_terms", energy_terms)
183        if energy_terms is None or len(energy_terms) == 0:
184
185            def total_energy(variables, data):
186                out = self.__apply(variables, data)
187                coords = out["coordinates"]
188                nsys = out["natoms"].shape[0]
189                nat = coords.shape[0]
190                dtype = coords.dtype
191                e = jnp.zeros(nsys, dtype=dtype)
192                eat = jnp.zeros(nat, dtype=dtype)
193                out["total_energy"] = e
194                out["atomic_energies"] = eat
195                return e, out
196
197            def energy_and_forces(variables, data):
198                e, out = total_energy(variables, data)
199                f = jnp.zeros_like(out["coordinates"])
200                out["forces"] = f
201                return e, f, out
202
203            def energy_and_forces_and_virial(variables, data):
204                e, f, out = energy_and_forces(variables, data)
205                v = jnp.zeros(
206                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
207                )
208                out["virial_tensor"] = v
209                return e, f, v, out
210
211        else:
212            # build the energy and force functions
213            def total_energy(variables, data):
214                out = self.__apply(variables, data)
215                atomic_energies = 0.0
216                system_energies = 0.0
217                species = out["species"]
218                nsys = out["natoms"].shape[0]
219                for term in self.energy_terms:
220                    e = out[term]
221                    if e.ndim > 1 and e.shape[-1] == 1:
222                        e = jnp.squeeze(e, axis=-1)
223                    if e.shape[0] == nsys and nsys != species.shape[0]:
224                        system_energies += e
225                        continue
226                    assert e.shape == species.shape
227                    atomic_energies += e
228                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
229                if isinstance(atomic_energies, jnp.ndarray):
230                    if "true_atoms" in out:
231                        atomic_energies = jnp.where(
232                            out["true_atoms"], atomic_energies, 0.0
233                        )
234                    out["atomic_energies"] = atomic_energies
235                    energies = jax.ops.segment_sum(
236                        atomic_energies,
237                        data["batch_index"],
238                        num_segments=len(data["natoms"]),
239                    )
240                else:
241                    energies = 0.0
242
243                if isinstance(system_energies, jnp.ndarray):
244                    if "true_sys" in out:
245                        system_energies = jnp.where(
246                            out["true_sys"], system_energies, 0.0
247                        )
248                    out["system_energies"] = system_energies
249
250                out["total_energy"] = energies + system_energies
251                return out["total_energy"], out
252
253            def energy_and_forces(variables, data):
254                def _etot(variables, coordinates):
255                    energy, out = total_energy(
256                        variables, {**data, "coordinates": coordinates}
257                    )
258                    return energy.sum(), out
259
260                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
261                    variables, data["coordinates"]
262                )
263                out["forces"] = -de
264
265                return out["total_energy"], out["forces"], out
266
267            # def energy_and_forces_and_virial(variables, data):
268            #     x = data["coordinates"]
269            #     batch_index = data["batch_index"]
270            #     if "cells" in data:
271            #         cells = data["cells"]
272            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
273
274            #         def _etot(variables, coordinates, cells):
275            #             reciprocal_cells = jnp.linalg.inv(cells)
276            #             energy, out = total_energy(
277            #                 variables,
278            #                 {
279            #                     **data,
280            #                     "coordinates": coordinates,
281            #                     "cells": cells,
282            #                     "reciprocal_cells": reciprocal_cells,
283            #                 },
284            #             )
285            #             return energy.sum(), out
286
287            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
288            #             variables, x, cells
289            #         )
290            #         f= -dedx
291            #         out["forces"] = f
292            #     else:
293            #         _,f,out = energy_and_forces(variables, data)
294
295            #     vir = -jax.ops.segment_sum(
296            #         f[:, :, None] * x[:, None, :],
297            #         batch_index,
298            #         num_segments=len(data["natoms"]),
299            #     )
300
301            #     if "cells" in data:
302            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
303            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
304            #         nsys = data["natoms"].shape[0]
305            #         if cells.shape[0]==1 and nsys>1:
306            #             dvir = dvir / nsys
307            #         vir = vir + dvir
308
309            #     out["virial_tensor"] = vir
310
311            #     return out["total_energy"], f, vir, out
312
313            def energy_and_forces_and_virial(variables, data):
314                x = data["coordinates"]
315                scaling = jnp.asarray(
316                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
317                )
318                def _etot(variables, coordinates, scaling):
319                    batch_index = data["batch_index"]
320                    coordinates = jax.vmap(jnp.matmul)(
321                        coordinates, scaling[batch_index]
322                    )
323                    inputs = {**data, "coordinates": coordinates}
324                    if "cells" in data:
325                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
326                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
327                        reciprocal_cells = jnp.linalg.inv(cells)
328                        inputs["cells"] = cells
329                        inputs["reciprocal_cells"] = reciprocal_cells
330                    energy, out = total_energy(variables, inputs)
331                    return energy.sum(), out
332
333                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
334                    variables, x, scaling
335                )
336                f = -dedx
337                out["forces"] = f
338                out["virial_tensor"] = vir
339
340                return out["total_energy"], f, vir, out
341
342        object.__setattr__(self, "_total_energy_raw", total_energy)
343        if jit:
344            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
345            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
346            object.__setattr__(
347                self,
348                "_energy_and_forces_and_virial",
349                jax.jit(energy_and_forces_and_virial),
350            )
351        else:
352            object.__setattr__(self, "_total_energy", total_energy)
353            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
354            object.__setattr__(
355                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
356            )
357
358    def get_gradient_function(
359        self,
360        *gradient_keys: Sequence[str],
361        jit: bool = True,
362        variables_as_input: bool = False,
363    ):
364        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
365
366        def _energy_gradient(variables, data):
367            def _etot(variables, inputs):
368                if "cells" in inputs:
369                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
370                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
371                energy, out = self._total_energy_raw(variables, {**data, **inputs})
372                return energy.sum(), out
373
374            inputs = {k: data[k] for k in gradient_keys}
375            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
376
377            return (
378                out["total_energy"],
379                de,
380                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
381            )
382
383        if variables_as_input:
384            energy_gradient = _energy_gradient
385        else:
386
387            def energy_gradient(data):
388                return _energy_gradient(self.variables, data)
389
390        if jit:
391            return jax.jit(energy_gradient)
392        else:
393            return energy_gradient
394
395    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
396        """apply preprocessing to the input data
397
398        !!! This is not a pure function => do not apply jax transforms !!!"""
399        if self.preproc_state is None:
400            out, _ = self.reinitialize_preprocessing(example_data=inputs)
401        elif use_gpu:
402            do_check_input = self.preproc_state.get("check_input", True)
403            if do_check_input:
404                inputs = check_input(inputs)
405            preproc_state, inputs = self.preprocessing.atom_padding(
406                self.preproc_state, inputs
407            )
408            inputs = self.preprocessing.process(preproc_state, inputs)
409            preproc_state, state_up, out, overflow = (
410                self.preprocessing.check_reallocate(
411                    preproc_state, inputs
412                )
413            )
414            if verbose and overflow:
415                print("GPU preprocessing: nblist overflow => reallocating nblist")
416                print("size updates:", state_up)
417        else:
418            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
419
420        object.__setattr__(self, "preproc_state", preproc_state)
421        return out
422
423    def reinitialize_preprocessing(
424        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
425    ) -> None:
426        ### TODO ###
427        if rng_key is None:
428            rng_key_pre = jax.random.PRNGKey(0)
429        else:
430            rng_key, rng_key_pre = jax.random.split(rng_key)
431
432        if example_data is None:
433            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
434            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
435
436        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
437        object.__setattr__(self, "preproc_state", preproc_state)
438        return inputs, rng_key
439
440    def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]:
441        """Apply the FENNIX model (preprocess + modules)
442
443        !!! This is not a pure function => do not apply jax transforms !!!
444        if you want to apply jax transforms, use  self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess
445        """
446        if variables is None:
447            variables = self.variables
448        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
449        output = self._apply(variables, inputs)
450        if self.use_atom_padding:
451            output = atom_unpadding(output)
452        return output
453
454    def total_energy(
455        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
456    ) -> Tuple[jnp.ndarray, Dict]:
457        """compute the total energy of the system
458
459        !!! This is not a pure function => do not apply jax transforms !!!
460        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
461        """
462        if variables is None:
463            variables = self.variables
464        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
465        # def print_shape(path,value):
466        #     if isinstance(value,jnp.ndarray):
467        #         print(path,value.shape)
468        #     else:
469        #         print(path,value)
470        # jax.tree_util.tree_map_with_path(print_shape,inputs)
471        _, output = self._total_energy(variables, inputs)
472        if self.use_atom_padding:
473            output = atom_unpadding(output)
474        e = output["total_energy"]
475        if unit is not None:
476            model_energy_unit = self.Ha_to_model_energy
477            if isinstance(unit, str):
478                unit = au.get_multiplier(unit)
479            e = e * (unit / model_energy_unit)
480        return e, output
481
482    def energy_and_forces(
483        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
484    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
485        """compute the total energy and forces of the system
486
487        !!! This is not a pure function => do not apply jax transforms !!!
488        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
489        """
490        if variables is None:
491            variables = self.variables
492        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
493        _, _, output = self._energy_and_forces(variables, inputs)
494        if self.use_atom_padding:
495            output = atom_unpadding(output)
496        e = output["total_energy"]
497        f = output["forces"]
498        if unit is not None:
499            model_energy_unit = self.Ha_to_model_energy
500            if isinstance(unit, str):
501                unit = au.get_multiplier(unit)
502            e = e * (unit / model_energy_unit)
503            f = f * (unit / model_energy_unit)
504        return e, f, output
505
506    def energy_and_forces_and_virial(
507        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
508    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
509        """compute the total energy and forces of the system
510
511        !!! This is not a pure function => do not apply jax transforms !!!
512        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
513        """
514        if variables is None:
515            variables = self.variables
516        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
517        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
518        if self.use_atom_padding:
519            output = atom_unpadding(output)
520        e = output["total_energy"]
521        f = output["forces"]
522        v = output["virial_tensor"]
523        if unit is not None:
524            model_energy_unit = self.Ha_to_model_energy
525            if isinstance(unit, str):
526                unit = au.get_multiplier(unit)
527            e = e * (unit / model_energy_unit)
528            f = f * (unit / model_energy_unit)
529            v = v * (unit / model_energy_unit)
530        return e, f, v, output
531
532    def remove_atom_padding(self, output):
533        """remove atom padding from the output"""
534        return atom_unpadding(output)
535
536    def get_model(self) -> Tuple[FENNIXModules, Dict]:
537        """return the model and its variables"""
538        return self.modules, self.variables
539
540    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
541        """return the preprocessing chain and its state"""
542        return self.preprocessing, self.preproc_state
543
544    def __setattr__(self, __name: str, __value: Any) -> None:
545        if __name == "variables":
546            if __value is not None:
547                if not (
548                    isinstance(__value, dict)
549                    or isinstance(__value, OrderedDict)
550                    or isinstance(__value, FrozenDict)
551                ):
552                    raise ValueError(f"{__name} must be a dict")
553                object.__setattr__(self, __name, JaxConverter()(__value))
554            else:
555                raise ValueError(f"{__name} cannot be None")
556        elif __name == "preproc_state":
557            if __value is not None:
558                if not (
559                    isinstance(__value, dict)
560                    or isinstance(__value, OrderedDict)
561                    or isinstance(__value, FrozenDict)
562                ):
563                    raise ValueError(f"{__name} must be a FrozenDict")
564                object.__setattr__(self, __name, freeze(JaxConverter()(__value)))
565            else:
566                raise ValueError(f"{__name} cannot be None")
567
568        elif self._initializing:
569            object.__setattr__(self, __name, __value)
570        else:
571            raise ValueError(f"{__name} attribute of FENNIX model is immutable.")
572
573    def generate_dummy_system(
574        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
575    ) -> Dict[str, Any]:
576        """
577        Generate dummy system for initialization
578        """
579        if box_size is None:
580            box_size = 2 * self.cutoff
581            for g in self._graphs_properties.values():
582                cutoff = g["cutoff"]
583                if cutoff is not None:
584                    box_size = min(box_size, 2 * g["cutoff"])
585        coordinates = np.array(
586            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
587        )
588        species = np.ones((n_atoms,), dtype=np.int32)
589        batch_index = np.zeros((n_atoms,), dtype=np.int32)
590        natoms = np.array([n_atoms], dtype=np.int32)
591        return {
592            "species": species,
593            "coordinates": coordinates,
594            # "graph": graph,
595            "batch_index": batch_index,
596            "natoms": natoms,
597        }
598
599    def summarize(
600        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
601    ) -> str:
602        """Summarize the model architecture and parameters"""
603        if rng_key is None:
604            head = "Summarizing with example data:\n"
605            rng_key = jax.random.PRNGKey(0)
606        if example_data is None:
607            head = "Summarizing with dummy 10 atoms system:\n"
608            rng_key, rng_key_sys = jax.random.split(rng_key)
609            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
610        rng_key, rng_key_pre = jax.random.split(rng_key)
611        _, inputs = self.preprocessing.init_with_output(example_data)
612        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
613
614    def to_dict(self):
615        """return a dictionary representation of the model"""
616        return {
617            **self._input_args,
618            "energy_terms": self.energy_terms,
619            "variables": deepcopy(self.variables),
620        }
621
622    def save(self, filename):
623        """save the model to a file"""
624        state_dict = self.to_dict()
625        state_dict["preprocessing"] = [
626            [k, v] for k, v in state_dict["preprocessing"].items()
627        ]
628        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
629        with open(filename, "wb") as f:
630            f.write(serialization.msgpack_serialize(state_dict))
631
632    @classmethod
633    def load(
634        cls,
635        filename,
636        use_atom_padding=False,
637        graph_config={},
638    ):
639        """load a model from a file"""
640        with open(filename, "rb") as f:
641            state_dict = serialization.msgpack_restore(f.read())
642        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
643        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
644        return cls(
645            **state_dict,
646            graph_config=graph_config,
647            use_atom_padding=use_atom_padding,
648        )
@dataclasses.dataclass
class FENNIX:
 25@dataclasses.dataclass
 26class FENNIX:
 27    """
 28    Static wrapper for FENNIX models
 29
 30    The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary
 31    which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization.
 32
 33    Since the model is static and contains variables, it must be initialized right away with either
 34    `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data`
 35    is provided, the model is initialized with `example_data` and the resulting variables are stored
 36    in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting.
 37    """
 38
 39    cutoff: Union[float, None]
 40    modules: FENNIXModules
 41    variables: Dict
 42    preprocessing: PreprocessingChain
 43    _apply: Callable[[Dict, Dict], Dict]
 44    _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]]
 45    _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]]
 46    _input_args: Dict
 47    _graphs_properties: Dict
 48    preproc_state: Dict
 49    energy_terms: Optional[Sequence[str]] = None
 50    _initializing: bool = True
 51    use_atom_padding: bool = False
 52
 53    def __init__(
 54        self,
 55        cutoff: float,
 56        modules: OrderedDict,
 57        preprocessing: OrderedDict = OrderedDict(),
 58        example_data=None,
 59        rng_key: Optional[jax.random.PRNGKey] = None,
 60        variables: Optional[dict] = None,
 61        energy_terms: Optional[Sequence[str]] = None,
 62        use_atom_padding: bool = False,
 63        graph_config: Dict = {},
 64        energy_unit: str = "Ha",
 65        **kwargs,
 66    ) -> None:
 67        """Initialize the FENNIX model
 68
 69        Arguments:
 70        ----------
 71        cutoff: float
 72            The cutoff radius for the model
 73        modules: OrderedDict
 74            The dictionary defining the sequence of FeNNol modules and their parameters.
 75        preprocessing: OrderedDict
 76            The dictionary defining the sequence of preprocessing modules and their parameters.
 77        example_data: dict
 78            Example data to initialize the model. If not provided, a dummy system is generated.
 79        rng_key: jax.random.PRNGKey
 80            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 81        variables: dict
 82            The variables of the model (i.e. weights, biases and all other tunable parameters).
 83            If not provided, the variables are initialized (usually at random)
 84        energy_terms: Sequence[str]
 85            The energy terms in the model output that will be summed to compute the total energy.
 86            If None, the total energy is always zero (useful for non-PES models).
 87        use_atom_padding: bool
 88            If True, the model will use atom padding for the input data.
 89            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 90        graph_config: dict
 91            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 92
 93        """
 94        self._input_args = {
 95            "cutoff": cutoff,
 96            "modules": OrderedDict(modules),
 97            "preprocessing": OrderedDict(preprocessing),
 98            "energy_unit": energy_unit,
 99        }
100        self.energy_unit = energy_unit
101        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
102        self.cutoff = cutoff
103        self.energy_terms = energy_terms
104        self.use_atom_padding = use_atom_padding
105
106        # add non-differentiable/non-jittable modules
107        preprocessing = deepcopy(preprocessing)
108        if cutoff is None:
109            preprocessing_modules = []
110        else:
111            prep_keys = list(preprocessing.keys())
112            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
113            if len(prep_keys) > 0 and prep_keys[0] == "graph":
114                graph_params = {
115                    **graph_params,
116                    **preprocessing.pop("graph"),
117                }
118            graph_params = {**graph_params, **graph_config}
119
120            preprocessing_modules = [
121                GraphGenerator(**graph_params),
122            ]
123
124        for name, params in preprocessing.items():
125            key = str(params.pop("module_name")) if "module_name" in params else name
126            key = str(params.pop("FID")) if "FID" in params else key
127            mod = PREPROCESSING[key.upper()](**freeze(params))
128            preprocessing_modules.append(mod)
129
130        self.preprocessing = PreprocessingChain(
131            tuple(preprocessing_modules), use_atom_padding
132        )
133        graphs_properties = self.preprocessing.get_graphs_properties()
134        self._graphs_properties = freeze(graphs_properties)
135        # add preprocessing modules that should be differentiated/jitted
136        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
137        # mods = self.preprocessing.get_processors(return_list=True)
138
139        # build the model
140        modules = deepcopy(modules)
141        modules_names = []
142        for name, params in modules.items():
143            key = str(params.pop("module_name")) if "module_name" in params else name
144            key = str(params.pop("FID")) if "FID" in params else key
145            if name in modules_names:
146                raise ValueError(f"Module {name} already exists")
147            modules_names.append(name)
148            params["name"] = name
149            mod = MODULES[key.upper()]
150            fields = [f.name for f in dataclasses.fields(mod)]
151            if "_graphs_properties" in fields:
152                params["_graphs_properties"] = graphs_properties
153            if "_energy_unit" in fields:
154                params["_energy_unit"] = energy_unit
155            mods.append((mod, params))
156
157        self.modules = FENNIXModules(mods)
158
159        self.__apply = self.modules.apply
160        self._apply = jax.jit(self.modules.apply)
161
162        self.set_energy_terms(energy_terms)
163
164        # initialize the model
165
166        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
167
168        if variables is not None:
169            self.variables = variables
170        elif rng_key is not None:
171            self.variables = self.modules.init(rng_key, inputs)
172        else:
173            raise ValueError(
174                "Either variables or a jax.random.PRNGKey must be provided for initialization"
175            )
176
177        self._initializing = False
178
179    def set_energy_terms(
180        self, energy_terms: Union[Sequence[str], None], jit: bool = True
181    ) -> None:
182        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
183        object.__setattr__(self, "energy_terms", energy_terms)
184        if energy_terms is None or len(energy_terms) == 0:
185
186            def total_energy(variables, data):
187                out = self.__apply(variables, data)
188                coords = out["coordinates"]
189                nsys = out["natoms"].shape[0]
190                nat = coords.shape[0]
191                dtype = coords.dtype
192                e = jnp.zeros(nsys, dtype=dtype)
193                eat = jnp.zeros(nat, dtype=dtype)
194                out["total_energy"] = e
195                out["atomic_energies"] = eat
196                return e, out
197
198            def energy_and_forces(variables, data):
199                e, out = total_energy(variables, data)
200                f = jnp.zeros_like(out["coordinates"])
201                out["forces"] = f
202                return e, f, out
203
204            def energy_and_forces_and_virial(variables, data):
205                e, f, out = energy_and_forces(variables, data)
206                v = jnp.zeros(
207                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
208                )
209                out["virial_tensor"] = v
210                return e, f, v, out
211
212        else:
213            # build the energy and force functions
214            def total_energy(variables, data):
215                out = self.__apply(variables, data)
216                atomic_energies = 0.0
217                system_energies = 0.0
218                species = out["species"]
219                nsys = out["natoms"].shape[0]
220                for term in self.energy_terms:
221                    e = out[term]
222                    if e.ndim > 1 and e.shape[-1] == 1:
223                        e = jnp.squeeze(e, axis=-1)
224                    if e.shape[0] == nsys and nsys != species.shape[0]:
225                        system_energies += e
226                        continue
227                    assert e.shape == species.shape
228                    atomic_energies += e
229                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
230                if isinstance(atomic_energies, jnp.ndarray):
231                    if "true_atoms" in out:
232                        atomic_energies = jnp.where(
233                            out["true_atoms"], atomic_energies, 0.0
234                        )
235                    out["atomic_energies"] = atomic_energies
236                    energies = jax.ops.segment_sum(
237                        atomic_energies,
238                        data["batch_index"],
239                        num_segments=len(data["natoms"]),
240                    )
241                else:
242                    energies = 0.0
243
244                if isinstance(system_energies, jnp.ndarray):
245                    if "true_sys" in out:
246                        system_energies = jnp.where(
247                            out["true_sys"], system_energies, 0.0
248                        )
249                    out["system_energies"] = system_energies
250
251                out["total_energy"] = energies + system_energies
252                return out["total_energy"], out
253
254            def energy_and_forces(variables, data):
255                def _etot(variables, coordinates):
256                    energy, out = total_energy(
257                        variables, {**data, "coordinates": coordinates}
258                    )
259                    return energy.sum(), out
260
261                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
262                    variables, data["coordinates"]
263                )
264                out["forces"] = -de
265
266                return out["total_energy"], out["forces"], out
267
268            # def energy_and_forces_and_virial(variables, data):
269            #     x = data["coordinates"]
270            #     batch_index = data["batch_index"]
271            #     if "cells" in data:
272            #         cells = data["cells"]
273            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
274
275            #         def _etot(variables, coordinates, cells):
276            #             reciprocal_cells = jnp.linalg.inv(cells)
277            #             energy, out = total_energy(
278            #                 variables,
279            #                 {
280            #                     **data,
281            #                     "coordinates": coordinates,
282            #                     "cells": cells,
283            #                     "reciprocal_cells": reciprocal_cells,
284            #                 },
285            #             )
286            #             return energy.sum(), out
287
288            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
289            #             variables, x, cells
290            #         )
291            #         f= -dedx
292            #         out["forces"] = f
293            #     else:
294            #         _,f,out = energy_and_forces(variables, data)
295
296            #     vir = -jax.ops.segment_sum(
297            #         f[:, :, None] * x[:, None, :],
298            #         batch_index,
299            #         num_segments=len(data["natoms"]),
300            #     )
301
302            #     if "cells" in data:
303            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
304            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
305            #         nsys = data["natoms"].shape[0]
306            #         if cells.shape[0]==1 and nsys>1:
307            #             dvir = dvir / nsys
308            #         vir = vir + dvir
309
310            #     out["virial_tensor"] = vir
311
312            #     return out["total_energy"], f, vir, out
313
314            def energy_and_forces_and_virial(variables, data):
315                x = data["coordinates"]
316                scaling = jnp.asarray(
317                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
318                )
319                def _etot(variables, coordinates, scaling):
320                    batch_index = data["batch_index"]
321                    coordinates = jax.vmap(jnp.matmul)(
322                        coordinates, scaling[batch_index]
323                    )
324                    inputs = {**data, "coordinates": coordinates}
325                    if "cells" in data:
326                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
327                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
328                        reciprocal_cells = jnp.linalg.inv(cells)
329                        inputs["cells"] = cells
330                        inputs["reciprocal_cells"] = reciprocal_cells
331                    energy, out = total_energy(variables, inputs)
332                    return energy.sum(), out
333
334                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
335                    variables, x, scaling
336                )
337                f = -dedx
338                out["forces"] = f
339                out["virial_tensor"] = vir
340
341                return out["total_energy"], f, vir, out
342
343        object.__setattr__(self, "_total_energy_raw", total_energy)
344        if jit:
345            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
346            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
347            object.__setattr__(
348                self,
349                "_energy_and_forces_and_virial",
350                jax.jit(energy_and_forces_and_virial),
351            )
352        else:
353            object.__setattr__(self, "_total_energy", total_energy)
354            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
355            object.__setattr__(
356                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
357            )
358
359    def get_gradient_function(
360        self,
361        *gradient_keys: Sequence[str],
362        jit: bool = True,
363        variables_as_input: bool = False,
364    ):
365        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
366
367        def _energy_gradient(variables, data):
368            def _etot(variables, inputs):
369                if "cells" in inputs:
370                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
371                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
372                energy, out = self._total_energy_raw(variables, {**data, **inputs})
373                return energy.sum(), out
374
375            inputs = {k: data[k] for k in gradient_keys}
376            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
377
378            return (
379                out["total_energy"],
380                de,
381                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
382            )
383
384        if variables_as_input:
385            energy_gradient = _energy_gradient
386        else:
387
388            def energy_gradient(data):
389                return _energy_gradient(self.variables, data)
390
391        if jit:
392            return jax.jit(energy_gradient)
393        else:
394            return energy_gradient
395
396    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
397        """apply preprocessing to the input data
398
399        !!! This is not a pure function => do not apply jax transforms !!!"""
400        if self.preproc_state is None:
401            out, _ = self.reinitialize_preprocessing(example_data=inputs)
402        elif use_gpu:
403            do_check_input = self.preproc_state.get("check_input", True)
404            if do_check_input:
405                inputs = check_input(inputs)
406            preproc_state, inputs = self.preprocessing.atom_padding(
407                self.preproc_state, inputs
408            )
409            inputs = self.preprocessing.process(preproc_state, inputs)
410            preproc_state, state_up, out, overflow = (
411                self.preprocessing.check_reallocate(
412                    preproc_state, inputs
413                )
414            )
415            if verbose and overflow:
416                print("GPU preprocessing: nblist overflow => reallocating nblist")
417                print("size updates:", state_up)
418        else:
419            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
420
421        object.__setattr__(self, "preproc_state", preproc_state)
422        return out
423
424    def reinitialize_preprocessing(
425        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
426    ) -> None:
427        ### TODO ###
428        if rng_key is None:
429            rng_key_pre = jax.random.PRNGKey(0)
430        else:
431            rng_key, rng_key_pre = jax.random.split(rng_key)
432
433        if example_data is None:
434            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
435            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
436
437        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
438        object.__setattr__(self, "preproc_state", preproc_state)
439        return inputs, rng_key
440
441    def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]:
442        """Apply the FENNIX model (preprocess + modules)
443
444        !!! This is not a pure function => do not apply jax transforms !!!
445        if you want to apply jax transforms, use  self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess
446        """
447        if variables is None:
448            variables = self.variables
449        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
450        output = self._apply(variables, inputs)
451        if self.use_atom_padding:
452            output = atom_unpadding(output)
453        return output
454
455    def total_energy(
456        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
457    ) -> Tuple[jnp.ndarray, Dict]:
458        """compute the total energy of the system
459
460        !!! This is not a pure function => do not apply jax transforms !!!
461        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
462        """
463        if variables is None:
464            variables = self.variables
465        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
466        # def print_shape(path,value):
467        #     if isinstance(value,jnp.ndarray):
468        #         print(path,value.shape)
469        #     else:
470        #         print(path,value)
471        # jax.tree_util.tree_map_with_path(print_shape,inputs)
472        _, output = self._total_energy(variables, inputs)
473        if self.use_atom_padding:
474            output = atom_unpadding(output)
475        e = output["total_energy"]
476        if unit is not None:
477            model_energy_unit = self.Ha_to_model_energy
478            if isinstance(unit, str):
479                unit = au.get_multiplier(unit)
480            e = e * (unit / model_energy_unit)
481        return e, output
482
483    def energy_and_forces(
484        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
485    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
486        """compute the total energy and forces of the system
487
488        !!! This is not a pure function => do not apply jax transforms !!!
489        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
490        """
491        if variables is None:
492            variables = self.variables
493        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
494        _, _, output = self._energy_and_forces(variables, inputs)
495        if self.use_atom_padding:
496            output = atom_unpadding(output)
497        e = output["total_energy"]
498        f = output["forces"]
499        if unit is not None:
500            model_energy_unit = self.Ha_to_model_energy
501            if isinstance(unit, str):
502                unit = au.get_multiplier(unit)
503            e = e * (unit / model_energy_unit)
504            f = f * (unit / model_energy_unit)
505        return e, f, output
506
507    def energy_and_forces_and_virial(
508        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
509    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
510        """compute the total energy and forces of the system
511
512        !!! This is not a pure function => do not apply jax transforms !!!
513        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
514        """
515        if variables is None:
516            variables = self.variables
517        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
518        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
519        if self.use_atom_padding:
520            output = atom_unpadding(output)
521        e = output["total_energy"]
522        f = output["forces"]
523        v = output["virial_tensor"]
524        if unit is not None:
525            model_energy_unit = self.Ha_to_model_energy
526            if isinstance(unit, str):
527                unit = au.get_multiplier(unit)
528            e = e * (unit / model_energy_unit)
529            f = f * (unit / model_energy_unit)
530            v = v * (unit / model_energy_unit)
531        return e, f, v, output
532
533    def remove_atom_padding(self, output):
534        """remove atom padding from the output"""
535        return atom_unpadding(output)
536
537    def get_model(self) -> Tuple[FENNIXModules, Dict]:
538        """return the model and its variables"""
539        return self.modules, self.variables
540
541    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
542        """return the preprocessing chain and its state"""
543        return self.preprocessing, self.preproc_state
544
545    def __setattr__(self, __name: str, __value: Any) -> None:
546        if __name == "variables":
547            if __value is not None:
548                if not (
549                    isinstance(__value, dict)
550                    or isinstance(__value, OrderedDict)
551                    or isinstance(__value, FrozenDict)
552                ):
553                    raise ValueError(f"{__name} must be a dict")
554                object.__setattr__(self, __name, JaxConverter()(__value))
555            else:
556                raise ValueError(f"{__name} cannot be None")
557        elif __name == "preproc_state":
558            if __value is not None:
559                if not (
560                    isinstance(__value, dict)
561                    or isinstance(__value, OrderedDict)
562                    or isinstance(__value, FrozenDict)
563                ):
564                    raise ValueError(f"{__name} must be a FrozenDict")
565                object.__setattr__(self, __name, freeze(JaxConverter()(__value)))
566            else:
567                raise ValueError(f"{__name} cannot be None")
568
569        elif self._initializing:
570            object.__setattr__(self, __name, __value)
571        else:
572            raise ValueError(f"{__name} attribute of FENNIX model is immutable.")
573
574    def generate_dummy_system(
575        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
576    ) -> Dict[str, Any]:
577        """
578        Generate dummy system for initialization
579        """
580        if box_size is None:
581            box_size = 2 * self.cutoff
582            for g in self._graphs_properties.values():
583                cutoff = g["cutoff"]
584                if cutoff is not None:
585                    box_size = min(box_size, 2 * g["cutoff"])
586        coordinates = np.array(
587            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
588        )
589        species = np.ones((n_atoms,), dtype=np.int32)
590        batch_index = np.zeros((n_atoms,), dtype=np.int32)
591        natoms = np.array([n_atoms], dtype=np.int32)
592        return {
593            "species": species,
594            "coordinates": coordinates,
595            # "graph": graph,
596            "batch_index": batch_index,
597            "natoms": natoms,
598        }
599
600    def summarize(
601        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
602    ) -> str:
603        """Summarize the model architecture and parameters"""
604        if rng_key is None:
605            head = "Summarizing with example data:\n"
606            rng_key = jax.random.PRNGKey(0)
607        if example_data is None:
608            head = "Summarizing with dummy 10 atoms system:\n"
609            rng_key, rng_key_sys = jax.random.split(rng_key)
610            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
611        rng_key, rng_key_pre = jax.random.split(rng_key)
612        _, inputs = self.preprocessing.init_with_output(example_data)
613        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
614
615    def to_dict(self):
616        """return a dictionary representation of the model"""
617        return {
618            **self._input_args,
619            "energy_terms": self.energy_terms,
620            "variables": deepcopy(self.variables),
621        }
622
623    def save(self, filename):
624        """save the model to a file"""
625        state_dict = self.to_dict()
626        state_dict["preprocessing"] = [
627            [k, v] for k, v in state_dict["preprocessing"].items()
628        ]
629        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
630        with open(filename, "wb") as f:
631            f.write(serialization.msgpack_serialize(state_dict))
632
633    @classmethod
634    def load(
635        cls,
636        filename,
637        use_atom_padding=False,
638        graph_config={},
639    ):
640        """load a model from a file"""
641        with open(filename, "rb") as f:
642            state_dict = serialization.msgpack_restore(f.read())
643        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
644        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
645        return cls(
646            **state_dict,
647            graph_config=graph_config,
648            use_atom_padding=use_atom_padding,
649        )

Static wrapper for FENNIX models

The underlying model is a fennol.models.modules.FENNIXModules built from the modules dictionary which references registered modules in fennol.models.modules.MODULES and provides the parameters for initialization.

Since the model is static and contains variables, it must be initialized right away with either example_data, variables or rng_key. If variables is provided, it is used directly. If example_data is provided, the model is initialized with example_data and the resulting variables are stored in the wrapper. If only rng_key is provided, the model is initialized with a dummy system and the resulting.

FENNIX( cutoff: float, modules: collections.OrderedDict, preprocessing: collections.OrderedDict = OrderedDict(), example_data=None, rng_key: Optional[PRNGKey] = None, variables: Optional[dict] = None, energy_terms: Optional[Sequence[str]] = None, use_atom_padding: bool = False, graph_config: Dict = {}, energy_unit: str = 'Ha', **kwargs)
 53    def __init__(
 54        self,
 55        cutoff: float,
 56        modules: OrderedDict,
 57        preprocessing: OrderedDict = OrderedDict(),
 58        example_data=None,
 59        rng_key: Optional[jax.random.PRNGKey] = None,
 60        variables: Optional[dict] = None,
 61        energy_terms: Optional[Sequence[str]] = None,
 62        use_atom_padding: bool = False,
 63        graph_config: Dict = {},
 64        energy_unit: str = "Ha",
 65        **kwargs,
 66    ) -> None:
 67        """Initialize the FENNIX model
 68
 69        Arguments:
 70        ----------
 71        cutoff: float
 72            The cutoff radius for the model
 73        modules: OrderedDict
 74            The dictionary defining the sequence of FeNNol modules and their parameters.
 75        preprocessing: OrderedDict
 76            The dictionary defining the sequence of preprocessing modules and their parameters.
 77        example_data: dict
 78            Example data to initialize the model. If not provided, a dummy system is generated.
 79        rng_key: jax.random.PRNGKey
 80            The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided).
 81        variables: dict
 82            The variables of the model (i.e. weights, biases and all other tunable parameters).
 83            If not provided, the variables are initialized (usually at random)
 84        energy_terms: Sequence[str]
 85            The energy terms in the model output that will be summed to compute the total energy.
 86            If None, the total energy is always zero (useful for non-PES models).
 87        use_atom_padding: bool
 88            If True, the model will use atom padding for the input data.
 89            This is useful when one plans to frequently change the number of atoms in the system (for example during training).
 90        graph_config: dict
 91            Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
 92
 93        """
 94        self._input_args = {
 95            "cutoff": cutoff,
 96            "modules": OrderedDict(modules),
 97            "preprocessing": OrderedDict(preprocessing),
 98            "energy_unit": energy_unit,
 99        }
100        self.energy_unit = energy_unit
101        self.Ha_to_model_energy = au.get_multiplier(energy_unit)
102        self.cutoff = cutoff
103        self.energy_terms = energy_terms
104        self.use_atom_padding = use_atom_padding
105
106        # add non-differentiable/non-jittable modules
107        preprocessing = deepcopy(preprocessing)
108        if cutoff is None:
109            preprocessing_modules = []
110        else:
111            prep_keys = list(preprocessing.keys())
112            graph_params = {"cutoff": cutoff, "graph_key": "graph"}
113            if len(prep_keys) > 0 and prep_keys[0] == "graph":
114                graph_params = {
115                    **graph_params,
116                    **preprocessing.pop("graph"),
117                }
118            graph_params = {**graph_params, **graph_config}
119
120            preprocessing_modules = [
121                GraphGenerator(**graph_params),
122            ]
123
124        for name, params in preprocessing.items():
125            key = str(params.pop("module_name")) if "module_name" in params else name
126            key = str(params.pop("FID")) if "FID" in params else key
127            mod = PREPROCESSING[key.upper()](**freeze(params))
128            preprocessing_modules.append(mod)
129
130        self.preprocessing = PreprocessingChain(
131            tuple(preprocessing_modules), use_atom_padding
132        )
133        graphs_properties = self.preprocessing.get_graphs_properties()
134        self._graphs_properties = freeze(graphs_properties)
135        # add preprocessing modules that should be differentiated/jitted
136        mods = [(JaxConverter, {})] + self.preprocessing.get_processors()
137        # mods = self.preprocessing.get_processors(return_list=True)
138
139        # build the model
140        modules = deepcopy(modules)
141        modules_names = []
142        for name, params in modules.items():
143            key = str(params.pop("module_name")) if "module_name" in params else name
144            key = str(params.pop("FID")) if "FID" in params else key
145            if name in modules_names:
146                raise ValueError(f"Module {name} already exists")
147            modules_names.append(name)
148            params["name"] = name
149            mod = MODULES[key.upper()]
150            fields = [f.name for f in dataclasses.fields(mod)]
151            if "_graphs_properties" in fields:
152                params["_graphs_properties"] = graphs_properties
153            if "_energy_unit" in fields:
154                params["_energy_unit"] = energy_unit
155            mods.append((mod, params))
156
157        self.modules = FENNIXModules(mods)
158
159        self.__apply = self.modules.apply
160        self._apply = jax.jit(self.modules.apply)
161
162        self.set_energy_terms(energy_terms)
163
164        # initialize the model
165
166        inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data)
167
168        if variables is not None:
169            self.variables = variables
170        elif rng_key is not None:
171            self.variables = self.modules.init(rng_key, inputs)
172        else:
173            raise ValueError(
174                "Either variables or a jax.random.PRNGKey must be provided for initialization"
175            )
176
177        self._initializing = False

Initialize the FENNIX model

Arguments:

cutoff: float The cutoff radius for the model modules: OrderedDict The dictionary defining the sequence of FeNNol modules and their parameters. preprocessing: OrderedDict The dictionary defining the sequence of preprocessing modules and their parameters. example_data: dict Example data to initialize the model. If not provided, a dummy system is generated. rng_key: jax.random.PRNGKey The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). variables: dict The variables of the model (i.e. weights, biases and all other tunable parameters). If not provided, the variables are initialized (usually at random) energy_terms: Sequence[str] The energy terms in the model output that will be summed to compute the total energy. If None, the total energy is always zero (useful for non-PES models). use_atom_padding: bool If True, the model will use atom padding for the input data. This is useful when one plans to frequently change the number of atoms in the system (for example during training). graph_config: dict Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.

cutoff: Optional[float]
variables: Dict
preproc_state: Dict
energy_terms: Optional[Sequence[str]] = None
use_atom_padding: bool = False
energy_unit
Ha_to_model_energy
def set_energy_terms(self, energy_terms: Optional[Sequence[str]], jit: bool = True) -> None:
179    def set_energy_terms(
180        self, energy_terms: Union[Sequence[str], None], jit: bool = True
181    ) -> None:
182        """Set the energy terms to be computed by the model and prepare the energy and force functions."""
183        object.__setattr__(self, "energy_terms", energy_terms)
184        if energy_terms is None or len(energy_terms) == 0:
185
186            def total_energy(variables, data):
187                out = self.__apply(variables, data)
188                coords = out["coordinates"]
189                nsys = out["natoms"].shape[0]
190                nat = coords.shape[0]
191                dtype = coords.dtype
192                e = jnp.zeros(nsys, dtype=dtype)
193                eat = jnp.zeros(nat, dtype=dtype)
194                out["total_energy"] = e
195                out["atomic_energies"] = eat
196                return e, out
197
198            def energy_and_forces(variables, data):
199                e, out = total_energy(variables, data)
200                f = jnp.zeros_like(out["coordinates"])
201                out["forces"] = f
202                return e, f, out
203
204            def energy_and_forces_and_virial(variables, data):
205                e, f, out = energy_and_forces(variables, data)
206                v = jnp.zeros(
207                    (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype
208                )
209                out["virial_tensor"] = v
210                return e, f, v, out
211
212        else:
213            # build the energy and force functions
214            def total_energy(variables, data):
215                out = self.__apply(variables, data)
216                atomic_energies = 0.0
217                system_energies = 0.0
218                species = out["species"]
219                nsys = out["natoms"].shape[0]
220                for term in self.energy_terms:
221                    e = out[term]
222                    if e.ndim > 1 and e.shape[-1] == 1:
223                        e = jnp.squeeze(e, axis=-1)
224                    if e.shape[0] == nsys and nsys != species.shape[0]:
225                        system_energies += e
226                        continue
227                    assert e.shape == species.shape
228                    atomic_energies += e
229                # atomic_energies = jnp.squeeze(atomic_energies, axis=-1)
230                if isinstance(atomic_energies, jnp.ndarray):
231                    if "true_atoms" in out:
232                        atomic_energies = jnp.where(
233                            out["true_atoms"], atomic_energies, 0.0
234                        )
235                    out["atomic_energies"] = atomic_energies
236                    energies = jax.ops.segment_sum(
237                        atomic_energies,
238                        data["batch_index"],
239                        num_segments=len(data["natoms"]),
240                    )
241                else:
242                    energies = 0.0
243
244                if isinstance(system_energies, jnp.ndarray):
245                    if "true_sys" in out:
246                        system_energies = jnp.where(
247                            out["true_sys"], system_energies, 0.0
248                        )
249                    out["system_energies"] = system_energies
250
251                out["total_energy"] = energies + system_energies
252                return out["total_energy"], out
253
254            def energy_and_forces(variables, data):
255                def _etot(variables, coordinates):
256                    energy, out = total_energy(
257                        variables, {**data, "coordinates": coordinates}
258                    )
259                    return energy.sum(), out
260
261                de, out = jax.grad(_etot, argnums=1, has_aux=True)(
262                    variables, data["coordinates"]
263                )
264                out["forces"] = -de
265
266                return out["total_energy"], out["forces"], out
267
268            # def energy_and_forces_and_virial(variables, data):
269            #     x = data["coordinates"]
270            #     batch_index = data["batch_index"]
271            #     if "cells" in data:
272            #         cells = data["cells"]
273            #         ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
274
275            #         def _etot(variables, coordinates, cells):
276            #             reciprocal_cells = jnp.linalg.inv(cells)
277            #             energy, out = total_energy(
278            #                 variables,
279            #                 {
280            #                     **data,
281            #                     "coordinates": coordinates,
282            #                     "cells": cells,
283            #                     "reciprocal_cells": reciprocal_cells,
284            #                 },
285            #             )
286            #             return energy.sum(), out
287
288            #         (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
289            #             variables, x, cells
290            #         )
291            #         f= -dedx
292            #         out["forces"] = f
293            #     else:
294            #         _,f,out = energy_and_forces(variables, data)
295
296            #     vir = -jax.ops.segment_sum(
297            #         f[:, :, None] * x[:, None, :],
298            #         batch_index,
299            #         num_segments=len(data["natoms"]),
300            #     )
301
302            #     if "cells" in data:
303            #         # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1))
304            #         dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells)
305            #         nsys = data["natoms"].shape[0]
306            #         if cells.shape[0]==1 and nsys>1:
307            #             dvir = dvir / nsys
308            #         vir = vir + dvir
309
310            #     out["virial_tensor"] = vir
311
312            #     return out["total_energy"], f, vir, out
313
314            def energy_and_forces_and_virial(variables, data):
315                x = data["coordinates"]
316                scaling = jnp.asarray(
317                    np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0)
318                )
319                def _etot(variables, coordinates, scaling):
320                    batch_index = data["batch_index"]
321                    coordinates = jax.vmap(jnp.matmul)(
322                        coordinates, scaling[batch_index]
323                    )
324                    inputs = {**data, "coordinates": coordinates}
325                    if "cells" in data:
326                        ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system)
327                        cells = jax.vmap(jnp.matmul)(data["cells"], scaling)
328                        reciprocal_cells = jnp.linalg.inv(cells)
329                        inputs["cells"] = cells
330                        inputs["reciprocal_cells"] = reciprocal_cells
331                    energy, out = total_energy(variables, inputs)
332                    return energy.sum(), out
333
334                (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)(
335                    variables, x, scaling
336                )
337                f = -dedx
338                out["forces"] = f
339                out["virial_tensor"] = vir
340
341                return out["total_energy"], f, vir, out
342
343        object.__setattr__(self, "_total_energy_raw", total_energy)
344        if jit:
345            object.__setattr__(self, "_total_energy", jax.jit(total_energy))
346            object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces))
347            object.__setattr__(
348                self,
349                "_energy_and_forces_and_virial",
350                jax.jit(energy_and_forces_and_virial),
351            )
352        else:
353            object.__setattr__(self, "_total_energy", total_energy)
354            object.__setattr__(self, "_energy_and_forces", energy_and_forces)
355            object.__setattr__(
356                self, "_energy_and_forces_and_virial", energy_and_forces_and_virial
357            )

Set the energy terms to be computed by the model and prepare the energy and force functions.

def get_gradient_function( self, *gradient_keys: Sequence[str], jit: bool = True, variables_as_input: bool = False):
359    def get_gradient_function(
360        self,
361        *gradient_keys: Sequence[str],
362        jit: bool = True,
363        variables_as_input: bool = False,
364    ):
365        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
366
367        def _energy_gradient(variables, data):
368            def _etot(variables, inputs):
369                if "cells" in inputs:
370                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
371                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
372                energy, out = self._total_energy_raw(variables, {**data, **inputs})
373                return energy.sum(), out
374
375            inputs = {k: data[k] for k in gradient_keys}
376            de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs)
377
378            return (
379                out["total_energy"],
380                de,
381                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
382            )
383
384        if variables_as_input:
385            energy_gradient = _energy_gradient
386        else:
387
388            def energy_gradient(data):
389                return _energy_gradient(self.variables, data)
390
391        if jit:
392            return jax.jit(energy_gradient)
393        else:
394            return energy_gradient

Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys

def preprocess(self, use_gpu=False, verbose=False, **inputs) -> Dict[str, Any]:
396    def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]:
397        """apply preprocessing to the input data
398
399        !!! This is not a pure function => do not apply jax transforms !!!"""
400        if self.preproc_state is None:
401            out, _ = self.reinitialize_preprocessing(example_data=inputs)
402        elif use_gpu:
403            do_check_input = self.preproc_state.get("check_input", True)
404            if do_check_input:
405                inputs = check_input(inputs)
406            preproc_state, inputs = self.preprocessing.atom_padding(
407                self.preproc_state, inputs
408            )
409            inputs = self.preprocessing.process(preproc_state, inputs)
410            preproc_state, state_up, out, overflow = (
411                self.preprocessing.check_reallocate(
412                    preproc_state, inputs
413                )
414            )
415            if verbose and overflow:
416                print("GPU preprocessing: nblist overflow => reallocating nblist")
417                print("size updates:", state_up)
418        else:
419            preproc_state, out = self.preprocessing(self.preproc_state, inputs)
420
421        object.__setattr__(self, "preproc_state", preproc_state)
422        return out

apply preprocessing to the input data

!!! This is not a pure function => do not apply jax transforms !!!

def reinitialize_preprocessing(self, rng_key: Optional[PRNGKey] = None, example_data=None) -> None:
424    def reinitialize_preprocessing(
425        self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None
426    ) -> None:
427        ### TODO ###
428        if rng_key is None:
429            rng_key_pre = jax.random.PRNGKey(0)
430        else:
431            rng_key, rng_key_pre = jax.random.split(rng_key)
432
433        if example_data is None:
434            rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre)
435            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
436
437        preproc_state, inputs = self.preprocessing.init_with_output(example_data)
438        object.__setattr__(self, "preproc_state", preproc_state)
439        return inputs, rng_key
def total_energy( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, Dict]:
455    def total_energy(
456        self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs
457    ) -> Tuple[jnp.ndarray, Dict]:
458        """compute the total energy of the system
459
460        !!! This is not a pure function => do not apply jax transforms !!!
461        if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
462        """
463        if variables is None:
464            variables = self.variables
465        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
466        # def print_shape(path,value):
467        #     if isinstance(value,jnp.ndarray):
468        #         print(path,value.shape)
469        #     else:
470        #         print(path,value)
471        # jax.tree_util.tree_map_with_path(print_shape,inputs)
472        _, output = self._total_energy(variables, inputs)
473        if self.use_atom_padding:
474            output = atom_unpadding(output)
475        e = output["total_energy"]
476        if unit is not None:
477            model_energy_unit = self.Ha_to_model_energy
478            if isinstance(unit, str):
479                unit = au.get_multiplier(unit)
480            e = e * (unit / model_energy_unit)
481        return e, output

compute the total energy of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess

def energy_and_forces( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, jax.Array, Dict]:
483    def energy_and_forces(
484        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs
485    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
486        """compute the total energy and forces of the system
487
488        !!! This is not a pure function => do not apply jax transforms !!!
489        if you want to apply jax transforms, use  self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
490        """
491        if variables is None:
492            variables = self.variables
493        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
494        _, _, output = self._energy_and_forces(variables, inputs)
495        if self.use_atom_padding:
496            output = atom_unpadding(output)
497        e = output["total_energy"]
498        f = output["forces"]
499        if unit is not None:
500            model_energy_unit = self.Ha_to_model_energy
501            if isinstance(unit, str):
502                unit = au.get_multiplier(unit)
503            e = e * (unit / model_energy_unit)
504            f = f * (unit / model_energy_unit)
505        return e, f, output

compute the total energy and forces of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess

def energy_and_forces_and_virial( self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False, **inputs) -> Tuple[jax.Array, jax.Array, Dict]:
507    def energy_and_forces_and_virial(
508        self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs
509    ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]:
510        """compute the total energy and forces of the system
511
512        !!! This is not a pure function => do not apply jax transforms !!!
513        if you want to apply jax transforms, use  self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
514        """
515        if variables is None:
516            variables = self.variables
517        inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs)
518        _, _, _, output = self._energy_and_forces_and_virial(variables, inputs)
519        if self.use_atom_padding:
520            output = atom_unpadding(output)
521        e = output["total_energy"]
522        f = output["forces"]
523        v = output["virial_tensor"]
524        if unit is not None:
525            model_energy_unit = self.Ha_to_model_energy
526            if isinstance(unit, str):
527                unit = au.get_multiplier(unit)
528            e = e * (unit / model_energy_unit)
529            f = f * (unit / model_energy_unit)
530            v = v * (unit / model_energy_unit)
531        return e, f, v, output

compute the total energy and forces of the system

!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess

def remove_atom_padding(self, output):
533    def remove_atom_padding(self, output):
534        """remove atom padding from the output"""
535        return atom_unpadding(output)

remove atom padding from the output

def get_model(self) -> Tuple[fennol.models.modules.FENNIXModules, Dict]:
537    def get_model(self) -> Tuple[FENNIXModules, Dict]:
538        """return the model and its variables"""
539        return self.modules, self.variables

return the model and its variables

def get_preprocessing(self) -> Tuple[fennol.models.preprocessing.PreprocessingChain, Dict]:
541    def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]:
542        """return the preprocessing chain and its state"""
543        return self.preprocessing, self.preproc_state

return the preprocessing chain and its state

def generate_dummy_system( self, rng_key: <function PRNGKey>, box_size=None, n_atoms: int = 10) -> Dict[str, Any]:
574    def generate_dummy_system(
575        self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10
576    ) -> Dict[str, Any]:
577        """
578        Generate dummy system for initialization
579        """
580        if box_size is None:
581            box_size = 2 * self.cutoff
582            for g in self._graphs_properties.values():
583                cutoff = g["cutoff"]
584                if cutoff is not None:
585                    box_size = min(box_size, 2 * g["cutoff"])
586        coordinates = np.array(
587            jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64
588        )
589        species = np.ones((n_atoms,), dtype=np.int32)
590        batch_index = np.zeros((n_atoms,), dtype=np.int32)
591        natoms = np.array([n_atoms], dtype=np.int32)
592        return {
593            "species": species,
594            "coordinates": coordinates,
595            # "graph": graph,
596            "batch_index": batch_index,
597            "natoms": natoms,
598        }

Generate dummy system for initialization

def summarize( self, rng_key: <function PRNGKey> = None, example_data=None, **kwargs) -> str:
600    def summarize(
601        self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs
602    ) -> str:
603        """Summarize the model architecture and parameters"""
604        if rng_key is None:
605            head = "Summarizing with example data:\n"
606            rng_key = jax.random.PRNGKey(0)
607        if example_data is None:
608            head = "Summarizing with dummy 10 atoms system:\n"
609            rng_key, rng_key_sys = jax.random.split(rng_key)
610            example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10)
611        rng_key, rng_key_pre = jax.random.split(rng_key)
612        _, inputs = self.preprocessing.init_with_output(example_data)
613        return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)

Summarize the model architecture and parameters

def to_dict(self):
615    def to_dict(self):
616        """return a dictionary representation of the model"""
617        return {
618            **self._input_args,
619            "energy_terms": self.energy_terms,
620            "variables": deepcopy(self.variables),
621        }

return a dictionary representation of the model

def save(self, filename):
623    def save(self, filename):
624        """save the model to a file"""
625        state_dict = self.to_dict()
626        state_dict["preprocessing"] = [
627            [k, v] for k, v in state_dict["preprocessing"].items()
628        ]
629        state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()]
630        with open(filename, "wb") as f:
631            f.write(serialization.msgpack_serialize(state_dict))

save the model to a file

@classmethod
def load(cls, filename, use_atom_padding=False, graph_config={}):
633    @classmethod
634    def load(
635        cls,
636        filename,
637        use_atom_padding=False,
638        graph_config={},
639    ):
640        """load a model from a file"""
641        with open(filename, "rb") as f:
642            state_dict = serialization.msgpack_restore(f.read())
643        state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]}
644        state_dict["modules"] = {k: v for k, v in state_dict["modules"]}
645        return cls(
646            **state_dict,
647            graph_config=graph_config,
648            use_atom_padding=use_atom_padding,
649        )

load a model from a file