fennol.models.fennix
1from typing import Any, Sequence, Callable, Union, Optional, Tuple, Dict 2from copy import deepcopy 3import dataclasses 4from collections import OrderedDict 5 6import jax 7import jax.numpy as jnp 8import flax.linen as nn 9import numpy as np 10from flax import serialization 11from flax.core.frozen_dict import freeze, unfreeze, FrozenDict 12from ..utils import AtomicUnits as au 13 14from .preprocessing import ( 15 GraphGenerator, 16 PreprocessingChain, 17 JaxConverter, 18 atom_unpadding, 19 check_input, 20) 21from .modules import MODULES, PREPROCESSING, FENNIXModules 22 23 24@dataclasses.dataclass 25class FENNIX: 26 """ 27 Static wrapper for FENNIX models 28 29 The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary 30 which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization. 31 32 Since the model is static and contains variables, it must be initialized right away with either 33 `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data` 34 is provided, the model is initialized with `example_data` and the resulting variables are stored 35 in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting. 36 """ 37 38 cutoff: Union[float, None] 39 modules: FENNIXModules 40 variables: Dict 41 preprocessing: PreprocessingChain 42 _apply: Callable[[Dict, Dict], Dict] 43 _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]] 44 _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]] 45 _input_args: Dict 46 _graphs_properties: Dict 47 preproc_state: Dict 48 energy_terms: Optional[Sequence[str]] = None 49 _initializing: bool = True 50 use_atom_padding: bool = False 51 52 def __init__( 53 self, 54 cutoff: float, 55 modules: OrderedDict, 56 preprocessing: OrderedDict = OrderedDict(), 57 example_data=None, 58 rng_key: Optional[jax.random.PRNGKey] = None, 59 variables: Optional[dict] = None, 60 energy_terms: Optional[Sequence[str]] = None, 61 use_atom_padding: bool = False, 62 graph_config: Dict = {}, 63 energy_unit: str = "Ha", 64 **kwargs, 65 ) -> None: 66 """Initialize the FENNIX model 67 68 Arguments: 69 ---------- 70 cutoff: float 71 The cutoff radius for the model 72 modules: OrderedDict 73 The dictionary defining the sequence of FeNNol modules and their parameters. 74 preprocessing: OrderedDict 75 The dictionary defining the sequence of preprocessing modules and their parameters. 76 example_data: dict 77 Example data to initialize the model. If not provided, a dummy system is generated. 78 rng_key: jax.random.PRNGKey 79 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 80 variables: dict 81 The variables of the model (i.e. weights, biases and all other tunable parameters). 82 If not provided, the variables are initialized (usually at random) 83 energy_terms: Sequence[str] 84 The energy terms in the model output that will be summed to compute the total energy. 85 If None, the total energy is always zero (useful for non-PES models). 86 use_atom_padding: bool 87 If True, the model will use atom padding for the input data. 88 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 89 graph_config: dict 90 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 91 92 """ 93 self._input_args = { 94 "cutoff": cutoff, 95 "modules": OrderedDict(modules), 96 "preprocessing": OrderedDict(preprocessing), 97 "energy_unit": energy_unit, 98 } 99 self.energy_unit = energy_unit 100 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 101 self.cutoff = cutoff 102 self.energy_terms = energy_terms 103 self.use_atom_padding = use_atom_padding 104 105 # add non-differentiable/non-jittable modules 106 preprocessing = deepcopy(preprocessing) 107 if cutoff is None: 108 preprocessing_modules = [] 109 else: 110 prep_keys = list(preprocessing.keys()) 111 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 112 if len(prep_keys) > 0 and prep_keys[0] == "graph": 113 graph_params = { 114 **graph_params, 115 **preprocessing.pop("graph"), 116 } 117 graph_params = {**graph_params, **graph_config} 118 119 preprocessing_modules = [ 120 GraphGenerator(**graph_params), 121 ] 122 123 for name, params in preprocessing.items(): 124 key = str(params.pop("module_name")) if "module_name" in params else name 125 key = str(params.pop("FID")) if "FID" in params else key 126 mod = PREPROCESSING[key.upper()](**freeze(params)) 127 preprocessing_modules.append(mod) 128 129 self.preprocessing = PreprocessingChain( 130 tuple(preprocessing_modules), use_atom_padding 131 ) 132 graphs_properties = self.preprocessing.get_graphs_properties() 133 self._graphs_properties = freeze(graphs_properties) 134 # add preprocessing modules that should be differentiated/jitted 135 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 136 # mods = self.preprocessing.get_processors(return_list=True) 137 138 # build the model 139 modules = deepcopy(modules) 140 modules_names = [] 141 for name, params in modules.items(): 142 key = str(params.pop("module_name")) if "module_name" in params else name 143 key = str(params.pop("FID")) if "FID" in params else key 144 if name in modules_names: 145 raise ValueError(f"Module {name} already exists") 146 modules_names.append(name) 147 params["name"] = name 148 mod = MODULES[key.upper()] 149 fields = [f.name for f in dataclasses.fields(mod)] 150 if "_graphs_properties" in fields: 151 params["_graphs_properties"] = graphs_properties 152 if "_energy_unit" in fields: 153 params["_energy_unit"] = energy_unit 154 mods.append((mod, params)) 155 156 self.modules = FENNIXModules(mods) 157 158 self.__apply = self.modules.apply 159 self._apply = jax.jit(self.modules.apply) 160 161 self.set_energy_terms(energy_terms) 162 163 # initialize the model 164 165 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 166 167 if variables is not None: 168 self.variables = variables 169 elif rng_key is not None: 170 self.variables = self.modules.init(rng_key, inputs) 171 else: 172 raise ValueError( 173 "Either variables or a jax.random.PRNGKey must be provided for initialization" 174 ) 175 176 self._initializing = False 177 178 def set_energy_terms( 179 self, energy_terms: Union[Sequence[str], None], jit: bool = True 180 ) -> None: 181 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 182 object.__setattr__(self, "energy_terms", energy_terms) 183 if isinstance(energy_terms, str): 184 energy_terms = [energy_terms] 185 186 if energy_terms is None or len(energy_terms) == 0: 187 188 def total_energy(variables, data): 189 out = self.__apply(variables, data) 190 coords = out["coordinates"] 191 nsys = out["natoms"].shape[0] 192 nat = coords.shape[0] 193 dtype = coords.dtype 194 e = jnp.zeros(nsys, dtype=dtype) 195 eat = jnp.zeros(nat, dtype=dtype) 196 out["total_energy"] = e 197 out["atomic_energies"] = eat 198 return e, out 199 200 def energy_and_forces(variables, data): 201 e, out = total_energy(variables, data) 202 f = jnp.zeros_like(out["coordinates"]) 203 out["forces"] = f 204 return e, f, out 205 206 def energy_and_forces_and_virial(variables, data): 207 e, f, out = energy_and_forces(variables, data) 208 v = jnp.zeros( 209 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 210 ) 211 out["virial_tensor"] = v 212 return e, f, v, out 213 214 else: 215 # build the energy and force functions 216 def total_energy(variables, data): 217 out = self.__apply(variables, data) 218 atomic_energies = 0.0 219 system_energies = 0.0 220 species = out["species"] 221 nsys = out["natoms"].shape[0] 222 for term in self.energy_terms: 223 e = out[term] 224 if e.ndim > 1 and e.shape[-1] == 1: 225 e = jnp.squeeze(e, axis=-1) 226 if e.shape[0] == nsys and nsys != species.shape[0]: 227 system_energies += e 228 continue 229 assert e.shape == species.shape 230 atomic_energies += e 231 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 232 if isinstance(atomic_energies, jnp.ndarray): 233 if "true_atoms" in out: 234 atomic_energies = jnp.where( 235 out["true_atoms"], atomic_energies, 0.0 236 ) 237 out["atomic_energies"] = atomic_energies 238 energies = jax.ops.segment_sum( 239 atomic_energies, 240 data["batch_index"], 241 num_segments=len(data["natoms"]), 242 ) 243 else: 244 energies = 0.0 245 246 if isinstance(system_energies, jnp.ndarray): 247 if "true_sys" in out: 248 system_energies = jnp.where( 249 out["true_sys"], system_energies, 0.0 250 ) 251 out["system_energies"] = system_energies 252 253 out["total_energy"] = energies + system_energies 254 return out["total_energy"], out 255 256 def energy_and_forces(variables, data): 257 def _etot(variables, coordinates): 258 energy, out = total_energy( 259 variables, {**data, "coordinates": coordinates} 260 ) 261 return energy.sum(), out 262 263 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 264 variables, data["coordinates"] 265 ) 266 out["forces"] = -de 267 268 return out["total_energy"], out["forces"], out 269 270 # def energy_and_forces_and_virial(variables, data): 271 # x = data["coordinates"] 272 # batch_index = data["batch_index"] 273 # if "cells" in data: 274 # cells = data["cells"] 275 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 276 277 # def _etot(variables, coordinates, cells): 278 # reciprocal_cells = jnp.linalg.inv(cells) 279 # energy, out = total_energy( 280 # variables, 281 # { 282 # **data, 283 # "coordinates": coordinates, 284 # "cells": cells, 285 # "reciprocal_cells": reciprocal_cells, 286 # }, 287 # ) 288 # return energy.sum(), out 289 290 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 291 # variables, x, cells 292 # ) 293 # f= -dedx 294 # out["forces"] = f 295 # else: 296 # _,f,out = energy_and_forces(variables, data) 297 298 # vir = -jax.ops.segment_sum( 299 # f[:, :, None] * x[:, None, :], 300 # batch_index, 301 # num_segments=len(data["natoms"]), 302 # ) 303 304 # if "cells" in data: 305 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 306 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 307 # nsys = data["natoms"].shape[0] 308 # if cells.shape[0]==1 and nsys>1: 309 # dvir = dvir / nsys 310 # vir = vir + dvir 311 312 # out["virial_tensor"] = vir 313 314 # return out["total_energy"], f, vir, out 315 316 def energy_and_forces_and_virial(variables, data): 317 x = data["coordinates"] 318 scaling = jnp.asarray( 319 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 320 ) 321 def _etot(variables, coordinates, scaling): 322 batch_index = data["batch_index"] 323 coordinates = jax.vmap(jnp.matmul)( 324 coordinates, scaling[batch_index] 325 ) 326 inputs = {**data, "coordinates": coordinates} 327 if "cells" in data: 328 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 329 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 330 reciprocal_cells = jnp.linalg.inv(cells) 331 inputs["cells"] = cells 332 inputs["reciprocal_cells"] = reciprocal_cells 333 energy, out = total_energy(variables, inputs) 334 return energy.sum(), out 335 336 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 337 variables, x, scaling 338 ) 339 f = -dedx 340 out["forces"] = f 341 out["virial_tensor"] = vir 342 343 return out["total_energy"], f, vir, out 344 345 object.__setattr__(self, "_total_energy_raw", total_energy) 346 if jit: 347 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 348 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 349 object.__setattr__( 350 self, 351 "_energy_and_forces_and_virial", 352 jax.jit(energy_and_forces_and_virial), 353 ) 354 else: 355 object.__setattr__(self, "_total_energy", total_energy) 356 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 357 object.__setattr__( 358 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 359 ) 360 361 def get_gradient_function( 362 self, 363 *gradient_keys: Sequence[str], 364 jit: bool = True, 365 variables_as_input: bool = False, 366 ): 367 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 368 369 def _energy_gradient(variables, data): 370 def _etot(variables, inputs): 371 if "strain" in inputs: 372 scaling = inputs["strain"] 373 batch_index = data["batch_index"] 374 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 375 coordinates = jax.vmap(jnp.matmul)( 376 coordinates, scaling[batch_index] 377 ) 378 inputs = {**inputs, "coordinates": coordinates} 379 if "cells" in inputs or "cells" in data: 380 cells = inputs["cells"] if "cells" in inputs else data["cells"] 381 cells = jax.vmap(jnp.matmul)(cells, scaling) 382 inputs["cells"] = cells 383 if "cells" in inputs: 384 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 385 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 386 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 387 return energy.sum(), out 388 389 if "strain" in gradient_keys and "strain" not in data: 390 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 391 inputs = {k: data[k] for k in gradient_keys} 392 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 393 394 return ( 395 out["total_energy"], 396 de, 397 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 398 ) 399 400 if variables_as_input: 401 energy_gradient = _energy_gradient 402 else: 403 404 def energy_gradient(data): 405 return _energy_gradient(self.variables, data) 406 407 if jit: 408 return jax.jit(energy_gradient) 409 else: 410 return energy_gradient 411 412 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 413 """apply preprocessing to the input data 414 415 !!! This is not a pure function => do not apply jax transforms !!!""" 416 if self.preproc_state is None: 417 out, _ = self.reinitialize_preprocessing(example_data=inputs) 418 elif use_gpu: 419 do_check_input = self.preproc_state.get("check_input", True) 420 if do_check_input: 421 inputs = check_input(inputs) 422 preproc_state, inputs = self.preprocessing.atom_padding( 423 self.preproc_state, inputs 424 ) 425 inputs = self.preprocessing.process(preproc_state, inputs) 426 preproc_state, state_up, out, overflow = ( 427 self.preprocessing.check_reallocate( 428 preproc_state, inputs 429 ) 430 ) 431 if verbose and overflow: 432 print("GPU preprocessing: nblist overflow => reallocating nblist") 433 print("size updates:", state_up) 434 else: 435 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 436 437 object.__setattr__(self, "preproc_state", preproc_state) 438 return out 439 440 def reinitialize_preprocessing( 441 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 442 ) -> None: 443 ### TODO ### 444 if rng_key is None: 445 rng_key_pre = jax.random.PRNGKey(0) 446 else: 447 rng_key, rng_key_pre = jax.random.split(rng_key) 448 449 if example_data is None: 450 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 451 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 452 453 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 454 object.__setattr__(self, "preproc_state", preproc_state) 455 return inputs, rng_key 456 457 def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]: 458 """Apply the FENNIX model (preprocess + modules) 459 460 !!! This is not a pure function => do not apply jax transforms !!! 461 if you want to apply jax transforms, use self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess 462 """ 463 if variables is None: 464 variables = self.variables 465 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 466 output = self._apply(variables, inputs) 467 if self.use_atom_padding: 468 output = atom_unpadding(output) 469 return output 470 471 def total_energy( 472 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 473 ) -> Tuple[jnp.ndarray, Dict]: 474 """compute the total energy of the system 475 476 !!! This is not a pure function => do not apply jax transforms !!! 477 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 478 """ 479 if variables is None: 480 variables = self.variables 481 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 482 # def print_shape(path,value): 483 # if isinstance(value,jnp.ndarray): 484 # print(path,value.shape) 485 # else: 486 # print(path,value) 487 # jax.tree_util.tree_map_with_path(print_shape,inputs) 488 _, output = self._total_energy(variables, inputs) 489 if self.use_atom_padding: 490 output = atom_unpadding(output) 491 e = output["total_energy"] 492 if unit is not None: 493 model_energy_unit = self.Ha_to_model_energy 494 if isinstance(unit, str): 495 unit = au.get_multiplier(unit) 496 e = e * (unit / model_energy_unit) 497 return e, output 498 499 def energy_and_forces( 500 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 501 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 502 """compute the total energy and forces of the system 503 504 !!! This is not a pure function => do not apply jax transforms !!! 505 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 506 """ 507 if variables is None: 508 variables = self.variables 509 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 510 _, _, output = self._energy_and_forces(variables, inputs) 511 if self.use_atom_padding: 512 output = atom_unpadding(output) 513 e = output["total_energy"] 514 f = output["forces"] 515 if unit is not None: 516 model_energy_unit = self.Ha_to_model_energy 517 if isinstance(unit, str): 518 unit = au.get_multiplier(unit) 519 e = e * (unit / model_energy_unit) 520 f = f * (unit / model_energy_unit) 521 return e, f, output 522 523 def energy_and_forces_and_virial( 524 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 525 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 526 """compute the total energy and forces of the system 527 528 !!! This is not a pure function => do not apply jax transforms !!! 529 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 530 """ 531 if variables is None: 532 variables = self.variables 533 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 534 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 535 if self.use_atom_padding: 536 output = atom_unpadding(output) 537 e = output["total_energy"] 538 f = output["forces"] 539 v = output["virial_tensor"] 540 if unit is not None: 541 model_energy_unit = self.Ha_to_model_energy 542 if isinstance(unit, str): 543 unit = au.get_multiplier(unit) 544 e = e * (unit / model_energy_unit) 545 f = f * (unit / model_energy_unit) 546 v = v * (unit / model_energy_unit) 547 return e, f, v, output 548 549 def remove_atom_padding(self, output): 550 """remove atom padding from the output""" 551 return atom_unpadding(output) 552 553 def get_model(self) -> Tuple[FENNIXModules, Dict]: 554 """return the model and its variables""" 555 return self.modules, self.variables 556 557 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 558 """return the preprocessing chain and its state""" 559 return self.preprocessing, self.preproc_state 560 561 def __setattr__(self, __name: str, __value: Any) -> None: 562 if __name == "variables": 563 if __value is not None: 564 if not ( 565 isinstance(__value, dict) 566 or isinstance(__value, OrderedDict) 567 or isinstance(__value, FrozenDict) 568 ): 569 raise ValueError(f"{__name} must be a dict") 570 object.__setattr__(self, __name, JaxConverter()(__value)) 571 else: 572 raise ValueError(f"{__name} cannot be None") 573 elif __name == "preproc_state": 574 if __value is not None: 575 if not ( 576 isinstance(__value, dict) 577 or isinstance(__value, OrderedDict) 578 or isinstance(__value, FrozenDict) 579 ): 580 raise ValueError(f"{__name} must be a FrozenDict") 581 object.__setattr__(self, __name, freeze(JaxConverter()(__value))) 582 else: 583 raise ValueError(f"{__name} cannot be None") 584 585 elif self._initializing: 586 object.__setattr__(self, __name, __value) 587 else: 588 raise ValueError(f"{__name} attribute of FENNIX model is immutable.") 589 590 def generate_dummy_system( 591 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 592 ) -> Dict[str, Any]: 593 """ 594 Generate dummy system for initialization 595 """ 596 if box_size is None: 597 box_size = 2 * self.cutoff 598 for g in self._graphs_properties.values(): 599 cutoff = g["cutoff"] 600 if cutoff is not None: 601 box_size = min(box_size, 2 * g["cutoff"]) 602 coordinates = np.array( 603 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 604 ) 605 species = np.ones((n_atoms,), dtype=np.int32) 606 batch_index = np.zeros((n_atoms,), dtype=np.int32) 607 natoms = np.array([n_atoms], dtype=np.int32) 608 return { 609 "species": species, 610 "coordinates": coordinates, 611 # "graph": graph, 612 "batch_index": batch_index, 613 "natoms": natoms, 614 } 615 616 def summarize( 617 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 618 ) -> str: 619 """Summarize the model architecture and parameters""" 620 if rng_key is None: 621 head = "Summarizing with example data:\n" 622 rng_key = jax.random.PRNGKey(0) 623 if example_data is None: 624 head = "Summarizing with dummy 10 atoms system:\n" 625 rng_key, rng_key_sys = jax.random.split(rng_key) 626 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 627 rng_key, rng_key_pre = jax.random.split(rng_key) 628 _, inputs = self.preprocessing.init_with_output(example_data) 629 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs) 630 631 def to_dict(self): 632 """return a dictionary representation of the model""" 633 return { 634 **self._input_args, 635 "energy_terms": self.energy_terms, 636 "variables": deepcopy(self.variables), 637 } 638 639 def save(self, filename): 640 """save the model to a file""" 641 state_dict = self.to_dict() 642 state_dict["preprocessing"] = [ 643 [k, v] for k, v in state_dict["preprocessing"].items() 644 ] 645 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 646 with open(filename, "wb") as f: 647 f.write(serialization.msgpack_serialize(state_dict)) 648 649 @classmethod 650 def load( 651 cls, 652 filename, 653 use_atom_padding=False, 654 graph_config={}, 655 ): 656 """load a model from a file""" 657 with open(filename, "rb") as f: 658 state_dict = serialization.msgpack_restore(f.read()) 659 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 660 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 661 return cls( 662 **state_dict, 663 graph_config=graph_config, 664 use_atom_padding=use_atom_padding, 665 )
25@dataclasses.dataclass 26class FENNIX: 27 """ 28 Static wrapper for FENNIX models 29 30 The underlying model is a `fennol.models.modules.FENNIXModules` built from the `modules` dictionary 31 which references registered modules in `fennol.models.modules.MODULES` and provides the parameters for initialization. 32 33 Since the model is static and contains variables, it must be initialized right away with either 34 `example_data`, `variables` or `rng_key`. If `variables` is provided, it is used directly. If `example_data` 35 is provided, the model is initialized with `example_data` and the resulting variables are stored 36 in the wrapper. If only `rng_key` is provided, the model is initialized with a dummy system and the resulting. 37 """ 38 39 cutoff: Union[float, None] 40 modules: FENNIXModules 41 variables: Dict 42 preprocessing: PreprocessingChain 43 _apply: Callable[[Dict, Dict], Dict] 44 _total_energy: Callable[[Dict, Dict], Tuple[jnp.ndarray, Dict]] 45 _energy_and_forces: Callable[[Dict, Dict], Tuple[jnp.ndarray, jnp.ndarray, Dict]] 46 _input_args: Dict 47 _graphs_properties: Dict 48 preproc_state: Dict 49 energy_terms: Optional[Sequence[str]] = None 50 _initializing: bool = True 51 use_atom_padding: bool = False 52 53 def __init__( 54 self, 55 cutoff: float, 56 modules: OrderedDict, 57 preprocessing: OrderedDict = OrderedDict(), 58 example_data=None, 59 rng_key: Optional[jax.random.PRNGKey] = None, 60 variables: Optional[dict] = None, 61 energy_terms: Optional[Sequence[str]] = None, 62 use_atom_padding: bool = False, 63 graph_config: Dict = {}, 64 energy_unit: str = "Ha", 65 **kwargs, 66 ) -> None: 67 """Initialize the FENNIX model 68 69 Arguments: 70 ---------- 71 cutoff: float 72 The cutoff radius for the model 73 modules: OrderedDict 74 The dictionary defining the sequence of FeNNol modules and their parameters. 75 preprocessing: OrderedDict 76 The dictionary defining the sequence of preprocessing modules and their parameters. 77 example_data: dict 78 Example data to initialize the model. If not provided, a dummy system is generated. 79 rng_key: jax.random.PRNGKey 80 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 81 variables: dict 82 The variables of the model (i.e. weights, biases and all other tunable parameters). 83 If not provided, the variables are initialized (usually at random) 84 energy_terms: Sequence[str] 85 The energy terms in the model output that will be summed to compute the total energy. 86 If None, the total energy is always zero (useful for non-PES models). 87 use_atom_padding: bool 88 If True, the model will use atom padding for the input data. 89 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 90 graph_config: dict 91 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 92 93 """ 94 self._input_args = { 95 "cutoff": cutoff, 96 "modules": OrderedDict(modules), 97 "preprocessing": OrderedDict(preprocessing), 98 "energy_unit": energy_unit, 99 } 100 self.energy_unit = energy_unit 101 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 102 self.cutoff = cutoff 103 self.energy_terms = energy_terms 104 self.use_atom_padding = use_atom_padding 105 106 # add non-differentiable/non-jittable modules 107 preprocessing = deepcopy(preprocessing) 108 if cutoff is None: 109 preprocessing_modules = [] 110 else: 111 prep_keys = list(preprocessing.keys()) 112 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 113 if len(prep_keys) > 0 and prep_keys[0] == "graph": 114 graph_params = { 115 **graph_params, 116 **preprocessing.pop("graph"), 117 } 118 graph_params = {**graph_params, **graph_config} 119 120 preprocessing_modules = [ 121 GraphGenerator(**graph_params), 122 ] 123 124 for name, params in preprocessing.items(): 125 key = str(params.pop("module_name")) if "module_name" in params else name 126 key = str(params.pop("FID")) if "FID" in params else key 127 mod = PREPROCESSING[key.upper()](**freeze(params)) 128 preprocessing_modules.append(mod) 129 130 self.preprocessing = PreprocessingChain( 131 tuple(preprocessing_modules), use_atom_padding 132 ) 133 graphs_properties = self.preprocessing.get_graphs_properties() 134 self._graphs_properties = freeze(graphs_properties) 135 # add preprocessing modules that should be differentiated/jitted 136 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 137 # mods = self.preprocessing.get_processors(return_list=True) 138 139 # build the model 140 modules = deepcopy(modules) 141 modules_names = [] 142 for name, params in modules.items(): 143 key = str(params.pop("module_name")) if "module_name" in params else name 144 key = str(params.pop("FID")) if "FID" in params else key 145 if name in modules_names: 146 raise ValueError(f"Module {name} already exists") 147 modules_names.append(name) 148 params["name"] = name 149 mod = MODULES[key.upper()] 150 fields = [f.name for f in dataclasses.fields(mod)] 151 if "_graphs_properties" in fields: 152 params["_graphs_properties"] = graphs_properties 153 if "_energy_unit" in fields: 154 params["_energy_unit"] = energy_unit 155 mods.append((mod, params)) 156 157 self.modules = FENNIXModules(mods) 158 159 self.__apply = self.modules.apply 160 self._apply = jax.jit(self.modules.apply) 161 162 self.set_energy_terms(energy_terms) 163 164 # initialize the model 165 166 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 167 168 if variables is not None: 169 self.variables = variables 170 elif rng_key is not None: 171 self.variables = self.modules.init(rng_key, inputs) 172 else: 173 raise ValueError( 174 "Either variables or a jax.random.PRNGKey must be provided for initialization" 175 ) 176 177 self._initializing = False 178 179 def set_energy_terms( 180 self, energy_terms: Union[Sequence[str], None], jit: bool = True 181 ) -> None: 182 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 183 object.__setattr__(self, "energy_terms", energy_terms) 184 if isinstance(energy_terms, str): 185 energy_terms = [energy_terms] 186 187 if energy_terms is None or len(energy_terms) == 0: 188 189 def total_energy(variables, data): 190 out = self.__apply(variables, data) 191 coords = out["coordinates"] 192 nsys = out["natoms"].shape[0] 193 nat = coords.shape[0] 194 dtype = coords.dtype 195 e = jnp.zeros(nsys, dtype=dtype) 196 eat = jnp.zeros(nat, dtype=dtype) 197 out["total_energy"] = e 198 out["atomic_energies"] = eat 199 return e, out 200 201 def energy_and_forces(variables, data): 202 e, out = total_energy(variables, data) 203 f = jnp.zeros_like(out["coordinates"]) 204 out["forces"] = f 205 return e, f, out 206 207 def energy_and_forces_and_virial(variables, data): 208 e, f, out = energy_and_forces(variables, data) 209 v = jnp.zeros( 210 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 211 ) 212 out["virial_tensor"] = v 213 return e, f, v, out 214 215 else: 216 # build the energy and force functions 217 def total_energy(variables, data): 218 out = self.__apply(variables, data) 219 atomic_energies = 0.0 220 system_energies = 0.0 221 species = out["species"] 222 nsys = out["natoms"].shape[0] 223 for term in self.energy_terms: 224 e = out[term] 225 if e.ndim > 1 and e.shape[-1] == 1: 226 e = jnp.squeeze(e, axis=-1) 227 if e.shape[0] == nsys and nsys != species.shape[0]: 228 system_energies += e 229 continue 230 assert e.shape == species.shape 231 atomic_energies += e 232 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 233 if isinstance(atomic_energies, jnp.ndarray): 234 if "true_atoms" in out: 235 atomic_energies = jnp.where( 236 out["true_atoms"], atomic_energies, 0.0 237 ) 238 out["atomic_energies"] = atomic_energies 239 energies = jax.ops.segment_sum( 240 atomic_energies, 241 data["batch_index"], 242 num_segments=len(data["natoms"]), 243 ) 244 else: 245 energies = 0.0 246 247 if isinstance(system_energies, jnp.ndarray): 248 if "true_sys" in out: 249 system_energies = jnp.where( 250 out["true_sys"], system_energies, 0.0 251 ) 252 out["system_energies"] = system_energies 253 254 out["total_energy"] = energies + system_energies 255 return out["total_energy"], out 256 257 def energy_and_forces(variables, data): 258 def _etot(variables, coordinates): 259 energy, out = total_energy( 260 variables, {**data, "coordinates": coordinates} 261 ) 262 return energy.sum(), out 263 264 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 265 variables, data["coordinates"] 266 ) 267 out["forces"] = -de 268 269 return out["total_energy"], out["forces"], out 270 271 # def energy_and_forces_and_virial(variables, data): 272 # x = data["coordinates"] 273 # batch_index = data["batch_index"] 274 # if "cells" in data: 275 # cells = data["cells"] 276 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 277 278 # def _etot(variables, coordinates, cells): 279 # reciprocal_cells = jnp.linalg.inv(cells) 280 # energy, out = total_energy( 281 # variables, 282 # { 283 # **data, 284 # "coordinates": coordinates, 285 # "cells": cells, 286 # "reciprocal_cells": reciprocal_cells, 287 # }, 288 # ) 289 # return energy.sum(), out 290 291 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 292 # variables, x, cells 293 # ) 294 # f= -dedx 295 # out["forces"] = f 296 # else: 297 # _,f,out = energy_and_forces(variables, data) 298 299 # vir = -jax.ops.segment_sum( 300 # f[:, :, None] * x[:, None, :], 301 # batch_index, 302 # num_segments=len(data["natoms"]), 303 # ) 304 305 # if "cells" in data: 306 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 307 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 308 # nsys = data["natoms"].shape[0] 309 # if cells.shape[0]==1 and nsys>1: 310 # dvir = dvir / nsys 311 # vir = vir + dvir 312 313 # out["virial_tensor"] = vir 314 315 # return out["total_energy"], f, vir, out 316 317 def energy_and_forces_and_virial(variables, data): 318 x = data["coordinates"] 319 scaling = jnp.asarray( 320 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 321 ) 322 def _etot(variables, coordinates, scaling): 323 batch_index = data["batch_index"] 324 coordinates = jax.vmap(jnp.matmul)( 325 coordinates, scaling[batch_index] 326 ) 327 inputs = {**data, "coordinates": coordinates} 328 if "cells" in data: 329 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 330 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 331 reciprocal_cells = jnp.linalg.inv(cells) 332 inputs["cells"] = cells 333 inputs["reciprocal_cells"] = reciprocal_cells 334 energy, out = total_energy(variables, inputs) 335 return energy.sum(), out 336 337 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 338 variables, x, scaling 339 ) 340 f = -dedx 341 out["forces"] = f 342 out["virial_tensor"] = vir 343 344 return out["total_energy"], f, vir, out 345 346 object.__setattr__(self, "_total_energy_raw", total_energy) 347 if jit: 348 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 349 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 350 object.__setattr__( 351 self, 352 "_energy_and_forces_and_virial", 353 jax.jit(energy_and_forces_and_virial), 354 ) 355 else: 356 object.__setattr__(self, "_total_energy", total_energy) 357 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 358 object.__setattr__( 359 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 360 ) 361 362 def get_gradient_function( 363 self, 364 *gradient_keys: Sequence[str], 365 jit: bool = True, 366 variables_as_input: bool = False, 367 ): 368 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 369 370 def _energy_gradient(variables, data): 371 def _etot(variables, inputs): 372 if "strain" in inputs: 373 scaling = inputs["strain"] 374 batch_index = data["batch_index"] 375 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 376 coordinates = jax.vmap(jnp.matmul)( 377 coordinates, scaling[batch_index] 378 ) 379 inputs = {**inputs, "coordinates": coordinates} 380 if "cells" in inputs or "cells" in data: 381 cells = inputs["cells"] if "cells" in inputs else data["cells"] 382 cells = jax.vmap(jnp.matmul)(cells, scaling) 383 inputs["cells"] = cells 384 if "cells" in inputs: 385 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 386 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 387 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 388 return energy.sum(), out 389 390 if "strain" in gradient_keys and "strain" not in data: 391 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 392 inputs = {k: data[k] for k in gradient_keys} 393 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 394 395 return ( 396 out["total_energy"], 397 de, 398 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 399 ) 400 401 if variables_as_input: 402 energy_gradient = _energy_gradient 403 else: 404 405 def energy_gradient(data): 406 return _energy_gradient(self.variables, data) 407 408 if jit: 409 return jax.jit(energy_gradient) 410 else: 411 return energy_gradient 412 413 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 414 """apply preprocessing to the input data 415 416 !!! This is not a pure function => do not apply jax transforms !!!""" 417 if self.preproc_state is None: 418 out, _ = self.reinitialize_preprocessing(example_data=inputs) 419 elif use_gpu: 420 do_check_input = self.preproc_state.get("check_input", True) 421 if do_check_input: 422 inputs = check_input(inputs) 423 preproc_state, inputs = self.preprocessing.atom_padding( 424 self.preproc_state, inputs 425 ) 426 inputs = self.preprocessing.process(preproc_state, inputs) 427 preproc_state, state_up, out, overflow = ( 428 self.preprocessing.check_reallocate( 429 preproc_state, inputs 430 ) 431 ) 432 if verbose and overflow: 433 print("GPU preprocessing: nblist overflow => reallocating nblist") 434 print("size updates:", state_up) 435 else: 436 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 437 438 object.__setattr__(self, "preproc_state", preproc_state) 439 return out 440 441 def reinitialize_preprocessing( 442 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 443 ) -> None: 444 ### TODO ### 445 if rng_key is None: 446 rng_key_pre = jax.random.PRNGKey(0) 447 else: 448 rng_key, rng_key_pre = jax.random.split(rng_key) 449 450 if example_data is None: 451 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 452 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 453 454 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 455 object.__setattr__(self, "preproc_state", preproc_state) 456 return inputs, rng_key 457 458 def __call__(self, variables: Optional[dict] = None, gpu_preprocessing=False,**inputs) -> Dict[str, Any]: 459 """Apply the FENNIX model (preprocess + modules) 460 461 !!! This is not a pure function => do not apply jax transforms !!! 462 if you want to apply jax transforms, use self._apply(variables, inputs) which is pure and preprocess the input using self.preprocess 463 """ 464 if variables is None: 465 variables = self.variables 466 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 467 output = self._apply(variables, inputs) 468 if self.use_atom_padding: 469 output = atom_unpadding(output) 470 return output 471 472 def total_energy( 473 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 474 ) -> Tuple[jnp.ndarray, Dict]: 475 """compute the total energy of the system 476 477 !!! This is not a pure function => do not apply jax transforms !!! 478 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 479 """ 480 if variables is None: 481 variables = self.variables 482 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 483 # def print_shape(path,value): 484 # if isinstance(value,jnp.ndarray): 485 # print(path,value.shape) 486 # else: 487 # print(path,value) 488 # jax.tree_util.tree_map_with_path(print_shape,inputs) 489 _, output = self._total_energy(variables, inputs) 490 if self.use_atom_padding: 491 output = atom_unpadding(output) 492 e = output["total_energy"] 493 if unit is not None: 494 model_energy_unit = self.Ha_to_model_energy 495 if isinstance(unit, str): 496 unit = au.get_multiplier(unit) 497 e = e * (unit / model_energy_unit) 498 return e, output 499 500 def energy_and_forces( 501 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 502 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 503 """compute the total energy and forces of the system 504 505 !!! This is not a pure function => do not apply jax transforms !!! 506 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 507 """ 508 if variables is None: 509 variables = self.variables 510 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 511 _, _, output = self._energy_and_forces(variables, inputs) 512 if self.use_atom_padding: 513 output = atom_unpadding(output) 514 e = output["total_energy"] 515 f = output["forces"] 516 if unit is not None: 517 model_energy_unit = self.Ha_to_model_energy 518 if isinstance(unit, str): 519 unit = au.get_multiplier(unit) 520 e = e * (unit / model_energy_unit) 521 f = f * (unit / model_energy_unit) 522 return e, f, output 523 524 def energy_and_forces_and_virial( 525 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 526 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 527 """compute the total energy and forces of the system 528 529 !!! This is not a pure function => do not apply jax transforms !!! 530 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 531 """ 532 if variables is None: 533 variables = self.variables 534 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 535 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 536 if self.use_atom_padding: 537 output = atom_unpadding(output) 538 e = output["total_energy"] 539 f = output["forces"] 540 v = output["virial_tensor"] 541 if unit is not None: 542 model_energy_unit = self.Ha_to_model_energy 543 if isinstance(unit, str): 544 unit = au.get_multiplier(unit) 545 e = e * (unit / model_energy_unit) 546 f = f * (unit / model_energy_unit) 547 v = v * (unit / model_energy_unit) 548 return e, f, v, output 549 550 def remove_atom_padding(self, output): 551 """remove atom padding from the output""" 552 return atom_unpadding(output) 553 554 def get_model(self) -> Tuple[FENNIXModules, Dict]: 555 """return the model and its variables""" 556 return self.modules, self.variables 557 558 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 559 """return the preprocessing chain and its state""" 560 return self.preprocessing, self.preproc_state 561 562 def __setattr__(self, __name: str, __value: Any) -> None: 563 if __name == "variables": 564 if __value is not None: 565 if not ( 566 isinstance(__value, dict) 567 or isinstance(__value, OrderedDict) 568 or isinstance(__value, FrozenDict) 569 ): 570 raise ValueError(f"{__name} must be a dict") 571 object.__setattr__(self, __name, JaxConverter()(__value)) 572 else: 573 raise ValueError(f"{__name} cannot be None") 574 elif __name == "preproc_state": 575 if __value is not None: 576 if not ( 577 isinstance(__value, dict) 578 or isinstance(__value, OrderedDict) 579 or isinstance(__value, FrozenDict) 580 ): 581 raise ValueError(f"{__name} must be a FrozenDict") 582 object.__setattr__(self, __name, freeze(JaxConverter()(__value))) 583 else: 584 raise ValueError(f"{__name} cannot be None") 585 586 elif self._initializing: 587 object.__setattr__(self, __name, __value) 588 else: 589 raise ValueError(f"{__name} attribute of FENNIX model is immutable.") 590 591 def generate_dummy_system( 592 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 593 ) -> Dict[str, Any]: 594 """ 595 Generate dummy system for initialization 596 """ 597 if box_size is None: 598 box_size = 2 * self.cutoff 599 for g in self._graphs_properties.values(): 600 cutoff = g["cutoff"] 601 if cutoff is not None: 602 box_size = min(box_size, 2 * g["cutoff"]) 603 coordinates = np.array( 604 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 605 ) 606 species = np.ones((n_atoms,), dtype=np.int32) 607 batch_index = np.zeros((n_atoms,), dtype=np.int32) 608 natoms = np.array([n_atoms], dtype=np.int32) 609 return { 610 "species": species, 611 "coordinates": coordinates, 612 # "graph": graph, 613 "batch_index": batch_index, 614 "natoms": natoms, 615 } 616 617 def summarize( 618 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 619 ) -> str: 620 """Summarize the model architecture and parameters""" 621 if rng_key is None: 622 head = "Summarizing with example data:\n" 623 rng_key = jax.random.PRNGKey(0) 624 if example_data is None: 625 head = "Summarizing with dummy 10 atoms system:\n" 626 rng_key, rng_key_sys = jax.random.split(rng_key) 627 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 628 rng_key, rng_key_pre = jax.random.split(rng_key) 629 _, inputs = self.preprocessing.init_with_output(example_data) 630 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs) 631 632 def to_dict(self): 633 """return a dictionary representation of the model""" 634 return { 635 **self._input_args, 636 "energy_terms": self.energy_terms, 637 "variables": deepcopy(self.variables), 638 } 639 640 def save(self, filename): 641 """save the model to a file""" 642 state_dict = self.to_dict() 643 state_dict["preprocessing"] = [ 644 [k, v] for k, v in state_dict["preprocessing"].items() 645 ] 646 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 647 with open(filename, "wb") as f: 648 f.write(serialization.msgpack_serialize(state_dict)) 649 650 @classmethod 651 def load( 652 cls, 653 filename, 654 use_atom_padding=False, 655 graph_config={}, 656 ): 657 """load a model from a file""" 658 with open(filename, "rb") as f: 659 state_dict = serialization.msgpack_restore(f.read()) 660 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 661 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 662 return cls( 663 **state_dict, 664 graph_config=graph_config, 665 use_atom_padding=use_atom_padding, 666 )
Static wrapper for FENNIX models
The underlying model is a fennol.models.modules.FENNIXModules
built from the modules
dictionary
which references registered modules in fennol.models.modules.MODULES
and provides the parameters for initialization.
Since the model is static and contains variables, it must be initialized right away with either
example_data
, variables
or rng_key
. If variables
is provided, it is used directly. If example_data
is provided, the model is initialized with example_data
and the resulting variables are stored
in the wrapper. If only rng_key
is provided, the model is initialized with a dummy system and the resulting.
53 def __init__( 54 self, 55 cutoff: float, 56 modules: OrderedDict, 57 preprocessing: OrderedDict = OrderedDict(), 58 example_data=None, 59 rng_key: Optional[jax.random.PRNGKey] = None, 60 variables: Optional[dict] = None, 61 energy_terms: Optional[Sequence[str]] = None, 62 use_atom_padding: bool = False, 63 graph_config: Dict = {}, 64 energy_unit: str = "Ha", 65 **kwargs, 66 ) -> None: 67 """Initialize the FENNIX model 68 69 Arguments: 70 ---------- 71 cutoff: float 72 The cutoff radius for the model 73 modules: OrderedDict 74 The dictionary defining the sequence of FeNNol modules and their parameters. 75 preprocessing: OrderedDict 76 The dictionary defining the sequence of preprocessing modules and their parameters. 77 example_data: dict 78 Example data to initialize the model. If not provided, a dummy system is generated. 79 rng_key: jax.random.PRNGKey 80 The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). 81 variables: dict 82 The variables of the model (i.e. weights, biases and all other tunable parameters). 83 If not provided, the variables are initialized (usually at random) 84 energy_terms: Sequence[str] 85 The energy terms in the model output that will be summed to compute the total energy. 86 If None, the total energy is always zero (useful for non-PES models). 87 use_atom_padding: bool 88 If True, the model will use atom padding for the input data. 89 This is useful when one plans to frequently change the number of atoms in the system (for example during training). 90 graph_config: dict 91 Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size. 92 93 """ 94 self._input_args = { 95 "cutoff": cutoff, 96 "modules": OrderedDict(modules), 97 "preprocessing": OrderedDict(preprocessing), 98 "energy_unit": energy_unit, 99 } 100 self.energy_unit = energy_unit 101 self.Ha_to_model_energy = au.get_multiplier(energy_unit) 102 self.cutoff = cutoff 103 self.energy_terms = energy_terms 104 self.use_atom_padding = use_atom_padding 105 106 # add non-differentiable/non-jittable modules 107 preprocessing = deepcopy(preprocessing) 108 if cutoff is None: 109 preprocessing_modules = [] 110 else: 111 prep_keys = list(preprocessing.keys()) 112 graph_params = {"cutoff": cutoff, "graph_key": "graph"} 113 if len(prep_keys) > 0 and prep_keys[0] == "graph": 114 graph_params = { 115 **graph_params, 116 **preprocessing.pop("graph"), 117 } 118 graph_params = {**graph_params, **graph_config} 119 120 preprocessing_modules = [ 121 GraphGenerator(**graph_params), 122 ] 123 124 for name, params in preprocessing.items(): 125 key = str(params.pop("module_name")) if "module_name" in params else name 126 key = str(params.pop("FID")) if "FID" in params else key 127 mod = PREPROCESSING[key.upper()](**freeze(params)) 128 preprocessing_modules.append(mod) 129 130 self.preprocessing = PreprocessingChain( 131 tuple(preprocessing_modules), use_atom_padding 132 ) 133 graphs_properties = self.preprocessing.get_graphs_properties() 134 self._graphs_properties = freeze(graphs_properties) 135 # add preprocessing modules that should be differentiated/jitted 136 mods = [(JaxConverter, {})] + self.preprocessing.get_processors() 137 # mods = self.preprocessing.get_processors(return_list=True) 138 139 # build the model 140 modules = deepcopy(modules) 141 modules_names = [] 142 for name, params in modules.items(): 143 key = str(params.pop("module_name")) if "module_name" in params else name 144 key = str(params.pop("FID")) if "FID" in params else key 145 if name in modules_names: 146 raise ValueError(f"Module {name} already exists") 147 modules_names.append(name) 148 params["name"] = name 149 mod = MODULES[key.upper()] 150 fields = [f.name for f in dataclasses.fields(mod)] 151 if "_graphs_properties" in fields: 152 params["_graphs_properties"] = graphs_properties 153 if "_energy_unit" in fields: 154 params["_energy_unit"] = energy_unit 155 mods.append((mod, params)) 156 157 self.modules = FENNIXModules(mods) 158 159 self.__apply = self.modules.apply 160 self._apply = jax.jit(self.modules.apply) 161 162 self.set_energy_terms(energy_terms) 163 164 # initialize the model 165 166 inputs, rng_key = self.reinitialize_preprocessing(rng_key, example_data) 167 168 if variables is not None: 169 self.variables = variables 170 elif rng_key is not None: 171 self.variables = self.modules.init(rng_key, inputs) 172 else: 173 raise ValueError( 174 "Either variables or a jax.random.PRNGKey must be provided for initialization" 175 ) 176 177 self._initializing = False
Initialize the FENNIX model
Arguments:
cutoff: float The cutoff radius for the model modules: OrderedDict The dictionary defining the sequence of FeNNol modules and their parameters. preprocessing: OrderedDict The dictionary defining the sequence of preprocessing modules and their parameters. example_data: dict Example data to initialize the model. If not provided, a dummy system is generated. rng_key: jax.random.PRNGKey The random key to initialize the model. If not provided, jax.random.PRNGKey(0) is used (should be avoided). variables: dict The variables of the model (i.e. weights, biases and all other tunable parameters). If not provided, the variables are initialized (usually at random) energy_terms: Sequence[str] The energy terms in the model output that will be summed to compute the total energy. If None, the total energy is always zero (useful for non-PES models). use_atom_padding: bool If True, the model will use atom padding for the input data. This is useful when one plans to frequently change the number of atoms in the system (for example during training). graph_config: dict Edit the graph configuration. Mostly used to change a long-range cutoff as a function of a simulation box size.
179 def set_energy_terms( 180 self, energy_terms: Union[Sequence[str], None], jit: bool = True 181 ) -> None: 182 """Set the energy terms to be computed by the model and prepare the energy and force functions.""" 183 object.__setattr__(self, "energy_terms", energy_terms) 184 if isinstance(energy_terms, str): 185 energy_terms = [energy_terms] 186 187 if energy_terms is None or len(energy_terms) == 0: 188 189 def total_energy(variables, data): 190 out = self.__apply(variables, data) 191 coords = out["coordinates"] 192 nsys = out["natoms"].shape[0] 193 nat = coords.shape[0] 194 dtype = coords.dtype 195 e = jnp.zeros(nsys, dtype=dtype) 196 eat = jnp.zeros(nat, dtype=dtype) 197 out["total_energy"] = e 198 out["atomic_energies"] = eat 199 return e, out 200 201 def energy_and_forces(variables, data): 202 e, out = total_energy(variables, data) 203 f = jnp.zeros_like(out["coordinates"]) 204 out["forces"] = f 205 return e, f, out 206 207 def energy_and_forces_and_virial(variables, data): 208 e, f, out = energy_and_forces(variables, data) 209 v = jnp.zeros( 210 (out["natoms"].shape[0], 3, 3), dtype=out["coordinates"].dtype 211 ) 212 out["virial_tensor"] = v 213 return e, f, v, out 214 215 else: 216 # build the energy and force functions 217 def total_energy(variables, data): 218 out = self.__apply(variables, data) 219 atomic_energies = 0.0 220 system_energies = 0.0 221 species = out["species"] 222 nsys = out["natoms"].shape[0] 223 for term in self.energy_terms: 224 e = out[term] 225 if e.ndim > 1 and e.shape[-1] == 1: 226 e = jnp.squeeze(e, axis=-1) 227 if e.shape[0] == nsys and nsys != species.shape[0]: 228 system_energies += e 229 continue 230 assert e.shape == species.shape 231 atomic_energies += e 232 # atomic_energies = jnp.squeeze(atomic_energies, axis=-1) 233 if isinstance(atomic_energies, jnp.ndarray): 234 if "true_atoms" in out: 235 atomic_energies = jnp.where( 236 out["true_atoms"], atomic_energies, 0.0 237 ) 238 out["atomic_energies"] = atomic_energies 239 energies = jax.ops.segment_sum( 240 atomic_energies, 241 data["batch_index"], 242 num_segments=len(data["natoms"]), 243 ) 244 else: 245 energies = 0.0 246 247 if isinstance(system_energies, jnp.ndarray): 248 if "true_sys" in out: 249 system_energies = jnp.where( 250 out["true_sys"], system_energies, 0.0 251 ) 252 out["system_energies"] = system_energies 253 254 out["total_energy"] = energies + system_energies 255 return out["total_energy"], out 256 257 def energy_and_forces(variables, data): 258 def _etot(variables, coordinates): 259 energy, out = total_energy( 260 variables, {**data, "coordinates": coordinates} 261 ) 262 return energy.sum(), out 263 264 de, out = jax.grad(_etot, argnums=1, has_aux=True)( 265 variables, data["coordinates"] 266 ) 267 out["forces"] = -de 268 269 return out["total_energy"], out["forces"], out 270 271 # def energy_and_forces_and_virial(variables, data): 272 # x = data["coordinates"] 273 # batch_index = data["batch_index"] 274 # if "cells" in data: 275 # cells = data["cells"] 276 # ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 277 278 # def _etot(variables, coordinates, cells): 279 # reciprocal_cells = jnp.linalg.inv(cells) 280 # energy, out = total_energy( 281 # variables, 282 # { 283 # **data, 284 # "coordinates": coordinates, 285 # "cells": cells, 286 # "reciprocal_cells": reciprocal_cells, 287 # }, 288 # ) 289 # return energy.sum(), out 290 291 # (dedx, dedcells), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 292 # variables, x, cells 293 # ) 294 # f= -dedx 295 # out["forces"] = f 296 # else: 297 # _,f,out = energy_and_forces(variables, data) 298 299 # vir = -jax.ops.segment_sum( 300 # f[:, :, None] * x[:, None, :], 301 # batch_index, 302 # num_segments=len(data["natoms"]), 303 # ) 304 305 # if "cells" in data: 306 # # dvir = jax.vmap(jnp.matmul)(dedcells, cells.transpose(0, 2, 1)) 307 # dvir = jnp.einsum("...ki,...kj->...ij", dedcells, cells) 308 # nsys = data["natoms"].shape[0] 309 # if cells.shape[0]==1 and nsys>1: 310 # dvir = dvir / nsys 311 # vir = vir + dvir 312 313 # out["virial_tensor"] = vir 314 315 # return out["total_energy"], f, vir, out 316 317 def energy_and_forces_and_virial(variables, data): 318 x = data["coordinates"] 319 scaling = jnp.asarray( 320 np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0) 321 ) 322 def _etot(variables, coordinates, scaling): 323 batch_index = data["batch_index"] 324 coordinates = jax.vmap(jnp.matmul)( 325 coordinates, scaling[batch_index] 326 ) 327 inputs = {**data, "coordinates": coordinates} 328 if "cells" in data: 329 ## cells is a nbatchx3x3 matrix which lines are cell vectors (i.e. cells[0,0,:] is the first cell vector of the first system) 330 cells = jax.vmap(jnp.matmul)(data["cells"], scaling) 331 reciprocal_cells = jnp.linalg.inv(cells) 332 inputs["cells"] = cells 333 inputs["reciprocal_cells"] = reciprocal_cells 334 energy, out = total_energy(variables, inputs) 335 return energy.sum(), out 336 337 (dedx, vir), out = jax.grad(_etot, argnums=(1, 2), has_aux=True)( 338 variables, x, scaling 339 ) 340 f = -dedx 341 out["forces"] = f 342 out["virial_tensor"] = vir 343 344 return out["total_energy"], f, vir, out 345 346 object.__setattr__(self, "_total_energy_raw", total_energy) 347 if jit: 348 object.__setattr__(self, "_total_energy", jax.jit(total_energy)) 349 object.__setattr__(self, "_energy_and_forces", jax.jit(energy_and_forces)) 350 object.__setattr__( 351 self, 352 "_energy_and_forces_and_virial", 353 jax.jit(energy_and_forces_and_virial), 354 ) 355 else: 356 object.__setattr__(self, "_total_energy", total_energy) 357 object.__setattr__(self, "_energy_and_forces", energy_and_forces) 358 object.__setattr__( 359 self, "_energy_and_forces_and_virial", energy_and_forces_and_virial 360 )
Set the energy terms to be computed by the model and prepare the energy and force functions.
362 def get_gradient_function( 363 self, 364 *gradient_keys: Sequence[str], 365 jit: bool = True, 366 variables_as_input: bool = False, 367 ): 368 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 369 370 def _energy_gradient(variables, data): 371 def _etot(variables, inputs): 372 if "strain" in inputs: 373 scaling = inputs["strain"] 374 batch_index = data["batch_index"] 375 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 376 coordinates = jax.vmap(jnp.matmul)( 377 coordinates, scaling[batch_index] 378 ) 379 inputs = {**inputs, "coordinates": coordinates} 380 if "cells" in inputs or "cells" in data: 381 cells = inputs["cells"] if "cells" in inputs else data["cells"] 382 cells = jax.vmap(jnp.matmul)(cells, scaling) 383 inputs["cells"] = cells 384 if "cells" in inputs: 385 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 386 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 387 energy, out = self._total_energy_raw(variables, {**data, **inputs}) 388 return energy.sum(), out 389 390 if "strain" in gradient_keys and "strain" not in data: 391 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 392 inputs = {k: data[k] for k in gradient_keys} 393 de, out = jax.grad(_etot, argnums=1, has_aux=True)(variables, inputs) 394 395 return ( 396 out["total_energy"], 397 de, 398 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 399 ) 400 401 if variables_as_input: 402 energy_gradient = _energy_gradient 403 else: 404 405 def energy_gradient(data): 406 return _energy_gradient(self.variables, data) 407 408 if jit: 409 return jax.jit(energy_gradient) 410 else: 411 return energy_gradient
Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys
413 def preprocess(self,use_gpu=False, verbose=False,**inputs) -> Dict[str, Any]: 414 """apply preprocessing to the input data 415 416 !!! This is not a pure function => do not apply jax transforms !!!""" 417 if self.preproc_state is None: 418 out, _ = self.reinitialize_preprocessing(example_data=inputs) 419 elif use_gpu: 420 do_check_input = self.preproc_state.get("check_input", True) 421 if do_check_input: 422 inputs = check_input(inputs) 423 preproc_state, inputs = self.preprocessing.atom_padding( 424 self.preproc_state, inputs 425 ) 426 inputs = self.preprocessing.process(preproc_state, inputs) 427 preproc_state, state_up, out, overflow = ( 428 self.preprocessing.check_reallocate( 429 preproc_state, inputs 430 ) 431 ) 432 if verbose and overflow: 433 print("GPU preprocessing: nblist overflow => reallocating nblist") 434 print("size updates:", state_up) 435 else: 436 preproc_state, out = self.preprocessing(self.preproc_state, inputs) 437 438 object.__setattr__(self, "preproc_state", preproc_state) 439 return out
apply preprocessing to the input data
!!! This is not a pure function => do not apply jax transforms !!!
441 def reinitialize_preprocessing( 442 self, rng_key: Optional[jax.random.PRNGKey] = None, example_data=None 443 ) -> None: 444 ### TODO ### 445 if rng_key is None: 446 rng_key_pre = jax.random.PRNGKey(0) 447 else: 448 rng_key, rng_key_pre = jax.random.split(rng_key) 449 450 if example_data is None: 451 rng_key_sys, rng_key_pre = jax.random.split(rng_key_pre) 452 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 453 454 preproc_state, inputs = self.preprocessing.init_with_output(example_data) 455 object.__setattr__(self, "preproc_state", preproc_state) 456 return inputs, rng_key
472 def total_energy( 473 self, variables: Optional[dict] = None, unit: Union[float, str] = None, gpu_preprocessing=False,**inputs 474 ) -> Tuple[jnp.ndarray, Dict]: 475 """compute the total energy of the system 476 477 !!! This is not a pure function => do not apply jax transforms !!! 478 if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess 479 """ 480 if variables is None: 481 variables = self.variables 482 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 483 # def print_shape(path,value): 484 # if isinstance(value,jnp.ndarray): 485 # print(path,value.shape) 486 # else: 487 # print(path,value) 488 # jax.tree_util.tree_map_with_path(print_shape,inputs) 489 _, output = self._total_energy(variables, inputs) 490 if self.use_atom_padding: 491 output = atom_unpadding(output) 492 e = output["total_energy"] 493 if unit is not None: 494 model_energy_unit = self.Ha_to_model_energy 495 if isinstance(unit, str): 496 unit = au.get_multiplier(unit) 497 e = e * (unit / model_energy_unit) 498 return e, output
compute the total energy of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._total_energy(variables, inputs) which is pure and preprocess the input using self.preprocess
500 def energy_and_forces( 501 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False,**inputs 502 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 503 """compute the total energy and forces of the system 504 505 !!! This is not a pure function => do not apply jax transforms !!! 506 if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess 507 """ 508 if variables is None: 509 variables = self.variables 510 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 511 _, _, output = self._energy_and_forces(variables, inputs) 512 if self.use_atom_padding: 513 output = atom_unpadding(output) 514 e = output["total_energy"] 515 f = output["forces"] 516 if unit is not None: 517 model_energy_unit = self.Ha_to_model_energy 518 if isinstance(unit, str): 519 unit = au.get_multiplier(unit) 520 e = e * (unit / model_energy_unit) 521 f = f * (unit / model_energy_unit) 522 return e, f, output
compute the total energy and forces of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces(variables, inputs) which is pure and preprocess the input using self.preprocess
524 def energy_and_forces_and_virial( 525 self, variables: Optional[dict] = None, unit: Union[float, str] = None,gpu_preprocessing=False, **inputs 526 ) -> Tuple[jnp.ndarray, jnp.ndarray, Dict]: 527 """compute the total energy and forces of the system 528 529 !!! This is not a pure function => do not apply jax transforms !!! 530 if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess 531 """ 532 if variables is None: 533 variables = self.variables 534 inputs = self.preprocess(use_gpu=gpu_preprocessing,**inputs) 535 _, _, _, output = self._energy_and_forces_and_virial(variables, inputs) 536 if self.use_atom_padding: 537 output = atom_unpadding(output) 538 e = output["total_energy"] 539 f = output["forces"] 540 v = output["virial_tensor"] 541 if unit is not None: 542 model_energy_unit = self.Ha_to_model_energy 543 if isinstance(unit, str): 544 unit = au.get_multiplier(unit) 545 e = e * (unit / model_energy_unit) 546 f = f * (unit / model_energy_unit) 547 v = v * (unit / model_energy_unit) 548 return e, f, v, output
compute the total energy and forces of the system
!!! This is not a pure function => do not apply jax transforms !!! if you want to apply jax transforms, use self._energy_and_forces_and_virial(variables, inputs) which is pure and preprocess the input using self.preprocess
550 def remove_atom_padding(self, output): 551 """remove atom padding from the output""" 552 return atom_unpadding(output)
remove atom padding from the output
554 def get_model(self) -> Tuple[FENNIXModules, Dict]: 555 """return the model and its variables""" 556 return self.modules, self.variables
return the model and its variables
558 def get_preprocessing(self) -> Tuple[PreprocessingChain, Dict]: 559 """return the preprocessing chain and its state""" 560 return self.preprocessing, self.preproc_state
return the preprocessing chain and its state
591 def generate_dummy_system( 592 self, rng_key: jax.random.PRNGKey, box_size=None, n_atoms: int = 10 593 ) -> Dict[str, Any]: 594 """ 595 Generate dummy system for initialization 596 """ 597 if box_size is None: 598 box_size = 2 * self.cutoff 599 for g in self._graphs_properties.values(): 600 cutoff = g["cutoff"] 601 if cutoff is not None: 602 box_size = min(box_size, 2 * g["cutoff"]) 603 coordinates = np.array( 604 jax.random.uniform(rng_key, (n_atoms, 3), maxval=box_size), dtype=np.float64 605 ) 606 species = np.ones((n_atoms,), dtype=np.int32) 607 batch_index = np.zeros((n_atoms,), dtype=np.int32) 608 natoms = np.array([n_atoms], dtype=np.int32) 609 return { 610 "species": species, 611 "coordinates": coordinates, 612 # "graph": graph, 613 "batch_index": batch_index, 614 "natoms": natoms, 615 }
Generate dummy system for initialization
617 def summarize( 618 self, rng_key: jax.random.PRNGKey = None, example_data=None, **kwargs 619 ) -> str: 620 """Summarize the model architecture and parameters""" 621 if rng_key is None: 622 head = "Summarizing with example data:\n" 623 rng_key = jax.random.PRNGKey(0) 624 if example_data is None: 625 head = "Summarizing with dummy 10 atoms system:\n" 626 rng_key, rng_key_sys = jax.random.split(rng_key) 627 example_data = self.generate_dummy_system(rng_key_sys, n_atoms=10) 628 rng_key, rng_key_pre = jax.random.split(rng_key) 629 _, inputs = self.preprocessing.init_with_output(example_data) 630 return head + nn.tabulate(self.modules, rng_key, **kwargs)(inputs)
Summarize the model architecture and parameters
632 def to_dict(self): 633 """return a dictionary representation of the model""" 634 return { 635 **self._input_args, 636 "energy_terms": self.energy_terms, 637 "variables": deepcopy(self.variables), 638 }
return a dictionary representation of the model
640 def save(self, filename): 641 """save the model to a file""" 642 state_dict = self.to_dict() 643 state_dict["preprocessing"] = [ 644 [k, v] for k, v in state_dict["preprocessing"].items() 645 ] 646 state_dict["modules"] = [[k, v] for k, v in state_dict["modules"].items()] 647 with open(filename, "wb") as f: 648 f.write(serialization.msgpack_serialize(state_dict))
save the model to a file
650 @classmethod 651 def load( 652 cls, 653 filename, 654 use_atom_padding=False, 655 graph_config={}, 656 ): 657 """load a model from a file""" 658 with open(filename, "rb") as f: 659 state_dict = serialization.msgpack_restore(f.read()) 660 state_dict["preprocessing"] = {k: v for k, v in state_dict["preprocessing"]} 661 state_dict["modules"] = {k: v for k, v in state_dict["modules"]} 662 return cls( 663 **state_dict, 664 graph_config=graph_config, 665 use_atom_padding=use_atom_padding, 666 )
load a model from a file