fennol.training.utils

  1import jax
  2import jax.numpy as jnp
  3import numpy as np
  4from typing import Callable, Optional, Dict, List, Tuple, Union, Any, NamedTuple
  5import optax
  6from copy import deepcopy
  7from functools import partial
  8
  9from ..utils import deep_update, AtomicUnits as au
 10from ..models import FENNIX
 11        
 12
 13@partial(jax.jit, static_argnums=(1,2,3,4))
 14def linear_schedule(step, start_value,end_value, start_step, duration):
 15    return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value)
 16
 17
 18def get_training_parameters(
 19    parameters: Dict[str, any], stage: int = -1
 20) -> Dict[str, any]:
 21    params = deepcopy(parameters["training"])
 22    if "stages" not in params:
 23        return params
 24
 25    stages: dict = params.pop("stages")
 26    stage_keys = list(stages.keys())
 27    if stage < 0:
 28        stage = len(stage_keys) + stage
 29    assert stage >= 0 and stage < len(
 30        stage_keys
 31    ), f"Stage {stage} not found in training parameters"
 32    for i in range(stage + 1):
 33        ## remove end_event from previous stage ##
 34        if i > 0 and "end_event" in params:
 35            params.pop("end_event")
 36        ## incrementally update training parameters ##
 37        stage_params = stages[stage_keys[i]]
 38        params = deep_update(params, stage_params)
 39    return params
 40
 41
 42def get_loss_definition(
 43    training_parameters: Dict[str, any],
 44    model_energy_unit: str = "Ha",  # , manual_renames: List[str] = []
 45) -> Tuple[Dict[str, any], List[str], List[str]]:
 46    """
 47    Returns the loss definition and a list of renamed references.
 48
 49    Args:
 50        training_parameters (dict): A dictionary containing training parameters.
 51
 52    Returns:
 53        tuple: A tuple containing:
 54            - loss_definition (dict): A dictionary containing the loss definition.
 55            - rename_refs (list): A list of renamed references.
 56    """
 57    default_loss_type = training_parameters.get("default_loss_type", "log_cosh")
 58    # loss_definition = deepcopy(training_parameters["loss"])
 59    used_keys = []
 60    ref_keys = []
 61    energy_mult = au.get_multiplier(model_energy_unit)
 62    loss_definition = {}
 63    for k in training_parameters["loss"]:
 64        # loss_prms = loss_definition[k]
 65        loss_prms = deepcopy(training_parameters["loss"][k])
 66        if "energy_unit" in loss_prms:
 67            loss_prms["mult"] = energy_mult / au.get_multiplier(
 68                loss_prms["energy_unit"]
 69            )
 70            if "unit" in loss_prms:
 71                print(
 72                    "Warning: Both 'unit' and 'energy_unit' are defined for loss component",
 73                    k,
 74                    " -> using 'energy_unit'",
 75                )
 76            loss_prms["unit"] = loss_prms["energy_unit"]
 77        elif "unit" in loss_prms:
 78            loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"])
 79        else:
 80            loss_prms["mult"] = 1.0
 81        if "key" not in loss_prms:
 82            loss_prms["key"] = k
 83        if "type" not in loss_prms:
 84            loss_prms["type"] = default_loss_type
 85        if "weight" not in loss_prms:
 86            loss_prms["weight"] = 1.0
 87        assert loss_prms["weight"] >= 0.0, "Loss weight must be positive"
 88        if "threshold" in loss_prms:
 89            assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0"
 90        if "ref" in loss_prms:
 91            ref = loss_prms["ref"]
 92            if not (ref.startswith("model_ref/") or ref.startswith("model/")):
 93                ref_keys.append(ref)
 94        if "ds_weight" in loss_prms:
 95            ref_keys.append(loss_prms["ds_weight"])
 96
 97        if "weight_start" in loss_prms:
 98            weight_start = loss_prms["weight_start"]
 99            if "weight_ramp" in loss_prms:
