fennol.training.utils

  1import jax
  2import jax.numpy as jnp
  3import numpy as np
  4from typing import Callable, Optional, Dict, List, Tuple
  5import optax
  6from copy import deepcopy
  7from flax import traverse_util
  8import json
  9import re
 10from functools import partial
 11
 12from ..utils import deep_update, AtomicUnits as au
 13from ..models import FENNIX
 14
 15@partial(jax.jit, static_argnums=(1,2,3,4))
 16def linear_schedule(step, start_value,end_value, start_step, duration):
 17    return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value)
 18
 19
 20def get_training_parameters(
 21    parameters: Dict[str, any], stage: int = -1
 22) -> Dict[str, any]:
 23    params = deepcopy(parameters["training"])
 24    if "stages" not in params:
 25        return params
 26
 27    stages: dict = params.pop("stages")
 28    stage_keys = list(stages.keys())
 29    if stage < 0:
 30        stage = len(stage_keys) + stage
 31    assert stage >= 0 and stage < len(
 32        stage_keys
 33    ), f"Stage {stage} not found in training parameters"
 34    for i in range(stage + 1):
 35        ## remove end_event from previous stage ##
 36        if i > 0 and "end_event" in params:
 37            params.pop("end_event")
 38        ## incrementally update training parameters ##
 39        stage_params = stages[stage_keys[i]]
 40        params = deep_update(params, stage_params)
 41    return params
 42
 43
 44def get_loss_definition(
 45    training_parameters: Dict[str, any],
 46    model_energy_unit: str = "Ha",  # , manual_renames: List[str] = []
 47) -> Tuple[Dict[str, any], List[str], List[str]]:
 48    """
 49    Returns the loss definition and a list of renamed references.
 50
 51    Args:
 52        training_parameters (dict): A dictionary containing training parameters.
 53
 54    Returns:
 55        tuple: A tuple containing:
 56            - loss_definition (dict): A dictionary containing the loss definition.
 57            - rename_refs (list): A list of renamed references.
 58    """
 59    default_loss_type = training_parameters.get("default_loss_type", "log_cosh")
 60    # loss_definition = deepcopy(training_parameters["loss"])
 61    used_keys = []
 62    ref_keys = []
 63    energy_mult = au.get_multiplier(model_energy_unit)
 64    loss_definition = {}
 65    for k in training_parameters["loss"]:
 66        # loss_prms = loss_definition[k]
 67        loss_prms = deepcopy(training_parameters["loss"][k])
 68        if "energy_unit" in loss_prms:
 69            loss_prms["mult"] = energy_mult / au.get_multiplier(
 70                loss_prms["energy_unit"]
 71            )
 72            if "unit" in loss_prms:
 73                print(
 74                    "Warning: Both 'unit' and 'energy_unit' are defined for loss component",
 75                    k,
 76                    " -> using 'energy_unit'",
 77                )
 78            loss_prms["unit"] = loss_prms["energy_unit"]
 79        elif "unit" in loss_prms:
 80            loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"])
 81        else:
 82            loss_prms["mult"] = 1.0
 83        if "key" not in loss_prms:
 84            loss_prms["key"] = k
 85        if "type" not in loss_prms:
 86            loss_prms["type"] = default_loss_type
 87        if "weight" not in loss_prms:
 88            loss_prms["weight"] = 1.0
 89        assert loss_prms["weight"] >= 0.0, "Loss weight must be positive"
 90        if "threshold" in loss_prms:
 91            assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0"
 92        if "ref" in loss_prms:
 93            ref = loss_prms["ref"]
 94            if not (ref.startswith("model_ref/") or ref.startswith("model/")):
 95                ref_keys.append(ref)
 96        if "ds_weight" in loss_prms:
 97            ref_keys.append(loss_prms["ds_weight"])
 98
 99        if "weight_start" in loss_prms:
100            weight_start = loss_prms["weight_start"]
101            if "weight_ramp" in loss_prms:
102                weight_ramp = loss_prms["weight_ramp"]
103            else:
104                weight_ramp = training_parameters.get("max_epochs")
105            weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0)
106            weight_end = loss_prms["weight"]
107            print(
108                "Weight ramp for",
109                k,
110                ":",
111                weight_start,
112                "->",
113                loss_prms["weight"],
114                " in",
115                weight_ramp,
116                "epochs",
117            )
118            loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp)
119            # loss_prms["weight_schedule"] = lambda e: weight_start + jnp.clip(
120                # (e - float(weight_ramp_start)) / float(weight_ramp), 0.0, 1.0
121            # ) * (weight_end - weight_start)
122
123        used_keys.append(loss_prms["key"])
124        loss_definition[k] = loss_prms
125
126    # rename_refs = list(
127    #     set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys)
128    # )
129
130    # for k in loss_definition.keys():
131    #     loss_prms = loss_definition[k]
132    #     if "ref" in loss_prms:
133    #         if loss_prms["ref"] in rename_refs:
134    #             loss_prms["ref"] = "true_" + loss_prms["ref"]
135
136    return loss_definition, list(set(used_keys)), list(set(ref_keys))
137
138
139def get_optimizer(
140    training_parameters: Dict[str, any], variables: Dict, initial_lr: float
141) -> optax.GradientTransformation:
142    """
143    Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
144
145    Args:
146    - training_parameters: A dictionary containing the training parameters.
147    - variables: A  pytree containing the model parameters.
148    - initial_lr: The initial learning rate.
149
150    Returns:
151    - An optax.GradientTransformation object that can be used to optimize the model parameters.
152    """
153
154    default_status = str(training_parameters.get("default_status", "trainable")).lower()
155    assert default_status in [
156        "trainable",
157        "frozen",
158    ], f"Default status must be 'trainable' or 'frozen', got {default_status}"
159
160    # find frozen and trainable parameters
161    frozen = training_parameters.get("frozen", [])
162    trainable = training_parameters.get("trainable", [])
163
164    def training_status(full_path, v):
165        full_path = "/".join(full_path[1:]).lower()
166        status = (default_status, "")
167        for path in frozen:
168            if full_path.startswith(path.lower()) and len(path) > len(status[1]):
169                status = ("frozen", path)
170        for path in trainable:
171            if full_path.startswith(path.lower()) and len(path) > len(status[1]):
172                status = ("trainable", path)
173        return status[0]
174
175    params_partition = traverse_util.path_aware_map(training_status, variables)
176    if len(frozen) > 0 or len(trainable) > 0:
177        print("params partition:")
178        print(json.dumps(params_partition, indent=2, sort_keys=False))
179
180    ## Gradient preprocessing
181    grad_processing = []
182
183    # zero nans
184    zero_nans = training_parameters.get("zero_nans", False)
185    if zero_nans:
186        grad_processing.append(optax.zero_nans())
187
188    # gradient clipping
189    clip_threshold = training_parameters.get("gradient_clipping", -1.0)
190    if clip_threshold > 0.0:
191        print("Adaptive gradient clipping threshold:", clip_threshold)
192        grad_processing.append(optax.adaptive_grad_clip(clip_threshold))
193
194    # OPTIMIZER
195    optimizer_name = training_parameters.get("optimizer", "adabelief")
196    optimizer = eval(
197        optimizer_name,
198        {"__builtins__": None},
199        {**optax.__dict__},
200    )
201    print("Optimizer:", optimizer_name)
202    optimizer_configuration = training_parameters.get("optimizer_config", {})
203    optimizer_configuration["learning_rate"] = 1.0
204    grad_processing.append(optimizer(**optimizer_configuration))
205
206    # weight decay
207    weight_decay = training_parameters.get("weight_decay", 0.0)
208    assert weight_decay >= 0.0, "Weight decay must be positive"
209    decay_targets = training_parameters.get("decay_targets", [""])
210
211    def decay_status(full_path, v):
212        full_path = "/".join(full_path).lower()
213        status = False
214        # print(full_path,re.match(r'^params\/', full_path))
215        for path in decay_targets:
216            if re.match(r"^params/" + path.lower(), full_path):
217                status = True
218            # if full_path.startswith("params/" + path.lower()):
219            #     status = True
220        return status
221
222    decay_mask = traverse_util.path_aware_map(decay_status, variables)
223    if weight_decay > 0.0:
224        print("weight decay:", weight_decay)
225        print(json.dumps(decay_mask, indent=2, sort_keys=False))
226        grad_processing.append(
227            optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask)
228        )
229
230    if zero_nans:
231        grad_processing.append(optax.zero_nans())
232
233    # learning rate
234    grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr))
235
236    ## define optimizer chain
237    optimizer_ = optax.chain(
238        *grad_processing,
239    )
240    partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()}
241    return optax.multi_transform(partition_optimizer, params_partition)
242
243
244def get_train_step_function(
245    loss_definition: Dict,
246    model: FENNIX,
247    evaluate: Callable,
248    optimizer: optax.GradientTransformation,
249    ema: Optional[optax.GradientTransformation] = None,
250    model_ref: Optional[FENNIX] = None,
251    compute_ref_coords: bool = False,
252    jit: bool = True,
253):
254    def train_step(
255        epoch,
256        data,
257        inputs,
258        variables,
259        opt_st,
260        variables_ema=None,
261        ema_st=None,
262    ):
263
264        def loss_fn(variables):
265            if model_ref is not None:
266                output_ref = evaluate(model_ref, model_ref.variables, inputs)
267                # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
268            output = evaluate(model, variables, inputs)
269            # _, _, output = model._energy_and_forces(variables, data)
270            natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
271            nsys = inputs["natoms"].shape[0]
272            system_mask = inputs["true_sys"]
273            atom_mask = inputs["true_atoms"]
274            if "system_sign" in inputs:
275                system_sign = inputs["system_sign"] > 0
276                system_mask = jnp.logical_and(system_mask,system_sign)
277                atom_mask = jnp.logical_and(atom_mask, system_sign[inputs["batch_index"]])
278
279            loss_tot = 0.0
280            for loss_prms in loss_definition.values():
281                use_ref_mask = False
282                predicted = output[loss_prms["key"]]
283                if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
284                    assert compute_ref_coords, "compute_ref_coords must be True"
285                    # predicted = predicted - output_data_ref[loss_prms["key"]]
286                    assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions"
287                    shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
288                    system_sign = inputs["system_sign"].reshape(*shape_mask)
289                    predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys)
290
291                if "ref" in loss_prms:
292                    if loss_prms["ref"].startswith("model_ref/"):
293                        assert model_ref is not None, "model_ref must be provided"
294                        try:
295                            ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
296                        except KeyError:
297                            raise KeyError(
298                                f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}"
299                            )
300                    elif loss_prms["ref"].startswith("model/"):
301                        try:
302                            ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
303                        except KeyError:
304                            raise KeyError(
305                                f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}"
306                            )
307                    else:
308                        try:
309                            ref = data[loss_prms["ref"]] * loss_prms["mult"]
310                        except KeyError:
311                            raise KeyError(
312                                f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}"
313                            )
314                        if loss_prms["ref"] + "_mask" in data:
315                            use_ref_mask = True
316                            ref_mask = data[loss_prms["ref"] + "_mask"]
317                else:
318                    ref = jnp.zeros_like(predicted)
319
320                if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
321                    norm_axis = loss_prms["norm_axis"]
322                    predicted = jnp.linalg.norm(
323                        predicted, axis=norm_axis, keepdims=True
324                    )
325                    ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
326
327                if loss_prms["type"] in [
328                    "ensemble_nll",
329                    "ensemble_crps",
330                    "gaussian_mixture",
331                ]:
332                    ensemble_axis = loss_prms.get("ensemble_axis", -1)
333                    ensemble_weights_key = loss_prms.get("ensemble_weights", None)
334                    ensemble_size = predicted.shape[ensemble_axis]
335                    if ensemble_weights_key is not None:
336                        ensemble_weights = output[ensemble_weights_key]
337                    else:
338                        ensemble_weights = jnp.ones_like(predicted)
339                    if "ensemble_subsample" in loss_prms and "rng_key" in output:
340                        ns = min(
341                            loss_prms["ensemble_subsample"],
342                            predicted.shape[ensemble_axis],
343                        )
344                        key, subkey = jax.random.split(output["rng_key"])
345
346                        def get_subsample(x):
347                            return jax.lax.slice_in_dim(
348                                jax.random.permutation(
349                                    subkey, x, axis=ensemble_axis, independent=True
350                                ),
351                                start_index=0,
352                                limit_index=ns,
353                                axis=ensemble_axis,
354                            )
355
356                        predicted = get_subsample(predicted)
357                        ensemble_weights = get_subsample(ensemble_weights)
358                        output["rng_key"] = key
359                    else:
360                        get_subsample = lambda x: x
361                    ensemble_weights = ensemble_weights / jnp.sum(
362                        ensemble_weights, axis=ensemble_axis, keepdims=True
363                    )
364                    predicted_ensemble = predicted
365                    predicted = (predicted * ensemble_weights).sum(
366                        axis=ensemble_axis, keepdims=True
367                    )
368                    predicted_var = (
369                        ensemble_weights * (predicted_ensemble - predicted) ** 2
370                    ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
371                    predicted = jnp.squeeze(predicted, axis=ensemble_axis)
372
373                if predicted.ndim > 1 and predicted.shape[-1] == 1:
374                    predicted = jnp.squeeze(predicted, axis=-1)
375
376                if ref.ndim > 1 and ref.shape[-1] == 1:
377                    ref = jnp.squeeze(ref, axis=-1)
378
379                per_atom = False
380                shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1)
381                # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape)
382                natscale = 1.0
383                if ref.shape[0] == output["batch_index"].shape[0]:
384                    ## shape is number of atoms
385                    truth_mask = atom_mask
386                    if "ds_weight" in loss_prms:
387                        weight_key = loss_prms["ds_weight"]
388                        natscale = data[weight_key][output["batch_index"]].reshape(
389                            *shape_mask
390                        )
391                elif ref.shape[0] == natoms.shape[0]:
392                    ## shape is number of systems
393                    truth_mask = system_mask
394                    if loss_prms.get("per_atom", False):
395                        per_atom = True
396                        ref = ref / natoms.reshape(*shape_mask)
397                        predicted = predicted / natoms.reshape(*shape_mask)
398                    if "nat_pow" in loss_prms:
399                        natscale = (
400                            1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"]
401                        )
402                    if "ds_weight" in loss_prms:
403                        weight_key = loss_prms["ds_weight"]
404                        natscale = natscale * data[weight_key].reshape(*shape_mask)
405                else:
406                    truth_mask = jnp.ones(ref.shape[0], dtype=bool)
407
408                if use_ref_mask:
409                    truth_mask = truth_mask * ref_mask.astype(bool)
410
411                nel = jnp.maximum(
412                    (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
413                    * jnp.sum(truth_mask).astype(jnp.float32),
414                    1.0,
415                )
416                truth_mask = truth_mask.reshape(*shape_mask)
417
418                ref = ref * truth_mask
419                predicted = predicted * truth_mask
420
421                natscale = natscale / nel
422
423                loss_type = loss_prms["type"]
424                if loss_type == "mse":
425                    loss = jnp.sum(natscale * (predicted - ref) ** 2)
426                elif loss_type == "log_cosh":
427                    loss = jnp.sum(natscale * optax.log_cosh(predicted, ref))
428                elif loss_type == "mae":
429                    loss = jnp.sum(natscale * jnp.abs(predicted - ref))
430                elif loss_type == "rel_mse":
431                    eps = loss_prms.get("rel_eps", 1.0e-6)
432                    rel_pow = loss_prms.get("rel_pow", 2.0)
433                    loss = jnp.sum(
434                        natscale
435                        * (predicted - ref) ** 2
436                        / (eps + jnp.abs(ref) ** rel_pow)
437                    )
438                elif loss_type == "rmse+mae":
439                    loss = (
440                        jnp.sum(natscale * (predicted - ref) ** 2)
441                    ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref))
442                elif loss_type == "crps":
443                    predicted_var_key = loss_prms.get(
444                        "var_key", loss_prms["key"] + "_var"
445                    )
446                    predicted_var = output[predicted_var_key]
447                    if per_atom:
448                        predicted_var = predicted_var / natoms.reshape(*shape_mask)
449                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
450                    sigma = predicted_var**0.5
451                    dy = (ref - predicted) / sigma
452                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
453                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
454                    loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi))
455                elif loss_type == "ensemble_nll":
456                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
457                    loss = 0.5 * jnp.sum(
458                        natscale
459                        * truth_mask
460                        * (
461                            jnp.log(predicted_var)
462                            + (ref - predicted) ** 2 / predicted_var
463                        )
464                    )
465                elif loss_type == "ensemble_crps":
466                    if per_atom:
467                        predicted_var = predicted_var / natoms.reshape(*shape_mask)
468                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
469                    sigma = predicted_var**0.5
470                    dy = (ref - predicted) / sigma
471                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
472                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
473                    loss = jnp.sum(
474                        natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi)
475                    )
476                elif loss_type == "gaussian_mixture":
477                    gm_w = get_subsample(output[loss_prms["key"] + "_w"])
478                    gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True)
479                    sigma = get_subsample(output[loss_prms["key"] + "_sigma"])
480                    ref = jnp.expand_dims(ref, axis=ensemble_axis)
481                    # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2)
482                    # log_w =jnp.log(gm_w + 1.e-10)
483                    # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis)
484                    # loss = jnp.sum(natscale * truth_mask * negloglikelihood)
485
486                    # CRPS loss
487                    def compute_crps(mu, sigma):
488                        dy = mu / sigma
489                        Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
490                        phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
491                        return sigma * (dy * (2 * Phi - 1.0) + 2 * phi)
492
493                    crps1 = compute_crps(ref - predicted_ensemble, sigma)
494                    gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis)
495
496                    if ensemble_axis < 0:
497                        ensemble_axis = predicted_ensemble.ndim + ensemble_axis
498                    muij = jnp.expand_dims(
499                        predicted_ensemble, axis=ensemble_axis
500                    ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1)
501                    sigma2 = sigma**2
502                    sigmaij = (
503                        jnp.expand_dims(sigma2, axis=ensemble_axis)
504                        + jnp.expand_dims(sigma2, axis=ensemble_axis + 1)
505                    ) ** 0.5
506                    wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims(
507                        gm_w, axis=ensemble_axis + 1
508                    )
509                    crps2 = compute_crps(muij, sigmaij)
510                    gm_crps2 = (wij * crps2).sum(
511                        axis=(ensemble_axis, ensemble_axis + 1)
512                    )
513
514                    crps = gm_crps1 - 0.5 * gm_crps2
515                    loss = jnp.sum(natscale * truth_mask * crps)
516
517                elif loss_type == "evidential":
518                    evidence = loss_prms["evidence_key"]
519                    nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1)
520                    gamma = predicted
521                    nu = nu.reshape(shape_mask)
522                    alpha = alpha.reshape(shape_mask)
523                    beta = beta.reshape(shape_mask)
524                    nu = jnp.where(truth_mask, nu, 1.0)
525                    alpha = jnp.where(truth_mask, alpha, 1.0)
526                    beta = jnp.where(truth_mask, beta, 1.0)
527                    omega = 2 * beta * (1 + nu)
528                    lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln(
529                        alpha + 0.5
530                    )
531                    ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega)
532                    lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2)
533                    wst = (
534                        (beta * (1 + nu) / (alpha * nu)) ** 0.5
535                        if loss_prms.get("normalize_evidence", True)
536                        else 1.0
537                    )
538                    lr = (
539                        loss_prms.get("lambda_evidence", 1.0)
540                        * jnp.abs(gamma - ref)
541                        * nu
542                        / wst
543                    )
544                    r = loss_prms.get("evidence_ratio", 1.0)
545                    le = (
546                        loss_prms.get("lambda_evidence_diff", 0.0)
547                        * (nu - r * 2 * alpha) ** 2
548                    )
549                    lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta
550                    loss = lg + ls + lt + lr + le + lb
551
552                    loss = jnp.sum(natscale * loss * truth_mask)
553                elif loss_type == "raw":
554                    loss = jnp.sum(natscale * predicted)
555                else:
556                    raise ValueError(f"Unknown loss type: {loss_type}")
557
558                if "weight_schedule" in loss_prms:
559                    w = linear_schedule(epoch, *loss_prms["weight_schedule"])
560                else:
561                    w = loss_prms["weight"]
562
563                loss_tot = loss_tot + w * loss
564
565            return loss_tot, output
566
567        (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables)
568        updates, opt_st = optimizer.update(grad, opt_st, params=variables)
569        variables = optax.apply_updates(variables, updates)
570        if ema is not None:
571            if variables_ema is None or ema_st is None:
572                raise ValueError(
573                    "train_step was setup with ema but either variables_ema or ema_st was not provided"
574                )
575            variables_ema, ema_st = ema.update(variables, ema_st)
576            return loss, variables, opt_st, variables_ema, ema_st, o
577        else:
578            return loss, variables, opt_st, o
579
580    if jit:
581        return jax.jit(train_step)
582    return train_step
583
584
585def get_validation_function(
586    loss_definition: Dict,
587    model: FENNIX,
588    evaluate: Callable,
589    model_ref: Optional[FENNIX] = None,
590    compute_ref_coords: bool = False,
591    return_targets: bool = False,
592    jit: bool = True,
593):
594
595    def validation(data, inputs, variables, inputs_ref=None):
596        if model_ref is not None:
597            output_ref = evaluate(model_ref, model_ref.variables, inputs)
598            # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
599        output = evaluate(model, variables, inputs)
600        # _, _, output = model._energy_and_forces(variables, data)
601        rmses = {}
602        maes = {}
603        if return_targets:
604            targets = {}
605
606        natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
607        nsys = inputs["natoms"].shape[0]
608        system_mask = inputs["true_sys"]
609        atom_mask = inputs["true_atoms"]
610        if "system_sign" in inputs:
611            system_sign = inputs["system_sign"] > 0
612            system_mask = jnp.logical_and(system_mask,system_sign)
613            atom_mask = jnp.logical_and(atom_mask, system_sign[inputs["batch_index"]])
614
615        for name, loss_prms in loss_definition.items():
616            do_validation = loss_prms.get("validate", True)
617            if not do_validation:
618                continue
619            predicted = output[loss_prms["key"]]
620            use_ref_mask = False
621            
622            if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
623                assert compute_ref_coords, "compute_ref_coords must be True"
624                # predicted = predicted - output_data_ref[loss_prms["key"]]
625                assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions"
626                shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
627                system_sign = inputs["system_sign"].reshape(*shape_mask)
628                predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys)
629
630            if "ref" in loss_prms:
631                if loss_prms["ref"].startswith("model_ref/"):
632                    assert model_ref is not None, "model_ref must be provided"
633                    ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
634                elif loss_prms["ref"].startswith("model/"):
635                    ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
636                else:
637                    ref = data[loss_prms["ref"]] * loss_prms["mult"]
638                    if loss_prms["ref"] + "_mask" in data:
639                        use_ref_mask = True
640                        ref_mask = data[loss_prms["ref"] + "_mask"]
641            else:
642                ref = jnp.zeros_like(predicted)
643            
644            if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
645                norm_axis = loss_prms["norm_axis"]
646                predicted = jnp.linalg.norm(
647                    predicted, axis=norm_axis, keepdims=True
648                )
649                ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
650
651            if loss_prms["type"].startswith("ensemble"):
652                axis = loss_prms.get("ensemble_axis", -1)
653                ensemble_weights_key = loss_prms.get("ensemble_weights", None)
654                if ensemble_weights_key is not None:
655                    ensemble_weights = output[ensemble_weights_key]
656                else:
657                    ensemble_weights = jnp.ones_like(predicted)
658                ensemble_weights = ensemble_weights / jnp.sum(
659                    ensemble_weights, axis=axis, keepdims=True
660                )
661
662                # predicted = predicted.mean(axis=axis)
663                predicted = (ensemble_weights * predicted).sum(axis=axis)
664
665            if loss_prms["type"] in ["gaussian_mixture"]:
666                axis = loss_prms.get("ensemble_axis", -1)
667                gm_w = output[loss_prms["key"] + "_w"]
668                gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True)
669                predicted = (predicted * gm_w).sum(axis=axis)
670
671            if predicted.ndim > 1 and predicted.shape[-1] == 1:
672                predicted = jnp.squeeze(predicted, axis=-1)
673
674            if ref.ndim > 1 and ref.shape[-1] == 1:
675                ref = jnp.squeeze(ref, axis=-1)
676
677            shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1)
678            if ref.shape[0] == output["batch_index"].shape[0]:
679                ## shape is number of atoms
680                truth_mask = atom_mask
681            elif ref.shape[0] == natoms.shape[0]:
682                ## shape is number of systems
683                truth_mask = system_mask
684                if loss_prms.get("per_atom_validation", False):
685                    ref = ref / natoms.reshape(*shape_mask)
686                    predicted = predicted / natoms.reshape(*shape_mask)
687            else:
688                truth_mask = jnp.ones(ref.shape[0], dtype=bool)
689
690            if use_ref_mask:
691                truth_mask = truth_mask * ref_mask.astype(bool)
692
693            nel = jnp.maximum(
694                (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
695                * jnp.sum(truth_mask).astype(predicted.dtype),
696                1.0,
697            )
698
699            truth_mask = truth_mask.reshape(*shape_mask)
700
701            ref = ref * truth_mask
702            predicted = predicted * truth_mask
703
704            rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5
705            mae = jnp.sum(jnp.abs(predicted - ref) / nel)
706
707            rmses[name] = rmse
708            maes[name] = mae
709            if return_targets:
710                targets[name] = (predicted, ref, truth_mask)
711
712        if return_targets:
713            return rmses, maes, output, targets
714
715        return rmses, maes, output
716
717    if jit:
718        return jax.jit(validation)
719    return validation
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def linear_schedule(step, start_value, end_value, start_step, duration):
16@partial(jax.jit, static_argnums=(1,2,3,4))
17def linear_schedule(step, start_value,end_value, start_step, duration):
18    return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value)
def get_training_parameters( parameters: Dict[str, <built-in function any>], stage: int = -1) -> Dict[str, <built-in function any>]:
21def get_training_parameters(
22    parameters: Dict[str, any], stage: int = -1
23) -> Dict[str, any]:
24    params = deepcopy(parameters["training"])
25    if "stages" not in params:
26        return params
27
28    stages: dict = params.pop("stages")
29    stage_keys = list(stages.keys())
30    if stage < 0:
31        stage = len(stage_keys) + stage
32    assert stage >= 0 and stage < len(
33        stage_keys
34    ), f"Stage {stage} not found in training parameters"
35    for i in range(stage + 1):
36        ## remove end_event from previous stage ##
37        if i > 0 and "end_event" in params:
38            params.pop("end_event")
39        ## incrementally update training parameters ##
40        stage_params = stages[stage_keys[i]]
41        params = deep_update(params, stage_params)
42    return params
def get_loss_definition( training_parameters: Dict[str, <built-in function any>], model_energy_unit: str = 'Ha') -> Tuple[Dict[str, <built-in function any>], List[str], List[str]]:
 45def get_loss_definition(
 46    training_parameters: Dict[str, any],
 47    model_energy_unit: str = "Ha",  # , manual_renames: List[str] = []
 48) -> Tuple[Dict[str, any], List[str], List[str]]:
 49    """
 50    Returns the loss definition and a list of renamed references.
 51
 52    Args:
 53        training_parameters (dict): A dictionary containing training parameters.
 54
 55    Returns:
 56        tuple: A tuple containing:
 57            - loss_definition (dict): A dictionary containing the loss definition.
 58            - rename_refs (list): A list of renamed references.
 59    """
 60    default_loss_type = training_parameters.get("default_loss_type", "log_cosh")
 61    # loss_definition = deepcopy(training_parameters["loss"])
 62    used_keys = []
 63    ref_keys = []
 64    energy_mult = au.get_multiplier(model_energy_unit)
 65    loss_definition = {}
 66    for k in training_parameters["loss"]:
 67        # loss_prms = loss_definition[k]
 68        loss_prms = deepcopy(training_parameters["loss"][k])
 69        if "energy_unit" in loss_prms:
 70            loss_prms["mult"] = energy_mult / au.get_multiplier(
 71                loss_prms["energy_unit"]
 72            )
 73            if "unit" in loss_prms:
 74                print(
 75                    "Warning: Both 'unit' and 'energy_unit' are defined for loss component",
 76                    k,
 77                    " -> using 'energy_unit'",
 78                )
 79            loss_prms["unit"] = loss_prms["energy_unit"]
 80        elif "unit" in loss_prms:
 81            loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"])
 82        else:
 83            loss_prms["mult"] = 1.0
 84        if "key" not in loss_prms:
 85            loss_prms["key"] = k
 86        if "type" not in loss_prms:
 87            loss_prms["type"] = default_loss_type
 88        if "weight" not in loss_prms:
 89            loss_prms["weight"] = 1.0
 90        assert loss_prms["weight"] >= 0.0, "Loss weight must be positive"
 91        if "threshold" in loss_prms:
 92            assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0"
 93        if "ref" in loss_prms:
 94            ref = loss_prms["ref"]
 95            if not (ref.startswith("model_ref/") or ref.startswith("model/")):
 96                ref_keys.append(ref)
 97        if "ds_weight" in loss_prms:
 98            ref_keys.append(loss_prms["ds_weight"])
 99
