fennol.training.training

  1import os
  2import yaml
  3import sys
  4import jax
  5import io
  6import time
  7import jax.numpy as jnp
  8import numpy as np
  9import optax
 10from collections import defaultdict
 11import json
 12from copy import deepcopy
 13from pathlib import Path
 14import argparse
 15try:
 16    import torch
 17except ImportError:
 18    raise ImportError(
 19        "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/"
 20    )
 21import random
 22from flax import traverse_util
 23import json
 24import shutil
 25import pickle
 26
 27from flax.core import freeze, unfreeze
 28from .io import (
 29    load_configuration,
 30    load_dataset,
 31    load_model,
 32    TeeLogger,
 33    copy_parameters,
 34)
 35from .utils import (
 36    get_loss_definition,
 37    get_train_step_function,
 38    get_validation_function,
 39    linear_schedule,
 40)
 41from .optimizers import get_optimizer, get_lr_schedule, multi_ema
 42from ..utils import deep_update, AtomicUnits as au
 43from ..utils.io import human_time_duration
 44
 45def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state):
 46    with open(output_directory + "/restart_checkpoint.pkl", "wb") as f:
 47            pickle.dump({
 48                "stage": stage,
 49                "training": training_state,
 50                "preproc_state": preproc_state,
 51                "metrics_state": metrics_state,
 52            }, f)
 53
 54def load_restart_checkpoint(output_directory):
 55    restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl"
 56    if not os.path.exists(restart_checkpoint_file):
 57        raise FileNotFoundError(
 58            f"Training state file not found: {restart_checkpoint_file}"
 59        )
 60    with open(restart_checkpoint_file, "rb") as f:
 61        restart_checkpoint = pickle.load(f)
 62    return restart_checkpoint
 63
 64
 65def main():
 66    parser = argparse.ArgumentParser(prog="fennol_train")
 67    parser.add_argument("config_file", type=str)
 68    parser.add_argument("--model_file", type=str, default=None)
 69    args = parser.parse_args()
 70    config_file = args.config_file
 71    model_file = args.model_file
 72
 73    os.environ["OMP_NUM_THREADS"] = "1"
 74    sys.stdout = io.TextIOWrapper(
 75        open(sys.stdout.fileno(), "wb", 0), write_through=True
 76    )
 77
 78    restart_training = False
 79    if os.path.isdir(config_file):
 80        output_directory = Path(config_file).absolute().as_posix()
 81        restart_training = True
 82        while output_directory.endswith("/"):
 83            output_directory = output_directory[:-1]
 84        config_file = output_directory + "/config.yaml"
 85        backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
 86        shutil.copytree(output_directory, backup_dir)
 87        print("Restarting training from", output_directory)
 88
 89    parameters = load_configuration(config_file)
 90
 91    ### Set the device
 92    if "FENNOL_DEVICE" in os.environ:
 93        device = os.environ["FENNOL_DEVICE"].lower()
 94        print(f"# Setting device from env FENNOL_DEVICE={device}")
 95    else:
 96        device: str = parameters.get("device", "cpu").lower()
 97    if device == "cpu":
 98        jax.config.update('jax_platforms', 'cpu')
 99        os.environ["CUDA_VISIBLE_DEVICES"] = ""
