fennol.training.training
1import os 2import yaml 3import sys 4import jax 5import io 6import time 7import jax.numpy as jnp 8import numpy as np 9import optax 10from collections import defaultdict 11import json 12from copy import deepcopy 13from pathlib import Path 14import argparse 15try: 16 import torch 17except ImportError: 18 raise ImportError( 19 "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/" 20 ) 21import random 22from flax import traverse_util 23import json 24import shutil 25import pickle 26 27from flax.core import freeze, unfreeze 28from .io import ( 29 load_configuration, 30 load_dataset, 31 load_model, 32 TeeLogger, 33 copy_parameters, 34) 35from .utils import ( 36 get_loss_definition, 37 get_train_step_function, 38 get_validation_function, 39 linear_schedule, 40) 41from .optimizers import get_optimizer, get_lr_schedule 42from ..utils import deep_update, AtomicUnits as au 43from ..utils.io import human_time_duration 44from ..models.preprocessing import AtomPadding, check_input, convert_to_jax 45 46def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state): 47 with open(output_directory + "/restart_checkpoint.pkl", "wb") as f: 48 pickle.dump({ 49 "stage": stage, 50 "training": training_state, 51 "preproc_state": preproc_state, 52 "metrics_state": metrics_state, 53 }, f) 54 55def load_restart_checkpoint(output_directory): 56 restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl" 57 if not os.path.exists(restart_checkpoint_file): 58 raise FileNotFoundError( 59 f"Training state file not found: {restart_checkpoint_file}" 60 ) 61 with open(restart_checkpoint_file, "rb") as f: 62 restart_checkpoint = pickle.load(f) 63 return restart_checkpoint 64 65 66def main(): 67 parser = argparse.ArgumentParser(prog="fennol_train") 68 parser.add_argument("config_file", type=str) 69 parser.add_argument("--model_file", type=str, default=None) 70 args = parser.parse_args() 71 config_file = args.config_file 72 model_file = args.model_file 73 74 os.environ["OMP_NUM_THREADS"] = "1" 75 sys.stdout = io.TextIOWrapper( 76 open(sys.stdout.fileno(), "wb", 0), write_through=True 77 ) 78 79 restart_training = False 80 if os.path.isdir(config_file): 81 output_directory = Path(config_file).absolute().as_posix() 82 restart_training = True 83 while output_directory.endswith("/"): 84 output_directory = output_directory[:-1] 85 config_file = output_directory + "/config.yaml" 86 backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}" 87 shutil.copytree(output_directory, backup_dir) 88 print("Restarting training from", output_directory) 89 90 parameters = load_configuration(config_file) 91 92 ### Set the device 93 if "FENNOL_DEVICE" in os.environ: 94 device = os.environ["FENNOL_DEVICE"].lower() 95 print(f"# Setting device from env FENNOL_DEVICE={device}") 96 else: 97 device: str = parameters.get("device", "cpu").lower() 98 if device == "cpu": 99 jax.config.update('jax_platforms', 'cpu') 100 os.environ["CUDA_VISIBLE_DEVICES"] = "" 101 elif device.startswith("cuda") or device.startswith("gpu"): 102 if ":" in device: 103 num = device.split(":")[-1] 104 os.environ["CUDA_VISIBLE_DEVICES"] = num 105 else: 106 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 107 device = "gpu" 108 109 _device = jax.devices(device)[0] 110 jax.config.update("jax_default_device", _device) 111 112 if restart_training: 113 restart_checkpoint = load_restart_checkpoint(output_directory) 114 else: 115 restart_checkpoint = None 116 117 # output directory 118 if not restart_training: 119 output_directory = parameters.get("output_directory", None) 120 if output_directory is not None: 121 if "{now}" in output_directory: 122 output_directory = output_directory.replace( 123 "{now}", time.strftime("%Y-%m-%d-%H-%M-%S") 124 ) 125 output_directory = Path(output_directory).absolute() 126 if not output_directory.exists(): 127 output_directory.mkdir(parents=True) 128 print("Output directory:", output_directory) 129 else: 130 output_directory = "." 131 132 output_directory = str(output_directory) + "/" 133 134 # copy config_file to output directory 135 # config_name = Path(config_file).name 136 config_ext = Path(config_file).suffix 137 with open(config_file) as f_in: 138 config_data = f_in.read() 139 with open(output_directory + "/config" + config_ext, "w") as f_out: 140 f_out.write(config_data) 141 142 # set log file 143 log_file = "train.log" # parameters.get("log_file", None) 144 logger = TeeLogger(output_directory + log_file) 145 logger.bind_stdout() 146 147 # set matmul precision 148 enable_x64 = parameters.get("double_precision", False) 149 jax.config.update("jax_enable_x64", enable_x64) 150 fprec = "float64" if enable_x64 else "float32" 151 parameters["fprec"] = fprec 152 if enable_x64: 153 print("Double precision enabled.") 154 155 matmul_precision = parameters.get("matmul_prec", "highest").lower() 156 assert matmul_precision in [ 157 "default", 158 "high", 159 "highest", 160 ], "matmul_prec must be one of 'default','high','highest'" 161 jax.config.update("jax_default_matmul_precision", matmul_precision) 162 163 # set random seed 164 rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1)) 165 print(f"rng_seed: {rng_seed}") 166 rng_key = jax.random.PRNGKey(rng_seed) 167 torch.manual_seed(rng_seed) 168 np.random.seed(rng_seed) 169 random.seed(rng_seed) 170 np_rng = np.random.Generator(np.random.PCG64(rng_seed)) 171 172 try: 173 if "stages" in parameters["training"]: 174 ## train in stages ## 175 params = deepcopy(parameters) 176 stages = params["training"].pop("stages") 177 assert isinstance(stages, dict), "'stages' must be a dict with named stages" 178 model_file_stage = model_file 179 print_stages_params = params["training"].get("print_stages_params", False) 180 for i, (stage, stage_params) in enumerate(stages.items()): 181 rng_key, subkey = jax.random.split(rng_key) 182 print("") 183 print(f"### STAGE {i+1}: {stage} ###") 184 185 ## remove end_event from previous stage ## 186 if i > 0 and "end_event" in params["training"]: 187 params["training"].pop("end_event") 188 189 ## incrementally update training parameters ## 190 params = deep_update(params, {"training": stage_params}) 191 if model_file_stage is not None: 192 ## load model from previous stage ## 193 params["model_file"] = model_file_stage 194 195 if restart_training and restart_checkpoint["stage"] != i + 1: 196 print(f"Skipping stage {i+1} (already completed)") 197 continue 198 199 if print_stages_params: 200 print("stage parameters:") 201 print(json.dumps(params, indent=2, sort_keys=False)) 202 203 ## train stage ## 204 _, model_file_stage = train( 205 (subkey, np_rng), 206 params, 207 stage=i + 1, 208 output_directory=output_directory, 209 restart_checkpoint=restart_checkpoint, 210 ) 211 212 restart_training = False 213 restart_checkpoint = None 214 else: 215 ## single training stage ## 216 train( 217 (rng_key, np_rng), 218 parameters, 219 model_file=model_file, 220 output_directory=output_directory, 221 restart_checkpoint=restart_checkpoint, 222 ) 223 except KeyboardInterrupt: 224 print("Training interrupted by user.") 225 finally: 226 if log_file is not None: 227 logger.unbind_stdout() 228 logger.close() 229 230 231def train( 232 rng, 233 parameters, 234 model_file=None, 235 stage=None, 236 output_directory=None, 237 restart_checkpoint=None, 238): 239 if output_directory is None: 240 output_directory = "./" 241 elif output_directory == "": 242 output_directory = "./" 243 elif not output_directory.endswith("/"): 244 output_directory += "/" 245 stage_prefix = f"_stage_{stage}" if stage is not None else "" 246 247 if isinstance(rng, tuple): 248 rng_key, np_rng = rng 249 else: 250 rng_key = rng 251 np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1))) 252 253 ### GET TRAINING PARAMETERS ### 254 training_parameters = parameters.get("training", {}) 255 256 #### LOAD MODEL #### 257 if restart_checkpoint is not None: 258 model_key = None 259 model_file = output_directory + "latest_model.fnx" 260 else: 261 rng_key, model_key = jax.random.split(rng_key) 262 model = load_model(parameters, model_file, rng_key=model_key) 263 264 ### SAVE INITIAL MODEL ### 265 save_initial_model = ( 266 restart_checkpoint is None 267 and (stage is None or stage == 0) 268 ) 269 if save_initial_model: 270 model.save(output_directory + "initial_model.fnx") 271 272 model_ref = None 273 if "model_ref" in training_parameters: 274 model_ref = load_model(parameters, training_parameters["model_ref"]) 275 print("Reference model:", training_parameters["model_ref"]) 276 277 if "ref_parameters" in training_parameters: 278 ref_parameters = training_parameters["ref_parameters"] 279 assert isinstance( 280 ref_parameters, list 281 ), "ref_parameters must be a list of str" 282 print("Reference parameters:", ref_parameters) 283 model.variables = copy_parameters( 284 model.variables, model_ref.variables, ref_parameters 285 ) 286 287 ### SET FLOATING POINT PRECISION ### 288 fprec = parameters.get("fprec", "float32") 289 290 def convert_to_fprec(x): 291 if jnp.issubdtype(x.dtype, jnp.floating): 292 return x.astype(fprec) 293 return x 294 295 model.variables = jax.tree_map(convert_to_fprec, model.variables) 296 297 ### SET UP LOSS FUNCTION ### 298 loss_definition, used_keys, ref_keys = get_loss_definition( 299 training_parameters, model_energy_unit=model.energy_unit 300 ) 301 302 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 303 if coordinates_ref_key is not None: 304 compute_ref_coords = True 305 print("Reference coordinates:", coordinates_ref_key) 306 else: 307 compute_ref_coords = False 308 309 #### LOAD DATASET #### 310 dspath = training_parameters.get("dspath", None) 311 if dspath is None: 312 raise ValueError("Dataset path 'training/dspath' should be specified.") 313 batch_size = training_parameters.get("batch_size", 16) 314 batch_size_val = training_parameters.get("batch_size_val", None) 315 rename_refs = training_parameters.get("rename_refs", {}) 316 training_iterator, validation_iterator = load_dataset( 317 dspath=dspath, 318 batch_size=batch_size, 319 batch_size_val=batch_size_val, 320 training_parameters=training_parameters, 321 infinite_iterator=True, 322 atom_padding=True, 323 ref_keys=ref_keys, 324 split_data_inputs=True, 325 np_rng=np_rng, 326 add_flags=["training"], 327 fprec=fprec, 328 rename_refs=rename_refs, 329 ) 330 331 compute_forces = "forces" in used_keys 332 compute_virial = "virial_tensor" in used_keys or "virial" in used_keys 333 compute_stress = "stress_tensor" in used_keys or "stress" in used_keys 334 compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys 335 336 ### get optimizer parameters ### 337 max_epochs = int(training_parameters.get("max_epochs", 2000)) 338 epoch_format = len(str(max_epochs)) 339 nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200) 340 nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20) 341 342 schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters) 343 optimizer = get_optimizer( 344 training_parameters, model.variables, schedule(sch_state)[0] 345 ) 346 opt_st = optimizer.init(model.variables) 347 348 ### exponential moving average of the parameters ### 349 ema_decay = training_parameters.get("ema_decay", -1.0) 350 if ema_decay > 0.0: 351 assert ema_decay < 1.0, "ema_decay must be in (0,1)" 352 ema = optax.ema(decay=ema_decay) 353 else: 354 ema = optax.identity() 355 ema_st = ema.init(model.variables) 356 357 #### end event #### 358 end_event = training_parameters.get("end_event", None) 359 if end_event is None or isinstance(end_event, str) and end_event.lower() == "none": 360 is_end = lambda metrics: False 361 else: 362 assert len(end_event) == 2, "end_event must be a list of two elements" 363 is_end = lambda metrics: metrics[end_event[0]] < end_event[1] 364 365 if "energy_terms" in training_parameters: 366 model.set_energy_terms(training_parameters["energy_terms"], jit=False) 367 print("energy terms:", model.energy_terms) 368 369 ### MODEL EVALUATION FUNCTION ### 370 pbc_training = training_parameters.get("pbc_training", False) 371 if compute_stress or compute_virial or compute_pressure: 372 virial_key = "virial" if "virial" in used_keys else "virial_tensor" 373 stress_key = "stress" if "stress" in used_keys else "stress_tensor" 374 pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor" 375 if compute_stress or compute_pressure: 376 assert pbc_training, "PBC must be enabled for stress or virial training" 377 print("Computing forces and stress tensor") 378 def evaluate(model, variables, data): 379 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 380 cells = output["cells"] 381 volume = jnp.abs(jnp.linalg.det(cells)) 382 stress = vir / volume[:, None, None] 383 output[stress_key] = stress 384 output[virial_key] = vir 385 if pressure_key == "pressure": 386 output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0 387 else: 388 output[pressure_key] = -stress 389 return output 390 391 else: 392 print("Computing forces and virial tensor") 393 def evaluate(model, variables, data): 394 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 395 output[virial_key] = vir 396 return output 397 398 elif compute_forces: 399 print("Computing forces") 400 def evaluate(model, variables, data): 401 _, _, output = model._energy_and_forces(variables, data) 402 return output 403 404 elif model.energy_terms is not None: 405 def evaluate(model, variables, data): 406 _, output = model._total_energy(variables, data) 407 return output 408 409 else: 410 def evaluate(model, variables, data): 411 output = model.modules.apply(variables, data) 412 return output 413 414 ### TRAINING FUNCTIONS ### 415 train_step = get_train_step_function( 416 loss_definition=loss_definition, 417 model=model, 418 model_ref=model_ref, 419 compute_ref_coords=compute_ref_coords, 420 evaluate=evaluate, 421 optimizer=optimizer, 422 ema=ema, 423 ) 424 425 validation = get_validation_function( 426 loss_definition=loss_definition, 427 model=model, 428 model_ref=model_ref, 429 compute_ref_coords=compute_ref_coords, 430 evaluate=evaluate, 431 return_targets=False, 432 ) 433 434 #### CONFIGURE PREPROCESSING #### 435 gpu_preprocessing = training_parameters.get("gpu_preprocessing", False) 436 if gpu_preprocessing: 437 print("GPU preprocessing activated.") 438 439 preproc_state = unfreeze(model.preproc_state) 440 layer_state = [] 441 for st in preproc_state["layers_state"]: 442 stnew = unfreeze(st) 443 # st["nblist_skin"] = nblist_skin 444 # if nblist_stride > 1: 445 # st["skin_stride"] = nblist_stride 446 # st["skin_count"] = nblist_stride 447 if "nblist_mult_size" in training_parameters: 448 stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 449 if "nblist_add_neigh" in training_parameters: 450 stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 451 if "nblist_add_atoms" in training_parameters: 452 stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 453 layer_state.append(freeze(stnew)) 454 455 preproc_state["layers_state"] = tuple(layer_state) 456 preproc_state["check_input"] = False 457 model.preproc_state = freeze(preproc_state) 458 459 ### INITIALIZE METRICS ### 460 maes_prev = defaultdict(lambda: np.inf) 461 metrics_beta = training_parameters.get("metrics_ema_decay", -1.0) 462 smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0 463 if smoothen_metrics: 464 print("Computing smoothed metrics with beta =", metrics_beta) 465 rmses_smooth = defaultdict(lambda: 0.0) 466 maes_smooth = defaultdict(lambda: 0.0) 467 rmse_tot_smooth = 0.0 468 mae_tot_smooth = 0.0 469 nsmooth = 0 470 max_restore_count = training_parameters.get("max_restore_count", 5) 471 472 fmetrics = open( 473 output_directory + f"metrics{stage_prefix}.traj", 474 "w" if restart_checkpoint is None else "a", 475 ) 476 477 keep_all_bests = training_parameters.get("keep_all_bests", False) 478 save_model_at_epochs = training_parameters.get("save_model_at_epochs", []) 479 save_model_at_epochs= set([int(x) for x in save_model_at_epochs]) 480 save_model_every = int(training_parameters.get("save_model_every_epoch", 0)) 481 if save_model_every > 0: 482 save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every)) 483 save_model_at_epochs = save_model_at_epochs.union(save_model_every) 484 if "save_model_schedule" in training_parameters: 485 save_model_schedule_def = [int(i) for i in training_parameters["save_model_schedule"]] 486 save_model_schedule = [] 487 while len(save_model_schedule_def) > 0: 488 i = save_model_schedule_def.pop(0) 489 if i > 0: 490 save_model_schedule.append(i) 491 elif i < 0: 492 start = -i 493 freq = save_model_schedule_def.pop(0) 494 assert freq > 0, "syntax error in save_model_schedule" 495 save_model_schedule.extend( 496 list(range(start, max_epochs + 1, freq)) 497 ) 498 save_model_at_epochs = save_model_at_epochs.union(save_model_schedule) 499 save_model_at_epochs = list(save_model_at_epochs) 500 save_model_at_epochs.sort() 501 502 ### LR factor after restoring a model that has diverged 503 restore_scaling = training_parameters.get("restore_scaling", 1) 504 assert restore_scaling > 0.0, "restore_scaling must be > 0.0" 505 assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0" 506 if restore_scaling != 1: 507 print("Applying LR scaling after restore:", restore_scaling) 508 509 510 metric_use_best = training_parameters.get("metric_best", "rmse_tot") # .lower() 511 metrics_state = {} 512 513 ### INITIALIZE TRAINING STATE ### 514 if restart_checkpoint is not None: 515 training_state = deepcopy(restart_checkpoint["training"]) 516 epoch_start = training_state["epoch"] 517 model.preproc_state = freeze(restart_checkpoint["preproc_state"]) 518 if smoothen_metrics: 519 metrics_state = deepcopy(restart_checkpoint["metrics_state"]) 520 rmses_smooth = metrics_state["rmses_smooth"] 521 maes_smooth = metrics_state["maes_smooth"] 522 rmse_tot_smooth = metrics_state["rmse_tot_smooth"] 523 mae_tot_smooth = metrics_state["mae_tot_smooth"] 524 nsmooth = metrics_state["nsmooth"] 525 print("Restored training state") 526 else: 527 epoch_start = 1 528 training_state = { 529 "rng_key": rng_key, 530 "opt_state": opt_st, 531 "ema_state": ema_st, 532 "sch_state": sch_state, 533 "variables": model.variables, 534 "variables_ema": model.variables, 535 "step": 0, 536 "restore_count": 0, 537 "restore_scale": 1, 538 "epoch": 0, 539 "best_metric": np.inf, 540 } 541 if smoothen_metrics: 542 metrics_state = { 543 "rmses_smooth": dict(rmses_smooth), 544 "maes_smooth": dict(maes_smooth), 545 "rmse_tot_smooth": rmse_tot_smooth, 546 "mae_tot_smooth": mae_tot_smooth, 547 "nsmooth": nsmooth, 548 } 549 550 del opt_st, ema_st, sch_state, preproc_state 551 variables_save = training_state['variables'] 552 variables_ema_save = training_state['variables_ema'] 553 554 ### Training loop ### 555 start = time.time() 556 print("Starting training...") 557 for epoch in range(epoch_start, max_epochs+1): 558 training_state["epoch"] = epoch 559 s = time.time() 560 for _ in range(nbatch_per_epoch): 561 # fetch data 562 inputs0, data = next(training_iterator) 563 564 # preprocess data 565 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 566 567 rng_key, subkey = jax.random.split(rng_key) 568 inputs["rng_key"] = subkey 569 inputs["training_epoch"] = epoch 570 571 # adjust learning rate 572 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"]) 573 current_lr *= training_state["restore_scale"] 574 training_state["opt_state"].inner_states["trainable"].inner_state[-1].hyperparams[ 575 "step_size" 576 ] = current_lr 577 578 # train step 579 loss,training_state, _ = train_step(data,inputs,training_state) 580 581 rmses_avg = defaultdict(lambda: 0.0) 582 maes_avg = defaultdict(lambda: 0.0) 583 for _ in range(nbatch_per_validation): 584 inputs0, data = next(validation_iterator) 585 586 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 587 588 rng_key, subkey = jax.random.split(rng_key) 589 inputs["rng_key"] = subkey 590 inputs["training_epoch"] = epoch 591 592 rmses, maes, output_val = validation( 593 data=data, 594 inputs=inputs, 595 variables=training_state["variables_ema"], 596 ) 597 for k, v in rmses.items(): 598 rmses_avg[k] += v 599 for k, v in maes.items(): 600 maes_avg[k] += v 601 602 jax.block_until_ready(output_val) 603 604 ### Timings ### 605 e = time.time() 606 epoch_time = e - s 607 608 elapsed_time = e - start 609 if not adaptive_scheduler: 610 remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch 611 remain_last = epoch_time * (max_epochs - epoch + 1) 612 # estimate remaining time via weighted average (put more weight on last at the beginning) 613 wremain = np.sin(0.5 * np.pi * epoch / max_epochs) 614 remaining_time = human_time_duration( 615 remain_glob * wremain + remain_last * (1 - wremain) 616 ) 617 elapsed_time = human_time_duration(elapsed_time) 618 batch_time = human_time_duration( 619 epoch_time / (nbatch_per_epoch + nbatch_per_validation) 620 ) 621 epoch_time = human_time_duration(epoch_time) 622 623 for k in rmses_avg.keys(): 624 rmses_avg[k] /= nbatch_per_validation 625 for k in maes_avg.keys(): 626 maes_avg[k] /= nbatch_per_validation 627 628 ### Print metrics ### 629 print("") 630 line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}" 631 line += f", epoch time = {epoch_time}, batch time = {batch_time}" 632 line += f", elapsed time = {elapsed_time}" 633 if not adaptive_scheduler: 634 line += f", est. remaining time = {remaining_time}" 635 print(line) 636 rmse_tot = 0.0 637 mae_tot = 0.0 638 for k in rmses_avg.keys(): 639 mult = loss_definition[k]["mult"] 640 loss_prms = loss_definition[k] 641 rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5 642 mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"] 643 unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else "" 644 645 weight_str = "" 646 if "weight_schedule" in loss_prms: 647 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 648 weight_str = f"(w={w:.3f})" 649 650 if rmses_avg[k] / mult < 1.0e-2: 651 print( 652 f" rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e} {unit} {weight_str}" 653 ) 654 else: 655 print( 656 f" rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f} {unit} {weight_str}" 657 ) 658 659 ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ### 660 restore = False 661 reinit = False 662 for k, mae in maes_avg.items(): 663 if np.isnan(mae): 664 restore = True 665 reinit = True 666 print("NaN detected in mae") 667 # for k, v in inputs.items(): 668 # if hasattr(v,"shape"): 669 # if np.isnan(v).any(): 670 # print(k,v) 671 # sys.exit(1) 672 if "threshold" in loss_definition[k]: 673 thr = loss_definition[k]["threshold"] 674 if mae > thr * maes_prev[k]: 675 restore = True 676 break 677 678 epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S") 679 model.variables = training_state["variables_ema"] 680 if epoch in save_model_at_epochs: 681 filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx" 682 print("Scheduled model save :", filename) 683 model.save(filename) 684 685 if restore: 686 training_state["restore_count"] += 1 687 if training_state["restore_count"] > max_restore_count: 688 if reinit: 689 raise ValueError("Model diverged and could not be restored.") 690 else: 691 training_state["restore_count"] = 0 692 print( 693 f"{max_restore_count} unsuccessful restores, resuming training." 694 ) 695 696 training_state["variables"] = deepcopy(variables_save) 697 training_state["variables_ema"] = deepcopy(variables_ema_save) 698 model.variables = training_state["variables_ema"] 699 training_state["restore_scale"] *= restore_scaling 700 print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}") 701 if reinit: 702 training_state["opt_state"] = optimizer.init(model.variables) 703 training_state["ema_state"] = ema.init(model.variables) 704 print("Reinitialized optimizer after divergence.") 705 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 706 continue 707 708 training_state["restore_count"] = 0 709 710 variables_save = deepcopy(training_state["variables"]) 711 variables_ema_save = deepcopy(training_state["variables_ema"]) 712 maes_prev = maes_avg 713 714 ### SAVE MODEL ### 715 model.save(output_directory + "latest_model.fnx") 716 717 # save metrics 718 step = int(training_state["step"]) 719 metrics = { 720 "epoch": epoch, 721 "step": step, 722 "data_count": step * batch_size, 723 "elapsed_time": time.time() - start, 724 "lr": current_lr, 725 "loss": loss, 726 } 727 for k in rmses_avg.keys(): 728 mult = loss_definition[k]["mult"] 729 metrics[f"rmse_{k}"] = rmses_avg[k] / mult 730 metrics[f"mae_{k}"] = maes_avg[k] / mult 731 732 metrics["rmse_tot"] = rmse_tot 733 metrics["mae_tot"] = mae_tot 734 if smoothen_metrics: 735 nsmooth += 1 736 for k in rmses_avg.keys(): 737 mult = loss_definition[k]["mult"] 738 rmses_smooth[k] = ( 739 metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k] 740 ) 741 maes_smooth[k] = ( 742 metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k] 743 ) 744 metrics[f"rmse_smooth_{k}"] = ( 745 rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 746 ) 747 metrics[f"mae_smooth_{k}"] = ( 748 maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 749 ) 750 rmse_tot_smooth = ( 751 metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot 752 ) 753 mae_tot_smooth = ( 754 metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot 755 ) 756 metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth) 757 metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth) 758 759 # update metrics state 760 metrics_state["rmses_smooth"] = dict(rmses_smooth) 761 metrics_state["maes_smooth"] = dict(maes_smooth) 762 metrics_state["rmse_tot_smooth"] = rmse_tot_smooth 763 metrics_state["mae_tot_smooth"] = mae_tot_smooth 764 metrics_state["nsmooth"] = nsmooth 765 766 assert ( 767 metric_use_best in metrics 768 ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics" 769 770 metric_for_best = metrics[metric_use_best] 771 if metric_for_best < training_state["best_metric"]: 772 training_state["best_metric"] = metric_for_best 773 metrics["best_metric"] = training_state["best_metric"] 774 if keep_all_bests: 775 best_name = ( 776 output_directory 777 + f"best_model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx" 778 ) 779 model.save(best_name) 780 781 best_name = output_directory + f"best_model{stage_prefix}.fnx" 782 model.save(best_name) 783 print("New best model saved to:", best_name) 784 else: 785 metrics["best_metric"] = training_state["best_metric"] 786 787 if epoch == 1: 788 headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())] 789 fmetrics.write("# " + " ".join(headers) + "\n") 790 fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n") 791 fmetrics.flush() 792 793 # update learning rate using current metrics 794 assert ( 795 schedule_metrics in metrics 796 ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics" 797 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics]) 798 799 # update and save training state 800 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 801 802 803 if is_end(metrics): 804 print("Stage finished.") 805 break 806 807 end = time.time() 808 809 print(f"Training time: {human_time_duration(end-start)}") 810 print("") 811 812 fmetrics.close() 813 814 filename = output_directory + f"final_model{stage_prefix}.fnx" 815 model.save(filename) 816 print("Final model saved to:", filename) 817 818 return metrics, filename 819 820 821if __name__ == "__main__": 822 main()
def
save_restart_checkpoint( output_directory, stage, training_state, preproc_state, metrics_state):
47def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state): 48 with open(output_directory + "/restart_checkpoint.pkl", "wb") as f: 49 pickle.dump({ 50 "stage": stage, 51 "training": training_state, 52 "preproc_state": preproc_state, 53 "metrics_state": metrics_state, 54 }, f)
def
load_restart_checkpoint(output_directory):
56def load_restart_checkpoint(output_directory): 57 restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl" 58 if not os.path.exists(restart_checkpoint_file): 59 raise FileNotFoundError( 60 f"Training state file not found: {restart_checkpoint_file}" 61 ) 62 with open(restart_checkpoint_file, "rb") as f: 63 restart_checkpoint = pickle.load(f) 64 return restart_checkpoint
def
main():
67def main(): 68 parser = argparse.ArgumentParser(prog="fennol_train") 69 parser.add_argument("config_file", type=str) 70 parser.add_argument("--model_file", type=str, default=None) 71 args = parser.parse_args() 72 config_file = args.config_file 73 model_file = args.model_file 74 75 os.environ["OMP_NUM_THREADS"] = "1" 76 sys.stdout = io.TextIOWrapper( 77 open(sys.stdout.fileno(), "wb", 0), write_through=True 78 ) 79 80 restart_training = False 81 if os.path.isdir(config_file): 82 output_directory = Path(config_file).absolute().as_posix() 83 restart_training = True 84 while output_directory.endswith("/"): 85 output_directory = output_directory[:-1] 86 config_file = output_directory + "/config.yaml" 87 backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}" 88 shutil.copytree(output_directory, backup_dir) 89 print("Restarting training from", output_directory) 90 91 parameters = load_configuration(config_file) 92 93 ### Set the device 94 if "FENNOL_DEVICE" in os.environ: 95 device = os.environ["FENNOL_DEVICE"].lower() 96 print(f"# Setting device from env FENNOL_DEVICE={device}") 97 else: 98 device: str = parameters.get("device", "cpu").lower() 99 if device == "cpu": 100 jax.config.update('jax_platforms', 'cpu') 101 os.environ["CUDA_VISIBLE_DEVICES"] = "" 102 elif device.startswith("cuda") or device.startswith("gpu"): 103 if ":" in device: 104 num = device.split(":")[-1] 105 os.environ["CUDA_VISIBLE_DEVICES"] = num 106 else: 107 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 108 device = "gpu" 109 110 _device = jax.devices(device)[0] 111 jax.config.update("jax_default_device", _device) 112 113 if restart_training: 114 restart_checkpoint = load_restart_checkpoint(output_directory) 115 else: 116 restart_checkpoint = None 117 118 # output directory 119 if not restart_training: 120 output_directory = parameters.get("output_directory", None) 121 if output_directory is not None: 122 if "{now}" in output_directory: 123 output_directory = output_directory.replace( 124 "{now}", time.strftime("%Y-%m-%d-%H-%M-%S") 125 ) 126 output_directory = Path(output_directory).absolute() 127 if not output_directory.exists(): 128 output_directory.mkdir(parents=True) 129 print("Output directory:", output_directory) 130 else: 131 output_directory = "." 132 133 output_directory = str(output_directory) + "/" 134 135 # copy config_file to output directory 136 # config_name = Path(config_file).name 137 config_ext = Path(config_file).suffix 138 with open(config_file) as f_in: 139 config_data = f_in.read() 140 with open(output_directory + "/config" + config_ext, "w") as f_out: 141 f_out.write(config_data) 142 143 # set log file 144 log_file = "train.log" # parameters.get("log_file", None) 145 logger = TeeLogger(output_directory + log_file) 146 logger.bind_stdout() 147 148 # set matmul precision 149 enable_x64 = parameters.get("double_precision", False) 150 jax.config.update("jax_enable_x64", enable_x64) 151 fprec = "float64" if enable_x64 else "float32" 152 parameters["fprec"] = fprec 153 if enable_x64: 154 print("Double precision enabled.") 155 156 matmul_precision = parameters.get("matmul_prec", "highest").lower() 157 assert matmul_precision in [ 158 "default", 159 "high", 160 "highest", 161 ], "matmul_prec must be one of 'default','high','highest'" 162 jax.config.update("jax_default_matmul_precision", matmul_precision) 163 164 # set random seed 165 rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1)) 166 print(f"rng_seed: {rng_seed}") 167 rng_key = jax.random.PRNGKey(rng_seed) 168 torch.manual_seed(rng_seed) 169 np.random.seed(rng_seed) 170 random.seed(rng_seed) 171 np_rng = np.random.Generator(np.random.PCG64(rng_seed)) 172 173 try: 174 if "stages" in parameters["training"]: 175 ## train in stages ## 176 params = deepcopy(parameters) 177 stages = params["training"].pop("stages") 178 assert isinstance(stages, dict), "'stages' must be a dict with named stages" 179 model_file_stage = model_file 180 print_stages_params = params["training"].get("print_stages_params", False) 181 for i, (stage, stage_params) in enumerate(stages.items()): 182 rng_key, subkey = jax.random.split(rng_key) 183 print("") 184 print(f"### STAGE {i+1}: {stage} ###") 185 186 ## remove end_event from previous stage ## 187 if i > 0 and "end_event" in params["training"]: 188 params["training"].pop("end_event") 189 190 ## incrementally update training parameters ## 191 params = deep_update(params, {"training": stage_params}) 192 if model_file_stage is not None: 193 ## load model from previous stage ## 194 params["model_file"] = model_file_stage 195 196 if restart_training and restart_checkpoint["stage"] != i + 1: 197 print(f"Skipping stage {i+1} (already completed)") 198 continue 199 200 if print_stages_params: 201 print("stage parameters:") 202 print(json.dumps(params, indent=2, sort_keys=False)) 203 204 ## train stage ## 205 _, model_file_stage = train( 206 (subkey, np_rng), 207 params, 208 stage=i + 1, 209 output_directory=output_directory, 210 restart_checkpoint=restart_checkpoint, 211 ) 212 213 restart_training = False 214 restart_checkpoint = None 215 else: 216 ## single training stage ## 217 train( 218 (rng_key, np_rng), 219 parameters, 220 model_file=model_file, 221 output_directory=output_directory, 222 restart_checkpoint=restart_checkpoint, 223 ) 224 except KeyboardInterrupt: 225 print("Training interrupted by user.") 226 finally: 227 if log_file is not None: 228 logger.unbind_stdout() 229 logger.close()
def
train( rng, parameters, model_file=None, stage=None, output_directory=None, restart_checkpoint=None):
232def train( 233 rng, 234 parameters, 235 model_file=None, 236 stage=None, 237 output_directory=None, 238 restart_checkpoint=None, 239): 240 if output_directory is None: 241 output_directory = "./" 242 elif output_directory == "": 243 output_directory = "./" 244 elif not output_directory.endswith("/"): 245 output_directory += "/" 246 stage_prefix = f"_stage_{stage}" if stage is not None else "" 247 248 if isinstance(rng, tuple): 249 rng_key, np_rng = rng 250 else: 251 rng_key = rng 252 np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1))) 253 254 ### GET TRAINING PARAMETERS ### 255 training_parameters = parameters.get("training", {}) 256 257 #### LOAD MODEL #### 258 if restart_checkpoint is not None: 259 model_key = None 260 model_file = output_directory + "latest_model.fnx" 261 else: 262 rng_key, model_key = jax.random.split(rng_key) 263 model = load_model(parameters, model_file, rng_key=model_key) 264 265 ### SAVE INITIAL MODEL ### 266 save_initial_model = ( 267 restart_checkpoint is None 268 and (stage is None or stage == 0) 269 ) 270 if save_initial_model: 271 model.save(output_directory + "initial_model.fnx") 272 273 model_ref = None 274 if "model_ref" in training_parameters: 275 model_ref = load_model(parameters, training_parameters["model_ref"]) 276 print("Reference model:", training_parameters["model_ref"]) 277 278 if "ref_parameters" in training_parameters: 279 ref_parameters = training_parameters["ref_parameters"] 280 assert isinstance( 281 ref_parameters, list 282 ), "ref_parameters must be a list of str" 283 print("Reference parameters:", ref_parameters) 284 model.variables = copy_parameters( 285 model.variables, model_ref.variables, ref_parameters 286 ) 287 288 ### SET FLOATING POINT PRECISION ### 289 fprec = parameters.get("fprec", "float32") 290 291 def convert_to_fprec(x): 292 if jnp.issubdtype(x.dtype, jnp.floating): 293 return x.astype(fprec) 294 return x 295 296 model.variables = jax.tree_map(convert_to_fprec, model.variables) 297 298 ### SET UP LOSS FUNCTION ### 299 loss_definition, used_keys, ref_keys = get_loss_definition( 300 training_parameters, model_energy_unit=model.energy_unit 301 ) 302 303 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 304 if coordinates_ref_key is not None: 305 compute_ref_coords = True 306 print("Reference coordinates:", coordinates_ref_key) 307 else: 308 compute_ref_coords = False 309 310 #### LOAD DATASET #### 311 dspath = training_parameters.get("dspath", None) 312 if dspath is None: 313 raise ValueError("Dataset path 'training/dspath' should be specified.") 314 batch_size = training_parameters.get("batch_size", 16) 315 batch_size_val = training_parameters.get("batch_size_val", None) 316 rename_refs = training_parameters.get("rename_refs", {}) 317 training_iterator, validation_iterator = load_dataset( 318 dspath=dspath, 319 batch_size=batch_size, 320 batch_size_val=batch_size_val, 321 training_parameters=training_parameters, 322 infinite_iterator=True, 323 atom_padding=True, 324 ref_keys=ref_keys, 325 split_data_inputs=True, 326 np_rng=np_rng, 327 add_flags=["training"], 328 fprec=fprec, 329 rename_refs=rename_refs, 330 ) 331 332 compute_forces = "forces" in used_keys 333 compute_virial = "virial_tensor" in used_keys or "virial" in used_keys 334 compute_stress = "stress_tensor" in used_keys or "stress" in used_keys 335 compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys 336 337 ### get optimizer parameters ### 338 max_epochs = int(training_parameters.get("max_epochs", 2000)) 339 epoch_format = len(str(max_epochs)) 340 nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200) 341 nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20) 342 343 schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters) 344 optimizer = get_optimizer( 345 training_parameters, model.variables, schedule(sch_state)[0] 346 ) 347 opt_st = optimizer.init(model.variables) 348 349 ### exponential moving average of the parameters ### 350 ema_decay = training_parameters.get("ema_decay", -1.0) 351 if ema_decay > 0.0: 352 assert ema_decay < 1.0, "ema_decay must be in (0,1)" 353 ema = optax.ema(decay=ema_decay) 354 else: 355 ema = optax.identity() 356 ema_st = ema.init(model.variables) 357 358 #### end event #### 359 end_event = training_parameters.get("end_event", None) 360 if end_event is None or isinstance(end_event, str) and end_event.lower() == "none": 361 is_end = lambda metrics: False 362 else: 363 assert len(end_event) == 2, "end_event must be a list of two elements" 364 is_end = lambda metrics: metrics[end_event[0]] < end_event[1] 365 366 if "energy_terms" in training_parameters: 367 model.set_energy_terms(training_parameters["energy_terms"], jit=False) 368 print("energy terms:", model.energy_terms) 369 370 ### MODEL EVALUATION FUNCTION ### 371 pbc_training = training_parameters.get("pbc_training", False) 372 if compute_stress or compute_virial or compute_pressure: 373 virial_key = "virial" if "virial" in used_keys else "virial_tensor" 374 stress_key = "stress" if "stress" in used_keys else "stress_tensor" 375 pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor" 376 if compute_stress or compute_pressure: 377 assert pbc_training, "PBC must be enabled for stress or virial training" 378 print("Computing forces and stress tensor") 379 def evaluate(model, variables, data): 380 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 381 cells = output["cells"] 382 volume = jnp.abs(jnp.linalg.det(cells)) 383 stress = vir / volume[:, None, None] 384 output[stress_key] = stress 385 output[virial_key] = vir 386 if pressure_key == "pressure": 387 output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0 388 else: 389 output[pressure_key] = -stress 390 return output 391 392 else: 393 print("Computing forces and virial tensor") 394 def evaluate(model, variables, data): 395 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 396 output[virial_key] = vir 397 return output 398 399 elif compute_forces: 400 print("Computing forces") 401 def evaluate(model, variables, data): 402 _, _, output = model._energy_and_forces(variables, data) 403 return output 404 405 elif model.energy_terms is not None: 406 def evaluate(model, variables, data): 407 _, output = model._total_energy(variables, data) 408 return output 409 410 else: 411 def evaluate(model, variables, data): 412 output = model.modules.apply(variables, data) 413 return output 414 415 ### TRAINING FUNCTIONS ### 416 train_step = get_train_step_function( 417 loss_definition=loss_definition, 418 model=model, 419 model_ref=model_ref, 420 compute_ref_coords=compute_ref_coords, 421 evaluate=evaluate, 422 optimizer=optimizer, 423 ema=ema, 424 ) 425 426 validation = get_validation_function( 427 loss_definition=loss_definition, 428 model=model, 429 model_ref=model_ref, 430 compute_ref_coords=compute_ref_coords, 431 evaluate=evaluate, 432 return_targets=False, 433 ) 434 435 #### CONFIGURE PREPROCESSING #### 436 gpu_preprocessing = training_parameters.get("gpu_preprocessing", False) 437 if gpu_preprocessing: 438 print("GPU preprocessing activated.") 439 440 preproc_state = unfreeze(model.preproc_state) 441 layer_state = [] 442 for st in preproc_state["layers_state"]: 443 stnew = unfreeze(st) 444 # st["nblist_skin"] = nblist_skin 445 # if nblist_stride > 1: 446 # st["skin_stride"] = nblist_stride 447 # st["skin_count"] = nblist_stride 448 if "nblist_mult_size" in training_parameters: 449 stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 450 if "nblist_add_neigh" in training_parameters: 451 stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 452 if "nblist_add_atoms" in training_parameters: 453 stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 454 layer_state.append(freeze(stnew)) 455 456 preproc_state["layers_state"] = tuple(layer_state) 457 preproc_state["check_input"] = False 458 model.preproc_state = freeze(preproc_state) 459 460 ### INITIALIZE METRICS ### 461 maes_prev = defaultdict(lambda: np.inf) 462 metrics_beta = training_parameters.get("metrics_ema_decay", -1.0) 463 smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0 464 if smoothen_metrics: 465 print("Computing smoothed metrics with beta =", metrics_beta) 466 rmses_smooth = defaultdict(lambda: 0.0) 467 maes_smooth = defaultdict(lambda: 0.0) 468 rmse_tot_smooth = 0.0 469 mae_tot_smooth = 0.0 470 nsmooth = 0 471 max_restore_count = training_parameters.get("max_restore_count", 5) 472 473 fmetrics = open( 474 output_directory + f"metrics{stage_prefix}.traj", 475 "w" if restart_checkpoint is None else "a", 476 ) 477 478 keep_all_bests = training_parameters.get("keep_all_bests", False) 479 save_model_at_epochs = training_parameters.get("save_model_at_epochs", []) 480 save_model_at_epochs= set([int(x) for x in save_model_at_epochs]) 481 save_model_every = int(training_parameters.get("save_model_every_epoch", 0)) 482 if save_model_every > 0: 483 save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every)) 484 save_model_at_epochs = save_model_at_epochs.union(save_model_every) 485 if "save_model_schedule" in training_parameters: 486 save_model_schedule_def = [int(i) for i in training_parameters["save_model_schedule"]] 487 save_model_schedule = [] 488 while len(save_model_schedule_def) > 0: 489 i = save_model_schedule_def.pop(0) 490 if i > 0: 491 save_model_schedule.append(i) 492 elif i < 0: 493 start = -i 494 freq = save_model_schedule_def.pop(0) 495 assert freq > 0, "syntax error in save_model_schedule" 496 save_model_schedule.extend( 497 list(range(start, max_epochs + 1, freq)) 498 ) 499 save_model_at_epochs = save_model_at_epochs.union(save_model_schedule) 500 save_model_at_epochs = list(save_model_at_epochs) 501 save_model_at_epochs.sort() 502 503 ### LR factor after restoring a model that has diverged 504 restore_scaling = training_parameters.get("restore_scaling", 1) 505 assert restore_scaling > 0.0, "restore_scaling must be > 0.0" 506 assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0" 507 if restore_scaling != 1: 508 print("Applying LR scaling after restore:", restore_scaling) 509 510 511 metric_use_best = training_parameters.get("metric_best", "rmse_tot") # .lower() 512 metrics_state = {} 513 514 ### INITIALIZE TRAINING STATE ### 515 if restart_checkpoint is not None: 516 training_state = deepcopy(restart_checkpoint["training"]) 517 epoch_start = training_state["epoch"] 518 model.preproc_state = freeze(restart_checkpoint["preproc_state"]) 519 if smoothen_metrics: 520 metrics_state = deepcopy(restart_checkpoint["metrics_state"]) 521 rmses_smooth = metrics_state["rmses_smooth"] 522 maes_smooth = metrics_state["maes_smooth"] 523 rmse_tot_smooth = metrics_state["rmse_tot_smooth"] 524 mae_tot_smooth = metrics_state["mae_tot_smooth"] 525 nsmooth = metrics_state["nsmooth"] 526 print("Restored training state") 527 else: 528 epoch_start = 1 529 training_state = { 530 "rng_key": rng_key, 531 "opt_state": opt_st, 532 "ema_state": ema_st, 533 "sch_state": sch_state, 534 "variables": model.variables, 535 "variables_ema": model.variables, 536 "step": 0, 537 "restore_count": 0, 538 "restore_scale": 1, 539 "epoch": 0, 540 "best_metric": np.inf, 541 } 542 if smoothen_metrics: 543 metrics_state = { 544 "rmses_smooth": dict(rmses_smooth), 545 "maes_smooth": dict(maes_smooth), 546 "rmse_tot_smooth": rmse_tot_smooth, 547 "mae_tot_smooth": mae_tot_smooth, 548 "nsmooth": nsmooth, 549 } 550 551 del opt_st, ema_st, sch_state, preproc_state 552 variables_save = training_state['variables'] 553 variables_ema_save = training_state['variables_ema'] 554 555 ### Training loop ### 556 start = time.time() 557 print("Starting training...") 558 for epoch in range(epoch_start, max_epochs+1): 559 training_state["epoch"] = epoch 560 s = time.time() 561 for _ in range(nbatch_per_epoch): 562 # fetch data 563 inputs0, data = next(training_iterator) 564 565 # preprocess data 566 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 567 568 rng_key, subkey = jax.random.split(rng_key) 569 inputs["rng_key"] = subkey 570 inputs["training_epoch"] = epoch 571 572 # adjust learning rate 573 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"]) 574 current_lr *= training_state["restore_scale"] 575 training_state["opt_state"].inner_states["trainable"].inner_state[-1].hyperparams[ 576 "step_size" 577 ] = current_lr 578 579 # train step 580 loss,training_state, _ = train_step(data,inputs,training_state) 581 582 rmses_avg = defaultdict(lambda: 0.0) 583 maes_avg = defaultdict(lambda: 0.0) 584 for _ in range(nbatch_per_validation): 585 inputs0, data = next(validation_iterator) 586 587 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 588 589 rng_key, subkey = jax.random.split(rng_key) 590 inputs["rng_key"] = subkey 591 inputs["training_epoch"] = epoch 592 593 rmses, maes, output_val = validation( 594 data=data, 595 inputs=inputs, 596 variables=training_state["variables_ema"], 597 ) 598 for k, v in rmses.items(): 599 rmses_avg[k] += v 600 for k, v in maes.items(): 601 maes_avg[k] += v 602 603 jax.block_until_ready(output_val) 604 605 ### Timings ### 606 e = time.time() 607 epoch_time = e - s 608 609 elapsed_time = e - start 610 if not adaptive_scheduler: 611 remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch 612 remain_last = epoch_time * (max_epochs - epoch + 1) 613 # estimate remaining time via weighted average (put more weight on last at the beginning) 614 wremain = np.sin(0.5 * np.pi * epoch / max_epochs) 615 remaining_time = human_time_duration( 616 remain_glob * wremain + remain_last * (1 - wremain) 617 ) 618 elapsed_time = human_time_duration(elapsed_time) 619 batch_time = human_time_duration( 620 epoch_time / (nbatch_per_epoch + nbatch_per_validation) 621 ) 622 epoch_time = human_time_duration(epoch_time) 623 624 for k in rmses_avg.keys(): 625 rmses_avg[k] /= nbatch_per_validation 626 for k in maes_avg.keys(): 627 maes_avg[k] /= nbatch_per_validation 628 629 ### Print metrics ### 630 print("") 631 line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}" 632 line += f", epoch time = {epoch_time}, batch time = {batch_time}" 633 line += f", elapsed time = {elapsed_time}" 634 if not adaptive_scheduler: 635 line += f", est. remaining time = {remaining_time}" 636 print(line) 637 rmse_tot = 0.0 638 mae_tot = 0.0 639 for k in rmses_avg.keys(): 640 mult = loss_definition[k]["mult"] 641 loss_prms = loss_definition[k] 642 rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5 643 mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"] 644 unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else "" 645 646 weight_str = "" 647 if "weight_schedule" in loss_prms: 648 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 649 weight_str = f"(w={w:.3f})" 650 651 if rmses_avg[k] / mult < 1.0e-2: 652 print( 653 f" rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e} {unit} {weight_str}" 654 ) 655 else: 656 print( 657 f" rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f} {unit} {weight_str}" 658 ) 659 660 ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ### 661 restore = False 662 reinit = False 663 for k, mae in maes_avg.items(): 664 if np.isnan(mae): 665 restore = True 666 reinit = True 667 print("NaN detected in mae") 668 # for k, v in inputs.items(): 669 # if hasattr(v,"shape"): 670 # if np.isnan(v).any(): 671 # print(k,v) 672 # sys.exit(1) 673 if "threshold" in loss_definition[k]: 674 thr = loss_definition[k]["threshold"] 675 if mae > thr * maes_prev[k]: 676 restore = True 677 break 678 679 epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S") 680 model.variables = training_state["variables_ema"] 681 if epoch in save_model_at_epochs: 682 filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx" 683 print("Scheduled model save :", filename) 684 model.save(filename) 685 686 if restore: 687 training_state["restore_count"] += 1 688 if training_state["restore_count"] > max_restore_count: 689 if reinit: 690 raise ValueError("Model diverged and could not be restored.") 691 else: 692 training_state["restore_count"] = 0 693 print( 694 f"{max_restore_count} unsuccessful restores, resuming training." 695 ) 696 697 training_state["variables"] = deepcopy(variables_save) 698 training_state["variables_ema"] = deepcopy(variables_ema_save) 699 model.variables = training_state["variables_ema"] 700 training_state["restore_scale"] *= restore_scaling 701 print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}") 702 if reinit: 703 training_state["opt_state"] = optimizer.init(model.variables) 704 training_state["ema_state"] = ema.init(model.variables) 705 print("Reinitialized optimizer after divergence.") 706 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 707 continue 708 709 training_state["restore_count"] = 0 710 711 variables_save = deepcopy(training_state["variables"]) 712 variables_ema_save = deepcopy(training_state["variables_ema"]) 713 maes_prev = maes_avg 714 715 ### SAVE MODEL ### 716 model.save(output_directory + "latest_model.fnx") 717 718 # save metrics 719 step = int(training_state["step"]) 720 metrics = { 721 "epoch": epoch, 722 "step": step, 723 "data_count": step * batch_size, 724 "elapsed_time": time.time() - start, 725 "lr": current_lr, 726 "loss": loss, 727 } 728 for k in rmses_avg.keys(): 729 mult = loss_definition[k]["mult"] 730 metrics[f"rmse_{k}"] = rmses_avg[k] / mult 731 metrics[f"mae_{k}"] = maes_avg[k] / mult 732 733 metrics["rmse_tot"] = rmse_tot 734 metrics["mae_tot"] = mae_tot 735 if smoothen_metrics: 736 nsmooth += 1 737 for k in rmses_avg.keys(): 738 mult = loss_definition[k]["mult"] 739 rmses_smooth[k] = ( 740 metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k] 741 ) 742 maes_smooth[k] = ( 743 metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k] 744 ) 745 metrics[f"rmse_smooth_{k}"] = ( 746 rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 747 ) 748 metrics[f"mae_smooth_{k}"] = ( 749 maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 750 ) 751 rmse_tot_smooth = ( 752 metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot 753 ) 754 mae_tot_smooth = ( 755 metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot 756 ) 757 metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth) 758 metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth) 759 760 # update metrics state 761 metrics_state["rmses_smooth"] = dict(rmses_smooth) 762 metrics_state["maes_smooth"] = dict(maes_smooth) 763 metrics_state["rmse_tot_smooth"] = rmse_tot_smooth 764 metrics_state["mae_tot_smooth"] = mae_tot_smooth 765 metrics_state["nsmooth"] = nsmooth 766 767 assert ( 768 metric_use_best in metrics 769 ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics" 770 771 metric_for_best = metrics[metric_use_best] 772 if metric_for_best < training_state["best_metric"]: 773 training_state["best_metric"] = metric_for_best 774 metrics["best_metric"] = training_state["best_metric"] 775 if keep_all_bests: 776 best_name = ( 777 output_directory 778 + f"best_model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx" 779 ) 780 model.save(best_name) 781 782 best_name = output_directory + f"best_model{stage_prefix}.fnx" 783 model.save(best_name) 784 print("New best model saved to:", best_name) 785 else: 786 metrics["best_metric"] = training_state["best_metric"] 787 788 if epoch == 1: 789 headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())] 790 fmetrics.write("# " + " ".join(headers) + "\n") 791 fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n") 792 fmetrics.flush() 793 794 # update learning rate using current metrics 795 assert ( 796 schedule_metrics in metrics 797 ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics" 798 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics]) 799 800 # update and save training state 801 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 802 803 804 if is_end(metrics): 805 print("Stage finished.") 806 break 807 808 end = time.time() 809 810 print(f"Training time: {human_time_duration(end-start)}") 811 print("") 812 813 fmetrics.close() 814 815 filename = output_directory + f"final_model{stage_prefix}.fnx" 816 model.save(filename) 817 print("Final model saved to:", filename) 818 819 return metrics, filename