fennol.training.training

  1import os
  2import yaml
  3import sys
  4import jax
  5import io
  6import time
  7import jax.numpy as jnp
  8import numpy as np
  9import optax
 10from collections import defaultdict
 11import json
 12from copy import deepcopy
 13from pathlib import Path
 14import argparse
 15try:
 16    import torch
 17except ImportError:
 18    raise ImportError(
 19        "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/"
 20    )
 21import random
 22from flax import traverse_util
 23import json
 24import shutil
 25import pickle
 26
 27from flax.core import freeze, unfreeze
 28from .io import (
 29    load_configuration,
 30    load_dataset,
 31    load_model,
 32    TeeLogger,
 33    copy_parameters,
 34)
 35from .utils import (
 36    get_loss_definition,
 37    get_train_step_function,
 38    get_validation_function,
 39    linear_schedule,
 40)
 41from .optimizers import get_optimizer, get_lr_schedule
 42from ..utils import deep_update, AtomicUnits as au
 43from ..utils.io import human_time_duration
 44from ..models.preprocessing import AtomPadding, check_input, convert_to_jax
 45
 46def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state):
 47    with open(output_directory + "/restart_checkpoint.pkl", "wb") as f:
 48            pickle.dump({
 49                "stage": stage,
 50                "training": training_state,
 51                "preproc_state": preproc_state,
 52                "metrics_state": metrics_state,
 53            }, f)
 54
 55def load_restart_checkpoint(output_directory):
 56    restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl"
 57    if not os.path.exists(restart_checkpoint_file):
 58        raise FileNotFoundError(
 59            f"Training state file not found: {restart_checkpoint_file}"
 60        )
 61    with open(restart_checkpoint_file, "rb") as f:
 62        restart_checkpoint = pickle.load(f)
 63    return restart_checkpoint
 64
 65
 66def main():
 67    parser = argparse.ArgumentParser(prog="fennol_train")
 68    parser.add_argument("config_file", type=str)
 69    parser.add_argument("--model_file", type=str, default=None)
 70    args = parser.parse_args()
 71    config_file = args.config_file
 72    model_file = args.model_file
 73
 74    os.environ["OMP_NUM_THREADS"] = "1"
 75    sys.stdout = io.TextIOWrapper(
 76        open(sys.stdout.fileno(), "wb", 0), write_through=True
 77    )
 78
 79    restart_training = False
 80    if os.path.isdir(config_file):
 81        output_directory = Path(config_file).absolute().as_posix()
 82        restart_training = True
 83        while output_directory.endswith("/"):
 84            output_directory = output_directory[:-1]
 85        config_file = output_directory + "/config.yaml"
 86        backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
 87        shutil.copytree(output_directory, backup_dir)
 88        print("Restarting training from", output_directory)
 89
 90    parameters = load_configuration(config_file)
 91
 92    ### Set the device
 93    device: str = parameters.get("device", "cpu").lower()
 94    if device == "cpu":
 95        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 96    elif device.startswith("cuda") or device.startswith("gpu"):
 97        if ":" in device:
 98            num = device.split(":")[-1]
 99            os.environ["CUDA_VISIBLE_DEVICES"] = num