100        if "weight_start" in loss_prms:
101            weight_start = loss_prms["weight_start"]
102            if "weight_ramp" in loss_prms:
103                weight_ramp = loss_prms["weight_ramp"]
104            else:
105                weight_ramp = training_parameters.get("max_epochs")
106            weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0)
107            weight_end = loss_prms["weight"]
108            print(
109                "Weight ramp for",
110                k,
111                ":",
112                weight_start,
113                "->",
114                loss_prms["weight"],
115                " in",
116                weight_ramp,
117                "epochs",
118            )
119            loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp)
120            # loss_prms["weight_schedule"] = lambda e: weight_start + jnp.clip(
121                # (e - float(weight_ramp_start)) / float(weight_ramp), 0.0, 1.0
122            # ) * (weight_end - weight_start)
123
124        used_keys.append(loss_prms["key"])
125        loss_definition[k] = loss_prms
126
127    # rename_refs = list(
128    #     set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys)
129    # )
130
131    # for k in loss_definition.keys():
132    #     loss_prms = loss_definition[k]
133    #     if "ref" in loss_prms:
134    #         if loss_prms["ref"] in rename_refs:
135    #             loss_prms["ref"] = "true_" + loss_prms["ref"]
136
137    return loss_definition, list(set(used_keys)), list(set(ref_keys))

