fennol.training.training

  1import os
  2import yaml
  3import sys
  4import jax
  5import io
  6import time
  7import jax.numpy as jnp
  8import numpy as np
  9import optax
 10from collections import defaultdict
 11import json
 12from copy import deepcopy
 13from pathlib import Path
 14import argparse
 15try:
 16    import torch
 17except ImportError:
 18    raise ImportError(
 19        "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/"
 20    )
 21import random
 22from flax import traverse_util
 23import json
 24import shutil
 25import pickle
 26
 27from flax.core import freeze, unfreeze
 28from .io import (
 29    load_configuration,
 30    load_dataset,
 31    load_model,
 32    TeeLogger,
 33    copy_parameters,
 34)
 35from .utils import (
 36    get_loss_definition,
 37    get_train_step_function,
 38    get_validation_function,
 39    linear_schedule,
 40)
 41from .optimizers import get_optimizer, get_lr_schedule
 42from ..utils import deep_update, AtomicUnits as au
 43from ..utils.io import human_time_duration
 44from ..models.preprocessing import AtomPadding, check_input, convert_to_jax
 45
 46def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state):
 47    with open(output_directory + "/restart_checkpoint.pkl", "wb") as f:
 48            pickle.dump({
 49                "stage": stage,
 50                "training": training_state,
 51                "preproc_state": preproc_state,
 52                "metrics_state": metrics_state,
 53            }, f)
 54
 55def load_restart_checkpoint(output_directory):
 56    restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl"
 57    if not os.path.exists(restart_checkpoint_file):
 58        raise FileNotFoundError(
 59            f"Training state file not found: {restart_checkpoint_file}"
 60        )
 61    with open(restart_checkpoint_file, "rb") as f:
 62        restart_checkpoint = pickle.load(f)
 63    return restart_checkpoint
 64
 65
 66def main():
 67    parser = argparse.ArgumentParser(prog="fennol_train")
 68    parser.add_argument("config_file", type=str)
 69    parser.add_argument("--model_file", type=str, default=None)
 70    args = parser.parse_args()
 71    config_file = args.config_file
 72    model_file = args.model_file
 73
 74    os.environ["OMP_NUM_THREADS"] = "1"
 75    sys.stdout = io.TextIOWrapper(
 76        open(sys.stdout.fileno(), "wb", 0), write_through=True
 77    )
 78
 79    restart_training = False
 80    if os.path.isdir(config_file):
 81        output_directory = Path(config_file).absolute().as_posix()
 82        restart_training = True
 83        while output_directory.endswith("/"):
 84            output_directory = output_directory[:-1]
 85        config_file = output_directory + "/config.yaml"
 86        backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
 87        shutil.copytree(output_directory, backup_dir)
 88        print("Restarting training from", output_directory)
 89
 90    parameters = load_configuration(config_file)
 91
 92    ### Set the device
 93    if "FENNOL_DEVICE" in os.environ:
 94        device = os.environ["FENNOL_DEVICE"].lower()
 95        print(f"# Setting device from env FENNOL_DEVICE={device}")
 96    else:
 97        device: str = parameters.get("device", "cpu").lower()
 98    if device == "cpu":
 99        jax.config.update('jax_platforms', 'cpu')
