fennol.ase

  1import ase
  2import ase.calculators.calculator
  3import ase.units
  4import numpy as np
  5import jax.numpy as jnp
  6from . import FENNIX
  7from .models.preprocessing import convert_to_jax
  8from typing import Sequence, Union, Optional
  9from ase.stress import full_3x3_to_voigt_6_stress
 10import jax
 11from .utils import AtomicUnits as au
 12
 13class FENNIXCalculator(ase.calculators.calculator.Calculator):
 14    """FENNIX calculator for ASE.
 15
 16    Arguments:
 17    ----------
 18    model: str or FENNIX
 19        The path to the model or the model itself.
 20    use_atom_padding: bool, default=False
 21        Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False.
 22    gpu_preprocessing: bool, default=False
 23        Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems.
 24    atoms: ase.Atoms, default=None
 25        The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object.
 26    verbose: bool, default=False
 27        Whether to print nblist update information or not.
 28    energy_terms: list of str, default=None
 29        The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model.
 30    """
 31
 32    implemented_properties = ["energy", "forces", "stress"]
 33
 34    def __init__(
 35        self,
 36        model: Union[str, FENNIX],
 37        gpu_preprocessing: bool = False,
 38        atoms: Optional[ase.Atoms] = None,
 39        verbose: bool = False,
 40        energy_terms: Optional[Sequence[str]] = None,
 41        use_float64: bool = False,
 42        matmul_prec: Optional[str] = None,
 43        **kwargs
 44    ):
 45        super().__init__()
 46        if use_float64:
 47            jax.config.update("jax_enable_x64", True)
 48        if matmul_prec is not None:
 49            assert matmul_prec in [
 50                "default",
 51                "high",
 52                "highest",
 53            ], "matmul_prec should be one of 'default', 'high', 'highest'"
 54            jax.config.update("jax_default_matmul_precision", matmul_prec)
 55
 56        if isinstance(model, FENNIX):
 57            self.model = model
 58        else:
 59            self.model = FENNIX.load(model, **kwargs)
 60        if energy_terms is not None:
 61            self.model.set_energy_terms(energy_terms)
 62        self.dtype = "float64" if use_float64 else "float32"
 63        self.gpu_preprocessing = gpu_preprocessing
 64        self.verbose = verbose
 65        self._fennol_inputs = None
 66        self._raw_inputs = None
 67
 68        model_unit = au.get_multiplier(self.model.energy_unit)
 69        self.energy_conv = ase.units.Hartree / model_unit
 70        if atoms is not None:
 71            self.preprocess(atoms)
 72
 73    def calculate(
 74        self,
 75        atoms=None,
 76        properties=["energy"],
 77        system_changes=ase.calculators.calculator.all_changes,
 78    ):
 79        super().calculate(atoms, properties, system_changes)
 80        inputs = self.preprocess(self.atoms, system_changes=system_changes)
 81
 82        results = {}
 83        if "stress" in properties:
 84            e, f, virial, output = self.model._energy_and_forces_and_virial(
 85                self.model.variables, inputs
 86            )
 87            volume = self.atoms.get_volume()
 88            stress = np.asarray(virial[0]) * self.energy_conv / volume
 89            results["stress"] = full_3x3_to_voigt_6_stress(stress)
 90            results["forces"] = np.asarray(f) * self.energy_conv
 91        elif "forces" in properties:
 92            e, f, output = self.model._energy_and_forces(self.model.variables, inputs)
 93            results["forces"] = np.asarray(f) * self.energy_conv
 94        else:
 95            e, output = self.model._total_energy(self.model.variables, inputs)
 96
 97        results["energy"] = float(e[0]) * self.energy_conv
 98        if self.model.use_atom_padding and "forces" in results:
 99            mask = np.asarray(output["true_atoms"])
