fennol.training.training

  1import os
  2import yaml
  3import sys
  4import jax
  5import io
  6import time
  7import jax.numpy as jnp
  8import numpy as np
  9import optax
 10from collections import defaultdict
 11import json
 12from copy import deepcopy
 13from pathlib import Path
 14import argparse
 15import torch
 16import random
 17from flax import traverse_util
 18import json
 19import shutil
 20import pickle
 21
 22from flax.core import freeze, unfreeze
 23from .io import (
 24    load_configuration,
 25    load_dataset,
 26    load_model,
 27    TeeLogger,
 28    copy_parameters,
 29)
 30from .utils import (
 31    get_loss_definition,
 32    get_train_step_function,
 33    get_validation_function,
 34    get_optimizer,
 35    linear_schedule,
 36)
 37from ..utils import deep_update, AtomicUnits as au
 38from ..utils.io import human_time_duration
 39from ..models.preprocessing import AtomPadding, check_input, convert_to_jax
 40
 41
 42def main():
 43    parser = argparse.ArgumentParser(prog="fennol_train")
 44    parser.add_argument("config_file", type=str)
 45    parser.add_argument("--model_file", type=str, default=None)
 46    args = parser.parse_args()
 47    config_file = args.config_file
 48    model_file = args.model_file
 49
 50    os.environ["OMP_NUM_THREADS"] = "1"
 51    sys.stdout = io.TextIOWrapper(
 52        open(sys.stdout.fileno(), "wb", 0), write_through=True
 53    )
 54
 55    restart_training = False
 56    if os.path.isdir(config_file):
 57        output_directory = Path(config_file).absolute().as_posix()
 58        config_file = output_directory + "/config.yaml"
 59        restart_training = True
 60        training_state_file = output_directory + "/train_state"
 61        if not os.path.exists(training_state_file):
 62            raise FileNotFoundError(
 63                f"Training state file not found: {training_state_file}"
 64            )
 65        while output_directory.endswith("/"):
 66            output_directory = output_directory[:-1]
 67        backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
 68        shutil.copytree(output_directory, backup_dir)
 69
 70        with open(training_state_file, "rb") as f:
 71            training_state = pickle.load(f)
 72        print("Restarting training from", output_directory)
 73    else:
 74        training_state = None
 75
 76    parameters = load_configuration(config_file)
 77
 78    ### Set the device
 79    device: str = parameters.get("device", "cpu").lower()
 80    if device == "cpu":
 81        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 82    elif device.startswith("cuda") or device.startswith("gpu"):
 83        if ":" in device:
 84            num = device.split(":")[-1]
 85            os.environ["CUDA_VISIBLE_DEVICES"] = num
 86        else:
 87            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 88        device = "gpu"
 89
 90    _device = jax.devices(device)[0]
 91    jax.config.update("jax_default_device", _device)
 92
 93    # output directory
 94    if not restart_training:
 95        output_directory = parameters.get("output_directory", None)
 96        if output_directory is not None:
 97            if "{now}" in output_directory:
 98                output_directory = output_directory.replace(
 99                    "{now}", time.strftime("%Y-%m-%d-%H-%M-%S")
100                )
101            output_directory = Path(output_directory).absolute()
102            if not output_directory.exists():
103                output_directory.mkdir(parents=True)
104            print("Output directory:", output_directory)
105        else:
106            output_directory = "."
107
108    output_directory = str(output_directory) + "/"
109
110    # copy config_file to output directory
111    # config_name = Path(config_file).name
112    config_ext = Path(config_file).suffix
113    with open(config_file) as f_in:
114        config_data = f_in.read()
115    with open(output_directory + "/config" + config_ext, "w") as f_out:
116        f_out.write(config_data)
117
118    # set log file
119    log_file = "train.log"  # parameters.get("log_file", None)
120    logger = TeeLogger(output_directory + log_file)
121    logger.bind_stdout()
122
123    # set matmul precision
124    enable_x64 = parameters.get("double_precision", False)
125    jax.config.update("jax_enable_x64", enable_x64)
126    fprec = "float64" if enable_x64 else "float32"
127    parameters["fprec"] = fprec
128    if enable_x64:
129        print("Double precision enabled.")
130
131    matmul_precision = parameters.get("matmul_prec", "highest").lower()
132    assert matmul_precision in [
133        "default",
134        "high",
135        "highest",
136    ], "matmul_prec must be one of 'default','high','highest'"
137    jax.config.update("jax_default_matmul_precision", matmul_precision)
138
139    # set random seed
140    rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1))
141    print(f"rng_seed: {rng_seed}")
142    rng_key = jax.random.PRNGKey(rng_seed)
143    torch.manual_seed(rng_seed)
144    np.random.seed(rng_seed)
145    random.seed(rng_seed)
146    np_rng = np.random.Generator(np.random.PCG64(rng_seed))
147
148    try:
149        if "stages" in parameters["training"]:
150            ## train in stages ##
151            params = deepcopy(parameters)
152            stages = params["training"].pop("stages")
153            assert isinstance(stages, dict), "'stages' must be a dict with named stages"
154            model_file_stage = model_file
155            print_stages_params = params["training"].get("print_stages_params", False)
156            for i, (stage, stage_params) in enumerate(stages.items()):
157                rng_key, subkey = jax.random.split(rng_key)
158                print("")
159                print(f"### STAGE {i+1}: {stage} ###")
160
161                ## remove end_event from previous stage ##
162                if i > 0 and "end_event" in params["training"]:
163                    params["training"].pop("end_event")
164
165                ## incrementally update training parameters ##
166                params = deep_update(params, {"training": stage_params})
167                if model_file_stage is not None:
168                    ## load model from previous stage ##
169                    params["model_file"] = model_file_stage
170
171                if restart_training and training_state["stage"] != i + 1:
172                    print(f"Skipping stage {i+1} (already completed)")
173                    continue
174
175                if print_stages_params:
176                    print("stage parameters:")
177                    print(json.dumps(params, indent=2, sort_keys=False))
178
179                ## train stage ##
180                _, model_file_stage = train(
181                    (subkey, np_rng),
182                    params,
183                    stage=i + 1,
184                    output_directory=output_directory,
185                    training_state=training_state,
186                )
187                training_state = None
188                restart_training = False
189        else:
190            ## single training stage ##
191            train(
192                (rng_key, np_rng),
193                parameters,
194                model_file=model_file,
195                output_directory=output_directory,
196                training_state=training_state,
197            )
198    except KeyboardInterrupt:
199        print("Training interrupted by user.")
200    finally:
201        if log_file is not None:
202            logger.unbind_stdout()
203            logger.close()
204
205
206def train(
207    rng,
208    parameters,
209    model_file=None,
210    stage=None,
211    output_directory=None,
212    training_state=None,
213):
214    if output_directory is None:
215        output_directory = "./"
216    elif not output_directory.endswith("/"):
217        output_directory += "/"
218    stage_prefix = f"_stage_{stage}" if stage is not None else ""
219
220    if isinstance(rng, tuple):
221        rng_key, np_rng = rng
222    else:
223        rng_key = rng
224        np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1)))
225
226    if training_state is not None:
227        model_key = None
228        model_file = output_directory + "latest_model.fnx"
229    else:
230        rng_key, model_key = jax.random.split(rng_key)
231    model = load_model(parameters, model_file, rng_key=model_key)
232
233    training_parameters = parameters.get("training", {})
234    model_ref = None
235    if "model_ref" in training_parameters:
236        model_ref = load_model(parameters, training_parameters["model_ref"])
237        print("Reference model:", training_parameters["model_ref"])
238
239        if "ref_parameters" in training_parameters:
240            ref_parameters = training_parameters["ref_parameters"]
241            assert isinstance(
242                ref_parameters, list
243            ), "ref_parameters must be a list of str"
244            print("Reference parameters:", ref_parameters)
245            model.variables = copy_parameters(
246                model.variables, model_ref.variables, ref_parameters
247            )
248    fprec = parameters.get("fprec", "float32")
249
250    def convert_to_fprec(x):
251        if jnp.issubdtype(x.dtype, jnp.floating):
252            return x.astype(fprec)
253        return x
254
255    model.variables = jax.tree_map(convert_to_fprec, model.variables)
256
257    loss_definition, used_keys, ref_keys = get_loss_definition(
258        training_parameters, model_energy_unit=model.energy_unit
259    )
260
261    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
262    if coordinates_ref_key is not None:
263        compute_ref_coords = True
264        print("Reference coordinates:", coordinates_ref_key)
265    else:
266        compute_ref_coords = False
267
268    dspath = training_parameters.get("dspath", None)
269    if dspath is None:
270        raise ValueError("Dataset path 'training/dspath' should be specified.")
271    batch_size = training_parameters.get("batch_size", 16)
272    rename_refs = training_parameters.get("rename_refs", {})
273    training_iterator, validation_iterator = load_dataset(
274        dspath=dspath,
275        batch_size=batch_size,
276        training_parameters=training_parameters,
277        infinite_iterator=True,
278        atom_padding=True,
279        ref_keys=ref_keys,
280        split_data_inputs=True,
281        np_rng=np_rng,
282        add_flags=["training"],
283        fprec=fprec,
284        rename_refs=rename_refs,
285    )
286
287    compute_forces = "forces" in used_keys
288    compute_virial = "virial_tensor" in used_keys or "virial" in used_keys
289    compute_stress = "stress_tensor" in used_keys or "stress" in used_keys
290    compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys
291
292    # get optimizer parameters
293    lr = training_parameters.get("lr", 1.0e-3)
294    max_epochs = training_parameters.get("max_epochs", 2000)
295    nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200)
296    nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20)
297    init_lr = training_parameters.get("init_lr", lr / 25)
298    final_lr = training_parameters.get("final_lr", lr / 10000)
299
300    schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower()
301    schedule_type = training_parameters.get("scheduler", schedule_type).lower()
302    schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot")
303
304    adaptive_scheduler = False
305    print("Schedule type:", schedule_type)
306    if schedule_type == "cosine_onecycle":
307        transition_epochs = training_parameters.get("onecycle_epochs", max_epochs)
308        peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs)
309        schedule_ = optax.cosine_onecycle_schedule(
310            peak_value=lr,
311            div_factor=lr / init_lr,
312            final_div_factor=init_lr / final_lr,
313            transition_steps=transition_epochs * nbatch_per_epoch,
314            pct_start=peak_epoch / transition_epochs,
315        )
316        sch_state = {"count": 0, "best": np.inf, "lr": init_lr}
317
318        def schedule(state, rmse=None):
319            new_state = {**state}
320            lr = schedule_(state["count"])
321            if rmse is None:
322                new_state["count"] += 1
323            new_state["lr"] = lr
324            return lr, new_state
325
326    elif schedule_type == "constant":
327        sch_state = {"count": 0}
328
329        def schedule(state, rmse=None):
330            new_state = {**state}
331            new_state["lr"] = lr
332            if rmse is None:
333                new_state["count"] += 1
334            return lr, new_state
335
336    elif schedule_type == "reduce_on_plateau":
337        patience = training_parameters.get("patience", 10)
338        factor = training_parameters.get("lr_factor", 0.5)
339        patience_thr = training_parameters.get("patience_thr", 0.0)
340        sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience}
341        adaptive_scheduler = True
342
343        def schedule(state, rmse=None):
344            new_state = {**state}
345            if rmse is None:
346                new_state["count"] += 1
347                return state["lr"], new_state
348            if rmse <= state["best"] * (1.0 + patience_thr):
349                if rmse < state["best"]:
350                    new_state["best"] = rmse
351                new_state["patience"] = 0
352            else:
353                new_state["patience"] += 1
354                if new_state["patience"] >= patience:
355                    new_state["lr"] = state["lr"] * factor
356                    new_state["patience"] = 0
357                    print("Reducing learning rate to", new_state["lr"])
358            return new_state["lr"], new_state
359
360    else:
361        raise ValueError(f"Unknown schedule_type: {schedule_type}")
362
363    stochastic_scheduler = training_parameters.get("stochastic_scheduler", False)
364    if stochastic_scheduler:
365        schedule_ = schedule
366        rng_key, scheduler_key = jax.random.split(rng_key)
367        sch_state["rng_key"] = scheduler_key
368        sch_state["lr_max"] = lr
369        sch_state["lr_min"] = final_lr
370
371        def schedule(state, rmse=None):
372            new_state = {**state, "lr": state["lr_max"]}
373            if rmse is None:
374                lr_max, new_state = schedule_(new_state, rmse=rmse)
375                lr_min = new_state["lr_min"]
376                new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"])
377                lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey)
378                new_state["lr"] = lr
379                new_state["lr_max"] = lr_max
380
381            return new_state["lr"], new_state
382
383    optimizer = get_optimizer(
384        training_parameters, model.variables, schedule(sch_state)[0]
385    )
386    opt_st = optimizer.init(model.variables)
387
388    # exponential moving average of the parameters
389    ema_decay = training_parameters.get("ema_decay", -1.0)
390    if ema_decay > 0.0:
391        assert ema_decay < 1.0, "ema_decay must be in (0,1)"
392        ema = optax.ema(decay=ema_decay)
393    else:
394        ema = optax.identity()
395    ema_st = ema.init(model.variables)
396
397    # end event
398    end_event = training_parameters.get("end_event", None)
399    if end_event is None or isinstance(end_event, str) and end_event.lower() == "none":
400        is_end = lambda metrics: False
401    else:
402        assert len(end_event) == 2, "end_event must be a list of two elements"
403        is_end = lambda metrics: metrics[end_event[0]] < end_event[1]
404
405    print_timings = parameters.get("print_timings", False)
406
407    if "energy_terms" in training_parameters:
408        model.set_energy_terms(training_parameters["energy_terms"], jit=False)
409    print("energy terms:", model.energy_terms)
410
411    pbc_training = training_parameters.get("pbc_training", False)
412    if compute_stress or compute_virial or compute_pressure:
413        virial_key = "virial" if "virial" in used_keys else "virial_tensor"
414        stress_key = "stress" if "stress" in used_keys else "stress_tensor"
415        pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor"
416        if compute_stress or compute_pressure:
417            assert pbc_training, "PBC must be enabled for stress or virial training"
418            print("Computing forces and stress tensor")
419
420            def evaluate(model, variables, data):
421                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
422                cells = output["cells"]
423                volume = jnp.abs(jnp.linalg.det(cells))
424                stress = vir / volume[:, None, None]
425                output[stress_key] = stress
426                output[virial_key] = vir
427                if pressure_key == "pressure":
428                    output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0
429                else:
430                    output[pressure_key] = -stress
431                return output
432
433        else:
434            print("Computing forces and virial tensor")
435
436            def evaluate(model, variables, data):
437                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
438                output[virial_key] = vir
439                return output
440
441    elif compute_forces:
442        print("Computing forces")
443
444        def evaluate(model, variables, data):
445            _, _, output = model._energy_and_forces(variables, data)
446            return output
447
448    elif model.energy_terms is not None:
449
450        def evaluate(model, variables, data):
451            _, output = model._total_energy(variables, data)
452            return output
453
454    else:
455
456        def evaluate(model, variables, data):
457            output = model.modules.apply(variables, data)
458            return output
459
460    train_step = get_train_step_function(
461        loss_definition=loss_definition,
462        model=model,
463        model_ref=model_ref,
464        compute_ref_coords=compute_ref_coords,
465        evaluate=evaluate,
466        optimizer=optimizer,
467        ema=ema,
468    )
469
470    validation = get_validation_function(
471        loss_definition=loss_definition,
472        model=model,
473        model_ref=model_ref,
474        compute_ref_coords=compute_ref_coords,
475        evaluate=evaluate,
476        return_targets=False,
477    )
478
479    ## configure preprocessing ##
480    minimum_image = training_parameters.get("minimum_image", False)
481    preproc_state = unfreeze(model.preproc_state)
482    layer_state = []
483    for st in preproc_state["layers_state"]:
484        stnew = unfreeze(st)
485        #     st["nblist_skin"] = nblist_skin
486        #     if nblist_stride > 1:
487        #         st["skin_stride"] = nblist_stride
488        #         st["skin_count"] = nblist_stride
489        if pbc_training:
490            stnew["minimum_image"] = minimum_image
491        if "nblist_mult_size" in training_parameters:
492            stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
493        if "nblist_add_neigh" in training_parameters:
494            stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
495        if "nblist_add_atoms" in training_parameters:
496            stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
497        layer_state.append(freeze(stnew))
498
499    preproc_state["layers_state"] = tuple(layer_state)
500    model.preproc_state = freeze(preproc_state)
501
502    # inputs,data = next(training_iterator)
503    # inputs = model.preprocess(**inputs)
504    # print("preproc_state:",model.preproc_state)
505
506    if training_parameters.get("gpu_preprocessing", False):
507        print("GPU preprocessing activated.")
508
509        def preprocessing(model, inputs):
510            preproc_state = model.preproc_state
511            outputs = model.preprocessing.process(preproc_state, inputs)
512            preproc_state, state_up, outputs, overflow = (
513                model.preprocessing.check_reallocate(preproc_state, outputs)
514            )
515            if overflow:
516                print("GPU preprocessing: nblist overflow => reallocating nblist")
517                print("size updates:", state_up)
518            model.preproc_state = preproc_state
519            return outputs
520
521    else:
522        preprocessing = lambda model, inputs: model.preprocess(**inputs)
523
524    fetch_time = 0.0
525    preprocess_time = 0.0
526    step_time = 0.0
527
528    maes_prev = defaultdict(lambda: np.inf)
529    metrics_beta = training_parameters.get("metrics_ema_decay", -1.0)
530    smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0
531    if smoothen_metrics:
532        print("Computing smoothed metrics with beta =", metrics_beta)
533        rmses_smooth = defaultdict(lambda: 0.0)
534        maes_smooth = defaultdict(lambda: 0.0)
535        rmse_tot_smooth = 0.0
536        mae_tot_smooth = 0.0
537        nsmooth = 0
538    count = 0
539    restore_count = 0
540    max_restore_count = training_parameters.get("max_restore_count", 5)
541    variables = deepcopy(model.variables)
542    variables_save = deepcopy(variables)
543    variables_ema_save = deepcopy(model.variables)
544
545    fmetrics = open(
546        output_directory + f"metrics{stage_prefix}.traj",
547        "w" if training_state is None else "a",
548    )
549
550    keep_all_bests = training_parameters.get("keep_all_bests", False)
551    previous_best_name = None
552    best_metric = np.inf
553    metric_use_best = training_parameters.get("metric_best", "rmse_tot")  # .lower()
554    # authorized_metrics = ["mae", "rmse"]
555    # if smoothen_metrics:
556    #     authorized_metrics+= ["mae_smooth", "rmse_smooth"]
557    # assert metric_use_best in authorized_metrics, f"metric_best must be one of {authorized_metrics}"
558
559    if training_state is not None:
560        preproc_state = training_state["preproc_state"]
561        model.preproc_state = freeze(preproc_state)
562        opt_st = training_state["opt_state"]
563        ema_st = training_state["ema_state"]
564        sch_state = training_state["sch_state"]
565        variables = training_state["variables"]
566        model.variables = training_state["variables_ema"]
567        count = training_state["count"]
568        restore_count = training_state["restore_count"]
569        epoch_start = training_state["epoch"]
570        rng_key = training_state["rng_key"]
571        best_metric = training_state["best_metric"]
572        if smoothen_metrics:
573            rmses_smooth = training_state["rmses_smooth"]
574            maes_smooth = training_state["maes_smooth"]
575            rmse_tot_smooth = training_state["rmse_tot_smooth"]
576            mae_tot_smooth = training_state["mae_tot_smooth"]
577            nsmooth = training_state["nsmooth"]
578        print("Restored training state")
579    else:
580        epoch_start = 0
581        training_state = {
582            "preproc_state": preproc_state,
583            "rng_key": rng_key,
584            "opt_state": opt_st,
585            "ema_state": ema_st,
586            "sch_state": sch_state,
587            "variables": variables,
588            "variables_ema": model.variables,
589            "count": count,
590            "restore_count": restore_count,
591            "epoch": 0,
592            "stage": stage,
593            "best_metric": np.inf,
594        }
595        if smoothen_metrics:
596            training_state["rmses_smooth"] = dict(rmses_smooth)
597            training_state["maes_smooth"] = dict(maes_smooth)
598            training_state["rmse_tot_smooth"] = rmse_tot_smooth
599            training_state["mae_tot_smooth"] = mae_tot_smooth
600            training_state["nsmooth"] = nsmooth
601
602    ### Training loop ###
603    start = time.time()
604    print("Starting training...")
605    for epoch in range(epoch_start, max_epochs):
606        s = time.time()
607        for _ in range(nbatch_per_epoch):
608            # fetch data
609            inputs0, data = next(training_iterator)
610
611            # preprocess data
612            inputs = preprocessing(model, inputs0)
613            # inputs = model.preprocess(**inputs0)
614
615            rng_key, subkey = jax.random.split(rng_key)
616            inputs["rng_key"] = subkey
617            # if compute_ref_coords:
618            #     inputs_ref = {**inputs0, "coordinates": data[coordinates_ref_key]}
619            #     inputs_ref = model.preprocess(**inputs_ref)
620            # else:
621            #     inputs_ref = None
622            # if print_timings:
623            #     jax.block_until_ready(inputs["coordinates"])
624
625            # train step
626            # opt_st.inner_states["trainable"].inner_state[1].hyperparams[
627            #     "learning_rate"
628            # ] = schedule(count)
629            current_lr, sch_state = schedule(sch_state)
630            opt_st.inner_states["trainable"].inner_state[-1].hyperparams[
631                "step_size"
632            ] = current_lr
633            loss, variables, opt_st, model.variables, ema_st, output = train_step(
634                epoch=epoch,
635                data=data,
636                inputs=inputs,
637                variables=variables,
638                variables_ema=model.variables,
639                opt_st=opt_st,
640                ema_st=ema_st,
641            )
642            count += 1
643
644        rmses_avg = defaultdict(lambda: 0.0)
645        maes_avg = defaultdict(lambda: 0.0)
646        for _ in range(nbatch_per_validation):
647            inputs0, data = next(validation_iterator)
648
649            inputs = preprocessing(model, inputs0)
650            # inputs = model.preprocess(**inputs0)
651
652            rng_key, subkey = jax.random.split(rng_key)
653            inputs["rng_key"] = subkey
654
655            # if compute_ref_coords:
656            #     inputs_ref = {**inputs0, "coordinates": data[coordinates_ref_key]}
657            #     inputs_ref = model.preprocess(**inputs_ref)
658            # else:
659            #     inputs_ref = None
660            rmses, maes, output_val = validation(
661                data=data,
662                inputs=inputs,
663                variables=model.variables,
664                # inputs_ref=inputs_ref,
665            )
666            for k, v in rmses.items():
667                rmses_avg[k] += v
668            for k, v in maes.items():
669                maes_avg[k] += v
670
671        jax.block_until_ready(output_val)
672        e = time.time()
673        epoch_time = e - s
674
675        elapsed_time = e - start
676        if not adaptive_scheduler:
677            remain_glob = elapsed_time * (max_epochs - epoch) / (epoch + 1)
678            remain_last = epoch_time * (max_epochs - epoch)
679            # estimate remaining time via weighted average (put more weight on last at the beginning)
680            wremain = np.sin(0.5 * np.pi * (epoch + 1) / max_epochs)
681            remaining_time = human_time_duration(
682                remain_glob * wremain + remain_last * (1 - wremain)
683            )
684        elapsed_time = human_time_duration(elapsed_time)
685        batch_time = human_time_duration(
686            epoch_time / (nbatch_per_epoch + nbatch_per_validation)
687        )
688        epoch_time = human_time_duration(epoch_time)
689
690        for k in rmses_avg.keys():
691            rmses_avg[k] /= nbatch_per_validation
692        for k in maes_avg.keys():
693            maes_avg[k] /= nbatch_per_validation
694
695        step_time /= nbatch_per_epoch
696        fetch_time /= nbatch_per_epoch
697        preprocess_time /= nbatch_per_epoch
698
699        print("")
700        line = f"Epoch {epoch+1}, lr={current_lr:.3e}, loss = {loss:.3e}"
701        line += f", epoch time = {epoch_time}, batch time = {batch_time}"
702        line += f", elapsed time = {elapsed_time}"
703        if not adaptive_scheduler:
704            line += f", est. remaining time = {remaining_time}"
705        print(line)
706        rmse_tot = 0.0
707        mae_tot = 0.0
708        for k in rmses_avg.keys():
709            mult = loss_definition[k]["mult"]
710            loss_prms = loss_definition[k]
711            rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5
712            mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"]
713            unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else ""
714
715            weight_str = ""
716            if "weight_schedule" in loss_prms:
717                w = linear_schedule(epoch, *loss_prms["weight_schedule"])
718                weight_str = f"(w={w:.3f})"
719
720            if rmses_avg[k] / mult < 1.0e-2:
721                print(
722                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e}   {unit}  {weight_str}"
723                )
724            else:
725                print(
726                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f}   {unit}  {weight_str}"
727                )
728
729        # if print_timings:
730        #     print(
731        #         f"    Timings per batch: fetch time = {fetch_time:.5f}; preprocess time = {preprocess_time:.5f}; train time = {step_time:.5f}"
732        #     )
733        fetch_time = 0.0
734        step_time = 0.0
735        preprocess_time = 0.0
736        restore = False
737        reinit = False
738        for k, mae in maes_avg.items():
739            if np.isnan(mae):
740                restore = True
741                reinit = True
742                print("NaN detected in mae")
743                # for k, v in inputs.items():
744                #     if hasattr(v,"shape"):
745                #         if np.isnan(v).any():
746                #             print(k,v)
747                # sys.exit(1)
748            if "threshold" in loss_definition[k]:
749                thr = loss_definition[k]["threshold"]
750                if mae > thr * maes_prev[k]:
751                    restore = True
752                    break
753
754        if restore:
755            restore_count += 1
756            if restore_count > max_restore_count:
757                if reinit:
758                    raise ValueError("Model diverged and could not be restored.")
759                else:
760                    restore_count = 0
761                    print(
762                        f"{max_restore_count} unsuccessful restores, resuming training."
763                    )
764                    continue
765
766            variables = deepcopy(variables_save)
767            model.variables = deepcopy(variables_ema_save)
768            print("Restored previous model after divergence.")
769            if reinit:
770                opt_st = optimizer.init(model.variables)
771                ema_st = ema.init(model.variables)
772                print("Reinitialized optimizer after divergence.")
773            continue
774
775        restore_count = 0
776        variables_save = deepcopy(variables)
777        variables_ema_save = deepcopy(model.variables)
778        maes_prev = maes_avg
779
780        model.save(output_directory + "latest_model.fnx")
781
782        # save metrics
783        metrics = {
784            "epoch": epoch + 1,
785            "step": count,
786            "data_count": count * batch_size,
787            "elapsed_time": time.time() - start,
788            "lr": current_lr,
789            "loss": loss,
790        }
791        for k in rmses_avg.keys():
792            mult = loss_definition[k]["mult"]
793            metrics[f"rmse_{k}"] = rmses_avg[k] / mult
794            metrics[f"mae_{k}"] = maes_avg[k] / mult
795
796        metrics["rmse_tot"] = rmse_tot
797        metrics["mae_tot"] = mae_tot
798        if smoothen_metrics:
799            nsmooth += 1
800            for k in rmses_avg.keys():
801                mult = loss_definition[k]["mult"]
802                rmses_smooth[k] = (
803                    metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k]
804                )
805                maes_smooth[k] = (
806                    metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k]
807                )
808                metrics[f"rmse_smooth_{k}"] = (
809                    rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
810                )
811                metrics[f"mae_smooth_{k}"] = (
812                    maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
813                )
814            rmse_tot_smooth = (
815                metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot
816            )
817            mae_tot_smooth = (
818                metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot
819            )
820            metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth)
821            metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth)
822
823        assert (
824            metric_use_best in metrics
825        ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics"
826        metric_for_best = metrics[metric_use_best]
827        if metric_for_best < best_metric:
828            best_metric = metric_for_best
829            metrics["best_metric"] = best_metric
830            if keep_all_bests:
831                best_name = (
832                    output_directory
833                    + f"best_model{stage_prefix}_{time.strftime('%Y-%m-%d-%H-%M-%S')}.fnx"
834                )
835                model.save(best_name)
836
837            best_name = output_directory + f"best_model{stage_prefix}.fnx"
838            model.save(best_name)
839            print("New best model saved to:", best_name)
840        else:
841            metrics["best_metric"] = best_metric
842
843        if epoch == 0:
844            headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())]
845            fmetrics.write("# " + " ".join(headers) + "\n")
846        fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n")
847        fmetrics.flush()
848
849        # update learning rate using current metrics
850        assert (
851            schedule_metrics in metrics
852        ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics"
853        current_lr, sch_state = schedule(sch_state, metrics[schedule_metrics])
854
855        # update and save training state
856        training_state["preproc_state"] = model.preproc_state
857        training_state["opt_state"] = opt_st
858        training_state["ema_state"] = ema_st
859        training_state["sch_state"] = sch_state
860        training_state["variables"] = variables
861        training_state["variables_ema"] = model.variables
862        training_state["count"] = count
863        training_state["restore_count"] = restore_count
864        training_state["epoch"] = epoch
865        training_state["best_metric"] = best_metric
866        if smoothen_metrics:
867            training_state["rmses_smooth"] = dict(rmses_smooth)
868            training_state["maes_smooth"] = dict(maes_smooth)
869            training_state["rmse_tot_smooth"] = rmse_tot_smooth
870            training_state["mae_tot_smooth"] = mae_tot_smooth
871            training_state["nsmooth"] = nsmooth
872
873        with open(output_directory + "train_state", "wb") as f:
874            pickle.dump(training_state, f)
875
876        if is_end(metrics):
877            print("Stage finished.")
878            break
879
880    end = time.time()
881
882    print(f"Training time: {human_time_duration(end-start)}")
883    print("")
884
885    fmetrics.close()
886
887    filename = output_directory + f"final_model{stage_prefix}.fnx"
888    model.save(filename)
889    print("Final model saved to:", filename)
890
891    return metrics, filename
892
893
894if __name__ == "__main__":
895    main()
def main():
 43def main():
 44    parser = argparse.ArgumentParser(prog="fennol_train")
 45    parser.add_argument("config_file", type=str)
 46    parser.add_argument("--model_file", type=str, default=None)
 47    args = parser.parse_args()
 48    config_file = args.config_file
 49    model_file = args.model_file
 50
 51    os.environ["OMP_NUM_THREADS"] = "1"
 52    sys.stdout = io.TextIOWrapper(
 53        open(sys.stdout.fileno(), "wb", 0), write_through=True
 54    )
 55
 56    restart_training = False
 57    if os.path.isdir(config_file):
 58        output_directory = Path(config_file).absolute().as_posix()
 59        config_file = output_directory + "/config.yaml"
 60        restart_training = True
 61        training_state_file = output_directory + "/train_state"
 62        if not os.path.exists(training_state_file):
 63            raise FileNotFoundError(
 64                f"Training state file not found: {training_state_file}"
 65            )
 66        while output_directory.endswith("/"):
 67            output_directory = output_directory[:-1]
 68        backup_dir = output_directory + f"_backup_{time.strftime('%Y-%m-%d-%H-%M-%S')}"
 69        shutil.copytree(output_directory, backup_dir)
 70
 71        with open(training_state_file, "rb") as f:
 72            training_state = pickle.load(f)
 73        print("Restarting training from", output_directory)
 74    else:
 75        training_state = None
 76
 77    parameters = load_configuration(config_file)
 78
 79    ### Set the device
 80    device: str = parameters.get("device", "cpu").lower()
 81    if device == "cpu":
 82        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 83    elif device.startswith("cuda") or device.startswith("gpu"):
 84        if ":" in device:
 85            num = device.split(":")[-1]
 86            os.environ["CUDA_VISIBLE_DEVICES"] = num
 87        else:
 88            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 89        device = "gpu"
 90
 91    _device = jax.devices(device)[0]
 92    jax.config.update("jax_default_device", _device)
 93
 94    # output directory
 95    if not restart_training:
 96        output_directory = parameters.get("output_directory", None)
 97        if output_directory is not None:
 98            if "{now}" in output_directory:
 99                output_directory = output_directory.replace(
100                    "{now}", time.strftime("%Y-%m-%d-%H-%M-%S")
101                )
102            output_directory = Path(output_directory).absolute()
103            if not output_directory.exists():
104                output_directory.mkdir(parents=True)
105            print("Output directory:", output_directory)
106        else:
107            output_directory = "."
108
109    output_directory = str(output_directory) + "/"
110
111    # copy config_file to output directory
112    # config_name = Path(config_file).name
113    config_ext = Path(config_file).suffix
114    with open(config_file) as f_in:
115        config_data = f_in.read()
116    with open(output_directory + "/config" + config_ext, "w") as f_out:
117        f_out.write(config_data)
118
119    # set log file
120    log_file = "train.log"  # parameters.get("log_file", None)
121    logger = TeeLogger(output_directory + log_file)
122    logger.bind_stdout()
123
124    # set matmul precision
125    enable_x64 = parameters.get("double_precision", False)
126    jax.config.update("jax_enable_x64", enable_x64)
127    fprec = "float64" if enable_x64 else "float32"
128    parameters["fprec"] = fprec
129    if enable_x64:
130        print("Double precision enabled.")
131
132    matmul_precision = parameters.get("matmul_prec", "highest").lower()
133    assert matmul_precision in [
134        "default",
135        "high",
136        "highest",
137    ], "matmul_prec must be one of 'default','high','highest'"
138    jax.config.update("jax_default_matmul_precision", matmul_precision)
139
140    # set random seed
141    rng_seed = parameters.get("rng_seed", np.random.randint(0, 2**32 - 1))
142    print(f"rng_seed: {rng_seed}")
143    rng_key = jax.random.PRNGKey(rng_seed)
144    torch.manual_seed(rng_seed)
145    np.random.seed(rng_seed)
146    random.seed(rng_seed)
147    np_rng = np.random.Generator(np.random.PCG64(rng_seed))
148
149    try:
150        if "stages" in parameters["training"]:
151            ## train in stages ##
152            params = deepcopy(parameters)
153            stages = params["training"].pop("stages")
154            assert isinstance(stages, dict), "'stages' must be a dict with named stages"
155            model_file_stage = model_file
156            print_stages_params = params["training"].get("print_stages_params", False)
157            for i, (stage, stage_params) in enumerate(stages.items()):
158                rng_key, subkey = jax.random.split(rng_key)
159                print("")
160                print(f"### STAGE {i+1}: {stage} ###")
161
162                ## remove end_event from previous stage ##
163                if i > 0 and "end_event" in params["training"]:
164                    params["training"].pop("end_event")
165
166                ## incrementally update training parameters ##
167                params = deep_update(params, {"training": stage_params})
168                if model_file_stage is not None:
169                    ## load model from previous stage ##
170                    params["model_file"] = model_file_stage
171
172                if restart_training and training_state["stage"] != i + 1:
173                    print(f"Skipping stage {i+1} (already completed)")
174                    continue
175
176                if print_stages_params:
177                    print("stage parameters:")
178                    print(json.dumps(params, indent=2, sort_keys=False))
179
180                ## train stage ##
181                _, model_file_stage = train(
182                    (subkey, np_rng),
183                    params,
184                    stage=i + 1,
185                    output_directory=output_directory,
186                    training_state=training_state,
187                )
188                training_state = None
189                restart_training = False
190        else:
191            ## single training stage ##
192            train(
193                (rng_key, np_rng),
194                parameters,
195                model_file=model_file,
196                output_directory=output_directory,
197                training_state=training_state,
198            )
199    except KeyboardInterrupt:
200        print("Training interrupted by user.")
201    finally:
202        if log_file is not None:
203            logger.unbind_stdout()
204            logger.close()
def train( rng, parameters, model_file=None, stage=None, output_directory=None, training_state=None):
207def train(
208    rng,
209    parameters,
210    model_file=None,
211    stage=None,
212    output_directory=None,
213    training_state=None,
214):
215    if output_directory is None:
216        output_directory = "./"
217    elif not output_directory.endswith("/"):
218        output_directory += "/"
219    stage_prefix = f"_stage_{stage}" if stage is not None else ""
220
221    if isinstance(rng, tuple):
222        rng_key, np_rng = rng
223    else:
224        rng_key = rng
225        np_rng = np.random.Generator(np.random.PCG64(np.random.randint(0, 2**32 - 1)))
226
227    if training_state is not None:
228        model_key = None
229        model_file = output_directory + "latest_model.fnx"
230    else:
231        rng_key, model_key = jax.random.split(rng_key)
232    model = load_model(parameters, model_file, rng_key=model_key)
233
234    training_parameters = parameters.get("training", {})
235    model_ref = None
236    if "model_ref" in training_parameters:
237        model_ref = load_model(parameters, training_parameters["model_ref"])
238        print("Reference model:", training_parameters["model_ref"])
239
240        if "ref_parameters" in training_parameters:
241            ref_parameters = training_parameters["ref_parameters"]
242            assert isinstance(
243                ref_parameters, list
244            ), "ref_parameters must be a list of str"
245            print("Reference parameters:", ref_parameters)
246            model.variables = copy_parameters(
247                model.variables, model_ref.variables, ref_parameters
248            )
249    fprec = parameters.get("fprec", "float32")
250
251    def convert_to_fprec(x):
252        if jnp.issubdtype(x.dtype, jnp.floating):
253            return x.astype(fprec)
254        return x
255
256    model.variables = jax.tree_map(convert_to_fprec, model.variables)
257
258    loss_definition, used_keys, ref_keys = get_loss_definition(
259        training_parameters, model_energy_unit=model.energy_unit
260    )
261
262    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
263    if coordinates_ref_key is not None:
264        compute_ref_coords = True
265        print("Reference coordinates:", coordinates_ref_key)
266    else:
267        compute_ref_coords = False
268
269    dspath = training_parameters.get("dspath", None)
270    if dspath is None:
271        raise ValueError("Dataset path 'training/dspath' should be specified.")
272    batch_size = training_parameters.get("batch_size", 16)
273    rename_refs = training_parameters.get("rename_refs", {})
274    training_iterator, validation_iterator = load_dataset(
275        dspath=dspath,
276        batch_size=batch_size,
277        training_parameters=training_parameters,
278        infinite_iterator=True,
279        atom_padding=True,
280        ref_keys=ref_keys,
281        split_data_inputs=True,
282        np_rng=np_rng,
283        add_flags=["training"],
284        fprec=fprec,
285        rename_refs=rename_refs,
286    )
287
288    compute_forces = "forces" in used_keys
289    compute_virial = "virial_tensor" in used_keys or "virial" in used_keys
290    compute_stress = "stress_tensor" in used_keys or "stress" in used_keys
291    compute_pressure = "pressure" in used_keys or "pressure_tensor" in used_keys
292
293    # get optimizer parameters
294    lr = training_parameters.get("lr", 1.0e-3)
295    max_epochs = training_parameters.get("max_epochs", 2000)
296    nbatch_per_epoch = training_parameters.get("nbatch_per_epoch", 200)
297    nbatch_per_validation = training_parameters.get("nbatch_per_validation", 20)
298    init_lr = training_parameters.get("init_lr", lr / 25)
299    final_lr = training_parameters.get("final_lr", lr / 10000)
300
301    schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower()
302    schedule_type = training_parameters.get("scheduler", schedule_type).lower()
303    schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot")
304
305    adaptive_scheduler = False
306    print("Schedule type:", schedule_type)
307    if schedule_type == "cosine_onecycle":
308        transition_epochs = training_parameters.get("onecycle_epochs", max_epochs)
309        peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs)
310        schedule_ = optax.cosine_onecycle_schedule(
311            peak_value=lr,
312            div_factor=lr / init_lr,
313            final_div_factor=init_lr / final_lr,
314            transition_steps=transition_epochs * nbatch_per_epoch,
315            pct_start=peak_epoch / transition_epochs,
316        )
317        sch_state = {"count": 0, "best": np.inf, "lr": init_lr}
318
319        def schedule(state, rmse=None):
320            new_state = {**state}
321            lr = schedule_(state["count"])
322            if rmse is None:
323                new_state["count"] += 1
324            new_state["lr"] = lr
325            return lr, new_state
326
327    elif schedule_type == "constant":
328        sch_state = {"count": 0}
329
330        def schedule(state, rmse=None):
331            new_state = {**state}
332            new_state["lr"] = lr
333            if rmse is None:
334                new_state["count"] += 1
335            return lr, new_state
336
337    elif schedule_type == "reduce_on_plateau":
338        patience = training_parameters.get("patience", 10)
339        factor = training_parameters.get("lr_factor", 0.5)
340        patience_thr = training_parameters.get("patience_thr", 0.0)
341        sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience}
342        adaptive_scheduler = True
343
344        def schedule(state, rmse=None):
345            new_state = {**state}
346            if rmse is None:
347                new_state["count"] += 1
348                return state["lr"], new_state
349            if rmse <= state["best"] * (1.0 + patience_thr):
350                if rmse < state["best"]:
351                    new_state["best"] = rmse
352                new_state["patience"] = 0
353            else:
354                new_state["patience"] += 1
355                if new_state["patience"] >= patience:
356                    new_state["lr"] = state["lr"] * factor
357                    new_state["patience"] = 0
358                    print("Reducing learning rate to", new_state["lr"])
359            return new_state["lr"], new_state
360
361    else:
362        raise ValueError(f"Unknown schedule_type: {schedule_type}")
363
364    stochastic_scheduler = training_parameters.get("stochastic_scheduler", False)
365    if stochastic_scheduler:
366        schedule_ = schedule
367        rng_key, scheduler_key = jax.random.split(rng_key)
368        sch_state["rng_key"] = scheduler_key
369        sch_state["lr_max"] = lr
370        sch_state["lr_min"] = final_lr
371
372        def schedule(state, rmse=None):
373            new_state = {**state, "lr": state["lr_max"]}
374            if rmse is None:
375                lr_max, new_state = schedule_(new_state, rmse=rmse)
376                lr_min = new_state["lr_min"]
377                new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"])
378                lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey)
379                new_state["lr"] = lr
380                new_state["lr_max"] = lr_max
381
382            return new_state["lr"], new_state
383
384    optimizer = get_optimizer(
385        training_parameters, model.variables, schedule(sch_state)[0]
386    )
387    opt_st = optimizer.init(model.variables)
388
389    # exponential moving average of the parameters
390    ema_decay = training_parameters.get("ema_decay", -1.0)
391    if ema_decay > 0.0:
392        assert ema_decay < 1.0, "ema_decay must be in (0,1)"
393        ema = optax.ema(decay=ema_decay)
394    else:
395        ema = optax.identity()
396    ema_st = ema.init(model.variables)
397
398    # end event
399    end_event = training_parameters.get("end_event", None)
400    if end_event is None or isinstance(end_event, str) and end_event.lower() == "none":
401        is_end = lambda metrics: False
402    else:
403        assert len(end_event) == 2, "end_event must be a list of two elements"
404        is_end = lambda metrics: metrics[end_event[0]] < end_event[1]
405
406    print_timings = parameters.get("print_timings", False)
407
408    if "energy_terms" in training_parameters:
409        model.set_energy_terms(training_parameters["energy_terms"], jit=False)
410    print("energy terms:", model.energy_terms)
411
412    pbc_training = training_parameters.get("pbc_training", False)
413    if compute_stress or compute_virial or compute_pressure:
414        virial_key = "virial" if "virial" in used_keys else "virial_tensor"
415        stress_key = "stress" if "stress" in used_keys else "stress_tensor"
416        pressure_key = "pressure" if "pressure" in used_keys else "pressure_tensor"
417        if compute_stress or compute_pressure:
418            assert pbc_training, "PBC must be enabled for stress or virial training"
419            print("Computing forces and stress tensor")
420
421            def evaluate(model, variables, data):
422                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
423                cells = output["cells"]
424                volume = jnp.abs(jnp.linalg.det(cells))
425                stress = vir / volume[:, None, None]
426                output[stress_key] = stress
427                output[virial_key] = vir
428                if pressure_key == "pressure":
429                    output[pressure_key] = -jnp.trace(stress, axis1=1, axis2=2) / 3.0
430                else:
431                    output[pressure_key] = -stress
432                return output
433
434        else:
435            print("Computing forces and virial tensor")
436
437            def evaluate(model, variables, data):
438                _, _, vir, output = model._energy_and_forces_and_virial(variables, data)
439                output[virial_key] = vir
440                return output
441
442    elif compute_forces:
443        print("Computing forces")
444
445        def evaluate(model, variables, data):
446            _, _, output = model._energy_and_forces(variables, data)
447            return output
448
449    elif model.energy_terms is not None:
450
451        def evaluate(model, variables, data):
452            _, output = model._total_energy(variables, data)
453            return output
454
455    else:
456
457        def evaluate(model, variables, data):
458            output = model.modules.apply(variables, data)
459            return output
460
461    train_step = get_train_step_function(
462        loss_definition=loss_definition,
463        model=model,
464        model_ref=model_ref,
465        compute_ref_coords=compute_ref_coords,
466        evaluate=evaluate,
467        optimizer=optimizer,
468        ema=ema,
469    )
470
471    validation = get_validation_function(
472        loss_definition=loss_definition,
473        model=model,
474        model_ref=model_ref,
475        compute_ref_coords=compute_ref_coords,
476        evaluate=evaluate,
477        return_targets=False,
478    )
479
480    ## configure preprocessing ##
481    minimum_image = training_parameters.get("minimum_image", False)
482    preproc_state = unfreeze(model.preproc_state)
483    layer_state = []
484    for st in preproc_state["layers_state"]:
485        stnew = unfreeze(st)
486        #     st["nblist_skin"] = nblist_skin
487        #     if nblist_stride > 1:
488        #         st["skin_stride"] = nblist_stride
489        #         st["skin_count"] = nblist_stride
490        if pbc_training:
491            stnew["minimum_image"] = minimum_image
492        if "nblist_mult_size" in training_parameters:
493            stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
494        if "nblist_add_neigh" in training_parameters:
495            stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
496        if "nblist_add_atoms" in training_parameters:
497            stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
498        layer_state.append(freeze(stnew))
499
500    preproc_state["layers_state"] = tuple(layer_state)
501    model.preproc_state = freeze(preproc_state)
502
503    # inputs,data = next(training_iterator)
504    # inputs = model.preprocess(**inputs)
505    # print("preproc_state:",model.preproc_state)
506
507    if training_parameters.get("gpu_preprocessing", False):
508        print("GPU preprocessing activated.")
509
510        def preprocessing(model, inputs):
511            preproc_state = model.preproc_state
512            outputs = model.preprocessing.process(preproc_state, inputs)
513            preproc_state, state_up, outputs, overflow = (
514                model.preprocessing.check_reallocate(preproc_state, outputs)
515            )
516            if overflow:
517                print("GPU preprocessing: nblist overflow => reallocating nblist")
518                print("size updates:", state_up)
519            model.preproc_state = preproc_state
520            return outputs
521
522    else:
523        preprocessing = lambda model, inputs: model.preprocess(**inputs)
524
525    fetch_time = 0.0
526    preprocess_time = 0.0
527    step_time = 0.0
528
529    maes_prev = defaultdict(lambda: np.inf)
530    metrics_beta = training_parameters.get("metrics_ema_decay", -1.0)
531    smoothen_metrics = metrics_beta < 1.0 and metrics_beta > 0.0
532    if smoothen_metrics:
533        print("Computing smoothed metrics with beta =", metrics_beta)
534        rmses_smooth = defaultdict(lambda: 0.0)
535        maes_smooth = defaultdict(lambda: 0.0)
536        rmse_tot_smooth = 0.0
537        mae_tot_smooth = 0.0
538        nsmooth = 0
539    count = 0
540    restore_count = 0
541    max_restore_count = training_parameters.get("max_restore_count", 5)
542    variables = deepcopy(model.variables)
543    variables_save = deepcopy(variables)
544    variables_ema_save = deepcopy(model.variables)
545
546    fmetrics = open(
547        output_directory + f"metrics{stage_prefix}.traj",
548        "w" if training_state is None else "a",
549    )
550
551    keep_all_bests = training_parameters.get("keep_all_bests", False)
552    previous_best_name = None
553    best_metric = np.inf
554    metric_use_best = training_parameters.get("metric_best", "rmse_tot")  # .lower()
555    # authorized_metrics = ["mae", "rmse"]
556    # if smoothen_metrics:
557    #     authorized_metrics+= ["mae_smooth", "rmse_smooth"]
558    # assert metric_use_best in authorized_metrics, f"metric_best must be one of {authorized_metrics}"
559
560    if training_state is not None:
561        preproc_state = training_state["preproc_state"]
562        model.preproc_state = freeze(preproc_state)
563        opt_st = training_state["opt_state"]
564        ema_st = training_state["ema_state"]
565        sch_state = training_state["sch_state"]
566        variables = training_state["variables"]
567        model.variables = training_state["variables_ema"]
568        count = training_state["count"]
569        restore_count = training_state["restore_count"]
570        epoch_start = training_state["epoch"]
571        rng_key = training_state["rng_key"]
572        best_metric = training_state["best_metric"]
573        if smoothen_metrics:
574            rmses_smooth = training_state["rmses_smooth"]
575            maes_smooth = training_state["maes_smooth"]
576            rmse_tot_smooth = training_state["rmse_tot_smooth"]
577            mae_tot_smooth = training_state["mae_tot_smooth"]
578            nsmooth = training_state["nsmooth"]
579        print("Restored training state")
580    else:
581        epoch_start = 0
582        training_state = {
583            "preproc_state": preproc_state,
584            "rng_key": rng_key,
585            "opt_state": opt_st,
586            "ema_state": ema_st,
587            "sch_state": sch_state,
588            "variables": variables,
589            "variables_ema": model.variables,
590            "count": count,
591            "restore_count": restore_count,
592            "epoch": 0,
593            "stage": stage,
594            "best_metric": np.inf,
595        }
596        if smoothen_metrics:
597            training_state["rmses_smooth"] = dict(rmses_smooth)
598            training_state["maes_smooth"] = dict(maes_smooth)
599            training_state["rmse_tot_smooth"] = rmse_tot_smooth
600            training_state["mae_tot_smooth"] = mae_tot_smooth
601            training_state["nsmooth"] = nsmooth
602
603    ### Training loop ###
604    start = time.time()
605    print("Starting training...")
606    for epoch in range(epoch_start, max_epochs):
607        s = time.time()
608        for _ in range(nbatch_per_epoch):
609            # fetch data
610            inputs0, data = next(training_iterator)
611
612            # preprocess data
613            inputs = preprocessing(model, inputs0)
614            # inputs = model.preprocess(**inputs0)
615
616            rng_key, subkey = jax.random.split(rng_key)
617            inputs["rng_key"] = subkey
618            # if compute_ref_coords:
619            #     inputs_ref = {**inputs0, "coordinates": data[coordinates_ref_key]}
620            #     inputs_ref = model.preprocess(**inputs_ref)
621            # else:
622            #     inputs_ref = None
623            # if print_timings:
624            #     jax.block_until_ready(inputs["coordinates"])
625
626            # train step
627            # opt_st.inner_states["trainable"].inner_state[1].hyperparams[
628            #     "learning_rate"
629            # ] = schedule(count)
630            current_lr, sch_state = schedule(sch_state)
631            opt_st.inner_states["trainable"].inner_state[-1].hyperparams[
632                "step_size"
633            ] = current_lr
634            loss, variables, opt_st, model.variables, ema_st, output = train_step(
635                epoch=epoch,
636                data=data,
637                inputs=inputs,
638                variables=variables,
639                variables_ema=model.variables,
640                opt_st=opt_st,
641                ema_st=ema_st,
642            )
643            count += 1
644
645        rmses_avg = defaultdict(lambda: 0.0)
646        maes_avg = defaultdict(lambda: 0.0)
647        for _ in range(nbatch_per_validation):
648            inputs0, data = next(validation_iterator)
649
650            inputs = preprocessing(model, inputs0)
651            # inputs = model.preprocess(**inputs0)
652
653            rng_key, subkey = jax.random.split(rng_key)
654            inputs["rng_key"] = subkey
655
656            # if compute_ref_coords:
657            #     inputs_ref = {**inputs0, "coordinates": data[coordinates_ref_key]}
658            #     inputs_ref = model.preprocess(**inputs_ref)
659            # else:
660            #     inputs_ref = None
661            rmses, maes, output_val = validation(
662                data=data,
663                inputs=inputs,
664                variables=model.variables,
665                # inputs_ref=inputs_ref,
666            )
667            for k, v in rmses.items():
668                rmses_avg[k] += v
669            for k, v in maes.items():
670                maes_avg[k] += v
671
672        jax.block_until_ready(output_val)
673        e = time.time()
674        epoch_time = e - s
675
676        elapsed_time = e - start
677        if not adaptive_scheduler:
678            remain_glob = elapsed_time * (max_epochs - epoch) / (epoch + 1)
679            remain_last = epoch_time * (max_epochs - epoch)
680            # estimate remaining time via weighted average (put more weight on last at the beginning)
681            wremain = np.sin(0.5 * np.pi * (epoch + 1) / max_epochs)
682            remaining_time = human_time_duration(
683                remain_glob * wremain + remain_last * (1 - wremain)
684            )
685        elapsed_time = human_time_duration(elapsed_time)
686        batch_time = human_time_duration(
687            epoch_time / (nbatch_per_epoch + nbatch_per_validation)
688        )
689        epoch_time = human_time_duration(epoch_time)
690
691        for k in rmses_avg.keys():
692            rmses_avg[k] /= nbatch_per_validation
693        for k in maes_avg.keys():
694            maes_avg[k] /= nbatch_per_validation
695
696        step_time /= nbatch_per_epoch
697        fetch_time /= nbatch_per_epoch
698        preprocess_time /= nbatch_per_epoch
699
700        print("")
701        line = f"Epoch {epoch+1}, lr={current_lr:.3e}, loss = {loss:.3e}"
702        line += f", epoch time = {epoch_time}, batch time = {batch_time}"
703        line += f", elapsed time = {elapsed_time}"
704        if not adaptive_scheduler:
705            line += f", est. remaining time = {remaining_time}"
706        print(line)
707        rmse_tot = 0.0
708        mae_tot = 0.0
709        for k in rmses_avg.keys():
710            mult = loss_definition[k]["mult"]
711            loss_prms = loss_definition[k]
712            rmse_tot = rmse_tot + rmses_avg[k] * loss_prms["weight"] ** 0.5
713            mae_tot = mae_tot + maes_avg[k] * loss_prms["weight"]
714            unit = "(" + loss_prms["unit"] + ")" if "unit" in loss_prms else ""
715
716            weight_str = ""
717            if "weight_schedule" in loss_prms:
718                w = linear_schedule(epoch, *loss_prms["weight_schedule"])
719                weight_str = f"(w={w:.3f})"
720
721            if rmses_avg[k] / mult < 1.0e-2:
722                print(
723                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3e} ; mae_{k}= {maes_avg[k]/mult:10.3e}   {unit}  {weight_str}"
724                )
725            else:
726                print(
727                    f"    rmse_{k}= {rmses_avg[k]/mult:10.3f} ; mae_{k}= {maes_avg[k]/mult:10.3f}   {unit}  {weight_str}"
728                )
729
730        # if print_timings:
731        #     print(
732        #         f"    Timings per batch: fetch time = {fetch_time:.5f}; preprocess time = {preprocess_time:.5f}; train time = {step_time:.5f}"
733        #     )
734        fetch_time = 0.0
735        step_time = 0.0
736        preprocess_time = 0.0
737        restore = False
738        reinit = False
739        for k, mae in maes_avg.items():
740            if np.isnan(mae):
741                restore = True
742                reinit = True
743                print("NaN detected in mae")
744                # for k, v in inputs.items():
745                #     if hasattr(v,"shape"):
746                #         if np.isnan(v).any():
747                #             print(k,v)
748                # sys.exit(1)
749            if "threshold" in loss_definition[k]:
750                thr = loss_definition[k]["threshold"]
751                if mae > thr * maes_prev[k]:
752                    restore = True
753                    break
754
755        if restore:
756            restore_count += 1
757            if restore_count > max_restore_count:
758                if reinit:
759                    raise ValueError("Model diverged and could not be restored.")
760                else:
761                    restore_count = 0
762                    print(
763                        f"{max_restore_count} unsuccessful restores, resuming training."
764                    )
765                    continue
766
767            variables = deepcopy(variables_save)
768            model.variables = deepcopy(variables_ema_save)
769            print("Restored previous model after divergence.")
770            if reinit:
771                opt_st = optimizer.init(model.variables)
772                ema_st = ema.init(model.variables)
773                print("Reinitialized optimizer after divergence.")
774            continue
775
776        restore_count = 0
777        variables_save = deepcopy(variables)
778        variables_ema_save = deepcopy(model.variables)
779        maes_prev = maes_avg
780
781        model.save(output_directory + "latest_model.fnx")
782
783        # save metrics
784        metrics = {
785            "epoch": epoch + 1,
786            "step": count,
787            "data_count": count * batch_size,
788            "elapsed_time": time.time() - start,
789            "lr": current_lr,
790            "loss": loss,
791        }
792        for k in rmses_avg.keys():
793            mult = loss_definition[k]["mult"]
794            metrics[f"rmse_{k}"] = rmses_avg[k] / mult
795            metrics[f"mae_{k}"] = maes_avg[k] / mult
796
797        metrics["rmse_tot"] = rmse_tot
798        metrics["mae_tot"] = mae_tot
799        if smoothen_metrics:
800            nsmooth += 1
801            for k in rmses_avg.keys():
802                mult = loss_definition[k]["mult"]
803                rmses_smooth[k] = (
804                    metrics_beta * rmses_smooth[k] + (1.0 - metrics_beta) * rmses_avg[k]
805                )
806                maes_smooth[k] = (
807                    metrics_beta * maes_smooth[k] + (1.0 - metrics_beta) * maes_avg[k]
808                )
809                metrics[f"rmse_smooth_{k}"] = (
810                    rmses_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
811                )
812                metrics[f"mae_smooth_{k}"] = (
813                    maes_smooth[k] / (1.0 - metrics_beta**nsmooth) / mult
814                )
815            rmse_tot_smooth = (
816                metrics_beta * rmse_tot_smooth + (1.0 - metrics_beta) * rmse_tot
817            )
818            mae_tot_smooth = (
819                metrics_beta * mae_tot_smooth + (1.0 - metrics_beta) * mae_tot
820            )
821            metrics["rmse_smooth_tot"] = rmse_tot_smooth / (1.0 - metrics_beta**nsmooth)
822            metrics["mae_smooth_tot"] = mae_tot_smooth / (1.0 - metrics_beta**nsmooth)
823
824        assert (
825            metric_use_best in metrics
826        ), f"Error: metric for selectring best model '{metric_use_best}' not in metrics"
827        metric_for_best = metrics[metric_use_best]
828        if metric_for_best < best_metric:
829            best_metric = metric_for_best
830            metrics["best_metric"] = best_metric
831            if keep_all_bests:
832                best_name = (
833                    output_directory
834                    + f"best_model{stage_prefix}_{time.strftime('%Y-%m-%d-%H-%M-%S')}.fnx"
835                )
836                model.save(best_name)
837
838            best_name = output_directory + f"best_model{stage_prefix}.fnx"
839            model.save(best_name)
840            print("New best model saved to:", best_name)
841        else:
842            metrics["best_metric"] = best_metric
843
844        if epoch == 0:
845            headers = [f"{i+1}:{k}" for i, k in enumerate(metrics.keys())]
846            fmetrics.write("# " + " ".join(headers) + "\n")
847        fmetrics.write(" ".join([str(metrics[k]) for k in metrics.keys()]) + "\n")
848        fmetrics.flush()
849
850        # update learning rate using current metrics
851        assert (
852            schedule_metrics in metrics
853        ), f"Error: cannot update lr, '{schedule_metrics}' not in metrics"
854        current_lr, sch_state = schedule(sch_state, metrics[schedule_metrics])
855
856        # update and save training state
857        training_state["preproc_state"] = model.preproc_state
858        training_state["opt_state"] = opt_st
859        training_state["ema_state"] = ema_st
860        training_state["sch_state"] = sch_state
861        training_state["variables"] = variables
862        training_state["variables_ema"] = model.variables
863        training_state["count"] = count
864        training_state["restore_count"] = restore_count
865        training_state["epoch"] = epoch
866        training_state["best_metric"] = best_metric
867        if smoothen_metrics:
868            training_state["rmses_smooth"] = dict(rmses_smooth)
869            training_state["maes_smooth"] = dict(maes_smooth)
870            training_state["rmse_tot_smooth"] = rmse_tot_smooth
871            training_state["mae_tot_smooth"] = mae_tot_smooth
872            training_state["nsmooth"] = nsmooth
873
874        with open(output_directory + "train_state", "wb") as f:
875            pickle.dump(training_state, f)
876
877        if is_end(metrics):
878            print("Stage finished.")
879            break
880
881    end = time.time()
882
883    print(f"Training time: {human_time_duration(end-start)}")
884    print("")
885
886    fmetrics.close()
887
888    filename = output_directory + f"final_model{stage_prefix}.fnx"
889    model.save(filename)
890    print("Final model saved to:", filename)
891
892    return metrics, filename