fennol.training.training
1import os 2import yaml 3import sys 4import jax 5import io 6import time 7import jax.numpy as jnp 8import numpy as np 9import optax 10from collections import defaultdict 11import json 12from copy import deepcopy 13from pathlib import Path 14import argparse 15import torch 16import random 17from flax import traverse_util 18import json 19import shutil 20import pickle 21 22from flax.core import freeze, unfreeze 23from .io import ( 24 load_configuration, 25 load_dataset, 26 load_model, 27 TeeLogger, 28 copy_parameters, 29) 30from .utils import ( 31 get_loss_definition, 32 get_train_step_function, 33 get_validation_function, 34 get_optimizer, 35 linear_schedule, 36) 37from ..utils import deep_update, AtomicUnits as au 38from ..utils.io import human_time_duration 39from ..models.preprocessing import AtomPadding, check_input, convert_to_jax 40 41 42def main(): 43 parser = argparse.ArgumentParser(prog="fennol_train") 44 parser.add_argument("config_file", type=str) 45 parser.add_argument("--model_file", type=str, default=None) 46 args = parser.parse_args() 47 config_file = args.config_file 48 model_file = args.model_file 49 50 os.environ["OMP_NUM_THREADS"] = "1" 51 sys.stdout = io.TextIOWrapper( 52 open(sys.stdout.fileno(), "wb", 0), write_through=True 53 ) 54 55 restart_training = False 56 if os.path.isdir(config_file): 57 output_directory = Path(config_file).absolute().as_posix() 58 config_file = output_directory + "/config.yaml" 59 restart_training = True 60 training_state_file = output_directory + "/train_state" 61 if not os.path.exists(training_state_file): 62 raise FileNotFoundError( 63 f"Training state file not found: {training_state_file}" 64 ) 65 while output_directory.endswith("/"): 66 output_directory = output_directory[:-1] 67 backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}" 68 shutil.copytree(output_directory, backup_dir) 69 70 with open(training_state_file, "rb") as f: 71 training_state = pickle.load(f) 72 print("Restarting training from", output_directory) 73 else: 74 training_state = None 75 76 parameters = load_configuration(config_file) 77 78 ### Set the device 79 device: str = parameters.get("device", "cpu").lower() 80 if device == "cpu": 81 os.environ["CUDA_VISIBLE_DEVICES"] = "" 82 elif device.startswith("cuda") or device.startswith("gpu"): 83 if ":" in device: 84 num = device.split(":")[-1] 85 os.environ["CUDA_VISIBLE_DEVICES"] = num 86 else: 87 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 88 device = "gpu" 89 90 _device = jax.devices(device)[0] 91 jax.config.update("jax_default_device", _device) 92 93 # output directory 94 if not restart_training: 95 output_directory = parameters.get("output_directory", None) 96 if output_directory is not None: 97 if "{now}" in output_directory: 98 output_directory = output_directory.replace( 99 "{now}", time.strftime("%Y-%m-%d-%H-%M-%S") 100 ) 101 output_directory = Path(output_directory).absolute() 102 if not output_directory.exists(): 103 output_directory.mkdir(parents=True) 104 print("Output directory:", output_directory) 105 else: 106 output_directory = "." 107 108 output_directory = str(output_directory) + "/" 109 110 # copy config_file to output directory 111 # config_name = Path(config_file).name 112 config_ext = Path(config_file).suffix 113 with open(config_file) as f_in: 114 config_data = f_in.read() 115 with open(output_directory + "/config" + config_ext, "w") as f_out: 116 f_out.write(config_data) 117 118 # set log file 119 log_file = "train.log" # parameters.get("log_file", None) 120 logger = TeeLogger(output_directory + log_file) 121 logger.bind_stdout() 122 123 # set matmul precision 124 enable_x64 = parameters.get("double_precision", False) 125 jax.config.update("jax_enable_x64", enable_x64) 126 fprec = "float64" if enable_x64 else "float32" 127 parameters["fprec"] = fprec 128 if enable_x64: 129 print("Double precision enabled.") 130 131 matmul_precision = parameters.get("matmul_prec", "highest").lower() 132 assert matmul_precision in [ 133 "default", 134 "high", 135 "highest", 136 ], "matmul_prec must be one of 'default','high','highest'" 137 jax.config.update("jax_default_matmul_precision", matmul_precision) 138 139 # set random seed 140 rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1)) 141 print(f"rng_seed: {rng_seed}") 142 rng_key = jax.random.PRNGKey(rng_seed) 143 torch.manual_seed(rng_seed) 144 np.random.seed(rng_seed) 145 random.seed(rng_seed) 146 np_rng = np.random.Generator(np.random.PCG64(rng_seed)) 147 148 try: 149 if "stages" in parameters["training"]: 150 ## train in stages ## 151 params = deepcopy(parameters) 152 stages = params["training"].pop("stages") 153 assert isinstance(stages, dict), "'stages' must be a dict with named stages" 154 model_file_stage = model_file 155 print_stages_params = params["training"].get("print_stages_params", False) 156 for i, (stage, stage_params) in enumerate(stages.items()): 157 rng_key, subkey = jax.random.split(rng_key) 158 print("") 159 print(f"### STAGE {i+1}: {stage} ###") 160 161 ## remove end_event from previous stage ## 162 if i > 0 and "end_event" in params["training"]: 163 params["training"].pop("end_event") 164 165 ## incrementally update training parameters ## 166 params = deep_update(params, {"training": stage_params}) 167 if model_file_stage is not None: 168 ## load model from previous stage ## 169 params["model_file"] = model_file_stage 170 171 if restart_training and training_state["stage"] != i + 1: 172 print(f"Skipping stage {i+1} (already completed)") 173 continue 174 175 if print_stages_params: 176 print("stage parameters:") 177 print(json.dumps(params, indent=2, sort_keys=False)) 178 179 ## train stage ## 180 _, model_file_stage = train( 181 (subkey, np_rng), 182 params, 183 stage=i + 1, 184 output_directory=output_directory, 185 training_state=training_state, 186 ) 187 training_state = None 188 restart_training = False 189 else: 190 ## single training stage ## 191 train( 192 (rng_key, np_rng), 193 parameters, 194 model_file=model_file, 195 output_directory=output_directory, 196 training_state=training_state, 197 ) 198 except KeyboardInterrupt: 199 print("Training interrupted by user.") 200 finally: 201 if log_file is not None: 202 logger.unbind_stdout() 203 logger.close() 204 205 206def train( 207 rng, 208 parameters, 209 model_file=None, 210 stage=None, 211 output_directory=None, 212 training_state=None, 213): 214 if output_directory is None: 215 output_directory = "./" 216 elif not output_directory.endswith("/"): 217 output_directory += "/" 218 stage_prefix = f"_stage_{stage}" if stage is not None else "" 219 220 if isinstance(rng, tuple): 221 rng_key, np_rng = rng 222 else: 223 rng_key = rng 224 np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1))) 225 226 if training_state is not None: 227 model_key = None 228 model_file = output_directory + "latest_model.fnx" 229 else: 230 rng_key, model_key = jax.random.split(rng_key) 231 model = load_model(parameters, model_file, rng_key=model_key) 232 233 training_parameters = parameters.get("training", {}) 234 model_ref = None 235 if "model_ref" in training_parameters: 236 model_ref = load_model(parameters, training_parameters["model_ref"]) 237 print("Reference model:", training_parameters["model_ref"]) 238 239 if "ref_parameters" in training_parameters: 240 ref_parameters = training_parameters["ref_parameters"] 241 assert isinstance( 242 ref_parameters, list 243 ), "ref_parameters must be a list of str" 244 print("Reference parameters:", ref_parameters) 245 model.variables = copy_parameters( 246 model.variables, model_ref.variables, ref_parameters 247 ) 248 fprec = parameters.get("fprec", "float32") 249 250 def convert_to_fprec(x): 251 if jnp.issubdtype(x.dtype, jnp.floating): 252 return x.astype(fprec) 253 return x 254 255 model.variables = jax.tree_map(convert_to_fprec, model.variables) 256 257 loss_definition, used_keys, ref_keys = get_loss_definition( 258 training_parameters, model_energy_unit=model.energy_unit 259 ) 260 261 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 262 if coordinates_ref_key is not None: 263 compute_ref_coords = True 264 print("Reference coordinates:", coordinates_ref_key) 265 else: 266 compute_ref_coords = False 267 268 dspath = training_parameters.get("dspath", None) 269 if dspath is None: 270 raise ValueError("Dataset path 'training/dspath' should be specified.") 271 batch_size = training_parameters.get("batch_size", 16) 272 rename_refs = training_parameters.get("rename_refs", {}) 273 training_iterator, validation_iterator = load_dataset( 274 dspath=dspath, 275 batch_size=batch_size, 276 training_parameters=training_parameters, 277 infinite_iterator=True, 278 atom_padding=True, 279 ref_keys=ref_keys, 280 split_data_inputs=True, 281 np_rng=np_rng, 282 add_flags=["training"], 283 fprec=fprec, 284 rename_refs=rename_refs, 285 ) 286 287 compute_forces = "forces" in used_keys 288 compute_virial = "virial_tensor" in used_keys or "virial" in used_keys 289 compute_stress = "stress_tensor" in used_keys or "stress" in used_keys 290 compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys 291 292 # get optimizer parameters 293 lr = training_parameters.get("lr", 1.0e-3) 294 max_epochs = training_parameters.get("max_epochs", 2000) 295 nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200) 296 nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20) 297 init_lr = training_parameters.get("init_lr", lr / 25) 298 final_lr = training_parameters.get("final_lr", lr / 10000) 299 300 schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower() 301 schedule_type = training_parameters.get("scheduler", schedule_type).lower() 302 schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot") 303 304 adaptive_scheduler = False 305 print("Schedule type:", schedule_type) 306 if schedule_type == "cosine_onecycle": 307 transition_epochs = training_parameters.get("onecycle_epochs", max_epochs) 308 peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs) 309 schedule_ = optax.cosine_onecycle_schedule( 310 peak_value=lr, 311 div_factor=lr / init_lr, 312 final_div_factor=init_lr / final_lr, 313 transition_steps=transition_epochs * nbatch_per_epoch, 314 pct_start=peak_epoch / transition_epochs, 315 ) 316 sch_state = {"count": 0, "best": np.inf, "lr": init_lr} 317 318 def schedule(state, rmse=None): 319 new_state = {**state} 320 lr = schedule_(state["count"]) 321 if rmse is None: 322 new_state["count"] += 1 323 new_state["lr"] = lr 324 return lr, new_state 325 326 elif schedule_type == "constant": 327 sch_state = {"count": 0} 328 329 def schedule(state, rmse=None): 330 new_state = {**state} 331 new_state["lr"] = lr 332 if rmse is None: 333 new_state["count"] += 1 334 return lr, new_state 335 336 elif schedule_type == "reduce_on_plateau": 337 patience = training_parameters.get("patience", 10) 338 factor = training_parameters.get("lr_factor", 0.5) 339 patience_thr = training_parameters.get("patience_thr", 0.0) 340 sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience} 341 adaptive_scheduler = True 342 343 def schedule(state, rmse=None): 344 new_state = {**state} 345 if rmse is None: 346 new_state["count"] += 1 347 return state["lr"], new_state 348 if rmse <= state["best"] * (1.0 + patience_thr): 349 if rmse < state["best"]: 350 new_state["best"] = rmse 351 new_state["patience"] = 0 352 else: 353 new_state["patience"] += 1 354 if new_state["patience"] >= patience: 355 new_state["lr"] = state["lr"] * factor 356 new_state["patience"] = 0 357 print("Reducing learning rate to", new_state["lr"]) 358 return new_state["lr"], new_state 359 360 else: 361 raise ValueError(f"Unknown schedule_type: {schedule_type}") 362 363 stochastic_scheduler = training_parameters.get("stochastic_scheduler", False) 364 if stochastic_scheduler: 365 schedule_ = schedule 366 rng_key, scheduler_key = jax.random.split(rng_key) 367 sch_state["rng_key"] = scheduler_key 368 sch_state["lr_max"] = lr 369 sch_state["lr_min"] = final_lr 370 371 def schedule(state, rmse=None): 372 new_state = {**state, "lr": state["lr_max"]} 373 if rmse is None: 374 lr_max, new_state = schedule_(new_state, rmse=rmse) 375 lr_min = new_state["lr_min"] 376 new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"]) 377 lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey) 378 new_state["lr"] = lr 379 new_state["lr_max"] = lr_max 380 381 return new_state["lr"], new_state 382 383 optimizer = get_optimizer( 384 training_parameters, model.variables, schedule(sch_state)[0] 385 ) 386 opt_st = optimizer.init(model.variables) 387 388 # exponential moving average of the parameters 389 ema_decay = training_parameters.get("ema_decay", -1.0) 390 if ema_decay > 0.0: 391 assert ema_decay < 1.0, "ema_decay must be in (0,1)" 392 ema = optax.ema(decay=ema_decay) 393 else: 394 ema = optax.identity() 395 ema_st = ema.init(model.variables) 396 397 # end event 398 end_event = training_parameters.get("end_event", None) 399 if end_event is None or isinstance(end_event, str) and end_event.lower() == "none": 400 is_end = lambda metrics: False 401 else: 402 assert len(end_event) == 2, "end_event must be a list of two elements" 403 is_end = lambda metrics: metrics[end_event[0]] < end_event[1] 404 405 print_timings = parameters.get("print_timings", False) 406 407 if "energy_terms" in training_parameters: 408 model.set_energy_terms(training_parameters["energy_terms"], jit=False) 409 print("energy terms:", model.energy_terms) 410 411 pbc_training = training_parameters.get("pbc_training", False) 412 if compute_stress or compute_virial or compute_pressure: 413 virial_key = "virial" if "virial" in used_keys else "virial_tensor" 414 stress_key = "stress" if "stress" in used_keys else "stress_tensor" 415 pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor" 416 if compute_stress or compute_pressure: 417 assert pbc_training, "PBC must be enabled for stress or virial training" 418 print("Computing forces and stress tensor") 419 420 def evaluate(model, variables, data): 421 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 422 cells = output["cells"] 423 volume = jnp.abs(jnp.linalg.det(cells)) 424 stress = vir / volume[:, None, None] 425 output[stress_key] = stress 426 output[virial_key] = vir 427 if pressure_key == "pressure": 428 output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0 429 else: 430 output[pressure_key] = -stress 431 return output 432 433 else: 434 print("Computing forces and virial tensor") 435 436 def evaluate(model, variables, data): 437 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 438 output[virial_key] = vir 439 return output 440 441 elif compute_forces: 442 print("Computing forces") 443 444 def evaluate(model, variables, data): 445 _, _, output = model._energy_and_forces(variables, data) 446 return output 447 448 elif model.energy_terms is not None: 449 450 def evaluate(model, variables, data): 451 _, output = model._total_energy(variables, data) 452 return output 453 454 else: 455 456 def evaluate(model, variables, data): 457 output = model.modules.apply(variables, data) 458 return output 459 460 train_step = get_train_step_function( 461 loss_definition=loss_definition, 462 model=model, 463 model_ref=model_ref, 464 compute_ref_coords=compute_ref_coords, 465 evaluate=evaluate, 466 optimizer=optimizer, 467 ema=ema, 468 ) 469 470 validation = get_validation_function( 471 loss_definition=loss_definition, 472 model=model, 473 model_ref=model_ref, 474 compute_ref_coords=compute_ref_coords, 475 evaluate=evaluate, 476 return_targets=False, 477 ) 478 479 ## configure preprocessing ## 480 minimum_image = training_parameters.get("minimum_image", False) 481 preproc_state = unfreeze(model.preproc_state) 482 layer_state = [] 483 for st in preproc_state["layers_state"]: 484 stnew = unfreeze(st) 485 # st["nblist_skin"] = nblist_skin 486 # if nblist_stride > 1: 487 # st["skin_stride"] = nblist_stride 488 # st["skin_count"] = nblist_stride 489 if pbc_training: 490 stnew["minimum_image"] = minimum_image 491 if "nblist_mult_size" in training_parameters: 492 stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 493 if "nblist_add_neigh" in training_parameters: 494 stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 495 if "nblist_add_atoms" in training_parameters: 496 stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 497 layer_state.append(freeze(stnew)) 498 499 preproc_state["layers_state"] = tuple(layer_state) 500 model.preproc_state = freeze(preproc_state) 501 502 # inputs,data = next(training_iterator) 503 # inputs = model.preprocess(**inputs) 504 # print("preproc_state:",model.preproc_state) 505 506 if training_parameters.get("gpu_preprocessing", False): 507 print("GPU preprocessing activated.") 508 509 def preprocessing(model, inputs): 510 preproc_state = model.preproc_state 511 outputs = model.preprocessing.process(preproc_state, inputs) 512 preproc_state, state_up, outputs, overflow = ( 513 model.preprocessing.check_reallocate(preproc_state, outputs) 514 ) 515 if overflow: 516 print("GPU preprocessing: nblist overflow => reallocating nblist") 517 print("size updates:", state_up) 518 model.preproc_state = preproc_state 519 return outputs 520 521 else: 522 preprocessing = lambda model, inputs: model.preprocess(**inputs) 523 524 fetch_time = 0.0 525 preprocess_time = 0.0 526 step_time = 0.0 527 528 maes_prev = defaultdict(lambda: np.inf) 529 metrics_beta = training_parameters.get("metrics_ema_decay", -1.0) 530 smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0 531 if smoothen_metrics: 532 print("Computing smoothed metrics with beta =", metrics_beta) 533 rmses_smooth = defaultdict(lambda: 0.0) 534 maes_smooth = defaultdict(lambda: 0.0) 535 rmse_tot_smooth = 0.0 536 mae_tot_smooth = 0.0 537 nsmooth = 0 538 count = 0 539 restore_count = 0 540 max_restore_count = training_parameters.get("max_restore_count", 5) 541 variables = deepcopy(model.variables) 542 variables_save = deepcopy(variables) 543 variables_ema_save = deepcopy(model.variables) 544 545 fmetrics = open( 546 output_directory + f"metrics{stage_prefix}.traj", 547 "w" if training_state is None else "a", 548 ) 549 550 keep_all_bests = training_parameters.get("keep_all_bests", False) 551 previous_best_name = None 552 best_metric = np.inf 553 metric_use_best = training_parameters.get("metric_best", "rmse_tot") # .lower() 554 # authorized_metrics = ["mae", "rmse"] 555 # if smoothen_metrics: 556 # authorized_metrics+= ["mae_smooth", "rmse_smooth"] 557 # assert metric_use_best in authorized_metrics, f"metric_best must be one of {authorized_metrics}" 558 559 if training_state is not None: 560 preproc_state = training_state["preproc_state"] 561 model.preproc_state = freeze(preproc_state) 562 opt_st = training_state["opt_state"] 563 ema_st = training_state["ema_state"] 564 sch_state = training_state["sch_state"] 565 variables = training_state["variables"] 566 model.variables = training_state["variables_ema"] 567 count = training_state["count"] 568 restore_count = training_state["restore_count"] 569 epoch_start = training_state["epoch"] 570 rng_key = training_state["rng_key"] 571 best_metric = training_state["best_metric"] 572 if smoothen_metrics: 573 rmses_smooth = training_state["rmses_smooth"] 574 maes_smooth = training_state["maes_smooth"] 575 rmse_tot_smooth = training_state["rmse_tot_smooth"] 576 mae_tot_smooth = training_state["mae_tot_smooth"] 577 nsmooth = training_state["nsmooth"] 578 print("Restored training state") 579 else: 580 epoch_start = 0 581 training_state = { 582 "preproc_state": preproc_state, 583 "rng_key": rng_key, 584 "opt_state": opt_st, 585 "ema_state": ema_st, 586 "sch_state": sch_state, 587 "variables": variables, 588 "variables_ema": model.variables, 589 "count": count, 590 "restore_count": restore_count, 591 "epoch": 0, 592 "stage": stage, 593 "best_metric": np.inf, 594 } 595 if smoothen_metrics: 596 training_state["rmses_smooth"] = dict(rmses_smooth) 597 training_state["maes_smooth"] = dict(maes_smooth) 598 training_state["rmse_tot_smooth"] = rmse_tot_smooth 599 training_state["mae_tot_smooth"] = mae_tot_smooth 600 training_state["nsmooth"] = nsmooth 601 602 ### Training loop ### 603 start = time.time() 604 print("Starting training...") 605 for epoch in range(epoch_start, max_epochs): 606 s = time.time() 607 for _ in range(nbatch_per_epoch): 608 # fetch data 609 inputs0, data = next(training_iterator) 610 611 # preprocess data 612 inputs = preprocessing(model, inputs0) 613 # inputs = model.preprocess(**inputs0) 614 615 rng_key, subkey = jax.random.split(rng_key) 616 inputs["rng_key"] = subkey 617 # if compute_ref_coords: 618 # inputs_ref = {**inputs0, "coordinates": data[coordinates_ref_key]} 619 # inputs_ref = model.preprocess(**inputs_ref) 620 # else: 621 # inputs_ref = None 622 # if print_timings: 623 # jax.block_until_ready(inputs["coordinates"]) 624 625 # train step 626 # opt_st.inner_states["trainable"].inner_state[1].hyperparams[ 627 # "learning_rate" 628 # ] = schedule(count) 629 current_lr, sch_state = schedule(sch_state) 630 opt_st.inner_states["trainable"].inner_state[-1].hyperparams[ 631 "step_size" 632 ] = current_lr 633 loss, variables, opt_st, model.variables, ema_st, output = train_step( 634 epoch=epoch, 635 data=data, 636 inputs=inputs, 637 variables=variables, 638 variables_ema=model.variables, 639 opt_st=opt_st, 640 ema_st=ema_st, 641 ) 642 count += 1 643 644 rmses_avg = defaultdict(lambda: 0.0) 645 maes_avg = defaultdict(lambda: 0.0) 646 for _ in range(nbatch_per_validation): 647 inputs0, data = next(validation_iterator) 648 649 inputs = preprocessing(model, inputs0) 650 # inputs = model.preprocess(**inputs0) 651 652 rng_key, subkey = jax.random.split(rng_key) 653 inputs["rng_key"] = subkey 654 655 # if compute_ref_coords: 656 # inputs_ref = {**inputs0, "coordinates": data[coordinates_ref_key]} 657 # inputs_ref = model.preprocess(**inputs_ref) 658 # else: 659 # inputs_ref = None 660 rmses, maes, output_val = validation( 661 data=data, 662 inputs=inputs, 663 variables=model.variables, 664 # inputs_ref=inputs_ref, 665 ) 666 for k, v in rmses.items(): 667 rmses_avg[k] += v 668 for k, v in maes.items(): 669 maes_avg[k] += v 670 671 jax.block_until_ready(output_val) 672 e = time.time() 673 epoch_time = e - s 674 675 elapsed_time = e - start 676 if not adaptive_scheduler: 677 remain_glob = elapsed_time * (max_epochs - epoch) / (epoch + 1) 678 remain_last = epoch_time * (max_epochs - epoch) 679 # estimate remaining time via weighted average (put more weight on last at the beginning) 680 wremain = np.sin(0.5 * np.pi * (epoch + 1) / max_epochs) 681 remaining_time = human_time_duration( 682 remain_glob * wremain + remain_last * (1 - wremain) 683 ) 684 elapsed_time = human_time_duration(elapsed_time) 685 batch_time = human_time_duration( 686 epoch_time / (nbatch_per_epoch + nbatch_per_validation) 687 ) 688 epoch_time = human_time_duration(epoch_time) 689 690 for k in rmses_avg.keys(): 691 rmses_avg[k] /= nbatch_per_validation 692 for k in maes_avg.keys(): 693 maes_avg[k] /= nbatch_per_validation 694 695 step_time /= nbatch_per_epoch 696 fetch_time /= nbatch_per_epoch 697 preprocess_time /= nbatch_per_epoch 698 699 print("") 700 line = f"Epoch {epoch+1}, lr={current_lr:.3e}, loss = {loss:.3e}" 701 line += f", epoch time = {epoch_time}, batch time = {batch_time}" 702 line += f", elapsed time = {elapsed_time}" 703 if not adaptive_scheduler: 704 line += f", est. remaining time = {remaining_time}" 705 print(line) 706 rmse_tot = 0.0 707 mae_tot = 0.0 708 for k in rmses_avg.keys(): 709 mult = loss_definition[k]["mult"] 710 loss_prms = loss_definition[k] 711 rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5 712 mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"] 713 unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else "" 714 715 weight_str = "" 716 if "weight_schedule" in loss_prms: 717 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 718 weight_str = f"(w={w:.3f})" 719 720 if rmses_avg[k] / mult < 1.0e-2: 721 print( 722 f" rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e} {unit} {weight_str}" 723 ) 724 else: 725 print( 726 f" rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f} {unit} {weight_str}" 727 ) 728 729 # if print_timings: 730 # print( 731 # f" Timings per batch: fetch time = {fetch_time:.5f}; preprocess time = {preprocess_time:.5f}; train time = {step_time:.5f}" 732 # ) 733 fetch_time = 0.0 734 step_time = 0.0 735 preprocess_time = 0.0 736 restore = False 737 reinit = False 738 for k, mae in maes_avg.items(): 739 if np.isnan(mae): 740 restore = True 741 reinit = True 742 print("NaN detected in mae") 743 # for k, v in inputs.items(): 744 # if hasattr(v,"shape"): 745 # if np.isnan(v).any(): 746 # print(k,v) 747 # sys.exit(1) 748 if "threshold" in loss_definition[k]: 749 thr = loss_definition[k]["threshold"] 750 if mae > thr * maes_prev[k]: 751 restore = True 752 break 753 754 if restore: 755 restore_count += 1 756 if restore_count > max_restore_count: 757 if reinit: 758 raise ValueError("Model diverged and could not be restored.") 759 else: 760 restore_count = 0 761 print( 762 f"{max_restore_count} unsuccessful restores, resuming training." 763 ) 764 continue 765 766 variables = deepcopy(variables_save) 767 model.variables = deepcopy(variables_ema_save) 768 print("Restored previous model after divergence.") 769 if reinit: 770 opt_st = optimizer.init(model.variables) 771 ema_st = ema.init(model.variables) 772 print("Reinitialized optimizer after divergence.") 773 continue 774 775 restore_count = 0 776 variables_save = deepcopy(variables) 777 variables_ema_save = deepcopy(model.variables) 778 maes_prev = maes_avg 779 780 model.save(output_directory + "latest_model.fnx") 781 782 # save metrics 783 metrics = { 784 "epoch": epoch + 1, 785 "step": count, 786 "data_count": count * batch_size, 787 "elapsed_time": time.time() - start, 788 "lr": current_lr, 789 "loss": loss, 790 } 791 for k in rmses_avg.keys(): 792 mult = loss_definition[k]["mult"] 793 metrics[f"rmse_{k}"] = rmses_avg[k] / mult 794 metrics[f"mae_{k}"] = maes_avg[k] / mult 795 796 metrics["rmse_tot"] = rmse_tot 797 metrics["mae_tot"] = mae_tot 798 if smoothen_metrics: 799 nsmooth += 1 800 for k in rmses_avg.keys(): 801 mult = loss_definition[k]["mult"] 802 rmses_smooth[k] = ( 803 metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k] 804 ) 805 maes_smooth[k] = ( 806 metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k] 807 ) 808 metrics[f"rmse_smooth_{k}"] = ( 809 rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 810 ) 811 metrics[f"mae_smooth_{k}"] = ( 812 maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 813 ) 814 rmse_tot_smooth = ( 815 metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot 816 ) 817 mae_tot_smooth = ( 818 metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot 819 ) 820 metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth) 821 metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth) 822 823 assert ( 824 metric_use_best in metrics 825 ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics" 826 metric_for_best = metrics[metric_use_best] 827 if metric_for_best < best_metric: 828 best_metric = metric_for_best 829 metrics["best_metric"] = best_metric 830 if keep_all_bests: 831 best_name = ( 832 output_directory 833 + f"best_model{stage_prefix}_{time.strftime('%Y-%m-%d-%H-%M-%S')}.fnx" 834 ) 835 model.save(best_name) 836 837 best_name = output_directory + f"best_model{stage_prefix}.fnx" 838 model.save(best_name) 839 print("New best model saved to:", best_name) 840 else: 841 metrics["best_metric"] = best_metric 842 843 if epoch == 0: 844 headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())] 845 fmetrics.write("# " + " ".join(headers) + "\n") 846 fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n") 847 fmetrics.flush() 848 849 # update learning rate using current metrics 850 assert ( 851 schedule_metrics in metrics 852 ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics" 853 current_lr, sch_state = schedule(sch_state, metrics[schedule_metrics]) 854 855 # update and save training state 856 training_state["preproc_state"] = model.preproc_state 857 training_state["opt_state"] = opt_st 858 training_state["ema_state"] = ema_st 859 training_state["sch_state"] = sch_state 860 training_state["variables"] = variables 861 training_state["variables_ema"] = model.variables 862 training_state["count"] = count 863 training_state["restore_count"] = restore_count 864 training_state["epoch"] = epoch 865 training_state["best_metric"] = best_metric 866 if smoothen_metrics: 867 training_state["rmses_smooth"] = dict(rmses_smooth) 868 training_state["maes_smooth"] = dict(maes_smooth) 869 training_state["rmse_tot_smooth"] = rmse_tot_smooth 870 training_state["mae_tot_smooth"] = mae_tot_smooth 871 training_state["nsmooth"] = nsmooth 872 873 with open(output_directory + "train_state", "wb") as f: 874 pickle.dump(training_state, f) 875 876 if is_end(metrics): 877 print("Stage finished.") 878 break 879 880 end = time.time() 881 882 print(f"Training time: {human_time_duration(end-start)}") 883 print("") 884 885 fmetrics.close() 886 887 filename = output_directory + f"final_model{stage_prefix}.fnx" 888 model.save(filename) 889 print("Final model saved to:", filename) 890 891 return metrics, filename 892 893 894if __name__ == "__main__": 895 main()
def
main():
43def main(): 44 parser = argparse.ArgumentParser(prog="fennol_train") 45 parser.add_argument("config_file", type=str) 46 parser.add_argument("--model_file", type=str, default=None) 47 args = parser.parse_args() 48 config_file = args.config_file 49 model_file = args.model_file 50 51 os.environ["OMP_NUM_THREADS"] = "1" 52 sys.stdout = io.TextIOWrapper( 53 open(sys.stdout.fileno(), "wb", 0), write_through=True 54 ) 55 56 restart_training = False 57 if os.path.isdir(config_file): 58 output_directory = Path(config_file).absolute().as_posix() 59 config_file = output_directory + "/config.yaml" 60 restart_training = True 61 training_state_file = output_directory + "/train_state" 62 if not os.path.exists(training_state_file): 63 raise FileNotFoundError( 64 f"Training state file not found: {training_state_file}" 65 ) 66 while output_directory.endswith("/"): 67 output_directory = output_directory[:-1] 68 backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}" 69 shutil.copytree(output_directory, backup_dir) 70 71 with open(training_state_file, "rb") as f: 72 training_state = pickle.load(f) 73 print("Restarting training from", output_directory) 74 else: 75 training_state = None 76 77 parameters = load_configuration(config_file) 78 79 ### Set the device 80 device: str = parameters.get("device", "cpu").lower() 81 if device == "cpu": 82 os.environ["CUDA_VISIBLE_DEVICES"] = "" 83 elif device.startswith("cuda") or device.startswith("gpu"): 84 if ":" in device: 85 num = device.split(":")[-1] 86 os.environ["CUDA_VISIBLE_DEVICES"] = num 87 else: 88 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 89 device = "gpu" 90 91 _device = jax.devices(device)[0] 92 jax.config.update("jax_default_device", _device) 93 94 # output directory 95 if not restart_training: 96 output_directory = parameters.get("output_directory", None) 97 if output_directory is not None: 98 if "{now}" in output_directory: 99 output_directory = output_directory.replace( 100 "{now}", time.strftime("%Y-%m-%d-%H-%M-%S") 101 ) 102 output_directory = Path(output_directory).absolute() 103 if not output_directory.exists(): 104 output_directory.mkdir(parents=True) 105 print("Output directory:", output_directory) 106 else: 107 output_directory = "." 108 109 output_directory = str(output_directory) + "/" 110 111 # copy config_file to output directory 112 # config_name = Path(config_file).name 113 config_ext = Path(config_file).suffix 114 with open(config_file) as f_in: 115 config_data = f_in.read() 116 with open(output_directory + "/config" + config_ext, "w") as f_out: 117 f_out.write(config_data) 118 119 # set log file 120 log_file = "train.log" # parameters.get("log_file", None) 121 logger = TeeLogger(output_directory + log_file) 122 logger.bind_stdout() 123 124 # set matmul precision 125 enable_x64 = parameters.get("double_precision", False) 126 jax.config.update("jax_enable_x64", enable_x64) 127 fprec = "float64" if enable_x64 else "float32" 128 parameters["fprec"] = fprec 129 if enable_x64: 130 print("Double precision enabled.") 131 132 matmul_precision = parameters.get("matmul_prec", "highest").lower() 133 assert matmul_precision in [ 134 "default", 135 "high", 136 "highest", 137 ], "matmul_prec must be one of 'default','high','highest'" 138 jax.config.update("jax_default_matmul_precision", matmul_precision) 139 140 # set random seed 141 rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1)) 142 print(f"rng_seed: {rng_seed}") 143 rng_key = jax.random.PRNGKey(rng_seed) 144 torch.manual_seed(rng_seed) 145 np.random.seed(rng_seed) 146 random.seed(rng_seed) 147 np_rng = np.random.Generator(np.random.PCG64(rng_seed)) 148 149 try: 150 if "stages" in parameters["training"]: 151 ## train in stages ## 152 params = deepcopy(parameters) 153 stages = params["training"].pop("stages") 154 assert isinstance(stages, dict), "'stages' must be a dict with named stages" 155 model_file_stage = model_file 156 print_stages_params = params["training"].get("print_stages_params", False) 157 for i, (stage, stage_params) in enumerate(stages.items()): 158 rng_key, subkey = jax.random.split(rng_key) 159 print("") 160 print(f"### STAGE {i+1}: {stage} ###") 161 162 ## remove end_event from previous stage ## 163 if i > 0 and "end_event" in params["training"]: 164 params["training"].pop("end_event") 165 166 ## incrementally update training parameters ## 167 params = deep_update(params, {"training": stage_params}) 168 if model_file_stage is not None: 169 ## load model from previous stage ## 170 params["model_file"] = model_file_stage 171 172 if restart_training and training_state["stage"] != i + 1: 173 print(f"Skipping stage {i+1} (already completed)") 174 continue 175 176 if print_stages_params: 177 print("stage parameters:") 178 print(json.dumps(params, indent=2, sort_keys=False)) 179 180 ## train stage ## 181 _, model_file_stage = train( 182 (subkey, np_rng), 183 params, 184 stage=i + 1, 185 output_directory=output_directory, 186 training_state=training_state, 187 ) 188 training_state = None 189 restart_training = False 190 else: 191 ## single training stage ## 192 train( 193 (rng_key, np_rng), 194 parameters, 195 model_file=model_file, 196 output_directory=output_directory, 197 training_state=training_state, 198 ) 199 except KeyboardInterrupt: 200 print("Training interrupted by user.") 201 finally: 202 if log_file is not None: 203 logger.unbind_stdout() 204 logger.close()
def
train( rng, parameters, model_file=None, stage=None, output_directory=None, training_state=None):
207def train( 208 rng, 209 parameters, 210 model_file=None, 211 stage=None, 212 output_directory=None, 213 training_state=None, 214): 215 if output_directory is None: 216 output_directory = "./" 217 elif not output_directory.endswith("/"): 218 output_directory += "/" 219 stage_prefix = f"_stage_{stage}" if stage is not None else "" 220 221 if isinstance(rng, tuple): 222 rng_key, np_rng = rng 223 else: 224 rng_key = rng 225 np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1))) 226 227 if training_state is not None: 228 model_key = None 229 model_file = output_directory + "latest_model.fnx" 230 else: 231 rng_key, model_key = jax.random.split(rng_key) 232 model = load_model(parameters, model_file, rng_key=model_key) 233 234 training_parameters = parameters.get("training", {}) 235 model_ref = None 236 if "model_ref" in training_parameters: 237 model_ref = load_model(parameters, training_parameters["model_ref"]) 238 print("Reference model:", training_parameters["model_ref"]) 239 240 if "ref_parameters" in training_parameters: 241 ref_parameters = training_parameters["ref_parameters"] 242 assert isinstance( 243 ref_parameters, list 244 ), "ref_parameters must be a list of str" 245 print("Reference parameters:", ref_parameters) 246 model.variables = copy_parameters( 247 model.variables, model_ref.variables, ref_parameters 248 ) 249 fprec = parameters.get("fprec", "float32") 250 251 def convert_to_fprec(x): 252 if jnp.issubdtype(x.dtype, jnp.floating): 253 return x.astype(fprec) 254 return x 255 256 model.variables = jax.tree_map(convert_to_fprec, model.variables) 257 258 loss_definition, used_keys, ref_keys = get_loss_definition( 259 training_parameters, model_energy_unit=model.energy_unit 260 ) 261 262 coordinates_ref_key = training_parameters.get("coordinates_ref_key", None) 263 if coordinates_ref_key is not None: 264 compute_ref_coords = True 265 print("Reference coordinates:", coordinates_ref_key) 266 else: 267 compute_ref_coords = False 268 269 dspath = training_parameters.get("dspath", None) 270 if dspath is None: 271 raise ValueError("Dataset path 'training/dspath' should be specified.") 272 batch_size = training_parameters.get("batch_size", 16) 273 rename_refs = training_parameters.get("rename_refs", {}) 274 training_iterator, validation_iterator = load_dataset( 275 dspath=dspath, 276 batch_size=batch_size, 277 training_parameters=training_parameters, 278 infinite_iterator=True, 279 atom_padding=True, 280 ref_keys=ref_keys, 281 split_data_inputs=True, 282 np_rng=np_rng, 283 add_flags=["training"], 284 fprec=fprec, 285 rename_refs=rename_refs, 286 ) 287 288 compute_forces = "forces" in used_keys 289 compute_virial = "virial_tensor" in used_keys or "virial" in used_keys 290 compute_stress = "stress_tensor" in used_keys or "stress" in used_keys 291 compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys 292 293 # get optimizer parameters 294 lr = training_parameters.get("lr", 1.0e-3) 295 max_epochs = training_parameters.get("max_epochs", 2000) 296 nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200) 297 nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20) 298 init_lr = training_parameters.get("init_lr", lr / 25) 299 final_lr = training_parameters.get("final_lr", lr / 10000) 300 301 schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower() 302 schedule_type = training_parameters.get("scheduler", schedule_type).lower() 303 schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot") 304 305 adaptive_scheduler = False 306 print("Schedule type:", schedule_type) 307 if schedule_type == "cosine_onecycle": 308 transition_epochs = training_parameters.get("onecycle_epochs", max_epochs) 309 peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs) 310 schedule_ = optax.cosine_onecycle_schedule( 311 peak_value=lr, 312 div_factor=lr / init_lr, 313 final_div_factor=init_lr / final_lr, 314 transition_steps=transition_epochs * nbatch_per_epoch, 315 pct_start=peak_epoch / transition_epochs, 316 ) 317 sch_state = {"count": 0, "best": np.inf, "lr": init_lr} 318 319 def schedule(state, rmse=None): 320 new_state = {**state} 321 lr = schedule_(state["count"]) 322 if rmse is None: 323 new_state["count"] += 1 324 new_state["lr"] = lr 325 return lr, new_state 326 327 elif schedule_type == "constant": 328 sch_state = {"count": 0} 329 330 def schedule(state, rmse=None): 331 new_state = {**state} 332 new_state["lr"] = lr 333 if rmse is None: 334 new_state["count"] += 1 335 return lr, new_state 336 337 elif schedule_type == "reduce_on_plateau": 338 patience = training_parameters.get("patience", 10) 339 factor = training_parameters.get("lr_factor", 0.5) 340 patience_thr = training_parameters.get("patience_thr", 0.0) 341 sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience} 342 adaptive_scheduler = True 343 344 def schedule(state, rmse=None): 345 new_state = {**state} 346 if rmse is None: 347 new_state["count"] += 1 348 return state["lr"], new_state 349 if rmse <= state["best"] * (1.0 + patience_thr): 350 if rmse < state["best"]: 351 new_state["best"] = rmse 352 new_state["patience"] = 0 353 else: 354 new_state["patience"] += 1 355 if new_state["patience"] >= patience: 356 new_state["lr"] = state["lr"] * factor 357 new_state["patience"] = 0 358 print("Reducing learning rate to", new_state["lr"]) 359 return new_state["lr"], new_state 360 361 else: 362 raise ValueError(f"Unknown schedule_type: {schedule_type}") 363 364 stochastic_scheduler = training_parameters.get("stochastic_scheduler", False) 365 if stochastic_scheduler: 366 schedule_ = schedule 367 rng_key, scheduler_key = jax.random.split(rng_key) 368 sch_state["rng_key"] = scheduler_key 369 sch_state["lr_max"] = lr 370 sch_state["lr_min"] = final_lr 371 372 def schedule(state, rmse=None): 373 new_state = {**state, "lr": state["lr_max"]} 374 if rmse is None: 375 lr_max, new_state = schedule_(new_state, rmse=rmse) 376 lr_min = new_state["lr_min"] 377 new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"]) 378 lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey) 379 new_state["lr"] = lr 380 new_state["lr_max"] = lr_max 381 382 return new_state["lr"], new_state 383 384 optimizer = get_optimizer( 385 training_parameters, model.variables, schedule(sch_state)[0] 386 ) 387 opt_st = optimizer.init(model.variables) 388 389 # exponential moving average of the parameters 390 ema_decay = training_parameters.get("ema_decay", -1.0) 391 if ema_decay > 0.0: 392 assert ema_decay < 1.0, "ema_decay must be in (0,1)" 393 ema = optax.ema(decay=ema_decay) 394 else: 395 ema = optax.identity() 396 ema_st = ema.init(model.variables) 397 398 # end event 399 end_event = training_parameters.get("end_event", None) 400 if end_event is None or isinstance(end_event, str) and end_event.lower() == "none": 401 is_end = lambda metrics: False 402 else: 403 assert len(end_event) == 2, "end_event must be a list of two elements" 404 is_end = lambda metrics: metrics[end_event[0]] < end_event[1] 405 406 print_timings = parameters.get("print_timings", False) 407 408 if "energy_terms" in training_parameters: 409 model.set_energy_terms(training_parameters["energy_terms"], jit=False) 410 print("energy terms:", model.energy_terms) 411 412 pbc_training = training_parameters.get("pbc_training", False) 413 if compute_stress or compute_virial or compute_pressure: 414 virial_key = "virial" if "virial" in used_keys else "virial_tensor" 415 stress_key = "stress" if "stress" in used_keys else "stress_tensor" 416 pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor" 417 if compute_stress or compute_pressure: 418 assert pbc_training, "PBC must be enabled for stress or virial training" 419 print("Computing forces and stress tensor") 420 421 def evaluate(model, variables, data): 422 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 423 cells = output["cells"] 424 volume = jnp.abs(jnp.linalg.det(cells)) 425 stress = vir / volume[:, None, None] 426 output[stress_key] = stress 427 output[virial_key] = vir 428 if pressure_key == "pressure": 429 output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0 430 else: 431 output[pressure_key] = -stress 432 return output 433 434 else: 435 print("Computing forces and virial tensor") 436 437 def evaluate(model, variables, data): 438 _, _, vir, output = model._energy_and_forces_and_virial(variables, data) 439 output[virial_key] = vir 440 return output 441 442 elif compute_forces: 443 print("Computing forces") 444 445 def evaluate(model, variables, data): 446 _, _, output = model._energy_and_forces(variables, data) 447 return output 448 449 elif model.energy_terms is not None: 450 451 def evaluate(model, variables, data): 452 _, output = model._total_energy(variables, data) 453 return output 454 455 else: 456 457 def evaluate(model, variables, data): 458 output = model.modules.apply(variables, data) 459 return output 460 461 train_step = get_train_step_function( 462 loss_definition=loss_definition, 463 model=model, 464 model_ref=model_ref, 465 compute_ref_coords=compute_ref_coords, 466 evaluate=evaluate, 467 optimizer=optimizer, 468 ema=ema, 469 ) 470 471 validation = get_validation_function( 472 loss_definition=loss_definition, 473 model=model, 474 model_ref=model_ref, 475 compute_ref_coords=compute_ref_coords, 476 evaluate=evaluate, 477 return_targets=False, 478 ) 479 480 ## configure preprocessing ## 481 minimum_image = training_parameters.get("minimum_image", False) 482 preproc_state = unfreeze(model.preproc_state) 483 layer_state = [] 484 for st in preproc_state["layers_state"]: 485 stnew = unfreeze(st) 486 # st["nblist_skin"] = nblist_skin 487 # if nblist_stride > 1: 488 # st["skin_stride"] = nblist_stride 489 # st["skin_count"] = nblist_stride 490 if pbc_training: 491 stnew["minimum_image"] = minimum_image 492 if "nblist_mult_size" in training_parameters: 493 stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 494 if "nblist_add_neigh" in training_parameters: 495 stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 496 if "nblist_add_atoms" in training_parameters: 497 stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 498 layer_state.append(freeze(stnew)) 499 500 preproc_state["layers_state"] = tuple(layer_state) 501 model.preproc_state = freeze(preproc_state) 502 503 # inputs,data = next(training_iterator) 504 # inputs = model.preprocess(**inputs) 505 # print("preproc_state:",model.preproc_state) 506 507 if training_parameters.get("gpu_preprocessing", False): 508 print("GPU preprocessing activated.") 509 510 def preprocessing(model, inputs): 511 preproc_state = model.preproc_state 512 outputs = model.preprocessing.process(preproc_state, inputs) 513 preproc_state, state_up, outputs, overflow = ( 514 model.preprocessing.check_reallocate(preproc_state, outputs) 515 ) 516 if overflow: 517 print("GPU preprocessing: nblist overflow => reallocating nblist") 518 print("size updates:", state_up) 519 model.preproc_state = preproc_state 520 return outputs 521 522 else: 523 preprocessing = lambda model, inputs: model.preprocess(**inputs) 524 525 fetch_time = 0.0 526 preprocess_time = 0.0 527 step_time = 0.0 528 529 maes_prev = defaultdict(lambda: np.inf) 530 metrics_beta = training_parameters.get("metrics_ema_decay", -1.0) 531 smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0 532 if smoothen_metrics: 533 print("Computing smoothed metrics with beta =", metrics_beta) 534 rmses_smooth = defaultdict(lambda: 0.0) 535 maes_smooth = defaultdict(lambda: 0.0) 536 rmse_tot_smooth = 0.0 537 mae_tot_smooth = 0.0 538 nsmooth = 0 539 count = 0 540 restore_count = 0 541 max_restore_count = training_parameters.get("max_restore_count", 5) 542 variables = deepcopy(model.variables) 543 variables_save = deepcopy(variables) 544 variables_ema_save = deepcopy(model.variables) 545 546 fmetrics = open( 547 output_directory + f"metrics{stage_prefix}.traj", 548 "w" if training_state is None else "a", 549 ) 550 551 keep_all_bests = training_parameters.get("keep_all_bests", False) 552 previous_best_name = None 553 best_metric = np.inf 554 metric_use_best = training_parameters.get("metric_best", "rmse_tot") # .lower() 555 # authorized_metrics = ["mae", "rmse"] 556 # if smoothen_metrics: 557 # authorized_metrics+= ["mae_smooth", "rmse_smooth"] 558 # assert metric_use_best in authorized_metrics, f"metric_best must be one of {authorized_metrics}" 559 560 if training_state is not None: 561 preproc_state = training_state["preproc_state"] 562 model.preproc_state = freeze(preproc_state) 563 opt_st = training_state["opt_state"] 564 ema_st = training_state["ema_state"] 565 sch_state = training_state["sch_state"] 566 variables = training_state["variables"] 567 model.variables = training_state["variables_ema"] 568 count = training_state["count"] 569 restore_count = training_state["restore_count"] 570 epoch_start = training_state["epoch"] 571 rng_key = training_state["rng_key"] 572 best_metric = training_state["best_metric"] 573 if smoothen_metrics: 574 rmses_smooth = training_state["rmses_smooth"] 575 maes_smooth = training_state["maes_smooth"] 576 rmse_tot_smooth = training_state["rmse_tot_smooth"] 577 mae_tot_smooth = training_state["mae_tot_smooth"] 578 nsmooth = training_state["nsmooth"] 579 print("Restored training state") 580 else: 581 epoch_start = 0 582 training_state = { 583 "preproc_state": preproc_state, 584 "rng_key": rng_key, 585 "opt_state": opt_st, 586 "ema_state": ema_st, 587 "sch_state": sch_state, 588 "variables": variables, 589 "variables_ema": model.variables, 590 "count": count, 591 "restore_count": restore_count, 592 "epoch": 0, 593 "stage": stage, 594 "best_metric": np.inf, 595 } 596 if smoothen_metrics: 597 training_state["rmses_smooth"] = dict(rmses_smooth) 598 training_state["maes_smooth"] = dict(maes_smooth) 599 training_state["rmse_tot_smooth"] = rmse_tot_smooth 600 training_state["mae_tot_smooth"] = mae_tot_smooth 601 training_state["nsmooth"] = nsmooth 602 603 ### Training loop ### 604 start = time.time() 605 print("Starting training...") 606 for epoch in range(epoch_start, max_epochs): 607 s = time.time() 608 for _ in range(nbatch_per_epoch): 609 # fetch data 610 inputs0, data = next(training_iterator) 611 612 # preprocess data 613 inputs = preprocessing(model, inputs0) 614 # inputs = model.preprocess(**inputs0) 615 616 rng_key, subkey = jax.random.split(rng_key) 617 inputs["rng_key"] = subkey 618 # if compute_ref_coords: 619 # inputs_ref = {**inputs0, "coordinates": data[coordinates_ref_key]} 620 # inputs_ref = model.preprocess(**inputs_ref) 621 # else: 622 # inputs_ref = None 623 # if print_timings: 624 # jax.block_until_ready(inputs["coordinates"]) 625 626 # train step 627 # opt_st.inner_states["trainable"].inner_state[1].hyperparams[ 628 # "learning_rate" 629 # ] = schedule(count) 630 current_lr, sch_state = schedule(sch_state) 631 opt_st.inner_states["trainable"].inner_state[-1].hyperparams[ 632 "step_size" 633 ] = current_lr 634 loss, variables, opt_st, model.variables, ema_st, output = train_step( 635 epoch=epoch, 636 data=data, 637 inputs=inputs, 638 variables=variables, 639 variables_ema=model.variables, 640 opt_st=opt_st, 641 ema_st=ema_st, 642 ) 643 count += 1 644 645 rmses_avg = defaultdict(lambda: 0.0) 646 maes_avg = defaultdict(lambda: 0.0) 647 for _ in range(nbatch_per_validation): 648 inputs0, data = next(validation_iterator) 649 650 inputs = preprocessing(model, inputs0) 651 # inputs = model.preprocess(**inputs0) 652 653 rng_key, subkey = jax.random.split(rng_key) 654 inputs["rng_key"] = subkey 655 656 # if compute_ref_coords: 657 # inputs_ref = {**inputs0, "coordinates": data[coordinates_ref_key]} 658 # inputs_ref = model.preprocess(**inputs_ref) 659 # else: 660 # inputs_ref = None 661 rmses, maes, output_val = validation( 662 data=data, 663 inputs=inputs, 664 variables=model.variables, 665 # inputs_ref=inputs_ref, 666 ) 667 for k, v in rmses.items(): 668 rmses_avg[k] += v 669 for k, v in maes.items(): 670 maes_avg[k] += v 671 672 jax.block_until_ready(output_val) 673 e = time.time() 674 epoch_time = e - s 675 676 elapsed_time = e - start 677 if not adaptive_scheduler: 678 remain_glob = elapsed_time * (max_epochs - epoch) / (epoch + 1) 679 remain_last = epoch_time * (max_epochs - epoch) 680 # estimate remaining time via weighted average (put more weight on last at the beginning) 681 wremain = np.sin(0.5 * np.pi * (epoch + 1) / max_epochs) 682 remaining_time = human_time_duration( 683 remain_glob * wremain + remain_last * (1 - wremain) 684 ) 685 elapsed_time = human_time_duration(elapsed_time) 686 batch_time = human_time_duration( 687 epoch_time / (nbatch_per_epoch + nbatch_per_validation) 688 ) 689 epoch_time = human_time_duration(epoch_time) 690 691 for k in rmses_avg.keys(): 692 rmses_avg[k] /= nbatch_per_validation 693 for k in maes_avg.keys(): 694 maes_avg[k] /= nbatch_per_validation 695 696 step_time /= nbatch_per_epoch 697 fetch_time /= nbatch_per_epoch 698 preprocess_time /= nbatch_per_epoch 699 700 print("") 701 line = f"Epoch {epoch+1}, lr={current_lr:.3e}, loss = {loss:.3e}" 702 line += f", epoch time = {epoch_time}, batch time = {batch_time}" 703 line += f", elapsed time = {elapsed_time}" 704 if not adaptive_scheduler: 705 line += f", est. remaining time = {remaining_time}" 706 print(line) 707 rmse_tot = 0.0 708 mae_tot = 0.0 709 for k in rmses_avg.keys(): 710 mult = loss_definition[k]["mult"] 711 loss_prms = loss_definition[k] 712 rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5 713 mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"] 714 unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else "" 715 716 weight_str = "" 717 if "weight_schedule" in loss_prms: 718 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 719 weight_str = f"(w={w:.3f})" 720 721 if rmses_avg[k] / mult < 1.0e-2: 722 print( 723 f" rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e} {unit} {weight_str}" 724 ) 725 else: 726 print( 727 f" rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f} {unit} {weight_str}" 728 ) 729 730 # if print_timings: 731 # print( 732 # f" Timings per batch: fetch time = {fetch_time:.5f}; preprocess time = {preprocess_time:.5f}; train time = {step_time:.5f}" 733 # ) 734 fetch_time = 0.0 735 step_time = 0.0 736 preprocess_time = 0.0 737 restore = False 738 reinit = False 739 for k, mae in maes_avg.items(): 740 if np.isnan(mae): 741 restore = True 742 reinit = True 743 print("NaN detected in mae") 744 # for k, v in inputs.items(): 745 # if hasattr(v,"shape"): 746 # if np.isnan(v).any(): 747 # print(k,v) 748 # sys.exit(1) 749 if "threshold" in loss_definition[k]: 750 thr = loss_definition[k]["threshold"] 751 if mae > thr * maes_prev[k]: 752 restore = True 753 break 754 755 if restore: 756 restore_count += 1 757 if restore_count > max_restore_count: 758 if reinit: 759 raise ValueError("Model diverged and could not be restored.") 760 else: 761 restore_count = 0 762 print( 763 f"{max_restore_count} unsuccessful restores, resuming training." 764 ) 765 continue 766 767 variables = deepcopy(variables_save) 768 model.variables = deepcopy(variables_ema_save) 769 print("Restored previous model after divergence.") 770 if reinit: 771 opt_st = optimizer.init(model.variables) 772 ema_st = ema.init(model.variables) 773 print("Reinitialized optimizer after divergence.") 774 continue 775 776 restore_count = 0 777 variables_save = deepcopy(variables) 778 variables_ema_save = deepcopy(model.variables) 779 maes_prev = maes_avg 780 781 model.save(output_directory + "latest_model.fnx") 782 783 # save metrics 784 metrics = { 785 "epoch": epoch + 1, 786 "step": count, 787 "data_count": count * batch_size, 788 "elapsed_time": time.time() - start, 789 "lr": current_lr, 790 "loss": loss, 791 } 792 for k in rmses_avg.keys(): 793 mult = loss_definition[k]["mult"] 794 metrics[f"rmse_{k}"] = rmses_avg[k] / mult 795 metrics[f"mae_{k}"] = maes_avg[k] / mult 796 797 metrics["rmse_tot"] = rmse_tot 798 metrics["mae_tot"] = mae_tot 799 if smoothen_metrics: 800 nsmooth += 1 801 for k in rmses_avg.keys(): 802 mult = loss_definition[k]["mult"] 803 rmses_smooth[k] = ( 804 metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k] 805 ) 806 maes_smooth[k] = ( 807 metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k] 808 ) 809 metrics[f"rmse_smooth_{k}"] = ( 810 rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 811 ) 812 metrics[f"mae_smooth_{k}"] = ( 813 maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult 814 ) 815 rmse_tot_smooth = ( 816 metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot 817 ) 818 mae_tot_smooth = ( 819 metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot 820 ) 821 metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth) 822 metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth) 823 824 assert ( 825 metric_use_best in metrics 826 ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics" 827 metric_for_best = metrics[metric_use_best] 828 if metric_for_best < best_metric: 829 best_metric = metric_for_best 830 metrics["best_metric"] = best_metric 831 if keep_all_bests: 832 best_name = ( 833 output_directory 834 + f"best_model{stage_prefix}_{time.strftime('%Y-%m-%d-%H-%M-%S')}.fnx" 835 ) 836 model.save(best_name) 837 838 best_name = output_directory + f"best_model{stage_prefix}.fnx" 839 model.save(best_name) 840 print("New best model saved to:", best_name) 841 else: 842 metrics["best_metric"] = best_metric 843 844 if epoch == 0: 845 headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())] 846 fmetrics.write("# " + " ".join(headers) + "\n") 847 fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n") 848 fmetrics.flush() 849 850 # update learning rate using current metrics 851 assert ( 852 schedule_metrics in metrics 853 ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics" 854 current_lr, sch_state = schedule(sch_state, metrics[schedule_metrics]) 855 856 # update and save training state 857 training_state["preproc_state"] = model.preproc_state 858 training_state["opt_state"] = opt_st 859 training_state["ema_state"] = ema_st 860 training_state["sch_state"] = sch_state 861 training_state["variables"] = variables 862 training_state["variables_ema"] = model.variables 863 training_state["count"] = count 864 training_state["restore_count"] = restore_count 865 training_state["epoch"] = epoch 866 training_state["best_metric"] = best_metric 867 if smoothen_metrics: 868 training_state["rmses_smooth"] = dict(rmses_smooth) 869 training_state["maes_smooth"] = dict(maes_smooth) 870 training_state["rmse_tot_smooth"] = rmse_tot_smooth 871 training_state["mae_tot_smooth"] = mae_tot_smooth 872 training_state["nsmooth"] = nsmooth 873 874 with open(output_directory + "train_state", "wb") as f: 875 pickle.dump(training_state, f) 876 877 if is_end(metrics): 878 print("Stage finished.") 879 break 880 881 end = time.time() 882 883 print(f"Training time: {human_time_duration(end-start)}") 884 print("") 885 886 fmetrics.close() 887 888 filename = output_directory + f"final_model{stage_prefix}.fnx" 889 model.save(filename) 890 print("Final model saved to:", filename) 891 892 return metrics, filename