100                weight_ramp = loss_prms["weight_ramp"]
101            else:
102                weight_ramp = training_parameters.get("max_epochs")
103            weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0)
104            weight_end = loss_prms["weight"]
105            print(f"Weight ramp for {k} : {weight_start} -> {weight_end} in {weight_ramp} epochs")
106            loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp)
107
108        used_keys.append(loss_prms["key"])
109        loss_definition[k] = loss_prms
110
111    # rename_refs = list(
112    #     set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys)
113    # )
114
115    # for k in loss_definition.keys():
116    #     loss_prms = loss_definition[k]
117    #     if "ref" in loss_prms:
118    #         if loss_prms["ref"] in rename_refs:
119    #             loss_prms["ref"] = "true_" + loss_prms["ref"]
120
121    return loss_definition, list(set(used_keys)), list(set(ref_keys))
122
123
124def get_train_step_function(
125    loss_definition: Dict,
126    model: FENNIX,
127    evaluate: Callable,
128    optimizer: optax.GradientTransformation,
129    ema: Optional[optax.GradientTransformation] = None,
130    model_ref: Optional[FENNIX] = None,
131    compute_ref_coords: bool = False,
132    jit: bool = True,
133):
134    def train_step(
135        data,
136        inputs,
137        training_state,
138    ):
139        epoch = training_state["epoch"]
140
141        def loss_fn(variables):
142            if model_ref is not None:
143                output_ref = evaluate(model_ref, model_ref.variables, inputs)
144                # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
145            output = evaluate(model, variables, inputs)
146            # _, _, output = model._energy_and_forces(variables, data)
147            natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
148            nsys = inputs["natoms"].shape[0]
149            system_mask = inputs["true_sys"]
150            atom_mask = inputs["true_atoms"]
151            batch_index = inputs["batch_index"]
152            if "system_sign" in inputs:
153                system_sign = inputs["system_sign"] > 0
154                system_mask = jnp.logical_and(system_mask,system_sign)
155                atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index])
156
157            loss_tot = 0.0
158            for loss_prms in loss_definition.values():
159                use_ref_mask = False
160                predicted = output[loss_prms["key"]]
161                
162
163                if "ref" in loss_prms:
164                    if loss_prms["ref"].startswith("model_ref/"):
165                        assert model_ref is not None, "model_ref must be provided"
166                        try:
167                            ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
168                        except KeyError:
169                            raise KeyError(
170                                f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}"
171                            )
172                        mask_key  = loss_prms["ref"][6:] + "_mask" 
173                        if mask_key in data:
174                            use_ref_mask = True
175                            ref_mask = data[mask_key]
176                        elif mask_key in output_ref:
177                            use_ref_mask = True
178                            ref_mask = output[mask_key]
179                        elif mask_key in output:
180                            use_ref_mask = True
181                            ref_mask = output[mask_key]
182                    elif loss_prms["ref"].startswith("model/"):
183                        try:
184                            ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
185                        except KeyError:
186                            raise KeyError(
187                                f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}"
188                            )
189                        mask_key  = loss_prms["ref"][6:] + "_mask" 
190                        if mask_key in data:
191                            use_ref_mask = True
192                            ref_mask = data[mask_key]
193                        elif mask_key in output:
194                            use_ref_mask = True
195                            ref_mask = output[mask_key]
196                    else:
197                        try:
198                            ref = data[loss_prms["ref"]] * loss_prms["mult"]
199                        except KeyError:
200                            raise KeyError(
201                                f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}"
202                            )
203                        if loss_prms["ref"] + "_mask" in data:
204                            use_ref_mask = True
205                            ref_mask = data[loss_prms["ref"] + "_mask"]
206                else:
207                    ref = jnp.zeros_like(predicted)
208
209                if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
210                    norm_axis = loss_prms["norm_axis"]
211                    predicted = jnp.linalg.norm(
212                        predicted, axis=norm_axis, keepdims=True
213                    )
214                    ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
215
216                ensemble_loss = loss_prms["type"] in ["ensemble_nll","ensemble_crps","gaussian_mixture",]
217                if ensemble_loss:
218                    ensemble_axis = loss_prms.get("ensemble_axis", -1)
219                    ensemble_weights_key = loss_prms.get("ensemble_weights", None)
220                    ensemble_size = predicted.shape[ensemble_axis]
221                    if ensemble_weights_key is not None:
222                        ensemble_weights = output[ensemble_weights_key]
223                    else:
224                        ensemble_weights = jnp.ones_like(predicted)
225                    if "ensemble_subsample" in loss_prms and "rng_key" in output:
226                        ns = min(
227                            loss_prms["ensemble_subsample"],
228                            predicted.shape[ensemble_axis],
229                        )
230                        key, subkey = jax.random.split(output["rng_key"])
231
232                        def get_subsample(x):
233                            return jax.lax.slice_in_dim(
234                                jax.random.permutation(
235                                    subkey, x, axis=ensemble_axis, independent=True
236                                ),
237                                start_index=0,
238                                limit_index=ns,
239                                axis=ensemble_axis,
240                            )
241
242                        predicted = get_subsample(predicted)
243                        ensemble_weights = get_subsample(ensemble_weights)
244                        output["rng_key"] = key
245                    else:
246                        get_subsample = lambda x: x
247                    ensemble_weights = ensemble_weights / jnp.sum(
248                        ensemble_weights, axis=ensemble_axis, keepdims=True
249                    )
250                    predicted_ensemble = predicted
251                    predicted = (predicted * ensemble_weights).sum(
252                        axis=ensemble_axis, keepdims=True
253                    )
254                    predicted_var = (
255                        ensemble_weights * (predicted_ensemble - predicted) ** 2
256                    ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
257                    predicted = jnp.squeeze(predicted, axis=ensemble_axis)
258                
259                if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
260                    assert compute_ref_coords, "compute_ref_coords must be True"
261                    # predicted = predicted - output_data_ref[loss_prms["key"]]
262                    # assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions"
263                    shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
264                    if predicted.shape[0] == nsys:
265                        predicted_sign_mask = inputs["system_sign"].reshape(*shape_mask) < 0
266                        predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,inputs["system_index"],nsys)
267                        predicted = predicted - predicted_shift
268                    elif predicted.shape[0] == batch_index.shape[0]:
269                        predicted_sign_mask = (inputs["system_sign"]<0)[batch_index].reshape(*shape_mask)
270                        natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)])
271                        iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index]
272                        iatdest = natshift[inputs["system_index"]][batch_index]+iat
273                        predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,iatdest,batch_index.shape[0])
274                        predicted = predicted - predicted_shift
275                    else:
276                        raise ValueError("remove_ref_sys only works with system-level or atom-level predictions")
277                    
278                    if ensemble_loss:
279                        predicted_ensemble = predicted_ensemble - jnp.expand_dims(predicted_shift,ensemble_axis)
280                        predicted_var = (
281                            ensemble_weights * (predicted_ensemble - jnp.expand_dims(predicted,ensemble_axis)) ** 2
282                        ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
283
284                if predicted.ndim > 1 and predicted.shape[-1] == 1:
285                    predicted = jnp.squeeze(predicted, axis=-1)
286
287                if ref.ndim > 1 and ref.shape[-1] == 1:
288                    ref = jnp.squeeze(ref, axis=-1)
289
290                per_atom = False
291                shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1)
292                # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape)
293                natscale = 1.0
294                if ref.shape[0] == output["batch_index"].shape[0]:
295                    ## shape is number of atoms
296                    truth_mask = atom_mask
297                    if "ds_weight" in loss_prms:
298                        weight_key = loss_prms["ds_weight"]
299                        natscale=data[weight_key]
300                        if natscale.shape[0] == natoms.shape[0]:
301                            natscale = natscale[output["batch_index"]]
302                        natscale = natscale.reshape(*shape_mask)
303                elif ref.shape[0] == natoms.shape[0]:
304                    ## shape is number of systems
305                    truth_mask = system_mask
306                    if loss_prms.get("per_atom", False):
307                        per_atom = True
308                        ref = ref / natoms.reshape(*shape_mask)
309                        predicted = predicted / natoms.reshape(*shape_mask)
310                    if "nat_pow" in loss_prms:
311                        natscale = (
312                            1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"]
313                        )
314                    if "ds_weight" in loss_prms:
315                        weight_key = loss_prms["ds_weight"]
316                        natscale = natscale * data[weight_key].reshape(*shape_mask)
317                else:
318                    truth_mask = jnp.ones(ref.shape[0], dtype=bool)
319
320                if use_ref_mask:
321                    truth_mask = truth_mask * ref_mask.astype(bool)
322
323                nel = jnp.maximum(
324                    (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
325                    * jnp.sum(truth_mask).astype(jnp.float32),
326                    1.0,
327                )
328                truth_mask = truth_mask.reshape(*shape_mask)
329
330                ref = ref * truth_mask
331                predicted = predicted * truth_mask
332
333                natscale = natscale / nel
334
335                loss_type = loss_prms["type"]
336                if loss_type == "mse":
337                    loss = jnp.sum(natscale * (predicted - ref) ** 2)
338                elif loss_type == "log_cosh":
339                    loss = jnp.sum(natscale * optax.log_cosh(predicted, ref))
340                elif loss_type == "mae":
341                    loss = jnp.sum(natscale * jnp.abs(predicted - ref))
342                elif loss_type == "rel_mse":
343                    eps = loss_prms.get("rel_eps", 1.0e-6)
344                    rel_pow = loss_prms.get("rel_pow", 2.0)
345                    loss = jnp.sum(
346                        natscale
347                        * (predicted - ref) ** 2
348                        / (eps + jnp.abs(ref) ** rel_pow)
349                    )
350                elif loss_type == "rmse+mae":
351                    loss = (
352                        jnp.sum(natscale * (predicted - ref) ** 2)
353                    ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref))
354                elif loss_type == "crps":
355                    predicted_var_key = loss_prms.get(
356                        "var_key", loss_prms["key"] + "_var"
357                    )
358                    lamba_crps = loss_prms.get("lambda_crps", 1.0)
359                    predicted_var = output[predicted_var_key]
360                    if per_atom:
361                        predicted_var = predicted_var / natoms.reshape(*shape_mask)**2
362                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
363                    sigma = predicted_var**0.5
364                    dy = (ref - predicted) / sigma
365                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
366                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
367                    loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5))
368                elif loss_type == "ensemble_nll":
369                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
370                    loss = 0.5 * jnp.sum(
371                        natscale
372                        * truth_mask
373                        * (
374                            jnp.log(predicted_var)
375                            + (ref - predicted) ** 2 / predicted_var
376                        )
377                    )
378                elif loss_type == "ensemble_crps":
379                    if per_atom:
380                        predicted_var = predicted_var / natoms.reshape(*shape_mask)**2
381                    lamba_crps = loss_prms.get("lambda_crps", 1.0)
382                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
383                    sigma = predicted_var**0.5
384                    dy = (ref - predicted) / sigma
385                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
386                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
387                    loss = jnp.sum(
388                        natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5)
389                    )
390                elif loss_type == "gaussian_mixture":
391                    gm_w = get_subsample(output[loss_prms["key"] + "_w"])
392                    gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True)
393                    sigma = get_subsample(output[loss_prms["key"] + "_sigma"])
394                    ref = jnp.expand_dims(ref, axis=ensemble_axis)
395                    # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2)
396                    # log_w =jnp.log(gm_w + 1.e-10)
397                    # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis)
398                    # loss = jnp.sum(natscale * truth_mask * negloglikelihood)
399
400                    # CRPS loss
401                    def compute_crps(mu, sigma):
402                        dy = mu / sigma
403                        Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
404                        phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
405                        return sigma * (dy * (2 * Phi - 1.0) + 2 * phi - 1./jnp.pi**0.5)
406
407                    crps1 = compute_crps(ref - predicted_ensemble, sigma)
408                    gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis)
409
410                    if ensemble_axis < 0:
411                        ensemble_axis = predicted_ensemble.ndim + ensemble_axis
412                    muij = jnp.expand_dims(
413                        predicted_ensemble, axis=ensemble_axis
414                    ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1)
415                    sigma2 = sigma**2
416                    sigmaij = (
417                        jnp.expand_dims(sigma2, axis=ensemble_axis)
418                        + jnp.expand_dims(sigma2, axis=ensemble_axis + 1)
419                    ) ** 0.5
420                    wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims(
421                        gm_w, axis=ensemble_axis + 1
422                    )
423                    crps2 = compute_crps(muij, sigmaij)
424                    gm_crps2 = (wij * crps2).sum(
425                        axis=(ensemble_axis, ensemble_axis + 1)
426                    )
427
428                    crps = gm_crps1 - 0.5 * gm_crps2
429                    loss = jnp.sum(natscale * truth_mask * crps)
430
431                elif loss_type == "evidential":
432                    evidence = loss_prms["evidence_key"]
433                    nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1)
434                    gamma = predicted
435                    nu = nu.reshape(shape_mask)
436                    alpha = alpha.reshape(shape_mask)
437                    beta = beta.reshape(shape_mask)
438                    nu = jnp.where(truth_mask, nu, 1.0)
439                    alpha = jnp.where(truth_mask, alpha, 1.0)
440                    beta = jnp.where(truth_mask, beta, 1.0)
441                    omega = 2 * beta * (1 + nu)
442                    lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln(
443                        alpha + 0.5
444                    )
445                    ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega)
446                    lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2)
447                    wst = (
448                        (beta * (1 + nu) / (alpha * nu)) ** 0.5
449                        if loss_prms.get("normalize_evidence", True)
450                        else 1.0
451                    )
452                    lr = (
453                        loss_prms.get("lambda_evidence", 1.0)
454                        * jnp.abs(gamma - ref)
455                        * nu
456                        / wst
457                    )
458                    r = loss_prms.get("evidence_ratio", 1.0)
459                    le = (
460                        loss_prms.get("lambda_evidence_diff", 0.0)
461                        * (nu - r * 2 * alpha) ** 2
462                    )
463                    lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta
464                    loss = lg + ls + lt + lr + le + lb
465
466                    loss = jnp.sum(natscale * loss * truth_mask)
467                elif loss_type == "raw":
468                    loss = jnp.sum(natscale * predicted)
469                elif loss_type == "binary_crossentropy":
470                    ref = ref.astype(predicted.dtype)
471                    if loss_prms.get("raw_logits",False):
472                        loss = jnp.sum(
473                            natscale*truth_mask
474                            * (
475                                -ref * jax.nn.log_sigmoid(predicted)
476                                - (1.0 - ref) * jax.nn.log_sigmoid(-predicted)
477                            )
478                        )
479                    else:
480                        loss = jnp.sum(
481                            natscale*truth_mask
482                            * (
483                                -ref * jnp.log(predicted + 1.0e-10)
484                                - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10)
485                            )
486                        )
487                else:
488                    raise ValueError(f"Unknown loss type: {loss_type}")
489
490                if "weight_schedule" in loss_prms:
491                    w = linear_schedule(epoch, *loss_prms["weight_schedule"])
492                else:
493                    w = loss_prms["weight"]
494
495                loss_tot = loss_tot + w * loss
496
497            return loss_tot, output
498
499        variables = training_state["variables"]
500        new_training_state = {**training_state, "step":training_state["step"]+1}
501
502        (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables)
503        updates, new_training_state["opt_state"] = optimizer.update(grad, training_state["opt_state"], params=variables)
504        variables = optax.apply_updates(variables, updates)
505        new_training_state["variables"] = variables
506
507        if ema is not None:
508            ema_st = training_state.get("ema_state",None)
509            if ema_st is None:
510                raise ValueError(
511                    "train_step was setup with ema but either variables_ema or ema_st was not provided"
512                )
513            new_training_state["variables_ema"], new_training_state["ema_state"] = ema.update(variables, ema_st)
514    
515        return loss, new_training_state, o
516
517    if jit:
518        return jax.jit(train_step)
519    return train_step
520
521
522def get_validation_function(
523    loss_definition: Dict,
524    model: FENNIX,
525    evaluate: Callable,
526    model_ref: Optional[FENNIX] = None,
527    compute_ref_coords: bool = False,
528    return_targets: bool = False,
529    jit: bool = True,
530):
531
532    def validation(data, inputs, variables, inputs_ref=None):
533        if model_ref is not None:
534            output_ref = evaluate(model_ref, model_ref.variables, inputs)
535            # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
536        output = evaluate(model, variables, inputs)
537        # _, _, output = model._energy_and_forces(variables, data)
538        rmses = {}
539        maes = {}
540        if return_targets:
541            targets = {}
542
543        natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
544        nsys = inputs["natoms"].shape[0]
545        system_mask = inputs["true_sys"]
546        atom_mask = inputs["true_atoms"]
547        batch_index = inputs["batch_index"]
548        if "system_sign" in inputs:
549            system_sign = inputs["system_sign"] > 0
550            system_mask = jnp.logical_and(system_mask,system_sign)
551            atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index])
552
553        for name, loss_prms in loss_definition.items():
554            do_validation = loss_prms.get("validate", True)
555            if not do_validation:
556                continue
557            predicted = output[loss_prms["key"]]
558            use_ref_mask = False
559            
560            if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
561                assert compute_ref_coords, "compute_ref_coords must be True"
562                # predicted = predicted - output_data_ref[loss_prms["key"]]
563                shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
564                if predicted.shape[0] == nsys:
565                    system_sign = inputs["system_sign"].reshape(*shape_mask)
566                    predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys)
567                elif predicted.shape[0] == batch_index.shape[0]:
568                    atom_sign = inputs["system_sign"][batch_index].reshape(*shape_mask)
569                    natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)])
570                    iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index]
571                    iatdest = natshift[inputs["system_index"]][batch_index]+iat
572                    predicted = jax.ops.segment_sum(atom_sign*predicted,iatdest,batch_index.shape[0])
573                else:
574                    raise ValueError("remove_ref_sys only works with system-level or atom-level predictions")
575
576            if "ref" in loss_prms:
577                if loss_prms["ref"].startswith("model_ref/"):
578                    assert model_ref is not None, "model_ref must be provided"
579                    ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
580                    mask_key  = loss_prms["ref"][6:] + "_mask" 
581                    if mask_key in data:
582                        use_ref_mask = True
583                        ref_mask = data[mask_key]
584                    elif mask_key in output_ref:
585                        use_ref_mask = True
586                        ref_mask = output[mask_key]
587                    elif mask_key in output:
588                        use_ref_mask = True
589                        ref_mask = output[mask_key]
590                elif loss_prms["ref"].startswith("model/"):
591                    ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
592                    mask_key  = loss_prms["ref"][6:] + "_mask" 
593                    if mask_key in data:
594                        use_ref_mask = True
595                        ref_mask = data[mask_key]
596                    elif mask_key in output:
597                        use_ref_mask = True
598                        ref_mask = output[mask_key]
599                else:
600                    ref = data[loss_prms["ref"]] * loss_prms["mult"]
601                    if loss_prms["ref"] + "_mask" in data:
602                        use_ref_mask = True
603                        ref_mask = data[loss_prms["ref"] + "_mask"]
604            else:
605                ref = jnp.zeros_like(predicted)
606            
607            if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
608                norm_axis = loss_prms["norm_axis"]
609                predicted = jnp.linalg.norm(
610                    predicted, axis=norm_axis, keepdims=True
611                )
612                ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
613
614            loss_type = loss_prms["type"].lower()
615            if loss_type.startswith("ensemble"):
616                axis = loss_prms.get("ensemble_axis", -1)
617                ensemble_weights_key = loss_prms.get("ensemble_weights", None)
618                if ensemble_weights_key is not None:
619                    ensemble_weights = output[ensemble_weights_key]
620                else:
621                    ensemble_weights = jnp.ones_like(predicted)
622                ensemble_weights = ensemble_weights / jnp.sum(
623                    ensemble_weights, axis=axis, keepdims=True
624                )
625
626                # predicted = predicted.mean(axis=axis)
627                predicted = (ensemble_weights * predicted).sum(axis=axis)
628
629            if loss_prms["type"] in ["gaussian_mixture"]:
630                axis = loss_prms.get("ensemble_axis", -1)
631                gm_w = output[loss_prms["key"] + "_w"]
632                gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True)
633                predicted = (predicted * gm_w).sum(axis=axis)
634
635            if predicted.ndim > 1 and predicted.shape[-1] == 1:
636                predicted = jnp.squeeze(predicted, axis=-1)
637
638            if ref.ndim > 1 and ref.shape[-1] == 1:
639                ref = jnp.squeeze(ref, axis=-1)
640
641            shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1)
642            if ref.shape[0] == output["batch_index"].shape[0]:
643                ## shape is number of atoms
644                truth_mask = atom_mask
645            elif ref.shape[0] == natoms.shape[0]:
646                ## shape is number of systems
647                truth_mask = system_mask
648                if loss_prms.get("per_atom_validation", False):
649                    ref = ref / natoms.reshape(*shape_mask)
650                    predicted = predicted / natoms.reshape(*shape_mask)
651            else:
652                truth_mask = jnp.ones(ref.shape[0], dtype=bool)
653
654            if use_ref_mask:
655                truth_mask = truth_mask * ref_mask.astype(bool)
656
657            nel = jnp.maximum(
658                (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
659                * jnp.sum(truth_mask).astype(predicted.dtype),
660                1.0,
661            )
662
663            truth_mask = truth_mask.reshape(*shape_mask)
664
665            ref = ref * truth_mask
666            predicted = predicted * truth_mask
667
668            if loss_type == "binary_crossentropy":
669                ref = ref.astype(predicted.dtype)
670                if loss_prms.get("raw_logits",False):
671                    rmse = jnp.sum(
672                            (
673                                -ref * jax.nn.log_sigmoid(predicted)
674                                - (1.0 - ref) * jax.nn.log_sigmoid(-predicted)
675                            )*truth_mask/nel
676                        )
677                    predicted = truth_mask*jax.nn.sigmoid(predicted)
678                else:
679                    rmse = jnp.sum(
680                            (
681                                -ref * jnp.log(predicted + 1.0e-10)
682                                - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10)
683                            )*truth_mask/nel
684                        )
685                thr = loss_prms.get("class_threshold", 0.05)
686                mae = jnp.sum(jnp.where(jnp.abs(predicted - ref) > thr, 1, 0)/nel)*100
687            else:
688                rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5
689                mae = jnp.sum(jnp.abs(predicted - ref) / nel)
690
691            rmses[name] = rmse
692            maes[name] = mae
693            if return_targets:
694                targets[name] = (predicted, ref, truth_mask)
695
696        if return_targets:
697            return rmses, maes, output, targets
698
699        return rmses, maes, output
700
701    if jit:
702        return jax.jit(validation)
703    return validation
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def linear_schedule(step, start_value, end_value, start_step, duration):
14@partial(jax.jit, static_argnums=(1,2,3,4))
15def linear_schedule(step, start_value,end_value, start_step, duration):
16    return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value)
def get_training_parameters( parameters: Dict[str, <built-in function any>], stage: int = -1) -> Dict[str, <built-in function any>]:
19def get_training_parameters(
20    parameters: Dict[str, any], stage: int = -1
21) -> Dict[str, any]:
22    params = deepcopy(parameters["training"])
23    if "stages" not in params:
24        return params
25
26    stages: dict = params.pop("stages")
27    stage_keys = list(stages.keys())
28    if stage < 0:
29        stage = len(stage_keys) + stage
30    assert stage >= 0 and stage < len(
31        stage_keys
32    ), f"Stage {stage} not found in training parameters"
33    for i in range(stage + 1):
34        ## remove end_event from previous stage ##
35        if i > 0 and "end_event" in params:
36            params.pop("end_event")
37        ## incrementally update training parameters ##
38        stage_params = stages[stage_keys[i]]
39        params = deep_update(params, stage_params)
40    return params
def get_loss_definition( training_parameters: Dict[str, <built-in function any>], model_energy_unit: str = 'Ha') -> Tuple[Dict[str, <built-in function any>], List[str], List[str]]:
 43def get_loss_definition(
 44    training_parameters: Dict[str, any],
 45    model_energy_unit: str = "Ha",  # , manual_renames: List[str] = []
 46) -> Tuple[Dict[str, any], List[str], List[str]]:
 47    """
 48    Returns the loss definition and a list of renamed references.
 49
 50    Args:
 51        training_parameters (dict): A dictionary containing training parameters.
 52
 53    Returns:
 54        tuple: A tuple containing:
 55            - loss_definition (dict): A dictionary containing the loss definition.
 56            - rename_refs (list): A list of renamed references.
 57    """
 58    default_loss_type = training_parameters.get("default_loss_type", "log_cosh")
 59    # loss_definition = deepcopy(training_parameters["loss"])
 60    used_keys = []
 61    ref_keys = []
 62    energy_mult = au.get_multiplier(model_energy_unit)
 63    loss_definition = {}
 64    for k in training_parameters["loss"]:
 65        # loss_prms = loss_definition[k]
 66        loss_prms = deepcopy(training_parameters["loss"][k])
 67        if "energy_unit" in loss_prms:
 68            loss_prms["mult"] = energy_mult / au.get_multiplier(
 69                loss_prms["energy_unit"]
 70            )
 71            if "unit" in loss_prms:
 72                print(
 73                    "Warning: Both 'unit' and 'energy_unit' are defined for loss component",
 74                    k,
 75                    " -> using 'energy_unit'",
 76                )
 77            loss_prms["unit"] = loss_prms["energy_unit"]
 78        elif "unit" in loss_prms:
 79            loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"])
 80        else:
 81            loss_prms["mult"] = 1.0
 82        if "key" not in loss_prms:
 83            loss_prms["key"] = k
 84        if "type" not in loss_prms:
 85            loss_prms["type"] = default_loss_type
 86        if "weight" not in loss_prms:
 87            loss_prms["weight"] = 1.0
 88        assert loss_prms["weight"] >= 0.0, "Loss weight must be positive"
 89        if "threshold" in loss_prms:
 90            assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0"
 91        if "ref" in loss_prms:
 92            ref = loss_prms["ref"]
 93            if not (ref.startswith("model_ref/") or ref.startswith("model/")):
 94                ref_keys.append(ref)
 95        if "ds_weight" in loss_prms:
 96            ref_keys.append(loss_prms["ds_weight"])
 97
 98        if "weight_start" in loss_prms:
 99            weight_start = loss_prms["weight_start"]
