fennol.training.utils

  1import jax
  2import jax.numpy as jnp
  3import numpy as np
  4from typing import Callable, Optional, Dict, List, Tuple, Union, Any, NamedTuple
  5import optax
  6from copy import deepcopy
  7from functools import partial
  8
  9from ..utils import deep_update
 10from ..utils.atomic_units import au
 11from ..models import FENNIX
 12        
 13
 14@partial(jax.jit, static_argnums=(1,2,3,4))
 15def linear_schedule(step, start_value,end_value, start_step, duration):
 16    return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value)
 17
 18
 19def get_training_parameters(
 20    parameters: Dict[str, any], stage: int = -1
 21) -> Dict[str, any]:
 22    params = deepcopy(parameters["training"])
 23    if "stages" not in params:
 24        return params
 25
 26    stages: dict = params.pop("stages")
 27    stage_keys = list(stages.keys())
 28    if stage < 0:
 29        stage = len(stage_keys) + stage
 30    assert stage >= 0 and stage < len(
 31        stage_keys
 32    ), f"Stage {stage} not found in training parameters"
 33    for i in range(stage + 1):
 34        ## remove end_event from previous stage ##
 35        if i > 0 and "end_event" in params:
 36            params.pop("end_event")
 37        ## incrementally update training parameters ##
 38        stage_params = stages[stage_keys[i]]
 39        params = deep_update(params, stage_params)
 40    return params
 41
 42
 43def get_loss_definition(
 44    training_parameters: Dict[str, any],
 45    model_energy_unit: str = "Ha",  # , manual_renames: List[str] = []
 46) -> Tuple[Dict[str, any], List[str], List[str]]:
 47    """
 48    Returns the loss definition and a list of renamed references.
 49
 50    Args:
 51        training_parameters (dict): A dictionary containing training parameters.
 52
 53    Returns:
 54        tuple: A tuple containing:
 55            - loss_definition (dict): A dictionary containing the loss definition.
 56            - rename_refs (list): A list of renamed references.
 57    """
 58    default_loss_type = training_parameters.get("default_loss_type", "log_cosh")
 59    # loss_definition = deepcopy(training_parameters["loss"])
 60    used_keys = []
 61    ref_keys = []
 62    energy_mult = au.get_multiplier(model_energy_unit)
 63    loss_definition = {}
 64    for k in training_parameters["loss"]:
 65        # loss_prms = loss_definition[k]
 66        loss_prms = deepcopy(training_parameters["loss"][k])
 67        if "energy_unit" in loss_prms:
 68            loss_prms["mult"] = energy_mult / au.get_multiplier(
 69                loss_prms["energy_unit"]
 70            )
 71            if "unit" in loss_prms:
 72                print(
 73                    "Warning: Both 'unit' and 'energy_unit' are defined for loss component",
 74                    k,
 75                    " -> using 'energy_unit'",
 76                )
 77            loss_prms["unit"] = loss_prms["energy_unit"]
 78        elif "unit" in loss_prms:
 79            loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"])
 80        else:
 81            loss_prms["mult"] = 1.0
 82        if "display_unit" not in loss_prms and "unit" in loss_prms:
 83            loss_prms["display_unit"] = loss_prms["unit"]
 84        au_convert = loss_prms["mult"] * au.get_multiplier(loss_prms.get("unit", "au"))
 85        loss_prms["display_mult"] = au_convert/au.get_multiplier(loss_prms.get("display_unit", "au"))
 86
 87        if "key" not in loss_prms:
 88            loss_prms["key"] = k
 89        if "type" not in loss_prms:
 90            loss_prms["type"] = default_loss_type
 91        if "weight" not in loss_prms:
 92            loss_prms["weight"] = 1.0
 93        assert loss_prms["weight"] >= 0.0, "Loss weight must be positive"
 94        if "threshold" in loss_prms:
 95            assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0"
 96        if "ref" in loss_prms:
 97            ref = loss_prms["ref"]
 98            if not (ref.startswith("model_ref/") or ref.startswith("model/")):
 99                ref_keys.append(ref)