100    elif device.startswith("cuda") or device.startswith("gpu"):
101        if ":" in device:
102            num = device.split(":")[-1]
103            os.environ["CUDA_VISIBLE_DEVICES"] = num
104        else:
105            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
106        device = "gpu"
107
108    _device = jax.devices(device)[0]
109    jax.config.update("jax_default_device", _device)
110
111    if restart_training:
112        restart_checkpoint = load_restart_checkpoint(output_directory)
113    else:
114        restart_checkpoint = None
115
116    # output directory
117    if not restart_training:
118        output_directory = parameters.get("output_directory", None)
119        if output_directory is not None:
120            if "{now}" in output_directory:
121                output_directory = output_directory.replace(
122                    "{now}", time.strftime("%Y-%m-%d-%H-%M-%S")
123                )
124            output_directory = Path(output_directory).absolute()
125            if not output_directory.exists():
126                output_directory.mkdir(parents=True)
127            print("Output directory:", output_directory)
128        else:
129            output_directory = "."
130
131    output_directory = str(output_directory) + "/"
132
133    # copy config_file to output directory
134    # config_name = Path(config_file).name
135    config_ext = Path(config_file).suffix
136    with open(config_file) as f_in:
137        config_data = f_in.read()
138    with open(output_directory + "/config" + config_ext, "w") as f_out:
139        f_out.write(config_data)
140
141    # set log file
142    log_file = "train.log"  # parameters.get("log_file", None)
143    logger = TeeLogger(output_directory + log_file)
144    logger.bind_stdout()
145
146    # set matmul precision
147    enable_x64 = parameters.get("double_precision", False)
148    jax.config.update("jax_enable_x64", enable_x64)
149    fprec = "float64" if enable_x64 else "float32"
150    parameters["fprec"] = fprec
151    if enable_x64:
152        print("Double precision enabled.")
153
154    matmul_precision = parameters.get("matmul_prec", "highest").lower()
155    assert matmul_precision in [
156        "default",
157        "high",
158        "highest",
159    ], "matmul_prec must be one of 'default','high','highest'"
160    jax.config.update("jax_default_matmul_precision", matmul_precision)
161
162    # set random seed
163    rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1))
164    print(f"rng_seed: {rng_seed}")
165    rng_key = jax.random.PRNGKey(rng_seed)
166    torch.manual_seed(rng_seed)
167    np.random.seed(rng_seed)
168    random.seed(rng_seed)
169    np_rng = np.random.Generator(np.random.PCG64(rng_seed))
170
171    try:
172        if "stages" in parameters["training"]:
173            ## train in stages ##
174            params = deepcopy(parameters)
175            stages = params["training"].pop("stages")
176            assert isinstance(stages, dict), "'stages' must be a dict with named stages"
177            model_file_stage = model_file
178            print_stages_params = params["training"].get("print_stages_params", False)
179            for i, (stage, stage_params) in enumerate(stages.items()):
180                rng_key, subkey = jax.random.split(rng_key)
181                print("")
182                print(f"### STAGE {i+1}: {stage} ###")
183
184                ## remove end_event from previous stage ##
185                if i > 0 and "end_event" in params["training"]:
186                    params["training"].pop("end_event")
187
188                ## incrementally update training parameters ##
189                params = deep_update(params, {"training": stage_params})
190                if model_file_stage is not None:
191                    ## load model from previous stage ##
192                    params["model_file"] = model_file_stage
193
194                if restart_training and restart_checkpoint["stage"] != i + 1:
195                    print(f"Skipping stage {i+1} (already completed)")
196                    continue
197
198                if print_stages_params:
199                    print("stage parameters:")
200                    print(json.dumps(params, indent=2, sort_keys=False))
201
202                ## train stage ##
203                _, model_file_stage = train(
204                    (subkey, np_rng),
205                    params,
206                    stage=i + 1,
207                    output_directory=output_directory,
208                    restart_checkpoint=restart_checkpoint,
209                )
210
211                restart_training = False
212                restart_checkpoint = None
213        else:
214            ## single training stage ##
215            train(
216                (rng_key, np_rng),
217                parameters,
218                model_file=model_file,
219                output_directory=output_directory,
220                restart_checkpoint=restart_checkpoint,
221            )
222    except KeyboardInterrupt:
223        print("Training interrupted by user.")
224    finally:
225        if log_file is not None:
226            logger.unbind_stdout()
227            logger.close()
228
229
230def train(
231    rng,
232    parameters,
233    model_file=None,
234    stage=None,
235    output_directory=None,
236    restart_checkpoint=None,
237):
238    if output_directory is None:
239        output_directory = "./"
240    elif output_directory == "":
241        output_directory = "./"
242    elif not output_directory.endswith("/"):
243        output_directory += "/"
244    stage_prefix = f"_stage_{stage}" if stage is not None else ""
245
246    if isinstance(rng, tuple):
247        rng_key, np_rng = rng
248    else:
249        rng_key = rng
250        np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1)))
251
252    ### GET TRAINING PARAMETERS ###
253    training_parameters = parameters.get("training", {})
254
255    #### LOAD MODEL ####
256    if restart_checkpoint is not None:
257        model_key = None
258        model_file = output_directory + "latest_model.fnx"
259    else:
260        rng_key, model_key = jax.random.split(rng_key)
261    model = load_model(parameters, model_file, rng_key=model_key)
262
263    ### SAVE INITIAL MODEL ###
264    save_initial_model = (
265        restart_checkpoint is None 
266        and (stage is None or stage == 0)
267    )
268    if save_initial_model:
269        model.save(output_directory + "initial_model.fnx")
270
271    model_ref = None
272    if "model_ref" in training_parameters:
273        model_ref = load_model(parameters, training_parameters["model_ref"])
274        print("Reference model:", training_parameters["model_ref"])
275
276        if "ref_parameters" in training_parameters:
277            ref_parameters = training_parameters["ref_parameters"]
278            assert isinstance(
279                ref_parameters, list
280            ), "ref_parameters must be a list of str"
281            print("Reference parameters:", ref_parameters)
282            model.variables = copy_parameters(
283                model.variables, model_ref.variables, ref_parameters
284            )
285
286    ### SET FLOATING POINT PRECISION ###
287    fprec = parameters.get("fprec", "float32")
288
289    def convert_to_fprec(x):
290        if jnp.issubdtype(x.dtype, jnp.floating):
291            return x.astype(fprec)
292        return x
293
294    model.variables = jax.tree_map(convert_to_fprec, model.variables)
295
296    ### SET UP LOSS FUNCTION ###
297    loss_definition, used_keys, ref_keys = get_loss_definition(
298        training_parameters, model_energy_unit=model.energy_unit
299    )
300
301    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
302    if coordinates_ref_key is not None:
303        compute_ref_coords = True
304        print("Reference coordinates:", coordinates_ref_key)
305    else:
306        compute_ref_coords = False
307
308    #### LOAD DATASET ####
309    dspath = training_parameters.get("dspath", None)
310    if dspath is None:
311        raise ValueError("Dataset path 'training/dspath' should be specified.")
312    batch_size = training_parameters.get("batch_size", 16)
313    batch_size_val = training_parameters.get("batch_size_val", None)
314    rename_refs = training_parameters.get("rename_refs", {})
315    training_iterator, validation_iterator = load_dataset(
316        dspath=dspath,
317        batch_size=batch_size,
318        batch_size_val=batch_size_val,
319        training_parameters=training_parameters,
320        infinite_iterator=True,
321        atom_padding=True,
322        ref_keys=ref_keys,
323        split_data_inputs=True,
324        np_rng=np_rng,
325        add_flags=["training"],
326        fprec=fprec,
327        rename_refs=rename_refs,
328    )
329
330    compute_forces = "forces" in used_keys
331    compute_virial = "virial_tensor" in used_keys or "virial" in used_keys
332    compute_stress = "stress_tensor" in used_keys or "stress" in used_keys
333    compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys
334
335    ### get optimizer parameters ###
336    max_epochs = int(training_parameters.get("max_epochs", 2000))
337    epoch_format = len(str(max_epochs))
338    nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200)
339    nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20)
340
341    schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters)
342    optimizer,ilr = get_optimizer(
343        training_parameters, model.variables, schedule(sch_state)[0]
344    )
345    opt_st = optimizer.init(model.variables)
346
347    ### exponential moving average of the parameters ###
348    ema_decay = training_parameters.get("ema_decay", None)
349    if ema_decay is not None:
350        if isinstance(ema_decay,float):
351            ema_decay = [ema_decay]
352        for decay in ema_decay:
353            assert isinstance(decay, float), "ema_decay must be a float or a list of floats"
354            assert 0. <= decay < 1.0, "ema_decay must be in (0,1)"
355        
356        power_ema = training_parameters.get("ema_power", False)
357        ema_ival = int(training_parameters.get("ema_validate", 0))
358        ema = multi_ema(decays=ema_decay,power=power_ema)  
359    else:
360        ema = multi_ema(decays=[]) # no ema
361        ema_ival = 0
362    ema_st = ema.init(model.variables)
363    variables_ema, _ = ema.update(model.variables, ema_st)
364
365    #### end event ####
366    end_event = training_parameters.get("end_event", None)
367    if end_event is None or isinstance(end_event, str) and end_event.lower() == "none":
368        is_end = lambda metrics: False
369    else:
370        assert len(end_event) == 2, "end_event must be a list of two elements"
371        is_end = lambda metrics: metrics[end_event[0]] < end_event[1]
372
373    if "energy_terms" in training_parameters:
374        model.set_energy_terms(training_parameters["energy_terms"], jit=False)
375    print("energy terms:", model.energy_terms)
376
377    ### MODEL EVALUATION FUNCTION ###
378    pbc_training = training_parameters.get("pbc_training", False)
379    if compute_stress or compute_virial or compute_pressure:
380        virial_key = "virial" if "virial" in used_keys else "virial_tensor"
381        stress_key = "stress" if "stress" in used_keys else "stress_tensor"
382        pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor"
383        if compute_stress or compute_pressure:
384            assert pbc_training, "PBC must be enabled for stress or virial training"
385            print("Computing forces and stress tensor")
386            def evaluate(model, variables, data):
387                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
388                cells = output["cells"]
389                volume = jnp.abs(jnp.linalg.det(cells))
390                stress = vir / volume[:, None, None]
391                output[stress_key] = stress
392                output[virial_key] = vir
393                if pressure_key == "pressure":
394                    output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0
395                else:
396                    output[pressure_key] = -stress
397                return output
398
399        else:
400            print("Computing forces and virial tensor")
401            def evaluate(model, variables, data):
402                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
403                output[virial_key] = vir
404                return output
405
406    elif compute_forces:
407        print("Computing forces")
408        def evaluate(model, variables, data):
409            _, _, output = model._energy_and_forces(variables, data)
410            return output
411
412    elif model.energy_terms is not None:
413        def evaluate(model, variables, data):
414            _, output = model._total_energy(variables, data)
415            return output
416
417    else:
418        def evaluate(model, variables, data):
419            output = model.modules.apply(variables, data)
420            return output
421
422    ### TRAINING FUNCTIONS ###
423    train_step = get_train_step_function(
424        loss_definition=loss_definition,
425        model=model,
426        model_ref=model_ref,
427        compute_ref_coords=compute_ref_coords,
428        evaluate=evaluate,
429        optimizer=optimizer,
430        ema=ema,
431    )
432
433    validation = get_validation_function(
434        loss_definition=loss_definition,
435        model=model,
436        model_ref=model_ref,
437        compute_ref_coords=compute_ref_coords,
438        evaluate=evaluate,
439        return_targets=False,
440    )
441
442    #### CONFIGURE PREPROCESSING ####
443    gpu_preprocessing = training_parameters.get("gpu_preprocessing", False)
444    if gpu_preprocessing:
445        print("GPU preprocessing activated.")
446
447    preproc_state = unfreeze(model.preproc_state)
448    layer_state = []
449    for st in preproc_state["layers_state"]:
450        stnew = unfreeze(st)
451        #     st["nblist_skin"] = nblist_skin
452        #     if nblist_stride > 1:
453        #         st["skin_stride"] = nblist_stride
454        #         st["skin_count"] = nblist_stride
455        if "nblist_mult_size" in training_parameters:
456            stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
457        if "nblist_add_neigh" in training_parameters:
458            stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
459        if "nblist_add_atoms" in training_parameters:
460            stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
461        layer_state.append(freeze(stnew))
462
463    preproc_state["layers_state"] = tuple(layer_state)
464    preproc_state["check_input"] = False
465    model.preproc_state = freeze(preproc_state)
466
467    ### INITIALIZE METRICS ###
468    maes_prev = defaultdict(lambda: np.inf)
469    metrics_beta = training_parameters.get("metrics_ema_decay", -1.0)
470    smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0
471    if smoothen_metrics:
472        print("Computing smoothed metrics with beta =", metrics_beta)
473        rmses_smooth = defaultdict(lambda: 0.0)
474        maes_smooth = defaultdict(lambda: 0.0)
475        rmse_tot_smooth = 0.0
476        mae_tot_smooth = 0.0
477        nsmooth = 0
478    max_restore_count = training_parameters.get("max_restore_count", 5)    
479
480    fmetrics = open(
481        output_directory + f"metrics{stage_prefix}.traj",
482        "w" if restart_checkpoint is None else "a",
483    )
484
485    keep_all_bests = training_parameters.get("keep_all_bests", False)
486    save_model_at_epochs = training_parameters.get("save_model_at_epochs", [])
487    save_model_at_epochs= set([int(x) for x in save_model_at_epochs])
488    save_model_every = int(training_parameters.get("save_model_every_epoch", 0))
489    if save_model_every > 0:
490        save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every))
491        save_model_at_epochs = save_model_at_epochs.union(save_model_every)
492    if "save_model_schedule" in training_parameters:
493        save_model_schedule_def = [int(i) for i in training_parameters["save_model_schedule"]]
494        save_model_schedule = []
495        while len(save_model_schedule_def) > 0:
496            i = save_model_schedule_def.pop(0)
497            if i > 0:
498                save_model_schedule.append(i)
499            elif i < 0:
500                start = -i
501                freq = save_model_schedule_def.pop(0)
502                assert freq > 0, "syntax error in save_model_schedule"
503                save_model_schedule.extend(
504                    list(range(start, max_epochs + 1, freq))
505                )
506        save_model_at_epochs = save_model_at_epochs.union(save_model_schedule)
507    save_model_at_epochs = list(save_model_at_epochs)
508    save_model_at_epochs.sort()
509
510    ### LR factor after restoring a model that has diverged
511    restore_scaling = training_parameters.get("restore_scaling", 1)
512    assert restore_scaling > 0.0, "restore_scaling must be > 0.0"
513    assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0"
514    if restore_scaling != 1:
515        print("Applying LR scaling after restore:", restore_scaling)
516
517
518    metric_use_best = training_parameters.get("metric_best", "rmse_tot")  # .lower()
519    metrics_state = {}
520
521    ### INITIALIZE TRAINING STATE ###
522    if restart_checkpoint is not None:
523        training_state = deepcopy(restart_checkpoint["training"])
524        epoch_start = training_state["epoch"]
525        model.preproc_state = freeze(restart_checkpoint["preproc_state"])
526        if smoothen_metrics:
527            metrics_state = deepcopy(restart_checkpoint["metrics_state"])
528            rmses_smooth = metrics_state["rmses_smooth"]
529            maes_smooth = metrics_state["maes_smooth"]
530            rmse_tot_smooth = metrics_state["rmse_tot_smooth"]
531            mae_tot_smooth = metrics_state["mae_tot_smooth"]
532            nsmooth = metrics_state["nsmooth"]
533        print("Restored training state")
534    else:
535        epoch_start = 1
536        training_state = {
537            "rng_key": rng_key,
538            "opt_state": opt_st,
539            "ema_state": ema_st,
540            "sch_state": sch_state,
541            "variables": model.variables,
542            "variables_ema": variables_ema,
543            "step": 0,
544            "restore_count": 0,
545            "restore_scale": 1,
546            "epoch": 0,
547            "best_metric": np.inf,
548        }
549        if smoothen_metrics:
550            metrics_state = {
551                "rmses_smooth": dict(rmses_smooth),
552                "maes_smooth": dict(maes_smooth),
553                "rmse_tot_smooth": rmse_tot_smooth,
554                "mae_tot_smooth": mae_tot_smooth,
555                "nsmooth": nsmooth,
556            }
557    
558    del opt_st, ema_st, sch_state, preproc_state, variables_ema
559    variables_save = training_state['variables']
560    variables_ema_save = training_state['variables_ema']
561
562    ### Training loop ###
563    start = time.time()
564    print("Starting training...")
565    for epoch in range(epoch_start, max_epochs+1):
566        training_state["epoch"] = epoch
567        s = time.time()
568        for _ in range(nbatch_per_epoch):
569            # fetch data
570            inputs0, data = next(training_iterator)
571
572            # preprocess data
573            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
574
575            rng_key, subkey = jax.random.split(rng_key)
576            inputs["rng_key"] = subkey
577            inputs["training_epoch"] = epoch
578
579            # adjust learning rate
580            current_lr, training_state["sch_state"] = schedule(training_state["sch_state"])
581            current_lr *= training_state["restore_scale"]
582            training_state["opt_state"].inner_states["trainable"].inner_state[ilr].hyperparams[
583                "step_size"
584            ] = current_lr 
585
586            # train step
587            loss,training_state, _ = train_step(data,inputs,training_state)
588
589        rmses_avg = defaultdict(lambda: 0.0)
590        maes_avg = defaultdict(lambda: 0.0)
591        for _ in range(nbatch_per_validation):
592            inputs0, data = next(validation_iterator)
593
594            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
595
596            rng_key, subkey = jax.random.split(rng_key)
597            inputs["rng_key"] = subkey
598            inputs["training_epoch"] = epoch
599
600            rmses, maes, output_val = validation(
601                data=data,
602                inputs=inputs,
603                variables=training_state["variables_ema"][ema_ival],
604            )
605            for k, v in rmses.items():
606                rmses_avg[k] += v
607            for k, v in maes.items():
608                maes_avg[k] += v
609
610        jax.block_until_ready(output_val)
611
612        ### Timings ###
613        e = time.time()
614        epoch_time = e - s
615
616        elapsed_time = e - start
617        if not adaptive_scheduler:
618            remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch
619            remain_last = epoch_time * (max_epochs - epoch + 1)
620            # estimate remaining time via weighted average (put more weight on last at the beginning)
621            wremain = np.sin(0.5 * np.pi * epoch / max_epochs)
622            remaining_time = human_time_duration(
623                remain_glob * wremain + remain_last * (1 - wremain)
624            )
625        elapsed_time = human_time_duration(elapsed_time)
626        batch_time = human_time_duration(
627            epoch_time / (nbatch_per_epoch + nbatch_per_validation)
628        )
629        epoch_time = human_time_duration(epoch_time)
630
631        for k in rmses_avg.keys():
632            rmses_avg[k] /= nbatch_per_validation
633        for k in maes_avg.keys():
634            maes_avg[k] /= nbatch_per_validation
635
636        ### Print metrics ###
637        print("")
638        line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}"
639        line += f", epoch time = {epoch_time}, batch time = {batch_time}"
640        line += f", elapsed time = {elapsed_time}"
641        if not adaptive_scheduler:
642            line += f", est. remaining time = {remaining_time}"
643        print(line)
644        rmse_tot = 0.0
645        mae_tot = 0.0
646        for k in rmses_avg.keys():
647            mult = loss_definition[k]["display_mult"]
648            loss_prms = loss_definition[k]
649            rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5
650            mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"]
651            unit = "(" + loss_prms["display_unit"] + ")" if "display_unit" in loss_prms else ""
652
653            weight_str = ""
654            if "weight_schedule" in loss_prms:
655                w = linear_schedule(epoch, *loss_prms["weight_schedule"])
656                weight_str = f"(w={w:.3f})"
657
658            if rmses_avg[k] / mult < 1.0e-2:
659                print(
660                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e}   {unit}  {weight_str}"
661                )
662            else:
663                print(
664                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f}   {unit}  {weight_str}"
665                )
666
667        ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ###
668        restore = False
669        reinit = False
670        for k, mae in maes_avg.items():
671            if np.isnan(mae):
672                restore = True
673                reinit = True
674                print("NaN detected in mae")
675                # for k, v in inputs.items():
676                #     if hasattr(v,"shape"):
677                #         if np.isnan(v).any():
678                #             print(k,v)
679                # sys.exit(1)
680            if "threshold" in loss_definition[k]:
681                thr = loss_definition[k]["threshold"]
682                if mae > thr * maes_prev[k]:
683                    restore = True
684                    break
685
686        epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S")
687        if epoch in save_model_at_epochs:
688            for i,v in enumerate(training_state["variables_ema"]):
689                model.variables = v
690                ema_str = f"_ema{i+1}" if len(training_state["variables_ema"]) > 1 else ""
691                filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}{ema_str}_{epoch_time_str}.fnx"
692                print("Scheduled model save :", filename)
693                model.save(filename)
694
695        if restore:
696            training_state["restore_count"] += 1
697            if training_state["restore_count"]  > max_restore_count:
698                if reinit:
699                    raise ValueError("Model diverged and could not be restored.")
700                else:
701                    training_state["restore_count"]  = 0
702                    print(
703                        f"{max_restore_count} unsuccessful restores, resuming training."
704                    )
705
706            training_state["variables"] = deepcopy(variables_save)
707            training_state["variables_ema"] = deepcopy(variables_ema_save)
708            training_state["restore_scale"] *= restore_scaling
709            print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}")
710            if reinit:
711                training_state["opt_state"] = optimizer.init(model.variables)
712                training_state["ema_state"] = ema.init(model.variables)
713                print("Reinitialized optimizer after divergence.")
714            save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
715            continue
716
717        training_state["restore_count"]  = 0
718
719        variables_save = deepcopy(training_state["variables"])
720        variables_ema_save = deepcopy(training_state["variables_ema"])
721        maes_prev = maes_avg
722
723        ### SAVE MODEL ###
724        model.variables = training_state["variables_ema"][ema_ival]
725        model.save(output_directory + "latest_model.fnx")
726
727        # save metrics
728        step = int(training_state["step"])
729        metrics = {
730            "epoch": epoch,
731            "step": step,
732            "data_count": step * batch_size,
733            "elapsed_time": time.time() - start,
734            "lr": current_lr,
735            "loss": loss,
736        }
737        for k in rmses_avg.keys():
738            mult = loss_definition[k]["mult"]
739            metrics[f"rmse_{k}"] = rmses_avg[k] / mult
740            metrics[f"mae_{k}"] = maes_avg[k] / mult
741
742        metrics["rmse_tot"] = rmse_tot
743        metrics["mae_tot"] = mae_tot
744        if smoothen_metrics:
745            nsmooth += 1
746            for k in rmses_avg.keys():
747                mult = loss_definition[k]["mult"]
748                rmses_smooth[k] = (
749                    metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k]
750                )
751                maes_smooth[k] = (
752                    metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k]
753                )
754                metrics[f"rmse_smooth_{k}"] = (
755                    rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
756                )
757                metrics[f"mae_smooth_{k}"] = (
758                    maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
759                )
760            rmse_tot_smooth = (
761                metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot
762            )
763            mae_tot_smooth = (
764                metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot
765            )
766            metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth)
767            metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth)
768
769            # update metrics state
770            metrics_state["rmses_smooth"] = dict(rmses_smooth)
771            metrics_state["maes_smooth"] = dict(maes_smooth)
772            metrics_state["rmse_tot_smooth"] = rmse_tot_smooth
773            metrics_state["mae_tot_smooth"] = mae_tot_smooth
774            metrics_state["nsmooth"] = nsmooth
775
776        assert (
777            metric_use_best in metrics
778        ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics"
779
780        metric_for_best = metrics[metric_use_best]
781        if metric_for_best <  training_state["best_metric"]:
782            training_state["best_metric"] = metric_for_best
783            metrics["best_metric"] = training_state["best_metric"]
784            best_epoch_str = f"_epoch{epoch:0{epoch_format}}_{epoch_time_str}" if keep_all_bests else ""
785            best_name = output_directory + f"best_model{stage_prefix}{best_epoch_str}.fnx"
786            model.save(best_name)
787            print("New best model saved to:", best_name)
788        else:
789            metrics["best_metric"] = training_state["best_metric"]
790
791        if epoch == 1:
792            headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())]
793            fmetrics.write("# " + " ".join(headers) + "\n")
794        fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n")
795        fmetrics.flush()
796
797        # update learning rate using current metrics
798        assert (
799            schedule_metrics in metrics
800        ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics"
801        current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics])
802
803        # update and save training state
804        save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
805
806        
807        if is_end(metrics):
808            print("Stage finished.")
809            break
810
811    end = time.time()
812
813    print(f"Training time: {human_time_duration(end-start)}")
814    print("")
815
816    fmetrics.close()
817
818    filename = output_directory + f"final_model{stage_prefix}.fnx"
819    model.save(filename)
820    print("Final model saved to:", filename)
821
822    return metrics, filename
823
824
825if __name__ == "__main__":
826    main()
def save_restart_checkpoint( output_directory, stage, training_state, preproc_state, metrics_state):
46def save_restart_checkpoint(output_directory,stage,training_state,preproc_state,metrics_state):
47    with open(output_directory + "/restart_checkpoint.pkl", "wb") as f:
48            pickle.dump({
49                "stage": stage,
50                "training": training_state,
51                "preproc_state": preproc_state,
52                "metrics_state": metrics_state,
53            }, f)
def load_restart_checkpoint(output_directory):
55def load_restart_checkpoint(output_directory):
56    restart_checkpoint_file = output_directory + "/restart_checkpoint.pkl"
57    if not os.path.exists(restart_checkpoint_file):
58        raise FileNotFoundError(
59            f"Training state file not found: {restart_checkpoint_file}"
60        )
61    with open(restart_checkpoint_file, "rb") as f:
62        restart_checkpoint = pickle.load(f)
63    return restart_checkpoint
def main():
 66def main():
 67    parser = argparse.ArgumentParser(prog="fennol_train")
 68    parser.add_argument("config_file", type=str)
 69    parser.add_argument("--model_file", type=str, default=None)
 70    args = parser.parse_args()
 71    config_file = args.config_file
 72    model_file = args.model_file
 73
 74    os.environ["OMP_NUM_THREADS"] = "1"
 75    sys.stdout = io.TextIOWrapper(
 76        open(sys.stdout.fileno(), "wb", 0), write_through=True
 77    )
 78
 79    restart_training = False
 80    if os.path.isdir(config_file):
 81        output_directory = Path(config_file).absolute().as_posix()
 82        restart_training = True
 83        while output_directory.endswith("/"):
 84            output_directory = output_directory[:-1]
 85        config_file = output_directory + "/config.yaml"
 86        backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
 87        shutil.copytree(output_directory, backup_dir)
 88        print("Restarting training from", output_directory)
 89
 90    parameters = load_configuration(config_file)
 91
 92    ### Set the device
 93    if "FENNOL_DEVICE" in os.environ:
 94        device = os.environ["FENNOL_DEVICE"].lower()
 95        print(f"# Setting device from env FENNOL_DEVICE={device}")
 96    else:
 97        device: str = parameters.get("device", "cpu").lower()
 98    if device == "cpu":
 99        jax.config.update('jax_platforms', 'cpu')
