fennol.training.training
1import os 2import yaml 3import sys 4import jax 5import io 6import time 7import jax.numpy as jnp 8import numpy as np 9import optax 10from collections import defaultdict 11import json 12from copy import deepcopy 13from pathlib import Path 14import argparse 15try: 16 import torch 17except ImportError: 18 raise ImportError( 19 "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/" 20 ) 21import random 22from flax import traverse_util 23import json 24import shutil 25import pickle 26 27from flax.core import freeze, unfreeze 28from .io import ( 29 load_configuration, 30 load_dataset, 31 load_model, 32 TeeLogger, 33 copy_parameters, 34) 35from .utils import ( 36 get_loss_definition, 37 get_train_step_function, 38 get_validation_function, 39 linear_schedule, 40) 41from .optimizers import get_optimizer, get_lr_schedule, multi_ema 42from ..utils import deep_update, AtomicUnits as au 43from ..utils.io import human_time_duration 44 45def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state): 46 with open(output_directory + "/restart_checkpoint.pkl", "wb") as f: 47 pickle.dump({ 48 "stage": stage, 49 "training": training_state, 50 "preproc_state": preproc_state, 51 "metrics_state": metrics_state, 52 }, f) 53 54def load_restart_checkpoint(output_directory): 55 restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl" 56 if not os.path.exists(restart_checkpoint_file): 57 raise FileNotFoundError( 58 f"Training state file not found: {restart_checkpoint_file}" 59 ) 60 with open(restart_checkpoint_file, "rb") as f: 61 restart_checkpoint = pickle.load(f) 62 return restart_checkpoint 63 64 65def main(): 66 parser = argparse.ArgumentParser(prog="fennol_train") 67 parser.add_argument("config_file", type=str) 68 parser.add_argument("--model_file", type=str, default=None) 69 args = parser.parse_args() 70 config_file = args.config_file 71 model_file = args.model_file 72 73 os.environ["OMP_NUM_THREADS"] = "1" 74 sys.stdout = io.TextIOWrapper( 75 open(sys.stdout.fileno(), "wb", 0), write_through=True 76 ) 77 78 restart_training = False 79 if os.path.isdir(config_file): 80 output_directory = Path(config_file).absolute().as_posix() 81 restart_training = True 82 while output_directory.endswith("/"): 83 output_directory = output_directory[:-1] 84 config_file = output_directory + "/config.yaml" 85 backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}" 86 shutil.copytree(output_directory, backup_dir) 87 print("Restarting training from", output_directory) 88 89 parameters = load_configuration(config_file) 90 91 ### Set the device 92 if "FENNOL_DEVICE" in os.environ: 93 device = os.environ["FENNOL_DEVICE"].lower() 94 print(f"# Setting device from env FENNOL_DEVICE={device}") 95 else: 96 device: str = parameters.get("device", "cpu").lower() 97 if device == "cpu": 98 jax.config.update('jax_platforms', 'cpu') 99 os.environ["CUDA_VISIBLE_DEVICES"] = "" 100 elif device.startswith("cuda") or device.startswith("gpu"): 101 if ":" in device: 102 num = device.split(":")[-1] 103 os.environ["CUDA_VISIBLE_DEVICES"] = num 104 else: 105 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 106 device = "gpu" 107 108 _device = jax.devices(device)[0] 109 jax.config.update("jax_default_device", _device) 110 111 if restart_training: 112 restart_checkpoint = load_restart_checkpoint(output_directory) 113 else: 114 restart_checkpoint = None 115 116 # output directory 117 if not restart_training: 118 output_directory = parameters.get("output_directory", None) 119 if output_directory is not None: 120 if "{now}" in output_directory: 121 output_directory = output_directory.replace( 122 "{now}", time.strftime("%Y-%m-%d-%H-%M-%S") 123 ) 124 output_directory = Path(output_directory).absolute() 125 if not output_directory.exists(): 126 output_directory.mkdir(parents=True) 127 print("Output directory:", output_directory) 128 else: 129 output_directory = "." 130 131 output_directory = str(output_directory) + "/" 132 133 # copy config_file to output directory 134 # config_name = Path(config_file).name 135 config_ext = Path(config_file).suffix 136 with open(config_file) as f_in: 137 config_data = f_in.read() 138 with open(output_directory + "/config" + config_ext, "w") as f_out: 139 f_out.write(config_data) 140 141 # set log file 142 log_file = "train.log" # parameters.get("log_file", None) 143 logger = TeeLogger(output_directory + log_file) 144 logger.bind_stdout() 145 146 # set matmul precision 147 enable_x64 = parameters.get("double_precision", False) 148 jax.config.update("jax_enable_x64", enable_x64) 149 fprec = "float64" if enable_x64 else "float32" 150 parameters["fprec"] = fprec 151 if enable_x64: 152 print("Double precision enabled.") 153 154 matmul_precision = parameters.get("matmul_prec", "highest").lower() 155 assert matmul_precision in [ 156 "default", 157 "high", 158 "highest", 159 ], "matmul_prec must be one of 'default','high','highest'" 160 jax.config.update("jax_default_matmul_precision", matmul_precision) 161 162 # set random seed 163 rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1)) 164 print(f"rng_seed: {rng_seed}") 165 rng_key = jax.random.PRNGKey(rng_seed) 166 torch.manual_seed(rng_seed) 167 np.random.seed(rng_seed) 168 random.seed(rng_seed) 169 np_rng = np.random.Generator(np.random.PCG64(rng_seed)) 170 171 try: 172 if "stages" in parameters["training"]: 173 ## train in stages ## 174 params = deepcopy(parameters) 175 stages = params["training"].pop("stages") 176 assert isinstance(stages, dict), "'stages' must be a dict with named stages" 177 model_file_stage = model_file 178 print_stages_params = params["training"].get("print_stages_params", False) 179 for i, (stage, stage_params) in enumerate(stages.items()): 180 rng_key, subkey = jax.random.split(rng_key) 181 print("") 182 print(f"### STAGE {i+1}: {stage} ###") 183 184 ## remove end_event from previous stage ## 185 if i > 0 and "end_event" in params["training"]: 186 params["training"].pop("end_event") 187 188 ## incrementally update training parameters ## 189 params = deep_update(params, {"training": stage_params}) 190 if model_file_stage is not None: 191 ## load model from previous stage ## 192 params["model_file"] = model_file_stage 193 194 if restart_training and restart_checkpoint["stage"] != i + 1: 195 print(f"Skipping stage {i+1} (already completed)") 196 continue 197 198 if print_stages_params: 199 print("stage parameters:") 200 print(json.dumps(params, indent=2, sort_keys=False)) 201 202 ## train stage ## 203 _, model_file_stage = train( 204 (subkey, np_rng), 205 params, 206 stage=i + 1, 207 output_directory=output_directory, 208 restart_checkpoint=restart_checkpoint, 209 ) 210 211 restart_training = False 212 restart_checkpoint = None 213 else: 214 ## single training stage ## 215 train( 216 (rng_key, np_rng), 217 parameters, 218 model_file=model_file, 219 output_directory=output_directory, 220 restart_checkpoint=restart_checkpoint, 221 ) 222 except KeyboardInterrupt: 223 print("Training interrupted by user.") 224 finally: 225 if log_file is not None: 226 logger.unbind_stdout() 227 logger.close() 228 229 230def train( 231 rng, 232 parameters, 233 model_file=None, 234 stage=None, 235 output_directory=None, 236 restart_checkpoint=None, 237): 238 if output_directory is None: 239 output_directory = "./" 240 elif output_directory == "": 241 output_directory = "./" 242 elif not output_directory.endswith("/"): 243 output_directory += "/" 244 stage_prefix = f"_stage_{stage}" if stage is not None else "" 245 246 if isinstance(rng, tuple): 247 rng_key, np_rng = rng 248 else: 249 rng_key = rng 250 np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1))) 251 252 ### GET TRAINING PARAMETERS ### 253 training_parameters = parameters.get("training", {}) 254 255 #### LOAD MODEL #### 256 if restart_checkpoint is not None: 257 model_key = None 258 model_file = output_directory + "latest_model.fnx" 259 else: 260 rng_key, model_key = jax.random.split(rng_key) 261 model = load_model(parameters, model_file, rng_key=model_key) 262 263 ### SAVE INITIAL MODEL ### 264 save_initial_model = ( 265 restart_checkpoint is None 266 and (stage is None or stage == 0) 267 ) 268 if save_initial_model: 269 model.save(output_directory + "initial_model.fnx") 270 271 model_ref = None 272 if "model_ref" in training_parameters: 273 model_ref = load_model(parameters, training_parameters["model_ref"]) 274 print("Reference model:", training_parameters["model_ref"]) 275 276 if "ref_parameters" in training_parameters: 277 ref_parameters = training_parameters["ref_parameters"] 278 assert isinstance( 279 ref_parameters, list 280 ), "ref_parameters must be a list of str" 281 print("Reference parameters:", ref_parameters) 282 model.variables = copy_parameters( 283 model.variables, model_ref.variables, ref_parameters 284 ) 285 286 ### SET FLOATING POINT PRECISION ### 287 fprec = parameters.get("fprec", "float32") 288 289 def convert_to_fprec(x): 290 if jnp.issubdtype(x.dtype, jnp.floating): 291 return x.astype(fprec) 292 return x 293 294 model.variables = jax.tree_map(convert_to_fprec, model.variables) 295 296 ### SET UP LOSS FUNCTION ### 297 loss_definition, used_keys, ref_keys = get_loss_definition( 298 training_parameters, model_energy_unit=model.energy_unit 299 ) 300 301 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 302 if coordinates_ref_key is not None: 303 compute_ref_coords = True 304 print("Reference coordinates:", coordinates_ref_key) 305 else: 306 compute_ref_coords = False 307 308 #### LOAD DATASET #### 309 dspath = training_parameters.get("dspath", None) 310 if dspath is None: 311 raise ValueError("Dataset path 'training/dspath' should be specified.") 312 batch_size = training_parameters.get("batch_size", 16) 313 batch_size_val = training_parameters.get("batch_size_val", None) 314 rename_refs = training_parameters.get("rename_refs", {}) 315 training_iterator, validation_iterator = load_dataset( 316 dspath=dspath, 317 batch_size=batch_size, 318 batch_size_val=batch_size_val, 319 training_parameters=training_parameters, 320 infinite_iterator=True, 321 atom_padding=True, 322 ref_keys=ref_keys, 323 split_data_inputs=True, 324 np_rng=np_rng, 325 add_flags=["training"], 326 fprec=fprec, 327 rename_refs=rename_refs, 328 ) 329 330 compute_forces = "forces" in used_keys 331 compute_virial = "virial_tensor" in used_keys or "virial" in used_keys 332 compute_stress = "stress_tensor" in used_keys or "stress" in used_keys 333 compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys 334 335 ### get optimizer parameters ### 336 max_epochs = int(training_parameters.get("max_epochs", 2000)) 337 epoch_format = len(str(max_epochs)) 338 nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200) 339 nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20) 340 341 schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters) 342 optimizer,ilr = get_optimizer( 343 training_parameters, model.variables, schedule(sch_state)[0] 344 ) 345 opt_st = optimizer.init(model.variables) 346 347 ### exponential moving average of the parameters ### 348 ema_decay = training_parameters.get("ema_decay", None) 349 if ema_decay is not None: 350 if isinstance(ema_decay,float): 351 ema_decay = [ema_decay] 352 for decay in ema_decay: 353 assert isinstance(decay, float), "ema_decay must be a float or a list of floats" 354 assert 0. <= decay < 1.0, "ema_decay must be in (0,1)" 355 356 power_ema = training_parameters.get("ema_power", False) 357 ema_ival = int(training_parameters.get("ema_validate", 0)) 358 ema = multi_ema(decays=ema_decay,power=power_ema) 359 else: 360 ema = multi_ema(decays=[]) # no ema 361 ema_ival = 0 362 ema_st = ema.init(model.variables) 363 variables_ema, _ = ema.update(model.variables, ema_st) 364 365 #### end event #### 366 end_event = training_parameters.get("end_event", None) 367 if end_event is None or isinstance(end_event, str) and end_event.lower() == "none": 368 is_end = lambda metrics: False 369 else: 370 assert len(end_event) == 2, "end_event must be a list of two elements" 371 is_end = lambda metrics: metrics[end_event[0]] < end_event[1] 372 373 if "energy_terms" in training_parameters: 374 model.set_energy_terms(training_parameters["energy_terms"], jit=False) 375 print("energy terms:", model.energy_terms) 376 377 ### MODEL EVALUATION FUNCTION ### 378 pbc_training = training_parameters.get("pbc_training", False) 379 if compute_stress or compute_virial or compute_pressure: 380 virial_key = "virial" if "virial" in used_keys else "virial_tensor" 381 stress_key = "stress" if "stress" in used_keys else "stress_tensor" 382 pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor" 383 if compute_stress or compute_pressure: 384 assert pbc_training, "PBC must be enabled for stress or virial training" 385 print("Computing forces and stress tensor") 386 def evaluate(model, variables, data): 387 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 388 cells = output["cells"] 389 volume = jnp.abs(jnp.linalg.det(cells)) 390 stress = vir / volume[:, None, None] 391 output[stress_key] = stress 392 output[virial_key] = vir 393 if pressure_key == "pressure": 394 output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0 395 else: 396 output[pressure_key] = -stress 397 return output 398 399 else: 400 print("Computing forces and virial tensor") 401 def evaluate(model, variables, data): 402 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 403 output[virial_key] = vir 404 return output 405 406 elif compute_forces: 407 print("Computing forces") 408 def evaluate(model, variables, data): 409 _, _, output = model._energy_and_forces(variables, data) 410 return output 411 412 elif model.energy_terms is not None: 413 def evaluate(model, variables, data): 414 _, output = model._total_energy(variables, data) 415 return output 416 417 else: 418 def evaluate(model, variables, data): 419 output = model.modules.apply(variables, data) 420 return output 421 422 ### TRAINING FUNCTIONS ### 423 train_step = get_train_step_function( 424 loss_definition=loss_definition, 425 model=model, 426 model_ref=model_ref, 427 compute_ref_coords=compute_ref_coords, 428 evaluate=evaluate, 429 optimizer=optimizer, 430 ema=ema, 431 ) 432 433 validation = get_validation_function( 434 loss_definition=loss_definition, 435 model=model, 436 model_ref=model_ref, 437 compute_ref_coords=compute_ref_coords, 438 evaluate=evaluate, 439 return_targets=False, 440 ) 441 442 #### CONFIGURE PREPROCESSING #### 443 gpu_preprocessing = training_parameters.get("gpu_preprocessing", False) 444 if gpu_preprocessing: 445 print("GPU preprocessing activated.") 446 447 preproc_state = unfreeze(model.preproc_state) 448 layer_state = [] 449 for st in preproc_state["layers_state"]: 450 stnew = unfreeze(st) 451 # st["nblist_skin"] = nblist_skin 452 # if nblist_stride > 1: 453 # st["skin_stride"] = nblist_stride 454 # st["skin_count"] = nblist_stride 455 if "nblist_mult_size" in training_parameters: 456 stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 457 if "nblist_add_neigh" in training_parameters: 458 stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 459 if "nblist_add_atoms" in training_parameters: 460 stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 461 layer_state.append(freeze(stnew)) 462 463 preproc_state["layers_state"] = tuple(layer_state) 464 preproc_state["check_input"] = False 465 model.preproc_state = freeze(preproc_state) 466 467 ### INITIALIZE METRICS ### 468 maes_prev = defaultdict(lambda: np.inf) 469 metrics_beta = training_parameters.get("metrics_ema_decay", -1.0) 470 smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0 471 if smoothen_metrics: 472 print("Computing smoothed metrics with beta =", metrics_beta) 473 rmses_smooth = defaultdict(lambda: 0.0) 474 maes_smooth = defaultdict(lambda: 0.0) 475 rmse_tot_smooth = 0.0 476 mae_tot_smooth = 0.0 477 nsmooth = 0 478 max_restore_count = training_parameters.get("max_restore_count", 5) 479 480 fmetrics = open( 481 output_directory + f"metrics{stage_prefix}.traj", 482 "w" if restart_checkpoint is None else "a", 483 ) 484 485 keep_all_bests = training_parameters.get("keep_all_bests", False) 486 save_model_at_epochs = training_parameters.get("save_model_at_epochs", []) 487 save_model_at_epochs= set([int(x) for x in save_model_at_epochs]) 488 save_model_every = int(training_parameters.get("save_model_every_epoch", 0)) 489 if save_model_every > 0: 490 save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every)) 491 save_model_at_epochs = save_model_at_epochs.union(save_model_every) 492 if "save_model_schedule" in training_parameters: 493 save_model_schedule_def = [int(i) for i in training_parameters["save_model_schedule"]] 494 save_model_schedule = [] 495 while len(save_model_schedule_def) > 0: 496 i = save_model_schedule_def.pop(0) 497 if i > 0: 498 save_model_schedule.append(i) 499 elif i < 0: 500 start = -i 501 freq = save_model_schedule_def.pop(0) 502 assert freq > 0, "syntax error in save_model_schedule" 503 save_model_schedule.extend( 504 list(range(start, max_epochs + 1, freq)) 505 ) 506 save_model_at_epochs = save_model_at_epochs.union(save_model_schedule) 507 save_model_at_epochs = list(save_model_at_epochs) 508 save_model_at_epochs.sort() 509 510 ### LR factor after restoring a model that has diverged 511 restore_scaling = training_parameters.get("restore_scaling", 1) 512 assert restore_scaling > 0.0, "restore_scaling must be > 0.0" 513 assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0" 514 if restore_scaling != 1: 515 print("Applying LR scaling after restore:", restore_scaling) 516 517 518 metric_use_best = training_parameters.get("metric_best", "rmse_tot") # .lower() 519 metrics_state = {} 520 521 ### INITIALIZE TRAINING STATE ### 522 if restart_checkpoint is not None: 523 training_state = deepcopy(restart_checkpoint["training"]) 524 epoch_start = training_state["epoch"] 525 model.preproc_state = freeze(restart_checkpoint["preproc_state"]) 526 if smoothen_metrics: 527 metrics_state = deepcopy(restart_checkpoint["metrics_state"]) 528 rmses_smooth = metrics_state["rmses_smooth"] 529 maes_smooth = metrics_state["maes_smooth"] 530 rmse_tot_smooth = metrics_state["rmse_tot_smooth"] 531 mae_tot_smooth = metrics_state["mae_tot_smooth"] 532 nsmooth = metrics_state["nsmooth"] 533 print("Restored training state") 534 else: 535 epoch_start = 1 536 training_state = { 537 "rng_key": rng_key, 538 "opt_state": opt_st, 539 "ema_state": ema_st, 540 "sch_state": sch_state, 541 "variables": model.variables, 542 "variables_ema": variables_ema, 543 "step": 0, 544 "restore_count": 0, 545 "restore_scale": 1, 546 "epoch": 0, 547 "best_metric": np.inf, 548 } 549 if smoothen_metrics: 550 metrics_state = { 551 "rmses_smooth": dict(rmses_smooth), 552 "maes_smooth": dict(maes_smooth), 553 "rmse_tot_smooth": rmse_tot_smooth, 554 "mae_tot_smooth": mae_tot_smooth, 555 "nsmooth": nsmooth, 556 } 557 558 del opt_st, ema_st, sch_state, preproc_state, variables_ema 559 variables_save = training_state['variables'] 560 variables_ema_save = training_state['variables_ema'] 561 562 ### Training loop ### 563 start = time.time() 564 print("Starting training...") 565 for epoch in range(epoch_start, max_epochs+1): 566 training_state["epoch"] = epoch 567 s = time.time() 568 for _ in range(nbatch_per_epoch): 569 # fetch data 570 inputs0, data = next(training_iterator) 571 572 # preprocess data 573 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 574 575 rng_key, subkey = jax.random.split(rng_key) 576 inputs["rng_key"] = subkey 577 inputs["training_epoch"] = epoch 578 579 # adjust learning rate 580 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"]) 581 current_lr *= training_state["restore_scale"] 582 training_state["opt_state"].inner_states["trainable"].inner_state[ilr].hyperparams[ 583 "step_size" 584 ] = current_lr 585 586 # train step 587 loss,training_state, _ = train_step(data,inputs,training_state) 588 589 rmses_avg = defaultdict(lambda: 0.0) 590 maes_avg = defaultdict(lambda: 0.0) 591 for _ in range(nbatch_per_validation): 592 inputs0, data = next(validation_iterator) 593 594 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 595 596 rng_key, subkey = jax.random.split(rng_key) 597 inputs["rng_key"] = subkey 598 inputs["training_epoch"] = epoch 599 600 rmses, maes, output_val = validation( 601 data=data, 602 inputs=inputs, 603 variables=training_state["variables_ema"][ema_ival], 604 ) 605 for k, v in rmses.items(): 606 rmses_avg[k] += v 607 for k, v in maes.items(): 608 maes_avg[k] += v 609 610 jax.block_until_ready(output_val) 611 612 ### Timings ### 613 e = time.time() 614 epoch_time = e - s 615 616 elapsed_time = e - start 617 if not adaptive_scheduler: 618 remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch 619 remain_last = epoch_time * (max_epochs - epoch + 1) 620 # estimate remaining time via weighted average (put more weight on last at the beginning) 621 wremain = np.sin(0.5 * np.pi * epoch / max_epochs) 622 remaining_time = human_time_duration( 623 remain_glob * wremain + remain_last * (1 - wremain) 624 ) 625 elapsed_time = human_time_duration(elapsed_time) 626 batch_time = human_time_duration( 627 epoch_time / (nbatch_per_epoch + nbatch_per_validation) 628 ) 629 epoch_time = human_time_duration(epoch_time) 630 631 for k in rmses_avg.keys(): 632 rmses_avg[k] /= nbatch_per_validation 633 for k in maes_avg.keys(): 634 maes_avg[k] /= nbatch_per_validation 635 636 ### Print metrics ### 637 print("") 638 line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}" 639 line += f", epoch time = {epoch_time}, batch time = {batch_time}" 640 line += f", elapsed time = {elapsed_time}" 641 if not adaptive_scheduler: 642 line += f", est. remaining time = {remaining_time}" 643 print(line) 644 rmse_tot = 0.0 645 mae_tot = 0.0 646 for k in rmses_avg.keys(): 647 mult = loss_definition[k]["display_mult"] 648 loss_prms = loss_definition[k] 649 rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5 650 mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"] 651 unit = "(" + loss_prms["display_unit"] + ")" if "display_unit" in loss_prms else "" 652 653 weight_str = "" 654 if "weight_schedule" in loss_prms: 655 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 656 weight_str = f"(w={w:.3f})" 657 658 if rmses_avg[k] / mult < 1.0e-2: 659 print( 660 f" rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e} {unit} {weight_str}" 661 ) 662 else: 663 print( 664 f" rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f} {unit} {weight_str}" 665 ) 666 667 ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ### 668 restore = False 669 reinit = False 670 for k, mae in maes_avg.items(): 671 if np.isnan(mae): 672 restore = True 673 reinit = True 674 print("NaN detected in mae") 675 # for k, v in inputs.items(): 676 # if hasattr(v,"shape"): 677 # if np.isnan(v).any(): 678 # print(k,v) 679 # sys.exit(1) 680 if "threshold" in loss_definition[k]: 681 thr = loss_definition[k]["threshold"] 682 if mae > thr * maes_prev[k]: 683 restore = True 684 break 685 686 epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S") 687 if epoch in save_model_at_epochs: 688 for i,v in enumerate(training_state["variables_ema"]): 689 model.variables = v 690 ema_str = f"_ema{i+1}" if len(training_state["variables_ema"]) > 1 else "" 691 filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}{ema_str}_{epoch_time_str}.fnx" 692 print("Scheduled model save :", filename) 693 model.save(filename) 694 695 if restore: 696 training_state["restore_count"] += 1 697 if training_state["restore_count"] > max_restore_count: 698 if reinit: 699 raise ValueError("Model diverged and could not be restored.") 700 else: 701 training_state["restore_count"] = 0 702 print( 703 f"{max_restore_count} unsuccessful restores, resuming training." 704 ) 705 706 training_state["variables"] = deepcopy(variables_save) 707 training_state["variables_ema"] = deepcopy(variables_ema_save) 708 training_state["restore_scale"] *= restore_scaling 709 print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}") 710 if reinit: 711 training_state["opt_state"] = optimizer.init(model.variables) 712 training_state["ema_state"] = ema.init(model.variables) 713 print("Reinitialized optimizer after divergence.") 714 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 715 continue 716 717 training_state["restore_count"] = 0 718 719 variables_save = deepcopy(training_state["variables"]) 720 variables_ema_save = deepcopy(training_state["variables_ema"]) 721 maes_prev = maes_avg 722 723 ### SAVE MODEL ### 724 model.variables = training_state["variables_ema"][ema_ival] 725 model.save(output_directory + "latest_model.fnx") 726 727 # save metrics 728 step = int(training_state["step"]) 729 metrics = { 730 "epoch": epoch, 731 "step": step, 732 "data_count": step * batch_size, 733 "elapsed_time": time.time() - start, 734 "lr": current_lr, 735 "loss": loss, 736 } 737 for k in rmses_avg.keys(): 738 mult = loss_definition[k]["mult"] 739 metrics[f"rmse_{k}"] = rmses_avg[k] / mult 740 metrics[f"mae_{k}"] = maes_avg[k] / mult 741 742 metrics["rmse_tot"] = rmse_tot 743 metrics["mae_tot"] = mae_tot 744 if smoothen_metrics: 745 nsmooth += 1 746 for k in rmses_avg.keys(): 747 mult = loss_definition[k]["mult"] 748 rmses_smooth[k] = ( 749 metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k] 750 ) 751 maes_smooth[k] = ( 752 metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k] 753 ) 754 metrics[f"rmse_smooth_{k}"] = ( 755 rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 756 ) 757 metrics[f"mae_smooth_{k}"] = ( 758 maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 759 ) 760 rmse_tot_smooth = ( 761 metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot 762 ) 763 mae_tot_smooth = ( 764 metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot 765 ) 766 metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth) 767 metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth) 768 769 # update metrics state 770 metrics_state["rmses_smooth"] = dict(rmses_smooth) 771 metrics_state["maes_smooth"] = dict(maes_smooth) 772 metrics_state["rmse_tot_smooth"] = rmse_tot_smooth 773 metrics_state["mae_tot_smooth"] = mae_tot_smooth 774 metrics_state["nsmooth"] = nsmooth 775 776 assert ( 777 metric_use_best in metrics 778 ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics" 779 780 metric_for_best = metrics[metric_use_best] 781 if metric_for_best < training_state["best_metric"]: 782 training_state["best_metric"] = metric_for_best 783 metrics["best_metric"] = training_state["best_metric"] 784 best_epoch_str = f"_epoch{epoch:0{epoch_format}}_{epoch_time_str}" if keep_all_bests else "" 785 best_name = output_directory + f"best_model{stage_prefix}{best_epoch_str}.fnx" 786 model.save(best_name) 787 print("New best model saved to:", best_name) 788 else: 789 metrics["best_metric"] = training_state["best_metric"] 790 791 if epoch == 1: 792 headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())] 793 fmetrics.write("# " + " ".join(headers) + "\n") 794 fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n") 795 fmetrics.flush() 796 797 # update learning rate using current metrics 798 assert ( 799 schedule_metrics in metrics 800 ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics" 801 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics]) 802 803 # update and save training state 804 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 805 806 807 if is_end(metrics): 808 print("Stage finished.") 809 break 810 811 end = time.time() 812 813 print(f"Training time: {human_time_duration(end-start)}") 814 print("") 815 816 fmetrics.close() 817 818 filename = output_directory + f"final_model{stage_prefix}.fnx" 819 model.save(filename) 820 print("Final model saved to:", filename) 821 822 return metrics, filename 823 824 825if __name__ == "__main__": 826 main()
def
save_restart_checkpoint( output_directory, stage, training_state, preproc_state, metrics_state):
46def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state): 47 with open(output_directory + "/restart_checkpoint.pkl", "wb") as f: 48 pickle.dump({ 49 "stage": stage, 50 "training": training_state, 51 "preproc_state": preproc_state, 52 "metrics_state": metrics_state, 53 }, f)
def
load_restart_checkpoint(output_directory):
55def load_restart_checkpoint(output_directory): 56 restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl" 57 if not os.path.exists(restart_checkpoint_file): 58 raise FileNotFoundError( 59 f"Training state file not found: {restart_checkpoint_file}" 60 ) 61 with open(restart_checkpoint_file, "rb") as f: 62 restart_checkpoint = pickle.load(f) 63 return restart_checkpoint
def
main():
66def main(): 67 parser = argparse.ArgumentParser(prog="fennol_train") 68 parser.add_argument("config_file", type=str) 69 parser.add_argument("--model_file", type=str, default=None) 70 args = parser.parse_args() 71 config_file = args.config_file 72 model_file = args.model_file 73 74 os.environ["OMP_NUM_THREADS"] = "1" 75 sys.stdout = io.TextIOWrapper( 76 open(sys.stdout.fileno(), "wb", 0), write_through=True 77 ) 78 79 restart_training = False 80 if os.path.isdir(config_file): 81 output_directory = Path(config_file).absolute().as_posix() 82 restart_training = True 83 while output_directory.endswith("/"): 84 output_directory = output_directory[:-1] 85 config_file = output_directory + "/config.yaml" 86 backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}" 87 shutil.copytree(output_directory, backup_dir) 88 print("Restarting training from", output_directory) 89 90 parameters = load_configuration(config_file) 91 92 ### Set the device 93 if "FENNOL_DEVICE" in os.environ: 94 device = os.environ["FENNOL_DEVICE"].lower() 95 print(f"# Setting device from env FENNOL_DEVICE={device}") 96 else: 97 device: str = parameters.get("device", "cpu").lower() 98 if device == "cpu": 99 jax.config.update('jax_platforms', 'cpu') 100 os.environ["CUDA_VISIBLE_DEVICES"] = "" 101 elif device.startswith("cuda") or device.startswith("gpu"): 102 if ":" in device: 103 num = device.split(":")[-1] 104 os.environ["CUDA_VISIBLE_DEVICES"] = num 105 else: 106 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 107 device = "gpu" 108 109 _device = jax.devices(device)[0] 110 jax.config.update("jax_default_device", _device) 111 112 if restart_training: 113 restart_checkpoint = load_restart_checkpoint(output_directory) 114 else: 115 restart_checkpoint = None 116 117 # output directory 118 if not restart_training: 119 output_directory = parameters.get("output_directory", None) 120 if output_directory is not None: 121 if "{now}" in output_directory: 122 output_directory = output_directory.replace( 123 "{now}", time.strftime("%Y-%m-%d-%H-%M-%S") 124 ) 125 output_directory = Path(output_directory).absolute() 126 if not output_directory.exists(): 127 output_directory.mkdir(parents=True) 128 print("Output directory:", output_directory) 129 else: 130 output_directory = "." 131 132 output_directory = str(output_directory) + "/" 133 134 # copy config_file to output directory 135 # config_name = Path(config_file).name 136 config_ext = Path(config_file).suffix 137 with open(config_file) as f_in: 138 config_data = f_in.read() 139 with open(output_directory + "/config" + config_ext, "w") as f_out: 140 f_out.write(config_data) 141 142 # set log file 143 log_file = "train.log" # parameters.get("log_file", None) 144 logger = TeeLogger(output_directory + log_file) 145 logger.bind_stdout() 146 147 # set matmul precision 148 enable_x64 = parameters.get("double_precision", False) 149 jax.config.update("jax_enable_x64", enable_x64) 150 fprec = "float64" if enable_x64 else "float32" 151 parameters["fprec"] = fprec 152 if enable_x64: 153 print("Double precision enabled.") 154 155 matmul_precision = parameters.get("matmul_prec", "highest").lower() 156 assert matmul_precision in [ 157 "default", 158 "high", 159 "highest", 160 ], "matmul_prec must be one of 'default','high','highest'" 161 jax.config.update("jax_default_matmul_precision", matmul_precision) 162 163 # set random seed 164 rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1)) 165 print(f"rng_seed: {rng_seed}") 166 rng_key = jax.random.PRNGKey(rng_seed) 167 torch.manual_seed(rng_seed) 168 np.random.seed(rng_seed) 169 random.seed(rng_seed) 170 np_rng = np.random.Generator(np.random.PCG64(rng_seed)) 171 172 try: 173 if "stages" in parameters["training"]: 174 ## train in stages ## 175 params = deepcopy(parameters) 176 stages = params["training"].pop("stages") 177 assert isinstance(stages, dict), "'stages' must be a dict with named stages" 178 model_file_stage = model_file 179 print_stages_params = params["training"].get("print_stages_params", False) 180 for i, (stage, stage_params) in enumerate(stages.items()): 181 rng_key, subkey = jax.random.split(rng_key) 182 print("") 183 print(f"### STAGE {i+1}: {stage} ###") 184 185 ## remove end_event from previous stage ## 186 if i > 0 and "end_event" in params["training"]: 187 params["training"].pop("end_event") 188 189 ## incrementally update training parameters ## 190 params = deep_update(params, {"training": stage_params}) 191 if model_file_stage is not None: 192 ## load model from previous stage ## 193 params["model_file"] = model_file_stage 194 195 if restart_training and restart_checkpoint["stage"] != i + 1: 196 print(f"Skipping stage {i+1} (already completed)") 197 continue 198 199 if print_stages_params: 200 print("stage parameters:") 201 print(json.dumps(params, indent=2, sort_keys=False)) 202 203 ## train stage ## 204 _, model_file_stage = train( 205 (subkey, np_rng), 206 params, 207 stage=i + 1, 208 output_directory=output_directory, 209 restart_checkpoint=restart_checkpoint, 210 ) 211 212 restart_training = False 213 restart_checkpoint = None 214 else: 215 ## single training stage ## 216 train( 217 (rng_key, np_rng), 218 parameters, 219 model_file=model_file, 220 output_directory=output_directory, 221 restart_checkpoint=restart_checkpoint, 222 ) 223 except KeyboardInterrupt: 224 print("Training interrupted by user.") 225 finally: 226 if log_file is not None: 227 logger.unbind_stdout() 228 logger.close()
def
train( rng, parameters, model_file=None, stage=None, output_directory=None, restart_checkpoint=None):
231def train( 232 rng, 233 parameters, 234 model_file=None, 235 stage=None, 236 output_directory=None, 237 restart_checkpoint=None, 238): 239 if output_directory is None: 240 output_directory = "./" 241 elif output_directory == "": 242 output_directory = "./" 243 elif not output_directory.endswith("/"): 244 output_directory += "/" 245 stage_prefix = f"_stage_{stage}" if stage is not None else "" 246 247 if isinstance(rng, tuple): 248 rng_key, np_rng = rng 249 else: 250 rng_key = rng 251 np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1))) 252 253 ### GET TRAINING PARAMETERS ### 254 training_parameters = parameters.get("training", {}) 255 256 #### LOAD MODEL #### 257 if restart_checkpoint is not None: 258 model_key = None 259 model_file = output_directory + "latest_model.fnx" 260 else: 261 rng_key, model_key = jax.random.split(rng_key) 262 model = load_model(parameters, model_file, rng_key=model_key) 263 264 ### SAVE INITIAL MODEL ### 265 save_initial_model = ( 266 restart_checkpoint is None 267 and (stage is None or stage == 0) 268 ) 269 if save_initial_model: 270 model.save(output_directory + "initial_model.fnx") 271 272 model_ref = None 273 if "model_ref" in training_parameters: 274 model_ref = load_model(parameters, training_parameters["model_ref"]) 275 print("Reference model:", training_parameters["model_ref"]) 276 277 if "ref_parameters" in training_parameters: 278 ref_parameters = training_parameters["ref_parameters"] 279 assert isinstance( 280 ref_parameters, list 281 ), "ref_parameters must be a list of str" 282 print("Reference parameters:", ref_parameters) 283 model.variables = copy_parameters( 284 model.variables, model_ref.variables, ref_parameters 285 ) 286 287 ### SET FLOATING POINT PRECISION ### 288 fprec = parameters.get("fprec", "float32") 289 290 def convert_to_fprec(x): 291 if jnp.issubdtype(x.dtype, jnp.floating): 292 return x.astype(fprec) 293 return x 294 295 model.variables = jax.tree_map(convert_to_fprec, model.variables) 296 297 ### SET UP LOSS FUNCTION ### 298 loss_definition, used_keys, ref_keys = get_loss_definition( 299 training_parameters, model_energy_unit=model.energy_unit 300 ) 301 302 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 303 if coordinates_ref_key is not None: 304 compute_ref_coords = True 305 print("Reference coordinates:", coordinates_ref_key) 306 else: 307 compute_ref_coords = False 308 309 #### LOAD DATASET #### 310 dspath = training_parameters.get("dspath", None) 311 if dspath is None: 312 raise ValueError("Dataset path 'training/dspath' should be specified.") 313 batch_size = training_parameters.get("batch_size", 16) 314 batch_size_val = training_parameters.get("batch_size_val", None) 315 rename_refs = training_parameters.get("rename_refs", {}) 316 training_iterator, validation_iterator = load_dataset( 317 dspath=dspath, 318 batch_size=batch_size, 319 batch_size_val=batch_size_val, 320 training_parameters=training_parameters, 321 infinite_iterator=True, 322 atom_padding=True, 323 ref_keys=ref_keys, 324 split_data_inputs=True, 325 np_rng=np_rng, 326 add_flags=["training"], 327 fprec=fprec, 328 rename_refs=rename_refs, 329 ) 330 331 compute_forces = "forces" in used_keys 332 compute_virial = "virial_tensor" in used_keys or "virial" in used_keys 333 compute_stress = "stress_tensor" in used_keys or "stress" in used_keys 334 compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys 335 336 ### get optimizer parameters ### 337 max_epochs = int(training_parameters.get("max_epochs", 2000)) 338 epoch_format = len(str(max_epochs)) 339 nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200) 340 nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20) 341 342 schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters) 343 optimizer,ilr = get_optimizer( 344 training_parameters, model.variables, schedule(sch_state)[0] 345 ) 346 opt_st = optimizer.init(model.variables) 347 348 ### exponential moving average of the parameters ### 349 ema_decay = training_parameters.get("ema_decay", None) 350 if ema_decay is not None: 351 if isinstance(ema_decay,float): 352 ema_decay = [ema_decay] 353 for decay in ema_decay: 354 assert isinstance(decay, float), "ema_decay must be a float or a list of floats" 355 assert 0. <= decay < 1.0, "ema_decay must be in (0,1)" 356 357 power_ema = training_parameters.get("ema_power", False) 358 ema_ival = int(training_parameters.get("ema_validate", 0)) 359 ema = multi_ema(decays=ema_decay,power=power_ema) 360 else: 361 ema = multi_ema(decays=[]) # no ema 362 ema_ival = 0 363 ema_st = ema.init(model.variables) 364 variables_ema, _ = ema.update(model.variables, ema_st) 365 366 #### end event #### 367 end_event = training_parameters.get("end_event", None) 368 if end_event is None or isinstance(end_event, str) and end_event.lower() == "none": 369 is_end = lambda metrics: False 370 else: 371 assert len(end_event) == 2, "end_event must be a list of two elements" 372 is_end = lambda metrics: metrics[end_event[0]] < end_event[1] 373 374 if "energy_terms" in training_parameters: 375 model.set_energy_terms(training_parameters["energy_terms"], jit=False) 376 print("energy terms:", model.energy_terms) 377 378 ### MODEL EVALUATION FUNCTION ### 379 pbc_training = training_parameters.get("pbc_training", False) 380 if compute_stress or compute_virial or compute_pressure: 381 virial_key = "virial" if "virial" in used_keys else "virial_tensor" 382 stress_key = "stress" if "stress" in used_keys else "stress_tensor" 383 pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor" 384 if compute_stress or compute_pressure: 385 assert pbc_training, "PBC must be enabled for stress or virial training" 386 print("Computing forces and stress tensor") 387 def evaluate(model, variables, data): 388 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 389 cells = output["cells"] 390 volume = jnp.abs(jnp.linalg.det(cells)) 391 stress = vir / volume[:, None, None] 392 output[stress_key] = stress 393 output[virial_key] = vir 394 if pressure_key == "pressure": 395 output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0 396 else: 397 output[pressure_key] = -stress 398 return output 399 400 else: 401 print("Computing forces and virial tensor") 402 def evaluate(model, variables, data): 403 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 404 output[virial_key] = vir 405 return output 406 407 elif compute_forces: 408 print("Computing forces") 409 def evaluate(model, variables, data): 410 _, _, output = model._energy_and_forces(variables, data) 411 return output 412 413 elif model.energy_terms is not None: 414 def evaluate(model, variables, data): 415 _, output = model._total_energy(variables, data) 416 return output 417 418 else: 419 def evaluate(model, variables, data): 420 output = model.modules.apply(variables, data) 421 return output 422 423 ### TRAINING FUNCTIONS ### 424 train_step = get_train_step_function( 425 loss_definition=loss_definition, 426 model=model, 427 model_ref=model_ref, 428 compute_ref_coords=compute_ref_coords, 429 evaluate=evaluate, 430 optimizer=optimizer, 431 ema=ema, 432 ) 433 434 validation = get_validation_function( 435 loss_definition=loss_definition, 436 model=model, 437 model_ref=model_ref, 438 compute_ref_coords=compute_ref_coords, 439 evaluate=evaluate, 440 return_targets=False, 441 ) 442 443 #### CONFIGURE PREPROCESSING #### 444 gpu_preprocessing = training_parameters.get("gpu_preprocessing", False) 445 if gpu_preprocessing: 446 print("GPU preprocessing activated.") 447 448 preproc_state = unfreeze(model.preproc_state) 449 layer_state = [] 450 for st in preproc_state["layers_state"]: 451 stnew = unfreeze(st) 452 # st["nblist_skin"] = nblist_skin 453 # if nblist_stride > 1: 454 # st["skin_stride"] = nblist_stride 455 # st["skin_count"] = nblist_stride 456 if "nblist_mult_size" in training_parameters: 457 stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 458 if "nblist_add_neigh" in training_parameters: 459 stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 460 if "nblist_add_atoms" in training_parameters: 461 stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 462 layer_state.append(freeze(stnew)) 463 464 preproc_state["layers_state"] = tuple(layer_state) 465 preproc_state["check_input"] = False 466 model.preproc_state = freeze(preproc_state) 467 468 ### INITIALIZE METRICS ### 469 maes_prev = defaultdict(lambda: np.inf) 470 metrics_beta = training_parameters.get("metrics_ema_decay", -1.0) 471 smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0 472 if smoothen_metrics: 473 print("Computing smoothed metrics with beta =", metrics_beta) 474 rmses_smooth = defaultdict(lambda: 0.0) 475 maes_smooth = defaultdict(lambda: 0.0) 476 rmse_tot_smooth = 0.0 477 mae_tot_smooth = 0.0 478 nsmooth = 0 479 max_restore_count = training_parameters.get("max_restore_count", 5) 480 481 fmetrics = open( 482 output_directory + f"metrics{stage_prefix}.traj", 483 "w" if restart_checkpoint is None else "a", 484 ) 485 486 keep_all_bests = training_parameters.get("keep_all_bests", False) 487 save_model_at_epochs = training_parameters.get("save_model_at_epochs", []) 488 save_model_at_epochs= set([int(x) for x in save_model_at_epochs]) 489 save_model_every = int(training_parameters.get("save_model_every_epoch", 0)) 490 if save_model_every > 0: 491 save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every)) 492 save_model_at_epochs = save_model_at_epochs.union(save_model_every) 493 if "save_model_schedule" in training_parameters: 494 save_model_schedule_def = [int(i) for i in training_parameters["save_model_schedule"]] 495 save_model_schedule = [] 496 while len(save_model_schedule_def) > 0: 497 i = save_model_schedule_def.pop(0) 498 if i > 0: 499 save_model_schedule.append(i) 500 elif i < 0: 501 start = -i 502 freq = save_model_schedule_def.pop(0) 503 assert freq > 0, "syntax error in save_model_schedule" 504 save_model_schedule.extend( 505 list(range(start, max_epochs + 1, freq)) 506 ) 507 save_model_at_epochs = save_model_at_epochs.union(save_model_schedule) 508 save_model_at_epochs = list(save_model_at_epochs) 509 save_model_at_epochs.sort() 510 511 ### LR factor after restoring a model that has diverged 512 restore_scaling = training_parameters.get("restore_scaling", 1) 513 assert restore_scaling > 0.0, "restore_scaling must be > 0.0" 514 assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0" 515 if restore_scaling != 1: 516 print("Applying LR scaling after restore:", restore_scaling) 517 518 519 metric_use_best = training_parameters.get("metric_best", "rmse_tot") # .lower() 520 metrics_state = {} 521 522 ### INITIALIZE TRAINING STATE ### 523 if restart_checkpoint is not None: 524 training_state = deepcopy(restart_checkpoint["training"]) 525 epoch_start = training_state["epoch"] 526 model.preproc_state = freeze(restart_checkpoint["preproc_state"]) 527 if smoothen_metrics: 528 metrics_state = deepcopy(restart_checkpoint["metrics_state"]) 529 rmses_smooth = metrics_state["rmses_smooth"] 530 maes_smooth = metrics_state["maes_smooth"] 531 rmse_tot_smooth = metrics_state["rmse_tot_smooth"] 532 mae_tot_smooth = metrics_state["mae_tot_smooth"] 533 nsmooth = metrics_state["nsmooth"] 534 print("Restored training state") 535 else: 536 epoch_start = 1 537 training_state = { 538 "rng_key": rng_key, 539 "opt_state": opt_st, 540 "ema_state": ema_st, 541 "sch_state": sch_state, 542 "variables": model.variables, 543 "variables_ema": variables_ema, 544 "step": 0, 545 "restore_count": 0, 546 "restore_scale": 1, 547 "epoch": 0, 548 "best_metric": np.inf, 549 } 550 if smoothen_metrics: 551 metrics_state = { 552 "rmses_smooth": dict(rmses_smooth), 553 "maes_smooth": dict(maes_smooth), 554 "rmse_tot_smooth": rmse_tot_smooth, 555 "mae_tot_smooth": mae_tot_smooth, 556 "nsmooth": nsmooth, 557 } 558 559 del opt_st, ema_st, sch_state, preproc_state, variables_ema 560 variables_save = training_state['variables'] 561 variables_ema_save = training_state['variables_ema'] 562 563 ### Training loop ### 564 start = time.time() 565 print("Starting training...") 566 for epoch in range(epoch_start, max_epochs+1): 567 training_state["epoch"] = epoch 568 s = time.time() 569 for _ in range(nbatch_per_epoch): 570 # fetch data 571 inputs0, data = next(training_iterator) 572 573 # preprocess data 574 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 575 576 rng_key, subkey = jax.random.split(rng_key) 577 inputs["rng_key"] = subkey 578 inputs["training_epoch"] = epoch 579 580 # adjust learning rate 581 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"]) 582 current_lr *= training_state["restore_scale"] 583 training_state["opt_state"].inner_states["trainable"].inner_state[ilr].hyperparams[ 584 "step_size" 585 ] = current_lr 586 587 # train step 588 loss,training_state, _ = train_step(data,inputs,training_state) 589 590 rmses_avg = defaultdict(lambda: 0.0) 591 maes_avg = defaultdict(lambda: 0.0) 592 for _ in range(nbatch_per_validation): 593 inputs0, data = next(validation_iterator) 594 595 inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0) 596 597 rng_key, subkey = jax.random.split(rng_key) 598 inputs["rng_key"] = subkey 599 inputs["training_epoch"] = epoch 600 601 rmses, maes, output_val = validation( 602 data=data, 603 inputs=inputs, 604 variables=training_state["variables_ema"][ema_ival], 605 ) 606 for k, v in rmses.items(): 607 rmses_avg[k] += v 608 for k, v in maes.items(): 609 maes_avg[k] += v 610 611 jax.block_until_ready(output_val) 612 613 ### Timings ### 614 e = time.time() 615 epoch_time = e - s 616 617 elapsed_time = e - start 618 if not adaptive_scheduler: 619 remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch 620 remain_last = epoch_time * (max_epochs - epoch + 1) 621 # estimate remaining time via weighted average (put more weight on last at the beginning) 622 wremain = np.sin(0.5 * np.pi * epoch / max_epochs) 623 remaining_time = human_time_duration( 624 remain_glob * wremain + remain_last * (1 - wremain) 625 ) 626 elapsed_time = human_time_duration(elapsed_time) 627 batch_time = human_time_duration( 628 epoch_time / (nbatch_per_epoch + nbatch_per_validation) 629 ) 630 epoch_time = human_time_duration(epoch_time) 631 632 for k in rmses_avg.keys(): 633 rmses_avg[k] /= nbatch_per_validation 634 for k in maes_avg.keys(): 635 maes_avg[k] /= nbatch_per_validation 636 637 ### Print metrics ### 638 print("") 639 line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}" 640 line += f", epoch time = {epoch_time}, batch time = {batch_time}" 641 line += f", elapsed time = {elapsed_time}" 642 if not adaptive_scheduler: 643 line += f", est. remaining time = {remaining_time}" 644 print(line) 645 rmse_tot = 0.0 646 mae_tot = 0.0 647 for k in rmses_avg.keys(): 648 mult = loss_definition[k]["display_mult"] 649 loss_prms = loss_definition[k] 650 rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5 651 mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"] 652 unit = "(" + loss_prms["display_unit"] + ")" if "display_unit" in loss_prms else "" 653 654 weight_str = "" 655 if "weight_schedule" in loss_prms: 656 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 657 weight_str = f"(w={w:.3f})" 658 659 if rmses_avg[k] / mult < 1.0e-2: 660 print( 661 f" rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e} {unit} {weight_str}" 662 ) 663 else: 664 print( 665 f" rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f} {unit} {weight_str}" 666 ) 667 668 ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ### 669 restore = False 670 reinit = False 671 for k, mae in maes_avg.items(): 672 if np.isnan(mae): 673 restore = True 674 reinit = True 675 print("NaN detected in mae") 676 # for k, v in inputs.items(): 677 # if hasattr(v,"shape"): 678 # if np.isnan(v).any(): 679 # print(k,v) 680 # sys.exit(1) 681 if "threshold" in loss_definition[k]: 682 thr = loss_definition[k]["threshold"] 683 if mae > thr * maes_prev[k]: 684 restore = True 685 break 686 687 epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S") 688 if epoch in save_model_at_epochs: 689 for i,v in enumerate(training_state["variables_ema"]): 690 model.variables = v 691 ema_str = f"_ema{i+1}" if len(training_state["variables_ema"]) > 1 else "" 692 filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}{ema_str}_{epoch_time_str}.fnx" 693 print("Scheduled model save :", filename) 694 model.save(filename) 695 696 if restore: 697 training_state["restore_count"] += 1 698 if training_state["restore_count"] > max_restore_count: 699 if reinit: 700 raise ValueError("Model diverged and could not be restored.") 701 else: 702 training_state["restore_count"] = 0 703 print( 704 f"{max_restore_count} unsuccessful restores, resuming training." 705 ) 706 707 training_state["variables"] = deepcopy(variables_save) 708 training_state["variables_ema"] = deepcopy(variables_ema_save) 709 training_state["restore_scale"] *= restore_scaling 710 print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}") 711 if reinit: 712 training_state["opt_state"] = optimizer.init(model.variables) 713 training_state["ema_state"] = ema.init(model.variables) 714 print("Reinitialized optimizer after divergence.") 715 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 716 continue 717 718 training_state["restore_count"] = 0 719 720 variables_save = deepcopy(training_state["variables"]) 721 variables_ema_save = deepcopy(training_state["variables_ema"]) 722 maes_prev = maes_avg 723 724 ### SAVE MODEL ### 725 model.variables = training_state["variables_ema"][ema_ival] 726 model.save(output_directory + "latest_model.fnx") 727 728 # save metrics 729 step = int(training_state["step"]) 730 metrics = { 731 "epoch": epoch, 732 "step": step, 733 "data_count": step * batch_size, 734 "elapsed_time": time.time() - start, 735 "lr": current_lr, 736 "loss": loss, 737 } 738 for k in rmses_avg.keys(): 739 mult = loss_definition[k]["mult"] 740 metrics[f"rmse_{k}"] = rmses_avg[k] / mult 741 metrics[f"mae_{k}"] = maes_avg[k] / mult 742 743 metrics["rmse_tot"] = rmse_tot 744 metrics["mae_tot"] = mae_tot 745 if smoothen_metrics: 746 nsmooth += 1 747 for k in rmses_avg.keys(): 748 mult = loss_definition[k]["mult"] 749 rmses_smooth[k] = ( 750 metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k] 751 ) 752 maes_smooth[k] = ( 753 metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k] 754 ) 755 metrics[f"rmse_smooth_{k}"] = ( 756 rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 757 ) 758 metrics[f"mae_smooth_{k}"] = ( 759 maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 760 ) 761 rmse_tot_smooth = ( 762 metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot 763 ) 764 mae_tot_smooth = ( 765 metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot 766 ) 767 metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth) 768 metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth) 769 770 # update metrics state 771 metrics_state["rmses_smooth"] = dict(rmses_smooth) 772 metrics_state["maes_smooth"] = dict(maes_smooth) 773 metrics_state["rmse_tot_smooth"] = rmse_tot_smooth 774 metrics_state["mae_tot_smooth"] = mae_tot_smooth 775 metrics_state["nsmooth"] = nsmooth 776 777 assert ( 778 metric_use_best in metrics 779 ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics" 780 781 metric_for_best = metrics[metric_use_best] 782 if metric_for_best < training_state["best_metric"]: 783 training_state["best_metric"] = metric_for_best 784 metrics["best_metric"] = training_state["best_metric"] 785 best_epoch_str = f"_epoch{epoch:0{epoch_format}}_{epoch_time_str}" if keep_all_bests else "" 786 best_name = output_directory + f"best_model{stage_prefix}{best_epoch_str}.fnx" 787 model.save(best_name) 788 print("New best model saved to:", best_name) 789 else: 790 metrics["best_metric"] = training_state["best_metric"] 791 792 if epoch == 1: 793 headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())] 794 fmetrics.write("# " + " ".join(headers) + "\n") 795 fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n") 796 fmetrics.flush() 797 798 # update learning rate using current metrics 799 assert ( 800 schedule_metrics in metrics 801 ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics" 802 current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics]) 803 804 # update and save training state 805 save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state) 806 807 808 if is_end(metrics): 809 print("Stage finished.") 810 break 811 812 end = time.time() 813 814 print(f"Training time: {human_time_duration(end-start)}") 815 print("") 816 817 fmetrics.close() 818 819 filename = output_directory + f"final_model{stage_prefix}.fnx" 820 model.save(filename) 821 print("Final model saved to:", filename) 822 823 return metrics, filename