100        if "ds_weight" in loss_prms:
101            ref_keys.append(loss_prms["ds_weight"])
102
103        if "weight_start" in loss_prms:
104            weight_start = loss_prms["weight_start"]
105            if "weight_ramp" in loss_prms:
106                weight_ramp = loss_prms["weight_ramp"]
107            else:
108                weight_ramp = training_parameters.get("max_epochs")
109            weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0)
110            weight_end = loss_prms["weight"]
111            print(f"Weight ramp for {k} : {weight_start} -> {weight_end} in {weight_ramp} epochs")
112            loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp)
113
114        used_keys.append(loss_prms["key"])
115        loss_definition[k] = loss_prms
116
117    # rename_refs = list(
118    #     set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys)
119    # )
120
121    # for k in loss_definition.keys():
122    #     loss_prms = loss_definition[k]
123    #     if "ref" in loss_prms:
124    #         if loss_prms["ref"] in rename_refs:
125    #             loss_prms["ref"] = "true_" + loss_prms["ref"]
126
127    return loss_definition, list(set(used_keys)), list(set(ref_keys))
128
129
130def get_train_step_function(
131    loss_definition: Dict,
132    model: FENNIX,
133    evaluate: Callable,
134    optimizer: optax.GradientTransformation,
135    ema: Optional[optax.GradientTransformation] = None,
136    model_ref: Optional[FENNIX] = None,
137    compute_ref_coords: bool = False,
138    jit: bool = True,
139):
140    def train_step(
141        data,
142        inputs,
143        training_state,
144    ):
145        epoch = training_state["epoch"]
146
147        def loss_fn(variables):
148            if model_ref is not None:
149                output_ref = evaluate(model_ref, model_ref.variables, inputs)
150                # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
151            output = evaluate(model, variables, inputs)
152            # _, _, output = model._energy_and_forces(variables, data)
153            natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
154            nsys = inputs["natoms"].shape[0]
155            system_mask = inputs["true_sys"]
156            atom_mask = inputs["true_atoms"]
157            batch_index = inputs["batch_index"]
158            if "system_sign" in inputs:
159                system_sign = inputs["system_sign"] > 0
160                system_mask = jnp.logical_and(system_mask,system_sign)
161                atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index])
162
163            loss_tot = 0.0
164            for loss_prms in loss_definition.values():
165                use_ref_mask = False
166                predicted = output[loss_prms["key"]]
167                
168
169                if "ref" in loss_prms:
170                    if loss_prms["ref"].startswith("model_ref/"):
171                        assert model_ref is not None, "model_ref must be provided"
172                        try:
173                            ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
174                        except KeyError:
175                            raise KeyError(
176                                f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}"
177                            )
178                        mask_key  = loss_prms["ref"][6:] + "_mask" 
179                        if mask_key in data:
180                            use_ref_mask = True
181                            ref_mask = data[mask_key]
182                        elif mask_key in output_ref:
183                            use_ref_mask = True
184                            ref_mask = output[mask_key]
185                        elif mask_key in output:
186                            use_ref_mask = True
187                            ref_mask = output[mask_key]
188                    elif loss_prms["ref"].startswith("model/"):
189                        try:
190                            ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
191                        except KeyError:
192                            raise KeyError(
193                                f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}"
194                            )
195                        mask_key  = loss_prms["ref"][6:] + "_mask" 
196                        if mask_key in data:
197                            use_ref_mask = True
198                            ref_mask = data[mask_key]
199                        elif mask_key in output:
200                            use_ref_mask = True
201                            ref_mask = output[mask_key]
202                    else:
203                        try:
204                            ref = data[loss_prms["ref"]] * loss_prms["mult"]
205                        except KeyError:
206                            raise KeyError(
207                                f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}"
208                            )
209                        if loss_prms["ref"] + "_mask" in data:
210                            use_ref_mask = True
211                            ref_mask = data[loss_prms["ref"] + "_mask"]
212                else:
213                    ref = jnp.zeros_like(predicted)
214
215                if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
216                    norm_axis = loss_prms["norm_axis"]
217                    predicted = jnp.linalg.norm(
218                        predicted, axis=norm_axis, keepdims=True
219                    )
220                    ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
221
222                ensemble_loss = loss_prms["type"] in ["ensemble_nll","ensemble_crps","gaussian_mixture",]
223                if ensemble_loss:
224                    ensemble_axis = loss_prms.get("ensemble_axis", -1)
225                    ensemble_weights_key = loss_prms.get("ensemble_weights", None)
226                    ensemble_size = predicted.shape[ensemble_axis]
227                    if ensemble_weights_key is not None:
228                        ensemble_weights = output[ensemble_weights_key]
229                    else:
230                        ensemble_weights = jnp.ones_like(predicted)
231                    if "ensemble_subsample" in loss_prms and "rng_key" in output:
232                        ns = min(
233                            loss_prms["ensemble_subsample"],
234                            predicted.shape[ensemble_axis],
235                        )
236                        key, subkey = jax.random.split(output["rng_key"])
237
238                        def get_subsample(x):
239                            return jax.lax.slice_in_dim(
240                                jax.random.permutation(
241                                    subkey, x, axis=ensemble_axis, independent=True
242                                ),
243                                start_index=0,
244                                limit_index=ns,
245                                axis=ensemble_axis,
246                            )
247
248                        predicted = get_subsample(predicted)
249                        ensemble_weights = get_subsample(ensemble_weights)
250                        output["rng_key"] = key
251                    else:
252                        get_subsample = lambda x: x
253                    ensemble_weights = ensemble_weights / jnp.sum(
254                        ensemble_weights, axis=ensemble_axis, keepdims=True
255                    )
256                    predicted_ensemble = predicted
257                    predicted = (predicted * ensemble_weights).sum(
258                        axis=ensemble_axis, keepdims=True
259                    )
260                    predicted_var = (
261                        ensemble_weights * (predicted_ensemble - predicted) ** 2
262                    ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
263                    predicted = jnp.squeeze(predicted, axis=ensemble_axis)
264                
265                if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
266                    assert compute_ref_coords, "compute_ref_coords must be True"
267                    # predicted = predicted - output_data_ref[loss_prms["key"]]
268                    # assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions"
269                    shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
270                    if predicted.shape[0] == nsys:
271                        predicted_sign_mask = inputs["system_sign"].reshape(*shape_mask) < 0
272                        predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,inputs["system_index"],nsys)
273                        predicted = predicted - predicted_shift
274                    elif predicted.shape[0] == batch_index.shape[0]:
275                        predicted_sign_mask = (inputs["system_sign"]<0)[batch_index].reshape(*shape_mask)
276                        natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)])
277                        iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index]
278                        iatdest = natshift[inputs["system_index"]][batch_index]+iat
279                        predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,iatdest,batch_index.shape[0])
280                        predicted = predicted - predicted_shift
281                    else:
282                        raise ValueError("remove_ref_sys only works with system-level or atom-level predictions")
283                    
284                    if ensemble_loss:
285                        predicted_ensemble = predicted_ensemble - jnp.expand_dims(predicted_shift,ensemble_axis)
286                        predicted_var = (
287                            ensemble_weights * (predicted_ensemble - jnp.expand_dims(predicted,ensemble_axis)) ** 2
288                        ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
289
290                if predicted.ndim > 1 and predicted.shape[-1] == 1:
291                    predicted = jnp.squeeze(predicted, axis=-1)
292
293                if ref.ndim > 1 and ref.shape[-1] == 1:
294                    ref = jnp.squeeze(ref, axis=-1)
295
296                per_atom = False
297                shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1)
298                # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape)
299                natscale = 1.0
300                if ref.shape[0] == output["batch_index"].shape[0]:
301                    ## shape is number of atoms
302                    truth_mask = atom_mask
303                    if "ds_weight" in loss_prms:
304                        weight_key = loss_prms["ds_weight"]
305                        natscale=data[weight_key]
306                        if natscale.shape[0] == natoms.shape[0]:
307                            natscale = natscale[output["batch_index"]]
308                        natscale = natscale.reshape(*shape_mask)
309                elif ref.shape[0] == natoms.shape[0]:
310                    ## shape is number of systems
311                    truth_mask = system_mask
312                    if loss_prms.get("per_atom", False):
313                        per_atom = True
314                        ref = ref / natoms.reshape(*shape_mask)
315                        predicted = predicted / natoms.reshape(*shape_mask)
316                    if "nat_pow" in loss_prms:
317                        natscale = (
318                            1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"]
319                        )
320                    if "ds_weight" in loss_prms:
321                        weight_key = loss_prms["ds_weight"]
322                        natscale = natscale * data[weight_key].reshape(*shape_mask)
323                else:
324                    truth_mask = jnp.ones(ref.shape[0], dtype=bool)
325
326                if use_ref_mask:
327                    truth_mask = truth_mask * ref_mask.astype(bool)
328
329                nel = jnp.maximum(
330                    (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
331                    * jnp.sum(truth_mask).astype(jnp.float32),
332                    1.0,
333                )
334                truth_mask = truth_mask.reshape(*shape_mask)
335
336                ref = ref * truth_mask
337                predicted = predicted * truth_mask
338
339                natscale = natscale / nel
340
341                loss_type = loss_prms["type"]
342                if loss_type == "mse":
343                    loss = jnp.sum(natscale * (predicted - ref) ** 2)
344                elif loss_type == "log_cosh":
345                    loss = jnp.sum(natscale * optax.log_cosh(predicted, ref))
346                elif loss_type == "mae":
347                    loss = jnp.sum(natscale * jnp.abs(predicted - ref))
348                elif loss_type == "rel_mse":
349                    eps = loss_prms.get("rel_eps", 1.0e-6)
350                    rel_pow = loss_prms.get("rel_pow", 2.0)
351                    loss = jnp.sum(
352                        natscale
353                        * (predicted - ref) ** 2
354                        / (eps + jnp.abs(ref) ** rel_pow)
355                    )
356                elif loss_type == "rmse+mae":
357                    loss = (
358                        jnp.sum(natscale * (predicted - ref) ** 2)
359                    ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref))
360                elif loss_type == "crps":
361                    predicted_var_key = loss_prms.get(
362                        "var_key", loss_prms["key"] + "_var"
363                    )
364                    lamba_crps = loss_prms.get("lambda_crps", 1.0)
365                    predicted_var = output[predicted_var_key]
366                    if per_atom:
367                        predicted_var = predicted_var / natoms.reshape(*shape_mask)**2
368                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
369                    sigma = predicted_var**0.5
370                    dy = (ref - predicted) / sigma
371                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
372                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
373                    loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5))
374                elif loss_type == "ensemble_nll":
375                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
376                    loss = 0.5 * jnp.sum(
377                        natscale
378                        * truth_mask
379                        * (
380                            jnp.log(predicted_var)
381                            + (ref - predicted) ** 2 / predicted_var
382                        )
383                    )
384                elif loss_type == "ensemble_crps":
385                    if per_atom:
386                        predicted_var = predicted_var / natoms.reshape(*shape_mask)**2
387                    lamba_crps = loss_prms.get("lambda_crps", 1.0)
388                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
389                    sigma = predicted_var**0.5
390                    dy = (ref - predicted) / sigma
391                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
392                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
393                    loss = jnp.sum(
394                        natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5)
395                    )
396                elif loss_type == "gaussian_mixture":
397                    gm_w = get_subsample(output[loss_prms["key"] + "_w"])
398                    gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True)
399                    sigma = get_subsample(output[loss_prms["key"] + "_sigma"])
400                    ref = jnp.expand_dims(ref, axis=ensemble_axis)
401                    # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2)
402                    # log_w =jnp.log(gm_w + 1.e-10)
403                    # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis)
404                    # loss = jnp.sum(natscale * truth_mask * negloglikelihood)
405
406                    # CRPS loss
407                    def compute_crps(mu, sigma):
408                        dy = mu / sigma
409                        Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
410                        phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
411                        return sigma * (dy * (2 * Phi - 1.0) + 2 * phi - 1./jnp.pi**0.5)
412
413                    crps1 = compute_crps(ref - predicted_ensemble, sigma)
414                    gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis)
415
416                    if ensemble_axis < 0:
417                        ensemble_axis = predicted_ensemble.ndim + ensemble_axis
418                    muij = jnp.expand_dims(
419                        predicted_ensemble, axis=ensemble_axis
420                    ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1)
421                    sigma2 = sigma**2
422                    sigmaij = (
423                        jnp.expand_dims(sigma2, axis=ensemble_axis)
424                        + jnp.expand_dims(sigma2, axis=ensemble_axis + 1)
425                    ) ** 0.5
426                    wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims(
427                        gm_w, axis=ensemble_axis + 1
428                    )
429                    crps2 = compute_crps(muij, sigmaij)
430                    gm_crps2 = (wij * crps2).sum(
431                        axis=(ensemble_axis, ensemble_axis + 1)
432                    )
433
434                    crps = gm_crps1 - 0.5 * gm_crps2
435                    loss = jnp.sum(natscale * truth_mask * crps)
436
437                elif loss_type == "evidential":
438                    evidence = loss_prms["evidence_key"]
439                    nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1)
440                    gamma = predicted
441                    nu = nu.reshape(shape_mask)
442                    alpha = alpha.reshape(shape_mask)
443                    beta = beta.reshape(shape_mask)
444                    nu = jnp.where(truth_mask, nu, 1.0)
445                    alpha = jnp.where(truth_mask, alpha, 1.0)
446                    beta = jnp.where(truth_mask, beta, 1.0)
447                    omega = 2 * beta * (1 + nu)
448                    lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln(
449                        alpha + 0.5
450                    )
451                    ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega)
452                    lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2)
453                    wst = (
454                        (beta * (1 + nu) / (alpha * nu)) ** 0.5
455                        if loss_prms.get("normalize_evidence", True)
456                        else 1.0
457                    )
458                    lr = (
459                        loss_prms.get("lambda_evidence", 1.0)
460                        * jnp.abs(gamma - ref)
461                        * nu
462                        / wst
463                    )
464                    r = loss_prms.get("evidence_ratio", 1.0)
465                    le = (
466                        loss_prms.get("lambda_evidence_diff", 0.0)
467                        * (nu - r * 2 * alpha) ** 2
468                    )
469                    lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta
470                    loss = lg + ls + lt + lr + le + lb
471
472                    loss = jnp.sum(natscale * loss * truth_mask)
473                elif loss_type == "raw":
474                    loss = jnp.sum(natscale * predicted)
475                elif loss_type == "binary_crossentropy":
476                    ref = ref.astype(predicted.dtype)
477                    if loss_prms.get("raw_logits",False):
478                        loss = jnp.sum(
479                            natscale*truth_mask
480                            * (
481                                -ref * jax.nn.log_sigmoid(predicted)
482                                - (1.0 - ref) * jax.nn.log_sigmoid(-predicted)
483                            )
484                        )
485                    else:
486                        loss = jnp.sum(
487                            natscale*truth_mask
488                            * (
489                                -ref * jnp.log(predicted + 1.0e-10)
490                                - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10)
491                            )
492                        )
493                else:
494                    raise ValueError(f"Unknown loss type: {loss_type}")
495
496                if "weight_schedule" in loss_prms:
497                    w = linear_schedule(epoch, *loss_prms["weight_schedule"])
498                else:
499                    w = loss_prms["weight"]
500
501                loss_tot = loss_tot + w * loss
502
503            return loss_tot, output
504
505        variables = training_state["variables"]
506        new_training_state = {**training_state, "step":training_state["step"]+1}
507
508        (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables)
509        updates, new_training_state["opt_state"] = optimizer.update(grad, training_state["opt_state"], params=variables)
510        variables = optax.apply_updates(variables, updates)
511        new_training_state["variables"] = variables
512
513        if ema is not None:
514            ema_st = training_state.get("ema_state",None)
515            if ema_st is None:
516                raise ValueError(
517                    "train_step was setup with ema but either variables_ema or ema_st was not provided"
518                )
519            new_training_state["variables_ema"], new_training_state["ema_state"] = ema.update(variables, ema_st)
520    
521        return loss, new_training_state, o
522
523    if jit:
524        return jax.jit(train_step)
525    return train_step
526
527
528def get_validation_function(
529    loss_definition: Dict,
530    model: FENNIX,
531    evaluate: Callable,
532    model_ref: Optional[FENNIX] = None,
533    compute_ref_coords: bool = False,
534    return_targets: bool = False,
535    jit: bool = True,
536):
537
538    def validation(data, inputs, variables, inputs_ref=None):
539        if model_ref is not None:
540            output_ref = evaluate(model_ref, model_ref.variables, inputs)
541            # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
542        output = evaluate(model, variables, inputs)
543        # _, _, output = model._energy_and_forces(variables, data)
544        rmses = {}
545        maes = {}
546        if return_targets:
547            targets = {}
548
549        natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
550        nsys = inputs["natoms"].shape[0]
551        system_mask = inputs["true_sys"]
552        atom_mask = inputs["true_atoms"]
553        batch_index = inputs["batch_index"]
554        if "system_sign" in inputs:
555            system_sign = inputs["system_sign"] > 0
556            system_mask = jnp.logical_and(system_mask,system_sign)
557            atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index])
558
559        for name, loss_prms in loss_definition.items():
560            do_validation = loss_prms.get("validate", True)
561            if not do_validation:
562                continue
563            predicted = output[loss_prms["key"]]
564            use_ref_mask = False
565            
566            if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
567                assert compute_ref_coords, "compute_ref_coords must be True"
568                # predicted = predicted - output_data_ref[loss_prms["key"]]
569                shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
570                if predicted.shape[0] == nsys:
571                    system_sign = inputs["system_sign"].reshape(*shape_mask)
572                    predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys)
573                elif predicted.shape[0] == batch_index.shape[0]:
574                    atom_sign = inputs["system_sign"][batch_index].reshape(*shape_mask)
575                    natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)])
576                    iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index]
577                    iatdest = natshift[inputs["system_index"]][batch_index]+iat
578                    predicted = jax.ops.segment_sum(atom_sign*predicted,iatdest,batch_index.shape[0])
579                else:
580                    raise ValueError("remove_ref_sys only works with system-level or atom-level predictions")
581
582            if "ref" in loss_prms:
583                if loss_prms["ref"].startswith("model_ref/"):
584                    assert model_ref is not None, "model_ref must be provided"
585                    ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
586                    mask_key  = loss_prms["ref"][6:] + "_mask" 
587                    if mask_key in data:
588                        use_ref_mask = True
589                        ref_mask = data[mask_key]
590                    elif mask_key in output_ref:
591                        use_ref_mask = True
592                        ref_mask = output[mask_key]
593                    elif mask_key in output:
594                        use_ref_mask = True
595                        ref_mask = output[mask_key]
596                elif loss_prms["ref"].startswith("model/"):
597                    ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
598                    mask_key  = loss_prms["ref"][6:] + "_mask" 
599                    if mask_key in data:
600                        use_ref_mask = True
601                        ref_mask = data[mask_key]
602                    elif mask_key in output:
603                        use_ref_mask = True
604                        ref_mask = output[mask_key]
605                else:
606                    ref = data[loss_prms["ref"]] * loss_prms["mult"]
607                    if loss_prms["ref"] + "_mask" in data:
608                        use_ref_mask = True
609                        ref_mask = data[loss_prms["ref"] + "_mask"]
610            else:
611                ref = jnp.zeros_like(predicted)
612            
613            if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
614                norm_axis = loss_prms["norm_axis"]
615                predicted = jnp.linalg.norm(
616                    predicted, axis=norm_axis, keepdims=True
617                )
618                ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
619
620            loss_type = loss_prms["type"].lower()
621            if loss_type.startswith("ensemble"):
622                axis = loss_prms.get("ensemble_axis", -1)
623                ensemble_weights_key = loss_prms.get("ensemble_weights", None)
624                if ensemble_weights_key is not None:
625                    ensemble_weights = output[ensemble_weights_key]
626                else:
627                    ensemble_weights = jnp.ones_like(predicted)
628                ensemble_weights = ensemble_weights / jnp.sum(
629                    ensemble_weights, axis=axis, keepdims=True
630                )
631
632                # predicted = predicted.mean(axis=axis)
633                predicted = (ensemble_weights * predicted).sum(axis=axis)
634
635            if loss_prms["type"] in ["gaussian_mixture"]:
636                axis = loss_prms.get("ensemble_axis", -1)
637                gm_w = output[loss_prms["key"] + "_w"]
638                gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True)
639                predicted = (predicted * gm_w).sum(axis=axis)
640
641            if predicted.ndim > 1 and predicted.shape[-1] == 1:
642                predicted = jnp.squeeze(predicted, axis=-1)
643
644            if ref.ndim > 1 and ref.shape[-1] == 1:
645                ref = jnp.squeeze(ref, axis=-1)
646
647            shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1)
648            if ref.shape[0] == output["batch_index"].shape[0]:
649                ## shape is number of atoms
650                truth_mask = atom_mask
651            elif ref.shape[0] == natoms.shape[0]:
652                ## shape is number of systems
653                truth_mask = system_mask
654                if loss_prms.get("per_atom_validation", False):
655                    ref = ref / natoms.reshape(*shape_mask)
656                    predicted = predicted / natoms.reshape(*shape_mask)
657            else:
658                truth_mask = jnp.ones(ref.shape[0], dtype=bool)
659
660            if use_ref_mask:
661                truth_mask = truth_mask * ref_mask.astype(bool)
662
663            nel = jnp.maximum(
664                (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
665                * jnp.sum(truth_mask).astype(predicted.dtype),
666                1.0,
667            )
668
669            truth_mask = truth_mask.reshape(*shape_mask)
670
671            ref = ref * truth_mask
672            predicted = predicted * truth_mask
673
674            if loss_type == "binary_crossentropy":
675                ref = ref.astype(predicted.dtype)
676                if loss_prms.get("raw_logits",False):
677                    rmse = jnp.sum(
678                            (
679                                -ref * jax.nn.log_sigmoid(predicted)
680                                - (1.0 - ref) * jax.nn.log_sigmoid(-predicted)
681                            )*truth_mask/nel
682                        )
683                    predicted = truth_mask*jax.nn.sigmoid(predicted)
684                else:
685                    rmse = jnp.sum(
686                            (
687                                -ref * jnp.log(predicted + 1.0e-10)
688                                - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10)
689                            )*truth_mask/nel
690                        )
691                thr = loss_prms.get("class_threshold", 0.05)
692                mae = jnp.sum(jnp.where(jnp.abs(predicted - ref) > thr, 1, 0)/nel)*100
693            else:
694                rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5
695                mae = jnp.sum(jnp.abs(predicted - ref) / nel)
696
697            rmses[name] = rmse
698            maes[name] = mae
699            if return_targets:
700                targets[name] = (predicted, ref, truth_mask)
701
702        if return_targets:
703            return rmses, maes, output, targets
704
705        return rmses, maes, output
706
707    if jit:
708        return jax.jit(validation)
709    return validation
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def linear_schedule(step, start_value, end_value, start_step, duration):
15@partial(jax.jit, static_argnums=(1,2,3,4))
16def linear_schedule(step, start_value,end_value, start_step, duration):
17    return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value)
def get_training_parameters( parameters: Dict[str, <built-in function any>], stage: int = -1) -> Dict[str, <built-in function any>]:
20def get_training_parameters(
21    parameters: Dict[str, any], stage: int = -1
22) -> Dict[str, any]:
23    params = deepcopy(parameters["training"])
24    if "stages" not in params:
25        return params
26
27    stages: dict = params.pop("stages")
28    stage_keys = list(stages.keys())
29    if stage < 0:
30        stage = len(stage_keys) + stage
31    assert stage >= 0 and stage < len(
32        stage_keys
33    ), f"Stage {stage} not found in training parameters"
34    for i in range(stage + 1):
35        ## remove end_event from previous stage ##
36        if i > 0 and "end_event" in params:
37            params.pop("end_event")
38        ## incrementally update training parameters ##
39        stage_params = stages[stage_keys[i]]
40        params = deep_update(params, stage_params)
41    return params
def get_loss_definition( training_parameters: Dict[str, <built-in function any>], model_energy_unit: str = 'Ha') -> Tuple[Dict[str, <built-in function any>], List[str], List[str]]:
 44def get_loss_definition(
 45    training_parameters: Dict[str, any],
 46    model_energy_unit: str = "Ha",  # , manual_renames: List[str] = []
 47) -> Tuple[Dict[str, any], List[str], List[str]]:
 48    """
 49    Returns the loss definition and a list of renamed references.
 50
 51    Args:
 52        training_parameters (dict): A dictionary containing training parameters.
 53
 54    Returns:
 55        tuple: A tuple containing:
 56            - loss_definition (dict): A dictionary containing the loss definition.
 57            - rename_refs (list): A list of renamed references.
 58    """
 59    default_loss_type = training_parameters.get("default_loss_type", "log_cosh")
 60    # loss_definition = deepcopy(training_parameters["loss"])
 61    used_keys = []
 62    ref_keys = []
 63    energy_mult = au.get_multiplier(model_energy_unit)
 64    loss_definition = {}
 65    for k in training_parameters["loss"]:
 66        # loss_prms = loss_definition[k]
 67        loss_prms = deepcopy(training_parameters["loss"][k])
 68        if "energy_unit" in loss_prms:
 69            loss_prms["mult"] = energy_mult / au.get_multiplier(
 70                loss_prms["energy_unit"]
 71            )
 72            if "unit" in loss_prms:
 73                print(
 74                    "Warning: Both 'unit' and 'energy_unit' are defined for loss component",
 75                    k,
 76                    " -> using 'energy_unit'",
 77                )
 78            loss_prms["unit"] = loss_prms["energy_unit"]
 79        elif "unit" in loss_prms:
 80            loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"])
 81        else:
 82            loss_prms["mult"] = 1.0
 83        if "display_unit" not in loss_prms and "unit" in loss_prms:
 84            loss_prms["display_unit"] = loss_prms["unit"]
 85        au_convert = loss_prms["mult"] * au.get_multiplier(loss_prms.get("unit", "au"))
 86        loss_prms["display_mult"] = au_convert/au.get_multiplier(loss_prms.get("display_unit", "au"))
 87
 88        if "key" not in loss_prms:
 89            loss_prms["key"] = k
 90        if "type" not in loss_prms:
 91            loss_prms["type"] = default_loss_type
 92        if "weight" not in loss_prms:
 93            loss_prms["weight"] = 1.0
 94        assert loss_prms["weight"] >= 0.0, "Loss weight must be positive"
 95        if "threshold" in loss_prms:
 96            assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0"
 97        if "ref" in loss_prms:
 98            ref = loss_prms["ref"]
 99            if not (ref.startswith("model_ref/") or ref.startswith("model/")):