100        else:
101            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
102        device = "gpu"
103
104    _device = jax.devices(device)[0]
105    jax.config.update("jax_default_device", _device)
106
107    if restart_training:
108        restart_checkpoint = load_restart_checkpoint(output_directory)
109    else:
110        restart_checkpoint = None
111
112    # output directory
113    if not restart_training:
114        output_directory = parameters.get("output_directory", None)
115        if output_directory is not None:
116            if "{now}" in output_directory:
117                output_directory = output_directory.replace(
118                    "{now}", time.strftime("%Y-%m-%d-%H-%M-%S")
119                )
120            output_directory = Path(output_directory).absolute()
121            if not output_directory.exists():
122                output_directory.mkdir(parents=True)
123            print("Output directory:", output_directory)
124        else:
125            output_directory = "."
126
127    output_directory = str(output_directory) + "/"
128
129    # copy config_file to output directory
130    # config_name = Path(config_file).name
131    config_ext = Path(config_file).suffix
132    with open(config_file) as f_in:
133        config_data = f_in.read()
134    with open(output_directory + "/config" + config_ext, "w") as f_out:
135        f_out.write(config_data)
136
137    # set log file
138    log_file = "train.log"  # parameters.get("log_file", None)
139    logger = TeeLogger(output_directory + log_file)
140    logger.bind_stdout()
141
142    # set matmul precision
143    enable_x64 = parameters.get("double_precision", False)
144    jax.config.update("jax_enable_x64", enable_x64)
145    fprec = "float64" if enable_x64 else "float32"
146    parameters["fprec"] = fprec
147    if enable_x64:
148        print("Double precision enabled.")
149
150    matmul_precision = parameters.get("matmul_prec", "highest").lower()
151    assert matmul_precision in [
152        "default",
153        "high",
154        "highest",
155    ], "matmul_prec must be one of 'default','high','highest'"
156    jax.config.update("jax_default_matmul_precision", matmul_precision)
157
158    # set random seed
159    rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1))
160    print(f"rng_seed: {rng_seed}")
161    rng_key = jax.random.PRNGKey(rng_seed)
162    torch.manual_seed(rng_seed)
163    np.random.seed(rng_seed)
164    random.seed(rng_seed)
165    np_rng = np.random.Generator(np.random.PCG64(rng_seed))
166
167    try:
168        if "stages" in parameters["training"]:
169            ## train in stages ##
170            params = deepcopy(parameters)
171            stages = params["training"].pop("stages")
172            assert isinstance(stages, dict), "'stages' must be a dict with named stages"
173            model_file_stage = model_file
174            print_stages_params = params["training"].get("print_stages_params", False)
175            for i, (stage, stage_params) in enumerate(stages.items()):
176                rng_key, subkey = jax.random.split(rng_key)
177                print("")
178                print(f"### STAGE {i+1}: {stage} ###")
179
180                ## remove end_event from previous stage ##
181                if i > 0 and "end_event" in params["training"]:
182                    params["training"].pop("end_event")
183
184                ## incrementally update training parameters ##
185                params = deep_update(params, {"training": stage_params})
186                if model_file_stage is not None:
187                    ## load model from previous stage ##
188                    params["model_file"] = model_file_stage
189
190                if restart_training and restart_checkpoint["stage"] != i + 1:
191                    print(f"Skipping stage {i+1} (already completed)")
192                    continue
193
194                if print_stages_params:
195                    print("stage parameters:")
196                    print(json.dumps(params, indent=2, sort_keys=False))
197
198                ## train stage ##
199                _, model_file_stage = train(
200                    (subkey, np_rng),
201                    params,
202                    stage=i + 1,
203                    output_directory=output_directory,
204                    restart_checkpoint=restart_checkpoint,
205                )
206
207                restart_training = False
208                restart_checkpoint = None
209        else:
210            ## single training stage ##
211            train(
212                (rng_key, np_rng),
213                parameters,
214                model_file=model_file,
215                output_directory=output_directory,
216                restart_checkpoint=restart_checkpoint,
217            )
218    except KeyboardInterrupt:
219        print("Training interrupted by user.")
220    finally:
221        if log_file is not None:
222            logger.unbind_stdout()
223            logger.close()
224
225
226def train(
227    rng,
228    parameters,
229    model_file=None,
230    stage=None,
231    output_directory=None,
232    restart_checkpoint=None,
233):
234    if output_directory is None:
235        output_directory = "./"
236    elif output_directory == "":
237        output_directory = "./"
238    elif not output_directory.endswith("/"):
239        output_directory += "/"
240    stage_prefix = f"_stage_{stage}" if stage is not None else ""
241
242    if isinstance(rng, tuple):
243        rng_key, np_rng = rng
244    else:
245        rng_key = rng
246        np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1)))
247
248    ### GET TRAINING PARAMETERS ###
249    training_parameters = parameters.get("training", {})
250
251    #### LOAD MODEL ####
252    if restart_checkpoint is not None:
253        model_key = None
254        model_file = output_directory + "latest_model.fnx"
255    else:
256        rng_key, model_key = jax.random.split(rng_key)
257    model = load_model(parameters, model_file, rng_key=model_key)
258
259    ### SAVE INITIAL MODEL ###
260    save_initial_model = (
261        restart_checkpoint is None 
262        and (stage is None or stage == 0)
263    )
264    if save_initial_model:
265        model.save(output_directory + "initial_model.fnx")
266
267    model_ref = None
268    if "model_ref" in training_parameters:
269        model_ref = load_model(parameters, training_parameters["model_ref"])
270        print("Reference model:", training_parameters["model_ref"])
271
272        if "ref_parameters" in training_parameters:
273            ref_parameters = training_parameters["ref_parameters"]
274            assert isinstance(
275                ref_parameters, list
276            ), "ref_parameters must be a list of str"
277            print("Reference parameters:", ref_parameters)
278            model.variables = copy_parameters(
279                model.variables, model_ref.variables, ref_parameters
280            )
281
282    ### SET FLOATING POINT PRECISION ###
283    fprec = parameters.get("fprec", "float32")
284
285    def convert_to_fprec(x):
286        if jnp.issubdtype(x.dtype, jnp.floating):
287            return x.astype(fprec)
288        return x
289
290    model.variables = jax.tree_map(convert_to_fprec, model.variables)
291
292    ### SET UP LOSS FUNCTION ###
293    loss_definition, used_keys, ref_keys = get_loss_definition(
294        training_parameters, model_energy_unit=model.energy_unit
295    )
296
297    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
298    if coordinates_ref_key is not None:
299        compute_ref_coords = True
300        print("Reference coordinates:", coordinates_ref_key)
301    else:
302        compute_ref_coords = False
303
304    #### LOAD DATASET ####
305    dspath = training_parameters.get("dspath", None)
306    if dspath is None:
307        raise ValueError("Dataset path 'training/dspath' should be specified.")
308    batch_size = training_parameters.get("batch_size", 16)
309    batch_size_val = training_parameters.get("batch_size_val", None)
310    rename_refs = training_parameters.get("rename_refs", {})
311    training_iterator, validation_iterator = load_dataset(
312        dspath=dspath,
313        batch_size=batch_size,
314        batch_size_val=batch_size_val,
315        training_parameters=training_parameters,
316        infinite_iterator=True,
317        atom_padding=True,
318        ref_keys=ref_keys,
319        split_data_inputs=True,
320        np_rng=np_rng,
321        add_flags=["training"],
322        fprec=fprec,
323        rename_refs=rename_refs,
324    )
325
326    compute_forces = "forces" in used_keys
327    compute_virial = "virial_tensor" in used_keys or "virial" in used_keys
328    compute_stress = "stress_tensor" in used_keys or "stress" in used_keys
329    compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys
330
331    ### get optimizer parameters ###
332    max_epochs = int(training_parameters.get("max_epochs", 2000))
333    epoch_format = len(str(max_epochs))
334    nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200)
335    nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20)
336
337    schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters)
338    optimizer = get_optimizer(
339        training_parameters, model.variables, schedule(sch_state)[0]
340    )
341    opt_st = optimizer.init(model.variables)
342
343    ### exponential moving average of the parameters ###
344    ema_decay = training_parameters.get("ema_decay", -1.0)
345    if ema_decay > 0.0:
346        assert ema_decay < 1.0, "ema_decay must be in (0,1)"
347        ema = optax.ema(decay=ema_decay)
348    else:
349        ema = optax.identity()
350    ema_st = ema.init(model.variables)
351
352    #### end event ####
353    end_event = training_parameters.get("end_event", None)
354    if end_event is None or isinstance(end_event, str) and end_event.lower() == "none":
355        is_end = lambda metrics: False
356    else:
357        assert len(end_event) == 2, "end_event must be a list of two elements"
358        is_end = lambda metrics: metrics[end_event[0]] < end_event[1]
359
360    if "energy_terms" in training_parameters:
361        model.set_energy_terms(training_parameters["energy_terms"], jit=False)
362    print("energy terms:", model.energy_terms)
363
364    ### MODEL EVALUATION FUNCTION ###
365    pbc_training = training_parameters.get("pbc_training", False)
366    if compute_stress or compute_virial or compute_pressure:
367        virial_key = "virial" if "virial" in used_keys else "virial_tensor"
368        stress_key = "stress" if "stress" in used_keys else "stress_tensor"
369        pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor"
370        if compute_stress or compute_pressure:
371            assert pbc_training, "PBC must be enabled for stress or virial training"
372            print("Computing forces and stress tensor")
373            def evaluate(model, variables, data):
374                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
375                cells = output["cells"]
376                volume = jnp.abs(jnp.linalg.det(cells))
377                stress = vir / volume[:, None, None]
378                output[stress_key] = stress
379                output[virial_key] = vir
380                if pressure_key == "pressure":
381                    output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0
382                else:
383                    output[pressure_key] = -stress
384                return output
385
386        else:
387            print("Computing forces and virial tensor")
388            def evaluate(model, variables, data):
389                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
390                output[virial_key] = vir
391                return output
392
393    elif compute_forces:
394        print("Computing forces")
395        def evaluate(model, variables, data):
396            _, _, output = model._energy_and_forces(variables, data)
397            return output
398
399    elif model.energy_terms is not None:
400        def evaluate(model, variables, data):
401            _, output = model._total_energy(variables, data)
402            return output
403
404    else:
405        def evaluate(model, variables, data):
406            output = model.modules.apply(variables, data)
407            return output
408
409    ### TRAINING FUNCTIONS ###
410    train_step = get_train_step_function(
411        loss_definition=loss_definition,
412        model=model,
413        model_ref=model_ref,
414        compute_ref_coords=compute_ref_coords,
415        evaluate=evaluate,
416        optimizer=optimizer,
417        ema=ema,
418    )
419
420    validation = get_validation_function(
421        loss_definition=loss_definition,
422        model=model,
423        model_ref=model_ref,
424        compute_ref_coords=compute_ref_coords,
425        evaluate=evaluate,
426        return_targets=False,
427    )
428
429    #### CONFIGURE PREPROCESSING ####
430    gpu_preprocessing = training_parameters.get("gpu_preprocessing", False)
431    if gpu_preprocessing:
432        print("GPU preprocessing activated.")
433
434    minimum_image = training_parameters.get("minimum_image", False)
435    preproc_state = unfreeze(model.preproc_state)
436    layer_state = []
437    for st in preproc_state["layers_state"]:
438        stnew = unfreeze(st)
439        #     st["nblist_skin"] = nblist_skin
440        #     if nblist_stride > 1:
441        #         st["skin_stride"] = nblist_stride
442        #         st["skin_count"] = nblist_stride
443        if pbc_training:
444            stnew["minimum_image"] = minimum_image
445        if "nblist_mult_size" in training_parameters:
446            stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
447        if "nblist_add_neigh" in training_parameters:
448            stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
449        if "nblist_add_atoms" in training_parameters:
450            stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
451        layer_state.append(freeze(stnew))
452
453    preproc_state["layers_state"] = tuple(layer_state)
454    preproc_state["check_input"] = False
455    model.preproc_state = freeze(preproc_state)
456
457    ### INITIALIZE METRICS ###
458    maes_prev = defaultdict(lambda: np.inf)
459    metrics_beta = training_parameters.get("metrics_ema_decay", -1.0)
460    smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0
461    if smoothen_metrics:
462        print("Computing smoothed metrics with beta =", metrics_beta)
463        rmses_smooth = defaultdict(lambda: 0.0)
464        maes_smooth = defaultdict(lambda: 0.0)
465        rmse_tot_smooth = 0.0
466        mae_tot_smooth = 0.0
467        nsmooth = 0
468    max_restore_count = training_parameters.get("max_restore_count", 5)    
469
470    fmetrics = open(
471        output_directory + f"metrics{stage_prefix}.traj",
472        "w" if restart_checkpoint is None else "a",
473    )
474
475    keep_all_bests = training_parameters.get("keep_all_bests", False)
476    save_model_at_epochs = training_parameters.get("save_model_at_epochs", [])
477    save_model_at_epochs= set([int(x) for x in save_model_at_epochs])
478    save_model_every = int(training_parameters.get("save_model_every_epoch", 0))
479    if save_model_every > 0:
480        save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every))
481        save_model_at_epochs = save_model_at_epochs.union(save_model_every)
482    save_model_at_epochs = list(save_model_at_epochs)
483    save_model_at_epochs.sort()
484
485    ### LR factor after restoring a model that has diverged
486    restore_scaling = training_parameters.get("restore_scaling", 1)
487    assert restore_scaling > 0.0, "restore_scaling must be > 0.0"
488    assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0"
489    if restore_scaling != 1:
490        print("Applying LR scaling after restore:", restore_scaling)
491
492
493    metric_use_best = training_parameters.get("metric_best", "rmse_tot")  # .lower()
494    metrics_state = {}
495
496    ### INITIALIZE TRAINING STATE ###
497    if restart_checkpoint is not None:
498        training_state = deepcopy(restart_checkpoint["training"])
499        epoch_start = training_state["epoch"]
500        model.preproc_state = freeze(restart_checkpoint["preproc_state"])
501        if smoothen_metrics:
502            metrics_state = deepcopy(restart_checkpoint["metrics_state"])
503            rmses_smooth = metrics_state["rmses_smooth"]
504            maes_smooth = metrics_state["maes_smooth"]
505            rmse_tot_smooth = metrics_state["rmse_tot_smooth"]
506            mae_tot_smooth = metrics_state["mae_tot_smooth"]
507            nsmooth = metrics_state["nsmooth"]
508        print("Restored training state")
509    else:
510        epoch_start = 1
511        training_state = {
512            "rng_key": rng_key,
513            "opt_state": opt_st,
514            "ema_state": ema_st,
515            "sch_state": sch_state,
516            "variables": model.variables,
517            "variables_ema": model.variables,
518            "step": 0,
519            "restore_count": 0,
520            "restore_scale": 1,
521            "epoch": 0,
522            "best_metric": np.inf,
523        }
524        if smoothen_metrics:
525            metrics_state = {
526                "rmses_smooth": dict(rmses_smooth),
527                "maes_smooth": dict(maes_smooth),
528                "rmse_tot_smooth": rmse_tot_smooth,
529                "mae_tot_smooth": mae_tot_smooth,
530                "nsmooth": nsmooth,
531            }
532    
533    del opt_st, ema_st, sch_state, preproc_state
534    variables_save = training_state['variables']
535    variables_ema_save = training_state['variables_ema']
536
537    ### Training loop ###
538    start = time.time()
539    print("Starting training...")
540    for epoch in range(epoch_start, max_epochs+1):
541        training_state["epoch"] = epoch
542        s = time.time()
543        for _ in range(nbatch_per_epoch):
544            # fetch data
545            inputs0, data = next(training_iterator)
546
547            # preprocess data
548            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
549
550            rng_key, subkey = jax.random.split(rng_key)
551            inputs["rng_key"] = subkey
552            inputs["training_epoch"] = epoch
553
554            # adjust learning rate
555            current_lr, training_state["sch_state"] = schedule(training_state["sch_state"])
556            current_lr *= training_state["restore_scale"]
557            training_state["opt_state"].inner_states["trainable"].inner_state[-1].hyperparams[
558                "step_size"
559            ] = current_lr 
560
561            # train step
562            loss,training_state, _ = train_step(data,inputs,training_state)
563
564        rmses_avg = defaultdict(lambda: 0.0)
565        maes_avg = defaultdict(lambda: 0.0)
566        for _ in range(nbatch_per_validation):
567            inputs0, data = next(validation_iterator)
568
569            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
570
571            rng_key, subkey = jax.random.split(rng_key)
572            inputs["rng_key"] = subkey
573            inputs["training_epoch"] = epoch
574
575            rmses, maes, output_val = validation(
576                data=data,
577                inputs=inputs,
578                variables=training_state["variables_ema"],
579            )
580            for k, v in rmses.items():
581                rmses_avg[k] += v
582            for k, v in maes.items():
583                maes_avg[k] += v
584
585        jax.block_until_ready(output_val)
586
587        ### Timings ###
588        e = time.time()
589        epoch_time = e - s
590
591        elapsed_time = e - start
592        if not adaptive_scheduler:
593            remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch
594            remain_last = epoch_time * (max_epochs - epoch + 1)
595            # estimate remaining time via weighted average (put more weight on last at the beginning)
596            wremain = np.sin(0.5 * np.pi * epoch / max_epochs)
597            remaining_time = human_time_duration(
598                remain_glob * wremain + remain_last * (1 - wremain)
599            )
600        elapsed_time = human_time_duration(elapsed_time)
601        batch_time = human_time_duration(
602            epoch_time / (nbatch_per_epoch + nbatch_per_validation)
603        )
604        epoch_time = human_time_duration(epoch_time)
605
606        for k in rmses_avg.keys():
607            rmses_avg[k] /= nbatch_per_validation
608        for k in maes_avg.keys():
609            maes_avg[k] /= nbatch_per_validation
610
611        ### Print metrics ###
612        print("")
613        line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}"
614        line += f", epoch time = {epoch_time}, batch time = {batch_time}"
615        line += f", elapsed time = {elapsed_time}"
616        if not adaptive_scheduler:
617            line += f", est. remaining time = {remaining_time}"
618        print(line)
619        rmse_tot = 0.0
620        mae_tot = 0.0
621        for k in rmses_avg.keys():
622            mult = loss_definition[k]["mult"]
623            loss_prms = loss_definition[k]
624            rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5
625            mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"]
626            unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else ""
627
628            weight_str = ""
629            if "weight_schedule" in loss_prms:
630                w = linear_schedule(epoch, *loss_prms["weight_schedule"])
631                weight_str = f"(w={w:.3f})"
632
633            if rmses_avg[k] / mult < 1.0e-2:
634                print(
635                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e}   {unit}  {weight_str}"
636                )
637            else:
638                print(
639                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f}   {unit}  {weight_str}"
640                )
641
642        ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ###
643        restore = False
644        reinit = False
645        for k, mae in maes_avg.items():
646            if np.isnan(mae):
647                restore = True
648                reinit = True
649                print("NaN detected in mae")
650                # for k, v in inputs.items():
651                #     if hasattr(v,"shape"):
652                #         if np.isnan(v).any():
653                #             print(k,v)
654                # sys.exit(1)
655            if "threshold" in loss_definition[k]:
656                thr = loss_definition[k]["threshold"]
657                if mae > thr * maes_prev[k]:
658                    restore = True
659                    break
660
661        epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S")
662        model.variables = training_state["variables_ema"]
663        if epoch in save_model_at_epochs:
664            filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx"
665            print("Scheduled model save :", filename)
666            model.save(filename)
667
668        if restore:
669            training_state["restore_count"] += 1
670            if training_state["restore_count"]  > max_restore_count:
671                if reinit:
672                    raise ValueError("Model diverged and could not be restored.")
673                else:
674                    training_state["restore_count"]  = 0
675                    print(
676                        f"{max_restore_count} unsuccessful restores, resuming training."
677                    )
678
679            training_state["variables"] = deepcopy(variables_save)
680            training_state["variables_ema"] = deepcopy(variables_ema_save)
681            model.variables = training_state["variables_ema"]
682            training_state["restore_scale"] *= restore_scaling
683            print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}")
684            if reinit:
685                training_state["opt_state"] = optimizer.init(model.variables)
686                training_state["ema_state"] = ema.init(model.variables)
687                print("Reinitialized optimizer after divergence.")
688            save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
689            continue
690
691        training_state["restore_count"]  = 0
692
693        variables_save = deepcopy(training_state["variables"])
694        variables_ema_save = deepcopy(training_state["variables_ema"])
695        maes_prev = maes_avg
696
697        ### SAVE MODEL ###
698        model.save(output_directory + "latest_model.fnx")
699
700        # save metrics
701        step = int(training_state["step"])
702        metrics = {
703            "epoch": epoch,
704            "step": step,
705            "data_count": step * batch_size,
706            "elapsed_time": time.time() - start,
707            "lr": current_lr,
708            "loss": loss,
709        }
710        for k in rmses_avg.keys():
711            mult = loss_definition[k]["mult"]
712            metrics[f"rmse_{k}"] = rmses_avg[k] / mult
713            metrics[f"mae_{k}"] = maes_avg[k] / mult
714
715        metrics["rmse_tot"] = rmse_tot
716        metrics["mae_tot"] = mae_tot
717        if smoothen_metrics:
718            nsmooth += 1
719            for k in rmses_avg.keys():
720                mult = loss_definition[k]["mult"]
721                rmses_smooth[k] = (
722                    metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k]
723                )
724                maes_smooth[k] = (
725                    metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k]
726                )
727                metrics[f"rmse_smooth_{k}"] = (
728                    rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
729                )
730                metrics[f"mae_smooth_{k}"] = (
731                    maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
732                )
733            rmse_tot_smooth = (
734                metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot
735            )
736            mae_tot_smooth = (
737                metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot
738            )
739            metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth)
740            metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth)
741
742            # update metrics state
743            metrics_state["rmses_smooth"] = dict(rmses_smooth)
744            metrics_state["maes_smooth"] = dict(maes_smooth)
745            metrics_state["rmse_tot_smooth"] = rmse_tot_smooth
746            metrics_state["mae_tot_smooth"] = mae_tot_smooth
747            metrics_state["nsmooth"] = nsmooth
748
749        assert (
750            metric_use_best in metrics
751        ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics"
752
753        metric_for_best = metrics[metric_use_best]
754        if metric_for_best <  training_state["best_metric"]:
755            training_state["best_metric"] = metric_for_best
756            metrics["best_metric"] = training_state["best_metric"]
757            if keep_all_bests:
758                best_name = (
759                    output_directory
760                    + f"best_model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx"
761                )
762                model.save(best_name)
763
764            best_name = output_directory + f"best_model{stage_prefix}.fnx"
765            model.save(best_name)
766            print("New best model saved to:", best_name)
767        else:
768            metrics["best_metric"] = training_state["best_metric"]
769
770        if epoch == 1:
771            headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())]
772            fmetrics.write("# " + " ".join(headers) + "\n")
773        fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n")
774        fmetrics.flush()
775
776        # update learning rate using current metrics
777        assert (
778            schedule_metrics in metrics
779        ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics"
780        current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics])
781
782        # update and save training state
783        save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
784
785        
786        if is_end(metrics):
787            print("Stage finished.")
788            break
789
790    end = time.time()
791
792    print(f"Training time: {human_time_duration(end-start)}")
793    print("")
794
795    fmetrics.close()
796
797    filename = output_directory + f"final_model{stage_prefix}.fnx"
798    model.save(filename)
799    print("Final model saved to:", filename)
800
801    return metrics, filename
802
803
804if __name__ == "__main__":
805    main()
def save_restart_checkpoint( output_directory, stage, training_state, preproc_state, metrics_state):
47def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state):
48    with open(output_directory + "/restart_checkpoint.pkl", "wb") as f:
49            pickle.dump({
50                "stage": stage,
51                "training": training_state,
52                "preproc_state": preproc_state,
53                "metrics_state": metrics_state,
54            }, f)
def load_restart_checkpoint(output_directory):
56def load_restart_checkpoint(output_directory):
57    restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl"
58    if not os.path.exists(restart_checkpoint_file):
59        raise FileNotFoundError(
60            f"Training state file not found: {restart_checkpoint_file}"
61        )
62    with open(restart_checkpoint_file, "rb") as f:
63        restart_checkpoint = pickle.load(f)
64    return restart_checkpoint
def main():
 67def main():
 68    parser = argparse.ArgumentParser(prog="fennol_train")
 69    parser.add_argument("config_file", type=str)
 70    parser.add_argument("--model_file", type=str, default=None)
 71    args = parser.parse_args()
 72    config_file = args.config_file
 73    model_file = args.model_file
 74
 75    os.environ["OMP_NUM_THREADS"] = "1"
 76    sys.stdout = io.TextIOWrapper(
 77        open(sys.stdout.fileno(), "wb", 0), write_through=True
 78    )
 79
 80    restart_training = False
 81    if os.path.isdir(config_file):
 82        output_directory = Path(config_file).absolute().as_posix()
 83        restart_training = True
 84        while output_directory.endswith("/"):
 85            output_directory = output_directory[:-1]
 86        config_file = output_directory + "/config.yaml"
 87        backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
 88        shutil.copytree(output_directory, backup_dir)
 89        print("Restarting training from", output_directory)
 90
 91    parameters = load_configuration(config_file)
 92
 93    ### Set the device
 94    device: str = parameters.get("device", "cpu").lower()
 95    if device == "cpu":
 96        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 97    elif device.startswith("cuda") or device.startswith("gpu"):
 98        if ":" in device:
 99            num = device.split(":")[-1]
