fennol.models.fennix
1from typing import Any, Sequence, Callable, Union, Optional, Tuple, Dict 2from copy import deepcopy 3import dataclasses 4from collections import OrderedDict 5 6import jax 7import jax.numpy as jnp 8import flax.linen as nn 9import numpy as np 10from flax import serialization 11from flax.core.frozen_dict import freeze, unfreeze, FrozenDict 12from ..utils import AtomicUnits as au 13 14from .preprocessing import ( 15 GraphGenerator, 16 PreprocessingChain, 17 JaxConverter, 18 atom_unpadding, 19 check_input, 20) 21from .modules import MODULES, PREPROCESSING, FENNIXModules 22 23 24@dataclasses.dataclass 25class FENNIX: 26 """ 27 Static wrapper for FENNIX models 28 29 The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary 30 which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization. 31 32 Since the model is static and contains variables, it must be initialized right away with either 33 `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data` 34 is provided, the model is initialized with `example_data` and the resulting variables are stored 35 in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting. 36 """ 37 38 cutoff: Union[float, None] 39 modules: FENNIXModules 40 variables: Dict 41 preprocessing: PreprocessingChain 42 _apply: Callable[[Dict, Dict], Dict] 43 _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]] 44 _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]] 45 _input_args: Dict 46 _graphs_properties: Dict 47 preproc_state: Dict 48 energy_terms: Optional[Sequence[str]] = None 49 _initializing: bool = True 50 use_atom_padding: bool = False 51 52 def __init__( 53 self, 54 cutoff: float, 55 modules: OrderedDict, 56 preprocessing: OrderedDict = OrderedDict(), 57 example_data=None, 58 rng_key: Optional[jax.random.PRNGKey] = None, 59 variables: Optional[dict] = None, 60 energy_terms: Optional[Sequence[str]] = None, 61 use_atom_padding: bool = False, 62 graph_config: Dict = {}, 63 energy_unit: str = "Ha", 64 **kwargs, 65 ) -> None: 66 """Initialize the FENNIX model 67 68 Arguments: 69 ---------- 70 cutoff: float 71 The cutoff radius for the model 72 modules: OrderedDict 73 The dictionary defining the sequence of FeNNol modules and their parameters. 74 preprocessing: OrderedDict 75 The dictionary defining the sequence of preprocessing modules and their parameters. 76 example_data: dict 77 Example data to initialize the model. If not provided, a dummy system is generated. 78 rng_key: jax.random.PRNGKey 79 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 80 variables: dict 81 The variables of the model (i.e. weights, biases and all other tunable parameters). 82 If not provided, the variables are initialized (usually at random) 83 energy_terms: Sequence[str] 84 The energy terms in the model output that will be summed to compute the total energy. 85 If None, the total energy is always zero (useful for non-PES models). 86 use_atom_padding: bool 87 If True, the model will use atom padding for the input data. 88 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 89 graph_config: dict 90 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 91 92 """ 93 self._input_args = { 94 "cutoff": cutoff, 95 "modules": OrderedDict(modules), 96 "preprocessing": OrderedDict(preprocessing), 97 "energy_unit": energy_unit, 98 } 99 self.energy_unit = energy_unit 100 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 101 self.cutoff = cutoff 102 self.energy_terms = energy_terms 103 self.use_atom_padding = use_atom_padding 104 105 # add non-differentiable/non-jittable modules 106 preprocessing = deepcopy(preprocessing) 107 if cutoff is None: 108 preprocessing_modules = [] 109 else: 110 prep_keys = list(preprocessing.keys()) 111 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 112 if len(prep_keys) > 0 and prep_keys[0] == "graph": 113 graph_params = { 114 **graph_params, 115 **preprocessing.pop("graph"), 116 } 117 graph_params = {**graph_params, **graph_config} 118 119 preprocessing_modules = [ 120 GraphGenerator(**graph_params), 121 ] 122 123 for name, params in preprocessing.items(): 124 key = str(params.pop("module_name")) if "module_name" in params else name 125 key = str(params.pop("FID")) if "FID" in params else key 126 mod = PREPROCESSING[key.upper()](**freeze(params)) 127 preprocessing_modules.append(mod) 128 129 self.preprocessing = PreprocessingChain( 130 tuple(preprocessing_modules), use_atom_padding 131 ) 132 graphs_properties = self.preprocessing.get_graphs_properties() 133 self._graphs_properties = freeze(graphs_properties) 134 # add preprocessing modules that should be differentiated/jitted 135 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 136 # mods = self.preprocessing.get_processors(return_list=True) 137 138 # build the model 139 modules = deepcopy(modules) 140 modules_names = [] 141 for name, params in modules.items(): 142 key = str(params.pop("module_name")) if "module_name" in params else name 143 key = str(params.pop("FID")) if "FID" in params else key 144 if name in modules_names: 145 raise ValueError(f"Module {name} already exists") 146 modules_names.append(name) 147 params["name"] = name 148 mod = MODULES[key.upper()] 149 fields = [f.name for f in dataclasses.fields(mod)] 150 if "_graphs_properties" in fields: 151 params["_graphs_properties"] = graphs_properties 152 if "_energy_unit" in fields: 153 params["_energy_unit"] = energy_unit 154 mods.append((mod, params)) 155 156 self.modules = FENNIXModules(mods) 157 158 self.__apply = self.modules.apply 159 self._apply = jax.jit(self.modules.apply) 160 161 self.set_energy_terms(energy_terms) 162 163 # initialize the model 164 165 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 166 167 if variables is not None: 168 self.variables = variables 169 elif rng_key is not None: 170 self.variables = self.modules.init(rng_key, inputs) 171 else: 172 raise ValueError( 173 "Either variables or a jax.random.PRNGKey must be provided for initialization" 174 ) 175 176 self._initializing = False 177 178 def set_energy_terms( 179 self, energy_terms: Union[Sequence[str], None], jit: bool = True 180 ) -> None: 181 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 182 object.__setattr__(self, "energy_terms", energy_terms) 183 if energy_terms is None or len(energy_terms) == 0: 184 185 def total_energy(variables, data): 186 out = self.__apply(variables, data) 187 coords = out["coordinates"] 188 nsys = out["natoms"].shape[0] 189 nat = coords.shape[0] 190 dtype = coords.dtype 191 e = jnp.zeros(nsys, dtype=dtype) 192 eat = jnp.zeros(nat, dtype=dtype) 193 out["total_energy"] = e 194 out["atomic_energies"] = eat 195 return e, out 196 197 def energy_and_forces(variables, data): 198 e, out = total_energy(variables, data) 199 f = jnp.zeros_like(out["coordinates"]) 200 out["forces"] = f 201 return e, f, out 202 203 def energy_and_forces_and_virial(variables, data): 204 e, f, out = energy_and_forces(variables, data) 205 v = jnp.zeros( 206 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 207 ) 208 out["virial_tensor"] = v 209 return e, f, v, out 210 211 else: 212 # build the energy and force functions 213 def total_energy(variables, data): 214 out = self.__apply(variables, data) 215 atomic_energies = 0.0 216 system_energies = 0.0 217 species = out["species"] 218 nsys = out["natoms"].shape[0] 219 for term in self.energy_terms: 220 e = out[term] 221 if e.ndim > 1 and e.shape[-1] == 1: 222 e = jnp.squeeze(e, axis=-1) 223 if e.shape[0] == nsys and nsys != species.shape[0]: 224 system_energies += e 225 continue 226 assert e.shape == species.shape 227 atomic_energies += e 228 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 229 if isinstance(atomic_energies, jnp.ndarray): 230 if "true_atoms" in out: 231 atomic_energies = jnp.where( 232 out["true_atoms"], atomic_energies, 0.0 233 ) 234 out["atomic_energies"] = atomic_energies 235 energies = jax.ops.segment_sum( 236 atomic_energies, 237 data["batch_index"], 238 num_segments=len(data["natoms"]), 239 ) 240 else: 241 energies = 0.0 242 243 if isinstance(system_energies, jnp.ndarray): 244 if "true_sys" in out: 245 system_energies = jnp.where( 246 out["true_sys"], system_energies, 0.0 247 ) 248 out["system_energies"] = system_energies 249 250 out["total_energy"] = energies + system_energies 251 return out["total_energy"], out 252 253 def energy_and_forces(variables, data): 254 def _etot(variables, coordinates): 255 energy, out = total_energy( 256 variables, {**data, "coordinates": coordinates} 257 ) 258 return energy.sum(), out 259 260 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 261 variables, data["coordinates"] 262 ) 263 out["forces"] = -de 264 265 return out["total_energy"], out["forces"], out 266 267 # def energy_and_forces_and_virial(variables, data): 268 # x = data["coordinates"] 269 # batch_index = data["batch_index"] 270 # if "cells" in data: 271 # cells = data["cells"] 272 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 273 274 # def _etot(variables, coordinates, cells): 275 # reciprocal_cells = jnp.linalg.inv(cells) 276 # energy, out = total_energy( 277 # variables, 278 # { 279 # **data, 280 # "coordinates": coordinates, 281 # "cells": cells, 282 # "reciprocal_cells": reciprocal_cells, 283 # }, 284 # ) 285 # return energy.sum(), out 286 287 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 288 # variables, x, cells 289 # ) 290 # f= -dedx 291 # out["forces"] = f 292 # else: 293 # _,f,out = energy_and_forces(variables, data) 294 295 # vir = -jax.ops.segment_sum( 296 # f[:, :, None] * x[:, None, :], 297 # batch_index, 298 # num_segments=len(data["natoms"]), 299 # ) 300 301 # if "cells" in data: 302 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 303 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 304 # nsys = data["natoms"].shape[0] 305 # if cells.shape[0]==1 and nsys>1: 306 # dvir = dvir / nsys 307 # vir = vir + dvir 308 309 # out["virial_tensor"] = vir 310 311 # return out["total_energy"], f, vir, out 312 313 def energy_and_forces_and_virial(variables, data): 314 x = data["coordinates"] 315 scaling = jnp.asarray( 316 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 317 ) 318 def _etot(variables, coordinates, scaling): 319 batch_index = data["batch_index"] 320 coordinates = jax.vmap(jnp.matmul)( 321 coordinates, scaling[batch_index] 322 ) 323 inputs = {**data, "coordinates": coordinates} 324 if "cells" in data: 325 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 326 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 327 reciprocal_cells = jnp.linalg.inv(cells) 328 inputs["cells"] = cells 329 inputs["reciprocal_cells"] = reciprocal_cells 330 energy, out = total_energy(variables, inputs) 331 return energy.sum(), out 332 333 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 334 variables, x, scaling 335 ) 336 f = -dedx 337 out["forces"] = f 338 out["virial_tensor"] = vir 339 340 return out["total_energy"], f, vir, out 341 342 object.__setattr__(self, "_total_energy_raw", total_energy) 343 if jit: 344 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 345 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 346 object.__setattr__( 347 self, 348 "_energy_and_forces_and_virial", 349 jax.jit(energy_and_forces_and_virial), 350 ) 351 else: 352 object.__setattr__(self, "_total_energy", total_energy) 353 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 354 object.__setattr__( 355 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 356 ) 357 358 def get_gradient_function( 359 self, 360 *gradient_keys: Sequence[str], 361 jit: bool = True, 362 variables_as_input: bool = False, 363 ): 364 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 365 366 def _energy_gradient(variables, data): 367 def _etot(variables, inputs): 368 if "cells" in inputs: 369 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 370 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 371 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 372 return energy.sum(), out 373 374 inputs = {k: data[k] for k in gradient_keys} 375 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 376 377 return ( 378 out["total_energy"], 379 de, 380 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 381 ) 382 383 if variables_as_input: 384 energy_gradient = _energy_gradient 385 else: 386 387 def energy_gradient(data): 388 return _energy_gradient(self.variables, data) 389 390 if jit: 391 return jax.jit(energy_gradient) 392 else: 393 return energy_gradient 394 395 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 396 """apply preprocessing to the input data 397 398 !!! This is not a pure function => do not apply jax transforms !!!""" 399 if self.preproc_state is None: 400 out, _ = self.reinitialize_preprocessing(example_data=inputs) 401 elif use_gpu: 402 do_check_input = self.preproc_state.get("check_input", True) 403 if do_check_input: 404 inputs = check_input(inputs) 405 preproc_state, inputs = self.preprocessing.atom_padding( 406 self.preproc_state, inputs 407 ) 408 inputs = self.preprocessing.process(preproc_state, inputs) 409 preproc_state, state_up, out, overflow = ( 410 self.preprocessing.check_reallocate( 411 preproc_state, inputs 412 ) 413 ) 414 if verbose and overflow: 415 print("GPU preprocessing: nblist overflow => reallocating nblist") 416 print("size updates:", state_up) 417 else: 418 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 419 420 object.__setattr__(self, "preproc_state", preproc_state) 421 return out 422 423 def reinitialize_preprocessing( 424 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 425 ) -> None: 426 ### TODO ### 427 if rng_key is None: 428 rng_key_pre = jax.random.PRNGKey(0) 429 else: 430 rng_key, rng_key_pre = jax.random.split(rng_key) 431 432 if example_data is None: 433 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 434 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 435 436 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 437 object.__setattr__(self, "preproc_state", preproc_state) 438 return inputs, rng_key 439 440 def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]: 441 """Apply the FENNIX model (preprocess + modules) 442 443 !!! This is not a pure function => do not apply jax transforms !!! 444 if you want to apply jax transforms, use self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess 445 """ 446 if variables is None: 447 variables = self.variables 448 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 449 output = self._apply(variables, inputs) 450 if self.use_atom_padding: 451 output = atom_unpadding(output) 452 return output 453 454 def total_energy( 455 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 456 ) -> Tuple[jnp.ndarray, Dict]: 457 """compute the total energy of the system 458 459 !!! This is not a pure function => do not apply jax transforms !!! 460 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 461 """ 462 if variables is None: 463 variables = self.variables 464 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 465 # def print_shape(path,value): 466 # if isinstance(value,jnp.ndarray): 467 # print(path,value.shape) 468 # else: 469 # print(path,value) 470 # jax.tree_util.tree_map_with_path(print_shape,inputs) 471 _, output = self._total_energy(variables, inputs) 472 if self.use_atom_padding: 473 output = atom_unpadding(output) 474 e = output["total_energy"] 475 if unit is not None: 476 model_energy_unit = self.Ha_to_model_energy 477 if isinstance(unit, str): 478 unit = au.get_multiplier(unit) 479 e = e * (unit / model_energy_unit) 480 return e, output 481 482 def energy_and_forces( 483 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 484 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 485 """compute the total energy and forces of the system 486 487 !!! This is not a pure function => do not apply jax transforms !!! 488 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 489 """ 490 if variables is None: 491 variables = self.variables 492 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 493 _, _, output = self._energy_and_forces(variables, inputs) 494 if self.use_atom_padding: 495 output = atom_unpadding(output) 496 e = output["total_energy"] 497 f = output["forces"] 498 if unit is not None: 499 model_energy_unit = self.Ha_to_model_energy 500 if isinstance(unit, str): 501 unit = au.get_multiplier(unit) 502 e = e * (unit / model_energy_unit) 503 f = f * (unit / model_energy_unit) 504 return e, f, output 505 506 def energy_and_forces_and_virial( 507 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 508 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 509 """compute the total energy and forces of the system 510 511 !!! This is not a pure function => do not apply jax transforms !!! 512 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 513 """ 514 if variables is None: 515 variables = self.variables 516 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 517 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 518 if self.use_atom_padding: 519 output = atom_unpadding(output) 520 e = output["total_energy"] 521 f = output["forces"] 522 v = output["virial_tensor"] 523 if unit is not None: 524 model_energy_unit = self.Ha_to_model_energy 525 if isinstance(unit, str): 526 unit = au.get_multiplier(unit) 527 e = e * (unit / model_energy_unit) 528 f = f * (unit / model_energy_unit) 529 v = v * (unit / model_energy_unit) 530 return e, f, v, output 531 532 def remove_atom_padding(self, output): 533 """remove atom padding from the output""" 534 return atom_unpadding(output) 535 536 def get_model(self) -> Tuple[FENNIXModules, Dict]: 537 """return the model and its variables""" 538 return self.modules, self.variables 539 540 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 541 """return the preprocessing chain and its state""" 542 return self.preprocessing, self.preproc_state 543 544 def __setattr__(self, __name: str, __value: Any) -> None: 545 if __name == "variables": 546 if __value is not None: 547 if not ( 548 isinstance(__value, dict) 549 or isinstance(__value, OrderedDict) 550 or isinstance(__value, FrozenDict) 551 ): 552 raise ValueError(f"{__name} must be a dict") 553 object.__setattr__(self, __name, JaxConverter()(__value)) 554 else: 555 raise ValueError(f"{__name} cannot be None") 556 elif __name == "preproc_state": 557 if __value is not None: 558 if not ( 559 isinstance(__value, dict) 560 or isinstance(__value, OrderedDict) 561 or isinstance(__value, FrozenDict) 562 ): 563 raise ValueError(f"{__name} must be a FrozenDict") 564 object.__setattr__(self, __name, freeze(JaxConverter()(__value))) 565 else: 566 raise ValueError(f"{__name} cannot be None") 567 568 elif self._initializing: 569 object.__setattr__(self, __name, __value) 570 else: 571 raise ValueError(f"{__name} attribute of FENNIX model is immutable.") 572 573 def generate_dummy_system( 574 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 575 ) -> Dict[str, Any]: 576 """ 577 Generate dummy system for initialization 578 """ 579 if box_size is None: 580 box_size = 2 * self.cutoff 581 for g in self._graphs_properties.values(): 582 cutoff = g["cutoff"] 583 if cutoff is not None: 584 box_size = min(box_size, 2 * g["cutoff"]) 585 coordinates = np.array( 586 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 587 ) 588 species = np.ones((n_atoms,), dtype=np.int32) 589 batch_index = np.zeros((n_atoms,), dtype=np.int32) 590 natoms = np.array([n_atoms], dtype=np.int32) 591 return { 592 "species": species, 593 "coordinates": coordinates, 594 # "graph": graph, 595 "batch_index": batch_index, 596 "natoms": natoms, 597 } 598 599 def summarize( 600 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 601 ) -> str: 602 """Summarize the model architecture and parameters""" 603 if rng_key is None: 604 head = "Summarizing with example data:\n" 605 rng_key = jax.random.PRNGKey(0) 606 if example_data is None: 607 head = "Summarizing with dummy 10 atoms system:\n" 608 rng_key, rng_key_sys = jax.random.split(rng_key) 609 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 610 rng_key, rng_key_pre = jax.random.split(rng_key) 611 _, inputs = self.preprocessing.init_with_output(example_data) 612 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs) 613 614 def to_dict(self): 615 """return a dictionary representation of the model""" 616 return { 617 **self._input_args, 618 "energy_terms": self.energy_terms, 619 "variables": deepcopy(self.variables), 620 } 621 622 def save(self, filename): 623 """save the model to a file""" 624 state_dict = self.to_dict() 625 state_dict["preprocessing"] = [ 626 [k, v] for k, v in state_dict["preprocessing"].items() 627 ] 628 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 629 with open(filename, "wb") as f: 630 f.write(serialization.msgpack_serialize(state_dict)) 631 632 @classmethod 633 def load( 634 cls, 635 filename, 636 use_atom_padding=False, 637 graph_config={}, 638 ): 639 """load a model from a file""" 640 with open(filename, "rb") as f: 641 state_dict = serialization.msgpack_restore(f.read()) 642 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 643 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 644 return cls( 645 **state_dict, 646 graph_config=graph_config, 647 use_atom_padding=use_atom_padding, 648 )
25@dataclasses.dataclass 26class FENNIX: 27 """ 28 Static wrapper for FENNIX models 29 30 The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary 31 which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization. 32 33 Since the model is static and contains variables, it must be initialized right away with either 34 `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data` 35 is provided, the model is initialized with `example_data` and the resulting variables are stored 36 in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting. 37 """ 38 39 cutoff: Union[float, None] 40 modules: FENNIXModules 41 variables: Dict 42 preprocessing: PreprocessingChain 43 _apply: Callable[[Dict, Dict], Dict] 44 _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]] 45 _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]] 46 _input_args: Dict 47 _graphs_properties: Dict 48 preproc_state: Dict 49 energy_terms: Optional[Sequence[str]] = None 50 _initializing: bool = True 51 use_atom_padding: bool = False 52 53 def __init__( 54 self, 55 cutoff: float, 56 modules: OrderedDict, 57 preprocessing: OrderedDict = OrderedDict(), 58 example_data=None, 59 rng_key: Optional[jax.random.PRNGKey] = None, 60 variables: Optional[dict] = None, 61 energy_terms: Optional[Sequence[str]] = None, 62 use_atom_padding: bool = False, 63 graph_config: Dict = {}, 64 energy_unit: str = "Ha", 65 **kwargs, 66 ) -> None: 67 """Initialize the FENNIX model 68 69 Arguments: 70 ---------- 71 cutoff: float 72 The cutoff radius for the model 73 modules: OrderedDict 74 The dictionary defining the sequence of FeNNol modules and their parameters. 75 preprocessing: OrderedDict 76 The dictionary defining the sequence of preprocessing modules and their parameters. 77 example_data: dict 78 Example data to initialize the model. If not provided, a dummy system is generated. 79 rng_key: jax.random.PRNGKey 80 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 81 variables: dict 82 The variables of the model (i.e. weights, biases and all other tunable parameters). 83 If not provided, the variables are initialized (usually at random) 84 energy_terms: Sequence[str] 85 The energy terms in the model output that will be summed to compute the total energy. 86 If None, the total energy is always zero (useful for non-PES models). 87 use_atom_padding: bool 88 If True, the model will use atom padding for the input data. 89 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 90 graph_config: dict 91 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 92 93 """ 94 self._input_args = { 95 "cutoff": cutoff, 96 "modules": OrderedDict(modules), 97 "preprocessing": OrderedDict(preprocessing), 98 "energy_unit": energy_unit, 99 } 100 self.energy_unit = energy_unit 101 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 102 self.cutoff = cutoff 103 self.energy_terms = energy_terms 104 self.use_atom_padding = use_atom_padding 105 106 # add non-differentiable/non-jittable modules 107 preprocessing = deepcopy(preprocessing) 108 if cutoff is None: 109 preprocessing_modules = [] 110 else: 111 prep_keys = list(preprocessing.keys()) 112 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 113 if len(prep_keys) > 0 and prep_keys[0] == "graph": 114 graph_params = { 115 **graph_params, 116 **preprocessing.pop("graph"), 117 } 118 graph_params = {**graph_params, **graph_config} 119 120 preprocessing_modules = [ 121 GraphGenerator(**graph_params), 122 ] 123 124 for name, params in preprocessing.items(): 125 key = str(params.pop("module_name")) if "module_name" in params else name 126 key = str(params.pop("FID")) if "FID" in params else key 127 mod = PREPROCESSING[key.upper()](**freeze(params)) 128 preprocessing_modules.append(mod) 129 130 self.preprocessing = PreprocessingChain( 131 tuple(preprocessing_modules), use_atom_padding 132 ) 133 graphs_properties = self.preprocessing.get_graphs_properties() 134 self._graphs_properties = freeze(graphs_properties) 135 # add preprocessing modules that should be differentiated/jitted 136 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 137 # mods = self.preprocessing.get_processors(return_list=True) 138 139 # build the model 140 modules = deepcopy(modules) 141 modules_names = [] 142 for name, params in modules.items(): 143 key = str(params.pop("module_name")) if "module_name" in params else name 144 key = str(params.pop("FID")) if "FID" in params else key 145 if name in modules_names: 146 raise ValueError(f"Module {name} already exists") 147 modules_names.append(name) 148 params["name"] = name 149 mod = MODULES[key.upper()] 150 fields = [f.name for f in dataclasses.fields(mod)] 151 if "_graphs_properties" in fields: 152 params["_graphs_properties"] = graphs_properties 153 if "_energy_unit" in fields: 154 params["_energy_unit"] = energy_unit 155 mods.append((mod, params)) 156 157 self.modules = FENNIXModules(mods) 158 159 self.__apply = self.modules.apply 160 self._apply = jax.jit(self.modules.apply) 161 162 self.set_energy_terms(energy_terms) 163 164 # initialize the model 165 166 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 167 168 if variables is not None: 169 self.variables = variables 170 elif rng_key is not None: 171 self.variables = self.modules.init(rng_key, inputs) 172 else: 173 raise ValueError( 174 "Either variables or a jax.random.PRNGKey must be provided for initialization" 175 ) 176 177 self._initializing = False 178 179 def set_energy_terms( 180 self, energy_terms: Union[Sequence[str], None], jit: bool = True 181 ) -> None: 182 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 183 object.__setattr__(self, "energy_terms", energy_terms) 184 if energy_terms is None or len(energy_terms) == 0: 185 186 def total_energy(variables, data): 187 out = self.__apply(variables, data) 188 coords = out["coordinates"] 189 nsys = out["natoms"].shape[0] 190 nat = coords.shape[0] 191 dtype = coords.dtype 192 e = jnp.zeros(nsys, dtype=dtype) 193 eat = jnp.zeros(nat, dtype=dtype) 194 out["total_energy"] = e 195 out["atomic_energies"] = eat 196 return e, out 197 198 def energy_and_forces(variables, data): 199 e, out = total_energy(variables, data) 200 f = jnp.zeros_like(out["coordinates"]) 201 out["forces"] = f 202 return e, f, out 203 204 def energy_and_forces_and_virial(variables, data): 205 e, f, out = energy_and_forces(variables, data) 206 v = jnp.zeros( 207 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 208 ) 209 out["virial_tensor"] = v 210 return e, f, v, out 211 212 else: 213 # build the energy and force functions 214 def total_energy(variables, data): 215 out = self.__apply(variables, data) 216 atomic_energies = 0.0 217 system_energies = 0.0 218 species = out["species"] 219 nsys = out["natoms"].shape[0] 220 for term in self.energy_terms: 221 e = out[term] 222 if e.ndim > 1 and e.shape[-1] == 1: 223 e = jnp.squeeze(e, axis=-1) 224 if e.shape[0] == nsys and nsys != species.shape[0]: 225 system_energies += e 226 continue 227 assert e.shape == species.shape 228 atomic_energies += e 229 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 230 if isinstance(atomic_energies, jnp.ndarray): 231 if "true_atoms" in out: 232 atomic_energies = jnp.where( 233 out["true_atoms"], atomic_energies, 0.0 234 ) 235 out["atomic_energies"] = atomic_energies 236 energies = jax.ops.segment_sum( 237 atomic_energies, 238 data["batch_index"], 239 num_segments=len(data["natoms"]), 240 ) 241 else: 242 energies = 0.0 243 244 if isinstance(system_energies, jnp.ndarray): 245 if "true_sys" in out: 246 system_energies = jnp.where( 247 out["true_sys"], system_energies, 0.0 248 ) 249 out["system_energies"] = system_energies 250 251 out["total_energy"] = energies + system_energies 252 return out["total_energy"], out 253 254 def energy_and_forces(variables, data): 255 def _etot(variables, coordinates): 256 energy, out = total_energy( 257 variables, {**data, "coordinates": coordinates} 258 ) 259 return energy.sum(), out 260 261 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 262 variables, data["coordinates"] 263 ) 264 out["forces"] = -de 265 266 return out["total_energy"], out["forces"], out 267 268 # def energy_and_forces_and_virial(variables, data): 269 # x = data["coordinates"] 270 # batch_index = data["batch_index"] 271 # if "cells" in data: 272 # cells = data["cells"] 273 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 274 275 # def _etot(variables, coordinates, cells): 276 # reciprocal_cells = jnp.linalg.inv(cells) 277 # energy, out = total_energy( 278 # variables, 279 # { 280 # **data, 281 # "coordinates": coordinates, 282 # "cells": cells, 283 # "reciprocal_cells": reciprocal_cells, 284 # }, 285 # ) 286 # return energy.sum(), out 287 288 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 289 # variables, x, cells 290 # ) 291 # f= -dedx 292 # out["forces"] = f 293 # else: 294 # _,f,out = energy_and_forces(variables, data) 295 296 # vir = -jax.ops.segment_sum( 297 # f[:, :, None] * x[:, None, :], 298 # batch_index, 299 # num_segments=len(data["natoms"]), 300 # ) 301 302 # if "cells" in data: 303 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 304 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 305 # nsys = data["natoms"].shape[0] 306 # if cells.shape[0]==1 and nsys>1: 307 # dvir = dvir / nsys 308 # vir = vir + dvir 309 310 # out["virial_tensor"] = vir 311 312 # return out["total_energy"], f, vir, out 313 314 def energy_and_forces_and_virial(variables, data): 315 x = data["coordinates"] 316 scaling = jnp.asarray( 317 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 318 ) 319 def _etot(variables, coordinates, scaling): 320 batch_index = data["batch_index"] 321 coordinates = jax.vmap(jnp.matmul)( 322 coordinates, scaling[batch_index] 323 ) 324 inputs = {**data, "coordinates": coordinates} 325 if "cells" in data: 326 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 327 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 328 reciprocal_cells = jnp.linalg.inv(cells) 329 inputs["cells"] = cells 330 inputs["reciprocal_cells"] = reciprocal_cells 331 energy, out = total_energy(variables, inputs) 332 return energy.sum(), out 333 334 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 335 variables, x, scaling 336 ) 337 f = -dedx 338 out["forces"] = f 339 out["virial_tensor"] = vir 340 341 return out["total_energy"], f, vir, out 342 343 object.__setattr__(self, "_total_energy_raw", total_energy) 344 if jit: 345 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 346 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 347 object.__setattr__( 348 self, 349 "_energy_and_forces_and_virial", 350 jax.jit(energy_and_forces_and_virial), 351 ) 352 else: 353 object.__setattr__(self, "_total_energy", total_energy) 354 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 355 object.__setattr__( 356 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 357 ) 358 359 def get_gradient_function( 360 self, 361 *gradient_keys: Sequence[str], 362 jit: bool = True, 363 variables_as_input: bool = False, 364 ): 365 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 366 367 def _energy_gradient(variables, data): 368 def _etot(variables, inputs): 369 if "cells" in inputs: 370 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 371 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 372 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 373 return energy.sum(), out 374 375 inputs = {k: data[k] for k in gradient_keys} 376 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 377 378 return ( 379 out["total_energy"], 380 de, 381 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 382 ) 383 384 if variables_as_input: 385 energy_gradient = _energy_gradient 386 else: 387 388 def energy_gradient(data): 389 return _energy_gradient(self.variables, data) 390 391 if jit: 392 return jax.jit(energy_gradient) 393 else: 394 return energy_gradient 395 396 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 397 """apply preprocessing to the input data 398 399 !!! This is not a pure function => do not apply jax transforms !!!""" 400 if self.preproc_state is None: 401 out, _ = self.reinitialize_preprocessing(example_data=inputs) 402 elif use_gpu: 403 do_check_input = self.preproc_state.get("check_input", True) 404 if do_check_input: 405 inputs = check_input(inputs) 406 preproc_state, inputs = self.preprocessing.atom_padding( 407 self.preproc_state, inputs 408 ) 409 inputs = self.preprocessing.process(preproc_state, inputs) 410 preproc_state, state_up, out, overflow = ( 411 self.preprocessing.check_reallocate( 412 preproc_state, inputs 413 ) 414 ) 415 if verbose and overflow: 416 print("GPU preprocessing: nblist overflow => reallocating nblist") 417 print("size updates:", state_up) 418 else: 419 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 420 421 object.__setattr__(self, "preproc_state", preproc_state) 422 return out 423 424 def reinitialize_preprocessing( 425 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 426 ) -> None: 427 ### TODO ### 428 if rng_key is None: 429 rng_key_pre = jax.random.PRNGKey(0) 430 else: 431 rng_key, rng_key_pre = jax.random.split(rng_key) 432 433 if example_data is None: 434 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 435 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 436 437 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 438 object.__setattr__(self, "preproc_state", preproc_state) 439 return inputs, rng_key 440 441 def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]: 442 """Apply the FENNIX model (preprocess + modules) 443 444 !!! This is not a pure function => do not apply jax transforms !!! 445 if you want to apply jax transforms, use self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess 446 """ 447 if variables is None: 448 variables = self.variables 449 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 450 output = self._apply(variables, inputs) 451 if self.use_atom_padding: 452 output = atom_unpadding(output) 453 return output 454 455 def total_energy( 456 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 457 ) -> Tuple[jnp.ndarray, Dict]: 458 """compute the total energy of the system 459 460 !!! This is not a pure function => do not apply jax transforms !!! 461 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 462 """ 463 if variables is None: 464 variables = self.variables 465 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 466 # def print_shape(path,value): 467 # if isinstance(value,jnp.ndarray): 468 # print(path,value.shape) 469 # else: 470 # print(path,value) 471 # jax.tree_util.tree_map_with_path(print_shape,inputs) 472 _, output = self._total_energy(variables, inputs) 473 if self.use_atom_padding: 474 output = atom_unpadding(output) 475 e = output["total_energy"] 476 if unit is not None: 477 model_energy_unit = self.Ha_to_model_energy 478 if isinstance(unit, str): 479 unit = au.get_multiplier(unit) 480 e = e * (unit / model_energy_unit) 481 return e, output 482 483 def energy_and_forces( 484 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 485 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 486 """compute the total energy and forces of the system 487 488 !!! This is not a pure function => do not apply jax transforms !!! 489 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 490 """ 491 if variables is None: 492 variables = self.variables 493 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 494 _, _, output = self._energy_and_forces(variables, inputs) 495 if self.use_atom_padding: 496 output = atom_unpadding(output) 497 e = output["total_energy"] 498 f = output["forces"] 499 if unit is not None: 500 model_energy_unit = self.Ha_to_model_energy 501 if isinstance(unit, str): 502 unit = au.get_multiplier(unit) 503 e = e * (unit / model_energy_unit) 504 f = f * (unit / model_energy_unit) 505 return e, f, output 506 507 def energy_and_forces_and_virial( 508 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 509 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 510 """compute the total energy and forces of the system 511 512 !!! This is not a pure function => do not apply jax transforms !!! 513 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 514 """ 515 if variables is None: 516 variables = self.variables 517 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 518 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 519 if self.use_atom_padding: 520 output = atom_unpadding(output) 521 e = output["total_energy"] 522 f = output["forces"] 523 v = output["virial_tensor"] 524 if unit is not None: 525 model_energy_unit = self.Ha_to_model_energy 526 if isinstance(unit, str): 527 unit = au.get_multiplier(unit) 528 e = e * (unit / model_energy_unit) 529 f = f * (unit / model_energy_unit) 530 v = v * (unit / model_energy_unit) 531 return e, f, v, output 532 533 def remove_atom_padding(self, output): 534 """remove atom padding from the output""" 535 return atom_unpadding(output) 536 537 def get_model(self) -> Tuple[FENNIXModules, Dict]: 538 """return the model and its variables""" 539 return self.modules, self.variables 540 541 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 542 """return the preprocessing chain and its state""" 543 return self.preprocessing, self.preproc_state 544 545 def __setattr__(self, __name: str, __value: Any) -> None: 546 if __name == "variables": 547 if __value is not None: 548 if not ( 549 isinstance(__value, dict) 550 or isinstance(__value, OrderedDict) 551 or isinstance(__value, FrozenDict) 552 ): 553 raise ValueError(f"{__name} must be a dict") 554 object.__setattr__(self, __name, JaxConverter()(__value)) 555 else: 556 raise ValueError(f"{__name} cannot be None") 557 elif __name == "preproc_state": 558 if __value is not None: 559 if not ( 560 isinstance(__value, dict) 561 or isinstance(__value, OrderedDict) 562 or isinstance(__value, FrozenDict) 563 ): 564 raise ValueError(f"{__name} must be a FrozenDict") 565 object.__setattr__(self, __name, freeze(JaxConverter()(__value))) 566 else: 567 raise ValueError(f"{__name} cannot be None") 568 569 elif self._initializing: 570 object.__setattr__(self, __name, __value) 571 else: 572 raise ValueError(f"{__name} attribute of FENNIX model is immutable.") 573 574 def generate_dummy_system( 575 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 576 ) -> Dict[str, Any]: 577 """ 578 Generate dummy system for initialization 579 """ 580 if box_size is None: 581 box_size = 2 * self.cutoff 582 for g in self._graphs_properties.values(): 583 cutoff = g["cutoff"] 584 if cutoff is not None: 585 box_size = min(box_size, 2 * g["cutoff"]) 586 coordinates = np.array( 587 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 588 ) 589 species = np.ones((n_atoms,), dtype=np.int32) 590 batch_index = np.zeros((n_atoms,), dtype=np.int32) 591 natoms = np.array([n_atoms], dtype=np.int32) 592 return { 593 "species": species, 594 "coordinates": coordinates, 595 # "graph": graph, 596 "batch_index": batch_index, 597 "natoms": natoms, 598 } 599 600 def summarize( 601 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 602 ) -> str: 603 """Summarize the model architecture and parameters""" 604 if rng_key is None: 605 head = "Summarizing with example data:\n" 606 rng_key = jax.random.PRNGKey(0) 607 if example_data is None: 608 head = "Summarizing with dummy 10 atoms system:\n" 609 rng_key, rng_key_sys = jax.random.split(rng_key) 610 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 611 rng_key, rng_key_pre = jax.random.split(rng_key) 612 _, inputs = self.preprocessing.init_with_output(example_data) 613 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs) 614 615 def to_dict(self): 616 """return a dictionary representation of the model""" 617 return { 618 **self._input_args, 619 "energy_terms": self.energy_terms, 620 "variables": deepcopy(self.variables), 621 } 622 623 def save(self, filename): 624 """save the model to a file""" 625 state_dict = self.to_dict() 626 state_dict["preprocessing"] = [ 627 [k, v] for k, v in state_dict["preprocessing"].items() 628 ] 629 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 630 with open(filename, "wb") as f: 631 f.write(serialization.msgpack_serialize(state_dict)) 632 633 @classmethod 634 def load( 635 cls, 636 filename, 637 use_atom_padding=False, 638 graph_config={}, 639 ): 640 """load a model from a file""" 641 with open(filename, "rb") as f: 642 state_dict = serialization.msgpack_restore(f.read()) 643 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 644 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 645 return cls( 646 **state_dict, 647 graph_config=graph_config, 648 use_atom_padding=use_atom_padding, 649 )
Static wrapper for FENNIX models
The underlying model is a fennol.models.modules.FENNIXModules
built from the modules
dictionary
which references registered modules in fennol.models.modules.MODULES
and provides the parameters for initialization.
Since the model is static and contains variables, it must be initialized right away with either
example_data
, variables
or rng_key
. If variables
is provided, it is used directly. If example_data
is provided, the model is initialized with example_data
and the resulting variables are stored
in the wrapper. If only rng_key
is provided, the model is initialized with a dummy system and the resulting.
53 def __init__( 54 self, 55 cutoff: float, 56 modules: OrderedDict, 57 preprocessing: OrderedDict = OrderedDict(), 58 example_data=None, 59 rng_key: Optional[jax.random.PRNGKey] = None, 60 variables: Optional[dict] = None, 61 energy_terms: Optional[Sequence[str]] = None, 62 use_atom_padding: bool = False, 63 graph_config: Dict = {}, 64 energy_unit: str = "Ha", 65 **kwargs, 66 ) -> None: 67 """Initialize the FENNIX model 68 69 Arguments: 70 ---------- 71 cutoff: float 72 The cutoff radius for the model 73 modules: OrderedDict 74 The dictionary defining the sequence of FeNNol modules and their parameters. 75 preprocessing: OrderedDict 76 The dictionary defining the sequence of preprocessing modules and their parameters. 77 example_data: dict 78 Example data to initialize the model. If not provided, a dummy system is generated. 79 rng_key: jax.random.PRNGKey 80 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 81 variables: dict 82 The variables of the model (i.e. weights, biases and all other tunable parameters). 83 If not provided, the variables are initialized (usually at random) 84 energy_terms: Sequence[str] 85 The energy terms in the model output that will be summed to compute the total energy. 86 If None, the total energy is always zero (useful for non-PES models). 87 use_atom_padding: bool 88 If True, the model will use atom padding for the input data. 89 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 90 graph_config: dict 91 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 92 93 """ 94 self._input_args = { 95 "cutoff": cutoff, 96 "modules": OrderedDict(modules), 97 "preprocessing": OrderedDict(preprocessing), 98 "energy_unit": energy_unit, 99 } 100 self.energy_unit = energy_unit 101 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 102 self.cutoff = cutoff 103 self.energy_terms = energy_terms 104 self.use_atom_padding = use_atom_padding 105 106 # add non-differentiable/non-jittable modules 107 preprocessing = deepcopy(preprocessing) 108 if cutoff is None: 109 preprocessing_modules = [] 110 else: 111 prep_keys = list(preprocessing.keys()) 112 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 113 if len(prep_keys) > 0 and prep_keys[0] == "graph": 114 graph_params = { 115 **graph_params, 116 **preprocessing.pop("graph"), 117 } 118 graph_params = {**graph_params, **graph_config} 119 120 preprocessing_modules = [ 121 GraphGenerator(**graph_params), 122 ] 123 124 for name, params in preprocessing.items(): 125 key = str(params.pop("module_name")) if "module_name" in params else name 126 key = str(params.pop("FID")) if "FID" in params else key 127 mod = PREPROCESSING[key.upper()](**freeze(params)) 128 preprocessing_modules.append(mod) 129 130 self.preprocessing = PreprocessingChain( 131 tuple(preprocessing_modules), use_atom_padding 132 ) 133 graphs_properties = self.preprocessing.get_graphs_properties() 134 self._graphs_properties = freeze(graphs_properties) 135 # add preprocessing modules that should be differentiated/jitted 136 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 137 # mods = self.preprocessing.get_processors(return_list=True) 138 139 # build the model 140 modules = deepcopy(modules) 141 modules_names = [] 142 for name, params in modules.items(): 143 key = str(params.pop("module_name")) if "module_name" in params else name 144 key = str(params.pop("FID")) if "FID" in params else key 145 if name in modules_names: 146 raise ValueError(f"Module {name} already exists") 147 modules_names.append(name) 148 params["name"] = name 149 mod = MODULES[key.upper()] 150 fields = [f.name for f in dataclasses.fields(mod)] 151 if "_graphs_properties" in fields: 152 params["_graphs_properties"] = graphs_properties 153 if "_energy_unit" in fields: 154 params["_energy_unit"] = energy_unit 155 mods.append((mod, params)) 156 157 self.modules = FENNIXModules(mods) 158 159 self.__apply = self.modules.apply 160 self._apply = jax.jit(self.modules.apply) 161 162 self.set_energy_terms(energy_terms) 163 164 # initialize the model 165 166 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 167 168 if variables is not None: 169 self.variables = variables 170 elif rng_key is not None: 171 self.variables = self.modules.init(rng_key, inputs) 172 else: 173 raise ValueError( 174 "Either variables or a jax.random.PRNGKey must be provided for initialization" 175 ) 176 177 self._initializing = False
Initialize the FENNIX model
Arguments:
cutoff: float The cutoff radius for the model modules: OrderedDict The dictionary defining the sequence of FeNNol modules and their parameters. preprocessing: OrderedDict The dictionary defining the sequence of preprocessing modules and their parameters. example_data: dict Example data to initialize the model. If not provided, a dummy system is generated. rng_key: jax.random.PRNGKey The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). variables: dict The variables of the model (i.e. weights, biases and all other tunable parameters). If not provided, the variables are initialized (usually at random) energy_terms: Sequence[str] The energy terms in the model output that will be summed to compute the total energy. If None, the total energy is always zero (useful for non-PES models). use_atom_padding: bool If True, the model will use atom padding for the input data. This is useful when one plans to frequently change the number of atoms in the system (for example during training). graph_config: dict Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
179 def set_energy_terms( 180 self, energy_terms: Union[Sequence[str], None], jit: bool = True 181 ) -> None: 182 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 183 object.__setattr__(self, "energy_terms", energy_terms) 184 if energy_terms is None or len(energy_terms) == 0: 185 186 def total_energy(variables, data): 187 out = self.__apply(variables, data) 188 coords = out["coordinates"] 189 nsys = out["natoms"].shape[0] 190 nat = coords.shape[0] 191 dtype = coords.dtype 192 e = jnp.zeros(nsys, dtype=dtype) 193 eat = jnp.zeros(nat, dtype=dtype) 194 out["total_energy"] = e 195 out["atomic_energies"] = eat 196 return e, out 197 198 def energy_and_forces(variables, data): 199 e, out = total_energy(variables, data) 200 f = jnp.zeros_like(out["coordinates"]) 201 out["forces"] = f 202 return e, f, out 203 204 def energy_and_forces_and_virial(variables, data): 205 e, f, out = energy_and_forces(variables, data) 206 v = jnp.zeros( 207 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 208 ) 209 out["virial_tensor"] = v 210 return e, f, v, out 211 212 else: 213 # build the energy and force functions 214 def total_energy(variables, data): 215 out = self.__apply(variables, data) 216 atomic_energies = 0.0 217 system_energies = 0.0 218 species = out["species"] 219 nsys = out["natoms"].shape[0] 220 for term in self.energy_terms: 221 e = out[term] 222 if e.ndim > 1 and e.shape[-1] == 1: 223 e = jnp.squeeze(e, axis=-1) 224 if e.shape[0] == nsys and nsys != species.shape[0]: 225 system_energies += e 226 continue 227 assert e.shape == species.shape 228 atomic_energies += e 229 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 230 if isinstance(atomic_energies, jnp.ndarray): 231 if "true_atoms" in out: 232 atomic_energies = jnp.where( 233 out["true_atoms"], atomic_energies, 0.0 234 ) 235 out["atomic_energies"] = atomic_energies 236 energies = jax.ops.segment_sum( 237 atomic_energies, 238 data["batch_index"], 239 num_segments=len(data["natoms"]), 240 ) 241 else: 242 energies = 0.0 243 244 if isinstance(system_energies, jnp.ndarray): 245 if "true_sys" in out: 246 system_energies = jnp.where( 247 out["true_sys"], system_energies, 0.0 248 ) 249 out["system_energies"] = system_energies 250 251 out["total_energy"] = energies + system_energies 252 return out["total_energy"], out 253 254 def energy_and_forces(variables, data): 255 def _etot(variables, coordinates): 256 energy, out = total_energy( 257 variables, {**data, "coordinates": coordinates} 258 ) 259 return energy.sum(), out 260 261 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 262 variables, data["coordinates"] 263 ) 264 out["forces"] = -de 265 266 return out["total_energy"], out["forces"], out 267 268 # def energy_and_forces_and_virial(variables, data): 269 # x = data["coordinates"] 270 # batch_index = data["batch_index"] 271 # if "cells" in data: 272 # cells = data["cells"] 273 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 274 275 # def _etot(variables, coordinates, cells): 276 # reciprocal_cells = jnp.linalg.inv(cells) 277 # energy, out = total_energy( 278 # variables, 279 # { 280 # **data, 281 # "coordinates": coordinates, 282 # "cells": cells, 283 # "reciprocal_cells": reciprocal_cells, 284 # }, 285 # ) 286 # return energy.sum(), out 287 288 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 289 # variables, x, cells 290 # ) 291 # f= -dedx 292 # out["forces"] = f 293 # else: 294 # _,f,out = energy_and_forces(variables, data) 295 296 # vir = -jax.ops.segment_sum( 297 # f[:, :, None] * x[:, None, :], 298 # batch_index, 299 # num_segments=len(data["natoms"]), 300 # ) 301 302 # if "cells" in data: 303 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 304 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 305 # nsys = data["natoms"].shape[0] 306 # if cells.shape[0]==1 and nsys>1: 307 # dvir = dvir / nsys 308 # vir = vir + dvir 309 310 # out["virial_tensor"] = vir 311 312 # return out["total_energy"], f, vir, out 313 314 def energy_and_forces_and_virial(variables, data): 315 x = data["coordinates"] 316 scaling = jnp.asarray( 317 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 318 ) 319 def _etot(variables, coordinates, scaling): 320 batch_index = data["batch_index"] 321 coordinates = jax.vmap(jnp.matmul)( 322 coordinates, scaling[batch_index] 323 ) 324 inputs = {**data, "coordinates": coordinates} 325 if "cells" in data: 326 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 327 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 328 reciprocal_cells = jnp.linalg.inv(cells) 329 inputs["cells"] = cells 330 inputs["reciprocal_cells"] = reciprocal_cells 331 energy, out = total_energy(variables, inputs) 332 return energy.sum(), out 333 334 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 335 variables, x, scaling 336 ) 337 f = -dedx 338 out["forces"] = f 339 out["virial_tensor"] = vir 340 341 return out["total_energy"], f, vir, out 342 343 object.__setattr__(self, "_total_energy_raw", total_energy) 344 if jit: 345 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 346 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 347 object.__setattr__( 348 self, 349 "_energy_and_forces_and_virial", 350 jax.jit(energy_and_forces_and_virial), 351 ) 352 else: 353 object.__setattr__(self, "_total_energy", total_energy) 354 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 355 object.__setattr__( 356 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 357 )
Set the energy terms to be computed by the model and prepare the energy and force functions.
359 def get_gradient_function( 360 self, 361 *gradient_keys: Sequence[str], 362 jit: bool = True, 363 variables_as_input: bool = False, 364 ): 365 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 366 367 def _energy_gradient(variables, data): 368 def _etot(variables, inputs): 369 if "cells" in inputs: 370 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 371 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 372 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 373 return energy.sum(), out 374 375 inputs = {k: data[k] for k in gradient_keys} 376 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 377 378 return ( 379 out["total_energy"], 380 de, 381 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 382 ) 383 384 if variables_as_input: 385 energy_gradient = _energy_gradient 386 else: 387 388 def energy_gradient(data): 389 return _energy_gradient(self.variables, data) 390 391 if jit: 392 return jax.jit(energy_gradient) 393 else: 394 return energy_gradient
Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys
396 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 397 """apply preprocessing to the input data 398 399 !!! This is not a pure function => do not apply jax transforms !!!""" 400 if self.preproc_state is None: 401 out, _ = self.reinitialize_preprocessing(example_data=inputs) 402 elif use_gpu: 403 do_check_input = self.preproc_state.get("check_input", True) 404 if do_check_input: 405 inputs = check_input(inputs) 406 preproc_state, inputs = self.preprocessing.atom_padding( 407 self.preproc_state, inputs 408 ) 409 inputs = self.preprocessing.process(preproc_state, inputs) 410 preproc_state, state_up, out, overflow = ( 411 self.preprocessing.check_reallocate( 412 preproc_state, inputs 413 ) 414 ) 415 if verbose and overflow: 416 print("GPU preprocessing: nblist overflow => reallocating nblist") 417 print("size updates:", state_up) 418 else: 419 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 420 421 object.__setattr__(self, "preproc_state", preproc_state) 422 return out
apply preprocessing to the input data
!!! This is not a pure function => do not apply jax transforms !!!
424 def reinitialize_preprocessing( 425 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 426 ) -> None: 427 ### TODO ### 428 if rng_key is None: 429 rng_key_pre = jax.random.PRNGKey(0) 430 else: 431 rng_key, rng_key_pre = jax.random.split(rng_key) 432 433 if example_data is None: 434 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 435 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 436 437 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 438 object.__setattr__(self, "preproc_state", preproc_state) 439 return inputs, rng_key
455 def total_energy( 456 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 457 ) -> Tuple[jnp.ndarray, Dict]: 458 """compute the total energy of the system 459 460 !!! This is not a pure function => do not apply jax transforms !!! 461 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 462 """ 463 if variables is None: 464 variables = self.variables 465 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 466 # def print_shape(path,value): 467 # if isinstance(value,jnp.ndarray): 468 # print(path,value.shape) 469 # else: 470 # print(path,value) 471 # jax.tree_util.tree_map_with_path(print_shape,inputs) 472 _, output = self._total_energy(variables, inputs) 473 if self.use_atom_padding: 474 output = atom_unpadding(output) 475 e = output["total_energy"] 476 if unit is not None: 477 model_energy_unit = self.Ha_to_model_energy 478 if isinstance(unit, str): 479 unit = au.get_multiplier(unit) 480 e = e * (unit / model_energy_unit) 481 return e, output
compute the total energy of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
483 def energy_and_forces( 484 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 485 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 486 """compute the total energy and forces of the system 487 488 !!! This is not a pure function => do not apply jax transforms !!! 489 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 490 """ 491 if variables is None: 492 variables = self.variables 493 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 494 _, _, output = self._energy_and_forces(variables, inputs) 495 if self.use_atom_padding: 496 output = atom_unpadding(output) 497 e = output["total_energy"] 498 f = output["forces"] 499 if unit is not None: 500 model_energy_unit = self.Ha_to_model_energy 501 if isinstance(unit, str): 502 unit = au.get_multiplier(unit) 503 e = e * (unit / model_energy_unit) 504 f = f * (unit / model_energy_unit) 505 return e, f, output
compute the total energy and forces of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
507 def energy_and_forces_and_virial( 508 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 509 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 510 """compute the total energy and forces of the system 511 512 !!! This is not a pure function => do not apply jax transforms !!! 513 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 514 """ 515 if variables is None: 516 variables = self.variables 517 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 518 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 519 if self.use_atom_padding: 520 output = atom_unpadding(output) 521 e = output["total_energy"] 522 f = output["forces"] 523 v = output["virial_tensor"] 524 if unit is not None: 525 model_energy_unit = self.Ha_to_model_energy 526 if isinstance(unit, str): 527 unit = au.get_multiplier(unit) 528 e = e * (unit / model_energy_unit) 529 f = f * (unit / model_energy_unit) 530 v = v * (unit / model_energy_unit) 531 return e, f, v, output
compute the total energy and forces of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
533 def remove_atom_padding(self, output): 534 """remove atom padding from the output""" 535 return atom_unpadding(output)
remove atom padding from the output
537 def get_model(self) -> Tuple[FENNIXModules, Dict]: 538 """return the model and its variables""" 539 return self.modules, self.variables
return the model and its variables
541 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 542 """return the preprocessing chain and its state""" 543 return self.preprocessing, self.preproc_state
return the preprocessing chain and its state
574 def generate_dummy_system( 575 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 576 ) -> Dict[str, Any]: 577 """ 578 Generate dummy system for initialization 579 """ 580 if box_size is None: 581 box_size = 2 * self.cutoff 582 for g in self._graphs_properties.values(): 583 cutoff = g["cutoff"] 584 if cutoff is not None: 585 box_size = min(box_size, 2 * g["cutoff"]) 586 coordinates = np.array( 587 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 588 ) 589 species = np.ones((n_atoms,), dtype=np.int32) 590 batch_index = np.zeros((n_atoms,), dtype=np.int32) 591 natoms = np.array([n_atoms], dtype=np.int32) 592 return { 593 "species": species, 594 "coordinates": coordinates, 595 # "graph": graph, 596 "batch_index": batch_index, 597 "natoms": natoms, 598 }
Generate dummy system for initialization
600 def summarize( 601 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 602 ) -> str: 603 """Summarize the model architecture and parameters""" 604 if rng_key is None: 605 head = "Summarizing with example data:\n" 606 rng_key = jax.random.PRNGKey(0) 607 if example_data is None: 608 head = "Summarizing with dummy 10 atoms system:\n" 609 rng_key, rng_key_sys = jax.random.split(rng_key) 610 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 611 rng_key, rng_key_pre = jax.random.split(rng_key) 612 _, inputs = self.preprocessing.init_with_output(example_data) 613 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
Summarize the model architecture and parameters
615 def to_dict(self): 616 """return a dictionary representation of the model""" 617 return { 618 **self._input_args, 619 "energy_terms": self.energy_terms, 620 "variables": deepcopy(self.variables), 621 }
return a dictionary representation of the model
623 def save(self, filename): 624 """save the model to a file""" 625 state_dict = self.to_dict() 626 state_dict["preprocessing"] = [ 627 [k, v] for k, v in state_dict["preprocessing"].items() 628 ] 629 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 630 with open(filename, "wb") as f: 631 f.write(serialization.msgpack_serialize(state_dict))
save the model to a file
633 @classmethod 634 def load( 635 cls, 636 filename, 637 use_atom_padding=False, 638 graph_config={}, 639 ): 640 """load a model from a file""" 641 with open(filename, "rb") as f: 642 state_dict = serialization.msgpack_restore(f.read()) 643 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 644 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 645 return cls( 646 **state_dict, 647 graph_config=graph_config, 648 use_atom_padding=use_atom_padding, 649 )
load a model from a file