100                ref_keys.append(ref)
101        if "ds_weight" in loss_prms:
102            ref_keys.append(loss_prms["ds_weight"])
103
104        if "weight_start" in loss_prms:
105            weight_start = loss_prms["weight_start"]
106            if "weight_ramp" in loss_prms:
107                weight_ramp = loss_prms["weight_ramp"]
108            else:
109                weight_ramp = training_parameters.get("max_epochs")
110            weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0)
111            weight_end = loss_prms["weight"]
112            print(f"Weight ramp for {k} : {weight_start} -> {weight_end} in {weight_ramp} epochs")
113            loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp)
114
115        used_keys.append(loss_prms["key"])
116        loss_definition[k] = loss_prms
117
118    # rename_refs = list(
119    #     set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys)
120    # )
121
122    # for k in loss_definition.keys():
123    #     loss_prms = loss_definition[k]
124    #     if "ref" in loss_prms:
125    #         if loss_prms["ref"] in rename_refs:
126    #             loss_prms["ref"] = "true_" + loss_prms["ref"]
127
128    return loss_definition, list(set(used_keys)), list(set(ref_keys))

Returns the loss definition and a list of renamed references.

Args: training_parameters (dict): A dictionary containing training parameters.

Returns: tuple: A tuple containing: - loss_definition (dict): A dictionary containing the loss definition. - rename_refs (list): A list of renamed references.