100            os.environ["CUDA_VISIBLE_DEVICES"] = num
101        else:
102            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
103        device = "gpu"
104
105    _device = jax.devices(device)[0]
106    jax.config.update("jax_default_device", _device)
107
108    if restart_training:
109        restart_checkpoint = load_restart_checkpoint(output_directory)
110    else:
111        restart_checkpoint = None
112
113    # output directory
114    if not restart_training:
115        output_directory = parameters.get("output_directory", None)
116        if output_directory is not None:
117            if "{now}" in output_directory:
118                output_directory = output_directory.replace(
119                    "{now}", time.strftime("%Y-%m-%d-%H-%M-%S")
120                )
121            output_directory = Path(output_directory).absolute()
122            if not output_directory.exists():
123                output_directory.mkdir(parents=True)
124            print("Output directory:", output_directory)
125        else:
126            output_directory = "."
127
128    output_directory = str(output_directory) + "/"
129
130    # copy config_file to output directory
131    # config_name = Path(config_file).name
132    config_ext = Path(config_file).suffix
133    with open(config_file) as f_in:
134        config_data = f_in.read()
135    with open(output_directory + "/config" + config_ext, "w") as f_out:
136        f_out.write(config_data)
137
138    # set log file
139    log_file = "train.log"  # parameters.get("log_file", None)
140    logger = TeeLogger(output_directory + log_file)
141    logger.bind_stdout()
142
143    # set matmul precision
144    enable_x64 = parameters.get("double_precision", False)
145    jax.config.update("jax_enable_x64", enable_x64)
146    fprec = "float64" if enable_x64 else "float32"
147    parameters["fprec"] = fprec
148    if enable_x64:
149        print("Double precision enabled.")
150
151    matmul_precision = parameters.get("matmul_prec", "highest").lower()
152    assert matmul_precision in [
153        "default",
154        "high",
155        "highest",
156    ], "matmul_prec must be one of 'default','high','highest'"
157    jax.config.update("jax_default_matmul_precision", matmul_precision)
158
159    # set random seed
160    rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1))
161    print(f"rng_seed: {rng_seed}")
162    rng_key = jax.random.PRNGKey(rng_seed)
163    torch.manual_seed(rng_seed)
164    np.random.seed(rng_seed)
165    random.seed(rng_seed)
166    np_rng = np.random.Generator(np.random.PCG64(rng_seed))
167
168    try:
169        if "stages" in parameters["training"]:
170            ## train in stages ##
171            params = deepcopy(parameters)
172            stages = params["training"].pop("stages")
173            assert isinstance(stages, dict), "'stages' must be a dict with named stages"
174            model_file_stage = model_file
175            print_stages_params = params["training"].get("print_stages_params", False)
176            for i, (stage, stage_params) in enumerate(stages.items()):
177                rng_key, subkey = jax.random.split(rng_key)
178                print("")
179                print(f"### STAGE {i+1}: {stage} ###")
180
181                ## remove end_event from previous stage ##
182                if i > 0 and "end_event" in params["training"]:
183                    params["training"].pop("end_event")
184
185                ## incrementally update training parameters ##
186                params = deep_update(params, {"training": stage_params})
187                if model_file_stage is not None:
188                    ## load model from previous stage ##
189                    params["model_file"] = model_file_stage
190
191                if restart_training and restart_checkpoint["stage"] != i + 1:
192                    print(f"Skipping stage {i+1} (already completed)")
193                    continue
194
195                if print_stages_params:
196                    print("stage parameters:")
197                    print(json.dumps(params, indent=2, sort_keys=False))
198
199                ## train stage ##
200                _, model_file_stage = train(
201                    (subkey, np_rng),
202                    params,
203                    stage=i + 1,
204                    output_directory=output_directory,
205                    restart_checkpoint=restart_checkpoint,
206                )
207
208                restart_training = False
209                restart_checkpoint = None
210        else:
211            ## single training stage ##
212            train(
213                (rng_key, np_rng),
214                parameters,
215                model_file=model_file,
216                output_directory=output_directory,
217                restart_checkpoint=restart_checkpoint,
218            )
219    except KeyboardInterrupt:
220        print("Training interrupted by user.")
221    finally:
222        if log_file is not None:
223            logger.unbind_stdout()
224            logger.close()
def train( rng, parameters, model_file=None, stage=None, output_directory=None, restart_checkpoint=None):
227def train(
228    rng,
229    parameters,
230    model_file=None,
231    stage=None,
232    output_directory=None,
233    restart_checkpoint=None,
234):
235    if output_directory is None:
236        output_directory = "./"
237    elif output_directory == "":
238        output_directory = "./"
239    elif not output_directory.endswith("/"):
240        output_directory += "/"
241    stage_prefix = f"_stage_{stage}" if stage is not None else ""
242
243    if isinstance(rng, tuple):
244        rng_key, np_rng = rng
245    else:
246        rng_key = rng
247        np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1)))
248
249    ### GET TRAINING PARAMETERS ###
250    training_parameters = parameters.get("training", {})
251
252    #### LOAD MODEL ####
253    if restart_checkpoint is not None:
254        model_key = None
255        model_file = output_directory + "latest_model.fnx"
256    else:
257        rng_key, model_key = jax.random.split(rng_key)
258    model = load_model(parameters, model_file, rng_key=model_key)
259
260    ### SAVE INITIAL MODEL ###
261    save_initial_model = (
262        restart_checkpoint is None 
263        and (stage is None or stage == 0)
264    )
265    if save_initial_model:
266        model.save(output_directory + "initial_model.fnx")
267
268    model_ref = None
269    if "model_ref" in training_parameters:
270        model_ref = load_model(parameters, training_parameters["model_ref"])
271        print("Reference model:", training_parameters["model_ref"])
272
273        if "ref_parameters" in training_parameters:
274            ref_parameters = training_parameters["ref_parameters"]
275            assert isinstance(
276                ref_parameters, list
277            ), "ref_parameters must be a list of str"
278            print("Reference parameters:", ref_parameters)
279            model.variables = copy_parameters(
280                model.variables, model_ref.variables, ref_parameters
281            )
282
283    ### SET FLOATING POINT PRECISION ###
284    fprec = parameters.get("fprec", "float32")
285
286    def convert_to_fprec(x):
287        if jnp.issubdtype(x.dtype, jnp.floating):
288            return x.astype(fprec)
289        return x
290
291    model.variables = jax.tree_map(convert_to_fprec, model.variables)
292
293    ### SET UP LOSS FUNCTION ###
294    loss_definition, used_keys, ref_keys = get_loss_definition(
295        training_parameters, model_energy_unit=model.energy_unit
296    )
297
298    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
299    if coordinates_ref_key is not None:
300        compute_ref_coords = True
301        print("Reference coordinates:", coordinates_ref_key)
302    else:
303        compute_ref_coords = False
304
305    #### LOAD DATASET ####
306    dspath = training_parameters.get("dspath", None)
307    if dspath is None:
308        raise ValueError("Dataset path 'training/dspath' should be specified.")
309    batch_size = training_parameters.get("batch_size", 16)
310    batch_size_val = training_parameters.get("batch_size_val", None)
311    rename_refs = training_parameters.get("rename_refs", {})
312    training_iterator, validation_iterator = load_dataset(
313        dspath=dspath,
314        batch_size=batch_size,
315        batch_size_val=batch_size_val,
316        training_parameters=training_parameters,
317        infinite_iterator=True,
318        atom_padding=True,
319        ref_keys=ref_keys,
320        split_data_inputs=True,
321        np_rng=np_rng,
322        add_flags=["training"],
323        fprec=fprec,
324        rename_refs=rename_refs,
325    )
326
327    compute_forces = "forces" in used_keys
328    compute_virial = "virial_tensor" in used_keys or "virial" in used_keys
329    compute_stress = "stress_tensor" in used_keys or "stress" in used_keys
330    compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys
331
332    ### get optimizer parameters ###
333    max_epochs = int(training_parameters.get("max_epochs", 2000))
334    epoch_format = len(str(max_epochs))
335    nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200)
336    nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20)
337
338    schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters)
339    optimizer = get_optimizer(
340        training_parameters, model.variables, schedule(sch_state)[0]
341    )
342    opt_st = optimizer.init(model.variables)
343
344    ### exponential moving average of the parameters ###
345    ema_decay = training_parameters.get("ema_decay", -1.0)
346    if ema_decay > 0.0:
347        assert ema_decay < 1.0, "ema_decay must be in (0,1)"
348        ema = optax.ema(decay=ema_decay)
349    else:
350        ema = optax.identity()
351    ema_st = ema.init(model.variables)
352
353    #### end event ####
354    end_event = training_parameters.get("end_event", None)
355    if end_event is None or isinstance(end_event, str) and end_event.lower() == "none":
356        is_end = lambda metrics: False
357    else:
358        assert len(end_event) == 2, "end_event must be a list of two elements"
359        is_end = lambda metrics: metrics[end_event[0]] < end_event[1]
360
361    if "energy_terms" in training_parameters:
362        model.set_energy_terms(training_parameters["energy_terms"], jit=False)
363    print("energy terms:", model.energy_terms)
364
365    ### MODEL EVALUATION FUNCTION ###
366    pbc_training = training_parameters.get("pbc_training", False)
367    if compute_stress or compute_virial or compute_pressure:
368        virial_key = "virial" if "virial" in used_keys else "virial_tensor"
369        stress_key = "stress" if "stress" in used_keys else "stress_tensor"
370        pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor"
371        if compute_stress or compute_pressure:
372            assert pbc_training, "PBC must be enabled for stress or virial training"
373            print("Computing forces and stress tensor")
374            def evaluate(model, variables, data):
375                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
376                cells = output["cells"]
377                volume = jnp.abs(jnp.linalg.det(cells))
378                stress = vir / volume[:, None, None]
379                output[stress_key] = stress
380                output[virial_key] = vir
381                if pressure_key == "pressure":
382                    output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0
383                else:
384                    output[pressure_key] = -stress
385                return output
386
387        else:
388            print("Computing forces and virial tensor")
389            def evaluate(model, variables, data):
390                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
391                output[virial_key] = vir
392                return output
393
394    elif compute_forces:
395        print("Computing forces")
396        def evaluate(model, variables, data):
397            _, _, output = model._energy_and_forces(variables, data)
398            return output
399
400    elif model.energy_terms is not None:
401        def evaluate(model, variables, data):
402            _, output = model._total_energy(variables, data)
403            return output
404
405    else:
406        def evaluate(model, variables, data):
407            output = model.modules.apply(variables, data)
408            return output
409
410    ### TRAINING FUNCTIONS ###
411    train_step = get_train_step_function(
412        loss_definition=loss_definition,
413        model=model,
414        model_ref=model_ref,
415        compute_ref_coords=compute_ref_coords,
416        evaluate=evaluate,
417        optimizer=optimizer,
418        ema=ema,
419    )
420
421    validation = get_validation_function(
422        loss_definition=loss_definition,
423        model=model,
424        model_ref=model_ref,
425        compute_ref_coords=compute_ref_coords,
426        evaluate=evaluate,
427        return_targets=False,
428    )
429
430    #### CONFIGURE PREPROCESSING ####
431    gpu_preprocessing = training_parameters.get("gpu_preprocessing", False)
432    if gpu_preprocessing:
433        print("GPU preprocessing activated.")
434
435    minimum_image = training_parameters.get("minimum_image", False)
436    preproc_state = unfreeze(model.preproc_state)
437    layer_state = []
438    for st in preproc_state["layers_state"]:
439        stnew = unfreeze(st)
440        #     st["nblist_skin"] = nblist_skin
441        #     if nblist_stride > 1:
442        #         st["skin_stride"] = nblist_stride
443        #         st["skin_count"] = nblist_stride
444        if pbc_training:
445            stnew["minimum_image"] = minimum_image
446        if "nblist_mult_size" in training_parameters:
447            stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
448        if "nblist_add_neigh" in training_parameters:
449            stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
450        if "nblist_add_atoms" in training_parameters:
451            stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
452        layer_state.append(freeze(stnew))
453
454    preproc_state["layers_state"] = tuple(layer_state)
455    preproc_state["check_input"] = False
456    model.preproc_state = freeze(preproc_state)
457
458    ### INITIALIZE METRICS ###
459    maes_prev = defaultdict(lambda: np.inf)
460    metrics_beta = training_parameters.get("metrics_ema_decay", -1.0)
461    smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0
462    if smoothen_metrics:
463        print("Computing smoothed metrics with beta =", metrics_beta)
464        rmses_smooth = defaultdict(lambda: 0.0)
465        maes_smooth = defaultdict(lambda: 0.0)
466        rmse_tot_smooth = 0.0
467        mae_tot_smooth = 0.0
468        nsmooth = 0
469    max_restore_count = training_parameters.get("max_restore_count", 5)    
470
471    fmetrics = open(
472        output_directory + f"metrics{stage_prefix}.traj",
473        "w" if restart_checkpoint is None else "a",
474    )
475
476    keep_all_bests = training_parameters.get("keep_all_bests", False)
477    save_model_at_epochs = training_parameters.get("save_model_at_epochs", [])
478    save_model_at_epochs= set([int(x) for x in save_model_at_epochs])
479    save_model_every = int(training_parameters.get("save_model_every_epoch", 0))
480    if save_model_every > 0:
481        save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every))
482        save_model_at_epochs = save_model_at_epochs.union(save_model_every)
483    save_model_at_epochs = list(save_model_at_epochs)
484    save_model_at_epochs.sort()
485
486    ### LR factor after restoring a model that has diverged
487    restore_scaling = training_parameters.get("restore_scaling", 1)
488    assert restore_scaling > 0.0, "restore_scaling must be > 0.0"
489    assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0"
490    if restore_scaling != 1:
491        print("Applying LR scaling after restore:", restore_scaling)
492
493
494    metric_use_best = training_parameters.get("metric_best", "rmse_tot")  # .lower()
495    metrics_state = {}
496
497    ### INITIALIZE TRAINING STATE ###
498    if restart_checkpoint is not None:
499        training_state = deepcopy(restart_checkpoint["training"])
500        epoch_start = training_state["epoch"]
501        model.preproc_state = freeze(restart_checkpoint["preproc_state"])
502        if smoothen_metrics:
503            metrics_state = deepcopy(restart_checkpoint["metrics_state"])
504            rmses_smooth = metrics_state["rmses_smooth"]
505            maes_smooth = metrics_state["maes_smooth"]
506            rmse_tot_smooth = metrics_state["rmse_tot_smooth"]
507            mae_tot_smooth = metrics_state["mae_tot_smooth"]
508            nsmooth = metrics_state["nsmooth"]
509        print("Restored training state")
510    else:
511        epoch_start = 1
512        training_state = {
513            "rng_key": rng_key,
514            "opt_state": opt_st,
515            "ema_state": ema_st,
516            "sch_state": sch_state,
517            "variables": model.variables,
518            "variables_ema": model.variables,
519            "step": 0,
520            "restore_count": 0,
521            "restore_scale": 1,
522            "epoch": 0,
523            "best_metric": np.inf,
524        }
525        if smoothen_metrics:
526            metrics_state = {
527                "rmses_smooth": dict(rmses_smooth),
528                "maes_smooth": dict(maes_smooth),
529                "rmse_tot_smooth": rmse_tot_smooth,
530                "mae_tot_smooth": mae_tot_smooth,
531                "nsmooth": nsmooth,
532            }
533    
534    del opt_st, ema_st, sch_state, preproc_state
535    variables_save = training_state['variables']
536    variables_ema_save = training_state['variables_ema']
537
538    ### Training loop ###
539    start = time.time()
540    print("Starting training...")
541    for epoch in range(epoch_start, max_epochs+1):
542        training_state["epoch"] = epoch
543        s = time.time()
544        for _ in range(nbatch_per_epoch):
545            # fetch data
546            inputs0, data = next(training_iterator)
547
548            # preprocess data
549            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
550
551            rng_key, subkey = jax.random.split(rng_key)
552            inputs["rng_key"] = subkey
553            inputs["training_epoch"] = epoch
554
555            # adjust learning rate
556            current_lr, training_state["sch_state"] = schedule(training_state["sch_state"])
557            current_lr *= training_state["restore_scale"]
558            training_state["opt_state"].inner_states["trainable"].inner_state[-1].hyperparams[
559                "step_size"
560            ] = current_lr 
561
562            # train step
563            loss,training_state, _ = train_step(data,inputs,training_state)
564
565        rmses_avg = defaultdict(lambda: 0.0)
566        maes_avg = defaultdict(lambda: 0.0)
567        for _ in range(nbatch_per_validation):
568            inputs0, data = next(validation_iterator)
569
570            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
571
572            rng_key, subkey = jax.random.split(rng_key)
573            inputs["rng_key"] = subkey
574            inputs["training_epoch"] = epoch
575
576            rmses, maes, output_val = validation(
577                data=data,
578                inputs=inputs,
579                variables=training_state["variables_ema"],
580            )
581            for k, v in rmses.items():
582                rmses_avg[k] += v
583            for k, v in maes.items():
584                maes_avg[k] += v
585
586        jax.block_until_ready(output_val)
587
588        ### Timings ###
589        e = time.time()
590        epoch_time = e - s
591
592        elapsed_time = e - start
593        if not adaptive_scheduler:
594            remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch
595            remain_last = epoch_time * (max_epochs - epoch + 1)
596            # estimate remaining time via weighted average (put more weight on last at the beginning)
597            wremain = np.sin(0.5 * np.pi * epoch / max_epochs)
598            remaining_time = human_time_duration(
599                remain_glob * wremain + remain_last * (1 - wremain)
600            )
601        elapsed_time = human_time_duration(elapsed_time)
602        batch_time = human_time_duration(
603            epoch_time / (nbatch_per_epoch + nbatch_per_validation)
604        )
605        epoch_time = human_time_duration(epoch_time)
606
607        for k in rmses_avg.keys():
608            rmses_avg[k] /= nbatch_per_validation
609        for k in maes_avg.keys():
610            maes_avg[k] /= nbatch_per_validation
611
612        ### Print metrics ###
613        print("")
614        line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}"
615        line += f", epoch time = {epoch_time}, batch time = {batch_time}"
616        line += f", elapsed time = {elapsed_time}"
617        if not adaptive_scheduler:
618            line += f", est. remaining time = {remaining_time}"
619        print(line)
620        rmse_tot = 0.0
621        mae_tot = 0.0
622        for k in rmses_avg.keys():
623            mult = loss_definition[k]["mult"]
624            loss_prms = loss_definition[k]
625            rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5
626            mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"]
627            unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else ""
628
629            weight_str = ""
630            if "weight_schedule" in loss_prms:
631                w = linear_schedule(epoch, *loss_prms["weight_schedule"])
632                weight_str = f"(w={w:.3f})"
633
634            if rmses_avg[k] / mult < 1.0e-2:
635                print(
636                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e}   {unit}  {weight_str}"
637                )
638            else:
639                print(
640                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f}   {unit}  {weight_str}"
641                )
642
643        ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ###
644        restore = False
645        reinit = False
646        for k, mae in maes_avg.items():
647            if np.isnan(mae):
648                restore = True
649                reinit = True
650                print("NaN detected in mae")
651                # for k, v in inputs.items():
652                #     if hasattr(v,"shape"):
653                #         if np.isnan(v).any():
654                #             print(k,v)
655                # sys.exit(1)
656            if "threshold" in loss_definition[k]:
657                thr = loss_definition[k]["threshold"]
658                if mae > thr * maes_prev[k]:
659                    restore = True
660                    break
661
662        epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S")
663        model.variables = training_state["variables_ema"]
664        if epoch in save_model_at_epochs:
665            filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx"
666            print("Scheduled model save :", filename)
667            model.save(filename)
668
669        if restore:
670            training_state["restore_count"] += 1
671            if training_state["restore_count"]  > max_restore_count:
672                if reinit:
673                    raise ValueError("Model diverged and could not be restored.")
674                else:
675                    training_state["restore_count"]  = 0
676                    print(
677                        f"{max_restore_count} unsuccessful restores, resuming training."
678                    )
679
680            training_state["variables"] = deepcopy(variables_save)
681            training_state["variables_ema"] = deepcopy(variables_ema_save)
682            model.variables = training_state["variables_ema"]
683            training_state["restore_scale"] *= restore_scaling
684            print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}")
685            if reinit:
686                training_state["opt_state"] = optimizer.init(model.variables)
687                training_state["ema_state"] = ema.init(model.variables)
688                print("Reinitialized optimizer after divergence.")
689            save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
690            continue
691
692        training_state["restore_count"]  = 0
693
694        variables_save = deepcopy(training_state["variables"])
695        variables_ema_save = deepcopy(training_state["variables_ema"])
696        maes_prev = maes_avg
697
698        ### SAVE MODEL ###
699        model.save(output_directory + "latest_model.fnx")
700
701        # save metrics
702        step = int(training_state["step"])
703        metrics = {
704            "epoch": epoch,
705            "step": step,
706            "data_count": step * batch_size,
707            "elapsed_time": time.time() - start,
708            "lr": current_lr,
709            "loss": loss,
710        }
711        for k in rmses_avg.keys():
712            mult = loss_definition[k]["mult"]
713            metrics[f"rmse_{k}"] = rmses_avg[k] / mult
714            metrics[f"mae_{k}"] = maes_avg[k] / mult
715
716        metrics["rmse_tot"] = rmse_tot
717        metrics["mae_tot"] = mae_tot
718        if smoothen_metrics:
719            nsmooth += 1
720            for k in rmses_avg.keys():
721                mult = loss_definition[k]["mult"]
722                rmses_smooth[k] = (
723                    metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k]
724                )
725                maes_smooth[k] = (
726                    metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k]
727                )
728                metrics[f"rmse_smooth_{k}"] = (
729                    rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
730                )
731                metrics[f"mae_smooth_{k}"] = (
732                    maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
733                )
734            rmse_tot_smooth = (
735                metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot
736            )
737            mae_tot_smooth = (
738                metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot
739            )
740            metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth)
741            metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth)
742
743            # update metrics state
744            metrics_state["rmses_smooth"] = dict(rmses_smooth)
745            metrics_state["maes_smooth"] = dict(maes_smooth)
746            metrics_state["rmse_tot_smooth"] = rmse_tot_smooth
747            metrics_state["mae_tot_smooth"] = mae_tot_smooth
748            metrics_state["nsmooth"] = nsmooth
749
750        assert (
751            metric_use_best in metrics
752        ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics"
753
754        metric_for_best = metrics[metric_use_best]
755        if metric_for_best <  training_state["best_metric"]:
756            training_state["best_metric"] = metric_for_best
757            metrics["best_metric"] = training_state["best_metric"]
758            if keep_all_bests:
759                best_name = (
760                    output_directory
761                    + f"best_model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx"
762                )
763                model.save(best_name)
764
765            best_name = output_directory + f"best_model{stage_prefix}.fnx"
766            model.save(best_name)
767            print("New best model saved to:", best_name)
768        else:
769            metrics["best_metric"] = training_state["best_metric"]
770
771        if epoch == 1:
772            headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())]
773            fmetrics.write("# " + " ".join(headers) + "\n")
774        fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n")
775        fmetrics.flush()
776
777        # update learning rate using current metrics
778        assert (
779            schedule_metrics in metrics
780        ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics"
781        current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics])
782
783        # update and save training state
784        save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
785
786        
787        if is_end(metrics):
788            print("Stage finished.")
789            break
790
791    end = time.time()
792
793    print(f"Training time: {human_time_duration(end-start)}")
794    print("")
795
796    fmetrics.close()
797
798    filename = output_directory + f"final_model{stage_prefix}.fnx"
799    model.save(filename)
800    print("Final model saved to:", filename)
801
802    return metrics, filename