Returns the loss definition and a list of renamed references.

Args: training_parameters (dict): A dictionary containing training parameters.

Returns: tuple: A tuple containing: - loss_definition (dict): A dictionary containing the loss definition. - rename_refs (list): A list of renamed references.

def get_optimizer( training_parameters: Dict[str, <built-in function any>], variables: Dict, initial_lr: float) -> optax._src.base.GradientTransformation:
140def get_optimizer(
141    training_parameters: Dict[str, any], variables: Dict, initial_lr: float
142) -> optax.GradientTransformation:
143    """
144    Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
145
146    Args:
147    - training_parameters: A dictionary containing the training parameters.
148    - variables: A  pytree containing the model parameters.
149    - initial_lr: The initial learning rate.
150
151    Returns:
152    - An optax.GradientTransformation object that can be used to optimize the model parameters.
153    """
154
155    default_status = str(training_parameters.get("default_status", "trainable")).lower()
156    assert default_status in [
157        "trainable",
158        "frozen",
159    ], f"Default status must be 'trainable' or 'frozen', got {default_status}"
160
161    # find frozen and trainable parameters
162    frozen = training_parameters.get("frozen", [])
163    trainable = training_parameters.get("trainable", [])
164
165    def training_status(full_path, v):
166        full_path = "/".join(full_path[1:]).lower()
167        status = (default_status, "")
168        for path in frozen:
169            if full_path.startswith(path.lower()) and len(path) > len(status[1]):
170                status = ("frozen", path)
171        for path in trainable:
172            if full_path.startswith(path.lower()) and len(path) > len(status[1]):
173                status = ("trainable", path)
174        return status[0]
175
176    params_partition = traverse_util.path_aware_map(training_status, variables)
177    if len(frozen) > 0 or len(trainable) > 0:
178        print("params partition:")
179        print(json.dumps(params_partition, indent=2, sort_keys=False))
180
181    ## Gradient preprocessing
182    grad_processing = []
183
184    # zero nans
185    zero_nans = training_parameters.get("zero_nans", False)
186    if zero_nans:
187        grad_processing.append(optax.zero_nans())
188
189    # gradient clipping
190    clip_threshold = training_parameters.get("gradient_clipping", -1.0)
191    if clip_threshold > 0.0:
192        print("Adaptive gradient clipping threshold:", clip_threshold)
193        grad_processing.append(optax.adaptive_grad_clip(clip_threshold))
194
195    # OPTIMIZER
196    optimizer_name = training_parameters.get("optimizer", "adabelief")
197    optimizer = eval(
198        optimizer_name,
199        {"__builtins__": None},
200        {**optax.__dict__},
201    )
202    print("Optimizer:", optimizer_name)
203    optimizer_configuration = training_parameters.get("optimizer_config", {})
204    optimizer_configuration["learning_rate"] = 1.0
205    grad_processing.append(optimizer(**optimizer_configuration))
206
207    # weight decay
208    weight_decay = training_parameters.get("weight_decay", 0.0)
209    assert weight_decay >= 0.0, "Weight decay must be positive"
210    decay_targets = training_parameters.get("decay_targets", [""])
211
212    def decay_status(full_path, v):
213        full_path = "/".join(full_path).lower()
214        status = False
215        # print(full_path,re.match(r'^params\/', full_path))
216        for path in decay_targets:
217            if re.match(r"^params/" + path.lower(), full_path):
218                status = True
219            # if full_path.startswith("params/" + path.lower()):
220            #     status = True
221        return status
222
223    decay_mask = traverse_util.path_aware_map(decay_status, variables)
224    if weight_decay > 0.0:
225        print("weight decay:", weight_decay)
226        print(json.dumps(decay_mask, indent=2, sort_keys=False))
227        grad_processing.append(
228            optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask)
229        )
230
231    if zero_nans:
232        grad_processing.append(optax.zero_nans())
233
234    # learning rate
235    grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr))
236
237    ## define optimizer chain
238    optimizer_ = optax.chain(
239        *grad_processing,
240    )
241    partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()}
242    return optax.multi_transform(partition_optimizer, params_partition)