100        os.environ["CUDA_VISIBLE_DEVICES"] = ""
101    elif device.startswith("cuda") or device.startswith("gpu"):
102        if ":" in device:
103            num = device.split(":")[-1]
104            os.environ["CUDA_VISIBLE_DEVICES"] = num
105        else:
106            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
107        device = "gpu"
108
109    _device = jax.devices(device)[0]
110    jax.config.update("jax_default_device", _device)
111
112    if restart_training:
113        restart_checkpoint = load_restart_checkpoint(output_directory)
114    else:
115        restart_checkpoint = None
116
117    # output directory
118    if not restart_training:
119        output_directory = parameters.get("output_directory", None)
120        if output_directory is not None:
121            if "{now}" in output_directory:
122                output_directory = output_directory.replace(
123                    "{now}", time.strftime("%Y-%m-%d-%H-%M-%S")
124                )
125            output_directory = Path(output_directory).absolute()
126            if not output_directory.exists():
127                output_directory.mkdir(parents=True)
128            print("Output directory:", output_directory)
129        else:
130            output_directory = "."
131
132    output_directory = str(output_directory) + "/"
133
134    # copy config_file to output directory
135    # config_name = Path(config_file).name
136    config_ext = Path(config_file).suffix
137    with open(config_file) as f_in:
138        config_data = f_in.read()
139    with open(output_directory + "/config" + config_ext, "w") as f_out:
140        f_out.write(config_data)
141
142    # set log file
143    log_file = "train.log"  # parameters.get("log_file", None)
144    logger = TeeLogger(output_directory + log_file)
145    logger.bind_stdout()
146
147    # set matmul precision
148    enable_x64 = parameters.get("double_precision", False)
149    jax.config.update("jax_enable_x64", enable_x64)
150    fprec = "float64" if enable_x64 else "float32"
151    parameters["fprec"] = fprec
152    if enable_x64:
153        print("Double precision enabled.")
154
155    matmul_precision = parameters.get("matmul_prec", "highest").lower()
156    assert matmul_precision in [
157        "default",
158        "high",
159        "highest",
160    ], "matmul_prec must be one of 'default','high','highest'"
161    jax.config.update("jax_default_matmul_precision", matmul_precision)
162
163    # set random seed
164    rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1))
165    print(f"rng_seed: {rng_seed}")
166    rng_key = jax.random.PRNGKey(rng_seed)
167    torch.manual_seed(rng_seed)
168    np.random.seed(rng_seed)
169    random.seed(rng_seed)
170    np_rng = np.random.Generator(np.random.PCG64(rng_seed))
171
172    try:
173        if "stages" in parameters["training"]:
174            ## train in stages ##
175            params = deepcopy(parameters)
176            stages = params["training"].pop("stages")
177            assert isinstance(stages, dict), "'stages' must be a dict with named stages"
178            model_file_stage = model_file
179            print_stages_params = params["training"].get("print_stages_params", False)
180            for i, (stage, stage_params) in enumerate(stages.items()):
181                rng_key, subkey = jax.random.split(rng_key)
182                print("")
183                print(f"### STAGE {i+1}: {stage} ###")
184
185                ## remove end_event from previous stage ##
186                if i > 0 and "end_event" in params["training"]:
187                    params["training"].pop("end_event")
188
189                ## incrementally update training parameters ##
190                params = deep_update(params, {"training": stage_params})
191                if model_file_stage is not None:
192                    ## load model from previous stage ##
193                    params["model_file"] = model_file_stage
194
195                if restart_training and restart_checkpoint["stage"] != i + 1:
196                    print(f"Skipping stage {i+1} (already completed)")
197                    continue
198
199                if print_stages_params:
200                    print("stage parameters:")
201                    print(json.dumps(params, indent=2, sort_keys=False))
202
203                ## train stage ##
204                _, model_file_stage = train(
205                    (subkey, np_rng),
206                    params,
207                    stage=i + 1,
208                    output_directory=output_directory,
209                    restart_checkpoint=restart_checkpoint,
210                )
211
212                restart_training = False
213                restart_checkpoint = None
214        else:
215            ## single training stage ##
216            train(
217                (rng_key, np_rng),
218                parameters,
219                model_file=model_file,
220                output_directory=output_directory,
221                restart_checkpoint=restart_checkpoint,
222            )
223    except KeyboardInterrupt:
224        print("Training interrupted by user.")
225    finally:
226        if log_file is not None:
227            logger.unbind_stdout()
228            logger.close()
def train( rng, parameters, model_file=None, stage=None, output_directory=None, restart_checkpoint=None):
231def train(
232    rng,
233    parameters,
234    model_file=None,
235    stage=None,
236    output_directory=None,
237    restart_checkpoint=None,
238):
239    if output_directory is None:
240        output_directory = "./"
241    elif output_directory == "":
242        output_directory = "./"
243    elif not output_directory.endswith("/"):
244        output_directory += "/"
245    stage_prefix = f"_stage_{stage}" if stage is not None else ""
246
247    if isinstance(rng, tuple):
248        rng_key, np_rng = rng
249    else:
250        rng_key = rng
251        np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1)))
252
253    ### GET TRAINING PARAMETERS ###
254    training_parameters = parameters.get("training", {})
255
256    #### LOAD MODEL ####
257    if restart_checkpoint is not None:
258        model_key = None
259        model_file = output_directory + "latest_model.fnx"
260    else:
261        rng_key, model_key = jax.random.split(rng_key)
262    model = load_model(parameters, model_file, rng_key=model_key)
263
264    ### SAVE INITIAL MODEL ###
265    save_initial_model = (
266        restart_checkpoint is None 
267        and (stage is None or stage == 0)
268    )
269    if save_initial_model:
270        model.save(output_directory + "initial_model.fnx")
271
272    model_ref = None
273    if "model_ref" in training_parameters:
274        model_ref = load_model(parameters, training_parameters["model_ref"])
275        print("Reference model:", training_parameters["model_ref"])
276
277        if "ref_parameters" in training_parameters:
278            ref_parameters = training_parameters["ref_parameters"]
279            assert isinstance(
280                ref_parameters, list
281            ), "ref_parameters must be a list of str"
282            print("Reference parameters:", ref_parameters)
283            model.variables = copy_parameters(
284                model.variables, model_ref.variables, ref_parameters
285            )
286
287    ### SET FLOATING POINT PRECISION ###
288    fprec = parameters.get("fprec", "float32")
289
290    def convert_to_fprec(x):
291        if jnp.issubdtype(x.dtype, jnp.floating):
292            return x.astype(fprec)
293        return x
294
295    model.variables = jax.tree_map(convert_to_fprec, model.variables)
296
297    ### SET UP LOSS FUNCTION ###
298    loss_definition, used_keys, ref_keys = get_loss_definition(
299        training_parameters, model_energy_unit=model.energy_unit
300    )
301
302    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
303    if coordinates_ref_key is not None:
304        compute_ref_coords = True
305        print("Reference coordinates:", coordinates_ref_key)
306    else:
307        compute_ref_coords = False
308
309    #### LOAD DATASET ####
310    dspath = training_parameters.get("dspath", None)
311    if dspath is None:
312        raise ValueError("Dataset path 'training/dspath' should be specified.")
313    batch_size = training_parameters.get("batch_size", 16)
314    batch_size_val = training_parameters.get("batch_size_val", None)
315    rename_refs = training_parameters.get("rename_refs", {})
316    training_iterator, validation_iterator = load_dataset(
317        dspath=dspath,
318        batch_size=batch_size,
319        batch_size_val=batch_size_val,
320        training_parameters=training_parameters,
321        infinite_iterator=True,
322        atom_padding=True,
323        ref_keys=ref_keys,
324        split_data_inputs=True,
325        np_rng=np_rng,
326        add_flags=["training"],
327        fprec=fprec,
328        rename_refs=rename_refs,
329    )
330
331    compute_forces = "forces" in used_keys
332    compute_virial = "virial_tensor" in used_keys or "virial" in used_keys
333    compute_stress = "stress_tensor" in used_keys or "stress" in used_keys
334    compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys
335
336    ### get optimizer parameters ###
337    max_epochs = int(training_parameters.get("max_epochs", 2000))
338    epoch_format = len(str(max_epochs))
339    nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200)
340    nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20)
341
342    schedule, sch_state, schedule_metrics,adaptive_scheduler = get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters)
343    optimizer,ilr = get_optimizer(
344        training_parameters, model.variables, schedule(sch_state)[0]
345    )
346    opt_st = optimizer.init(model.variables)
347
348    ### exponential moving average of the parameters ###
349    ema_decay = training_parameters.get("ema_decay", None)
350    if ema_decay is not None:
351        if isinstance(ema_decay,float):
352            ema_decay = [ema_decay]
353        for decay in ema_decay:
354            assert isinstance(decay, float), "ema_decay must be a float or a list of floats"
355            assert 0. <= decay < 1.0, "ema_decay must be in (0,1)"
356        
357        power_ema = training_parameters.get("ema_power", False)
358        ema_ival = int(training_parameters.get("ema_validate", 0))
359        ema = multi_ema(decays=ema_decay,power=power_ema)  
360    else:
361        ema = multi_ema(decays=[]) # no ema
362        ema_ival = 0
363    ema_st = ema.init(model.variables)
364    variables_ema, _ = ema.update(model.variables, ema_st)
365
366    #### end event ####
367    end_event = training_parameters.get("end_event", None)
368    if end_event is None or isinstance(end_event, str) and end_event.lower() == "none":
369        is_end = lambda metrics: False
370    else:
371        assert len(end_event) == 2, "end_event must be a list of two elements"
372        is_end = lambda metrics: metrics[end_event[0]] < end_event[1]
373
374    if "energy_terms" in training_parameters:
375        model.set_energy_terms(training_parameters["energy_terms"], jit=False)
376    print("energy terms:", model.energy_terms)
377
378    ### MODEL EVALUATION FUNCTION ###
379    pbc_training = training_parameters.get("pbc_training", False)
380    if compute_stress or compute_virial or compute_pressure:
381        virial_key = "virial" if "virial" in used_keys else "virial_tensor"
382        stress_key = "stress" if "stress" in used_keys else "stress_tensor"
383        pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor"
384        if compute_stress or compute_pressure:
385            assert pbc_training, "PBC must be enabled for stress or virial training"
386            print("Computing forces and stress tensor")
387            def evaluate(model, variables, data):
388                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
389                cells = output["cells"]
390                volume = jnp.abs(jnp.linalg.det(cells))
391                stress = vir / volume[:, None, None]
392                output[stress_key] = stress
393                output[virial_key] = vir
394                if pressure_key == "pressure":
395                    output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0
396                else:
397                    output[pressure_key] = -stress
398                return output
399
400        else:
401            print("Computing forces and virial tensor")
402            def evaluate(model, variables, data):
403                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
404                output[virial_key] = vir
405                return output
406
407    elif compute_forces:
408        print("Computing forces")
409        def evaluate(model, variables, data):
410            _, _, output = model._energy_and_forces(variables, data)
411            return output
412
413    elif model.energy_terms is not None:
414        def evaluate(model, variables, data):
415            _, output = model._total_energy(variables, data)
416            return output
417
418    else:
419        def evaluate(model, variables, data):
420            output = model.modules.apply(variables, data)
421            return output
422
423    ### TRAINING FUNCTIONS ###
424    train_step = get_train_step_function(
425        loss_definition=loss_definition,
426        model=model,
427        model_ref=model_ref,
428        compute_ref_coords=compute_ref_coords,
429        evaluate=evaluate,
430        optimizer=optimizer,
431        ema=ema,
432    )
433
434    validation = get_validation_function(
435        loss_definition=loss_definition,
436        model=model,
437        model_ref=model_ref,
438        compute_ref_coords=compute_ref_coords,
439        evaluate=evaluate,
440        return_targets=False,
441    )
442
443    #### CONFIGURE PREPROCESSING ####
444    gpu_preprocessing = training_parameters.get("gpu_preprocessing", False)
445    if gpu_preprocessing:
446        print("GPU preprocessing activated.")
447
448    preproc_state = unfreeze(model.preproc_state)
449    layer_state = []
450    for st in preproc_state["layers_state"]:
451        stnew = unfreeze(st)
452        #     st["nblist_skin"] = nblist_skin
453        #     if nblist_stride > 1:
454        #         st["skin_stride"] = nblist_stride
455        #         st["skin_count"] = nblist_stride
456        if "nblist_mult_size" in training_parameters:
457            stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
458        if "nblist_add_neigh" in training_parameters:
459            stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
460        if "nblist_add_atoms" in training_parameters:
461            stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
462        layer_state.append(freeze(stnew))
463
464    preproc_state["layers_state"] = tuple(layer_state)
465    preproc_state["check_input"] = False
466    model.preproc_state = freeze(preproc_state)
467
468    ### INITIALIZE METRICS ###
469    maes_prev = defaultdict(lambda: np.inf)
470    metrics_beta = training_parameters.get("metrics_ema_decay", -1.0)
471    smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0
472    if smoothen_metrics:
473        print("Computing smoothed metrics with beta =", metrics_beta)
474        rmses_smooth = defaultdict(lambda: 0.0)
475        maes_smooth = defaultdict(lambda: 0.0)
476        rmse_tot_smooth = 0.0
477        mae_tot_smooth = 0.0
478        nsmooth = 0
479    max_restore_count = training_parameters.get("max_restore_count", 5)    
480
481    fmetrics = open(
482        output_directory + f"metrics{stage_prefix}.traj",
483        "w" if restart_checkpoint is None else "a",
484    )
485
486    keep_all_bests = training_parameters.get("keep_all_bests", False)
487    save_model_at_epochs = training_parameters.get("save_model_at_epochs", [])
488    save_model_at_epochs= set([int(x) for x in save_model_at_epochs])
489    save_model_every = int(training_parameters.get("save_model_every_epoch", 0))
490    if save_model_every > 0:
491        save_model_every = list(range(save_model_every, max_epochs+save_model_every, save_model_every))
492        save_model_at_epochs = save_model_at_epochs.union(save_model_every)
493    if "save_model_schedule" in training_parameters:
494        save_model_schedule_def = [int(i) for i in training_parameters["save_model_schedule"]]
495        save_model_schedule = []
496        while len(save_model_schedule_def) > 0:
497            i = save_model_schedule_def.pop(0)
498            if i > 0:
499                save_model_schedule.append(i)
500            elif i < 0:
501                start = -i
502                freq = save_model_schedule_def.pop(0)
503                assert freq > 0, "syntax error in save_model_schedule"
504                save_model_schedule.extend(
505                    list(range(start, max_epochs + 1, freq))
506                )
507        save_model_at_epochs = save_model_at_epochs.union(save_model_schedule)
508    save_model_at_epochs = list(save_model_at_epochs)
509    save_model_at_epochs.sort()
510
511    ### LR factor after restoring a model that has diverged
512    restore_scaling = training_parameters.get("restore_scaling", 1)
513    assert restore_scaling > 0.0, "restore_scaling must be > 0.0"
514    assert restore_scaling <= 1.0, "restore_scaling must be <= 1.0"
515    if restore_scaling != 1:
516        print("Applying LR scaling after restore:", restore_scaling)
517
518
519    metric_use_best = training_parameters.get("metric_best", "rmse_tot")  # .lower()
520    metrics_state = {}
521
522    ### INITIALIZE TRAINING STATE ###
523    if restart_checkpoint is not None:
524        training_state = deepcopy(restart_checkpoint["training"])
525        epoch_start = training_state["epoch"]
526        model.preproc_state = freeze(restart_checkpoint["preproc_state"])
527        if smoothen_metrics:
528            metrics_state = deepcopy(restart_checkpoint["metrics_state"])
529            rmses_smooth = metrics_state["rmses_smooth"]
530            maes_smooth = metrics_state["maes_smooth"]
531            rmse_tot_smooth = metrics_state["rmse_tot_smooth"]
532            mae_tot_smooth = metrics_state["mae_tot_smooth"]
533            nsmooth = metrics_state["nsmooth"]
534        print("Restored training state")
535    else:
536        epoch_start = 1
537        training_state = {
538            "rng_key": rng_key,
539            "opt_state": opt_st,
540            "ema_state": ema_st,
541            "sch_state": sch_state,
542            "variables": model.variables,
543            "variables_ema": variables_ema,
544            "step": 0,
545            "restore_count": 0,
546            "restore_scale": 1,
547            "epoch": 0,
548            "best_metric": np.inf,
549        }
550        if smoothen_metrics:
551            metrics_state = {
552                "rmses_smooth": dict(rmses_smooth),
553                "maes_smooth": dict(maes_smooth),
554                "rmse_tot_smooth": rmse_tot_smooth,
555                "mae_tot_smooth": mae_tot_smooth,
556                "nsmooth": nsmooth,
557            }
558    
559    del opt_st, ema_st, sch_state, preproc_state, variables_ema
560    variables_save = training_state['variables']
561    variables_ema_save = training_state['variables_ema']
562
563    ### Training loop ###
564    start = time.time()
565    print("Starting training...")
566    for epoch in range(epoch_start, max_epochs+1):
567        training_state["epoch"] = epoch
568        s = time.time()
569        for _ in range(nbatch_per_epoch):
570            # fetch data
571            inputs0, data = next(training_iterator)
572
573            # preprocess data
574            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
575
576            rng_key, subkey = jax.random.split(rng_key)
577            inputs["rng_key"] = subkey
578            inputs["training_epoch"] = epoch
579
580            # adjust learning rate
581            current_lr, training_state["sch_state"] = schedule(training_state["sch_state"])
582            current_lr *= training_state["restore_scale"]
583            training_state["opt_state"].inner_states["trainable"].inner_state[ilr].hyperparams[
584                "step_size"
585            ] = current_lr 
586
587            # train step
588            loss,training_state, _ = train_step(data,inputs,training_state)
589
590        rmses_avg = defaultdict(lambda: 0.0)
591        maes_avg = defaultdict(lambda: 0.0)
592        for _ in range(nbatch_per_validation):
593            inputs0, data = next(validation_iterator)
594
595            inputs = model.preprocess(verbose=True,use_gpu=gpu_preprocessing,**inputs0)
596
597            rng_key, subkey = jax.random.split(rng_key)
598            inputs["rng_key"] = subkey
599            inputs["training_epoch"] = epoch
600
601            rmses, maes, output_val = validation(
602                data=data,
603                inputs=inputs,
604                variables=training_state["variables_ema"][ema_ival],
605            )
606            for k, v in rmses.items():
607                rmses_avg[k] += v
608            for k, v in maes.items():
609                maes_avg[k] += v
610
611        jax.block_until_ready(output_val)
612
613        ### Timings ###
614        e = time.time()
615        epoch_time = e - s
616
617        elapsed_time = e - start
618        if not adaptive_scheduler:
619            remain_glob = elapsed_time * (max_epochs - epoch + 1) / epoch
620            remain_last = epoch_time * (max_epochs - epoch + 1)
621            # estimate remaining time via weighted average (put more weight on last at the beginning)
622            wremain = np.sin(0.5 * np.pi * epoch / max_epochs)
623            remaining_time = human_time_duration(
624                remain_glob * wremain + remain_last * (1 - wremain)
625            )
626        elapsed_time = human_time_duration(elapsed_time)
627        batch_time = human_time_duration(
628            epoch_time / (nbatch_per_epoch + nbatch_per_validation)
629        )
630        epoch_time = human_time_duration(epoch_time)
631
632        for k in rmses_avg.keys():
633            rmses_avg[k] /= nbatch_per_validation
634        for k in maes_avg.keys():
635            maes_avg[k] /= nbatch_per_validation
636
637        ### Print metrics ###
638        print("")
639        line = f"Epoch {epoch}, lr={current_lr:.3e}, loss = {loss:.3e}"
640        line += f", epoch time = {epoch_time}, batch time = {batch_time}"
641        line += f", elapsed time = {elapsed_time}"
642        if not adaptive_scheduler:
643            line += f", est. remaining time = {remaining_time}"
644        print(line)
645        rmse_tot = 0.0
646        mae_tot = 0.0
647        for k in rmses_avg.keys():
648            mult = loss_definition[k]["display_mult"]
649            loss_prms = loss_definition[k]
650            rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5
651            mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"]
652            unit = "(" + loss_prms["display_unit"] + ")" if "display_unit" in loss_prms else ""
653
654            weight_str = ""
655            if "weight_schedule" in loss_prms:
656                w = linear_schedule(epoch, *loss_prms["weight_schedule"])
657                weight_str = f"(w={w:.3f})"
658
659            if rmses_avg[k] / mult < 1.0e-2:
660                print(
661                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e}   {unit}  {weight_str}"
662                )
663            else:
664                print(
665                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f}   {unit}  {weight_str}"
666                )
667
668        ### CHECK IF WE SHOULD RESTORE PREVIOUS MODEL ###
669        restore = False
670        reinit = False
671        for k, mae in maes_avg.items():
672            if np.isnan(mae):
673                restore = True
674                reinit = True
675                print("NaN detected in mae")
676                # for k, v in inputs.items():
677                #     if hasattr(v,"shape"):
678                #         if np.isnan(v).any():
679                #             print(k,v)
680                # sys.exit(1)
681            if "threshold" in loss_definition[k]:
682                thr = loss_definition[k]["threshold"]
683                if mae > thr * maes_prev[k]:
684                    restore = True
685                    break
686
687        epoch_time_str = time.strftime("%Y-%m-%d-%H-%M-%S")
688        if epoch in save_model_at_epochs:
689            for i,v in enumerate(training_state["variables_ema"]):
690                model.variables = v
691                ema_str = f"_ema{i+1}" if len(training_state["variables_ema"]) > 1 else ""
692                filename = output_directory + f"model{stage_prefix}_epoch{epoch:0{epoch_format}}{ema_str}_{epoch_time_str}.fnx"
693                print("Scheduled model save :", filename)
694                model.save(filename)
695
696        if restore:
697            training_state["restore_count"] += 1
698            if training_state["restore_count"]  > max_restore_count:
699                if reinit:
700                    raise ValueError("Model diverged and could not be restored.")
701                else:
702                    training_state["restore_count"]  = 0
703                    print(
704                        f"{max_restore_count} unsuccessful restores, resuming training."
705                    )
706
707            training_state["variables"] = deepcopy(variables_save)
708            training_state["variables_ema"] = deepcopy(variables_ema_save)
709            training_state["restore_scale"] *= restore_scaling
710            print(f"Restored previous model after divergence. Restore scale: {training_state['restore_scale']:.3f}")
711            if reinit:
712                training_state["opt_state"] = optimizer.init(model.variables)
713                training_state["ema_state"] = ema.init(model.variables)
714                print("Reinitialized optimizer after divergence.")
715            save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
716            continue
717
718        training_state["restore_count"]  = 0
719
720        variables_save = deepcopy(training_state["variables"])
721        variables_ema_save = deepcopy(training_state["variables_ema"])
722        maes_prev = maes_avg
723
724        ### SAVE MODEL ###
725        model.variables = training_state["variables_ema"][ema_ival]
726        model.save(output_directory + "latest_model.fnx")
727
728        # save metrics
729        step = int(training_state["step"])
730        metrics = {
731            "epoch": epoch,
732            "step": step,
733            "data_count": step * batch_size,
734            "elapsed_time": time.time() - start,
735            "lr": current_lr,
736            "loss": loss,
737        }
738        for k in rmses_avg.keys():
739            mult = loss_definition[k]["mult"]
740            metrics[f"rmse_{k}"] = rmses_avg[k] / mult
741            metrics[f"mae_{k}"] = maes_avg[k] / mult
742
743        metrics["rmse_tot"] = rmse_tot
744        metrics["mae_tot"] = mae_tot
745        if smoothen_metrics:
746            nsmooth += 1
747            for k in rmses_avg.keys():
748                mult = loss_definition[k]["mult"]
749                rmses_smooth[k] = (
750                    metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k]
751                )
752                maes_smooth[k] = (
753                    metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k]
754                )
755                metrics[f"rmse_smooth_{k}"] = (
756                    rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
757                )
758                metrics[f"mae_smooth_{k}"] = (
759                    maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
760                )
761            rmse_tot_smooth = (
762                metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot
763            )
764            mae_tot_smooth = (
765                metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot
766            )
767            metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth)
768            metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth)
769
770            # update metrics state
771            metrics_state["rmses_smooth"] = dict(rmses_smooth)
772            metrics_state["maes_smooth"] = dict(maes_smooth)
773            metrics_state["rmse_tot_smooth"] = rmse_tot_smooth
774            metrics_state["mae_tot_smooth"] = mae_tot_smooth
775            metrics_state["nsmooth"] = nsmooth
776
777        assert (
778            metric_use_best in metrics
779        ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics"
780
781        metric_for_best = metrics[metric_use_best]
782        if metric_for_best <  training_state["best_metric"]:
783            training_state["best_metric"] = metric_for_best
784            metrics["best_metric"] = training_state["best_metric"]
785            best_epoch_str = f"_epoch{epoch:0{epoch_format}}_{epoch_time_str}" if keep_all_bests else ""
786            best_name = output_directory + f"best_model{stage_prefix}{best_epoch_str}.fnx"
787            model.save(best_name)
788            print("New best model saved to:", best_name)
789        else:
790            metrics["best_metric"] = training_state["best_metric"]
791
792        if epoch == 1:
793            headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())]
794            fmetrics.write("# " + " ".join(headers) + "\n")
795        fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n")
796        fmetrics.flush()
797
798        # update learning rate using current metrics
799        assert (
800            schedule_metrics in metrics
801        ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics"
802        current_lr, training_state["sch_state"] = schedule(training_state["sch_state"], metrics[schedule_metrics])
803
804        # update and save training state
805        save_restart_checkpoint(output_directory,stage,training_state,model.preproc_state,metrics_state)
806
807        
808        if is_end(metrics):
809            print("Stage finished.")
810            break
811
812    end = time.time()
813
814    print(f"Training time: {human_time_duration(end-start)}")
815    print("")
816
817    fmetrics.close()
818
819    filename = output_directory + f"final_model{stage_prefix}.fnx"
820    model.save(filename)
821    print("Final model saved to:", filename)
822
823    return metrics, filename