100        os.environ["CUDA_VISIBLE_DEVICES"] = ""
101    elif device.startswith("cuda") or device.startswith("gpu"):
102        if ":" in device:
103            num = device.split(":")[-1]
104            os.environ["CUDA_VISIBLE_DEVICES"] = num
105        else:
106            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
107        device = "gpu"
108
109    _device = jax.devices(device)[0]
110    jax.config.update("jax_default_device", _device)
111
112    if restart_training:
113        restart_checkpoint = load_restart_checkpoint(output_directory)
114    else:
115        restart_checkpoint = None
116
117    # output directory
118    if not restart_training:
119        output_directory = parameters.get("output_directory", None)
120        if output_directory is not None:
121            if "{now}" in output_directory:
122                output_directory = output_directory.replace(
123                    "{now}", time.strftime("%Y-%m-%d-%H-%M-%S")
124                )
125            output_directory = Path(output_directory).absolute()
126            if not output_directory.exists():
127                output_directory.mkdir(parents=True)
128            print("Output directory:", output_directory)
129        else:
130            output_directory = "."
131
132    output_directory = str(output_directory) + "/"
133
134    # copy config_file to output directory
135    # config_name = Path(config_file).name
136    config_ext = Path(config_file).suffix
137    with open(config_file) as f_in:
138        config_data = f_in.read()
139    with open(output_directory + "/config" + config_ext, "w") as f_out:
140        f_out.write(config_data)
141
142    # set log file
143    log_file = "train.log"  # parameters.get("log_file", None)
144    logger = TeeLogger(output_directory + log_file)
145    logger.bind_stdout()
146
147    # set matmul precision
148    enable_x64 = parameters.get("double_precision", False)
149    jax.config.update("jax_enable_x64", enable_x64)
150    fprec = "float64" if enable_x64 else "float32"
151    parameters["fprec"] = fprec
152    if enable_x64:
153        print("Double precision enabled.")
154
155    matmul_precision = parameters.get("matmul_prec", "highest").lower()
156    assert matmul_precision in [
157        "default",
158        "high",
159        "highest",
160    ], "matmul_prec must be one of 'default','high','highest'"
161    jax.config.update("jax_default_matmul_precision", matmul_precision)
162
163    # set random seed
164    rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1))
165    print(f"rng_seed: {rng_seed}")
166    rng_key = jax.random.PRNGKey(rng_seed)
167    torch.manual_seed(rng_seed)
168    np.random.seed(rng_seed)
169    random.seed(rng_seed)
170    np_rng = np.random.Generator(np.random.PCG64(rng_seed))
171
172    try:
173        if "stages" in parameters["training"]:
174            ## train in stages ##
175            params = deepcopy(parameters)
176            stages = params["training"].pop("stages")
177            assert isinstance(stages, dict), "'stages' must be a dict with named stages"
178            model_file_stage = model_file
179            print_stages_params = params["training"].get("print_stages_params", False)
180            for i, (stage, stage_params) in enumerate(stages.items()):
181                rng_key, subkey = jax.random.split(rng_key)
182                print("")
183                print(f"### STAGE {i+1}: {stage} ###")
184
185                ## remove end_event from previous stage ##
186                if i > 0 and "end_event" in params["training"]:
187                    params["training"].pop("end_event")
188
189                ## incrementally update training parameters ##
190                params = deep_update(params, {"training": stage_params})
191                if model_file_stage is not None:
192                    ## load model from previous stage ##
193                    params["model_file"] = model_file_stage
194
195                if restart_training and restart_checkpoint["stage"] != i + 1:
196                    print(f"Skipping stage {i+1} (already completed)")
197                    continue
198
199                if print_stages_params:
200                    print("stage parameters:")
201                    print(json.dumps(params, indent=2, sort_keys=False))
202
203                ## train stage ##
204                _, model_file_stage = train(
205                    (subkey, np_rng),
206                    params,
207                    stage=i + 1,
208                    output_directory=output_directory,
209                    restart_checkpoint=restart_checkpoint,
210                )
211
212                restart_training = False
213                restart_checkpoint = None
214        else:
215            ## single training stage ##
216            train(
217                (rng_key, np_rng),
218                parameters,
219                model_file=model_file,
220                output_directory=output_directory,
221                restart_checkpoint=restart_checkpoint,
222            )
223    except KeyboardInterrupt:
224        print("Training interrupted by user.")
225    finally:
226        if log_file is not None:
227            logger.unbind_stdout()
228            logger.close()
229
230
231def train(
232    rng,
233    parameters,
234    model_file=None,
235    stage=None,
236    output_directory=None,
237    restart_checkpoint=None,
238):
239    if output_directory is None:
240        output_directory = "./"
241    elif output_directory == "":
242        output_directory = "./"
243    elif not output_directory.endswith("/"):
244        output_directory += "/"
245    stage_prefix = f"_stage_{stage}" if stage is not None else ""
246
247    if isinstance(rng, tuple):
248        rng_key, np_rng = rng
249    else:
250        rng_key = rng
251        np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1)))
252
253    ### GET TRAINING PARAMETERS ###
254    training_parameters = parameters.get("training", {})
255
256    #### LOAD MODEL ####
257    if restart_checkpoint is not None:
258        model_key = None
259        model_file = output_directory + "latest_model.fnx"
260    else:
261        rng_key, model_key = jax.random.split(rng_key)
262    model = load_model(parameters, model_file, rng_key=model_key)
263
264    ### SAVE INITIAL MODEL ###
265    save_initial_model = (
266        restart_checkpoint is None 
267        and (stage is None or stage == 0)
268    )
269    if save_initial_model:
270        model.save(output_directory + "initial_model.fnx")
271
272    model_ref = None
273    if "model_ref" in training_parameters:
274        model_ref = load_model(parameters, training_parameters["model_ref"])
275        print("Reference model:", training_parameters["model_ref"])
276
277        if "ref_parameters" in training_parameters:
278            ref_parameters = training_parameters["ref_parameters"]
279            assert isinstance(
280                ref_parameters, list
281            ), "ref_parameters must be a list of str"
282            print("Reference parameters:", ref_parameters)
283            model.variables = copy_parameters(
284                model.variables, model_ref.variables, ref_parameters
285            )
286
287    ### SET FLOATING POINT PRECISION ###
288    fprec = parameters.get("fprec", "float32")
289
290    def convert_to_fprec(x):
291        if jnp.issubdtype(x.dtype, jnp.floating):
292            return x.astype(fprec)
293        return x
294
295    model.variables = jax.tree_map(convert_to_fprec, model.variables)
296
297    ### SET UP LOSS FUNCTION ###
298    loss_definition, used_keys, ref_keys = get_loss_definition(
299        training_parameters, model_energy_unit=model.energy_unit
300    )
301
302    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
303    if coordinates_ref_key is not None:
304        compute_ref_coords = True
305        print("Reference coordinates:", coordinates_ref_key)
306    else:
307        compute_ref_coords = False
308
309    #### LOAD DATASET ####
310    dspath = training_parameters.get("dspath", None)
311    if dspath is None:
312        raise ValueError("Dataset path 'training/dspath' should be specified.")
313    batch_size = training_parameters.get("batch_size", 16)
314    batch_size_val = training_parameters.get("batch_size_val", None)
315    rename_refs = training_parameters.get("rename_refs", {})
316    training_iterator, validation_iterator = load_dataset(
317        dspath=dspath,
318        batch_size=batch_size,
319        batch_size_val=batch_size_val,
320        training_parameters=training_parameters,
321        infinite_iterator=True,
322        atom_padding=True,
323        ref_keys=ref_keys,
324        split_data_inputs=True,
325        np_rng=np_rng,
326        add_flags=["training"],
327        fprec=fprec,
328        rename_refs=rename_refs,
329    )
330
331    compute_forces = "forces" in used_keys
332    compute_virial = "virial_tensor" in used_keys or "virial" in used_keys
333    compute_stress = "stress_tensor" in used_keys or "stress" in used_keys
334    compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys
335
336    ### get optimizer parameters ###
337    max_epochs = int(training_parameters.get("max_epochs", 2000))
338    epoch_format = len(str(max_epochs))
339    nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200)
340    nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20)
341
342    schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters)
343    optimizer = get_optimizer(
344        training_parameters, model.variables, schedule(sch_state)[0]
345    )
346    opt_st = optimizer.init(model.variables)
347
348    ### exponential moving average of the parameters ###
349    ema_decay = training_parameters.get("ema_decay", -1.0)
350    if ema_decay > 0.0:
351        assert ema_decay < 1.0, "ema_decay must be in (0,1)"
352        ema = optax.ema(decay=ema_decay)
353    else:
354        ema = optax.identity()
355    ema_st = ema.init(model.variables)
356
357    #### end event ####
358    end_event = training_parameters.get("end_event", None)
359    if end_event is None or isinstance(end_event, str) and end_event.lower() == "none":
360        is_end = lambda metrics: False
361    else:
362        assert len(end_event) == 2, "end_event must be a list of two elements"
363        is_end = lambda metrics: metrics[end_event[0]] < end_event[1]
364
365    if "energy_terms" in training_parameters:
366        model.set_energy_terms(training_parameters["energy_terms"], jit=False)
367    print("energy terms:", model.energy_terms)
368
369    ### MODEL EVALUATION FUNCTION ###
370    pbc_training = training_parameters.get("pbc_training", False)
371    if compute_stress or compute_virial or compute_pressure:
372        virial_key = "virial" if "virial" in used_keys else "virial_tensor"
373        stress_key = "stress" if "stress" in used_keys else "stress_tensor"
374        pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor"
375        if compute_stress or compute_pressure:
376            assert pbc_training, "PBC must be enabled for stress or virial training"
377            print("Computing forces and stress tensor")
378            def evaluate(model, variables, data):
379                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
380                cells = output["cells"]
381                volume = jnp.abs(jnp.linalg.det(cells))
382                stress = vir / volume[:, None, None]
383                output[stress_key] = stress
384                output[virial_key] = vir
385                if pressure_key == "pressure":
386                    output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0
387                else:
388                    output[pressure_key] = -stress
389                return output
390
391        else:
392            print("Computing forces and virial tensor")
393            def evaluate(model, variables, data):
394                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
395                output[virial_key] = vir
396                return output
397
398    elif compute_forces:
399        print("Computing forces")
400        def evaluate(model, variables, data):
401            _, _, output = model._energy_and_forces(variables, data)
402            return output
403
404    elif model.energy_terms is not None:
405        def evaluate(model, variables, data):
406            _, output = model._total_energy(variables, data)
407            return output
408
409    else:
410        def evaluate(model, variables, data):
411            output = model.modules.apply(variables, data)
412            return output
413
414    ### TRAINING FUNCTIONS ###
415    train_step = get_train_step_function(
416        loss_definition=loss_definition,
417        model=model,
418        model_ref=model_ref,
419        compute_ref_coords=compute_ref_coords,
420        evaluate=evaluate,
421        optimizer=optimizer,
422        ema=ema,
423    )
424
425    validation = get_validation_function(
426        loss_definition=loss_definition,
427        model=model,
428        model_ref=model_ref,
429        compute_ref_coords=compute_ref_coords,
430        evaluate=evaluate,
431        return_targets=False,
432    )
433
434    #### CONFIGURE PREPROCESSING ####
435    gpu_preprocessing = training_parameters.get("gpu_preprocessing", False)
436    if gpu_preprocessing:
437        print("GPU preprocessing activated.")
438
439    preproc_state = unfreeze(model.preproc_state)
440    layer_state = []
441    for st in preproc_state["layers_state"]:
442        stnew = unfreeze(st)
443        #     st["nblist_skin"] = nblist_skin
444        #     if nblist_stride > 1:
445        #         st["skin_stride"] = nblist_stride
446        #         st["skin_count"] = nblist_stride
447        if "nblist_mult_size" in training_parameters:
448            stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
449        if "nblist_add_neigh" in training_parameters:
450            stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
451        if "nblist_add_atoms" in training_parameters:
452            stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
453        layer_state.append(freeze(stnew))
454
455    preproc_state["layers_state"] = tuple(layer_state)
456    preproc_state["check_input"] = False
457    model.preproc_state = freeze(preproc_state)
458
459    ### INITIALIZE METRICS ###
460    maes_prev = defaultdict(lambda: np.inf)
461    metrics_beta = training_parameters.get("metrics_ema_decay", -1.0)
462    smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0
463    if smoothen_metrics:
464        print("Computing smoothed metrics with beta =", metrics_beta)
465        rmses_smooth = defaultdict(lambda: 0.0)
466        maes_smooth = defaultdict(lambda: 0.0)
467        rmse_tot_smooth = 0.0
468        mae_tot_smooth = 0.0
469        nsmooth = 0
470    max_restore_count = training_parameters.get("max_restore_count", 5)    
471
472    fmetrics = open(
473        output_directory + f"metrics{stage_prefix}.traj",
474        "w" if restart_checkpoint is None else "a",
475    )
476
477    keep_all_bests = training_parameters.get("keep_all_bests", False)
478    save_model_at_epochs = training_parameters.get("save_model_at_epochs", [])
479    save_model_at_epochs= set([int(x) for x in save_model_at_epochs])
480    save_model_every = int(training_parameters.get("save_model_every_epoch", 0))
481    if save_model_every > 0:
482        save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every))
483        save_model_at_epochs = save_model_at_epochs.union(save_model_every)
484    if "save_model_schedule" in training_parameters:
485        save_model_schedule_def = [int(i) for i in training_parameters["save_model_schedule"]]
486        save_model_schedule = []
487        while len(save_model_schedule_def) > 0:
488            i = save_model_schedule_def.pop(0)
489            if i > 0:
490                save_model_schedule.append(i)
491            elif i < 0:
492                start = -i
493                freq = save_model_schedule_def.pop(0)
494                assert freq > 0, "syntax error in save_model_schedule"
495                save_model_schedule.extend(
496                    list(range(start, max_epochs + 1, freq))
497                )
498        save_model_at_epochs = save_model_at_epochs.union(save_model_schedule)
499    save_model_at_epochs = list(save_model_at_epochs)
500    save_model_at_epochs.sort()
501
502    ### LR factor after restoring a model that has diverged
503    restore_scaling = training_parameters.get("restore_scaling", 1)
504    assert restore_scaling > 0.0, "restore_scaling must be > 0.0"
505    assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0"
506    if restore_scaling != 1:
507        print("Applying LR scaling after restore:", restore_scaling)
508
509
510    metric_use_best = training_parameters.get("metric_best", "rmse_tot")  # .lower()
511    metrics_state = {}
512
513    ### INITIALIZE TRAINING STATE ###
514    if restart_checkpoint is not None:
515        training_state = deepcopy(restart_checkpoint["training"])
516        epoch_start = training_state["epoch"]
517        model.preproc_state = freeze(restart_checkpoint["preproc_state"])
518        if smoothen_metrics:
519            metrics_state = deepcopy(restart_checkpoint["metrics_state"])
520            rmses_smooth = metrics_state["rmses_smooth"]
521            maes_smooth = metrics_state["maes_smooth"]
522            rmse_tot_smooth = metrics_state["rmse_tot_smooth"]
523            mae_tot_smooth = metrics_state["mae_tot_smooth"]
524            nsmooth = metrics_state["nsmooth"]
525        print("Restored training state")
526    else:
527        epoch_start = 1
528        training_state = {
529            "rng_key": rng_key,
530            "opt_state": opt_st,
531            "ema_state": ema_st,
532            "sch_state": sch_state,
533            "variables": model.variables,
534            "variables_ema": model.variables,
535            "step": 0,
536            "restore_count": 0,
537            "restore_scale": 1,
538            "epoch": 0,
539            "best_metric": np.inf,
540        }
541        if smoothen_metrics:
542            metrics_state = {
543                "rmses_smooth": dict(rmses_smooth),
544                "maes_smooth": dict(maes_smooth),
545                "rmse_tot_smooth": rmse_tot_smooth,
546                "mae_tot_smooth": mae_tot_smooth,
547                "nsmooth": nsmooth,
548            }
549    
550    del opt_st, ema_st, sch_state, preproc_state
551    variables_save = training_state['variables']
552    variables_ema_save = training_state['variables_ema']
553
554    ### Training loop ###
555    start = time.time()
556    print("Starting training...")
557    for epoch in range(epoch_start, max_epochs+1):
558        training_state["epoch"] = epoch
559        s = time.time()
560        for _ in range(nbatch_per_epoch):
561            # fetch data
562            inputs0, data = next(training_iterator)
563
564            # preprocess data
565            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
566
567            rng_key, subkey = jax.random.split(rng_key)
568            inputs["rng_key"] = subkey
569            inputs["training_epoch"] = epoch
570
571            # adjust learning rate
572            current_lr, training_state["sch_state"] = schedule(training_state["sch_state"])
573            current_lr *= training_state["restore_scale"]
574            training_state["opt_state"].inner_states["trainable"].inner_state[-1].hyperparams[
575                "step_size"
576            ] = current_lr 
577
578            # train step
579            loss,training_state, _ = train_step(data,inputs,training_state)
580
581        rmses_avg = defaultdict(lambda: 0.0)
582        maes_avg = defaultdict(lambda: 0.0)
583        for _ in range(nbatch_per_validation):
584            inputs0, data = next(validation_iterator)
585
586            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
587
588            rng_key, subkey = jax.random.split(rng_key)
589            inputs["rng_key"] = subkey
590            inputs["training_epoch"] = epoch
591
592            rmses, maes, output_val = validation(
593                data=data,
594                inputs=inputs,
595                variables=training_state["variables_ema"],
596            )
597            for k, v in rmses.items():
598                rmses_avg[k] += v
599            for k, v in maes.items():
600                maes_avg[k] += v
601
602        jax.block_until_ready(output_val)
603
604        ### Timings ###
605        e = time.time()
606        epoch_time = e - s
607
608        elapsed_time = e - start
609        if not adaptive_scheduler:
610            remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch
611            remain_last = epoch_time * (max_epochs - epoch + 1)
612            # estimate remaining time via weighted average (put more weight on last at the beginning)
613            wremain = np.sin(0.5 * np.pi * epoch / max_epochs)
614            remaining_time = human_time_duration(
615                remain_glob * wremain + remain_last * (1 - wremain)
616            )
617        elapsed_time = human_time_duration(elapsed_time)
618        batch_time = human_time_duration(
619            epoch_time / (nbatch_per_epoch + nbatch_per_validation)
620        )
621        epoch_time = human_time_duration(epoch_time)
622
623        for k in rmses_avg.keys():
624            rmses_avg[k] /= nbatch_per_validation
625        for k in maes_avg.keys():
626            maes_avg[k] /= nbatch_per_validation
627
628        ### Print metrics ###
629        print("")
630        line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}"
631        line += f", epoch time = {epoch_time}, batch time = {batch_time}"
632        line += f", elapsed time = {elapsed_time}"
633        if not adaptive_scheduler:
634            line += f", est. remaining time = {remaining_time}"
635        print(line)
636        rmse_tot = 0.0
637        mae_tot = 0.0
638        for k in rmses_avg.keys():
639            mult = loss_definition[k]["mult"]
640            loss_prms = loss_definition[k]
641            rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5
642            mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"]
643            unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else ""
644
645            weight_str = ""
646            if "weight_schedule" in loss_prms:
647                w = linear_schedule(epoch, *loss_prms["weight_schedule"])
648                weight_str = f"(w={w:.3f})"
649
650            if rmses_avg[k] / mult < 1.0e-2:
651                print(
652                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e}   {unit}  {weight_str}"
653                )
654            else:
655                print(
656                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f}   {unit}  {weight_str}"
657                )
658
659        ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ###
660        restore = False
661        reinit = False
662        for k, mae in maes_avg.items():
663            if np.isnan(mae):
664                restore = True
665                reinit = True
666                print("NaN detected in mae")
667                # for k, v in inputs.items():
668                #     if hasattr(v,"shape"):
669                #         if np.isnan(v).any():
670                #             print(k,v)
671                # sys.exit(1)
672            if "threshold" in loss_definition[k]:
673                thr = loss_definition[k]["threshold"]
674                if mae > thr * maes_prev[k]:
675                    restore = True
676                    break
677
678        epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S")
679        model.variables = training_state["variables_ema"]
680        if epoch in save_model_at_epochs:
681            filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx"
682            print("Scheduled model save :", filename)
683            model.save(filename)
684
685        if restore:
686            training_state["restore_count"] += 1
687            if training_state["restore_count"]  > max_restore_count:
688                if reinit:
689                    raise ValueError("Model diverged and could not be restored.")
690                else:
691                    training_state["restore_count"]  = 0
692                    print(
693                        f"{max_restore_count} unsuccessful restores, resuming training."
694                    )
695
696            training_state["variables"] = deepcopy(variables_save)
697            training_state["variables_ema"] = deepcopy(variables_ema_save)
698            model.variables = training_state["variables_ema"]
699            training_state["restore_scale"] *= restore_scaling
700            print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}")
701            if reinit:
702                training_state["opt_state"] = optimizer.init(model.variables)
703                training_state["ema_state"] = ema.init(model.variables)
704                print("Reinitialized optimizer after divergence.")
705            save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
706            continue
707
708        training_state["restore_count"]  = 0
709
710        variables_save = deepcopy(training_state["variables"])
711        variables_ema_save = deepcopy(training_state["variables_ema"])
712        maes_prev = maes_avg
713
714        ### SAVE MODEL ###
715        model.save(output_directory + "latest_model.fnx")
716
717        # save metrics
718        step = int(training_state["step"])
719        metrics = {
720            "epoch": epoch,
721            "step": step,
722            "data_count": step * batch_size,
723            "elapsed_time": time.time() - start,
724            "lr": current_lr,
725            "loss": loss,
726        }
727        for k in rmses_avg.keys():
728            mult = loss_definition[k]["mult"]
729            metrics[f"rmse_{k}"] = rmses_avg[k] / mult
730            metrics[f"mae_{k}"] = maes_avg[k] / mult
731
732        metrics["rmse_tot"] = rmse_tot
733        metrics["mae_tot"] = mae_tot
734        if smoothen_metrics:
735            nsmooth += 1
736            for k in rmses_avg.keys():
737                mult = loss_definition[k]["mult"]
738                rmses_smooth[k] = (
739                    metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k]
740                )
741                maes_smooth[k] = (
742                    metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k]
743                )
744                metrics[f"rmse_smooth_{k}"] = (
745                    rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
746                )
747                metrics[f"mae_smooth_{k}"] = (
748                    maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
749                )
750            rmse_tot_smooth = (
751                metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot
752            )
753            mae_tot_smooth = (
754                metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot
755            )
756            metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth)
757            metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth)
758
759            # update metrics state
760            metrics_state["rmses_smooth"] = dict(rmses_smooth)
761            metrics_state["maes_smooth"] = dict(maes_smooth)
762            metrics_state["rmse_tot_smooth"] = rmse_tot_smooth
763            metrics_state["mae_tot_smooth"] = mae_tot_smooth
764            metrics_state["nsmooth"] = nsmooth
765
766        assert (
767            metric_use_best in metrics
768        ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics"
769
770        metric_for_best = metrics[metric_use_best]
771        if metric_for_best <  training_state["best_metric"]:
772            training_state["best_metric"] = metric_for_best
773            metrics["best_metric"] = training_state["best_metric"]
774            if keep_all_bests:
775                best_name = (
776                    output_directory
777                    + f"best_model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx"
778                )
779                model.save(best_name)
780
781            best_name = output_directory + f"best_model{stage_prefix}.fnx"
782            model.save(best_name)
783            print("New best model saved to:", best_name)
784        else:
785            metrics["best_metric"] = training_state["best_metric"]
786
787        if epoch == 1:
788            headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())]
789            fmetrics.write("# " + " ".join(headers) + "\n")
790        fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n")
791        fmetrics.flush()
792
793        # update learning rate using current metrics
794        assert (
795            schedule_metrics in metrics
796        ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics"
797        current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics])
798
799        # update and save training state
800        save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
801
802        
803        if is_end(metrics):
804            print("Stage finished.")
805            break
806
807    end = time.time()
808
809    print(f"Training time: {human_time_duration(end-start)}")
810    print("")
811
812    fmetrics.close()
813
814    filename = output_directory + f"final_model{stage_prefix}.fnx"
815    model.save(filename)
816    print("Final model saved to:", filename)
817
818    return metrics, filename
819
820
821if __name__ == "__main__":
822    main()
def save_restart_checkpoint( output_directory, stage, training_state, preproc_state, metrics_state):
47def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state):
48    with open(output_directory + "/restart_checkpoint.pkl", "wb") as f:
49            pickle.dump({
50                "stage": stage,
51                "training": training_state,
52                "preproc_state": preproc_state,
53                "metrics_state": metrics_state,
54            }, f)
def load_restart_checkpoint(output_directory):
56def load_restart_checkpoint(output_directory):
57    restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl"
58    if not os.path.exists(restart_checkpoint_file):
59        raise FileNotFoundError(
60            f"Training state file not found: {restart_checkpoint_file}"
61        )
62    with open(restart_checkpoint_file, "rb") as f:
63        restart_checkpoint = pickle.load(f)
64    return restart_checkpoint
def main():
 67def main():
 68    parser = argparse.ArgumentParser(prog="fennol_train")
 69    parser.add_argument("config_file", type=str)
 70    parser.add_argument("--model_file", type=str, default=None)
 71    args = parser.parse_args()
 72    config_file = args.config_file
 73    model_file = args.model_file
 74
 75    os.environ["OMP_NUM_THREADS"] = "1"
 76    sys.stdout = io.TextIOWrapper(
 77        open(sys.stdout.fileno(), "wb", 0), write_through=True
 78    )
 79
 80    restart_training = False
 81    if os.path.isdir(config_file):
 82        output_directory = Path(config_file).absolute().as_posix()
 83        restart_training = True
 84        while output_directory.endswith("/"):
 85            output_directory = output_directory[:-1]
 86        config_file = output_directory + "/config.yaml"
 87        backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
 88        shutil.copytree(output_directory, backup_dir)
 89        print("Restarting training from", output_directory)
 90
 91    parameters = load_configuration(config_file)
 92
 93    ### Set the device
 94    if "FENNOL_DEVICE" in os.environ:
 95        device = os.environ["FENNOL_DEVICE"].lower()
 96        print(f"# Setting device from env FENNOL_DEVICE={device}")
 97    else:
 98        device: str = parameters.get("device", "cpu").lower()
 99    if device == "cpu":