100            results["forces"] = results["forces"][mask]
101
102        self.results.update(results)
103
104    def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes):
105
106        force_cpu_preprocessing = False
107        if self._raw_inputs is None:
108            force_cpu_preprocessing = True
109            cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype)
110            pbc = np.asarray(atoms.get_pbc(), dtype=bool)
111            if np.all(pbc):
112                use_pbc = True
113            elif np.any(pbc):
114                raise NotImplementedError("PBC should be activated in all directions.")
115            else:
116                use_pbc = False
117
118            species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32)
119            coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype)
120            natoms = np.array([len(species)], dtype=np.int32)
121            batch_index = np.array([0] * len(species), dtype=np.int32)
122
123            inputs = {
124                "species": species,
125                "coordinates": coordinates,
126                "natoms": natoms,
127                "batch_index": batch_index,
128            }
129            if use_pbc:
130                reciprocal_cell = np.linalg.inv(cell)
131                inputs["cells"] = cell.reshape(1, 3, 3)
132                inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3)
133            self._raw_inputs = convert_to_jax(inputs)
134        else:
135            if "cell" in system_changes:
136                pbc = np.asarray(atoms.get_pbc(), dtype=bool)
137                if np.all(pbc):
138                    use_pbc = True
139                elif np.any(pbc):
140                    raise NotImplementedError(
141                        "PBC should be activated in all directions."
142                    )
143                else:
144                    use_pbc = False
145                if use_pbc:
146                    cell = np.asarray(
147                        atoms.get_cell(complete=True).array, dtype=self.dtype
148                    )
149                    reciprocal_cell = np.linalg.inv(cell)
150                    self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3))
151                    self._raw_inputs["reciprocal_cells"] = jnp.asarray(
152                        reciprocal_cell.reshape(1, 3, 3)
153                    )
154                elif "cells" in self._raw_inputs:
155                    del self._raw_inputs["cells"]
156                    del self._raw_inputs["reciprocal_cells"]
157            if "numbers" in system_changes:
158                self._raw_inputs["species"] = jnp.asarray(
159                    atoms.get_atomic_numbers(), dtype=jnp.int32
160                )
161                self._raw_inputs["natoms"] = jnp.array(
162                    [len(self._raw_inputs["species"])], dtype=np.int32
163                )
164                self._raw_inputs["batch_index"] = jnp.array(
165                    [0] * len(self._raw_inputs["species"]), dtype=np.int32
166                )
167                force_cpu_preprocessing = True
168            if "positions" in system_changes:
169                self._raw_inputs["coordinates"] = jnp.asarray(
170                    atoms.get_positions(), dtype=self.dtype
171                )
172
173        if self.gpu_preprocessing and not force_cpu_preprocessing:
174            _, inputs = self.model.preprocessing.atom_padding(
175                self.model.preproc_state, self._raw_inputs
176            )
177            inputs = {**self._fennol_inputs, **inputs}
178
179            inputs = self.model.preprocessing.process(self.model.preproc_state, inputs)
180            self.model.preproc_state, state_up, inputs, overflow = (
181                self.model.preprocessing.check_reallocate(
182                    self.model.preproc_state, inputs
183                )
184            )
185            self._fennol_inputs = inputs
186            if self.verbose and overflow:
187                print("FENNIX nblist overflow => reallocating nblist")
188                print("  size updates:", state_up)
189        else:
190            self._fennol_inputs = self.model.preprocess(**self._raw_inputs)
191
192        return self._fennol_inputs
class FENNIXCalculator(ase.calculators.calculator.Calculator):
 14class FENNIXCalculator(ase.calculators.calculator.Calculator):
 15    """FENNIX calculator for ASE.
 16
 17    Arguments:
 18    ----------
 19    model: str or FENNIX
 20        The path to the model or the model itself.
 21    use_atom_padding: bool, default=False
 22        Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False.
 23    gpu_preprocessing: bool, default=False
 24        Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems.
 25    atoms: ase.Atoms, default=None
 26        The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object.
 27    verbose: bool, default=False
 28        Whether to print nblist update information or not.
 29    energy_terms: list of str, default=None
 30        The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model.
 31    """
 32
 33    implemented_properties = ["energy", "forces", "stress"]
 34
 35    def __init__(
 36        self,
 37        model: Union[str, FENNIX],
 38        gpu_preprocessing: bool = False,
 39        atoms: Optional[ase.Atoms] = None,
 40        verbose: bool = False,
 41        energy_terms: Optional[Sequence[str]] = None,
 42        use_float64: bool = False,
 43        matmul_prec: Optional[str] = None,
 44        **kwargs
 45    ):
 46        super().__init__()
 47        if use_float64:
 48            jax.config.update("jax_enable_x64", True)
 49        if matmul_prec is not None:
 50            assert matmul_prec in [
 51                "default",
 52                "high",
 53                "highest",
 54            ], "matmul_prec should be one of 'default', 'high', 'highest'"
 55            jax.config.update("jax_default_matmul_precision", matmul_prec)
 56
 57        if isinstance(model, FENNIX):
 58            self.model = model
 59        else:
 60            self.model = FENNIX.load(model, **kwargs)
 61        if energy_terms is not None:
 62            self.model.set_energy_terms(energy_terms)
 63        self.dtype = "float64" if use_float64 else "float32"
 64        self.gpu_preprocessing = gpu_preprocessing
 65        self.verbose = verbose
 66        self._fennol_inputs = None
 67        self._raw_inputs = None
 68
 69        model_unit = au.get_multiplier(self.model.energy_unit)
 70        self.energy_conv = ase.units.Hartree / model_unit
 71        if atoms is not None:
 72            self.preprocess(atoms)
 73
 74    def calculate(
 75        self,
 76        atoms=None,
 77        properties=["energy"],
 78        system_changes=ase.calculators.calculator.all_changes,
 79    ):
 80        super().calculate(atoms, properties, system_changes)
 81        inputs = self.preprocess(self.atoms, system_changes=system_changes)
 82
 83        results = {}
 84        if "stress" in properties:
 85            e, f, virial, output = self.model._energy_and_forces_and_virial(
 86                self.model.variables, inputs
 87            )
 88            volume = self.atoms.get_volume()
 89            stress = np.asarray(virial[0]) * self.energy_conv / volume
 90            results["stress"] = full_3x3_to_voigt_6_stress(stress)
 91            results["forces"] = np.asarray(f) * self.energy_conv
 92        elif "forces" in properties:
 93            e, f, output = self.model._energy_and_forces(self.model.variables, inputs)
 94            results["forces"] = np.asarray(f) * self.energy_conv
 95        else:
 96            e, output = self.model._total_energy(self.model.variables, inputs)
 97
 98        results["energy"] = float(e[0]) * self.energy_conv
 99        if self.model.use_atom_padding and "forces" in results:
100            mask = np.asarray(output["true_atoms"])
101            results["forces"] = results["forces"][mask]
102
103        self.results.update(results)
104
105    def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes):
106
107        force_cpu_preprocessing = False
108        if self._raw_inputs is None:
109            force_cpu_preprocessing = True
110            cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype)
111            pbc = np.asarray(atoms.get_pbc(), dtype=bool)
112            if np.all(pbc):
113                use_pbc = True
114            elif np.any(pbc):
115                raise NotImplementedError("PBC should be activated in all directions.")
116            else:
117                use_pbc = False
118
119            species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32)
120            coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype)
121            natoms = np.array([len(species)], dtype=np.int32)
122            batch_index = np.array([0] * len(species), dtype=np.int32)
123
124            inputs = {
125                "species": species,
126                "coordinates": coordinates,
127                "natoms": natoms,
128                "batch_index": batch_index,
129            }
130            if use_pbc:
131                reciprocal_cell = np.linalg.inv(cell)
132                inputs["cells"] = cell.reshape(1, 3, 3)
133                inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3)
134            self._raw_inputs = convert_to_jax(inputs)
135        else:
136            if "cell" in system_changes:
137                pbc = np.asarray(atoms.get_pbc(), dtype=bool)
138                if np.all(pbc):
139                    use_pbc = True
140                elif np.any(pbc):
141                    raise NotImplementedError(
142                        "PBC should be activated in all directions."
143                    )
144                else:
145                    use_pbc = False
146                if use_pbc:
147                    cell = np.asarray(
148                        atoms.get_cell(complete=True).array, dtype=self.dtype
149                    )
150                    reciprocal_cell = np.linalg.inv(cell)
151                    self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3))
152                    self._raw_inputs["reciprocal_cells"] = jnp.asarray(
153                        reciprocal_cell.reshape(1, 3, 3)
154                    )
155                elif "cells" in self._raw_inputs:
156                    del self._raw_inputs["cells"]
157                    del self._raw_inputs["reciprocal_cells"]
158            if "numbers" in system_changes:
159                self._raw_inputs["species"] = jnp.asarray(
160                    atoms.get_atomic_numbers(), dtype=jnp.int32
161                )
162                self._raw_inputs["natoms"] = jnp.array(
163                    [len(self._raw_inputs["species"])], dtype=np.int32
164                )
165                self._raw_inputs["batch_index"] = jnp.array(
166                    [0] * len(self._raw_inputs["species"]), dtype=np.int32
167                )
168                force_cpu_preprocessing = True
169            if "positions" in system_changes:
170                self._raw_inputs["coordinates"] = jnp.asarray(
171                    atoms.get_positions(), dtype=self.dtype
172                )
173
174        if self.gpu_preprocessing and not force_cpu_preprocessing:
175            _, inputs = self.model.preprocessing.atom_padding(
176                self.model.preproc_state, self._raw_inputs
177            )
178            inputs = {**self._fennol_inputs, **inputs}
179
180            inputs = self.model.preprocessing.process(self.model.preproc_state, inputs)
181            self.model.preproc_state, state_up, inputs, overflow = (
182                self.model.preprocessing.check_reallocate(
183                    self.model.preproc_state, inputs
184                )
185            )
186            self._fennol_inputs = inputs
187            if self.verbose and overflow:
188                print("FENNIX nblist overflow => reallocating nblist")
189                print("  size updates:", state_up)
190        else:
191            self._fennol_inputs = self.model.preprocess(**self._raw_inputs)
192
193        return self._fennol_inputs