def get_train_step_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, optimizer: optax._src.base.GradientTransformation, ema: Optional[optax._src.base.GradientTransformation] = None, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, jit: bool = True):
131def get_train_step_function(
132    loss_definition: Dict,
133    model: FENNIX,
134    evaluate: Callable,
135    optimizer: optax.GradientTransformation,
136    ema: Optional[optax.GradientTransformation] = None,
137    model_ref: Optional[FENNIX] = None,
138    compute_ref_coords: bool = False,
139    jit: bool = True,
140):
141    def train_step(
142        data,
143        inputs,
144        training_state,
145    ):
146        epoch = training_state["epoch"]
147
148        def loss_fn(variables):
149            if model_ref is not None:
150                output_ref = evaluate(model_ref, model_ref.variables, inputs)
151                # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
152            output = evaluate(model, variables, inputs)
153            # _, _, output = model._energy_and_forces(variables, data)
154            natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
155            nsys = inputs["natoms"].shape[0]
156            system_mask = inputs["true_sys"]
157            atom_mask = inputs["true_atoms"]
158            batch_index = inputs["batch_index"]
159            if "system_sign" in inputs:
160                system_sign = inputs["system_sign"] > 0
161                system_mask = jnp.logical_and(system_mask,system_sign)
162                atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index])
163
164            loss_tot = 0.0
165            for loss_prms in loss_definition.values():
166                use_ref_mask = False
167                predicted = output[loss_prms["key"]]
168                
169
170                if "ref" in loss_prms:
171                    if loss_prms["ref"].startswith("model_ref/"):
172                        assert model_ref is not None, "model_ref must be provided"
173                        try:
174                            ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
175                        except KeyError:
176                            raise KeyError(
177                                f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}"
178                            )
179                        mask_key  = loss_prms["ref"][6:] + "_mask" 
180                        if mask_key in data:
181                            use_ref_mask = True
182                            ref_mask = data[mask_key]
183                        elif mask_key in output_ref:
184                            use_ref_mask = True
185                            ref_mask = output[mask_key]
186                        elif mask_key in output:
187                            use_ref_mask = True
188                            ref_mask = output[mask_key]
189                    elif loss_prms["ref"].startswith("model/"):
190                        try:
191                            ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
192                        except KeyError:
193                            raise KeyError(
194                                f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}"
195                            )
196                        mask_key  = loss_prms["ref"][6:] + "_mask" 
197                        if mask_key in data:
198                            use_ref_mask = True
199                            ref_mask = data[mask_key]
200                        elif mask_key in output:
201                            use_ref_mask = True
202                            ref_mask = output[mask_key]
203                    else:
204                        try:
205                            ref = data[loss_prms["ref"]] * loss_prms["mult"]
206                        except KeyError:
207                            raise KeyError(
208                                f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}"
209                            )
210                        if loss_prms["ref"] + "_mask" in data:
211                            use_ref_mask = True
212                            ref_mask = data[loss_prms["ref"] + "_mask"]
213                else:
214                    ref = jnp.zeros_like(predicted)
215
216                if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
217                    norm_axis = loss_prms["norm_axis"]
218                    predicted = jnp.linalg.norm(
219                        predicted, axis=norm_axis, keepdims=True
220                    )
221                    ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
222
223                ensemble_loss = loss_prms["type"] in ["ensemble_nll","ensemble_crps","gaussian_mixture",]
224                if ensemble_loss:
225                    ensemble_axis = loss_prms.get("ensemble_axis", -1)
226                    ensemble_weights_key = loss_prms.get("ensemble_weights", None)
227                    ensemble_size = predicted.shape[ensemble_axis]
228                    if ensemble_weights_key is not None:
229                        ensemble_weights = output[ensemble_weights_key]
230                    else:
231                        ensemble_weights = jnp.ones_like(predicted)
232                    if "ensemble_subsample" in loss_prms and "rng_key" in output:
233                        ns = min(
234                            loss_prms["ensemble_subsample"],
235                            predicted.shape[ensemble_axis],
236                        )
237                        key, subkey = jax.random.split(output["rng_key"])
238
239                        def get_subsample(x):
240                            return jax.lax.slice_in_dim(
241                                jax.random.permutation(
242                                    subkey, x, axis=ensemble_axis, independent=True
243                                ),
244                                start_index=0,
245                                limit_index=ns,
246                                axis=ensemble_axis,
247                            )
248
249                        predicted = get_subsample(predicted)
250                        ensemble_weights = get_subsample(ensemble_weights)
251                        output["rng_key"] = key
252                    else:
253                        get_subsample = lambda x: x
254                    ensemble_weights = ensemble_weights / jnp.sum(
255                        ensemble_weights, axis=ensemble_axis, keepdims=True
256                    )
257                    predicted_ensemble = predicted
258                    predicted = (predicted * ensemble_weights).sum(
259                        axis=ensemble_axis, keepdims=True
260                    )
261                    predicted_var = (
262                        ensemble_weights * (predicted_ensemble - predicted) ** 2
263                    ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
264                    predicted = jnp.squeeze(predicted, axis=ensemble_axis)
265                
266                if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
267                    assert compute_ref_coords, "compute_ref_coords must be True"
268                    # predicted = predicted - output_data_ref[loss_prms["key"]]
269                    # assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions"
270                    shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
271                    if predicted.shape[0] == nsys:
272                        predicted_sign_mask = inputs["system_sign"].reshape(*shape_mask) < 0
273                        predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,inputs["system_index"],nsys)
274                        predicted = predicted - predicted_shift
275                    elif predicted.shape[0] == batch_index.shape[0]:
276                        predicted_sign_mask = (inputs["system_sign"]<0)[batch_index].reshape(*shape_mask)
277                        natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)])
278                        iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index]
279                        iatdest = natshift[inputs["system_index"]][batch_index]+iat
280                        predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,iatdest,batch_index.shape[0])
281                        predicted = predicted - predicted_shift
282                    else:
283                        raise ValueError("remove_ref_sys only works with system-level or atom-level predictions")
284                    
285                    if ensemble_loss:
286                        predicted_ensemble = predicted_ensemble - jnp.expand_dims(predicted_shift,ensemble_axis)
287                        predicted_var = (
288                            ensemble_weights * (predicted_ensemble - jnp.expand_dims(predicted,ensemble_axis)) ** 2
289                        ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
290
291                if predicted.ndim > 1 and predicted.shape[-1] == 1:
292                    predicted = jnp.squeeze(predicted, axis=-1)
293
294                if ref.ndim > 1 and ref.shape[-1] == 1:
295                    ref = jnp.squeeze(ref, axis=-1)
296
297                per_atom = False
298                shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1)
299                # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape)
300                natscale = 1.0
301                if ref.shape[0] == output["batch_index"].shape[0]:
302                    ## shape is number of atoms
303                    truth_mask = atom_mask
304                    if "ds_weight" in loss_prms:
305                        weight_key = loss_prms["ds_weight"]
306                        natscale=data[weight_key]
307                        if natscale.shape[0] == natoms.shape[0]:
308                            natscale = natscale[output["batch_index"]]
309                        natscale = natscale.reshape(*shape_mask)
310                elif ref.shape[0] == natoms.shape[0]:
311                    ## shape is number of systems
312                    truth_mask = system_mask
313                    if loss_prms.get("per_atom", False):
314                        per_atom = True
315                        ref = ref / natoms.reshape(*shape_mask)
316                        predicted = predicted / natoms.reshape(*shape_mask)
317                    if "nat_pow" in loss_prms:
318                        natscale = (
319                            1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"]
320                        )
321                    if "ds_weight" in loss_prms:
322                        weight_key = loss_prms["ds_weight"]
323                        natscale = natscale * data[weight_key].reshape(*shape_mask)
324                else:
325                    truth_mask = jnp.ones(ref.shape[0], dtype=bool)
326
327                if use_ref_mask:
328                    truth_mask = truth_mask * ref_mask.astype(bool)
329
330                nel = jnp.maximum(
331                    (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
332                    * jnp.sum(truth_mask).astype(jnp.float32),
333                    1.0,
334                )
335                truth_mask = truth_mask.reshape(*shape_mask)
336
337                ref = ref * truth_mask
338                predicted = predicted * truth_mask
339
340                natscale = natscale / nel
341
342                loss_type = loss_prms["type"]
343                if loss_type == "mse":
344                    loss = jnp.sum(natscale * (predicted - ref) ** 2)
345                elif loss_type == "log_cosh":
346                    loss = jnp.sum(natscale * optax.log_cosh(predicted, ref))
347                elif loss_type == "mae":
348                    loss = jnp.sum(natscale * jnp.abs(predicted - ref))
349                elif loss_type == "rel_mse":
350                    eps = loss_prms.get("rel_eps", 1.0e-6)
351                    rel_pow = loss_prms.get("rel_pow", 2.0)
352                    loss = jnp.sum(
353                        natscale
354                        * (predicted - ref) ** 2
355                        / (eps + jnp.abs(ref) ** rel_pow)
356                    )
357                elif loss_type == "rmse+mae":
358                    loss = (
359                        jnp.sum(natscale * (predicted - ref) ** 2)
360                    ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref))
361                elif loss_type == "crps":
362                    predicted_var_key = loss_prms.get(
363                        "var_key", loss_prms["key"] + "_var"
364                    )
365                    lamba_crps = loss_prms.get("lambda_crps", 1.0)
366                    predicted_var = output[predicted_var_key]
367                    if per_atom:
368                        predicted_var = predicted_var / natoms.reshape(*shape_mask)**2
369                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
370                    sigma = predicted_var**0.5
371                    dy = (ref - predicted) / sigma
372                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
373                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
374                    loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5))
375                elif loss_type == "ensemble_nll":
376                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
377                    loss = 0.5 * jnp.sum(
378                        natscale
379                        * truth_mask
380                        * (
381                            jnp.log(predicted_var)
382                            + (ref - predicted) ** 2 / predicted_var
383                        )
384                    )
385                elif loss_type == "ensemble_crps":
386                    if per_atom:
387                        predicted_var = predicted_var / natoms.reshape(*shape_mask)**2
388                    lamba_crps = loss_prms.get("lambda_crps", 1.0)
389                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
390                    sigma = predicted_var**0.5
391                    dy = (ref - predicted) / sigma
392                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
393                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
394                    loss = jnp.sum(
395                        natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5)
396                    )
397                elif loss_type == "gaussian_mixture":
398                    gm_w = get_subsample(output[loss_prms["key"] + "_w"])
399                    gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True)
400                    sigma = get_subsample(output[loss_prms["key"] + "_sigma"])
401                    ref = jnp.expand_dims(ref, axis=ensemble_axis)
402                    # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2)
403                    # log_w =jnp.log(gm_w + 1.e-10)
404                    # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis)
405                    # loss = jnp.sum(natscale * truth_mask * negloglikelihood)
406
407                    # CRPS loss
408                    def compute_crps(mu, sigma):
409                        dy = mu / sigma
410                        Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
411                        phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
412                        return sigma * (dy * (2 * Phi - 1.0) + 2 * phi - 1./jnp.pi**0.5)
413
414                    crps1 = compute_crps(ref - predicted_ensemble, sigma)
415                    gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis)
416
417                    if ensemble_axis < 0:
418                        ensemble_axis = predicted_ensemble.ndim + ensemble_axis
419                    muij = jnp.expand_dims(
420                        predicted_ensemble, axis=ensemble_axis
421                    ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1)
422                    sigma2 = sigma**2
423                    sigmaij = (
424                        jnp.expand_dims(sigma2, axis=ensemble_axis)
425                        + jnp.expand_dims(sigma2, axis=ensemble_axis + 1)
426                    ) ** 0.5
427                    wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims(
428                        gm_w, axis=ensemble_axis + 1
429                    )
430                    crps2 = compute_crps(muij, sigmaij)
431                    gm_crps2 = (wij * crps2).sum(
432                        axis=(ensemble_axis, ensemble_axis + 1)
433                    )
434
435                    crps = gm_crps1 - 0.5 * gm_crps2
436                    loss = jnp.sum(natscale * truth_mask * crps)
437
438                elif loss_type == "evidential":
439                    evidence = loss_prms["evidence_key"]
440                    nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1)
441                    gamma = predicted
442                    nu = nu.reshape(shape_mask)
443                    alpha = alpha.reshape(shape_mask)
444                    beta = beta.reshape(shape_mask)
445                    nu = jnp.where(truth_mask, nu, 1.0)
446                    alpha = jnp.where(truth_mask, alpha, 1.0)
447                    beta = jnp.where(truth_mask, beta, 1.0)
448                    omega = 2 * beta * (1 + nu)
449                    lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln(
450                        alpha + 0.5
451                    )
452                    ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega)
453                    lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2)
454                    wst = (
455                        (beta * (1 + nu) / (alpha * nu)) ** 0.5
456                        if loss_prms.get("normalize_evidence", True)
457                        else 1.0
458                    )
459                    lr = (
460                        loss_prms.get("lambda_evidence", 1.0)
461                        * jnp.abs(gamma - ref)
462                        * nu
463                        / wst
464                    )
465                    r = loss_prms.get("evidence_ratio", 1.0)
466                    le = (
467                        loss_prms.get("lambda_evidence_diff", 0.0)
468                        * (nu - r * 2 * alpha) ** 2
469                    )
470                    lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta
471                    loss = lg + ls + lt + lr + le + lb
472
473                    loss = jnp.sum(natscale * loss * truth_mask)
474                elif loss_type == "raw":
475                    loss = jnp.sum(natscale * predicted)
476                elif loss_type == "binary_crossentropy":
477                    ref = ref.astype(predicted.dtype)
478                    if loss_prms.get("raw_logits",False):
479                        loss = jnp.sum(
480                            natscale*truth_mask
481                            * (
482                                -ref * jax.nn.log_sigmoid(predicted)
483                                - (1.0 - ref) * jax.nn.log_sigmoid(-predicted)
484                            )
485                        )
486                    else:
487                        loss = jnp.sum(
488                            natscale*truth_mask
489                            * (
490                                -ref * jnp.log(predicted + 1.0e-10)
491                                - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10)
492                            )
493                        )
494                else:
495                    raise ValueError(f"Unknown loss type: {loss_type}")
496
497                if "weight_schedule" in loss_prms:
498                    w = linear_schedule(epoch, *loss_prms["weight_schedule"])
499                else:
500                    w = loss_prms["weight"]
501
502                loss_tot = loss_tot + w * loss
503
504            return loss_tot, output
505
506        variables = training_state["variables"]
507        new_training_state = {**training_state, "step":training_state["step"]+1}
508
509        (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables)
510        updates, new_training_state["opt_state"] = optimizer.update(grad, training_state["opt_state"], params=variables)
511        variables = optax.apply_updates(variables, updates)
512        new_training_state["variables"] = variables
513
514        if ema is not None:
515            ema_st = training_state.get("ema_state",None)
516            if ema_st is None:
517                raise ValueError(
518                    "train_step was setup with ema but either variables_ema or ema_st was not provided"
519                )
520            new_training_state["variables_ema"], new_training_state["ema_state"] = ema.update(variables, ema_st)
521    
522        return loss, new_training_state, o
523
524    if jit:
525        return jax.jit(train_step)
526    return train_step
def get_validation_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, return_targets: bool = False, jit: bool = True):
529def get_validation_function(
530    loss_definition: Dict,
531    model: FENNIX,
532    evaluate: Callable,
533    model_ref: Optional[FENNIX] = None,
534    compute_ref_coords: bool = False,
535    return_targets: bool = False,
536    jit: bool = True,
537):
538
539    def validation(data, inputs, variables, inputs_ref=None):
540        if model_ref is not None:
541            output_ref = evaluate(model_ref, model_ref.variables, inputs)
542            # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
543        output = evaluate(model, variables, inputs)
544        # _, _, output = model._energy_and_forces(variables, data)
545        rmses = {}
546        maes = {}
547        if return_targets:
548            targets = {}
549
550        natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
551        nsys = inputs["natoms"].shape[0]
552        system_mask = inputs["true_sys"]
553        atom_mask = inputs["true_atoms"]
554        batch_index = inputs["batch_index"]
555        if "system_sign" in inputs:
556            system_sign = inputs["system_sign"] > 0
557            system_mask = jnp.logical_and(system_mask,system_sign)
558            atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index])
559
560        for name, loss_prms in loss_definition.items():
561            do_validation = loss_prms.get("validate", True)
562            if not do_validation:
563                continue
564            predicted = output[loss_prms["key"]]
565            use_ref_mask = False
566            
567            if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
568                assert compute_ref_coords, "compute_ref_coords must be True"
569                # predicted = predicted - output_data_ref[loss_prms["key"]]
570                shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
571                if predicted.shape[0] == nsys:
572                    system_sign = inputs["system_sign"].reshape(*shape_mask)
573                    predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys)
574                elif predicted.shape[0] == batch_index.shape[0]:
575                    atom_sign = inputs["system_sign"][batch_index].reshape(*shape_mask)
576                    natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)])
577                    iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index]
578                    iatdest = natshift[inputs["system_index"]][batch_index]+iat
579                    predicted = jax.ops.segment_sum(atom_sign*predicted,iatdest,batch_index.shape[0])
580                else:
581                    raise ValueError("remove_ref_sys only works with system-level or atom-level predictions")
582
583            if "ref" in loss_prms:
584                if loss_prms["ref"].startswith("model_ref/"):
585                    assert model_ref is not None, "model_ref must be provided"
586                    ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
587                    mask_key  = loss_prms["ref"][6:] + "_mask" 
588                    if mask_key in data:
589                        use_ref_mask = True
590                        ref_mask = data[mask_key]
591                    elif mask_key in output_ref:
592                        use_ref_mask = True
593                        ref_mask = output[mask_key]
594                    elif mask_key in output:
595                        use_ref_mask = True
596                        ref_mask = output[mask_key]
597                elif loss_prms["ref"].startswith("model/"):
598                    ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
599                    mask_key  = loss_prms["ref"][6:] + "_mask" 
600                    if mask_key in data:
601                        use_ref_mask = True
602                        ref_mask = data[mask_key]
603                    elif mask_key in output:
604                        use_ref_mask = True
605                        ref_mask = output[mask_key]
606                else:
607                    ref = data[loss_prms["ref"]] * loss_prms["mult"]
608                    if loss_prms["ref"] + "_mask" in data:
609                        use_ref_mask = True
610                        ref_mask = data[loss_prms["ref"] + "_mask"]
611            else:
612                ref = jnp.zeros_like(predicted)
613            
614            if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
615                norm_axis = loss_prms["norm_axis"]
616                predicted = jnp.linalg.norm(
617                    predicted, axis=norm_axis, keepdims=True
618                )
619                ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
620
621            loss_type = loss_prms["type"].lower()
622            if loss_type.startswith("ensemble"):
623                axis = loss_prms.get("ensemble_axis", -1)
624                ensemble_weights_key = loss_prms.get("ensemble_weights", None)
625                if ensemble_weights_key is not None:
626                    ensemble_weights = output[ensemble_weights_key]
627                else:
628                    ensemble_weights = jnp.ones_like(predicted)
629                ensemble_weights = ensemble_weights / jnp.sum(
630                    ensemble_weights, axis=axis, keepdims=True
631                )
632
633                # predicted = predicted.mean(axis=axis)
634                predicted = (ensemble_weights * predicted).sum(axis=axis)
635
636            if loss_prms["type"] in ["gaussian_mixture"]:
637                axis = loss_prms.get("ensemble_axis", -1)
638                gm_w = output[loss_prms["key"] + "_w"]
639                gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True)
640                predicted = (predicted * gm_w).sum(axis=axis)
641
642            if predicted.ndim > 1 and predicted.shape[-1] == 1:
643                predicted = jnp.squeeze(predicted, axis=-1)
644
645            if ref.ndim > 1 and ref.shape[-1] == 1:
646                ref = jnp.squeeze(ref, axis=-1)
647
648            shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1)
649            if ref.shape[0] == output["batch_index"].shape[0]:
650                ## shape is number of atoms
651                truth_mask = atom_mask
652            elif ref.shape[0] == natoms.shape[0]:
653                ## shape is number of systems
654                truth_mask = system_mask
655                if loss_prms.get("per_atom_validation", False):
656                    ref = ref / natoms.reshape(*shape_mask)
657                    predicted = predicted / natoms.reshape(*shape_mask)
658            else:
659                truth_mask = jnp.ones(ref.shape[0], dtype=bool)
660
661            if use_ref_mask:
662                truth_mask = truth_mask * ref_mask.astype(bool)
663
664            nel = jnp.maximum(
665                (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
666                * jnp.sum(truth_mask).astype(predicted.dtype),
667                1.0,
668            )
669
670            truth_mask = truth_mask.reshape(*shape_mask)
671
672            ref = ref * truth_mask
673            predicted = predicted * truth_mask
674
675            if loss_type == "binary_crossentropy":
676                ref = ref.astype(predicted.dtype)
677                if loss_prms.get("raw_logits",False):
678                    rmse = jnp.sum(
679                            (
680                                -ref * jax.nn.log_sigmoid(predicted)
681                                - (1.0 - ref) * jax.nn.log_sigmoid(-predicted)
682                            )*truth_mask/nel
683                        )
684                    predicted = truth_mask*jax.nn.sigmoid(predicted)
685                else:
686                    rmse = jnp.sum(
687                            (
688                                -ref * jnp.log(predicted + 1.0e-10)
689                                - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10)
690                            )*truth_mask/nel
691                        )
692                thr = loss_prms.get("class_threshold", 0.05)
693                mae = jnp.sum(jnp.where(jnp.abs(predicted - ref) > thr, 1, 0)/nel)*100
694            else:
695                rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5
696                mae = jnp.sum(jnp.abs(predicted - ref) / nel)
697
698            rmses[name] = rmse
699            maes[name] = mae
700            if return_targets:
701                targets[name] = (predicted, ref, truth_mask)
702
703        if return_targets:
704            return rmses, maes, output, targets
705
706        return rmses, maes, output
707
708    if jit:
709        return jax.jit(validation)
710    return validation