fennol.training.io
1import os, io, sys 2import numpy as np 3from scipy.spatial.transform import Rotation 4from collections import defaultdict 5import pickle 6import glob 7from flax import traverse_util 8from typing import Dict, List, Tuple, Union, Optional, Callable, Sequence 9from .databases import DBDataset, H5Dataset 10from ..models.preprocessing import AtomPadding 11 12import json 13import yaml 14 15try: 16 import tomlkit 17except ImportError: 18 tomlkit = None 19 20try: 21 from torch.utils.data import DataLoader 22except ImportError: 23 raise ImportError( 24 "PyTorch is required for training. Install the CPU version from https://pytorch.org/get-started/locally/" 25 ) 26 27from ..models import FENNIX 28 29 30def load_configuration(config_file: str) -> Dict[str, any]: 31 if config_file.endswith(".json"): 32 parameters = json.load(open(config_file)) 33 elif config_file.endswith(".yaml") or config_file.endswith(".yml"): 34 parameters = yaml.load(open(config_file), Loader=yaml.FullLoader) 35 elif tomlkit is not None and config_file.endswith(".toml"): 36 parameters = tomlkit.loads(open(config_file).read()) 37 else: 38 supported_formats = [".json", ".yaml", ".yml"] 39 if tomlkit is not None: 40 supported_formats.append(".toml") 41 raise ValueError( 42 f"Unknown config file format. Supported formats: {supported_formats}" 43 ) 44 return parameters 45 46 47def load_dataset( 48 dspath: str, 49 batch_size: int, 50 rename_refs: dict = {}, 51 infinite_iterator: bool = False, 52 atom_padding: bool = False, 53 ref_keys: Optional[Sequence[str]] = None, 54 split_data_inputs: bool = False, 55 np_rng: Optional[np.random.Generator] = None, 56 train_val_split: bool = True, 57 training_parameters: dict = {}, 58 add_flags: Sequence[str] = ["training"], 59 fprec: str = "float32", 60): 61 """ 62 Load a dataset from a pickle file and return two iterators for training and validation batches. 63 64 Args: 65 training_parameters (dict): A dictionary with the following keys: 66 - 'dspath': str. Path to the pickle file containing the dataset. 67 - 'batch_size': int. Number of samples per batch. 68 rename_refs (list, optional): A list of strings with the names of the reference properties to rename. 69 Default is an empty list. 70 71 Returns: 72 tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. 73 For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: 74 - 'natoms': np.ndarray. Array with the number of atoms in each system. 75 - 'batch_index': np.ndarray. Array with the index of the system to which each atom 76 if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name. 77 """ 78 79 assert isinstance( 80 training_parameters, dict 81 ), "training_parameters must be a dictionary." 82 assert isinstance( 83 rename_refs, dict 84 ), "rename_refs must be a dictionary with the keys to rename." 85 86 pbc_training = training_parameters.get("pbc_training", False) 87 minimum_image = training_parameters.get("minimum_image", False) 88 89 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 90 91 input_keys = [ 92 "species", 93 "coordinates", 94 "natoms", 95 "batch_index", 96 "total_charge", 97 "flags", 98 ] 99 if pbc_training: 100 input_keys += ["cells", "reciprocal_cells"] 101 if atom_padding: 102 input_keys += ["true_atoms", "true_sys"] 103 if coordinates_ref_key is not None: 104 input_keys += ["system_index", "system_sign"] 105 106 flags = {f: None for f in add_flags} 107 if minimum_image and pbc_training: 108 flags["minimum_image"] = None 109 110 additional_input_keys = set(training_parameters.get("additional_input_keys", [])) 111 additional_input_keys_ = set() 112 for key in additional_input_keys: 113 if key not in input_keys: 114 additional_input_keys_.add(key) 115 additional_input_keys = additional_input_keys_ 116 117 all_inputs = set(input_keys + list(additional_input_keys)) 118 119 extract_all_keys = ref_keys is None 120 if ref_keys is not None: 121 ref_keys = set(ref_keys) 122 ref_keys_ = set() 123 for key in ref_keys: 124 if key not in all_inputs: 125 ref_keys_.add(key) 126 127 random_rotation = training_parameters.get("random_rotation", False) 128 if random_rotation: 129 assert np_rng is not None, "np_rng must be provided for adding noise." 130 131 apply_rotation = { 132 1: lambda x, r: x @ r, 133 -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r), 134 2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r), 135 } 136 valid_rotations = tuple(apply_rotation.keys()) 137 rotated_keys = { 138 "coordinates": 1, 139 "forces": 1, 140 "virial_tensor": 2, 141 "stress_tensor": 2, 142 "virial": 2, 143 "stress": 2, 144 } 145 if pbc_training: 146 rotated_keys["cells"] = 1 147 user_rotated_keys = dict(training_parameters.get("rotated_keys", {})) 148 for k, v in user_rotated_keys.items(): 149 assert ( 150 v in valid_rotations 151 ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}" 152 rotated_keys[k] = v 153 154 # rotated_vector_keys = set( 155 # ["coordinates", "forces"] 156 # + list(training_parameters.get("rotated_vector_keys", [])) 157 # ) 158 # if pbc_training: 159 # rotated_vector_keys.add("cells") 160 161 # rotated_tensor_keys = set( 162 # ["virial_tensor", "stress_tensor", "virial", "stress"] 163 # + list(training_parameters.get("rotated_tensor_keys", [])) 164 # ) 165 # assert rotated_vector_keys.isdisjoint( 166 # rotated_tensor_keys 167 # ), "Rotated vector keys and rotated tensor keys must be disjoint." 168 # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys) 169 170 print( 171 "Applying random rotations to the following keys if present:", 172 list(rotated_keys.keys()), 173 ) 174 175 def apply_random_rotations(output, nbatch): 176 euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3)) 177 r = [ 178 Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T 179 for i in range(nbatch) 180 ] 181 for k, l in rotated_keys.items(): 182 if k in output: 183 for i in range(nbatch): 184 output[k][i] = apply_rotation[l](output[k][i], r[i]) 185 186 else: 187 188 def apply_random_rotations(output, nbatch): 189 pass 190 191 if pbc_training: 192 print("Periodic boundary conditions are active.") 193 length_nopbc = training_parameters.get("length_nopbc", 1000.0) 194 195 def add_cell(d, output): 196 if "cell" not in d: 197 cell = np.asarray( 198 [ 199 [length_nopbc, 0.0, 0.0], 200 [0.0, length_nopbc, 0.0], 201 [0.0, 0.0, length_nopbc], 202 ], 203 dtype=fprec, 204 ) 205 else: 206 cell = np.asarray(d["cell"], dtype=fprec) 207 output["cells"].append(cell.reshape(1, 3, 3)) 208 209 else: 210 211 def add_cell(d, output): 212 if "cell" in d: 213 print( 214 "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions." 215 ) 216 217 if extract_all_keys: 218 219 def add_other_keys(d, output, atom_shift): 220 for k, v in d.items(): 221 if k in ("cell", "total_charge"): 222 continue 223 v_array = np.array(v) 224 # Shift atom number if necessary 225 if k.endswith("_atidx"): 226 v_array = v_array + atom_shift 227 output[k].append(v_array) 228 229 else: 230 231 def add_other_keys(d, output, atom_shift): 232 output["species"].append(np.asarray(d["species"])) 233 output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec)) 234 for k in additional_input_keys: 235 v_array = np.array(d[k]) 236 # Shift atom number if necessary 237 if k.endswith("_atidx"): 238 v_array = v_array + atom_shift 239 output[k].append(v_array) 240 for k in ref_keys_: 241 v_array = np.array(d[k]) 242 # Shift atom number if necessary 243 if k.endswith("_atidx"): 244 v_array = v_array + atom_shift 245 output[k].append(v_array) 246 if k + "_mask" in d: 247 output[k + "_mask"].append(np.asarray(d[k + "_mask"])) 248 249 def add_keys(d, output, atom_shift, batch_index): 250 nat = d["species"].shape[0] 251 252 output["natoms"].append(np.asarray([nat])) 253 output["batch_index"].append(np.asarray([batch_index] * nat)) 254 if "total_charge" not in d: 255 total_charge = np.asarray(0.0, dtype=fprec) 256 else: 257 total_charge = np.asarray(d["total_charge"], dtype=fprec) 258 output["total_charge"].append(total_charge) 259 260 add_cell(d, output) 261 add_other_keys(d, output, atom_shift) 262 263 return atom_shift + nat 264 265 def collate_fn_(batch): 266 output = defaultdict(list) 267 atom_shift = 0 268 batch_index = 0 269 270 for d in batch: 271 atom_shift = add_keys(d, output, atom_shift, batch_index) 272 batch_index += 1 273 274 if coordinates_ref_key is not None: 275 output["system_index"].append(np.asarray([batch_index - 1])) 276 output["system_sign"].append(np.asarray([1])) 277 if coordinates_ref_key in d: 278 dref = {**d, "coordinates": d[coordinates_ref_key]} 279 atom_shift = add_keys(dref, output, atom_shift, batch_index) 280 output["system_index"].append(np.asarray([batch_index - 1])) 281 output["system_sign"].append(np.asarray([-1])) 282 batch_index += 1 283 284 nbatch_ = len(output["natoms"]) 285 apply_random_rotations(output,nbatch_) 286 287 # Stack and concatenate the arrays 288 for k, v in output.items(): 289 if v[0].ndim == 0: 290 v = np.stack(v) 291 else: 292 v = np.concatenate(v, axis=0) 293 if np.issubdtype(v.dtype, np.floating): 294 v = v.astype(fprec) 295 output[k] = v 296 297 if "cells" in output and pbc_training: 298 output["reciprocal_cells"] = np.linalg.inv(output["cells"]) 299 300 # Rename necessary keys 301 # for key in rename_refs: 302 # if key in output: 303 # output["true_" + key] = output.pop(key) 304 for kold, knew in rename_refs.items(): 305 assert ( 306 knew not in output 307 ), f"Cannot rename key {kold} to {knew}. Key {knew} already present." 308 if kold in output: 309 output[knew] = output.pop(kold) 310 311 output["flags"] = flags 312 return output 313 314 collate_layers_train = [collate_fn_] 315 collate_layers_valid = [collate_fn_] 316 317 ### collate preprocessing 318 # add noise to the training data 319 noise_sigma = training_parameters.get("noise_sigma", None) 320 if noise_sigma is not None: 321 assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary" 322 323 for sigma in noise_sigma.values(): 324 assert sigma >= 0, "Noise sigma should be a positive number" 325 326 print("Adding noise to the training data:") 327 for key, sigma in noise_sigma.items(): 328 print(f" - {key} with sigma = {sigma}") 329 330 assert np_rng is not None, "np_rng must be provided for adding noise." 331 332 def collate_with_noise(batch): 333 for key, sigma in noise_sigma.items(): 334 if key in batch and sigma > 0: 335 batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype( 336 batch[key].dtype 337 ) 338 return batch 339 340 collate_layers_train.append(collate_with_noise) 341 342 if atom_padding: 343 padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0)) 344 padder_state = padder.init() 345 346 def collate_with_padding(batch): 347 padder_state_up, output = padder(padder_state, batch) 348 padder_state.update(padder_state_up) 349 return output 350 351 collate_layers_train.append(collate_with_padding) 352 collate_layers_valid.append(collate_with_padding) 353 354 if split_data_inputs: 355 356 # input_keys += additional_input_keys 357 # input_keys = set(input_keys) 358 print("Input keys:", all_inputs) 359 print("Ref keys:", ref_keys) 360 361 def collate_split(batch): 362 inputs = {} 363 refs = {} 364 for k, v in batch.items(): 365 if k in all_inputs: 366 inputs[k] = v 367 if k in ref_keys: 368 refs[k] = v 369 if k.endswith("_mask") and k[:-5] in ref_keys: 370 refs[k] = v 371 return inputs, refs 372 373 collate_layers_train.append(collate_split) 374 collate_layers_valid.append(collate_split) 375 376 ### apply all collate preprocessing 377 if len(collate_layers_train) == 1: 378 collate_fn_train = collate_layers_train[0] 379 else: 380 381 def collate_fn_train(batch): 382 for layer in collate_layers_train: 383 batch = layer(batch) 384 return batch 385 386 if len(collate_layers_valid) == 1: 387 collate_fn_valid = collate_layers_valid[0] 388 else: 389 390 def collate_fn_valid(batch): 391 for layer in collate_layers_valid: 392 batch = layer(batch) 393 return batch 394 395 # dspath = training_parameters.get("dspath", None) 396 print(f"Loading dataset from {dspath}...", end="") 397 # print(f" the following keys will be renamed if present : {rename_refs}") 398 sharded_training = False 399 if dspath.endswith(".db"): 400 dataset = {} 401 if train_val_split: 402 dataset["training"] = DBDataset(dspath, table="training") 403 dataset["validation"] = DBDataset(dspath, table="validation") 404 else: 405 dataset = DBDataset(dspath) 406 elif dspath.endswith(".h5") or dspath.endswith(".hdf5"): 407 dataset = {} 408 if train_val_split: 409 dataset["training"] = H5Dataset(dspath, table="training") 410 dataset["validation"] = H5Dataset(dspath, table="validation") 411 else: 412 dataset = H5Dataset(dspath) 413 elif dspath.endswith(".pkl") or dspath.endswith(".pickle"): 414 with open(dspath, "rb") as f: 415 dataset = pickle.load(f) 416 if not train_val_split and isinstance(dataset, dict): 417 dataset = dataset["training"] 418 elif os.path.isdir(dspath): 419 if train_val_split: 420 dataset = {} 421 with open(dspath + "/validation.pkl", "rb") as f: 422 dataset["validation"] = pickle.load(f) 423 else: 424 dataset = None 425 426 shard_files = sorted(glob.glob(dspath + "/training_*.pkl")) 427 nshards = len(shard_files) 428 if nshards == 0: 429 raise ValueError("No dataset shards found.") 430 elif nshards == 1: 431 with open(shard_files[0], "rb") as f: 432 if train_val_split: 433 dataset["training"] = pickle.load(f) 434 else: 435 dataset = pickle.load(f) 436 else: 437 print(f"Found {nshards} dataset shards.") 438 sharded_training = True 439 440 else: 441 raise ValueError( 442 f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'" 443 ) 444 print(" done.") 445 446 ### BUILD DATALOADERS 447 # batch_size = training_parameters.get("batch_size", 16) 448 shuffle = training_parameters.get("shuffle_dataset", True) 449 if train_val_split: 450 dataloader_validation = DataLoader( 451 dataset["validation"], 452 batch_size=batch_size, 453 shuffle=shuffle, 454 collate_fn=collate_fn_valid, 455 ) 456 457 if sharded_training: 458 459 def iterate_sharded_dataset(): 460 indices = np.arange(nshards) 461 if shuffle: 462 assert np_rng is not None, "np_rng must be provided for shuffling." 463 np_rng.shuffle(indices) 464 for i in indices: 465 filename = shard_files[i] 466 print(f"# Loading dataset shard from {filename}...", end="") 467 with open(filename, "rb") as f: 468 dataset = pickle.load(f) 469 print(" done.") 470 dataloader = DataLoader( 471 dataset, 472 batch_size=batch_size, 473 shuffle=shuffle, 474 collate_fn=collate_fn_train, 475 ) 476 for batch in dataloader: 477 yield batch 478 479 class DataLoaderSharded: 480 def __iter__(self): 481 return iterate_sharded_dataset() 482 483 dataloader_training = DataLoaderSharded() 484 else: 485 dataloader_training = DataLoader( 486 dataset["training"] if train_val_split else dataset, 487 batch_size=batch_size, 488 shuffle=shuffle, 489 collate_fn=collate_fn_train, 490 ) 491 492 if not infinite_iterator: 493 if train_val_split: 494 return dataloader_training, dataloader_validation 495 return dataloader_training 496 497 def next_batch_factory(dataloader): 498 while True: 499 for batch in dataloader: 500 yield batch 501 502 training_iterator = next_batch_factory(dataloader_training) 503 if train_val_split: 504 validation_iterator = next_batch_factory(dataloader_validation) 505 return training_iterator, validation_iterator 506 return training_iterator 507 508 509def load_model( 510 parameters: Dict[str, any], 511 model_file: Optional[str] = None, 512 rng_key: Optional[str] = None, 513) -> FENNIX: 514 """ 515 Load a FENNIX model from a file or create a new one. 516 517 Args: 518 parameters (dict): A dictionary of parameters for the model. 519 model_file (str, optional): The path to a saved model file to load. 520 521 Returns: 522 FENNIX: A FENNIX model object. 523 """ 524 print_model = parameters["training"].get("print_model", False) 525 if model_file is None: 526 model_file = parameters.get("model_file", None) 527 if model_file is not None and os.path.exists(model_file): 528 model = FENNIX.load(model_file, use_atom_padding=False) 529 if print_model: 530 print(model.summarize()) 531 print(f"Restored model from '{model_file}'.") 532 else: 533 assert ( 534 rng_key is not None 535 ), "rng_key must be specified if model_file is not provided." 536 model_params = parameters["model"] 537 if isinstance(model_params, str): 538 assert os.path.exists( 539 model_params 540 ), f"Model file '{model_params}' not found." 541 model = FENNIX.load(model_params, use_atom_padding=False) 542 print(f"Restored model from '{model_params}'.") 543 else: 544 model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False) 545 if print_model: 546 print(model.summarize()) 547 return model 548 549 550def copy_parameters(variables, variables_ref, params): 551 def merge_params(full_path_, v, v_ref): 552 full_path = "/".join(full_path_[1:]).lower() 553 status = (False, "") 554 for path in params: 555 if full_path.startswith(path.lower()) and len(path) > len(status[1]): 556 status = (True, path) 557 return v_ref if status[0] else v 558 559 flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False) 560 flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False) 561 return traverse_util.unflatten_dict( 562 { 563 k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v 564 for k, v in flat.items() 565 } 566 ) 567 568 569class TeeLogger(object): 570 def __init__(self, name): 571 self.file = io.TextIOWrapper(open(name, "wb"), write_through=True) 572 self.stdout = None 573 574 def __del__(self): 575 self.close() 576 577 def write(self, data): 578 self.file.write(data) 579 self.stdout.write(data) 580 self.flush() 581 582 def close(self): 583 self.file.close() 584 585 def flush(self): 586 self.file.flush() 587 588 def bind_stdout(self): 589 if isinstance(sys.stdout, TeeLogger): 590 raise ValueError("stdout already bound to a Tee instance.") 591 if self.stdout is not None: 592 raise ValueError("stdout already bound.") 593 self.stdout = sys.stdout 594 sys.stdout = self 595 596 def unbind_stdout(self): 597 if self.stdout is None: 598 raise ValueError("stdout is not bound.") 599 sys.stdout = self.stdout
31def load_configuration(config_file: str) -> Dict[str, any]: 32 if config_file.endswith(".json"): 33 parameters = json.load(open(config_file)) 34 elif config_file.endswith(".yaml") or config_file.endswith(".yml"): 35 parameters = yaml.load(open(config_file), Loader=yaml.FullLoader) 36 elif tomlkit is not None and config_file.endswith(".toml"): 37 parameters = tomlkit.loads(open(config_file).read()) 38 else: 39 supported_formats = [".json", ".yaml", ".yml"] 40 if tomlkit is not None: 41 supported_formats.append(".toml") 42 raise ValueError( 43 f"Unknown config file format. Supported formats: {supported_formats}" 44 ) 45 return parameters
48def load_dataset( 49 dspath: str, 50 batch_size: int, 51 rename_refs: dict = {}, 52 infinite_iterator: bool = False, 53 atom_padding: bool = False, 54 ref_keys: Optional[Sequence[str]] = None, 55 split_data_inputs: bool = False, 56 np_rng: Optional[np.random.Generator] = None, 57 train_val_split: bool = True, 58 training_parameters: dict = {}, 59 add_flags: Sequence[str] = ["training"], 60 fprec: str = "float32", 61): 62 """ 63 Load a dataset from a pickle file and return two iterators for training and validation batches. 64 65 Args: 66 training_parameters (dict): A dictionary with the following keys: 67 - 'dspath': str. Path to the pickle file containing the dataset. 68 - 'batch_size': int. Number of samples per batch. 69 rename_refs (list, optional): A list of strings with the names of the reference properties to rename. 70 Default is an empty list. 71 72 Returns: 73 tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. 74 For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: 75 - 'natoms': np.ndarray. Array with the number of atoms in each system. 76 - 'batch_index': np.ndarray. Array with the index of the system to which each atom 77 if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name. 78 """ 79 80 assert isinstance( 81 training_parameters, dict 82 ), "training_parameters must be a dictionary." 83 assert isinstance( 84 rename_refs, dict 85 ), "rename_refs must be a dictionary with the keys to rename." 86 87 pbc_training = training_parameters.get("pbc_training", False) 88 minimum_image = training_parameters.get("minimum_image", False) 89 90 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 91 92 input_keys = [ 93 "species", 94 "coordinates", 95 "natoms", 96 "batch_index", 97 "total_charge", 98 "flags", 99 ] 100 if pbc_training: 101 input_keys += ["cells", "reciprocal_cells"] 102 if atom_padding: 103 input_keys += ["true_atoms", "true_sys"] 104 if coordinates_ref_key is not None: 105 input_keys += ["system_index", "system_sign"] 106 107 flags = {f: None for f in add_flags} 108 if minimum_image and pbc_training: 109 flags["minimum_image"] = None 110 111 additional_input_keys = set(training_parameters.get("additional_input_keys", [])) 112 additional_input_keys_ = set() 113 for key in additional_input_keys: 114 if key not in input_keys: 115 additional_input_keys_.add(key) 116 additional_input_keys = additional_input_keys_ 117 118 all_inputs = set(input_keys + list(additional_input_keys)) 119 120 extract_all_keys = ref_keys is None 121 if ref_keys is not None: 122 ref_keys = set(ref_keys) 123 ref_keys_ = set() 124 for key in ref_keys: 125 if key not in all_inputs: 126 ref_keys_.add(key) 127 128 random_rotation = training_parameters.get("random_rotation", False) 129 if random_rotation: 130 assert np_rng is not None, "np_rng must be provided for adding noise." 131 132 apply_rotation = { 133 1: lambda x, r: x @ r, 134 -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r), 135 2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r), 136 } 137 valid_rotations = tuple(apply_rotation.keys()) 138 rotated_keys = { 139 "coordinates": 1, 140 "forces": 1, 141 "virial_tensor": 2, 142 "stress_tensor": 2, 143 "virial": 2, 144 "stress": 2, 145 } 146 if pbc_training: 147 rotated_keys["cells"] = 1 148 user_rotated_keys = dict(training_parameters.get("rotated_keys", {})) 149 for k, v in user_rotated_keys.items(): 150 assert ( 151 v in valid_rotations 152 ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}" 153 rotated_keys[k] = v 154 155 # rotated_vector_keys = set( 156 # ["coordinates", "forces"] 157 # + list(training_parameters.get("rotated_vector_keys", [])) 158 # ) 159 # if pbc_training: 160 # rotated_vector_keys.add("cells") 161 162 # rotated_tensor_keys = set( 163 # ["virial_tensor", "stress_tensor", "virial", "stress"] 164 # + list(training_parameters.get("rotated_tensor_keys", [])) 165 # ) 166 # assert rotated_vector_keys.isdisjoint( 167 # rotated_tensor_keys 168 # ), "Rotated vector keys and rotated tensor keys must be disjoint." 169 # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys) 170 171 print( 172 "Applying random rotations to the following keys if present:", 173 list(rotated_keys.keys()), 174 ) 175 176 def apply_random_rotations(output, nbatch): 177 euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3)) 178 r = [ 179 Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T 180 for i in range(nbatch) 181 ] 182 for k, l in rotated_keys.items(): 183 if k in output: 184 for i in range(nbatch): 185 output[k][i] = apply_rotation[l](output[k][i], r[i]) 186 187 else: 188 189 def apply_random_rotations(output, nbatch): 190 pass 191 192 if pbc_training: 193 print("Periodic boundary conditions are active.") 194 length_nopbc = training_parameters.get("length_nopbc", 1000.0) 195 196 def add_cell(d, output): 197 if "cell" not in d: 198 cell = np.asarray( 199 [ 200 [length_nopbc, 0.0, 0.0], 201 [0.0, length_nopbc, 0.0], 202 [0.0, 0.0, length_nopbc], 203 ], 204 dtype=fprec, 205 ) 206 else: 207 cell = np.asarray(d["cell"], dtype=fprec) 208 output["cells"].append(cell.reshape(1, 3, 3)) 209 210 else: 211 212 def add_cell(d, output): 213 if "cell" in d: 214 print( 215 "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions." 216 ) 217 218 if extract_all_keys: 219 220 def add_other_keys(d, output, atom_shift): 221 for k, v in d.items(): 222 if k in ("cell", "total_charge"): 223 continue 224 v_array = np.array(v) 225 # Shift atom number if necessary 226 if k.endswith("_atidx"): 227 v_array = v_array + atom_shift 228 output[k].append(v_array) 229 230 else: 231 232 def add_other_keys(d, output, atom_shift): 233 output["species"].append(np.asarray(d["species"])) 234 output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec)) 235 for k in additional_input_keys: 236 v_array = np.array(d[k]) 237 # Shift atom number if necessary 238 if k.endswith("_atidx"): 239 v_array = v_array + atom_shift 240 output[k].append(v_array) 241 for k in ref_keys_: 242 v_array = np.array(d[k]) 243 # Shift atom number if necessary 244 if k.endswith("_atidx"): 245 v_array = v_array + atom_shift 246 output[k].append(v_array) 247 if k + "_mask" in d: 248 output[k + "_mask"].append(np.asarray(d[k + "_mask"])) 249 250 def add_keys(d, output, atom_shift, batch_index): 251 nat = d["species"].shape[0] 252 253 output["natoms"].append(np.asarray([nat])) 254 output["batch_index"].append(np.asarray([batch_index] * nat)) 255 if "total_charge" not in d: 256 total_charge = np.asarray(0.0, dtype=fprec) 257 else: 258 total_charge = np.asarray(d["total_charge"], dtype=fprec) 259 output["total_charge"].append(total_charge) 260 261 add_cell(d, output) 262 add_other_keys(d, output, atom_shift) 263 264 return atom_shift + nat 265 266 def collate_fn_(batch): 267 output = defaultdict(list) 268 atom_shift = 0 269 batch_index = 0 270 271 for d in batch: 272 atom_shift = add_keys(d, output, atom_shift, batch_index) 273 batch_index += 1 274 275 if coordinates_ref_key is not None: 276 output["system_index"].append(np.asarray([batch_index - 1])) 277 output["system_sign"].append(np.asarray([1])) 278 if coordinates_ref_key in d: 279 dref = {**d, "coordinates": d[coordinates_ref_key]} 280 atom_shift = add_keys(dref, output, atom_shift, batch_index) 281 output["system_index"].append(np.asarray([batch_index - 1])) 282 output["system_sign"].append(np.asarray([-1])) 283 batch_index += 1 284 285 nbatch_ = len(output["natoms"]) 286 apply_random_rotations(output,nbatch_) 287 288 # Stack and concatenate the arrays 289 for k, v in output.items(): 290 if v[0].ndim == 0: 291 v = np.stack(v) 292 else: 293 v = np.concatenate(v, axis=0) 294 if np.issubdtype(v.dtype, np.floating): 295 v = v.astype(fprec) 296 output[k] = v 297 298 if "cells" in output and pbc_training: 299 output["reciprocal_cells"] = np.linalg.inv(output["cells"]) 300 301 # Rename necessary keys 302 # for key in rename_refs: 303 # if key in output: 304 # output["true_" + key] = output.pop(key) 305 for kold, knew in rename_refs.items(): 306 assert ( 307 knew not in output 308 ), f"Cannot rename key {kold} to {knew}. Key {knew} already present." 309 if kold in output: 310 output[knew] = output.pop(kold) 311 312 output["flags"] = flags 313 return output 314 315 collate_layers_train = [collate_fn_] 316 collate_layers_valid = [collate_fn_] 317 318 ### collate preprocessing 319 # add noise to the training data 320 noise_sigma = training_parameters.get("noise_sigma", None) 321 if noise_sigma is not None: 322 assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary" 323 324 for sigma in noise_sigma.values(): 325 assert sigma >= 0, "Noise sigma should be a positive number" 326 327 print("Adding noise to the training data:") 328 for key, sigma in noise_sigma.items(): 329 print(f" - {key} with sigma = {sigma}") 330 331 assert np_rng is not None, "np_rng must be provided for adding noise." 332 333 def collate_with_noise(batch): 334 for key, sigma in noise_sigma.items(): 335 if key in batch and sigma > 0: 336 batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype( 337 batch[key].dtype 338 ) 339 return batch 340 341 collate_layers_train.append(collate_with_noise) 342 343 if atom_padding: 344 padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0)) 345 padder_state = padder.init() 346 347 def collate_with_padding(batch): 348 padder_state_up, output = padder(padder_state, batch) 349 padder_state.update(padder_state_up) 350 return output 351 352 collate_layers_train.append(collate_with_padding) 353 collate_layers_valid.append(collate_with_padding) 354 355 if split_data_inputs: 356 357 # input_keys += additional_input_keys 358 # input_keys = set(input_keys) 359 print("Input keys:", all_inputs) 360 print("Ref keys:", ref_keys) 361 362 def collate_split(batch): 363 inputs = {} 364 refs = {} 365 for k, v in batch.items(): 366 if k in all_inputs: 367 inputs[k] = v 368 if k in ref_keys: 369 refs[k] = v 370 if k.endswith("_mask") and k[:-5] in ref_keys: 371 refs[k] = v 372 return inputs, refs 373 374 collate_layers_train.append(collate_split) 375 collate_layers_valid.append(collate_split) 376 377 ### apply all collate preprocessing 378 if len(collate_layers_train) == 1: 379 collate_fn_train = collate_layers_train[0] 380 else: 381 382 def collate_fn_train(batch): 383 for layer in collate_layers_train: 384 batch = layer(batch) 385 return batch 386 387 if len(collate_layers_valid) == 1: 388 collate_fn_valid = collate_layers_valid[0] 389 else: 390 391 def collate_fn_valid(batch): 392 for layer in collate_layers_valid: 393 batch = layer(batch) 394 return batch 395 396 # dspath = training_parameters.get("dspath", None) 397 print(f"Loading dataset from {dspath}...", end="") 398 # print(f" the following keys will be renamed if present : {rename_refs}") 399 sharded_training = False 400 if dspath.endswith(".db"): 401 dataset = {} 402 if train_val_split: 403 dataset["training"] = DBDataset(dspath, table="training") 404 dataset["validation"] = DBDataset(dspath, table="validation") 405 else: 406 dataset = DBDataset(dspath) 407 elif dspath.endswith(".h5") or dspath.endswith(".hdf5"): 408 dataset = {} 409 if train_val_split: 410 dataset["training"] = H5Dataset(dspath, table="training") 411 dataset["validation"] = H5Dataset(dspath, table="validation") 412 else: 413 dataset = H5Dataset(dspath) 414 elif dspath.endswith(".pkl") or dspath.endswith(".pickle"): 415 with open(dspath, "rb") as f: 416 dataset = pickle.load(f) 417 if not train_val_split and isinstance(dataset, dict): 418 dataset = dataset["training"] 419 elif os.path.isdir(dspath): 420 if train_val_split: 421 dataset = {} 422 with open(dspath + "/validation.pkl", "rb") as f: 423 dataset["validation"] = pickle.load(f) 424 else: 425 dataset = None 426 427 shard_files = sorted(glob.glob(dspath + "/training_*.pkl")) 428 nshards = len(shard_files) 429 if nshards == 0: 430 raise ValueError("No dataset shards found.") 431 elif nshards == 1: 432 with open(shard_files[0], "rb") as f: 433 if train_val_split: 434 dataset["training"] = pickle.load(f) 435 else: 436 dataset = pickle.load(f) 437 else: 438 print(f"Found {nshards} dataset shards.") 439 sharded_training = True 440 441 else: 442 raise ValueError( 443 f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'" 444 ) 445 print(" done.") 446 447 ### BUILD DATALOADERS 448 # batch_size = training_parameters.get("batch_size", 16) 449 shuffle = training_parameters.get("shuffle_dataset", True) 450 if train_val_split: 451 dataloader_validation = DataLoader( 452 dataset["validation"], 453 batch_size=batch_size, 454 shuffle=shuffle, 455 collate_fn=collate_fn_valid, 456 ) 457 458 if sharded_training: 459 460 def iterate_sharded_dataset(): 461 indices = np.arange(nshards) 462 if shuffle: 463 assert np_rng is not None, "np_rng must be provided for shuffling." 464 np_rng.shuffle(indices) 465 for i in indices: 466 filename = shard_files[i] 467 print(f"# Loading dataset shard from {filename}...", end="") 468 with open(filename, "rb") as f: 469 dataset = pickle.load(f) 470 print(" done.") 471 dataloader = DataLoader( 472 dataset, 473 batch_size=batch_size, 474 shuffle=shuffle, 475 collate_fn=collate_fn_train, 476 ) 477 for batch in dataloader: 478 yield batch 479 480 class DataLoaderSharded: 481 def __iter__(self): 482 return iterate_sharded_dataset() 483 484 dataloader_training = DataLoaderSharded() 485 else: 486 dataloader_training = DataLoader( 487 dataset["training"] if train_val_split else dataset, 488 batch_size=batch_size, 489 shuffle=shuffle, 490 collate_fn=collate_fn_train, 491 ) 492 493 if not infinite_iterator: 494 if train_val_split: 495 return dataloader_training, dataloader_validation 496 return dataloader_training 497 498 def next_batch_factory(dataloader): 499 while True: 500 for batch in dataloader: 501 yield batch 502 503 training_iterator = next_batch_factory(dataloader_training) 504 if train_val_split: 505 validation_iterator = next_batch_factory(dataloader_validation) 506 return training_iterator, validation_iterator 507 return training_iterator
Load a dataset from a pickle file and return two iterators for training and validation batches.
Args: training_parameters (dict): A dictionary with the following keys: - 'dspath': str. Path to the pickle file containing the dataset. - 'batch_size': int. Number of samples per batch. rename_refs (list, optional): A list of strings with the names of the reference properties to rename. Default is an empty list.
Returns: tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: - 'natoms': np.ndarray. Array with the number of atoms in each system. - 'batch_index': np.ndarray. Array with the index of the system to which each atom if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
510def load_model( 511 parameters: Dict[str, any], 512 model_file: Optional[str] = None, 513 rng_key: Optional[str] = None, 514) -> FENNIX: 515 """ 516 Load a FENNIX model from a file or create a new one. 517 518 Args: 519 parameters (dict): A dictionary of parameters for the model. 520 model_file (str, optional): The path to a saved model file to load. 521 522 Returns: 523 FENNIX: A FENNIX model object. 524 """ 525 print_model = parameters["training"].get("print_model", False) 526 if model_file is None: 527 model_file = parameters.get("model_file", None) 528 if model_file is not None and os.path.exists(model_file): 529 model = FENNIX.load(model_file, use_atom_padding=False) 530 if print_model: 531 print(model.summarize()) 532 print(f"Restored model from '{model_file}'.") 533 else: 534 assert ( 535 rng_key is not None 536 ), "rng_key must be specified if model_file is not provided." 537 model_params = parameters["model"] 538 if isinstance(model_params, str): 539 assert os.path.exists( 540 model_params 541 ), f"Model file '{model_params}' not found." 542 model = FENNIX.load(model_params, use_atom_padding=False) 543 print(f"Restored model from '{model_params}'.") 544 else: 545 model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False) 546 if print_model: 547 print(model.summarize()) 548 return model
Load a FENNIX model from a file or create a new one.
Args: parameters (dict): A dictionary of parameters for the model. model_file (str, optional): The path to a saved model file to load.
Returns: FENNIX: A FENNIX model object.
551def copy_parameters(variables, variables_ref, params): 552 def merge_params(full_path_, v, v_ref): 553 full_path = "/".join(full_path_[1:]).lower() 554 status = (False, "") 555 for path in params: 556 if full_path.startswith(path.lower()) and len(path) > len(status[1]): 557 status = (True, path) 558 return v_ref if status[0] else v 559 560 flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False) 561 flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False) 562 return traverse_util.unflatten_dict( 563 { 564 k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v 565 for k, v in flat.items() 566 } 567 )
570class TeeLogger(object): 571 def __init__(self, name): 572 self.file = io.TextIOWrapper(open(name, "wb"), write_through=True) 573 self.stdout = None 574 575 def __del__(self): 576 self.close() 577 578 def write(self, data): 579 self.file.write(data) 580 self.stdout.write(data) 581 self.flush() 582 583 def close(self): 584 self.file.close() 585 586 def flush(self): 587 self.file.flush() 588 589 def bind_stdout(self): 590 if isinstance(sys.stdout, TeeLogger): 591 raise ValueError("stdout already bound to a Tee instance.") 592 if self.stdout is not None: 593 raise ValueError("stdout already bound.") 594 self.stdout = sys.stdout 595 sys.stdout = self 596 597 def unbind_stdout(self): 598 if self.stdout is None: 599 raise ValueError("stdout is not bound.") 600 sys.stdout = self.stdout