fennol.ase
1import ase 2import ase.calculators.calculator 3import ase.units 4import numpy as np 5import jax.numpy as jnp 6from . import FENNIX 7from .models.preprocessing import convert_to_jax 8from typing import Sequence, Union, Optional 9from ase.stress import full_3x3_to_voigt_6_stress 10import jax 11from .utils import AtomicUnits as au 12 13class FENNIXCalculator(ase.calculators.calculator.Calculator): 14 """FENNIX calculator for ASE. 15 16 Arguments: 17 ---------- 18 model: str or FENNIX 19 The path to the model or the model itself. 20 use_atom_padding: bool, default=False 21 Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False. 22 gpu_preprocessing: bool, default=False 23 Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems. 24 atoms: ase.Atoms, default=None 25 The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object. 26 verbose: bool, default=False 27 Whether to print nblist update information or not. 28 energy_terms: list of str, default=None 29 The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model. 30 """ 31 32 implemented_properties = ["energy", "forces", "stress"] 33 34 def __init__( 35 self, 36 model: Union[str, FENNIX], 37 gpu_preprocessing: bool = False, 38 atoms: Optional[ase.Atoms] = None, 39 verbose: bool = False, 40 energy_terms: Optional[Sequence[str]] = None, 41 use_float64: bool = False, 42 matmul_prec: Optional[str] = None, 43 save_raw_output: bool = False, 44 **kwargs 45 ): 46 super().__init__() 47 if use_float64: 48 jax.config.update("jax_enable_x64", True) 49 if matmul_prec is not None: 50 assert matmul_prec in [ 51 "default", 52 "high", 53 "highest", 54 ], "matmul_prec should be one of 'default', 'high', 'highest'" 55 jax.config.update("jax_default_matmul_precision", matmul_prec) 56 57 if isinstance(model, FENNIX): 58 self.model = model 59 else: 60 self.model = FENNIX.load(model, **kwargs) 61 if energy_terms is not None: 62 self.model.set_energy_terms(energy_terms) 63 self.dtype = "float64" if use_float64 else "float32" 64 self.gpu_preprocessing = gpu_preprocessing 65 self.verbose = verbose 66 self._fennol_inputs = None 67 self._raw_inputs = None 68 self.save_raw_output = save_raw_output 69 70 model_unit = au.get_multiplier(self.model.energy_unit) 71 self.energy_conv = ase.units.Hartree / model_unit 72 if atoms is not None: 73 self.preprocess(atoms) 74 75 def calculate( 76 self, 77 atoms=None, 78 properties=["energy"], 79 system_changes=ase.calculators.calculator.all_changes, 80 ): 81 super().calculate(atoms, properties, system_changes) 82 inputs = self.preprocess(self.atoms, system_changes=system_changes) 83 total_charge = self.atoms.get_initial_charges().sum() 84 inputs["total_charge"] = int(total_charge) 85 86 results = {} 87 if "stress" in properties: 88 e, f, virial, output = self.model._energy_and_forces_and_virial( 89 self.model.variables, inputs 90 ) 91 volume = self.atoms.get_volume() 92 stress = np.asarray(virial[0]) * self.energy_conv / volume 93 results["stress"] = full_3x3_to_voigt_6_stress(stress) 94 results["forces"] = np.asarray(f) * self.energy_conv 95 elif "forces" in properties: 96 e, f, output = self.model._energy_and_forces(self.model.variables, inputs) 97 results["forces"] = np.asarray(f) * self.energy_conv 98 else: 99 e, output = self.model._total_energy(self.model.variables, inputs) 100 101 results["energy"] = float(e[0]) * self.energy_conv 102 if self.model.use_atom_padding and "forces" in results: 103 mask = np.asarray(output["true_atoms"]) 104 results["forces"] = results["forces"][mask] 105 106 if self.save_raw_output: 107 results["raw_output"] = output 108 109 self.results.update(results) 110 111 def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes): 112 113 force_cpu_preprocessing = False 114 if self._raw_inputs is None: 115 force_cpu_preprocessing = True 116 cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype) 117 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 118 if np.all(pbc): 119 use_pbc = True 120 elif np.any(pbc): 121 raise NotImplementedError("PBC should be activated in all directions.") 122 else: 123 use_pbc = False 124 125 species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32) 126 coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype) 127 natoms = np.array([len(species)], dtype=np.int32) 128 batch_index = np.array([0] * len(species), dtype=np.int32) 129 130 inputs = { 131 "species": species, 132 "coordinates": coordinates, 133 "natoms": natoms, 134 "batch_index": batch_index, 135 } 136 if use_pbc: 137 reciprocal_cell = np.linalg.inv(cell) 138 inputs["cells"] = cell.reshape(1, 3, 3) 139 inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3) 140 self._raw_inputs = convert_to_jax(inputs) 141 else: 142 if "cell" in system_changes: 143 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 144 if np.all(pbc): 145 use_pbc = True 146 elif np.any(pbc): 147 raise NotImplementedError( 148 "PBC should be activated in all directions." 149 ) 150 else: 151 use_pbc = False 152 if use_pbc: 153 cell = np.asarray( 154 atoms.get_cell(complete=True).array, dtype=self.dtype 155 ) 156 reciprocal_cell = np.linalg.inv(cell) 157 self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3)) 158 self._raw_inputs["reciprocal_cells"] = jnp.asarray( 159 reciprocal_cell.reshape(1, 3, 3) 160 ) 161 elif "cells" in self._raw_inputs: 162 del self._raw_inputs["cells"] 163 del self._raw_inputs["reciprocal_cells"] 164 if "numbers" in system_changes: 165 self._raw_inputs["species"] = jnp.asarray( 166 atoms.get_atomic_numbers(), dtype=jnp.int32 167 ) 168 self._raw_inputs["natoms"] = jnp.array( 169 [len(self._raw_inputs["species"])], dtype=np.int32 170 ) 171 self._raw_inputs["batch_index"] = jnp.array( 172 [0] * len(self._raw_inputs["species"]), dtype=np.int32 173 ) 174 force_cpu_preprocessing = True 175 if "positions" in system_changes: 176 self._raw_inputs["coordinates"] = jnp.asarray( 177 atoms.get_positions(), dtype=self.dtype 178 ) 179 180 if self.gpu_preprocessing and not force_cpu_preprocessing: 181 _, inputs = self.model.preprocessing.atom_padding( 182 self.model.preproc_state, self._raw_inputs 183 ) 184 inputs = {**self._fennol_inputs, **inputs} 185 186 inputs = self.model.preprocessing.process(self.model.preproc_state, inputs) 187 self.model.preproc_state, state_up, inputs, overflow = ( 188 self.model.preprocessing.check_reallocate( 189 self.model.preproc_state, inputs 190 ) 191 ) 192 self._fennol_inputs = inputs 193 if self.verbose and overflow: 194 print("FENNIX nblist overflow => reallocating nblist") 195 print(" size updates:", state_up) 196 else: 197 self._fennol_inputs = self.model.preprocess(**self._raw_inputs) 198 199 return self._fennol_inputs
14class FENNIXCalculator(ase.calculators.calculator.Calculator): 15 """FENNIX calculator for ASE. 16 17 Arguments: 18 ---------- 19 model: str or FENNIX 20 The path to the model or the model itself. 21 use_atom_padding: bool, default=False 22 Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False. 23 gpu_preprocessing: bool, default=False 24 Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems. 25 atoms: ase.Atoms, default=None 26 The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object. 27 verbose: bool, default=False 28 Whether to print nblist update information or not. 29 energy_terms: list of str, default=None 30 The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model. 31 """ 32 33 implemented_properties = ["energy", "forces", "stress"] 34 35 def __init__( 36 self, 37 model: Union[str, FENNIX], 38 gpu_preprocessing: bool = False, 39 atoms: Optional[ase.Atoms] = None, 40 verbose: bool = False, 41 energy_terms: Optional[Sequence[str]] = None, 42 use_float64: bool = False, 43 matmul_prec: Optional[str] = None, 44 save_raw_output: bool = False, 45 **kwargs 46 ): 47 super().__init__() 48 if use_float64: 49 jax.config.update("jax_enable_x64", True) 50 if matmul_prec is not None: 51 assert matmul_prec in [ 52 "default", 53 "high", 54 "highest", 55 ], "matmul_prec should be one of 'default', 'high', 'highest'" 56 jax.config.update("jax_default_matmul_precision", matmul_prec) 57 58 if isinstance(model, FENNIX): 59 self.model = model 60 else: 61 self.model = FENNIX.load(model, **kwargs) 62 if energy_terms is not None: 63 self.model.set_energy_terms(energy_terms) 64 self.dtype = "float64" if use_float64 else "float32" 65 self.gpu_preprocessing = gpu_preprocessing 66 self.verbose = verbose 67 self._fennol_inputs = None 68 self._raw_inputs = None 69 self.save_raw_output = save_raw_output 70 71 model_unit = au.get_multiplier(self.model.energy_unit) 72 self.energy_conv = ase.units.Hartree / model_unit 73 if atoms is not None: 74 self.preprocess(atoms) 75 76 def calculate( 77 self, 78 atoms=None, 79 properties=["energy"], 80 system_changes=ase.calculators.calculator.all_changes, 81 ): 82 super().calculate(atoms, properties, system_changes) 83 inputs = self.preprocess(self.atoms, system_changes=system_changes) 84 total_charge = self.atoms.get_initial_charges().sum() 85 inputs["total_charge"] = int(total_charge) 86 87 results = {} 88 if "stress" in properties: 89 e, f, virial, output = self.model._energy_and_forces_and_virial( 90 self.model.variables, inputs 91 ) 92 volume = self.atoms.get_volume() 93 stress = np.asarray(virial[0]) * self.energy_conv / volume 94 results["stress"] = full_3x3_to_voigt_6_stress(stress) 95 results["forces"] = np.asarray(f) * self.energy_conv 96 elif "forces" in properties: 97 e, f, output = self.model._energy_and_forces(self.model.variables, inputs) 98 results["forces"] = np.asarray(f) * self.energy_conv 99 else: 100 e, output = self.model._total_energy(self.model.variables, inputs) 101 102 results["energy"] = float(e[0]) * self.energy_conv 103 if self.model.use_atom_padding and "forces" in results: 104 mask = np.asarray(output["true_atoms"]) 105 results["forces"] = results["forces"][mask] 106 107 if self.save_raw_output: 108 results["raw_output"] = output 109 110 self.results.update(results) 111 112 def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes): 113 114 force_cpu_preprocessing = False 115 if self._raw_inputs is None: 116 force_cpu_preprocessing = True 117 cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype) 118 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 119 if np.all(pbc): 120 use_pbc = True 121 elif np.any(pbc): 122 raise NotImplementedError("PBC should be activated in all directions.") 123 else: 124 use_pbc = False 125 126 species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32) 127 coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype) 128 natoms = np.array([len(species)], dtype=np.int32) 129 batch_index = np.array([0] * len(species), dtype=np.int32) 130 131 inputs = { 132 "species": species, 133 "coordinates": coordinates, 134 "natoms": natoms, 135 "batch_index": batch_index, 136 } 137 if use_pbc: 138 reciprocal_cell = np.linalg.inv(cell) 139 inputs["cells"] = cell.reshape(1, 3, 3) 140 inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3) 141 self._raw_inputs = convert_to_jax(inputs) 142 else: 143 if "cell" in system_changes: 144 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 145 if np.all(pbc): 146 use_pbc = True 147 elif np.any(pbc): 148 raise NotImplementedError( 149 "PBC should be activated in all directions." 150 ) 151 else: 152 use_pbc = False 153 if use_pbc: 154 cell = np.asarray( 155 atoms.get_cell(complete=True).array, dtype=self.dtype 156 ) 157 reciprocal_cell = np.linalg.inv(cell) 158 self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3)) 159 self._raw_inputs["reciprocal_cells"] = jnp.asarray( 160 reciprocal_cell.reshape(1, 3, 3) 161 ) 162 elif "cells" in self._raw_inputs: 163 del self._raw_inputs["cells"] 164 del self._raw_inputs["reciprocal_cells"] 165 if "numbers" in system_changes: 166 self._raw_inputs["species"] = jnp.asarray( 167 atoms.get_atomic_numbers(), dtype=jnp.int32 168 ) 169 self._raw_inputs["natoms"] = jnp.array( 170 [len(self._raw_inputs["species"])], dtype=np.int32 171 ) 172 self._raw_inputs["batch_index"] = jnp.array( 173 [0] * len(self._raw_inputs["species"]), dtype=np.int32 174 ) 175 force_cpu_preprocessing = True 176 if "positions" in system_changes: 177 self._raw_inputs["coordinates"] = jnp.asarray( 178 atoms.get_positions(), dtype=self.dtype 179 ) 180 181 if self.gpu_preprocessing and not force_cpu_preprocessing: 182 _, inputs = self.model.preprocessing.atom_padding( 183 self.model.preproc_state, self._raw_inputs 184 ) 185 inputs = {**self._fennol_inputs, **inputs} 186 187 inputs = self.model.preprocessing.process(self.model.preproc_state, inputs) 188 self.model.preproc_state, state_up, inputs, overflow = ( 189 self.model.preprocessing.check_reallocate( 190 self.model.preproc_state, inputs 191 ) 192 ) 193 self._fennol_inputs = inputs 194 if self.verbose and overflow: 195 print("FENNIX nblist overflow => reallocating nblist") 196 print(" size updates:", state_up) 197 else: 198 self._fennol_inputs = self.model.preprocess(**self._raw_inputs) 199 200 return self._fennol_inputs
FENNIX calculator for ASE.
Arguments:
model: str or FENNIX The path to the model or the model itself. use_atom_padding: bool, default=False Whether to use atom padding or not. Atom padding is useful to prevent recompiling the model if the number of atoms changes. If the number of atoms is expected to be fixed, it is recommended to set this to False. gpu_preprocessing: bool, default=False Whether to preprocess the data on the GPU or not. This is useful for large systems, but may not be necessary (or even slower) for small systems. atoms: ase.Atoms, default=None The atoms object to be used for the calculation. If provided, the calculator will be initialized with the atoms object. verbose: bool, default=False Whether to print nblist update information or not. energy_terms: list of str, default=None The energy terms to include in the total energy. If None, this will default to the energy terms defined in the model.
35 def __init__( 36 self, 37 model: Union[str, FENNIX], 38 gpu_preprocessing: bool = False, 39 atoms: Optional[ase.Atoms] = None, 40 verbose: bool = False, 41 energy_terms: Optional[Sequence[str]] = None, 42 use_float64: bool = False, 43 matmul_prec: Optional[str] = None, 44 save_raw_output: bool = False, 45 **kwargs 46 ): 47 super().__init__() 48 if use_float64: 49 jax.config.update("jax_enable_x64", True) 50 if matmul_prec is not None: 51 assert matmul_prec in [ 52 "default", 53 "high", 54 "highest", 55 ], "matmul_prec should be one of 'default', 'high', 'highest'" 56 jax.config.update("jax_default_matmul_precision", matmul_prec) 57 58 if isinstance(model, FENNIX): 59 self.model = model 60 else: 61 self.model = FENNIX.load(model, **kwargs) 62 if energy_terms is not None: 63 self.model.set_energy_terms(energy_terms) 64 self.dtype = "float64" if use_float64 else "float32" 65 self.gpu_preprocessing = gpu_preprocessing 66 self.verbose = verbose 67 self._fennol_inputs = None 68 self._raw_inputs = None 69 self.save_raw_output = save_raw_output 70 71 model_unit = au.get_multiplier(self.model.energy_unit) 72 self.energy_conv = ase.units.Hartree / model_unit 73 if atoms is not None: 74 self.preprocess(atoms)
Basic calculator implementation.
restart: str Prefix for restart file. May contain a directory. Default is None: don't restart. ignore_bad_restart_file: bool Deprecated, please do not use. Passing more than one positional argument to Calculator() is deprecated and will stop working in the future. Ignore broken or missing restart file. By default, it is an error if the restart file is missing or broken. directory: str or PurePath Working directory in which to read and write files and perform calculations. label: str Name used for all files. Not supported by all calculators. May contain a directory, but please use the directory parameter for that instead. atoms: Atoms object Optional Atoms object to which the calculator will be attached. When restarting, atoms will get its positions and unit-cell updated from file.
Properties calculator can handle (energy, forces, ...)
76 def calculate( 77 self, 78 atoms=None, 79 properties=["energy"], 80 system_changes=ase.calculators.calculator.all_changes, 81 ): 82 super().calculate(atoms, properties, system_changes) 83 inputs = self.preprocess(self.atoms, system_changes=system_changes) 84 total_charge = self.atoms.get_initial_charges().sum() 85 inputs["total_charge"] = int(total_charge) 86 87 results = {} 88 if "stress" in properties: 89 e, f, virial, output = self.model._energy_and_forces_and_virial( 90 self.model.variables, inputs 91 ) 92 volume = self.atoms.get_volume() 93 stress = np.asarray(virial[0]) * self.energy_conv / volume 94 results["stress"] = full_3x3_to_voigt_6_stress(stress) 95 results["forces"] = np.asarray(f) * self.energy_conv 96 elif "forces" in properties: 97 e, f, output = self.model._energy_and_forces(self.model.variables, inputs) 98 results["forces"] = np.asarray(f) * self.energy_conv 99 else: 100 e, output = self.model._total_energy(self.model.variables, inputs) 101 102 results["energy"] = float(e[0]) * self.energy_conv 103 if self.model.use_atom_padding and "forces" in results: 104 mask = np.asarray(output["true_atoms"]) 105 results["forces"] = results["forces"][mask] 106 107 if self.save_raw_output: 108 results["raw_output"] = output 109 110 self.results.update(results)
Do the calculation.
properties: list of str List of what needs to be calculated. Can be any combination of 'energy', 'forces', 'stress', 'dipole', 'charges', 'magmom' and 'magmoms'. system_changes: list of str List of what has changed since last calculation. Can be any combination of these six: 'positions', 'numbers', 'cell', 'pbc', 'initial_charges' and 'initial_magmoms'.
Subclasses need to implement this, but can ignore properties and system_changes if they want. Calculated properties should be inserted into results dictionary like shown in this dummy example::
self.results = {'energy': 0.0,
'forces': np.zeros((len(atoms), 3)),
'stress': np.zeros(6),
'dipole': np.zeros(3),
'charges': np.zeros(len(atoms)),
'magmom': 0.0,
'magmoms': np.zeros(len(atoms))}
The subclass implementation should first call this implementation to set the atoms attribute and create any missing directories.
112 def preprocess(self, atoms, system_changes=ase.calculators.calculator.all_changes): 113 114 force_cpu_preprocessing = False 115 if self._raw_inputs is None: 116 force_cpu_preprocessing = True 117 cell = np.asarray(atoms.get_cell(complete=True).array, dtype=self.dtype) 118 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 119 if np.all(pbc): 120 use_pbc = True 121 elif np.any(pbc): 122 raise NotImplementedError("PBC should be activated in all directions.") 123 else: 124 use_pbc = False 125 126 species = np.asarray(atoms.get_atomic_numbers(), dtype=np.int32) 127 coordinates = np.asarray(atoms.get_positions(), dtype=self.dtype) 128 natoms = np.array([len(species)], dtype=np.int32) 129 batch_index = np.array([0] * len(species), dtype=np.int32) 130 131 inputs = { 132 "species": species, 133 "coordinates": coordinates, 134 "natoms": natoms, 135 "batch_index": batch_index, 136 } 137 if use_pbc: 138 reciprocal_cell = np.linalg.inv(cell) 139 inputs["cells"] = cell.reshape(1, 3, 3) 140 inputs["reciprocal_cells"] = reciprocal_cell.reshape(1, 3, 3) 141 self._raw_inputs = convert_to_jax(inputs) 142 else: 143 if "cell" in system_changes: 144 pbc = np.asarray(atoms.get_pbc(), dtype=bool) 145 if np.all(pbc): 146 use_pbc = True 147 elif np.any(pbc): 148 raise NotImplementedError( 149 "PBC should be activated in all directions." 150 ) 151 else: 152 use_pbc = False 153 if use_pbc: 154 cell = np.asarray( 155 atoms.get_cell(complete=True).array, dtype=self.dtype 156 ) 157 reciprocal_cell = np.linalg.inv(cell) 158 self._raw_inputs["cells"] = jnp.asarray(cell.reshape(1, 3, 3)) 159 self._raw_inputs["reciprocal_cells"] = jnp.asarray( 160 reciprocal_cell.reshape(1, 3, 3) 161 ) 162 elif "cells" in self._raw_inputs: 163 del self._raw_inputs["cells"] 164 del self._raw_inputs["reciprocal_cells"] 165 if "numbers" in system_changes: 166 self._raw_inputs["species"] = jnp.asarray( 167 atoms.get_atomic_numbers(), dtype=jnp.int32 168 ) 169 self._raw_inputs["natoms"] = jnp.array( 170 [len(self._raw_inputs["species"])], dtype=np.int32 171 ) 172 self._raw_inputs["batch_index"] = jnp.array( 173 [0] * len(self._raw_inputs["species"]), dtype=np.int32 174 ) 175 force_cpu_preprocessing = True 176 if "positions" in system_changes: 177 self._raw_inputs["coordinates"] = jnp.asarray( 178 atoms.get_positions(), dtype=self.dtype 179 ) 180 181 if self.gpu_preprocessing and not force_cpu_preprocessing: 182 _, inputs = self.model.preprocessing.atom_padding( 183 self.model.preproc_state, self._raw_inputs 184 ) 185 inputs = {**self._fennol_inputs, **inputs} 186 187 inputs = self.model.preprocessing.process(self.model.preproc_state, inputs) 188 self.model.preproc_state, state_up, inputs, overflow = ( 189 self.model.preprocessing.check_reallocate( 190 self.model.preproc_state, inputs 191 ) 192 ) 193 self._fennol_inputs = inputs 194 if self.verbose and overflow: 195 print("FENNIX nblist overflow => reallocating nblist") 196 print(" size updates:", state_up) 197 else: 198 self._fennol_inputs = self.model.preprocess(**self._raw_inputs) 199 200 return self._fennol_inputs