Returns an optax.GradientTransformation object that can be used to optimize the model parameters.

Args:

  • training_parameters: A dictionary containing the training parameters.
  • variables: A pytree containing the model parameters.
  • initial_lr: The initial learning rate.

Returns:

  • An optax.GradientTransformation object that can be used to optimize the model parameters.
def get_train_step_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, optimizer: optax._src.base.GradientTransformation, ema: Optional[optax._src.base.GradientTransformation] = None, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, jit: bool = True):
245def get_train_step_function(
246    loss_definition: Dict,
247    model: FENNIX,
248    evaluate: Callable,
249    optimizer: optax.GradientTransformation,
250    ema: Optional[optax.GradientTransformation] = None,
251    model_ref: Optional[FENNIX] = None,
252    compute_ref_coords: bool = False,
253    jit: bool = True,
254):
255    def train_step(
256        epoch,
257        data,
258        inputs,
259        variables,
260        opt_st,
261        variables_ema=None,
262        ema_st=None,
263    ):
264
265        def loss_fn(variables):
266            if model_ref is not None:
267                output_ref = evaluate(model_ref, model_ref.variables, inputs)
268                # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
269            output = evaluate(model, variables, inputs)
270            # _, _, output = model._energy_and_forces(variables, data)
271            natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
272            nsys = inputs["natoms"].shape[0]
273            system_mask = inputs["true_sys"]
274            atom_mask = inputs["true_atoms"]
275            if "system_sign" in inputs:
276                system_sign = inputs["system_sign"] > 0
277                system_mask = jnp.logical_and(system_mask,system_sign)
278                atom_mask = jnp.logical_and(atom_mask, system_sign[inputs["batch_index"]])
279
280            loss_tot = 0.0
281            for loss_prms in loss_definition.values():
282                use_ref_mask = False
283                predicted = output[loss_prms["key"]]
284                if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
285                    assert compute_ref_coords, "compute_ref_coords must be True"
286                    # predicted = predicted - output_data_ref[loss_prms["key"]]
287                    assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions"
288                    shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
289                    system_sign = inputs["system_sign"].reshape(*shape_mask)
290                    predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys)
291
292                if "ref" in loss_prms:
293                    if loss_prms["ref"].startswith("model_ref/"):
294                        assert model_ref is not None, "model_ref must be provided"
295                        try:
296                            ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
297                        except KeyError:
298                            raise KeyError(
299                                f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}"
300                            )
301                    elif loss_prms["ref"].startswith("model/"):
302                        try:
303                            ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
304                        except KeyError:
305                            raise KeyError(
306                                f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}"
307                            )
308                    else:
309                        try:
310                            ref = data[loss_prms["ref"]] * loss_prms["mult"]
311                        except KeyError:
312                            raise KeyError(
313                                f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}"
314                            )
315                        if loss_prms["ref"] + "_mask" in data:
316                            use_ref_mask = True
317                            ref_mask = data[loss_prms["ref"] + "_mask"]
318                else:
319                    ref = jnp.zeros_like(predicted)
320
321                if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
322                    norm_axis = loss_prms["norm_axis"]
323                    predicted = jnp.linalg.norm(
324                        predicted, axis=norm_axis, keepdims=True
325                    )
326                    ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
327
328                if loss_prms["type"] in [
329                    "ensemble_nll",
330                    "ensemble_crps",
331                    "gaussian_mixture",
332                ]:
333                    ensemble_axis = loss_prms.get("ensemble_axis", -1)
334                    ensemble_weights_key = loss_prms.get("ensemble_weights", None)
335                    ensemble_size = predicted.shape[ensemble_axis]
336                    if ensemble_weights_key is not None:
337                        ensemble_weights = output[ensemble_weights_key]
338                    else:
339                        ensemble_weights = jnp.ones_like(predicted)
340                    if "ensemble_subsample" in loss_prms and "rng_key" in output:
341                        ns = min(
342                            loss_prms["ensemble_subsample"],
343                            predicted.shape[ensemble_axis],
344                        )
345                        key, subkey = jax.random.split(output["rng_key"])
346
347                        def get_subsample(x):
348                            return jax.lax.slice_in_dim(
349                                jax.random.permutation(
350                                    subkey, x, axis=ensemble_axis, independent=True
351                                ),
352                                start_index=0,
353                                limit_index=ns,
354                                axis=ensemble_axis,
355                            )
356
357                        predicted = get_subsample(predicted)
358                        ensemble_weights = get_subsample(ensemble_weights)
359                        output["rng_key"] = key
360                    else:
361                        get_subsample = lambda x: x
362                    ensemble_weights = ensemble_weights / jnp.sum(
363                        ensemble_weights, axis=ensemble_axis, keepdims=True
364                    )
365                    predicted_ensemble = predicted
366                    predicted = (predicted * ensemble_weights).sum(
367                        axis=ensemble_axis, keepdims=True
368                    )
369                    predicted_var = (
370                        ensemble_weights * (predicted_ensemble - predicted) ** 2
371                    ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
372                    predicted = jnp.squeeze(predicted, axis=ensemble_axis)
373
374                if predicted.ndim > 1 and predicted.shape[-1] == 1:
375                    predicted = jnp.squeeze(predicted, axis=-1)
376
377                if ref.ndim > 1 and ref.shape[-1] == 1:
378                    ref = jnp.squeeze(ref, axis=-1)
379
380                per_atom = False
381                shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1)
382                # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape)
383                natscale = 1.0
384                if ref.shape[0] == output["batch_index"].shape[0]:
385                    ## shape is number of atoms
386                    truth_mask = atom_mask
387                    if "ds_weight" in loss_prms:
388                        weight_key = loss_prms["ds_weight"]
389                        natscale = data[weight_key][output["batch_index"]].reshape(
390                            *shape_mask
391                        )
392                elif ref.shape[0] == natoms.shape[0]:
393                    ## shape is number of systems
394                    truth_mask = system_mask
395                    if loss_prms.get("per_atom", False):
396                        per_atom = True
397                        ref = ref / natoms.reshape(*shape_mask)
398                        predicted = predicted / natoms.reshape(*shape_mask)
399                    if "nat_pow" in loss_prms:
400                        natscale = (
401                            1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"]
402                        )
403                    if "ds_weight" in loss_prms:
404                        weight_key = loss_prms["ds_weight"]
405                        natscale = natscale * data[weight_key].reshape(*shape_mask)
406                else:
407                    truth_mask = jnp.ones(ref.shape[0], dtype=bool)
408
409                if use_ref_mask:
410                    truth_mask = truth_mask * ref_mask.astype(bool)
411
412                nel = jnp.maximum(
413                    (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
414                    * jnp.sum(truth_mask).astype(jnp.float32),
415                    1.0,
416                )
417                truth_mask = truth_mask.reshape(*shape_mask)
418
419                ref = ref * truth_mask
420                predicted = predicted * truth_mask
421
422                natscale = natscale / nel
423
424                loss_type = loss_prms["type"]
425                if loss_type == "mse":
426                    loss = jnp.sum(natscale * (predicted - ref) ** 2)
427                elif loss_type == "log_cosh":
428                    loss = jnp.sum(natscale * optax.log_cosh(predicted, ref))
429                elif loss_type == "mae":
430                    loss = jnp.sum(natscale * jnp.abs(predicted - ref))
431                elif loss_type == "rel_mse":
432                    eps = loss_prms.get("rel_eps", 1.0e-6)
433                    rel_pow = loss_prms.get("rel_pow", 2.0)
434                    loss = jnp.sum(
435                        natscale
436                        * (predicted - ref) ** 2
437                        / (eps + jnp.abs(ref) ** rel_pow)
438                    )
439                elif loss_type == "rmse+mae":
440                    loss = (
441                        jnp.sum(natscale * (predicted - ref) ** 2)
442                    ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref))
443                elif loss_type == "crps":
444                    predicted_var_key = loss_prms.get(
445                        "var_key", loss_prms["key"] + "_var"
446                    )
447                    predicted_var = output[predicted_var_key]
448                    if per_atom:
449                        predicted_var = predicted_var / natoms.reshape(*shape_mask)
450                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
451                    sigma = predicted_var**0.5
452                    dy = (ref - predicted) / sigma
453                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
454                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
455                    loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi))
456                elif loss_type == "ensemble_nll":
457                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
458                    loss = 0.5 * jnp.sum(
459                        natscale
460                        * truth_mask
461                        * (
462                            jnp.log(predicted_var)
463                            + (ref - predicted) ** 2 / predicted_var
464                        )
465                    )
466                elif loss_type == "ensemble_crps":
467                    if per_atom:
468                        predicted_var = predicted_var / natoms.reshape(*shape_mask)
469                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
470                    sigma = predicted_var**0.5
471                    dy = (ref - predicted) / sigma
472                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
473                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
474                    loss = jnp.sum(
475                        natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi)
476                    )
477                elif loss_type == "gaussian_mixture":
478                    gm_w = get_subsample(output[loss_prms["key"] + "_w"])
479                    gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True)
480                    sigma = get_subsample(output[loss_prms["key"] + "_sigma"])
481                    ref = jnp.expand_dims(ref, axis=ensemble_axis)
482                    # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2)
483                    # log_w =jnp.log(gm_w + 1.e-10)
484                    # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis)
485                    # loss = jnp.sum(natscale * truth_mask * negloglikelihood)
486
487                    # CRPS loss
488                    def compute_crps(mu, sigma):
489                        dy = mu / sigma
490                        Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
491                        phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
492                        return sigma * (dy * (2 * Phi - 1.0) + 2 * phi)
493
494                    crps1 = compute_crps(ref - predicted_ensemble, sigma)
495                    gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis)
496
497                    if ensemble_axis < 0:
498                        ensemble_axis = predicted_ensemble.ndim + ensemble_axis
499                    muij = jnp.expand_dims(
500                        predicted_ensemble, axis=ensemble_axis
501                    ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1)
502                    sigma2 = sigma**2
503                    sigmaij = (
504                        jnp.expand_dims(sigma2, axis=ensemble_axis)
505                        + jnp.expand_dims(sigma2, axis=ensemble_axis + 1)
506                    ) ** 0.5
507                    wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims(
508                        gm_w, axis=ensemble_axis + 1
509                    )
510                    crps2 = compute_crps(muij, sigmaij)
511                    gm_crps2 = (wij * crps2).sum(
512                        axis=(ensemble_axis, ensemble_axis + 1)
513                    )
514
515                    crps = gm_crps1 - 0.5 * gm_crps2
516                    loss = jnp.sum(natscale * truth_mask * crps)
517
518                elif loss_type == "evidential":
519                    evidence = loss_prms["evidence_key"]
520                    nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1)
521                    gamma = predicted
522                    nu = nu.reshape(shape_mask)
523                    alpha = alpha.reshape(shape_mask)
524                    beta = beta.reshape(shape_mask)
525                    nu = jnp.where(truth_mask, nu, 1.0)
526                    alpha = jnp.where(truth_mask, alpha, 1.0)
527                    beta = jnp.where(truth_mask, beta, 1.0)
528                    omega = 2 * beta * (1 + nu)
529                    lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln(
530                        alpha + 0.5
531                    )
532                    ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega)
533                    lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2)
534                    wst = (
535                        (beta * (1 + nu) / (alpha * nu)) ** 0.5
536                        if loss_prms.get("normalize_evidence", True)
537                        else 1.0
538                    )
539                    lr = (
540                        loss_prms.get("lambda_evidence", 1.0)
541                        * jnp.abs(gamma - ref)
542                        * nu
543                        / wst
544                    )
545                    r = loss_prms.get("evidence_ratio", 1.0)
546                    le = (
547                        loss_prms.get("lambda_evidence_diff", 0.0)
548                        * (nu - r * 2 * alpha) ** 2
549                    )
550                    lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta
551                    loss = lg + ls + lt + lr + le + lb
552
553                    loss = jnp.sum(natscale * loss * truth_mask)
554                elif loss_type == "raw":
555                    loss = jnp.sum(natscale * predicted)
556                else:
557                    raise ValueError(f"Unknown loss type: {loss_type}")
558
559                if "weight_schedule" in loss_prms:
560                    w = linear_schedule(epoch, *loss_prms["weight_schedule"])
561                else:
562                    w = loss_prms["weight"]
563
564                loss_tot = loss_tot + w * loss
565
566            return loss_tot, output
567
568        (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables)
569        updates, opt_st = optimizer.update(grad, opt_st, params=variables)
570        variables = optax.apply_updates(variables, updates)
571        if ema is not None:
572            if variables_ema is None or ema_st is None:
573                raise ValueError(
574                    "train_step was setup with ema but either variables_ema or ema_st was not provided"
575                )
576            variables_ema, ema_st = ema.update(variables, ema_st)
577            return loss, variables, opt_st, variables_ema, ema_st, o
578        else:
579            return loss, variables, opt_st, o
580
581    if jit:
582        return jax.jit(train_step)
583    return train_step
def get_validation_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, return_targets: bool = False, jit: bool = True):
586def get_validation_function(
587    loss_definition: Dict,
588    model: FENNIX,
589    evaluate: Callable,
590    model_ref: Optional[FENNIX] = None,
591    compute_ref_coords: bool = False,
592    return_targets: bool = False,
593    jit: bool = True,
594):
595
596    def validation(data, inputs, variables, inputs_ref=None):
597        if model_ref is not None:
598            output_ref = evaluate(model_ref, model_ref.variables, inputs)
599            # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
600        output = evaluate(model, variables, inputs)
601        # _, _, output = model._energy_and_forces(variables, data)
602        rmses = {}
603        maes = {}
604        if return_targets:
605            targets = {}
606
607        natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
608        nsys = inputs["natoms"].shape[0]
609        system_mask = inputs["true_sys"]
610        atom_mask = inputs["true_atoms"]
611        if "system_sign" in inputs:
612            system_sign = inputs["system_sign"] > 0
613            system_mask = jnp.logical_and(system_mask,system_sign)
614            atom_mask = jnp.logical_and(atom_mask, system_sign[inputs["batch_index"]])
615
616        for name, loss_prms in loss_definition.items():
617            do_validation = loss_prms.get("validate", True)
618            if not do_validation:
619                continue
620            predicted = output[loss_prms["key"]]
621            use_ref_mask = False
622            
623            if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
624                assert compute_ref_coords, "compute_ref_coords must be True"
625                # predicted = predicted - output_data_ref[loss_prms["key"]]
626                assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions"
627                shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
628                system_sign = inputs["system_sign"].reshape(*shape_mask)
629                predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys)
630
631            if "ref" in loss_prms:
632                if loss_prms["ref"].startswith("model_ref/"):
633                    assert model_ref is not None, "model_ref must be provided"
634                    ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
635                elif loss_prms["ref"].startswith("model/"):
636                    ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
637                else:
638                    ref = data[loss_prms["ref"]] * loss_prms["mult"]
639                    if loss_prms["ref"] + "_mask" in data:
640                        use_ref_mask = True
641                        ref_mask = data[loss_prms["ref"] + "_mask"]
642            else:
643                ref = jnp.zeros_like(predicted)
644            
645            if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
646                norm_axis = loss_prms["norm_axis"]
647                predicted = jnp.linalg.norm(
648                    predicted, axis=norm_axis, keepdims=True
649                )
650                ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
651
652            if loss_prms["type"].startswith("ensemble"):
653                axis = loss_prms.get("ensemble_axis", -1)
654                ensemble_weights_key = loss_prms.get("ensemble_weights", None)
655                if ensemble_weights_key is not None:
656                    ensemble_weights = output[ensemble_weights_key]
657                else:
658                    ensemble_weights = jnp.ones_like(predicted)
659                ensemble_weights = ensemble_weights / jnp.sum(
660                    ensemble_weights, axis=axis, keepdims=True
661                )
662
663                # predicted = predicted.mean(axis=axis)
664                predicted = (ensemble_weights * predicted).sum(axis=axis)
665
666            if loss_prms["type"] in ["gaussian_mixture"]:
667                axis = loss_prms.get("ensemble_axis", -1)
668                gm_w = output[loss_prms["key"] + "_w"]
669                gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True)
670                predicted = (predicted * gm_w).sum(axis=axis)
671
672            if predicted.ndim > 1 and predicted.shape[-1] == 1:
673                predicted = jnp.squeeze(predicted, axis=-1)
674
675            if ref.ndim > 1 and ref.shape[-1] == 1:
676                ref = jnp.squeeze(ref, axis=-1)
677
678            shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1)
679            if ref.shape[0] == output["batch_index"].shape[0]:
680                ## shape is number of atoms
681                truth_mask = atom_mask
682            elif ref.shape[0] == natoms.shape[0]:
683                ## shape is number of systems
684                truth_mask = system_mask
685                if loss_prms.get("per_atom_validation", False):
686                    ref = ref / natoms.reshape(*shape_mask)
687                    predicted = predicted / natoms.reshape(*shape_mask)
688            else:
689                truth_mask = jnp.ones(ref.shape[0], dtype=bool)
690
691            if use_ref_mask:
692                truth_mask = truth_mask * ref_mask.astype(bool)
693
694            nel = jnp.maximum(
695                (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
696                * jnp.sum(truth_mask).astype(predicted.dtype),
697                1.0,
698            )
699
700            truth_mask = truth_mask.reshape(*shape_mask)
701
702            ref = ref * truth_mask
703            predicted = predicted * truth_mask
704
705            rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5
706            mae = jnp.sum(jnp.abs(predicted - ref) / nel)
707
708            rmses[name] = rmse
709            maes[name] = mae
710            if return_targets:
711                targets[name] = (predicted, ref, truth_mask)
712
713        if return_targets:
714            return rmses, maes, output, targets
715
716        return rmses, maes, output
717
718    if jit:
719        return jax.jit(validation)
720    return validation