fennol.training.io
1import os, io, sys 2import numpy as np 3from scipy.spatial.transform import Rotation 4from collections import defaultdict 5import pickle 6import glob 7from flax import traverse_util 8from typing import Dict, List, Tuple, Union, Optional, Callable, Sequence 9from .databases import DBDataset, H5Dataset 10from ..models.preprocessing import AtomPadding 11import re 12import json 13import yaml 14 15try: 16 import tomlkit 17except ImportError: 18 tomlkit = None 19 20try: 21 from torch.utils.data import DataLoader 22except ImportError: 23 raise ImportError( 24 "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/" 25 ) 26 27from ..models import FENNIX 28 29 30def load_configuration(config_file: str) -> Dict[str, any]: 31 if config_file.endswith(".json"): 32 parameters = json.load(open(config_file)) 33 elif config_file.endswith(".yaml") or config_file.endswith(".yml"): 34 parameters = yaml.load(open(config_file), Loader=yaml.FullLoader) 35 elif tomlkit is not None and config_file.endswith(".toml"): 36 parameters = tomlkit.loads(open(config_file).read()) 37 else: 38 supported_formats = [".json", ".yaml", ".yml"] 39 if tomlkit is not None: 40 supported_formats.append(".toml") 41 raise ValueError( 42 f"Unknown config file format. Supported formats: {supported_formats}" 43 ) 44 return parameters 45 46 47def load_dataset( 48 dspath: str, 49 batch_size: int, 50 batch_size_val: Optional[int] = None, 51 rename_refs: dict = {}, 52 infinite_iterator: bool = False, 53 atom_padding: bool = False, 54 ref_keys: Optional[Sequence[str]] = None, 55 split_data_inputs: bool = False, 56 np_rng: Optional[np.random.Generator] = None, 57 train_val_split: bool = True, 58 training_parameters: dict = {}, 59 add_flags: Sequence[str] = ["training"], 60 fprec: str = "float32", 61): 62 """ 63 Load a dataset from a pickle file and return two iterators for training and validation batches. 64 65 Args: 66 training_parameters (dict): A dictionary with the following keys: 67 - 'dspath': str. Path to the pickle file containing the dataset. 68 - 'batch_size': int. Number of samples per batch. 69 rename_refs (list, optional): A list of strings with the names of the reference properties to rename. 70 Default is an empty list. 71 72 Returns: 73 tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. 74 For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: 75 - 'natoms': np.ndarray. Array with the number of atoms in each system. 76 - 'batch_index': np.ndarray. Array with the index of the system to which each atom 77 if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name. 78 """ 79 80 assert isinstance( 81 training_parameters, dict 82 ), "training_parameters must be a dictionary." 83 assert isinstance( 84 rename_refs, dict 85 ), "rename_refs must be a dictionary with the keys to rename." 86 87 pbc_training = training_parameters.get("pbc_training", False) 88 minimum_image = training_parameters.get("minimum_image", False) 89 90 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 91 92 input_keys = [ 93 "species", 94 "coordinates", 95 "natoms", 96 "batch_index", 97 "total_charge", 98 "flags", 99 ] 100 if pbc_training: 101 input_keys += ["cells", "reciprocal_cells"] 102 if atom_padding: 103 input_keys += ["true_atoms", "true_sys"] 104 if coordinates_ref_key is not None: 105 input_keys += ["system_index", "system_sign"] 106 107 flags = {f: None for f in add_flags} 108 if minimum_image and pbc_training: 109 flags["minimum_image"] = None 110 111 additional_input_keys = set(training_parameters.get("additional_input_keys", [])) 112 additional_input_keys_ = set() 113 for key in additional_input_keys: 114 if key not in input_keys: 115 additional_input_keys_.add(key) 116 additional_input_keys = additional_input_keys_ 117 118 all_inputs = set(input_keys + list(additional_input_keys)) 119 120 extract_all_keys = ref_keys is None 121 if ref_keys is not None: 122 ref_keys = set(ref_keys) 123 ref_keys_ = set() 124 for key in ref_keys: 125 if key not in all_inputs: 126 ref_keys_.add(key) 127 128 random_rotation = training_parameters.get("random_rotation", False) 129 if random_rotation: 130 assert np_rng is not None, "np_rng must be provided for adding noise." 131 132 apply_rotation = { 133 1: lambda x, r: x @ r, 134 -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r), 135 2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r), 136 } 137 def rotate_2f(x,r): 138 assert x.shape[-1]==6 139 # select from 6 components (xx,yy,zz,xy,xz,yz) to form the 3x3 tensor 140 indices = np.array([0,3,4,3,1,5,4,5,2]) 141 x=x[...,indices].reshape(*x.shape[:-1],3,3) 142 x=np.einsum("li,...lk,kj->...ij", r, x, r) 143 # select back the 6 components 144 indices = np.array([[0,0],[1,1],[2,2],[0,1],[0,2],[1,2]]) 145 x=x[...,indices[:,0],indices[:,1]] 146 return x 147 apply_rotation[-2]=rotate_2f 148 149 valid_rotations = tuple(apply_rotation.keys()) 150 rotated_keys = { 151 "coordinates": 1, 152 "forces": 1, 153 "virial_tensor": 2, 154 "stress_tensor": 2, 155 "virial": 2, 156 "stress": 2, 157 } 158 if pbc_training: 159 rotated_keys["cells"] = 1 160 user_rotated_keys = dict(training_parameters.get("rotated_keys", {})) 161 for k, v in user_rotated_keys.items(): 162 assert ( 163 v in valid_rotations 164 ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}" 165 rotated_keys[k] = v 166 167 # rotated_vector_keys = set( 168 # ["coordinates", "forces"] 169 # + list(training_parameters.get("rotated_vector_keys", [])) 170 # ) 171 # if pbc_training: 172 # rotated_vector_keys.add("cells") 173 174 # rotated_tensor_keys = set( 175 # ["virial_tensor", "stress_tensor", "virial", "stress"] 176 # + list(training_parameters.get("rotated_tensor_keys", [])) 177 # ) 178 # assert rotated_vector_keys.isdisjoint( 179 # rotated_tensor_keys 180 # ), "Rotated vector keys and rotated tensor keys must be disjoint." 181 # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys) 182 183 print( 184 "Applying random rotations to the following keys if present:", 185 list(rotated_keys.keys()), 186 ) 187 188 def apply_random_rotations(output, nbatch): 189 euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3)) 190 r = [ 191 Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T 192 for i in range(nbatch) 193 ] 194 for k, l in rotated_keys.items(): 195 if k in output: 196 for i in range(nbatch): 197 output[k][i] = apply_rotation[l](output[k][i], r[i]) 198 199 else: 200 201 def apply_random_rotations(output, nbatch): 202 pass 203 204 flow_matching = training_parameters.get("flow_matching", False) 205 if flow_matching: 206 if ref_keys is not None: 207 ref_keys.add("flow_matching_target") 208 if "flow_matching_target" in ref_keys_: 209 ref_keys_.remove("flow_matching_target") 210 all_inputs.add("flow_matching_time") 211 212 def add_flow_matching(output, nbatch): 213 ts = np_rng.uniform(0.0, 1.0, (nbatch,)) 214 targets = [] 215 for i in range(nbatch): 216 x1 = output["coordinates"][i] 217 com = x1.mean(axis=0, keepdims=True) 218 x1 = x1 - com 219 x0 = np_rng.normal(0.0, 1.0, x1.shape) 220 xt = (1 - ts[i]) * x0 + ts[i] * x1 221 output["coordinates"][i] = xt 222 targets.append(x1 - x0) 223 output["flow_matching_target"] = targets 224 output["flow_matching_time"] = [np.array(t) for t in ts] 225 226 else: 227 228 def add_flow_matching(output, nbatch): 229 pass 230 231 if pbc_training: 232 print("Periodic boundary conditions are active.") 233 length_nopbc = training_parameters.get("length_nopbc", 1000.0) 234 235 def add_cell(d, output): 236 if "cell" not in d: 237 cell = np.asarray( 238 [ 239 [length_nopbc, 0.0, 0.0], 240 [0.0, length_nopbc, 0.0], 241 [0.0, 0.0, length_nopbc], 242 ], 243 dtype=fprec, 244 ) 245 else: 246 cell = np.asarray(d["cell"], dtype=fprec) 247 output["cells"].append(cell.reshape(1, 3, 3)) 248 249 else: 250 251 def add_cell(d, output): 252 if "cell" in d: 253 print( 254 "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions." 255 ) 256 257 if extract_all_keys: 258 259 def add_other_keys(d, output, atom_shift): 260 for k, v in d.items(): 261 if k in ("cell", "total_charge"): 262 continue 263 v_array = np.array(v) 264 # Shift atom number if necessary 265 if k.endswith("_atidx"): 266 v_array = v_array + atom_shift 267 output[k].append(v_array) 268 269 else: 270 271 def add_other_keys(d, output, atom_shift): 272 output["species"].append(np.asarray(d["species"])) 273 output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec)) 274 for k in additional_input_keys: 275 v_array = np.array(d[k]) 276 # Shift atom number if necessary 277 if k.endswith("_atidx"): 278 v_array = v_array + atom_shift 279 output[k].append(v_array) 280 for k in ref_keys_: 281 v_array = np.array(d[k]) 282 # Shift atom number if necessary 283 if k.endswith("_atidx"): 284 v_array = v_array + atom_shift 285 output[k].append(v_array) 286 if k + "_mask" in d: 287 output[k + "_mask"].append(np.asarray(d[k + "_mask"])) 288 289 def add_keys(d, output, atom_shift, batch_index): 290 nat = d["species"].shape[0] 291 292 output["natoms"].append(np.asarray([nat])) 293 output["batch_index"].append(np.asarray([batch_index] * nat)) 294 if "total_charge" not in d: 295 total_charge = np.asarray(0.0, dtype=fprec) 296 else: 297 total_charge = np.asarray(d["total_charge"], dtype=fprec) 298 output["total_charge"].append(total_charge) 299 300 add_cell(d, output) 301 add_other_keys(d, output, atom_shift) 302 303 return atom_shift + nat 304 305 def collate_fn_(batch): 306 output = defaultdict(list) 307 atom_shift = 0 308 batch_index = 0 309 310 for d in batch: 311 atom_shift = add_keys(d, output, atom_shift, batch_index) 312 batch_index += 1 313 314 if coordinates_ref_key is not None: 315 output["system_index"].append(np.asarray([batch_index - 1])) 316 output["system_sign"].append(np.asarray([1])) 317 if coordinates_ref_key in d: 318 dref = {**d, "coordinates": d[coordinates_ref_key]} 319 atom_shift = add_keys(dref, output, atom_shift, batch_index) 320 output["system_index"].append(np.asarray([batch_index - 1])) 321 output["system_sign"].append(np.asarray([-1])) 322 batch_index += 1 323 324 nbatch_ = len(output["natoms"]) 325 apply_random_rotations(output,nbatch_) 326 add_flow_matching(output,nbatch_) 327 328 # Stack and concatenate the arrays 329 for k, v in output.items(): 330 if v[0].ndim == 0: 331 v = np.stack(v) 332 else: 333 v = np.concatenate(v, axis=0) 334 if np.issubdtype(v.dtype, np.floating): 335 v = v.astype(fprec) 336 output[k] = v 337 338 if "cells" in output and pbc_training: 339 output["reciprocal_cells"] = np.linalg.inv(output["cells"]) 340 341 # Rename necessary keys 342 # for key in rename_refs: 343 # if key in output: 344 # output["true_" + key] = output.pop(key) 345 for kold, knew in rename_refs.items(): 346 assert ( 347 knew not in output 348 ), f"Cannot rename key {kold} to {knew}. Key {knew} already present." 349 if kold in output: 350 output[knew] = output.pop(kold) 351 352 output["flags"] = flags 353 return output 354 355 collate_layers_train = [collate_fn_] 356 collate_layers_valid = [collate_fn_] 357 358 ### collate preprocessing 359 # add noise to the training data 360 noise_sigma = training_parameters.get("noise_sigma", None) 361 if noise_sigma is not None: 362 assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary" 363 364 for sigma in noise_sigma.values(): 365 assert sigma >= 0, "Noise sigma should be a positive number" 366 367 print("Adding noise to the training data:") 368 for key, sigma in noise_sigma.items(): 369 print(f" - {key} with sigma = {sigma}") 370 371 assert np_rng is not None, "np_rng must be provided for adding noise." 372 373 def collate_with_noise(batch): 374 for key, sigma in noise_sigma.items(): 375 if key in batch and sigma > 0: 376 batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype( 377 batch[key].dtype 378 ) 379 return batch 380 381 collate_layers_train.append(collate_with_noise) 382 383 if atom_padding: 384 padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0)) 385 padder_state = padder.init() 386 387 def collate_with_padding(batch): 388 padder_state_up, output = padder(padder_state, batch) 389 padder_state.update(padder_state_up) 390 return output 391 392 collate_layers_train.append(collate_with_padding) 393 collate_layers_valid.append(collate_with_padding) 394 395 if split_data_inputs: 396 397 # input_keys += additional_input_keys 398 # input_keys = set(input_keys) 399 print("Input keys:", all_inputs) 400 print("Ref keys:", ref_keys) 401 402 def collate_split(batch): 403 inputs = {} 404 refs = {} 405 for k, v in batch.items(): 406 if k in all_inputs: 407 inputs[k] = v 408 if k in ref_keys: 409 refs[k] = v 410 if k.endswith("_mask") and k[:-5] in ref_keys: 411 refs[k] = v 412 return inputs, refs 413 414 collate_layers_train.append(collate_split) 415 collate_layers_valid.append(collate_split) 416 417 ### apply all collate preprocessing 418 if len(collate_layers_train) == 1: 419 collate_fn_train = collate_layers_train[0] 420 else: 421 422 def collate_fn_train(batch): 423 for layer in collate_layers_train: 424 batch = layer(batch) 425 return batch 426 427 if len(collate_layers_valid) == 1: 428 collate_fn_valid = collate_layers_valid[0] 429 else: 430 431 def collate_fn_valid(batch): 432 for layer in collate_layers_valid: 433 batch = layer(batch) 434 return batch 435 436 if not os.path.exists(dspath): 437 raise ValueError(f"Dataset file '{dspath}' not found.") 438 # dspath = training_parameters.get("dspath", None) 439 print(f"Loading dataset from {dspath}...", end="") 440 # print(f" the following keys will be renamed if present : {rename_refs}") 441 sharded_training = False 442 if dspath.endswith(".db"): 443 dataset = {} 444 if train_val_split: 445 dataset["training"] = DBDataset(dspath, table="training") 446 dataset["validation"] = DBDataset(dspath, table="validation") 447 else: 448 dataset = DBDataset(dspath) 449 elif dspath.endswith(".h5") or dspath.endswith(".hdf5"): 450 dataset = {} 451 if train_val_split: 452 dataset["training"] = H5Dataset(dspath, table="training") 453 dataset["validation"] = H5Dataset(dspath, table="validation") 454 else: 455 dataset = H5Dataset(dspath) 456 elif dspath.endswith(".pkl") or dspath.endswith(".pickle"): 457 with open(dspath, "rb") as f: 458 dataset = pickle.load(f) 459 if not train_val_split and isinstance(dataset, dict): 460 dataset = dataset["training"] 461 elif os.path.isdir(dspath): 462 if train_val_split: 463 dataset = {} 464 with open(dspath + "/validation.pkl", "rb") as f: 465 dataset["validation"] = pickle.load(f) 466 else: 467 dataset = None 468 469 shard_files = sorted(glob.glob(dspath + "/training_*.pkl")) 470 nshards = len(shard_files) 471 if nshards == 0: 472 raise ValueError("No dataset shards found.") 473 elif nshards == 1: 474 with open(shard_files[0], "rb") as f: 475 if train_val_split: 476 dataset["training"] = pickle.load(f) 477 else: 478 dataset = pickle.load(f) 479 else: 480 print(f"Found {nshards} dataset shards.") 481 sharded_training = True 482 483 else: 484 raise ValueError( 485 f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'" 486 ) 487 print(" done.") 488 489 ### BUILD DATALOADERS 490 # batch_size = training_parameters.get("batch_size", 16) 491 shuffle = training_parameters.get("shuffle_dataset", True) 492 if train_val_split: 493 if batch_size_val is None: 494 batch_size_val = batch_size 495 dataloader_validation = DataLoader( 496 dataset["validation"], 497 batch_size=batch_size_val, 498 shuffle=shuffle, 499 collate_fn=collate_fn_valid, 500 ) 501 502 if sharded_training: 503 504 def iterate_sharded_dataset(): 505 indices = np.arange(nshards) 506 if shuffle: 507 assert np_rng is not None, "np_rng must be provided for shuffling." 508 np_rng.shuffle(indices) 509 for i in indices: 510 filename = shard_files[i] 511 print(f"# Loading dataset shard from {filename}...", end="") 512 with open(filename, "rb") as f: 513 dataset = pickle.load(f) 514 print(" done.") 515 dataloader = DataLoader( 516 dataset, 517 batch_size=batch_size, 518 shuffle=shuffle, 519 collate_fn=collate_fn_train, 520 ) 521 for batch in dataloader: 522 yield batch 523 524 class DataLoaderSharded: 525 def __iter__(self): 526 return iterate_sharded_dataset() 527 528 dataloader_training = DataLoaderSharded() 529 else: 530 dataloader_training = DataLoader( 531 dataset["training"] if train_val_split else dataset, 532 batch_size=batch_size, 533 shuffle=shuffle, 534 collate_fn=collate_fn_train, 535 ) 536 537 if not infinite_iterator: 538 if train_val_split: 539 return dataloader_training, dataloader_validation 540 return dataloader_training 541 542 def next_batch_factory(dataloader): 543 while True: 544 for batch in dataloader: 545 yield batch 546 547 training_iterator = next_batch_factory(dataloader_training) 548 if train_val_split: 549 validation_iterator = next_batch_factory(dataloader_validation) 550 return training_iterator, validation_iterator 551 return training_iterator 552 553 554def load_model( 555 parameters: Dict[str, any], 556 model_file: Optional[str] = None, 557 rng_key: Optional[str] = None, 558) -> FENNIX: 559 """ 560 Load a FENNIX model from a file or create a new one. 561 562 Args: 563 parameters (dict): A dictionary of parameters for the model. 564 model_file (str, optional): The path to a saved model file to load. 565 566 Returns: 567 FENNIX: A FENNIX model object. 568 """ 569 print_model = parameters["training"].get("print_model", False) 570 if model_file is None: 571 model_file = parameters.get("model_file", None) 572 if model_file is not None and os.path.exists(model_file): 573 model = FENNIX.load(model_file, use_atom_padding=False) 574 if print_model: 575 print(model.summarize()) 576 print(f"Restored model from '{model_file}'.") 577 else: 578 assert ( 579 rng_key is not None 580 ), "rng_key must be specified if model_file is not provided." 581 model_params = parameters["model"] 582 if isinstance(model_params, str): 583 assert os.path.exists( 584 model_params 585 ), f"Model file '{model_params}' not found." 586 model = FENNIX.load(model_params, use_atom_padding=False) 587 print(f"Restored model from '{model_params}'.") 588 else: 589 model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False) 590 if print_model: 591 print(model.summarize()) 592 return model 593 594 595def copy_parameters(variables, variables_ref, params=[".*"]): 596 def merge_params(full_path_, v, v_ref): 597 full_path = "/".join(full_path_[1:]).lower() 598 # status = (False, "") 599 for path in params: 600 if re.match(path.lower(), full_path): 601 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 602 return v_ref 603 return v 604 # return v_ref if status[0] else v 605 606 flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False) 607 flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False) 608 return traverse_util.unflatten_dict( 609 { 610 k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v 611 for k, v in flat.items() 612 } 613 ) 614 615 616class TeeLogger(object): 617 def __init__(self, name): 618 self.file = io.TextIOWrapper(open(name, "wb"), write_through=True) 619 self.stdout = None 620 621 def __del__(self): 622 self.close() 623 624 def write(self, data): 625 self.file.write(data) 626 self.stdout.write(data) 627 self.flush() 628 629 def close(self): 630 self.file.close() 631 632 def flush(self): 633 self.file.flush() 634 635 def bind_stdout(self): 636 if isinstance(sys.stdout, TeeLogger): 637 raise ValueError("stdout already bound to a Tee instance.") 638 if self.stdout is not None: 639 raise ValueError("stdout already bound.") 640 self.stdout = sys.stdout 641 sys.stdout = self 642 643 def unbind_stdout(self): 644 if self.stdout is None: 645 raise ValueError("stdout is not bound.") 646 sys.stdout = self.stdout
31def load_configuration(config_file: str) -> Dict[str, any]: 32 if config_file.endswith(".json"): 33 parameters = json.load(open(config_file)) 34 elif config_file.endswith(".yaml") or config_file.endswith(".yml"): 35 parameters = yaml.load(open(config_file), Loader=yaml.FullLoader) 36 elif tomlkit is not None and config_file.endswith(".toml"): 37 parameters = tomlkit.loads(open(config_file).read()) 38 else: 39 supported_formats = [".json", ".yaml", ".yml"] 40 if tomlkit is not None: 41 supported_formats.append(".toml") 42 raise ValueError( 43 f"Unknown config file format. Supported formats: {supported_formats}" 44 ) 45 return parameters
48def load_dataset( 49 dspath: str, 50 batch_size: int, 51 batch_size_val: Optional[int] = None, 52 rename_refs: dict = {}, 53 infinite_iterator: bool = False, 54 atom_padding: bool = False, 55 ref_keys: Optional[Sequence[str]] = None, 56 split_data_inputs: bool = False, 57 np_rng: Optional[np.random.Generator] = None, 58 train_val_split: bool = True, 59 training_parameters: dict = {}, 60 add_flags: Sequence[str] = ["training"], 61 fprec: str = "float32", 62): 63 """ 64 Load a dataset from a pickle file and return two iterators for training and validation batches. 65 66 Args: 67 training_parameters (dict): A dictionary with the following keys: 68 - 'dspath': str. Path to the pickle file containing the dataset. 69 - 'batch_size': int. Number of samples per batch. 70 rename_refs (list, optional): A list of strings with the names of the reference properties to rename. 71 Default is an empty list. 72 73 Returns: 74 tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. 75 For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: 76 - 'natoms': np.ndarray. Array with the number of atoms in each system. 77 - 'batch_index': np.ndarray. Array with the index of the system to which each atom 78 if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name. 79 """ 80 81 assert isinstance( 82 training_parameters, dict 83 ), "training_parameters must be a dictionary." 84 assert isinstance( 85 rename_refs, dict 86 ), "rename_refs must be a dictionary with the keys to rename." 87 88 pbc_training = training_parameters.get("pbc_training", False) 89 minimum_image = training_parameters.get("minimum_image", False) 90 91 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 92 93 input_keys = [ 94 "species", 95 "coordinates", 96 "natoms", 97 "batch_index", 98 "total_charge", 99 "flags", 100 ] 101 if pbc_training: 102 input_keys += ["cells", "reciprocal_cells"] 103 if atom_padding: 104 input_keys += ["true_atoms", "true_sys"] 105 if coordinates_ref_key is not None: 106 input_keys += ["system_index", "system_sign"] 107 108 flags = {f: None for f in add_flags} 109 if minimum_image and pbc_training: 110 flags["minimum_image"] = None 111 112 additional_input_keys = set(training_parameters.get("additional_input_keys", [])) 113 additional_input_keys_ = set() 114 for key in additional_input_keys: 115 if key not in input_keys: 116 additional_input_keys_.add(key) 117 additional_input_keys = additional_input_keys_ 118 119 all_inputs = set(input_keys + list(additional_input_keys)) 120 121 extract_all_keys = ref_keys is None 122 if ref_keys is not None: 123 ref_keys = set(ref_keys) 124 ref_keys_ = set() 125 for key in ref_keys: 126 if key not in all_inputs: 127 ref_keys_.add(key) 128 129 random_rotation = training_parameters.get("random_rotation", False) 130 if random_rotation: 131 assert np_rng is not None, "np_rng must be provided for adding noise." 132 133 apply_rotation = { 134 1: lambda x, r: x @ r, 135 -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r), 136 2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r), 137 } 138 def rotate_2f(x,r): 139 assert x.shape[-1]==6 140 # select from 6 components (xx,yy,zz,xy,xz,yz) to form the 3x3 tensor 141 indices = np.array([0,3,4,3,1,5,4,5,2]) 142 x=x[...,indices].reshape(*x.shape[:-1],3,3) 143 x=np.einsum("li,...lk,kj->...ij", r, x, r) 144 # select back the 6 components 145 indices = np.array([[0,0],[1,1],[2,2],[0,1],[0,2],[1,2]]) 146 x=x[...,indices[:,0],indices[:,1]] 147 return x 148 apply_rotation[-2]=rotate_2f 149 150 valid_rotations = tuple(apply_rotation.keys()) 151 rotated_keys = { 152 "coordinates": 1, 153 "forces": 1, 154 "virial_tensor": 2, 155 "stress_tensor": 2, 156 "virial": 2, 157 "stress": 2, 158 } 159 if pbc_training: 160 rotated_keys["cells"] = 1 161 user_rotated_keys = dict(training_parameters.get("rotated_keys", {})) 162 for k, v in user_rotated_keys.items(): 163 assert ( 164 v in valid_rotations 165 ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}" 166 rotated_keys[k] = v 167 168 # rotated_vector_keys = set( 169 # ["coordinates", "forces"] 170 # + list(training_parameters.get("rotated_vector_keys", [])) 171 # ) 172 # if pbc_training: 173 # rotated_vector_keys.add("cells") 174 175 # rotated_tensor_keys = set( 176 # ["virial_tensor", "stress_tensor", "virial", "stress"] 177 # + list(training_parameters.get("rotated_tensor_keys", [])) 178 # ) 179 # assert rotated_vector_keys.isdisjoint( 180 # rotated_tensor_keys 181 # ), "Rotated vector keys and rotated tensor keys must be disjoint." 182 # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys) 183 184 print( 185 "Applying random rotations to the following keys if present:", 186 list(rotated_keys.keys()), 187 ) 188 189 def apply_random_rotations(output, nbatch): 190 euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3)) 191 r = [ 192 Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T 193 for i in range(nbatch) 194 ] 195 for k, l in rotated_keys.items(): 196 if k in output: 197 for i in range(nbatch): 198 output[k][i] = apply_rotation[l](output[k][i], r[i]) 199 200 else: 201 202 def apply_random_rotations(output, nbatch): 203 pass 204 205 flow_matching = training_parameters.get("flow_matching", False) 206 if flow_matching: 207 if ref_keys is not None: 208 ref_keys.add("flow_matching_target") 209 if "flow_matching_target" in ref_keys_: 210 ref_keys_.remove("flow_matching_target") 211 all_inputs.add("flow_matching_time") 212 213 def add_flow_matching(output, nbatch): 214 ts = np_rng.uniform(0.0, 1.0, (nbatch,)) 215 targets = [] 216 for i in range(nbatch): 217 x1 = output["coordinates"][i] 218 com = x1.mean(axis=0, keepdims=True) 219 x1 = x1 - com 220 x0 = np_rng.normal(0.0, 1.0, x1.shape) 221 xt = (1 - ts[i]) * x0 + ts[i] * x1 222 output["coordinates"][i] = xt 223 targets.append(x1 - x0) 224 output["flow_matching_target"] = targets 225 output["flow_matching_time"] = [np.array(t) for t in ts] 226 227 else: 228 229 def add_flow_matching(output, nbatch): 230 pass 231 232 if pbc_training: 233 print("Periodic boundary conditions are active.") 234 length_nopbc = training_parameters.get("length_nopbc", 1000.0) 235 236 def add_cell(d, output): 237 if "cell" not in d: 238 cell = np.asarray( 239 [ 240 [length_nopbc, 0.0, 0.0], 241 [0.0, length_nopbc, 0.0], 242 [0.0, 0.0, length_nopbc], 243 ], 244 dtype=fprec, 245 ) 246 else: 247 cell = np.asarray(d["cell"], dtype=fprec) 248 output["cells"].append(cell.reshape(1, 3, 3)) 249 250 else: 251 252 def add_cell(d, output): 253 if "cell" in d: 254 print( 255 "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions." 256 ) 257 258 if extract_all_keys: 259 260 def add_other_keys(d, output, atom_shift): 261 for k, v in d.items(): 262 if k in ("cell", "total_charge"): 263 continue 264 v_array = np.array(v) 265 # Shift atom number if necessary 266 if k.endswith("_atidx"): 267 v_array = v_array + atom_shift 268 output[k].append(v_array) 269 270 else: 271 272 def add_other_keys(d, output, atom_shift): 273 output["species"].append(np.asarray(d["species"])) 274 output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec)) 275 for k in additional_input_keys: 276 v_array = np.array(d[k]) 277 # Shift atom number if necessary 278 if k.endswith("_atidx"): 279 v_array = v_array + atom_shift 280 output[k].append(v_array) 281 for k in ref_keys_: 282 v_array = np.array(d[k]) 283 # Shift atom number if necessary 284 if k.endswith("_atidx"): 285 v_array = v_array + atom_shift 286 output[k].append(v_array) 287 if k + "_mask" in d: 288 output[k + "_mask"].append(np.asarray(d[k + "_mask"])) 289 290 def add_keys(d, output, atom_shift, batch_index): 291 nat = d["species"].shape[0] 292 293 output["natoms"].append(np.asarray([nat])) 294 output["batch_index"].append(np.asarray([batch_index] * nat)) 295 if "total_charge" not in d: 296 total_charge = np.asarray(0.0, dtype=fprec) 297 else: 298 total_charge = np.asarray(d["total_charge"], dtype=fprec) 299 output["total_charge"].append(total_charge) 300 301 add_cell(d, output) 302 add_other_keys(d, output, atom_shift) 303 304 return atom_shift + nat 305 306 def collate_fn_(batch): 307 output = defaultdict(list) 308 atom_shift = 0 309 batch_index = 0 310 311 for d in batch: 312 atom_shift = add_keys(d, output, atom_shift, batch_index) 313 batch_index += 1 314 315 if coordinates_ref_key is not None: 316 output["system_index"].append(np.asarray([batch_index - 1])) 317 output["system_sign"].append(np.asarray([1])) 318 if coordinates_ref_key in d: 319 dref = {**d, "coordinates": d[coordinates_ref_key]} 320 atom_shift = add_keys(dref, output, atom_shift, batch_index) 321 output["system_index"].append(np.asarray([batch_index - 1])) 322 output["system_sign"].append(np.asarray([-1])) 323 batch_index += 1 324 325 nbatch_ = len(output["natoms"]) 326 apply_random_rotations(output,nbatch_) 327 add_flow_matching(output,nbatch_) 328 329 # Stack and concatenate the arrays 330 for k, v in output.items(): 331 if v[0].ndim == 0: 332 v = np.stack(v) 333 else: 334 v = np.concatenate(v, axis=0) 335 if np.issubdtype(v.dtype, np.floating): 336 v = v.astype(fprec) 337 output[k] = v 338 339 if "cells" in output and pbc_training: 340 output["reciprocal_cells"] = np.linalg.inv(output["cells"]) 341 342 # Rename necessary keys 343 # for key in rename_refs: 344 # if key in output: 345 # output["true_" + key] = output.pop(key) 346 for kold, knew in rename_refs.items(): 347 assert ( 348 knew not in output 349 ), f"Cannot rename key {kold} to {knew}. Key {knew} already present." 350 if kold in output: 351 output[knew] = output.pop(kold) 352 353 output["flags"] = flags 354 return output 355 356 collate_layers_train = [collate_fn_] 357 collate_layers_valid = [collate_fn_] 358 359 ### collate preprocessing 360 # add noise to the training data 361 noise_sigma = training_parameters.get("noise_sigma", None) 362 if noise_sigma is not None: 363 assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary" 364 365 for sigma in noise_sigma.values(): 366 assert sigma >= 0, "Noise sigma should be a positive number" 367 368 print("Adding noise to the training data:") 369 for key, sigma in noise_sigma.items(): 370 print(f" - {key} with sigma = {sigma}") 371 372 assert np_rng is not None, "np_rng must be provided for adding noise." 373 374 def collate_with_noise(batch): 375 for key, sigma in noise_sigma.items(): 376 if key in batch and sigma > 0: 377 batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype( 378 batch[key].dtype 379 ) 380 return batch 381 382 collate_layers_train.append(collate_with_noise) 383 384 if atom_padding: 385 padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0)) 386 padder_state = padder.init() 387 388 def collate_with_padding(batch): 389 padder_state_up, output = padder(padder_state, batch) 390 padder_state.update(padder_state_up) 391 return output 392 393 collate_layers_train.append(collate_with_padding) 394 collate_layers_valid.append(collate_with_padding) 395 396 if split_data_inputs: 397 398 # input_keys += additional_input_keys 399 # input_keys = set(input_keys) 400 print("Input keys:", all_inputs) 401 print("Ref keys:", ref_keys) 402 403 def collate_split(batch): 404 inputs = {} 405 refs = {} 406 for k, v in batch.items(): 407 if k in all_inputs: 408 inputs[k] = v 409 if k in ref_keys: 410 refs[k] = v 411 if k.endswith("_mask") and k[:-5] in ref_keys: 412 refs[k] = v 413 return inputs, refs 414 415 collate_layers_train.append(collate_split) 416 collate_layers_valid.append(collate_split) 417 418 ### apply all collate preprocessing 419 if len(collate_layers_train) == 1: 420 collate_fn_train = collate_layers_train[0] 421 else: 422 423 def collate_fn_train(batch): 424 for layer in collate_layers_train: 425 batch = layer(batch) 426 return batch 427 428 if len(collate_layers_valid) == 1: 429 collate_fn_valid = collate_layers_valid[0] 430 else: 431 432 def collate_fn_valid(batch): 433 for layer in collate_layers_valid: 434 batch = layer(batch) 435 return batch 436 437 if not os.path.exists(dspath): 438 raise ValueError(f"Dataset file '{dspath}' not found.") 439 # dspath = training_parameters.get("dspath", None) 440 print(f"Loading dataset from {dspath}...", end="") 441 # print(f" the following keys will be renamed if present : {rename_refs}") 442 sharded_training = False 443 if dspath.endswith(".db"): 444 dataset = {} 445 if train_val_split: 446 dataset["training"] = DBDataset(dspath, table="training") 447 dataset["validation"] = DBDataset(dspath, table="validation") 448 else: 449 dataset = DBDataset(dspath) 450 elif dspath.endswith(".h5") or dspath.endswith(".hdf5"): 451 dataset = {} 452 if train_val_split: 453 dataset["training"] = H5Dataset(dspath, table="training") 454 dataset["validation"] = H5Dataset(dspath, table="validation") 455 else: 456 dataset = H5Dataset(dspath) 457 elif dspath.endswith(".pkl") or dspath.endswith(".pickle"): 458 with open(dspath, "rb") as f: 459 dataset = pickle.load(f) 460 if not train_val_split and isinstance(dataset, dict): 461 dataset = dataset["training"] 462 elif os.path.isdir(dspath): 463 if train_val_split: 464 dataset = {} 465 with open(dspath + "/validation.pkl", "rb") as f: 466 dataset["validation"] = pickle.load(f) 467 else: 468 dataset = None 469 470 shard_files = sorted(glob.glob(dspath + "/training_*.pkl")) 471 nshards = len(shard_files) 472 if nshards == 0: 473 raise ValueError("No dataset shards found.") 474 elif nshards == 1: 475 with open(shard_files[0], "rb") as f: 476 if train_val_split: 477 dataset["training"] = pickle.load(f) 478 else: 479 dataset = pickle.load(f) 480 else: 481 print(f"Found {nshards} dataset shards.") 482 sharded_training = True 483 484 else: 485 raise ValueError( 486 f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'" 487 ) 488 print(" done.") 489 490 ### BUILD DATALOADERS 491 # batch_size = training_parameters.get("batch_size", 16) 492 shuffle = training_parameters.get("shuffle_dataset", True) 493 if train_val_split: 494 if batch_size_val is None: 495 batch_size_val = batch_size 496 dataloader_validation = DataLoader( 497 dataset["validation"], 498 batch_size=batch_size_val, 499 shuffle=shuffle, 500 collate_fn=collate_fn_valid, 501 ) 502 503 if sharded_training: 504 505 def iterate_sharded_dataset(): 506 indices = np.arange(nshards) 507 if shuffle: 508 assert np_rng is not None, "np_rng must be provided for shuffling." 509 np_rng.shuffle(indices) 510 for i in indices: 511 filename = shard_files[i] 512 print(f"# Loading dataset shard from {filename}...", end="") 513 with open(filename, "rb") as f: 514 dataset = pickle.load(f) 515 print(" done.") 516 dataloader = DataLoader( 517 dataset, 518 batch_size=batch_size, 519 shuffle=shuffle, 520 collate_fn=collate_fn_train, 521 ) 522 for batch in dataloader: 523 yield batch 524 525 class DataLoaderSharded: 526 def __iter__(self): 527 return iterate_sharded_dataset() 528 529 dataloader_training = DataLoaderSharded() 530 else: 531 dataloader_training = DataLoader( 532 dataset["training"] if train_val_split else dataset, 533 batch_size=batch_size, 534 shuffle=shuffle, 535 collate_fn=collate_fn_train, 536 ) 537 538 if not infinite_iterator: 539 if train_val_split: 540 return dataloader_training, dataloader_validation 541 return dataloader_training 542 543 def next_batch_factory(dataloader): 544 while True: 545 for batch in dataloader: 546 yield batch 547 548 training_iterator = next_batch_factory(dataloader_training) 549 if train_val_split: 550 validation_iterator = next_batch_factory(dataloader_validation) 551 return training_iterator, validation_iterator 552 return training_iterator
Load a dataset from a pickle file and return two iterators for training and validation batches.
Args: training_parameters (dict): A dictionary with the following keys: - 'dspath': str. Path to the pickle file containing the dataset. - 'batch_size': int. Number of samples per batch. rename_refs (list, optional): A list of strings with the names of the reference properties to rename. Default is an empty list.
Returns: tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: - 'natoms': np.ndarray. Array with the number of atoms in each system. - 'batch_index': np.ndarray. Array with the index of the system to which each atom if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
555def load_model( 556 parameters: Dict[str, any], 557 model_file: Optional[str] = None, 558 rng_key: Optional[str] = None, 559) -> FENNIX: 560 """ 561 Load a FENNIX model from a file or create a new one. 562 563 Args: 564 parameters (dict): A dictionary of parameters for the model. 565 model_file (str, optional): The path to a saved model file to load. 566 567 Returns: 568 FENNIX: A FENNIX model object. 569 """ 570 print_model = parameters["training"].get("print_model", False) 571 if model_file is None: 572 model_file = parameters.get("model_file", None) 573 if model_file is not None and os.path.exists(model_file): 574 model = FENNIX.load(model_file, use_atom_padding=False) 575 if print_model: 576 print(model.summarize()) 577 print(f"Restored model from '{model_file}'.") 578 else: 579 assert ( 580 rng_key is not None 581 ), "rng_key must be specified if model_file is not provided." 582 model_params = parameters["model"] 583 if isinstance(model_params, str): 584 assert os.path.exists( 585 model_params 586 ), f"Model file '{model_params}' not found." 587 model = FENNIX.load(model_params, use_atom_padding=False) 588 print(f"Restored model from '{model_params}'.") 589 else: 590 model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False) 591 if print_model: 592 print(model.summarize()) 593 return model
Load a FENNIX model from a file or create a new one.
Args: parameters (dict): A dictionary of parameters for the model. model_file (str, optional): The path to a saved model file to load.
Returns: FENNIX: A FENNIX model object.
596def copy_parameters(variables, variables_ref, params=[".*"]): 597 def merge_params(full_path_, v, v_ref): 598 full_path = "/".join(full_path_[1:]).lower() 599 # status = (False, "") 600 for path in params: 601 if re.match(path.lower(), full_path): 602 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 603 return v_ref 604 return v 605 # return v_ref if status[0] else v 606 607 flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False) 608 flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False) 609 return traverse_util.unflatten_dict( 610 { 611 k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v 612 for k, v in flat.items() 613 } 614 )
617class TeeLogger(object): 618 def __init__(self, name): 619 self.file = io.TextIOWrapper(open(name, "wb"), write_through=True) 620 self.stdout = None 621 622 def __del__(self): 623 self.close() 624 625 def write(self, data): 626 self.file.write(data) 627 self.stdout.write(data) 628 self.flush() 629 630 def close(self): 631 self.file.close() 632 633 def flush(self): 634 self.file.flush() 635 636 def bind_stdout(self): 637 if isinstance(sys.stdout, TeeLogger): 638 raise ValueError("stdout already bound to a Tee instance.") 639 if self.stdout is not None: 640 raise ValueError("stdout already bound.") 641 self.stdout = sys.stdout 642 sys.stdout = self 643 644 def unbind_stdout(self): 645 if self.stdout is None: 646 raise ValueError("stdout is not bound.") 647 sys.stdout = self.stdout