fennol.training.training
1import os 2import yaml 3import sys 4import jax 5import io 6import time 7import jax.numpy as jnp 8import numpy as np 9import optax 10from collections import defaultdict 11import json 12from copy import deepcopy 13from pathlib import Path 14import argparse 15try: 16 import torch 17except ImportError: 18 raise ImportError( 19 "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/" 20 ) 21import random 22from flax import traverse_util 23import json 24import shutil 25import pickle 26 27from flax.core import freeze, unfreeze 28from .io import ( 29 load_configuration, 30 load_dataset, 31 load_model, 32 TeeLogger, 33 copy_parameters, 34) 35from .utils import ( 36 get_loss_definition, 37 get_train_step_function, 38 get_validation_function, 39 linear_schedule, 40) 41from .optimizers import get_optimizer, get_lr_schedule 42from ..utils import deep_update, AtomicUnits as au 43from ..utils.io import human_time_duration 44from ..models.preprocessing import AtomPadding, check_input, convert_to_jax 45 46def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state): 47 with open(output_directory + "/restart_checkpoint.pkl", "wb") as f: 48 pickle.dump({ 49 "stage": stage, 50 "training": training_state, 51 "preproc_state": preproc_state, 52 "metrics_state": metrics_state, 53 }, f) 54 55def load_restart_checkpoint(output_directory): 56 restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl" 57 if not os.path.exists(restart_checkpoint_file): 58 raise FileNotFoundError( 59 f"Training state file not found: {restart_checkpoint_file}" 60 ) 61 with open(restart_checkpoint_file, "rb") as f: 62 restart_checkpoint = pickle.load(f) 63 return restart_checkpoint 64 65 66def main(): 67 parser = argparse.ArgumentParser(prog="fennol_train") 68 parser.add_argument("config_file", type=str) 69 parser.add_argument("--model_file", type=str, default=None) 70 args = parser.parse_args() 71 config_file = args.config_file 72 model_file = args.model_file 73 74 os.environ["OMP_NUM_THREADS"] = "1" 75 sys.stdout = io.TextIOWrapper( 76 open(sys.stdout.fileno(), "wb", 0), write_through=True 77 ) 78 79 restart_training = False 80 if os.path.isdir(config_file): 81 output_directory = Path(config_file).absolute().as_posix() 82 restart_training = True 83 while output_directory.endswith("/"): 84 output_directory = output_directory[:-1] 85 config_file = output_directory + "/config.yaml" 86 backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}" 87 shutil.copytree(output_directory, backup_dir) 88 print("Restarting training from", output_directory) 89 90 parameters = load_configuration(config_file) 91 92 ### Set the device 93 device: str = parameters.get("device", "cpu").lower() 94 if device == "cpu": 95 os.environ["CUDA_VISIBLE_DEVICES"] = "" 96 elif device.startswith("cuda") or device.startswith("gpu"): 97 if ":" in device: 98 num = device.split(":")[-1] 99 os.environ["CUDA_VISIBLE_DEVICES"] = num 100 else: 101 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 102 device = "gpu" 103 104 _device = jax.devices(device)[0] 105 jax.config.update("jax_default_device", _device) 106 107 if restart_training: 108 restart_checkpoint = load_restart_checkpoint(output_directory) 109 else: 110 restart_checkpoint = None 111 112 # output directory 113 if not restart_training: 114 output_directory = parameters.get("output_directory", None) 115 if output_directory is not None: 116 if "{now}" in output_directory: 117 output_directory = output_directory.replace( 118 "{now}", time.strftime("%Y-%m-%d-%H-%M-%S") 119 ) 120 output_directory = Path(output_directory).absolute() 121 if not output_directory.exists(): 122 output_directory.mkdir(parents=True) 123 print("Output directory:", output_directory) 124 else: 125 output_directory = "." 126 127 output_directory = str(output_directory) + "/" 128 129 # copy config_file to output directory 130 # config_name = Path(config_file).name 131 config_ext = Path(config_file).suffix 132 with open(config_file) as f_in: 133 config_data = f_in.read() 134 with open(output_directory + "/config" + config_ext, "w") as f_out: 135 f_out.write(config_data) 136 137 # set log file 138 log_file = "train.log" # parameters.get("log_file", None) 139 logger = TeeLogger(output_directory + log_file) 140 logger.bind_stdout() 141 142 # set matmul precision 143 enable_x64 = parameters.get("double_precision", False) 144 jax.config.update("jax_enable_x64", enable_x64) 145 fprec = "float64" if enable_x64 else "float32" 146 parameters["fprec"] = fprec 147 if enable_x64: 148 print("Double precision enabled.") 149 150 matmul_precision = parameters.get("matmul_prec", "highest").lower() 151 assert matmul_precision in [ 152 "default", 153 "high", 154 "highest", 155 ], "matmul_prec must be one of 'default','high','highest'" 156 jax.config.update("jax_default_matmul_precision", matmul_precision) 157 158 # set random seed 159 rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1)) 160 print(f"rng_seed: {rng_seed}") 161 rng_key = jax.random.PRNGKey(rng_seed) 162 torch.manual_seed(rng_seed) 163 np.random.seed(rng_seed) 164 random.seed(rng_seed) 165 np_rng = np.random.Generator(np.random.PCG64(rng_seed)) 166 167 try: 168 if "stages" in parameters["training"]: 169 ## train in stages ## 170 params = deepcopy(parameters) 171 stages = params["training"].pop("stages") 172 assert isinstance(stages, dict), "'stages' must be a dict with named stages" 173 model_file_stage = model_file 174 print_stages_params = params["training"].get("print_stages_params", False) 175 for i, (stage, stage_params) in enumerate(stages.items()): 176 rng_key, subkey = jax.random.split(rng_key) 177 print("") 178 print(f"### STAGE {i+1}: {stage} ###") 179 180 ## remove end_event from previous stage ## 181 if i > 0 and "end_event" in params["training"]: 182 params["training"].pop("end_event") 183 184 ## incrementally update training parameters ## 185 params = deep_update(params, {"training": stage_params}) 186 if model_file_stage is not None: 187 ## load model from previous stage ## 188 params["model_file"] = model_file_stage 189 190 if restart_training and restart_checkpoint["stage"] != i + 1: 191 print(f"Skipping stage {i+1} (already completed)") 192 continue 193 194 if print_stages_params: 195 print("stage parameters:") 196 print(json.dumps(params, indent=2, sort_keys=False)) 197 198 ## train stage ## 199 _, model_file_stage = train( 200 (subkey, np_rng), 201 params, 202 stage=i + 1, 203 output_directory=output_directory, 204 restart_checkpoint=restart_checkpoint, 205 ) 206 207 restart_training = False 208 restart_checkpoint = None 209 else: 210 ## single training stage ## 211 train( 212 (rng_key, np_rng), 213 parameters, 214 model_file=model_file, 215 output_directory=output_directory, 216 restart_checkpoint=restart_checkpoint, 217 ) 218 except KeyboardInterrupt: 219 print("Training interrupted by user.") 220 finally: 221 if log_file is not None: 222 logger.unbind_stdout() 223 logger.close() 224 225 226def train( 227 rng, 228 parameters, 229 model_file=None, 230 stage=None, 231 output_directory=None, 232 restart_checkpoint=None, 233): 234 if output_directory is None: 235 output_directory = "./" 236 elif output_directory == "": 237 output_directory = "./" 238 elif not output_directory.endswith("/"): 239 output_directory += "/" 240 stage_prefix = f"_stage_{stage}" if stage is not None else "" 241 242 if isinstance(rng, tuple): 243 rng_key, np_rng = rng 244 else: 245 rng_key = rng 246 np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1))) 247 248 ### GET TRAINING PARAMETERS ### 249 training_parameters = parameters.get("training", {}) 250 251 #### LOAD MODEL #### 252 if restart_checkpoint is not None: 253 model_key = None 254 model_file = output_directory + "latest_model.fnx" 255 else: 256 rng_key, model_key = jax.random.split(rng_key) 257 model = load_model(parameters, model_file, rng_key=model_key) 258 259 ### SAVE INITIAL MODEL ### 260 save_initial_model = ( 261 restart_checkpoint is None 262 and (stage is None or stage == 0) 263 ) 264 if save_initial_model: 265 model.save(output_directory + "initial_model.fnx") 266 267 model_ref = None 268 if "model_ref" in training_parameters: 269 model_ref = load_model(parameters, training_parameters["model_ref"]) 270 print("Reference model:", training_parameters["model_ref"]) 271 272 if "ref_parameters" in training_parameters: 273 ref_parameters = training_parameters["ref_parameters"] 274 assert isinstance( 275 ref_parameters, list 276 ), "ref_parameters must be a list of str" 277 print("Reference parameters:", ref_parameters) 278 model.variables = copy_parameters( 279 model.variables, model_ref.variables, ref_parameters 280 ) 281 282 ### SET FLOATING POINT PRECISION ### 283 fprec = parameters.get("fprec", "float32") 284 285 def convert_to_fprec(x): 286 if jnp.issubdtype(x.dtype, jnp.floating): 287 return x.astype(fprec) 288 return x 289 290 model.variables = jax.tree_map(convert_to_fprec, model.variables) 291 292 ### SET UP LOSS FUNCTION ### 293 loss_definition, used_keys, ref_keys = get_loss_definition( 294 training_parameters, model_energy_unit=model.energy_unit 295 ) 296 297 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 298 if coordinates_ref_key is not None: 299 compute_ref_coords = True 300 print("Reference coordinates:", coordinates_ref_key) 301 else: 302 compute_ref_coords = False 303 304 #### LOAD DATASET #### 305 dspath = training_parameters.get("dspath", None) 306 if dspath is None: 307 raise ValueError("Dataset path 'training/dspath' should be specified.") 308 batch_size = training_parameters.get("batch_size", 16) 309 batch_size_val = training_parameters.get("batch_size_val", None) 310 rename_refs = training_parameters.get("rename_refs", {}) 311 training_iterator, validation_iterator = load_dataset( 312 dspath=dspath, 313 batch_size=batch_size, 314 batch_size_val=batch_size_val, 315 training_parameters=training_parameters, 316 infinite_iterator=True, 317 atom_padding=True, 318 ref_keys=ref_keys, 319 split_data_inputs=True, 320 np_rng=np_rng, 321 add_flags=["training"], 322 fprec=fprec, 323 rename_refs=rename_refs, 324 ) 325 326 compute_forces = "forces" in used_keys 327 compute_virial = "virial_tensor" in used_keys or "virial" in used_keys 328 compute_stress = "stress_tensor" in used_keys or "stress" in used_keys 329 compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys 330 331 ### get optimizer parameters ### 332 max_epochs = int(training_parameters.get("max_epochs", 2000)) 333 epoch_format = len(str(max_epochs)) 334 nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200) 335 nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20) 336 337 schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters) 338 optimizer = get_optimizer( 339 training_parameters, model.variables, schedule(sch_state)[0] 340 ) 341 opt_st = optimizer.init(model.variables) 342 343 ### exponential moving average of the parameters ### 344 ema_decay = training_parameters.get("ema_decay", -1.0) 345 if ema_decay > 0.0: 346 assert ema_decay < 1.0, "ema_decay must be in (0,1)" 347 ema = optax.ema(decay=ema_decay) 348 else: 349 ema = optax.identity() 350 ema_st = ema.init(model.variables) 351 352 #### end event #### 353 end_event = training_parameters.get("end_event", None) 354 if end_event is None or isinstance(end_event, str) and end_event.lower() == "none": 355 is_end = lambda metrics: False 356 else: 357 assert len(end_event) == 2, "end_event must be a list of two elements" 358 is_end = lambda metrics: metrics[end_event[0]] < end_event[1] 359 360 if "energy_terms" in training_parameters: 361 model.set_energy_terms(training_parameters["energy_terms"], jit=False) 362 print("energy terms:", model.energy_terms) 363 364 ### MODEL EVALUATION FUNCTION ### 365 pbc_training = training_parameters.get("pbc_training", False) 366 if compute_stress or compute_virial or compute_pressure: 367 virial_key = "virial" if "virial" in used_keys else "virial_tensor" 368 stress_key = "stress" if "stress" in used_keys else "stress_tensor" 369 pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor" 370 if compute_stress or compute_pressure: 371 assert pbc_training, "PBC must be enabled for stress or virial training" 372 print("Computing forces and stress tensor") 373 def evaluate(model, variables, data): 374 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 375 cells = output["cells"] 376 volume = jnp.abs(jnp.linalg.det(cells)) 377 stress = vir / volume[:, None, None] 378 output[stress_key] = stress 379 output[virial_key] = vir 380 if pressure_key == "pressure": 381 output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0 382 else: 383 output[pressure_key] = -stress 384 return output 385 386 else: 387 print("Computing forces and virial tensor") 388 def evaluate(model, variables, data): 389 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 390 output[virial_key] = vir 391 return output 392 393 elif compute_forces: 394 print("Computing forces") 395 def evaluate(model, variables, data): 396 _, _, output = model._energy_and_forces(variables, data) 397 return output 398 399 elif model.energy_terms is not None: 400 def evaluate(model, variables, data): 401 _, output = model._total_energy(variables, data) 402 return output 403 404 else: 405 def evaluate(model, variables, data): 406 output = model.modules.apply(variables, data) 407 return output 408 409 ### TRAINING FUNCTIONS ### 410 train_step = get_train_step_function( 411 loss_definition=loss_definition, 412 model=model, 413 model_ref=model_ref, 414 compute_ref_coords=compute_ref_coords, 415 evaluate=evaluate, 416 optimizer=optimizer, 417 ema=ema, 418 ) 419 420 validation = get_validation_function( 421 loss_definition=loss_definition, 422 model=model, 423 model_ref=model_ref, 424 compute_ref_coords=compute_ref_coords, 425 evaluate=evaluate, 426 return_targets=False, 427 ) 428 429 #### CONFIGURE PREPROCESSING #### 430 gpu_preprocessing = training_parameters.get("gpu_preprocessing", False) 431 if gpu_preprocessing: 432 print("GPU preprocessing activated.") 433 434 minimum_image = training_parameters.get("minimum_image", False) 435 preproc_state = unfreeze(model.preproc_state) 436 layer_state = [] 437 for st in preproc_state["layers_state"]: 438 stnew = unfreeze(st) 439 # st["nblist_skin"] = nblist_skin 440 # if nblist_stride > 1: 441 # st["skin_stride"] = nblist_stride 442 # st["skin_count"] = nblist_stride 443 if pbc_training: 444 stnew["minimum_image"] = minimum_image 445 if "nblist_mult_size" in training_parameters: 446 stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 447 if "nblist_add_neigh" in training_parameters: 448 stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 449 if "nblist_add_atoms" in training_parameters: 450 stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 451 layer_state.append(freeze(stnew)) 452 453 preproc_state["layers_state"] = tuple(layer_state) 454 preproc_state["check_input"] = False 455 model.preproc_state = freeze(preproc_state) 456 457 ### INITIALIZE METRICS ### 458 maes_prev = defaultdict(lambda: np.inf) 459 metrics_beta = training_parameters.get("metrics_ema_decay", -1.0) 460 smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0 461 if smoothen_metrics: 462 print("Computing smoothed metrics with beta =", metrics_beta) 463 rmses_smooth = defaultdict(lambda: 0.0) 464 maes_smooth = defaultdict(lambda: 0.0) 465 rmse_tot_smooth = 0.0 466 mae_tot_smooth = 0.0 467 nsmooth = 0 468 max_restore_count = training_parameters.get("max_restore_count", 5) 469 470 fmetrics = open( 471 output_directory + f"metrics{stage_prefix}.traj", 472 "w" if restart_checkpoint is None else "a", 473 ) 474 475 keep_all_bests = training_parameters.get("keep_all_bests", False) 476 save_model_at_epochs = training_parameters.get("save_model_at_epochs", []) 477 save_model_at_epochs= set([int(x) for x in save_model_at_epochs]) 478 save_model_every = int(training_parameters.get("save_model_every_epoch", 0)) 479 if save_model_every > 0: 480 save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every)) 481 save_model_at_epochs = save_model_at_epochs.union(save_model_every) 482 save_model_at_epochs = list(save_model_at_epochs) 483 save_model_at_epochs.sort() 484 485 ### LR factor after restoring a model that has diverged 486 restore_scaling = training_parameters.get("restore_scaling", 1) 487 assert restore_scaling > 0.0, "restore_scaling must be > 0.0" 488 assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0" 489 if restore_scaling != 1: 490 print("Applying LR scaling after restore:", restore_scaling) 491 492 493 metric_use_best = training_parameters.get("metric_best", "rmse_tot") # .lower() 494 metrics_state = {} 495 496 ### INITIALIZE TRAINING STATE ### 497 if restart_checkpoint is not None: 498 training_state = deepcopy(restart_checkpoint["training"]) 499 epoch_start = training_state["epoch"] 500 model.preproc_state = freeze(restart_checkpoint["preproc_state"]) 501 if smoothen_metrics: 502 metrics_state = deepcopy(restart_checkpoint["metrics_state"]) 503 rmses_smooth = metrics_state["rmses_smooth"] 504 maes_smooth = metrics_state["maes_smooth"] 505 rmse_tot_smooth = metrics_state["rmse_tot_smooth"] 506 mae_tot_smooth = metrics_state["mae_tot_smooth"] 507 nsmooth = metrics_state["nsmooth"] 508 print("Restored training state") 509 else: 510 epoch_start = 1 511 training_state = { 512 "rng_key": rng_key, 513 "opt_state": opt_st, 514 "ema_state": ema_st, 515 "sch_state": sch_state, 516 "variables": model.variables, 517 "variables_ema": model.variables, 518 "step": 0, 519 "restore_count": 0, 520 "restore_scale": 1, 521 "epoch": 0, 522 "best_metric": np.inf, 523 } 524 if smoothen_metrics: 525 metrics_state = { 526 "rmses_smooth": dict(rmses_smooth), 527 "maes_smooth": dict(maes_smooth), 528 "rmse_tot_smooth": rmse_tot_smooth, 529 "mae_tot_smooth": mae_tot_smooth, 530 "nsmooth": nsmooth, 531 } 532 533 del opt_st, ema_st, sch_state, preproc_state 534 variables_save = training_state['variables'] 535 variables_ema_save = training_state['variables_ema'] 536 537 ### Training loop ### 538 start = time.time() 539 print("Starting training...") 540 for epoch in range(epoch_start, max_epochs+1): 541 training_state["epoch"] = epoch 542 s = time.time() 543 for _ in range(nbatch_per_epoch): 544 # fetch data 545 inputs0, data = next(training_iterator) 546 547 # preprocess data 548 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 549 550 rng_key, subkey = jax.random.split(rng_key) 551 inputs["rng_key"] = subkey 552 inputs["training_epoch"] = epoch 553 554 # adjust learning rate 555 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"]) 556 current_lr *= training_state["restore_scale"] 557 training_state["opt_state"].inner_states["trainable"].inner_state[-1].hyperparams[ 558 "step_size" 559 ] = current_lr 560 561 # train step 562 loss,training_state, _ = train_step(data,inputs,training_state) 563 564 rmses_avg = defaultdict(lambda: 0.0) 565 maes_avg = defaultdict(lambda: 0.0) 566 for _ in range(nbatch_per_validation): 567 inputs0, data = next(validation_iterator) 568 569 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 570 571 rng_key, subkey = jax.random.split(rng_key) 572 inputs["rng_key"] = subkey 573 inputs["training_epoch"] = epoch 574 575 rmses, maes, output_val = validation( 576 data=data, 577 inputs=inputs, 578 variables=training_state["variables_ema"], 579 ) 580 for k, v in rmses.items(): 581 rmses_avg[k] += v 582 for k, v in maes.items(): 583 maes_avg[k] += v 584 585 jax.block_until_ready(output_val) 586 587 ### Timings ### 588 e = time.time() 589 epoch_time = e - s 590 591 elapsed_time = e - start 592 if not adaptive_scheduler: 593 remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch 594 remain_last = epoch_time * (max_epochs - epoch + 1) 595 # estimate remaining time via weighted average (put more weight on last at the beginning) 596 wremain = np.sin(0.5 * np.pi * epoch / max_epochs) 597 remaining_time = human_time_duration( 598 remain_glob * wremain + remain_last * (1 - wremain) 599 ) 600 elapsed_time = human_time_duration(elapsed_time) 601 batch_time = human_time_duration( 602 epoch_time / (nbatch_per_epoch + nbatch_per_validation) 603 ) 604 epoch_time = human_time_duration(epoch_time) 605 606 for k in rmses_avg.keys(): 607 rmses_avg[k] /= nbatch_per_validation 608 for k in maes_avg.keys(): 609 maes_avg[k] /= nbatch_per_validation 610 611 ### Print metrics ### 612 print("") 613 line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}" 614 line += f", epoch time = {epoch_time}, batch time = {batch_time}" 615 line += f", elapsed time = {elapsed_time}" 616 if not adaptive_scheduler: 617 line += f", est. remaining time = {remaining_time}" 618 print(line) 619 rmse_tot = 0.0 620 mae_tot = 0.0 621 for k in rmses_avg.keys(): 622 mult = loss_definition[k]["mult"] 623 loss_prms = loss_definition[k] 624 rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5 625 mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"] 626 unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else "" 627 628 weight_str = "" 629 if "weight_schedule" in loss_prms: 630 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 631 weight_str = f"(w={w:.3f})" 632 633 if rmses_avg[k] / mult < 1.0e-2: 634 print( 635 f" rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e} {unit} {weight_str}" 636 ) 637 else: 638 print( 639 f" rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f} {unit} {weight_str}" 640 ) 641 642 ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ### 643 restore = False 644 reinit = False 645 for k, mae in maes_avg.items(): 646 if np.isnan(mae): 647 restore = True 648 reinit = True 649 print("NaN detected in mae") 650 # for k, v in inputs.items(): 651 # if hasattr(v,"shape"): 652 # if np.isnan(v).any(): 653 # print(k,v) 654 # sys.exit(1) 655 if "threshold" in loss_definition[k]: 656 thr = loss_definition[k]["threshold"] 657 if mae > thr * maes_prev[k]: 658 restore = True 659 break 660 661 epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S") 662 model.variables = training_state["variables_ema"] 663 if epoch in save_model_at_epochs: 664 filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx" 665 print("Scheduled model save :", filename) 666 model.save(filename) 667 668 if restore: 669 training_state["restore_count"] += 1 670 if training_state["restore_count"] > max_restore_count: 671 if reinit: 672 raise ValueError("Model diverged and could not be restored.") 673 else: 674 training_state["restore_count"] = 0 675 print( 676 f"{max_restore_count} unsuccessful restores, resuming training." 677 ) 678 679 training_state["variables"] = deepcopy(variables_save) 680 training_state["variables_ema"] = deepcopy(variables_ema_save) 681 model.variables = training_state["variables_ema"] 682 training_state["restore_scale"] *= restore_scaling 683 print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}") 684 if reinit: 685 training_state["opt_state"] = optimizer.init(model.variables) 686 training_state["ema_state"] = ema.init(model.variables) 687 print("Reinitialized optimizer after divergence.") 688 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 689 continue 690 691 training_state["restore_count"] = 0 692 693 variables_save = deepcopy(training_state["variables"]) 694 variables_ema_save = deepcopy(training_state["variables_ema"]) 695 maes_prev = maes_avg 696 697 ### SAVE MODEL ### 698 model.save(output_directory + "latest_model.fnx") 699 700 # save metrics 701 step = int(training_state["step"]) 702 metrics = { 703 "epoch": epoch, 704 "step": step, 705 "data_count": step * batch_size, 706 "elapsed_time": time.time() - start, 707 "lr": current_lr, 708 "loss": loss, 709 } 710 for k in rmses_avg.keys(): 711 mult = loss_definition[k]["mult"] 712 metrics[f"rmse_{k}"] = rmses_avg[k] / mult 713 metrics[f"mae_{k}"] = maes_avg[k] / mult 714 715 metrics["rmse_tot"] = rmse_tot 716 metrics["mae_tot"] = mae_tot 717 if smoothen_metrics: 718 nsmooth += 1 719 for k in rmses_avg.keys(): 720 mult = loss_definition[k]["mult"] 721 rmses_smooth[k] = ( 722 metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k] 723 ) 724 maes_smooth[k] = ( 725 metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k] 726 ) 727 metrics[f"rmse_smooth_{k}"] = ( 728 rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 729 ) 730 metrics[f"mae_smooth_{k}"] = ( 731 maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 732 ) 733 rmse_tot_smooth = ( 734 metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot 735 ) 736 mae_tot_smooth = ( 737 metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot 738 ) 739 metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth) 740 metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth) 741 742 # update metrics state 743 metrics_state["rmses_smooth"] = dict(rmses_smooth) 744 metrics_state["maes_smooth"] = dict(maes_smooth) 745 metrics_state["rmse_tot_smooth"] = rmse_tot_smooth 746 metrics_state["mae_tot_smooth"] = mae_tot_smooth 747 metrics_state["nsmooth"] = nsmooth 748 749 assert ( 750 metric_use_best in metrics 751 ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics" 752 753 metric_for_best = metrics[metric_use_best] 754 if metric_for_best < training_state["best_metric"]: 755 training_state["best_metric"] = metric_for_best 756 metrics["best_metric"] = training_state["best_metric"] 757 if keep_all_bests: 758 best_name = ( 759 output_directory 760 + f"best_model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx" 761 ) 762 model.save(best_name) 763 764 best_name = output_directory + f"best_model{stage_prefix}.fnx" 765 model.save(best_name) 766 print("New best model saved to:", best_name) 767 else: 768 metrics["best_metric"] = training_state["best_metric"] 769 770 if epoch == 1: 771 headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())] 772 fmetrics.write("# " + " ".join(headers) + "\n") 773 fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n") 774 fmetrics.flush() 775 776 # update learning rate using current metrics 777 assert ( 778 schedule_metrics in metrics 779 ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics" 780 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics]) 781 782 # update and save training state 783 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 784 785 786 if is_end(metrics): 787 print("Stage finished.") 788 break 789 790 end = time.time() 791 792 print(f"Training time: {human_time_duration(end-start)}") 793 print("") 794 795 fmetrics.close() 796 797 filename = output_directory + f"final_model{stage_prefix}.fnx" 798 model.save(filename) 799 print("Final model saved to:", filename) 800 801 return metrics, filename 802 803 804if __name__ == "__main__": 805 main()
def
save_restart_checkpoint( output_directory, stage, training_state, preproc_state, metrics_state):
47def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state): 48 with open(output_directory + "/restart_checkpoint.pkl", "wb") as f: 49 pickle.dump({ 50 "stage": stage, 51 "training": training_state, 52 "preproc_state": preproc_state, 53 "metrics_state": metrics_state, 54 }, f)
def
load_restart_checkpoint(output_directory):
56def load_restart_checkpoint(output_directory): 57 restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl" 58 if not os.path.exists(restart_checkpoint_file): 59 raise FileNotFoundError( 60 f"Training state file not found: {restart_checkpoint_file}" 61 ) 62 with open(restart_checkpoint_file, "rb") as f: 63 restart_checkpoint = pickle.load(f) 64 return restart_checkpoint
def
main():
67def main(): 68 parser = argparse.ArgumentParser(prog="fennol_train") 69 parser.add_argument("config_file", type=str) 70 parser.add_argument("--model_file", type=str, default=None) 71 args = parser.parse_args() 72 config_file = args.config_file 73 model_file = args.model_file 74 75 os.environ["OMP_NUM_THREADS"] = "1" 76 sys.stdout = io.TextIOWrapper( 77 open(sys.stdout.fileno(), "wb", 0), write_through=True 78 ) 79 80 restart_training = False 81 if os.path.isdir(config_file): 82 output_directory = Path(config_file).absolute().as_posix() 83 restart_training = True 84 while output_directory.endswith("/"): 85 output_directory = output_directory[:-1] 86 config_file = output_directory + "/config.yaml" 87 backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}" 88 shutil.copytree(output_directory, backup_dir) 89 print("Restarting training from", output_directory) 90 91 parameters = load_configuration(config_file) 92 93 ### Set the device 94 device: str = parameters.get("device", "cpu").lower() 95 if device == "cpu": 96 os.environ["CUDA_VISIBLE_DEVICES"] = "" 97 elif device.startswith("cuda") or device.startswith("gpu"): 98 if ":" in device: 99 num = device.split(":")[-1] 100 os.environ["CUDA_VISIBLE_DEVICES"] = num 101 else: 102 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 103 device = "gpu" 104 105 _device = jax.devices(device)[0] 106 jax.config.update("jax_default_device", _device) 107 108 if restart_training: 109 restart_checkpoint = load_restart_checkpoint(output_directory) 110 else: 111 restart_checkpoint = None 112 113 # output directory 114 if not restart_training: 115 output_directory = parameters.get("output_directory", None) 116 if output_directory is not None: 117 if "{now}" in output_directory: 118 output_directory = output_directory.replace( 119 "{now}", time.strftime("%Y-%m-%d-%H-%M-%S") 120 ) 121 output_directory = Path(output_directory).absolute() 122 if not output_directory.exists(): 123 output_directory.mkdir(parents=True) 124 print("Output directory:", output_directory) 125 else: 126 output_directory = "." 127 128 output_directory = str(output_directory) + "/" 129 130 # copy config_file to output directory 131 # config_name = Path(config_file).name 132 config_ext = Path(config_file).suffix 133 with open(config_file) as f_in: 134 config_data = f_in.read() 135 with open(output_directory + "/config" + config_ext, "w") as f_out: 136 f_out.write(config_data) 137 138 # set log file 139 log_file = "train.log" # parameters.get("log_file", None) 140 logger = TeeLogger(output_directory + log_file) 141 logger.bind_stdout() 142 143 # set matmul precision 144 enable_x64 = parameters.get("double_precision", False) 145 jax.config.update("jax_enable_x64", enable_x64) 146 fprec = "float64" if enable_x64 else "float32" 147 parameters["fprec"] = fprec 148 if enable_x64: 149 print("Double precision enabled.") 150 151 matmul_precision = parameters.get("matmul_prec", "highest").lower() 152 assert matmul_precision in [ 153 "default", 154 "high", 155 "highest", 156 ], "matmul_prec must be one of 'default','high','highest'" 157 jax.config.update("jax_default_matmul_precision", matmul_precision) 158 159 # set random seed 160 rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1)) 161 print(f"rng_seed: {rng_seed}") 162 rng_key = jax.random.PRNGKey(rng_seed) 163 torch.manual_seed(rng_seed) 164 np.random.seed(rng_seed) 165 random.seed(rng_seed) 166 np_rng = np.random.Generator(np.random.PCG64(rng_seed)) 167 168 try: 169 if "stages" in parameters["training"]: 170 ## train in stages ## 171 params = deepcopy(parameters) 172 stages = params["training"].pop("stages") 173 assert isinstance(stages, dict), "'stages' must be a dict with named stages" 174 model_file_stage = model_file 175 print_stages_params = params["training"].get("print_stages_params", False) 176 for i, (stage, stage_params) in enumerate(stages.items()): 177 rng_key, subkey = jax.random.split(rng_key) 178 print("") 179 print(f"### STAGE {i+1}: {stage} ###") 180 181 ## remove end_event from previous stage ## 182 if i > 0 and "end_event" in params["training"]: 183 params["training"].pop("end_event") 184 185 ## incrementally update training parameters ## 186 params = deep_update(params, {"training": stage_params}) 187 if model_file_stage is not None: 188 ## load model from previous stage ## 189 params["model_file"] = model_file_stage 190 191 if restart_training and restart_checkpoint["stage"] != i + 1: 192 print(f"Skipping stage {i+1} (already completed)") 193 continue 194 195 if print_stages_params: 196 print("stage parameters:") 197 print(json.dumps(params, indent=2, sort_keys=False)) 198 199 ## train stage ## 200 _, model_file_stage = train( 201 (subkey, np_rng), 202 params, 203 stage=i + 1, 204 output_directory=output_directory, 205 restart_checkpoint=restart_checkpoint, 206 ) 207 208 restart_training = False 209 restart_checkpoint = None 210 else: 211 ## single training stage ## 212 train( 213 (rng_key, np_rng), 214 parameters, 215 model_file=model_file, 216 output_directory=output_directory, 217 restart_checkpoint=restart_checkpoint, 218 ) 219 except KeyboardInterrupt: 220 print("Training interrupted by user.") 221 finally: 222 if log_file is not None: 223 logger.unbind_stdout() 224 logger.close()
def
train( rng, parameters, model_file=None, stage=None, output_directory=None, restart_checkpoint=None):
227def train( 228 rng, 229 parameters, 230 model_file=None, 231 stage=None, 232 output_directory=None, 233 restart_checkpoint=None, 234): 235 if output_directory is None: 236 output_directory = "./" 237 elif output_directory == "": 238 output_directory = "./" 239 elif not output_directory.endswith("/"): 240 output_directory += "/" 241 stage_prefix = f"_stage_{stage}" if stage is not None else "" 242 243 if isinstance(rng, tuple): 244 rng_key, np_rng = rng 245 else: 246 rng_key = rng 247 np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1))) 248 249 ### GET TRAINING PARAMETERS ### 250 training_parameters = parameters.get("training", {}) 251 252 #### LOAD MODEL #### 253 if restart_checkpoint is not None: 254 model_key = None 255 model_file = output_directory + "latest_model.fnx" 256 else: 257 rng_key, model_key = jax.random.split(rng_key) 258 model = load_model(parameters, model_file, rng_key=model_key) 259 260 ### SAVE INITIAL MODEL ### 261 save_initial_model = ( 262 restart_checkpoint is None 263 and (stage is None or stage == 0) 264 ) 265 if save_initial_model: 266 model.save(output_directory + "initial_model.fnx") 267 268 model_ref = None 269 if "model_ref" in training_parameters: 270 model_ref = load_model(parameters, training_parameters["model_ref"]) 271 print("Reference model:", training_parameters["model_ref"]) 272 273 if "ref_parameters" in training_parameters: 274 ref_parameters = training_parameters["ref_parameters"] 275 assert isinstance( 276 ref_parameters, list 277 ), "ref_parameters must be a list of str" 278 print("Reference parameters:", ref_parameters) 279 model.variables = copy_parameters( 280 model.variables, model_ref.variables, ref_parameters 281 ) 282 283 ### SET FLOATING POINT PRECISION ### 284 fprec = parameters.get("fprec", "float32") 285 286 def convert_to_fprec(x): 287 if jnp.issubdtype(x.dtype, jnp.floating): 288 return x.astype(fprec) 289 return x 290 291 model.variables = jax.tree_map(convert_to_fprec, model.variables) 292 293 ### SET UP LOSS FUNCTION ### 294 loss_definition, used_keys, ref_keys = get_loss_definition( 295 training_parameters, model_energy_unit=model.energy_unit 296 ) 297 298 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 299 if coordinates_ref_key is not None: 300 compute_ref_coords = True 301 print("Reference coordinates:", coordinates_ref_key) 302 else: 303 compute_ref_coords = False 304 305 #### LOAD DATASET #### 306 dspath = training_parameters.get("dspath", None) 307 if dspath is None: 308 raise ValueError("Dataset path 'training/dspath' should be specified.") 309 batch_size = training_parameters.get("batch_size", 16) 310 batch_size_val = training_parameters.get("batch_size_val", None) 311 rename_refs = training_parameters.get("rename_refs", {}) 312 training_iterator, validation_iterator = load_dataset( 313 dspath=dspath, 314 batch_size=batch_size, 315 batch_size_val=batch_size_val, 316 training_parameters=training_parameters, 317 infinite_iterator=True, 318 atom_padding=True, 319 ref_keys=ref_keys, 320 split_data_inputs=True, 321 np_rng=np_rng, 322 add_flags=["training"], 323 fprec=fprec, 324 rename_refs=rename_refs, 325 ) 326 327 compute_forces = "forces" in used_keys 328 compute_virial = "virial_tensor" in used_keys or "virial" in used_keys 329 compute_stress = "stress_tensor" in used_keys or "stress" in used_keys 330 compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys 331 332 ### get optimizer parameters ### 333 max_epochs = int(training_parameters.get("max_epochs", 2000)) 334 epoch_format = len(str(max_epochs)) 335 nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200) 336 nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20) 337 338 schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters) 339 optimizer = get_optimizer( 340 training_parameters, model.variables, schedule(sch_state)[0] 341 ) 342 opt_st = optimizer.init(model.variables) 343 344 ### exponential moving average of the parameters ### 345 ema_decay = training_parameters.get("ema_decay", -1.0) 346 if ema_decay > 0.0: 347 assert ema_decay < 1.0, "ema_decay must be in (0,1)" 348 ema = optax.ema(decay=ema_decay) 349 else: 350 ema = optax.identity() 351 ema_st = ema.init(model.variables) 352 353 #### end event #### 354 end_event = training_parameters.get("end_event", None) 355 if end_event is None or isinstance(end_event, str) and end_event.lower() == "none": 356 is_end = lambda metrics: False 357 else: 358 assert len(end_event) == 2, "end_event must be a list of two elements" 359 is_end = lambda metrics: metrics[end_event[0]] < end_event[1] 360 361 if "energy_terms" in training_parameters: 362 model.set_energy_terms(training_parameters["energy_terms"], jit=False) 363 print("energy terms:", model.energy_terms) 364 365 ### MODEL EVALUATION FUNCTION ### 366 pbc_training = training_parameters.get("pbc_training", False) 367 if compute_stress or compute_virial or compute_pressure: 368 virial_key = "virial" if "virial" in used_keys else "virial_tensor" 369 stress_key = "stress" if "stress" in used_keys else "stress_tensor" 370 pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor" 371 if compute_stress or compute_pressure: 372 assert pbc_training, "PBC must be enabled for stress or virial training" 373 print("Computing forces and stress tensor") 374 def evaluate(model, variables, data): 375 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 376 cells = output["cells"] 377 volume = jnp.abs(jnp.linalg.det(cells)) 378 stress = vir / volume[:, None, None] 379 output[stress_key] = stress 380 output[virial_key] = vir 381 if pressure_key == "pressure": 382 output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0 383 else: 384 output[pressure_key] = -stress 385 return output 386 387 else: 388 print("Computing forces and virial tensor") 389 def evaluate(model, variables, data): 390 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 391 output[virial_key] = vir 392 return output 393 394 elif compute_forces: 395 print("Computing forces") 396 def evaluate(model, variables, data): 397 _, _, output = model._energy_and_forces(variables, data) 398 return output 399 400 elif model.energy_terms is not None: 401 def evaluate(model, variables, data): 402 _, output = model._total_energy(variables, data) 403 return output 404 405 else: 406 def evaluate(model, variables, data): 407 output = model.modules.apply(variables, data) 408 return output 409 410 ### TRAINING FUNCTIONS ### 411 train_step = get_train_step_function( 412 loss_definition=loss_definition, 413 model=model, 414 model_ref=model_ref, 415 compute_ref_coords=compute_ref_coords, 416 evaluate=evaluate, 417 optimizer=optimizer, 418 ema=ema, 419 ) 420 421 validation = get_validation_function( 422 loss_definition=loss_definition, 423 model=model, 424 model_ref=model_ref, 425 compute_ref_coords=compute_ref_coords, 426 evaluate=evaluate, 427 return_targets=False, 428 ) 429 430 #### CONFIGURE PREPROCESSING #### 431 gpu_preprocessing = training_parameters.get("gpu_preprocessing", False) 432 if gpu_preprocessing: 433 print("GPU preprocessing activated.") 434 435 minimum_image = training_parameters.get("minimum_image", False) 436 preproc_state = unfreeze(model.preproc_state) 437 layer_state = [] 438 for st in preproc_state["layers_state"]: 439 stnew = unfreeze(st) 440 # st["nblist_skin"] = nblist_skin 441 # if nblist_stride > 1: 442 # st["skin_stride"] = nblist_stride 443 # st["skin_count"] = nblist_stride 444 if pbc_training: 445 stnew["minimum_image"] = minimum_image 446 if "nblist_mult_size" in training_parameters: 447 stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 448 if "nblist_add_neigh" in training_parameters: 449 stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 450 if "nblist_add_atoms" in training_parameters: 451 stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 452 layer_state.append(freeze(stnew)) 453 454 preproc_state["layers_state"] = tuple(layer_state) 455 preproc_state["check_input"] = False 456 model.preproc_state = freeze(preproc_state) 457 458 ### INITIALIZE METRICS ### 459 maes_prev = defaultdict(lambda: np.inf) 460 metrics_beta = training_parameters.get("metrics_ema_decay", -1.0) 461 smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0 462 if smoothen_metrics: 463 print("Computing smoothed metrics with beta =", metrics_beta) 464 rmses_smooth = defaultdict(lambda: 0.0) 465 maes_smooth = defaultdict(lambda: 0.0) 466 rmse_tot_smooth = 0.0 467 mae_tot_smooth = 0.0 468 nsmooth = 0 469 max_restore_count = training_parameters.get("max_restore_count", 5) 470 471 fmetrics = open( 472 output_directory + f"metrics{stage_prefix}.traj", 473 "w" if restart_checkpoint is None else "a", 474 ) 475 476 keep_all_bests = training_parameters.get("keep_all_bests", False) 477 save_model_at_epochs = training_parameters.get("save_model_at_epochs", []) 478 save_model_at_epochs= set([int(x) for x in save_model_at_epochs]) 479 save_model_every = int(training_parameters.get("save_model_every_epoch", 0)) 480 if save_model_every > 0: 481 save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every)) 482 save_model_at_epochs = save_model_at_epochs.union(save_model_every) 483 save_model_at_epochs = list(save_model_at_epochs) 484 save_model_at_epochs.sort() 485 486 ### LR factor after restoring a model that has diverged 487 restore_scaling = training_parameters.get("restore_scaling", 1) 488 assert restore_scaling > 0.0, "restore_scaling must be > 0.0" 489 assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0" 490 if restore_scaling != 1: 491 print("Applying LR scaling after restore:", restore_scaling) 492 493 494 metric_use_best = training_parameters.get("metric_best", "rmse_tot") # .lower() 495 metrics_state = {} 496 497 ### INITIALIZE TRAINING STATE ### 498 if restart_checkpoint is not None: 499 training_state = deepcopy(restart_checkpoint["training"]) 500 epoch_start = training_state["epoch"] 501 model.preproc_state = freeze(restart_checkpoint["preproc_state"]) 502 if smoothen_metrics: 503 metrics_state = deepcopy(restart_checkpoint["metrics_state"]) 504 rmses_smooth = metrics_state["rmses_smooth"] 505 maes_smooth = metrics_state["maes_smooth"] 506 rmse_tot_smooth = metrics_state["rmse_tot_smooth"] 507 mae_tot_smooth = metrics_state["mae_tot_smooth"] 508 nsmooth = metrics_state["nsmooth"] 509 print("Restored training state") 510 else: 511 epoch_start = 1 512 training_state = { 513 "rng_key": rng_key, 514 "opt_state": opt_st, 515 "ema_state": ema_st, 516 "sch_state": sch_state, 517 "variables": model.variables, 518 "variables_ema": model.variables, 519 "step": 0, 520 "restore_count": 0, 521 "restore_scale": 1, 522 "epoch": 0, 523 "best_metric": np.inf, 524 } 525 if smoothen_metrics: 526 metrics_state = { 527 "rmses_smooth": dict(rmses_smooth), 528 "maes_smooth": dict(maes_smooth), 529 "rmse_tot_smooth": rmse_tot_smooth, 530 "mae_tot_smooth": mae_tot_smooth, 531 "nsmooth": nsmooth, 532 } 533 534 del opt_st, ema_st, sch_state, preproc_state 535 variables_save = training_state['variables'] 536 variables_ema_save = training_state['variables_ema'] 537 538 ### Training loop ### 539 start = time.time() 540 print("Starting training...") 541 for epoch in range(epoch_start, max_epochs+1): 542 training_state["epoch"] = epoch 543 s = time.time() 544 for _ in range(nbatch_per_epoch): 545 # fetch data 546 inputs0, data = next(training_iterator) 547 548 # preprocess data 549 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 550 551 rng_key, subkey = jax.random.split(rng_key) 552 inputs["rng_key"] = subkey 553 inputs["training_epoch"] = epoch 554 555 # adjust learning rate 556 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"]) 557 current_lr *= training_state["restore_scale"] 558 training_state["opt_state"].inner_states["trainable"].inner_state[-1].hyperparams[ 559 "step_size" 560 ] = current_lr 561 562 # train step 563 loss,training_state, _ = train_step(data,inputs,training_state) 564 565 rmses_avg = defaultdict(lambda: 0.0) 566 maes_avg = defaultdict(lambda: 0.0) 567 for _ in range(nbatch_per_validation): 568 inputs0, data = next(validation_iterator) 569 570 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 571 572 rng_key, subkey = jax.random.split(rng_key) 573 inputs["rng_key"] = subkey 574 inputs["training_epoch"] = epoch 575 576 rmses, maes, output_val = validation( 577 data=data, 578 inputs=inputs, 579 variables=training_state["variables_ema"], 580 ) 581 for k, v in rmses.items(): 582 rmses_avg[k] += v 583 for k, v in maes.items(): 584 maes_avg[k] += v 585 586 jax.block_until_ready(output_val) 587 588 ### Timings ### 589 e = time.time() 590 epoch_time = e - s 591 592 elapsed_time = e - start 593 if not adaptive_scheduler: 594 remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch 595 remain_last = epoch_time * (max_epochs - epoch + 1) 596 # estimate remaining time via weighted average (put more weight on last at the beginning) 597 wremain = np.sin(0.5 * np.pi * epoch / max_epochs) 598 remaining_time = human_time_duration( 599 remain_glob * wremain + remain_last * (1 - wremain) 600 ) 601 elapsed_time = human_time_duration(elapsed_time) 602 batch_time = human_time_duration( 603 epoch_time / (nbatch_per_epoch + nbatch_per_validation) 604 ) 605 epoch_time = human_time_duration(epoch_time) 606 607 for k in rmses_avg.keys(): 608 rmses_avg[k] /= nbatch_per_validation 609 for k in maes_avg.keys(): 610 maes_avg[k] /= nbatch_per_validation 611 612 ### Print metrics ### 613 print("") 614 line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}" 615 line += f", epoch time = {epoch_time}, batch time = {batch_time}" 616 line += f", elapsed time = {elapsed_time}" 617 if not adaptive_scheduler: 618 line += f", est. remaining time = {remaining_time}" 619 print(line) 620 rmse_tot = 0.0 621 mae_tot = 0.0 622 for k in rmses_avg.keys(): 623 mult = loss_definition[k]["mult"] 624 loss_prms = loss_definition[k] 625 rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5 626 mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"] 627 unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else "" 628 629 weight_str = "" 630 if "weight_schedule" in loss_prms: 631 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 632 weight_str = f"(w={w:.3f})" 633 634 if rmses_avg[k] / mult < 1.0e-2: 635 print( 636 f" rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e} {unit} {weight_str}" 637 ) 638 else: 639 print( 640 f" rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f} {unit} {weight_str}" 641 ) 642 643 ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ### 644 restore = False 645 reinit = False 646 for k, mae in maes_avg.items(): 647 if np.isnan(mae): 648 restore = True 649 reinit = True 650 print("NaN detected in mae") 651 # for k, v in inputs.items(): 652 # if hasattr(v,"shape"): 653 # if np.isnan(v).any(): 654 # print(k,v) 655 # sys.exit(1) 656 if "threshold" in loss_definition[k]: 657 thr = loss_definition[k]["threshold"] 658 if mae > thr * maes_prev[k]: 659 restore = True 660 break 661 662 epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S") 663 model.variables = training_state["variables_ema"] 664 if epoch in save_model_at_epochs: 665 filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx" 666 print("Scheduled model save :", filename) 667 model.save(filename) 668 669 if restore: 670 training_state["restore_count"] += 1 671 if training_state["restore_count"] > max_restore_count: 672 if reinit: 673 raise ValueError("Model diverged and could not be restored.") 674 else: 675 training_state["restore_count"] = 0 676 print( 677 f"{max_restore_count} unsuccessful restores, resuming training." 678 ) 679 680 training_state["variables"] = deepcopy(variables_save) 681 training_state["variables_ema"] = deepcopy(variables_ema_save) 682 model.variables = training_state["variables_ema"] 683 training_state["restore_scale"] *= restore_scaling 684 print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}") 685 if reinit: 686 training_state["opt_state"] = optimizer.init(model.variables) 687 training_state["ema_state"] = ema.init(model.variables) 688 print("Reinitialized optimizer after divergence.") 689 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 690 continue 691 692 training_state["restore_count"] = 0 693 694 variables_save = deepcopy(training_state["variables"]) 695 variables_ema_save = deepcopy(training_state["variables_ema"]) 696 maes_prev = maes_avg 697 698 ### SAVE MODEL ### 699 model.save(output_directory + "latest_model.fnx") 700 701 # save metrics 702 step = int(training_state["step"]) 703 metrics = { 704 "epoch": epoch, 705 "step": step, 706 "data_count": step * batch_size, 707 "elapsed_time": time.time() - start, 708 "lr": current_lr, 709 "loss": loss, 710 } 711 for k in rmses_avg.keys(): 712 mult = loss_definition[k]["mult"] 713 metrics[f"rmse_{k}"] = rmses_avg[k] / mult 714 metrics[f"mae_{k}"] = maes_avg[k] / mult 715 716 metrics["rmse_tot"] = rmse_tot 717 metrics["mae_tot"] = mae_tot 718 if smoothen_metrics: 719 nsmooth += 1 720 for k in rmses_avg.keys(): 721 mult = loss_definition[k]["mult"] 722 rmses_smooth[k] = ( 723 metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k] 724 ) 725 maes_smooth[k] = ( 726 metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k] 727 ) 728 metrics[f"rmse_smooth_{k}"] = ( 729 rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 730 ) 731 metrics[f"mae_smooth_{k}"] = ( 732 maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 733 ) 734 rmse_tot_smooth = ( 735 metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot 736 ) 737 mae_tot_smooth = ( 738 metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot 739 ) 740 metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth) 741 metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth) 742 743 # update metrics state 744 metrics_state["rmses_smooth"] = dict(rmses_smooth) 745 metrics_state["maes_smooth"] = dict(maes_smooth) 746 metrics_state["rmse_tot_smooth"] = rmse_tot_smooth 747 metrics_state["mae_tot_smooth"] = mae_tot_smooth 748 metrics_state["nsmooth"] = nsmooth 749 750 assert ( 751 metric_use_best in metrics 752 ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics" 753 754 metric_for_best = metrics[metric_use_best] 755 if metric_for_best < training_state["best_metric"]: 756 training_state["best_metric"] = metric_for_best 757 metrics["best_metric"] = training_state["best_metric"] 758 if keep_all_bests: 759 best_name = ( 760 output_directory 761 + f"best_model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx" 762 ) 763 model.save(best_name) 764 765 best_name = output_directory + f"best_model{stage_prefix}.fnx" 766 model.save(best_name) 767 print("New best model saved to:", best_name) 768 else: 769 metrics["best_metric"] = training_state["best_metric"] 770 771 if epoch == 1: 772 headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())] 773 fmetrics.write("# " + " ".join(headers) + "\n") 774 fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n") 775 fmetrics.flush() 776 777 # update learning rate using current metrics 778 assert ( 779 schedule_metrics in metrics 780 ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics" 781 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics]) 782 783 # update and save training state 784 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 785 786 787 if is_end(metrics): 788 print("Stage finished.") 789 break 790 791 end = time.time() 792 793 print(f"Training time: {human_time_duration(end-start)}") 794 print("") 795 796 fmetrics.close() 797 798 filename = output_directory + f"final_model{stage_prefix}.fnx" 799 model.save(filename) 800 print("Final model saved to:", filename) 801 802 return metrics, filename