100        jax.config.update('jax_platforms', 'cpu')
101        os.environ["CUDA_VISIBLE_DEVICES"] = ""
102    elif device.startswith("cuda") or device.startswith("gpu"):
103        if ":" in device:
104            num = device.split(":")[-1]
105            os.environ["CUDA_VISIBLE_DEVICES"] = num
106        else:
107            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
108        device = "gpu"
109
110    _device = jax.devices(device)[0]
111    jax.config.update("jax_default_device", _device)
112
113    if restart_training:
114        restart_checkpoint = load_restart_checkpoint(output_directory)
115    else:
116        restart_checkpoint = None
117
118    # output directory
119    if not restart_training:
120        output_directory = parameters.get("output_directory", None)
121        if output_directory is not None:
122            if "{now}" in output_directory:
123                output_directory = output_directory.replace(
124                    "{now}", time.strftime("%Y-%m-%d-%H-%M-%S")
125                )
126            output_directory = Path(output_directory).absolute()
127            if not output_directory.exists():
128                output_directory.mkdir(parents=True)
129            print("Output directory:", output_directory)
130        else:
131            output_directory = "."
132
133    output_directory = str(output_directory) + "/"
134
135    # copy config_file to output directory
136    # config_name = Path(config_file).name
137    config_ext = Path(config_file).suffix
138    with open(config_file) as f_in:
139        config_data = f_in.read()
140    with open(output_directory + "/config" + config_ext, "w") as f_out:
141        f_out.write(config_data)
142
143    # set log file
144    log_file = "train.log"  # parameters.get("log_file", None)
145    logger = TeeLogger(output_directory + log_file)
146    logger.bind_stdout()
147
148    # set matmul precision
149    enable_x64 = parameters.get("double_precision", False)
150    jax.config.update("jax_enable_x64", enable_x64)
151    fprec = "float64" if enable_x64 else "float32"
152    parameters["fprec"] = fprec
153    if enable_x64:
154        print("Double precision enabled.")
155
156    matmul_precision = parameters.get("matmul_prec", "highest").lower()
157    assert matmul_precision in [
158        "default",
159        "high",
160        "highest",
161    ], "matmul_prec must be one of 'default','high','highest'"
162    jax.config.update("jax_default_matmul_precision", matmul_precision)
163
164    # set random seed
165    rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1))
166    print(f"rng_seed: {rng_seed}")
167    rng_key = jax.random.PRNGKey(rng_seed)
168    torch.manual_seed(rng_seed)
169    np.random.seed(rng_seed)
170    random.seed(rng_seed)
171    np_rng = np.random.Generator(np.random.PCG64(rng_seed))
172
173    try:
174        if "stages" in parameters["training"]:
175            ## train in stages ##
176            params = deepcopy(parameters)
177            stages = params["training"].pop("stages")
178            assert isinstance(stages, dict), "'stages' must be a dict with named stages"
179            model_file_stage = model_file
180            print_stages_params = params["training"].get("print_stages_params", False)
181            for i, (stage, stage_params) in enumerate(stages.items()):
182                rng_key, subkey = jax.random.split(rng_key)
183                print("")
184                print(f"### STAGE {i+1}: {stage} ###")
185
186                ## remove end_event from previous stage ##
187                if i > 0 and "end_event" in params["training"]:
188                    params["training"].pop("end_event")
189
190                ## incrementally update training parameters ##
191                params = deep_update(params, {"training": stage_params})
192                if model_file_stage is not None:
193                    ## load model from previous stage ##
194                    params["model_file"] = model_file_stage
195
196                if restart_training and restart_checkpoint["stage"] != i + 1:
197                    print(f"Skipping stage {i+1} (already completed)")
198                    continue
199
200                if print_stages_params:
201                    print("stage parameters:")
202                    print(json.dumps(params, indent=2, sort_keys=False))
203
204                ## train stage ##
205                _, model_file_stage = train(
206                    (subkey, np_rng),
207                    params,
208                    stage=i + 1,
209                    output_directory=output_directory,
210                    restart_checkpoint=restart_checkpoint,
211                )
212
213                restart_training = False
214                restart_checkpoint = None
215        else:
216            ## single training stage ##
217            train(
218                (rng_key, np_rng),
219                parameters,
220                model_file=model_file,
221                output_directory=output_directory,
222                restart_checkpoint=restart_checkpoint,
223            )
224    except KeyboardInterrupt:
225        print("Training interrupted by user.")
226    finally:
227        if log_file is not None:
228            logger.unbind_stdout()
229            logger.close()
def train( rng, parameters, model_file=None, stage=None, output_directory=None, restart_checkpoint=None):
232def train(
233    rng,
234    parameters,
235    model_file=None,
236    stage=None,
237    output_directory=None,
238    restart_checkpoint=None,
239):
240    if output_directory is None:
241        output_directory = "./"
242    elif output_directory == "":
243        output_directory = "./"
244    elif not output_directory.endswith("/"):
245        output_directory += "/"
246    stage_prefix = f"_stage_{stage}" if stage is not None else ""
247
248    if isinstance(rng, tuple):
249        rng_key, np_rng = rng
250    else:
251        rng_key = rng
252        np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1)))
253
254    ### GET TRAINING PARAMETERS ###
255    training_parameters = parameters.get("training", {})
256
257    #### LOAD MODEL ####
258    if restart_checkpoint is not None:
259        model_key = None
260        model_file = output_directory + "latest_model.fnx"
261    else:
262        rng_key, model_key = jax.random.split(rng_key)
263    model = load_model(parameters, model_file, rng_key=model_key)
264
265    ### SAVE INITIAL MODEL ###
266    save_initial_model = (
267        restart_checkpoint is None 
268        and (stage is None or stage == 0)
269    )
270    if save_initial_model:
271        model.save(output_directory + "initial_model.fnx")
272
273    model_ref = None
274    if "model_ref" in training_parameters:
275        model_ref = load_model(parameters, training_parameters["model_ref"])
276        print("Reference model:", training_parameters["model_ref"])
277
278        if "ref_parameters" in training_parameters:
279            ref_parameters = training_parameters["ref_parameters"]
280            assert isinstance(
281                ref_parameters, list
282            ), "ref_parameters must be a list of str"
283            print("Reference parameters:", ref_parameters)
284            model.variables = copy_parameters(
285                model.variables, model_ref.variables, ref_parameters
286            )
287
288    ### SET FLOATING POINT PRECISION ###
289    fprec = parameters.get("fprec", "float32")
290
291    def convert_to_fprec(x):
292        if jnp.issubdtype(x.dtype, jnp.floating):
293            return x.astype(fprec)
294        return x
295
296    model.variables = jax.tree_map(convert_to_fprec, model.variables)
297
298    ### SET UP LOSS FUNCTION ###
299    loss_definition, used_keys, ref_keys = get_loss_definition(
300        training_parameters, model_energy_unit=model.energy_unit
301    )
302
303    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
304    if coordinates_ref_key is not None:
305        compute_ref_coords = True
306        print("Reference coordinates:", coordinates_ref_key)
307    else:
308        compute_ref_coords = False
309
310    #### LOAD DATASET ####
311    dspath = training_parameters.get("dspath", None)
312    if dspath is None:
313        raise ValueError("Dataset path 'training/dspath' should be specified.")
314    batch_size = training_parameters.get("batch_size", 16)
315    batch_size_val = training_parameters.get("batch_size_val", None)
316    rename_refs = training_parameters.get("rename_refs", {})
317    training_iterator, validation_iterator = load_dataset(
318        dspath=dspath,
319        batch_size=batch_size,
320        batch_size_val=batch_size_val,
321        training_parameters=training_parameters,
322        infinite_iterator=True,
323        atom_padding=True,
324        ref_keys=ref_keys,
325        split_data_inputs=True,
326        np_rng=np_rng,
327        add_flags=["training"],
328        fprec=fprec,
329        rename_refs=rename_refs,
330    )
331
332    compute_forces = "forces" in used_keys
333    compute_virial = "virial_tensor" in used_keys or "virial" in used_keys
334    compute_stress = "stress_tensor" in used_keys or "stress" in used_keys
335    compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys
336
337    ### get optimizer parameters ###
338    max_epochs = int(training_parameters.get("max_epochs", 2000))
339    epoch_format = len(str(max_epochs))
340    nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200)
341    nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20)
342
343    schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters)
344    optimizer = get_optimizer(
345        training_parameters, model.variables, schedule(sch_state)[0]
346    )
347    opt_st = optimizer.init(model.variables)
348
349    ### exponential moving average of the parameters ###
350    ema_decay = training_parameters.get("ema_decay", -1.0)
351    if ema_decay > 0.0:
352        assert ema_decay < 1.0, "ema_decay must be in (0,1)"
353        ema = optax.ema(decay=ema_decay)
354    else:
355        ema = optax.identity()
356    ema_st = ema.init(model.variables)
357
358    #### end event ####
359    end_event = training_parameters.get("end_event", None)
360    if end_event is None or isinstance(end_event, str) and end_event.lower() == "none":
361        is_end = lambda metrics: False
362    else:
363        assert len(end_event) == 2, "end_event must be a list of two elements"
364        is_end = lambda metrics: metrics[end_event[0]] < end_event[1]
365
366    if "energy_terms" in training_parameters:
367        model.set_energy_terms(training_parameters["energy_terms"], jit=False)
368    print("energy terms:", model.energy_terms)
369
370    ### MODEL EVALUATION FUNCTION ###
371    pbc_training = training_parameters.get("pbc_training", False)
372    if compute_stress or compute_virial or compute_pressure:
373        virial_key = "virial" if "virial" in used_keys else "virial_tensor"
374        stress_key = "stress" if "stress" in used_keys else "stress_tensor"
375        pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor"
376        if compute_stress or compute_pressure:
377            assert pbc_training, "PBC must be enabled for stress or virial training"
378            print("Computing forces and stress tensor")
379            def evaluate(model, variables, data):
380                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
381                cells = output["cells"]
382                volume = jnp.abs(jnp.linalg.det(cells))
383                stress = vir / volume[:, None, None]
384                output[stress_key] = stress
385                output[virial_key] = vir
386                if pressure_key == "pressure":
387                    output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0
388                else:
389                    output[pressure_key] = -stress
390                return output
391
392        else:
393            print("Computing forces and virial tensor")
394            def evaluate(model, variables, data):
395                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
396                output[virial_key] = vir
397                return output
398
399    elif compute_forces:
400        print("Computing forces")
401        def evaluate(model, variables, data):
402            _, _, output = model._energy_and_forces(variables, data)
403            return output
404
405    elif model.energy_terms is not None:
406        def evaluate(model, variables, data):
407            _, output = model._total_energy(variables, data)
408            return output
409
410    else:
411        def evaluate(model, variables, data):
412            output = model.modules.apply(variables, data)
413            return output
414
415    ### TRAINING FUNCTIONS ###
416    train_step = get_train_step_function(
417        loss_definition=loss_definition,
418        model=model,
419        model_ref=model_ref,
420        compute_ref_coords=compute_ref_coords,
421        evaluate=evaluate,
422        optimizer=optimizer,
423        ema=ema,
424    )
425
426    validation = get_validation_function(
427        loss_definition=loss_definition,
428        model=model,
429        model_ref=model_ref,
430        compute_ref_coords=compute_ref_coords,
431        evaluate=evaluate,
432        return_targets=False,
433    )
434
435    #### CONFIGURE PREPROCESSING ####
436    gpu_preprocessing = training_parameters.get("gpu_preprocessing", False)
437    if gpu_preprocessing:
438        print("GPU preprocessing activated.")
439
440    preproc_state = unfreeze(model.preproc_state)
441    layer_state = []
442    for st in preproc_state["layers_state"]:
443        stnew = unfreeze(st)
444        #     st["nblist_skin"] = nblist_skin
445        #     if nblist_stride > 1:
446        #         st["skin_stride"] = nblist_stride
447        #         st["skin_count"] = nblist_stride
448        if "nblist_mult_size" in training_parameters:
449            stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
450        if "nblist_add_neigh" in training_parameters:
451            stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
452        if "nblist_add_atoms" in training_parameters:
453            stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
454        layer_state.append(freeze(stnew))
455
456    preproc_state["layers_state"] = tuple(layer_state)
457    preproc_state["check_input"] = False
458    model.preproc_state = freeze(preproc_state)
459
460    ### INITIALIZE METRICS ###
461    maes_prev = defaultdict(lambda: np.inf)
462    metrics_beta = training_parameters.get("metrics_ema_decay", -1.0)
463    smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0
464    if smoothen_metrics:
465        print("Computing smoothed metrics with beta =", metrics_beta)
466        rmses_smooth = defaultdict(lambda: 0.0)
467        maes_smooth = defaultdict(lambda: 0.0)
468        rmse_tot_smooth = 0.0
469        mae_tot_smooth = 0.0
470        nsmooth = 0
471    max_restore_count = training_parameters.get("max_restore_count", 5)    
472
473    fmetrics = open(
474        output_directory + f"metrics{stage_prefix}.traj",
475        "w" if restart_checkpoint is None else "a",
476    )
477
478    keep_all_bests = training_parameters.get("keep_all_bests", False)
479    save_model_at_epochs = training_parameters.get("save_model_at_epochs", [])
480    save_model_at_epochs= set([int(x) for x in save_model_at_epochs])
481    save_model_every = int(training_parameters.get("save_model_every_epoch", 0))
482    if save_model_every > 0:
483        save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every))
484        save_model_at_epochs = save_model_at_epochs.union(save_model_every)
485    if "save_model_schedule" in training_parameters:
486        save_model_schedule_def = [int(i) for i in training_parameters["save_model_schedule"]]
487        save_model_schedule = []
488        while len(save_model_schedule_def) > 0:
489            i = save_model_schedule_def.pop(0)
490            if i > 0:
491                save_model_schedule.append(i)
492            elif i < 0:
493                start = -i
494                freq = save_model_schedule_def.pop(0)
495                assert freq > 0, "syntax error in save_model_schedule"
496                save_model_schedule.extend(
497                    list(range(start, max_epochs + 1, freq))
498                )
499        save_model_at_epochs = save_model_at_epochs.union(save_model_schedule)
500    save_model_at_epochs = list(save_model_at_epochs)
501    save_model_at_epochs.sort()
502
503    ### LR factor after restoring a model that has diverged
504    restore_scaling = training_parameters.get("restore_scaling", 1)
505    assert restore_scaling > 0.0, "restore_scaling must be > 0.0"
506    assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0"
507    if restore_scaling != 1:
508        print("Applying LR scaling after restore:", restore_scaling)
509
510
511    metric_use_best = training_parameters.get("metric_best", "rmse_tot")  # .lower()
512    metrics_state = {}
513
514    ### INITIALIZE TRAINING STATE ###
515    if restart_checkpoint is not None:
516        training_state = deepcopy(restart_checkpoint["training"])
517        epoch_start = training_state["epoch"]
518        model.preproc_state = freeze(restart_checkpoint["preproc_state"])
519        if smoothen_metrics:
520            metrics_state = deepcopy(restart_checkpoint["metrics_state"])
521            rmses_smooth = metrics_state["rmses_smooth"]
522            maes_smooth = metrics_state["maes_smooth"]
523            rmse_tot_smooth = metrics_state["rmse_tot_smooth"]
524            mae_tot_smooth = metrics_state["mae_tot_smooth"]
525            nsmooth = metrics_state["nsmooth"]
526        print("Restored training state")
527    else:
528        epoch_start = 1
529        training_state = {
530            "rng_key": rng_key,
531            "opt_state": opt_st,
532            "ema_state": ema_st,
533            "sch_state": sch_state,
534            "variables": model.variables,
535            "variables_ema": model.variables,
536            "step": 0,
537            "restore_count": 0,
538            "restore_scale": 1,
539            "epoch": 0,
540            "best_metric": np.inf,
541        }
542        if smoothen_metrics:
543            metrics_state = {
544                "rmses_smooth": dict(rmses_smooth),
545                "maes_smooth": dict(maes_smooth),
546                "rmse_tot_smooth": rmse_tot_smooth,
547                "mae_tot_smooth": mae_tot_smooth,
548                "nsmooth": nsmooth,
549            }
550    
551    del opt_st, ema_st, sch_state, preproc_state
552    variables_save = training_state['variables']
553    variables_ema_save = training_state['variables_ema']
554
555    ### Training loop ###
556    start = time.time()
557    print("Starting training...")
558    for epoch in range(epoch_start, max_epochs+1):
559        training_state["epoch"] = epoch
560        s = time.time()
561        for _ in range(nbatch_per_epoch):
562            # fetch data
563            inputs0, data = next(training_iterator)
564
565            # preprocess data
566            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
567
568            rng_key, subkey = jax.random.split(rng_key)
569            inputs["rng_key"] = subkey
570            inputs["training_epoch"] = epoch
571
572            # adjust learning rate
573            current_lr, training_state["sch_state"] = schedule(training_state["sch_state"])
574            current_lr *= training_state["restore_scale"]
575            training_state["opt_state"].inner_states["trainable"].inner_state[-1].hyperparams[
576                "step_size"
577            ] = current_lr 
578
579            # train step
580            loss,training_state, _ = train_step(data,inputs,training_state)
581
582        rmses_avg = defaultdict(lambda: 0.0)
583        maes_avg = defaultdict(lambda: 0.0)
584        for _ in range(nbatch_per_validation):
585            inputs0, data = next(validation_iterator)
586
587            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
588
589            rng_key, subkey = jax.random.split(rng_key)
590            inputs["rng_key"] = subkey
591            inputs["training_epoch"] = epoch
592
593            rmses, maes, output_val = validation(
594                data=data,
595                inputs=inputs,
596                variables=training_state["variables_ema"],
597            )
598            for k, v in rmses.items():
599                rmses_avg[k] += v
600            for k, v in maes.items():
601                maes_avg[k] += v
602
603        jax.block_until_ready(output_val)
604
605        ### Timings ###
606        e = time.time()
607        epoch_time = e - s
608
609        elapsed_time = e - start
610        if not adaptive_scheduler:
611            remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch
612            remain_last = epoch_time * (max_epochs - epoch + 1)
613            # estimate remaining time via weighted average (put more weight on last at the beginning)
614            wremain = np.sin(0.5 * np.pi * epoch / max_epochs)
615            remaining_time = human_time_duration(
616                remain_glob * wremain + remain_last * (1 - wremain)
617            )
618        elapsed_time = human_time_duration(elapsed_time)
619        batch_time = human_time_duration(
620            epoch_time / (nbatch_per_epoch + nbatch_per_validation)
621        )
622        epoch_time = human_time_duration(epoch_time)
623
624        for k in rmses_avg.keys():
625            rmses_avg[k] /= nbatch_per_validation
626        for k in maes_avg.keys():
627            maes_avg[k] /= nbatch_per_validation
628
629        ### Print metrics ###
630        print("")
631        line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}"
632        line += f", epoch time = {epoch_time}, batch time = {batch_time}"
633        line += f", elapsed time = {elapsed_time}"
634        if not adaptive_scheduler:
635            line += f", est. remaining time = {remaining_time}"
636        print(line)
637        rmse_tot = 0.0
638        mae_tot = 0.0
639        for k in rmses_avg.keys():
640            mult = loss_definition[k]["mult"]
641            loss_prms = loss_definition[k]
642            rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5
643            mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"]
644            unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else ""
645
646            weight_str = ""
647            if "weight_schedule" in loss_prms:
648                w = linear_schedule(epoch, *loss_prms["weight_schedule"])
649                weight_str = f"(w={w:.3f})"
650
651            if rmses_avg[k] / mult < 1.0e-2:
652                print(
653                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e}   {unit}  {weight_str}"
654                )
655            else:
656                print(
657                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f}   {unit}  {weight_str}"
658                )
659
660        ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ###
661        restore = False
662        reinit = False
663        for k, mae in maes_avg.items():
664            if np.isnan(mae):
665                restore = True
666                reinit = True
667                print("NaN detected in mae")
668                # for k, v in inputs.items():
669                #     if hasattr(v,"shape"):
670                #         if np.isnan(v).any():
671                #             print(k,v)
672                # sys.exit(1)
673            if "threshold" in loss_definition[k]:
674                thr = loss_definition[k]["threshold"]
675                if mae > thr * maes_prev[k]:
676                    restore = True
677                    break
678
679        epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S")
680        model.variables = training_state["variables_ema"]
681        if epoch in save_model_at_epochs:
682            filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx"
683            print("Scheduled model save :", filename)
684            model.save(filename)
685
686        if restore:
687            training_state["restore_count"] += 1
688            if training_state["restore_count"]  > max_restore_count:
689                if reinit:
690                    raise ValueError("Model diverged and could not be restored.")
691                else:
692                    training_state["restore_count"]  = 0
693                    print(
694                        f"{max_restore_count} unsuccessful restores, resuming training."
695                    )
696
697            training_state["variables"] = deepcopy(variables_save)
698            training_state["variables_ema"] = deepcopy(variables_ema_save)
699            model.variables = training_state["variables_ema"]
700            training_state["restore_scale"] *= restore_scaling
701            print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}")
702            if reinit:
703                training_state["opt_state"] = optimizer.init(model.variables)
704                training_state["ema_state"] = ema.init(model.variables)
705                print("Reinitialized optimizer after divergence.")
706            save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
707            continue
708
709        training_state["restore_count"]  = 0
710
711        variables_save = deepcopy(training_state["variables"])
712        variables_ema_save = deepcopy(training_state["variables_ema"])
713        maes_prev = maes_avg
714
715        ### SAVE MODEL ###
716        model.save(output_directory + "latest_model.fnx")
717
718        # save metrics
719        step = int(training_state["step"])
720        metrics = {
721            "epoch": epoch,
722            "step": step,
723            "data_count": step * batch_size,
724            "elapsed_time": time.time() - start,
725            "lr": current_lr,
726            "loss": loss,
727        }
728        for k in rmses_avg.keys():
729            mult = loss_definition[k]["mult"]
730            metrics[f"rmse_{k}"] = rmses_avg[k] / mult
731            metrics[f"mae_{k}"] = maes_avg[k] / mult
732
733        metrics["rmse_tot"] = rmse_tot
734        metrics["mae_tot"] = mae_tot
735        if smoothen_metrics:
736            nsmooth += 1
737            for k in rmses_avg.keys():
738                mult = loss_definition[k]["mult"]
739                rmses_smooth[k] = (
740                    metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k]
741                )
742                maes_smooth[k] = (
743                    metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k]
744                )
745                metrics[f"rmse_smooth_{k}"] = (
746                    rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
747                )
748                metrics[f"mae_smooth_{k}"] = (
749                    maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
750                )
751            rmse_tot_smooth = (
752                metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot
753            )
754            mae_tot_smooth = (
755                metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot
756            )
757            metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth)
758            metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth)
759
760            # update metrics state
761            metrics_state["rmses_smooth"] = dict(rmses_smooth)
762            metrics_state["maes_smooth"] = dict(maes_smooth)
763            metrics_state["rmse_tot_smooth"] = rmse_tot_smooth
764            metrics_state["mae_tot_smooth"] = mae_tot_smooth
765            metrics_state["nsmooth"] = nsmooth
766
767        assert (
768            metric_use_best in metrics
769        ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics"
770
771        metric_for_best = metrics[metric_use_best]
772        if metric_for_best <  training_state["best_metric"]:
773            training_state["best_metric"] = metric_for_best
774            metrics["best_metric"] = training_state["best_metric"]
775            if keep_all_bests:
776                best_name = (
777                    output_directory
778                    + f"best_model{stage_prefix}_epoch{epoch:0{epoch_format}}_{epoch_time_str}.fnx"
779                )
780                model.save(best_name)
781
782            best_name = output_directory + f"best_model{stage_prefix}.fnx"
783            model.save(best_name)
784            print("New best model saved to:", best_name)
785        else:
786            metrics["best_metric"] = training_state["best_metric"]
787
788        if epoch == 1:
789            headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())]
790            fmetrics.write("# " + " ".join(headers) + "\n")
791        fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n")
792        fmetrics.flush()
793
794        # update learning rate using current metrics
795        assert (
796            schedule_metrics in metrics
797        ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics"
798        current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics])
799
800        # update and save training state
801        save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
802
803        
804        if is_end(metrics):
805            print("Stage finished.")
806            break
807
808    end = time.time()
809
810    print(f"Training time: {human_time_duration(end-start)}")
811    print("")
812
813    fmetrics.close()
814
815    filename = output_directory + f"final_model{stage_prefix}.fnx"
816    model.save(filename)
817    print("Final model saved to:", filename)
818
819    return metrics, filename