FENNIX calculator for ASE.

Arguments:

model: str or FENNIX The path to the model or the model itself. use_atom_padding: bool, default=False Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False. gpu_preprocessing: bool, default=False Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems. atoms: ase.Atoms, default=None The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object. verbose: bool, default=False Whether to print nblist update information or not. energy_terms: list of str, default=None The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model.

FENNIXCalculator( model: Union[str, fennol.models.fennix.FENNIX], gpu_preprocessing: bool = False, atoms: Optional[ase.atoms.Atoms] = None, verbose: bool = False, energy_terms: Optional[Sequence[str]] = None, use_float64: bool = False, matmul_prec: Optional[str] = None, **kwargs)
35    def __init__(
36        self,
37        model: Union[str, FENNIX],
38        gpu_preprocessing: bool = False,
39        atoms: Optional[ase.Atoms] = None,
40        verbose: bool = False,
41        energy_terms: Optional[Sequence[str]] = None,
42        use_float64: bool = False,
43        matmul_prec: Optional[str] = None,
44        **kwargs
45    ):
46        super().__init__()
47        if use_float64:
48            jax.config.update("jax_enable_x64", True)
49        if matmul_prec is not None:
50            assert matmul_prec in [
51                "default",
52                "high",
53                "highest",
54            ], "matmul_prec should be one of 'default', 'high', 'highest'"
55            jax.config.update("jax_default_matmul_precision", matmul_prec)
56
57        if isinstance(model, FENNIX):
58            self.model = model
59        else:
60            self.model = FENNIX.load(model, **kwargs)
61        if energy_terms is not None:
62            self.model.set_energy_terms(energy_terms)
63        self.dtype = "float64" if use_float64 else "float32"
64        self.gpu_preprocessing = gpu_preprocessing
65        self.verbose = verbose
66        self._fennol_inputs = None
67        self._raw_inputs = None
68
69        model_unit = au.get_multiplier(self.model.energy_unit)
70        self.energy_conv = ase.units.Hartree / model_unit
71        if atoms is not None:
72            self.preprocess(atoms)

Basic calculator implementation.

restart: str Prefix for restart file. May contain a directory. Default is None: don't restart. ignore_bad_restart_file: bool Deprecated, please do not use. Passing more than one positional argument to Calculator() is deprecated and will stop working in the future. Ignore broken or missing restart file. By default, it is an error if the restart file is missing or broken. directory: str or PurePath Working directory in which to read and write files and perform calculations. label: str Name used for all files. Not supported by all calculators. May contain a directory, but please use the directory parameter for that instead. atoms: Atoms object Optional Atoms object to which the calculator will be attached. When restarting, atoms will get its positions and unit-cell updated from file.

implemented_properties = ['energy', 'forces', 'stress']

Properties calculator can handle (energy, forces, ...)

dtype
gpu_preprocessing
verbose
energy_conv
def calculate( self, atoms=None, properties=['energy'], system_changes=['positions', 'numbers', 'cell', 'pbc', 'initial_charges', 'initial_magmoms']):
 74    def calculate(
 75        self,
 76        atoms=None,
 77        properties=["energy"],
 78        system_changes=ase.calculators.calculator.all_changes,
 79    ):
 80        super().calculate(atoms, properties, system_changes)
 81        inputs = self.preprocess(self.atoms, system_changes=system_changes)
 82
 83        results = {}
 84        if "stress" in properties:
 85            e, f, virial, output = self.model._energy_and_forces_and_virial(
 86                self.model.variables, inputs
 87            )
 88            volume = self.atoms.get_volume()
 89            stress = np.asarray(virial[0]) * self.energy_conv / volume
 90            results["stress"] = full_3x3_to_voigt_6_stress(stress)
 91            results["forces"] = np.asarray(f) * self.energy_conv
 92        elif "forces" in properties:
 93            e, f, output = self.model._energy_and_forces(self.model.variables, inputs)
 94            results["forces"] = np.asarray(f) * self.energy_conv
 95        else:
 96            e, output = self.model._total_energy(self.model.variables, inputs)
 97
 98        results["energy"] = float(e[0]) * self.energy_conv
 99        if self.model.use_atom_padding and "forces" in results:
100            mask = np.asarray(output["true_atoms"])
101            results["forces"] = results["forces"][mask]
102
103        self.results.update(results)

Do the calculation.

properties: list of str List of what needs to be calculated. Can be any combination of 'energy', 'forces', 'stress', 'dipole', 'charges', 'magmom' and 'magmoms'. system_changes: list of str List of what has changed since last calculation. Can be any combination of these six: 'positions', 'numbers', 'cell', 'pbc', 'initial_charges' and 'initial_magmoms'.

Subclasses need to implement this, but can ignore properties and system_changes if they want. Calculated properties should be inserted into results dictionary like shown in this dummy example::

self.results = {'energy': 0.0,
                'forces': np.zeros((len(atoms), 3)),
                'stress': np.zeros(6),
                'dipole': np.zeros(3),
                'charges': np.zeros(len(atoms)),
                'magmom': 0.0,
                'magmoms': np.zeros(len(atoms))}

The subclass implementation should first call this implementation to set the atoms attribute and create any missing directories.

def preprocess( self, atoms, system_changes=['positions', 'numbers', 'cell', 'pbc', 'initial_charges', 'initial_magmoms']):
105    def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes):
106
107        force_cpu_preprocessing = False
108        if self._raw_inputs is None:
109            force_cpu_preprocessing = True
110            cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype)
111            pbc = np.asarray(atoms.get_pbc(), dtype=bool)
112            if np.all(pbc):
113                use_pbc = True
114            elif np.any(pbc):
115                raise NotImplementedError("PBC should be activated in all directions.")
116            else:
117                use_pbc = False
118
119            species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32)
120            coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype)
121            natoms = np.array([len(species)], dtype=np.int32)
122            batch_index = np.array([0] * len(species), dtype=np.int32)
123
124            inputs = {
125                "species": species,
126                "coordinates": coordinates,
127                "natoms": natoms,
128                "batch_index": batch_index,
129            }
130            if use_pbc:
131                reciprocal_cell = np.linalg.inv(cell)
132                inputs["cells"] = cell.reshape(1, 3, 3)
133                inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3)
134            self._raw_inputs = convert_to_jax(inputs)
135        else:
136            if "cell" in system_changes:
137                pbc = np.asarray(atoms.get_pbc(), dtype=bool)
138                if np.all(pbc):
139                    use_pbc = True
140                elif np.any(pbc):
141                    raise NotImplementedError(
142                        "PBC should be activated in all directions."
143                    )
144                else:
145                    use_pbc = False
146                if use_pbc:
147                    cell = np.asarray(
148                        atoms.get_cell(complete=True).array, dtype=self.dtype
149                    )
150                    reciprocal_cell = np.linalg.inv(cell)
151                    self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3))
152                    self._raw_inputs["reciprocal_cells"] = jnp.asarray(
153                        reciprocal_cell.reshape(1, 3, 3)
154                    )
155                elif "cells" in self._raw_inputs:
156                    del self._raw_inputs["cells"]
157                    del self._raw_inputs["reciprocal_cells"]
158            if "numbers" in system_changes:
159                self._raw_inputs["species"] = jnp.asarray(
160                    atoms.get_atomic_numbers(), dtype=jnp.int32
161                )
162                self._raw_inputs["natoms"] = jnp.array(
163                    [len(self._raw_inputs["species"])], dtype=np.int32
164                )
165                self._raw_inputs["batch_index"] = jnp.array(
166                    [0] * len(self._raw_inputs["species"]), dtype=np.int32
167                )
168                force_cpu_preprocessing = True
169            if "positions" in system_changes:
170                self._raw_inputs["coordinates"] = jnp.asarray(
171                    atoms.get_positions(), dtype=self.dtype
172                )
173
174        if self.gpu_preprocessing and not force_cpu_preprocessing:
175            _, inputs = self.model.preprocessing.atom_padding(
176                self.model.preproc_state, self._raw_inputs
177            )
178            inputs = {**self._fennol_inputs, **inputs}
179
180            inputs = self.model.preprocessing.process(self.model.preproc_state, inputs)
181            self.model.preproc_state, state_up, inputs, overflow = (
182                self.model.preprocessing.check_reallocate(
183                    self.model.preproc_state, inputs
184                )
185            )
186            self._fennol_inputs = inputs
187            if self.verbose and overflow:
188                print("FENNIX nblist overflow => reallocating nblist")
189                print("  size updates:", state_up)
190        else:
191            self._fennol_inputs = self.model.preprocess(**self._raw_inputs)
192
193        return self._fennol_inputs