100            if "weight_ramp" in loss_prms:
101                weight_ramp = loss_prms["weight_ramp"]
102            else:
103                weight_ramp = training_parameters.get("max_epochs")
104            weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0)
105            weight_end = loss_prms["weight"]
106            print(f"Weight ramp for {k} : {weight_start} -> {weight_end} in {weight_ramp} epochs")
107            loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp)
108
109        used_keys.append(loss_prms["key"])
110        loss_definition[k] = loss_prms
111
112    # rename_refs = list(
113    #     set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys)
114    # )
115
116    # for k in loss_definition.keys():
117    #     loss_prms = loss_definition[k]
118    #     if "ref" in loss_prms:
119    #         if loss_prms["ref"] in rename_refs:
120    #             loss_prms["ref"] = "true_" + loss_prms["ref"]
121
122    return loss_definition, list(set(used_keys)), list(set(ref_keys))

Returns the loss definition and a list of renamed references.

Args: training_parameters (dict): A dictionary containing training parameters.

Returns: tuple: A tuple containing: - loss_definition (dict): A dictionary containing the loss definition. - rename_refs (list): A list of renamed references.

def get_train_step_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, optimizer: optax._src.base.GradientTransformation, ema: Optional[optax._src.base.GradientTransformation] = None, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, jit: bool = True):
125def get_train_step_function(
126    loss_definition: Dict,
127    model: FENNIX,
128    evaluate: Callable,
129    optimizer: optax.GradientTransformation,
130    ema: Optional[optax.GradientTransformation] = None,
131    model_ref: Optional[FENNIX] = None,
132    compute_ref_coords: bool = False,
133    jit: bool = True,
134):
135    def train_step(
136        data,
137        inputs,
138        training_state,
139    ):
140        epoch = training_state["epoch"]
141
142        def loss_fn(variables):
143            if model_ref is not None:
144                output_ref = evaluate(model_ref, model_ref.variables, inputs)
145                # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
146            output = evaluate(model, variables, inputs)
147            # _, _, output = model._energy_and_forces(variables, data)
148            natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
149            nsys = inputs["natoms"].shape[0]
150            system_mask = inputs["true_sys"]
151            atom_mask = inputs["true_atoms"]
152            batch_index = inputs["batch_index"]
153            if "system_sign" in inputs:
154                system_sign = inputs["system_sign"] > 0
155                system_mask = jnp.logical_and(system_mask,system_sign)
156                atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index])
157
158            loss_tot = 0.0
159            for loss_prms in loss_definition.values():
160                use_ref_mask = False
161                predicted = output[loss_prms["key"]]
162                
163
164                if "ref" in loss_prms:
165                    if loss_prms["ref"].startswith("model_ref/"):
166                        assert model_ref is not None, "model_ref must be provided"
167                        try:
168                            ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
169                        except KeyError:
170                            raise KeyError(
171                                f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}"
172                            )
173                        mask_key  = loss_prms["ref"][6:] + "_mask" 
174                        if mask_key in data:
175                            use_ref_mask = True
176                            ref_mask = data[mask_key]
177                        elif mask_key in output_ref:
178                            use_ref_mask = True
179                            ref_mask = output[mask_key]
180                        elif mask_key in output:
181                            use_ref_mask = True
182                            ref_mask = output[mask_key]
183                    elif loss_prms["ref"].startswith("model/"):
184                        try:
185                            ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
186                        except KeyError:
187                            raise KeyError(
188                                f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}"
189                            )
190                        mask_key  = loss_prms["ref"][6:] + "_mask" 
191                        if mask_key in data:
192                            use_ref_mask = True
193                            ref_mask = data[mask_key]
194                        elif mask_key in output:
195                            use_ref_mask = True
196                            ref_mask = output[mask_key]
197                    else:
198                        try:
199                            ref = data[loss_prms["ref"]] * loss_prms["mult"]
200                        except KeyError:
201                            raise KeyError(
202                                f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}"
203                            )
204                        if loss_prms["ref"] + "_mask" in data:
205                            use_ref_mask = True
206                            ref_mask = data[loss_prms["ref"] + "_mask"]
207                else:
208                    ref = jnp.zeros_like(predicted)
209
210                if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
211                    norm_axis = loss_prms["norm_axis"]
212                    predicted = jnp.linalg.norm(
213                        predicted, axis=norm_axis, keepdims=True
214                    )
215                    ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
216
217                ensemble_loss = loss_prms["type"] in ["ensemble_nll","ensemble_crps","gaussian_mixture",]
218                if ensemble_loss:
219                    ensemble_axis = loss_prms.get("ensemble_axis", -1)
220                    ensemble_weights_key = loss_prms.get("ensemble_weights", None)
221                    ensemble_size = predicted.shape[ensemble_axis]
222                    if ensemble_weights_key is not None:
223                        ensemble_weights = output[ensemble_weights_key]
224                    else:
225                        ensemble_weights = jnp.ones_like(predicted)
226                    if "ensemble_subsample" in loss_prms and "rng_key" in output:
227                        ns = min(
228                            loss_prms["ensemble_subsample"],
229                            predicted.shape[ensemble_axis],
230                        )
231                        key, subkey = jax.random.split(output["rng_key"])
232
233                        def get_subsample(x):
234                            return jax.lax.slice_in_dim(
235                                jax.random.permutation(
236                                    subkey, x, axis=ensemble_axis, independent=True
237                                ),
238                                start_index=0,
239                                limit_index=ns,
240                                axis=ensemble_axis,
241                            )
242
243                        predicted = get_subsample(predicted)
244                        ensemble_weights = get_subsample(ensemble_weights)
245                        output["rng_key"] = key
246                    else:
247                        get_subsample = lambda x: x
248                    ensemble_weights = ensemble_weights / jnp.sum(
249                        ensemble_weights, axis=ensemble_axis, keepdims=True
250                    )
251                    predicted_ensemble = predicted
252                    predicted = (predicted * ensemble_weights).sum(
253                        axis=ensemble_axis, keepdims=True
254                    )
255                    predicted_var = (
256                        ensemble_weights * (predicted_ensemble - predicted) ** 2
257                    ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
258                    predicted = jnp.squeeze(predicted, axis=ensemble_axis)
259                
260                if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
261                    assert compute_ref_coords, "compute_ref_coords must be True"
262                    # predicted = predicted - output_data_ref[loss_prms["key"]]
263                    # assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions"
264                    shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
265                    if predicted.shape[0] == nsys:
266                        predicted_sign_mask = inputs["system_sign"].reshape(*shape_mask) < 0
267                        predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,inputs["system_index"],nsys)
268                        predicted = predicted - predicted_shift
269                    elif predicted.shape[0] == batch_index.shape[0]:
270                        predicted_sign_mask = (inputs["system_sign"]<0)[batch_index].reshape(*shape_mask)
271                        natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)])
272                        iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index]
273                        iatdest = natshift[inputs["system_index"]][batch_index]+iat
274                        predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,iatdest,batch_index.shape[0])
275                        predicted = predicted - predicted_shift
276                    else:
277                        raise ValueError("remove_ref_sys only works with system-level or atom-level predictions")
278                    
279                    if ensemble_loss:
280                        predicted_ensemble = predicted_ensemble - jnp.expand_dims(predicted_shift,ensemble_axis)
281                        predicted_var = (
282                            ensemble_weights * (predicted_ensemble - jnp.expand_dims(predicted,ensemble_axis)) ** 2
283                        ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0))
284
285                if predicted.ndim > 1 and predicted.shape[-1] == 1:
286                    predicted = jnp.squeeze(predicted, axis=-1)
287
288                if ref.ndim > 1 and ref.shape[-1] == 1:
289                    ref = jnp.squeeze(ref, axis=-1)
290
291                per_atom = False
292                shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1)
293                # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape)
294                natscale = 1.0
295                if ref.shape[0] == output["batch_index"].shape[0]:
296                    ## shape is number of atoms
297                    truth_mask = atom_mask
298                    if "ds_weight" in loss_prms:
299                        weight_key = loss_prms["ds_weight"]
300                        natscale=data[weight_key]
301                        if natscale.shape[0] == natoms.shape[0]:
302                            natscale = natscale[output["batch_index"]]
303                        natscale = natscale.reshape(*shape_mask)
304                elif ref.shape[0] == natoms.shape[0]:
305                    ## shape is number of systems
306                    truth_mask = system_mask
307                    if loss_prms.get("per_atom", False):
308                        per_atom = True
309                        ref = ref / natoms.reshape(*shape_mask)
310                        predicted = predicted / natoms.reshape(*shape_mask)
311                    if "nat_pow" in loss_prms:
312                        natscale = (
313                            1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"]
314                        )
315                    if "ds_weight" in loss_prms:
316                        weight_key = loss_prms["ds_weight"]
317                        natscale = natscale * data[weight_key].reshape(*shape_mask)
318                else:
319                    truth_mask = jnp.ones(ref.shape[0], dtype=bool)
320
321                if use_ref_mask:
322                    truth_mask = truth_mask * ref_mask.astype(bool)
323
324                nel = jnp.maximum(
325                    (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
326                    * jnp.sum(truth_mask).astype(jnp.float32),
327                    1.0,
328                )
329                truth_mask = truth_mask.reshape(*shape_mask)
330
331                ref = ref * truth_mask
332                predicted = predicted * truth_mask
333
334                natscale = natscale / nel
335
336                loss_type = loss_prms["type"]
337                if loss_type == "mse":
338                    loss = jnp.sum(natscale * (predicted - ref) ** 2)
339                elif loss_type == "log_cosh":
340                    loss = jnp.sum(natscale * optax.log_cosh(predicted, ref))
341                elif loss_type == "mae":
342                    loss = jnp.sum(natscale * jnp.abs(predicted - ref))
343                elif loss_type == "rel_mse":
344                    eps = loss_prms.get("rel_eps", 1.0e-6)
345                    rel_pow = loss_prms.get("rel_pow", 2.0)
346                    loss = jnp.sum(
347                        natscale
348                        * (predicted - ref) ** 2
349                        / (eps + jnp.abs(ref) ** rel_pow)
350                    )
351                elif loss_type == "rmse+mae":
352                    loss = (
353                        jnp.sum(natscale * (predicted - ref) ** 2)
354                    ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref))
355                elif loss_type == "crps":
356                    predicted_var_key = loss_prms.get(
357                        "var_key", loss_prms["key"] + "_var"
358                    )
359                    lamba_crps = loss_prms.get("lambda_crps", 1.0)
360                    predicted_var = output[predicted_var_key]
361                    if per_atom:
362                        predicted_var = predicted_var / natoms.reshape(*shape_mask)**2
363                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
364                    sigma = predicted_var**0.5
365                    dy = (ref - predicted) / sigma
366                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
367                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
368                    loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5))
369                elif loss_type == "ensemble_nll":
370                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
371                    loss = 0.5 * jnp.sum(
372                        natscale
373                        * truth_mask
374                        * (
375                            jnp.log(predicted_var)
376                            + (ref - predicted) ** 2 / predicted_var
377                        )
378                    )
379                elif loss_type == "ensemble_crps":
380                    if per_atom:
381                        predicted_var = predicted_var / natoms.reshape(*shape_mask)**2
382                    lamba_crps = loss_prms.get("lambda_crps", 1.0)
383                    predicted_var = predicted_var * truth_mask + (1.0 - truth_mask)
384                    sigma = predicted_var**0.5
385                    dy = (ref - predicted) / sigma
386                    Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
387                    phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
388                    loss = jnp.sum(
389                        natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5)
390                    )
391                elif loss_type == "gaussian_mixture":
392                    gm_w = get_subsample(output[loss_prms["key"] + "_w"])
393                    gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True)
394                    sigma = get_subsample(output[loss_prms["key"] + "_sigma"])
395                    ref = jnp.expand_dims(ref, axis=ensemble_axis)
396                    # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2)
397                    # log_w =jnp.log(gm_w + 1.e-10)
398                    # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis)
399                    # loss = jnp.sum(natscale * truth_mask * negloglikelihood)
400
401                    # CRPS loss
402                    def compute_crps(mu, sigma):
403                        dy = mu / sigma
404                        Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5))
405                        phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5
406                        return sigma * (dy * (2 * Phi - 1.0) + 2 * phi - 1./jnp.pi**0.5)
407
408                    crps1 = compute_crps(ref - predicted_ensemble, sigma)
409                    gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis)
410
411                    if ensemble_axis < 0:
412                        ensemble_axis = predicted_ensemble.ndim + ensemble_axis
413                    muij = jnp.expand_dims(
414                        predicted_ensemble, axis=ensemble_axis
415                    ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1)
416                    sigma2 = sigma**2
417                    sigmaij = (
418                        jnp.expand_dims(sigma2, axis=ensemble_axis)
419                        + jnp.expand_dims(sigma2, axis=ensemble_axis + 1)
420                    ) ** 0.5
421                    wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims(
422                        gm_w, axis=ensemble_axis + 1
423                    )
424                    crps2 = compute_crps(muij, sigmaij)
425                    gm_crps2 = (wij * crps2).sum(
426                        axis=(ensemble_axis, ensemble_axis + 1)
427                    )
428
429                    crps = gm_crps1 - 0.5 * gm_crps2
430                    loss = jnp.sum(natscale * truth_mask * crps)
431
432                elif loss_type == "evidential":
433                    evidence = loss_prms["evidence_key"]
434                    nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1)
435                    gamma = predicted
436                    nu = nu.reshape(shape_mask)
437                    alpha = alpha.reshape(shape_mask)
438                    beta = beta.reshape(shape_mask)
439                    nu = jnp.where(truth_mask, nu, 1.0)
440                    alpha = jnp.where(truth_mask, alpha, 1.0)
441                    beta = jnp.where(truth_mask, beta, 1.0)
442                    omega = 2 * beta * (1 + nu)
443                    lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln(
444                        alpha + 0.5
445                    )
446                    ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega)
447                    lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2)
448                    wst = (
449                        (beta * (1 + nu) / (alpha * nu)) ** 0.5
450                        if loss_prms.get("normalize_evidence", True)
451                        else 1.0
452                    )
453                    lr = (
454                        loss_prms.get("lambda_evidence", 1.0)
455                        * jnp.abs(gamma - ref)
456                        * nu
457                        / wst
458                    )
459                    r = loss_prms.get("evidence_ratio", 1.0)
460                    le = (
461                        loss_prms.get("lambda_evidence_diff", 0.0)
462                        * (nu - r * 2 * alpha) ** 2
463                    )
464                    lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta
465                    loss = lg + ls + lt + lr + le + lb
466
467                    loss = jnp.sum(natscale * loss * truth_mask)
468                elif loss_type == "raw":
469                    loss = jnp.sum(natscale * predicted)
470                elif loss_type == "binary_crossentropy":
471                    ref = ref.astype(predicted.dtype)
472                    if loss_prms.get("raw_logits",False):
473                        loss = jnp.sum(
474                            natscale*truth_mask
475                            * (
476                                -ref * jax.nn.log_sigmoid(predicted)
477                                - (1.0 - ref) * jax.nn.log_sigmoid(-predicted)
478                            )
479                        )
480                    else:
481                        loss = jnp.sum(
482                            natscale*truth_mask
483                            * (
484                                -ref * jnp.log(predicted + 1.0e-10)
485                                - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10)
486                            )
487                        )
488                else:
489                    raise ValueError(f"Unknown loss type: {loss_type}")
490
491                if "weight_schedule" in loss_prms:
492                    w = linear_schedule(epoch, *loss_prms["weight_schedule"])
493                else:
494                    w = loss_prms["weight"]
495
496                loss_tot = loss_tot + w * loss
497
498            return loss_tot, output
499
500        variables = training_state["variables"]
501        new_training_state = {**training_state, "step":training_state["step"]+1}
502
503        (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables)
504        updates, new_training_state["opt_state"] = optimizer.update(grad, training_state["opt_state"], params=variables)
505        variables = optax.apply_updates(variables, updates)
506        new_training_state["variables"] = variables
507
508        if ema is not None:
509            ema_st = training_state.get("ema_state",None)
510            if ema_st is None:
511                raise ValueError(
512                    "train_step was setup with ema but either variables_ema or ema_st was not provided"
513                )
514            new_training_state["variables_ema"], new_training_state["ema_state"] = ema.update(variables, ema_st)
515    
516        return loss, new_training_state, o
517
518    if jit:
519        return jax.jit(train_step)
520    return train_step
def get_validation_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, return_targets: bool = False, jit: bool = True):
523def get_validation_function(
524    loss_definition: Dict,
525    model: FENNIX,
526    evaluate: Callable,
527    model_ref: Optional[FENNIX] = None,
528    compute_ref_coords: bool = False,
529    return_targets: bool = False,
530    jit: bool = True,
531):
532
533    def validation(data, inputs, variables, inputs_ref=None):
534        if model_ref is not None:
535            output_ref = evaluate(model_ref, model_ref.variables, inputs)
536            # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data)
537        output = evaluate(model, variables, inputs)
538        # _, _, output = model._energy_and_forces(variables, data)
539        rmses = {}
540        maes = {}
541        if return_targets:
542            targets = {}
543
544        natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1)
545        nsys = inputs["natoms"].shape[0]
546        system_mask = inputs["true_sys"]
547        atom_mask = inputs["true_atoms"]
548        batch_index = inputs["batch_index"]
549        if "system_sign" in inputs:
550            system_sign = inputs["system_sign"] > 0
551            system_mask = jnp.logical_and(system_mask,system_sign)
552            atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index])
553
554        for name, loss_prms in loss_definition.items():
555            do_validation = loss_prms.get("validate", True)
556            if not do_validation:
557                continue
558            predicted = output[loss_prms["key"]]
559            use_ref_mask = False
560            
561            if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]:
562                assert compute_ref_coords, "compute_ref_coords must be True"
563                # predicted = predicted - output_data_ref[loss_prms["key"]]
564                shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1)
565                if predicted.shape[0] == nsys:
566                    system_sign = inputs["system_sign"].reshape(*shape_mask)
567                    predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys)
568                elif predicted.shape[0] == batch_index.shape[0]:
569                    atom_sign = inputs["system_sign"][batch_index].reshape(*shape_mask)
570                    natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)])
571                    iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index]
572                    iatdest = natshift[inputs["system_index"]][batch_index]+iat
573                    predicted = jax.ops.segment_sum(atom_sign*predicted,iatdest,batch_index.shape[0])
574                else:
575                    raise ValueError("remove_ref_sys only works with system-level or atom-level predictions")
576
577            if "ref" in loss_prms:
578                if loss_prms["ref"].startswith("model_ref/"):
579                    assert model_ref is not None, "model_ref must be provided"
580                    ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"]
581                    mask_key  = loss_prms["ref"][6:] + "_mask" 
582                    if mask_key in data:
583                        use_ref_mask = True
584                        ref_mask = data[mask_key]
585                    elif mask_key in output_ref:
586                        use_ref_mask = True
587                        ref_mask = output[mask_key]
588                    elif mask_key in output:
589                        use_ref_mask = True
590                        ref_mask = output[mask_key]
591                elif loss_prms["ref"].startswith("model/"):
592                    ref = output[loss_prms["ref"][6:]] * loss_prms["mult"]
593                    mask_key  = loss_prms["ref"][6:] + "_mask" 
594                    if mask_key in data:
595                        use_ref_mask = True
596                        ref_mask = data[mask_key]
597                    elif mask_key in output:
598                        use_ref_mask = True
599                        ref_mask = output[mask_key]
600                else:
601                    ref = data[loss_prms["ref"]] * loss_prms["mult"]
602                    if loss_prms["ref"] + "_mask" in data:
603                        use_ref_mask = True
604                        ref_mask = data[loss_prms["ref"] + "_mask"]
605            else:
606                ref = jnp.zeros_like(predicted)
607            
608            if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None:
609                norm_axis = loss_prms["norm_axis"]
610                predicted = jnp.linalg.norm(
611                    predicted, axis=norm_axis, keepdims=True
612                )
613                ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True)
614
615            loss_type = loss_prms["type"].lower()
616            if loss_type.startswith("ensemble"):
617                axis = loss_prms.get("ensemble_axis", -1)
618                ensemble_weights_key = loss_prms.get("ensemble_weights", None)
619                if ensemble_weights_key is not None:
620                    ensemble_weights = output[ensemble_weights_key]
621                else:
622                    ensemble_weights = jnp.ones_like(predicted)
623                ensemble_weights = ensemble_weights / jnp.sum(
624                    ensemble_weights, axis=axis, keepdims=True
625                )
626
627                # predicted = predicted.mean(axis=axis)
628                predicted = (ensemble_weights * predicted).sum(axis=axis)
629
630            if loss_prms["type"] in ["gaussian_mixture"]:
631                axis = loss_prms.get("ensemble_axis", -1)
632                gm_w = output[loss_prms["key"] + "_w"]
633                gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True)
634                predicted = (predicted * gm_w).sum(axis=axis)
635
636            if predicted.ndim > 1 and predicted.shape[-1] == 1:
637                predicted = jnp.squeeze(predicted, axis=-1)
638
639            if ref.ndim > 1 and ref.shape[-1] == 1:
640                ref = jnp.squeeze(ref, axis=-1)
641
642            shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1)
643            if ref.shape[0] == output["batch_index"].shape[0]:
644                ## shape is number of atoms
645                truth_mask = atom_mask
646            elif ref.shape[0] == natoms.shape[0]:
647                ## shape is number of systems
648                truth_mask = system_mask
649                if loss_prms.get("per_atom_validation", False):
650                    ref = ref / natoms.reshape(*shape_mask)
651                    predicted = predicted / natoms.reshape(*shape_mask)
652            else:
653                truth_mask = jnp.ones(ref.shape[0], dtype=bool)
654
655            if use_ref_mask:
656                truth_mask = truth_mask * ref_mask.astype(bool)
657
658            nel = jnp.maximum(
659                (float(np.prod(ref.shape)) / float(truth_mask.shape[0]))
660                * jnp.sum(truth_mask).astype(predicted.dtype),
661                1.0,
662            )
663
664            truth_mask = truth_mask.reshape(*shape_mask)
665
666            ref = ref * truth_mask
667            predicted = predicted * truth_mask
668
669            if loss_type == "binary_crossentropy":
670                ref = ref.astype(predicted.dtype)
671                if loss_prms.get("raw_logits",False):
672                    rmse = jnp.sum(
673                            (
674                                -ref * jax.nn.log_sigmoid(predicted)
675                                - (1.0 - ref) * jax.nn.log_sigmoid(-predicted)
676                            )*truth_mask/nel
677                        )
678                    predicted = truth_mask*jax.nn.sigmoid(predicted)
679                else:
680                    rmse = jnp.sum(
681                            (
682                                -ref * jnp.log(predicted + 1.0e-10)
683                                - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10)
684                            )*truth_mask/nel
685                        )
686                thr = loss_prms.get("class_threshold", 0.05)
687                mae = jnp.sum(jnp.where(jnp.abs(predicted - ref) > thr, 1, 0)/nel)*100
688            else:
689                rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5
690                mae = jnp.sum(jnp.abs(predicted - ref) / nel)
691
692            rmses[name] = rmse
693            maes[name] = mae
694            if return_targets:
695                targets[name] = (predicted, ref, truth_mask)
696
697        if return_targets:
698            return rmses, maes, output, targets
699
700        return rmses, maes, output
701
702    if jit:
703        return jax.jit(validation)
704    return validation