fennol.ase
1import ase 2import ase.calculators.calculator 3import ase.units 4import numpy as np 5import jax.numpy as jnp 6from . import FENNIX 7from .models.preprocessing import convert_to_jax 8from typing import Sequence, Union, Optional 9from ase.stress import full_3x3_to_voigt_6_stress 10import jax 11from .utils import AtomicUnits as au 12 13class FENNIXCalculator(ase.calculators.calculator.Calculator): 14 """FENNIX calculator for ASE. 15 16 Arguments: 17 ---------- 18 model: str or FENNIX 19 The path to the model or the model itself. 20 use_atom_padding: bool, default=False 21 Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False. 22 gpu_preprocessing: bool, default=False 23 Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems. 24 atoms: ase.Atoms, default=None 25 The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object. 26 verbose: bool, default=False 27 Whether to print nblist update information or not. 28 energy_terms: list of str, default=None 29 The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model. 30 """ 31 32 implemented_properties = ["energy", "forces", "stress"] 33 34 def __init__( 35 self, 36 model: Union[str, FENNIX], 37 gpu_preprocessing: bool = False, 38 atoms: Optional[ase.Atoms] = None, 39 verbose: bool = False, 40 energy_terms: Optional[Sequence[str]] = None, 41 use_float64: bool = False, 42 matmul_prec: Optional[str] = None, 43 **kwargs 44 ): 45 super().__init__() 46 if use_float64: 47 jax.config.update("jax_enable_x64", True) 48 if matmul_prec is not None: 49 assert matmul_prec in [ 50 "default", 51 "high", 52 "highest", 53 ], "matmul_prec should be one of 'default', 'high', 'highest'" 54 jax.config.update("jax_default_matmul_precision", matmul_prec) 55 56 if isinstance(model, FENNIX): 57 self.model = model 58 else: 59 self.model = FENNIX.load(model, **kwargs) 60 if energy_terms is not None: 61 self.model.set_energy_terms(energy_terms) 62 self.dtype = "float64" if use_float64 else "float32" 63 self.gpu_preprocessing = gpu_preprocessing 64 self.verbose = verbose 65 self._fennol_inputs = None 66 self._raw_inputs = None 67 68 model_unit = au.get_multiplier(self.model.energy_unit) 69 self.energy_conv = ase.units.Hartree / model_unit 70 if atoms is not None: 71 self.preprocess(atoms) 72 73 def calculate( 74 self, 75 atoms=None, 76 properties=["energy"], 77 system_changes=ase.calculators.calculator.all_changes, 78 ): 79 super().calculate(atoms, properties, system_changes) 80 inputs = self.preprocess(self.atoms, system_changes=system_changes) 81 82 results = {} 83 if "stress" in properties: 84 e, f, virial, output = self.model._energy_and_forces_and_virial( 85 self.model.variables, inputs 86 ) 87 volume = self.atoms.get_volume() 88 stress = np.asarray(virial[0]) * self.energy_conv / volume 89 results["stress"] = full_3x3_to_voigt_6_stress(stress) 90 results["forces"] = np.asarray(f) * self.energy_conv 91 elif "forces" in properties: 92 e, f, output = self.model._energy_and_forces(self.model.variables, inputs) 93 results["forces"] = np.asarray(f) * self.energy_conv 94 else: 95 e, output = self.model._total_energy(self.model.variables, inputs) 96 97 results["energy"] = float(e[0]) * self.energy_conv 98 if self.model.use_atom_padding and "forces" in results: 99 mask = np.asarray(output["true_atoms"]) 100 results["forces"] = results["forces"][mask] 101 102 self.results.update(results) 103 104 def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes): 105 106 force_cpu_preprocessing = False 107 if self._raw_inputs is None: 108 force_cpu_preprocessing = True 109 cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype) 110 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 111 if np.all(pbc): 112 use_pbc = True 113 elif np.any(pbc): 114 raise NotImplementedError("PBC should be activated in all directions.") 115 else: 116 use_pbc = False 117 118 species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32) 119 coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype) 120 natoms = np.array([len(species)], dtype=np.int32) 121 batch_index = np.array([0] * len(species), dtype=np.int32) 122 123 inputs = { 124 "species": species, 125 "coordinates": coordinates, 126 "natoms": natoms, 127 "batch_index": batch_index, 128 } 129 if use_pbc: 130 reciprocal_cell = np.linalg.inv(cell) 131 inputs["cells"] = cell.reshape(1, 3, 3) 132 inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3) 133 self._raw_inputs = convert_to_jax(inputs) 134 else: 135 if "cell" in system_changes: 136 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 137 if np.all(pbc): 138 use_pbc = True 139 elif np.any(pbc): 140 raise NotImplementedError( 141 "PBC should be activated in all directions." 142 ) 143 else: 144 use_pbc = False 145 if use_pbc: 146 cell = np.asarray( 147 atoms.get_cell(complete=True).array, dtype=self.dtype 148 ) 149 reciprocal_cell = np.linalg.inv(cell) 150 self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3)) 151 self._raw_inputs["reciprocal_cells"] = jnp.asarray( 152 reciprocal_cell.reshape(1, 3, 3) 153 ) 154 elif "cells" in self._raw_inputs: 155 del self._raw_inputs["cells"] 156 del self._raw_inputs["reciprocal_cells"] 157 if "numbers" in system_changes: 158 self._raw_inputs["species"] = jnp.asarray( 159 atoms.get_atomic_numbers(), dtype=jnp.int32 160 ) 161 self._raw_inputs["natoms"] = jnp.array( 162 [len(self._raw_inputs["species"])], dtype=np.int32 163 ) 164 self._raw_inputs["batch_index"] = jnp.array( 165 [0] * len(self._raw_inputs["species"]), dtype=np.int32 166 ) 167 force_cpu_preprocessing = True 168 if "positions" in system_changes: 169 self._raw_inputs["coordinates"] = jnp.asarray( 170 atoms.get_positions(), dtype=self.dtype 171 ) 172 173 if self.gpu_preprocessing and not force_cpu_preprocessing: 174 _, inputs = self.model.preprocessing.atom_padding( 175 self.model.preproc_state, self._raw_inputs 176 ) 177 inputs = {**self._fennol_inputs, **inputs} 178 179 inputs = self.model.preprocessing.process(self.model.preproc_state, inputs) 180 self.model.preproc_state, state_up, inputs, overflow = ( 181 self.model.preprocessing.check_reallocate( 182 self.model.preproc_state, inputs 183 ) 184 ) 185 self._fennol_inputs = inputs 186 if self.verbose and overflow: 187 print("FENNIX nblist overflow => reallocating nblist") 188 print(" size updates:", state_up) 189 else: 190 self._fennol_inputs = self.model.preprocess(**self._raw_inputs) 191 192 return self._fennol_inputs
14class FENNIXCalculator(ase.calculators.calculator.Calculator): 15 """FENNIX calculator for ASE. 16 17 Arguments: 18 ---------- 19 model: str or FENNIX 20 The path to the model or the model itself. 21 use_atom_padding: bool, default=False 22 Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False. 23 gpu_preprocessing: bool, default=False 24 Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems. 25 atoms: ase.Atoms, default=None 26 The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object. 27 verbose: bool, default=False 28 Whether to print nblist update information or not. 29 energy_terms: list of str, default=None 30 The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model. 31 """ 32 33 implemented_properties = ["energy", "forces", "stress"] 34 35 def __init__( 36 self, 37 model: Union[str, FENNIX], 38 gpu_preprocessing: bool = False, 39 atoms: Optional[ase.Atoms] = None, 40 verbose: bool = False, 41 energy_terms: Optional[Sequence[str]] = None, 42 use_float64: bool = False, 43 matmul_prec: Optional[str] = None, 44 **kwargs 45 ): 46 super().__init__() 47 if use_float64: 48 jax.config.update("jax_enable_x64", True) 49 if matmul_prec is not None: 50 assert matmul_prec in [ 51 "default", 52 "high", 53 "highest", 54 ], "matmul_prec should be one of 'default', 'high', 'highest'" 55 jax.config.update("jax_default_matmul_precision", matmul_prec) 56 57 if isinstance(model, FENNIX): 58 self.model = model 59 else: 60 self.model = FENNIX.load(model, **kwargs) 61 if energy_terms is not None: 62 self.model.set_energy_terms(energy_terms) 63 self.dtype = "float64" if use_float64 else "float32" 64 self.gpu_preprocessing = gpu_preprocessing 65 self.verbose = verbose 66 self._fennol_inputs = None 67 self._raw_inputs = None 68 69 model_unit = au.get_multiplier(self.model.energy_unit) 70 self.energy_conv = ase.units.Hartree / model_unit 71 if atoms is not None: 72 self.preprocess(atoms) 73 74 def calculate( 75 self, 76 atoms=None, 77 properties=["energy"], 78 system_changes=ase.calculators.calculator.all_changes, 79 ): 80 super().calculate(atoms, properties, system_changes) 81 inputs = self.preprocess(self.atoms, system_changes=system_changes) 82 83 results = {} 84 if "stress" in properties: 85 e, f, virial, output = self.model._energy_and_forces_and_virial( 86 self.model.variables, inputs 87 ) 88 volume = self.atoms.get_volume() 89 stress = np.asarray(virial[0]) * self.energy_conv / volume 90 results["stress"] = full_3x3_to_voigt_6_stress(stress) 91 results["forces"] = np.asarray(f) * self.energy_conv 92 elif "forces" in properties: 93 e, f, output = self.model._energy_and_forces(self.model.variables, inputs) 94 results["forces"] = np.asarray(f) * self.energy_conv 95 else: 96 e, output = self.model._total_energy(self.model.variables, inputs) 97 98 results["energy"] = float(e[0]) * self.energy_conv 99 if self.model.use_atom_padding and "forces" in results: 100 mask = np.asarray(output["true_atoms"]) 101 results["forces"] = results["forces"][mask] 102 103 self.results.update(results) 104 105 def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes): 106 107 force_cpu_preprocessing = False 108 if self._raw_inputs is None: 109 force_cpu_preprocessing = True 110 cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype) 111 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 112 if np.all(pbc): 113 use_pbc = True 114 elif np.any(pbc): 115 raise NotImplementedError("PBC should be activated in all directions.") 116 else: 117 use_pbc = False 118 119 species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32) 120 coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype) 121 natoms = np.array([len(species)], dtype=np.int32) 122 batch_index = np.array([0] * len(species), dtype=np.int32) 123 124 inputs = { 125 "species": species, 126 "coordinates": coordinates, 127 "natoms": natoms, 128 "batch_index": batch_index, 129 } 130 if use_pbc: 131 reciprocal_cell = np.linalg.inv(cell) 132 inputs["cells"] = cell.reshape(1, 3, 3) 133 inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3) 134 self._raw_inputs = convert_to_jax(inputs) 135 else: 136 if "cell" in system_changes: 137 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 138 if np.all(pbc): 139 use_pbc = True 140 elif np.any(pbc): 141 raise NotImplementedError( 142 "PBC should be activated in all directions." 143 ) 144 else: 145 use_pbc = False 146 if use_pbc: 147 cell = np.asarray( 148 atoms.get_cell(complete=True).array, dtype=self.dtype 149 ) 150 reciprocal_cell = np.linalg.inv(cell) 151 self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3)) 152 self._raw_inputs["reciprocal_cells"] = jnp.asarray( 153 reciprocal_cell.reshape(1, 3, 3) 154 ) 155 elif "cells" in self._raw_inputs: 156 del self._raw_inputs["cells"] 157 del self._raw_inputs["reciprocal_cells"] 158 if "numbers" in system_changes: 159 self._raw_inputs["species"] = jnp.asarray( 160 atoms.get_atomic_numbers(), dtype=jnp.int32 161 ) 162 self._raw_inputs["natoms"] = jnp.array( 163 [len(self._raw_inputs["species"])], dtype=np.int32 164 ) 165 self._raw_inputs["batch_index"] = jnp.array( 166 [0] * len(self._raw_inputs["species"]), dtype=np.int32 167 ) 168 force_cpu_preprocessing = True 169 if "positions" in system_changes: 170 self._raw_inputs["coordinates"] = jnp.asarray( 171 atoms.get_positions(), dtype=self.dtype 172 ) 173 174 if self.gpu_preprocessing and not force_cpu_preprocessing: 175 _, inputs = self.model.preprocessing.atom_padding( 176 self.model.preproc_state, self._raw_inputs 177 ) 178 inputs = {**self._fennol_inputs, **inputs} 179 180 inputs = self.model.preprocessing.process(self.model.preproc_state, inputs) 181 self.model.preproc_state, state_up, inputs, overflow = ( 182 self.model.preprocessing.check_reallocate( 183 self.model.preproc_state, inputs 184 ) 185 ) 186 self._fennol_inputs = inputs 187 if self.verbose and overflow: 188 print("FENNIX nblist overflow => reallocating nblist") 189 print(" size updates:", state_up) 190 else: 191 self._fennol_inputs = self.model.preprocess(**self._raw_inputs) 192 193 return self._fennol_inputs
FENNIX calculator for ASE.
Arguments:
model: str or FENNIX The path to the model or the model itself. use_atom_padding: bool, default=False Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False. gpu_preprocessing: bool, default=False Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems. atoms: ase.Atoms, default=None The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object. verbose: bool, default=False Whether to print nblist update information or not. energy_terms: list of str, default=None The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model.
35 def __init__( 36 self, 37 model: Union[str, FENNIX], 38 gpu_preprocessing: bool = False, 39 atoms: Optional[ase.Atoms] = None, 40 verbose: bool = False, 41 energy_terms: Optional[Sequence[str]] = None, 42 use_float64: bool = False, 43 matmul_prec: Optional[str] = None, 44 **kwargs 45 ): 46 super().__init__() 47 if use_float64: 48 jax.config.update("jax_enable_x64", True) 49 if matmul_prec is not None: 50 assert matmul_prec in [ 51 "default", 52 "high", 53 "highest", 54 ], "matmul_prec should be one of 'default', 'high', 'highest'" 55 jax.config.update("jax_default_matmul_precision", matmul_prec) 56 57 if isinstance(model, FENNIX): 58 self.model = model 59 else: 60 self.model = FENNIX.load(model, **kwargs) 61 if energy_terms is not None: 62 self.model.set_energy_terms(energy_terms) 63 self.dtype = "float64" if use_float64 else "float32" 64 self.gpu_preprocessing = gpu_preprocessing 65 self.verbose = verbose 66 self._fennol_inputs = None 67 self._raw_inputs = None 68 69 model_unit = au.get_multiplier(self.model.energy_unit) 70 self.energy_conv = ase.units.Hartree / model_unit 71 if atoms is not None: 72 self.preprocess(atoms)
Basic calculator implementation.
restart: str Prefix for restart file. May contain a directory. Default is None: don't restart. ignore_bad_restart_file: bool Deprecated, please do not use. Passing more than one positional argument to Calculator() is deprecated and will stop working in the future. Ignore broken or missing restart file. By default, it is an error if the restart file is missing or broken. directory: str or PurePath Working directory in which to read and write files and perform calculations. label: str Name used for all files. Not supported by all calculators. May contain a directory, but please use the directory parameter for that instead. atoms: Atoms object Optional Atoms object to which the calculator will be attached. When restarting, atoms will get its positions and unit-cell updated from file.
Properties calculator can handle (energy, forces, ...)
74 def calculate( 75 self, 76 atoms=None, 77 properties=["energy"], 78 system_changes=ase.calculators.calculator.all_changes, 79 ): 80 super().calculate(atoms, properties, system_changes) 81 inputs = self.preprocess(self.atoms, system_changes=system_changes) 82 83 results = {} 84 if "stress" in properties: 85 e, f, virial, output = self.model._energy_and_forces_and_virial( 86 self.model.variables, inputs 87 ) 88 volume = self.atoms.get_volume() 89 stress = np.asarray(virial[0]) * self.energy_conv / volume 90 results["stress"] = full_3x3_to_voigt_6_stress(stress) 91 results["forces"] = np.asarray(f) * self.energy_conv 92 elif "forces" in properties: 93 e, f, output = self.model._energy_and_forces(self.model.variables, inputs) 94 results["forces"] = np.asarray(f) * self.energy_conv 95 else: 96 e, output = self.model._total_energy(self.model.variables, inputs) 97 98 results["energy"] = float(e[0]) * self.energy_conv 99 if self.model.use_atom_padding and "forces" in results: 100 mask = np.asarray(output["true_atoms"]) 101 results["forces"] = results["forces"][mask] 102 103 self.results.update(results)
Do the calculation.
properties: list of str List of what needs to be calculated. Can be any combination of 'energy', 'forces', 'stress', 'dipole', 'charges', 'magmom' and 'magmoms'. system_changes: list of str List of what has changed since last calculation. Can be any combination of these six: 'positions', 'numbers', 'cell', 'pbc', 'initial_charges' and 'initial_magmoms'.
Subclasses need to implement this, but can ignore properties and system_changes if they want. Calculated properties should be inserted into results dictionary like shown in this dummy example::
self.results = {'energy': 0.0,
'forces': np.zeros((len(atoms), 3)),
'stress': np.zeros(6),
'dipole': np.zeros(3),
'charges': np.zeros(len(atoms)),
'magmom': 0.0,
'magmoms': np.zeros(len(atoms))}
The subclass implementation should first call this implementation to set the atoms attribute and create any missing directories.
105 def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes): 106 107 force_cpu_preprocessing = False 108 if self._raw_inputs is None: 109 force_cpu_preprocessing = True 110 cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype) 111 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 112 if np.all(pbc): 113 use_pbc = True 114 elif np.any(pbc): 115 raise NotImplementedError("PBC should be activated in all directions.") 116 else: 117 use_pbc = False 118 119 species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32) 120 coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype) 121 natoms = np.array([len(species)], dtype=np.int32) 122 batch_index = np.array([0] * len(species), dtype=np.int32) 123 124 inputs = { 125 "species": species, 126 "coordinates": coordinates, 127 "natoms": natoms, 128 "batch_index": batch_index, 129 } 130 if use_pbc: 131 reciprocal_cell = np.linalg.inv(cell) 132 inputs["cells"] = cell.reshape(1, 3, 3) 133 inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3) 134 self._raw_inputs = convert_to_jax(inputs) 135 else: 136 if "cell" in system_changes: 137 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 138 if np.all(pbc): 139 use_pbc = True 140 elif np.any(pbc): 141 raise NotImplementedError( 142 "PBC should be activated in all directions." 143 ) 144 else: 145 use_pbc = False 146 if use_pbc: 147 cell = np.asarray( 148 atoms.get_cell(complete=True).array, dtype=self.dtype 149 ) 150 reciprocal_cell = np.linalg.inv(cell) 151 self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3)) 152 self._raw_inputs["reciprocal_cells"] = jnp.asarray( 153 reciprocal_cell.reshape(1, 3, 3) 154 ) 155 elif "cells" in self._raw_inputs: 156 del self._raw_inputs["cells"] 157 del self._raw_inputs["reciprocal_cells"] 158 if "numbers" in system_changes: 159 self._raw_inputs["species"] = jnp.asarray( 160 atoms.get_atomic_numbers(), dtype=jnp.int32 161 ) 162 self._raw_inputs["natoms"] = jnp.array( 163 [len(self._raw_inputs["species"])], dtype=np.int32 164 ) 165 self._raw_inputs["batch_index"] = jnp.array( 166 [0] * len(self._raw_inputs["species"]), dtype=np.int32 167 ) 168 force_cpu_preprocessing = True 169 if "positions" in system_changes: 170 self._raw_inputs["coordinates"] = jnp.asarray( 171 atoms.get_positions(), dtype=self.dtype 172 ) 173 174 if self.gpu_preprocessing and not force_cpu_preprocessing: 175 _, inputs = self.model.preprocessing.atom_padding( 176 self.model.preproc_state, self._raw_inputs 177 ) 178 inputs = {**self._fennol_inputs, **inputs} 179 180 inputs = self.model.preprocessing.process(self.model.preproc_state, inputs) 181 self.model.preproc_state, state_up, inputs, overflow = ( 182 self.model.preprocessing.check_reallocate( 183 self.model.preproc_state, inputs 184 ) 185 ) 186 self._fennol_inputs = inputs 187 if self.verbose and overflow: 188 print("FENNIX nblist overflow => reallocating nblist") 189 print(" size updates:", state_up) 190 else: 191 self._fennol_inputs = self.model.preprocess(**self._raw_inputs) 192 193 return self._fennol_inputs