fennol.training.utils
1import jax 2import jax.numpy as jnp 3import numpy as np 4from typing import Callable, Optional, Dict, List, Tuple, Union, Any, NamedTuple 5import optax 6from copy import deepcopy 7from functools import partial 8 9from ..utils import deep_update 10from ..utils.atomic_units import au 11from ..models import FENNIX 12 13 14@partial(jax.jit, static_argnums=(1,2,3,4)) 15def linear_schedule(step, start_value,end_value, start_step, duration): 16 return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value) 17 18 19def get_training_parameters( 20 parameters: Dict[str, any], stage: int = -1 21) -> Dict[str, any]: 22 params = deepcopy(parameters["training"]) 23 if "stages" not in params: 24 return params 25 26 stages: dict = params.pop("stages") 27 stage_keys = list(stages.keys()) 28 if stage < 0: 29 stage = len(stage_keys) + stage 30 assert stage >= 0 and stage < len( 31 stage_keys 32 ), f"Stage {stage} not found in training parameters" 33 for i in range(stage + 1): 34 ## remove end_event from previous stage ## 35 if i > 0 and "end_event" in params: 36 params.pop("end_event") 37 ## incrementally update training parameters ## 38 stage_params = stages[stage_keys[i]] 39 params = deep_update(params, stage_params) 40 return params 41 42 43def get_loss_definition( 44 training_parameters: Dict[str, any], 45 model_energy_unit: str = "Ha", # , manual_renames: List[str] = [] 46) -> Tuple[Dict[str, any], List[str], List[str]]: 47 """ 48 Returns the loss definition and a list of renamed references. 49 50 Args: 51 training_parameters (dict): A dictionary containing training parameters. 52 53 Returns: 54 tuple: A tuple containing: 55 - loss_definition (dict): A dictionary containing the loss definition. 56 - rename_refs (list): A list of renamed references. 57 """ 58 default_loss_type = training_parameters.get("default_loss_type", "log_cosh") 59 # loss_definition = deepcopy(training_parameters["loss"]) 60 used_keys = [] 61 ref_keys = [] 62 energy_mult = au.get_multiplier(model_energy_unit) 63 loss_definition = {} 64 for k in training_parameters["loss"]: 65 # loss_prms = loss_definition[k] 66 loss_prms = deepcopy(training_parameters["loss"][k]) 67 if "energy_unit" in loss_prms: 68 loss_prms["mult"] = energy_mult / au.get_multiplier( 69 loss_prms["energy_unit"] 70 ) 71 if "unit" in loss_prms: 72 print( 73 "Warning: Both 'unit' and 'energy_unit' are defined for loss component", 74 k, 75 " -> using 'energy_unit'", 76 ) 77 loss_prms["unit"] = loss_prms["energy_unit"] 78 elif "unit" in loss_prms: 79 loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"]) 80 else: 81 loss_prms["mult"] = 1.0 82 if "display_unit" not in loss_prms and "unit" in loss_prms: 83 loss_prms["display_unit"] = loss_prms["unit"] 84 au_convert = loss_prms["mult"] * au.get_multiplier(loss_prms.get("unit", "au")) 85 loss_prms["display_mult"] = au_convert/au.get_multiplier(loss_prms.get("display_unit", "au")) 86 87 if "key" not in loss_prms: 88 loss_prms["key"] = k 89 if "type" not in loss_prms: 90 loss_prms["type"] = default_loss_type 91 if "weight" not in loss_prms: 92 loss_prms["weight"] = 1.0 93 assert loss_prms["weight"] >= 0.0, "Loss weight must be positive" 94 if "threshold" in loss_prms: 95 assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0" 96 if "ref" in loss_prms: 97 ref = loss_prms["ref"] 98 if not (ref.startswith("model_ref/") or ref.startswith("model/")): 99 ref_keys.append(ref) 100 if "ds_weight" in loss_prms: 101 ref_keys.append(loss_prms["ds_weight"]) 102 103 if "weight_start" in loss_prms: 104 weight_start = loss_prms["weight_start"] 105 if "weight_ramp" in loss_prms: 106 weight_ramp = loss_prms["weight_ramp"] 107 else: 108 weight_ramp = training_parameters.get("max_epochs") 109 weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0) 110 weight_end = loss_prms["weight"] 111 print(f"Weight ramp for {k} : {weight_start} -> {weight_end} in {weight_ramp} epochs") 112 loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp) 113 114 used_keys.append(loss_prms["key"]) 115 loss_definition[k] = loss_prms 116 117 # rename_refs = list( 118 # set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys) 119 # ) 120 121 # for k in loss_definition.keys(): 122 # loss_prms = loss_definition[k] 123 # if "ref" in loss_prms: 124 # if loss_prms["ref"] in rename_refs: 125 # loss_prms["ref"] = "true_" + loss_prms["ref"] 126 127 return loss_definition, list(set(used_keys)), list(set(ref_keys)) 128 129 130def get_train_step_function( 131 loss_definition: Dict, 132 model: FENNIX, 133 evaluate: Callable, 134 optimizer: optax.GradientTransformation, 135 ema: Optional[optax.GradientTransformation] = None, 136 model_ref: Optional[FENNIX] = None, 137 compute_ref_coords: bool = False, 138 jit: bool = True, 139): 140 def train_step( 141 data, 142 inputs, 143 training_state, 144 ): 145 epoch = training_state["epoch"] 146 147 def loss_fn(variables): 148 if model_ref is not None: 149 output_ref = evaluate(model_ref, model_ref.variables, inputs) 150 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 151 output = evaluate(model, variables, inputs) 152 # _, _, output = model._energy_and_forces(variables, data) 153 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 154 nsys = inputs["natoms"].shape[0] 155 system_mask = inputs["true_sys"] 156 atom_mask = inputs["true_atoms"] 157 batch_index = inputs["batch_index"] 158 if "system_sign" in inputs: 159 system_sign = inputs["system_sign"] > 0 160 system_mask = jnp.logical_and(system_mask,system_sign) 161 atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index]) 162 163 loss_tot = 0.0 164 for loss_prms in loss_definition.values(): 165 use_ref_mask = False 166 predicted = output[loss_prms["key"]] 167 168 169 if "ref" in loss_prms: 170 if loss_prms["ref"].startswith("model_ref/"): 171 assert model_ref is not None, "model_ref must be provided" 172 try: 173 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 174 except KeyError: 175 raise KeyError( 176 f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}" 177 ) 178 mask_key = loss_prms["ref"][6:] + "_mask" 179 if mask_key in data: 180 use_ref_mask = True 181 ref_mask = data[mask_key] 182 elif mask_key in output_ref: 183 use_ref_mask = True 184 ref_mask = output[mask_key] 185 elif mask_key in output: 186 use_ref_mask = True 187 ref_mask = output[mask_key] 188 elif loss_prms["ref"].startswith("model/"): 189 try: 190 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 191 except KeyError: 192 raise KeyError( 193 f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}" 194 ) 195 mask_key = loss_prms["ref"][6:] + "_mask" 196 if mask_key in data: 197 use_ref_mask = True 198 ref_mask = data[mask_key] 199 elif mask_key in output: 200 use_ref_mask = True 201 ref_mask = output[mask_key] 202 else: 203 try: 204 ref = data[loss_prms["ref"]] * loss_prms["mult"] 205 except KeyError: 206 raise KeyError( 207 f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}" 208 ) 209 if loss_prms["ref"] + "_mask" in data: 210 use_ref_mask = True 211 ref_mask = data[loss_prms["ref"] + "_mask"] 212 else: 213 ref = jnp.zeros_like(predicted) 214 215 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 216 norm_axis = loss_prms["norm_axis"] 217 predicted = jnp.linalg.norm( 218 predicted, axis=norm_axis, keepdims=True 219 ) 220 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 221 222 ensemble_loss = loss_prms["type"] in ["ensemble_nll","ensemble_crps","gaussian_mixture",] 223 if ensemble_loss: 224 ensemble_axis = loss_prms.get("ensemble_axis", -1) 225 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 226 ensemble_size = predicted.shape[ensemble_axis] 227 if ensemble_weights_key is not None: 228 ensemble_weights = output[ensemble_weights_key] 229 else: 230 ensemble_weights = jnp.ones_like(predicted) 231 if "ensemble_subsample" in loss_prms and "rng_key" in output: 232 ns = min( 233 loss_prms["ensemble_subsample"], 234 predicted.shape[ensemble_axis], 235 ) 236 key, subkey = jax.random.split(output["rng_key"]) 237 238 def get_subsample(x): 239 return jax.lax.slice_in_dim( 240 jax.random.permutation( 241 subkey, x, axis=ensemble_axis, independent=True 242 ), 243 start_index=0, 244 limit_index=ns, 245 axis=ensemble_axis, 246 ) 247 248 predicted = get_subsample(predicted) 249 ensemble_weights = get_subsample(ensemble_weights) 250 output["rng_key"] = key 251 else: 252 get_subsample = lambda x: x 253 ensemble_weights = ensemble_weights / jnp.sum( 254 ensemble_weights, axis=ensemble_axis, keepdims=True 255 ) 256 predicted_ensemble = predicted 257 predicted = (predicted * ensemble_weights).sum( 258 axis=ensemble_axis, keepdims=True 259 ) 260 predicted_var = ( 261 ensemble_weights * (predicted_ensemble - predicted) ** 2 262 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 263 predicted = jnp.squeeze(predicted, axis=ensemble_axis) 264 265 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 266 assert compute_ref_coords, "compute_ref_coords must be True" 267 # predicted = predicted - output_data_ref[loss_prms["key"]] 268 # assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions" 269 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 270 if predicted.shape[0] == nsys: 271 predicted_sign_mask = inputs["system_sign"].reshape(*shape_mask) < 0 272 predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,inputs["system_index"],nsys) 273 predicted = predicted - predicted_shift 274 elif predicted.shape[0] == batch_index.shape[0]: 275 predicted_sign_mask = (inputs["system_sign"]<0)[batch_index].reshape(*shape_mask) 276 natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)]) 277 iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index] 278 iatdest = natshift[inputs["system_index"]][batch_index]+iat 279 predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,iatdest,batch_index.shape[0]) 280 predicted = predicted - predicted_shift 281 else: 282 raise ValueError("remove_ref_sys only works with system-level or atom-level predictions") 283 284 if ensemble_loss: 285 predicted_ensemble = predicted_ensemble - jnp.expand_dims(predicted_shift,ensemble_axis) 286 predicted_var = ( 287 ensemble_weights * (predicted_ensemble - jnp.expand_dims(predicted,ensemble_axis)) ** 2 288 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 289 290 if predicted.ndim > 1 and predicted.shape[-1] == 1: 291 predicted = jnp.squeeze(predicted, axis=-1) 292 293 if ref.ndim > 1 and ref.shape[-1] == 1: 294 ref = jnp.squeeze(ref, axis=-1) 295 296 per_atom = False 297 shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1) 298 # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape) 299 natscale = 1.0 300 if ref.shape[0] == output["batch_index"].shape[0]: 301 ## shape is number of atoms 302 truth_mask = atom_mask 303 if "ds_weight" in loss_prms: 304 weight_key = loss_prms["ds_weight"] 305 natscale=data[weight_key] 306 if natscale.shape[0] == natoms.shape[0]: 307 natscale = natscale[output["batch_index"]] 308 natscale = natscale.reshape(*shape_mask) 309 elif ref.shape[0] == natoms.shape[0]: 310 ## shape is number of systems 311 truth_mask = system_mask 312 if loss_prms.get("per_atom", False): 313 per_atom = True 314 ref = ref / natoms.reshape(*shape_mask) 315 predicted = predicted / natoms.reshape(*shape_mask) 316 if "nat_pow" in loss_prms: 317 natscale = ( 318 1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"] 319 ) 320 if "ds_weight" in loss_prms: 321 weight_key = loss_prms["ds_weight"] 322 natscale = natscale * data[weight_key].reshape(*shape_mask) 323 else: 324 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 325 326 if use_ref_mask: 327 truth_mask = truth_mask * ref_mask.astype(bool) 328 329 nel = jnp.maximum( 330 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 331 * jnp.sum(truth_mask).astype(jnp.float32), 332 1.0, 333 ) 334 truth_mask = truth_mask.reshape(*shape_mask) 335 336 ref = ref * truth_mask 337 predicted = predicted * truth_mask 338 339 natscale = natscale / nel 340 341 loss_type = loss_prms["type"] 342 if loss_type == "mse": 343 loss = jnp.sum(natscale * (predicted - ref) ** 2) 344 elif loss_type == "log_cosh": 345 loss = jnp.sum(natscale * optax.log_cosh(predicted, ref)) 346 elif loss_type == "mae": 347 loss = jnp.sum(natscale * jnp.abs(predicted - ref)) 348 elif loss_type == "rel_mse": 349 eps = loss_prms.get("rel_eps", 1.0e-6) 350 rel_pow = loss_prms.get("rel_pow", 2.0) 351 loss = jnp.sum( 352 natscale 353 * (predicted - ref) ** 2 354 / (eps + jnp.abs(ref) ** rel_pow) 355 ) 356 elif loss_type == "rmse+mae": 357 loss = ( 358 jnp.sum(natscale * (predicted - ref) ** 2) 359 ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref)) 360 elif loss_type == "crps": 361 predicted_var_key = loss_prms.get( 362 "var_key", loss_prms["key"] + "_var" 363 ) 364 lamba_crps = loss_prms.get("lambda_crps", 1.0) 365 predicted_var = output[predicted_var_key] 366 if per_atom: 367 predicted_var = predicted_var / natoms.reshape(*shape_mask)**2 368 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 369 sigma = predicted_var**0.5 370 dy = (ref - predicted) / sigma 371 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 372 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 373 loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5)) 374 elif loss_type == "ensemble_nll": 375 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 376 loss = 0.5 * jnp.sum( 377 natscale 378 * truth_mask 379 * ( 380 jnp.log(predicted_var) 381 + (ref - predicted) ** 2 / predicted_var 382 ) 383 ) 384 elif loss_type == "ensemble_crps": 385 if per_atom: 386 predicted_var = predicted_var / natoms.reshape(*shape_mask)**2 387 lamba_crps = loss_prms.get("lambda_crps", 1.0) 388 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 389 sigma = predicted_var**0.5 390 dy = (ref - predicted) / sigma 391 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 392 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 393 loss = jnp.sum( 394 natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5) 395 ) 396 elif loss_type == "gaussian_mixture": 397 gm_w = get_subsample(output[loss_prms["key"] + "_w"]) 398 gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True) 399 sigma = get_subsample(output[loss_prms["key"] + "_sigma"]) 400 ref = jnp.expand_dims(ref, axis=ensemble_axis) 401 # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2) 402 # log_w =jnp.log(gm_w + 1.e-10) 403 # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis) 404 # loss = jnp.sum(natscale * truth_mask * negloglikelihood) 405 406 # CRPS loss 407 def compute_crps(mu, sigma): 408 dy = mu / sigma 409 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 410 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 411 return sigma * (dy * (2 * Phi - 1.0) + 2 * phi - 1./jnp.pi**0.5) 412 413 crps1 = compute_crps(ref - predicted_ensemble, sigma) 414 gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis) 415 416 if ensemble_axis < 0: 417 ensemble_axis = predicted_ensemble.ndim + ensemble_axis 418 muij = jnp.expand_dims( 419 predicted_ensemble, axis=ensemble_axis 420 ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1) 421 sigma2 = sigma**2 422 sigmaij = ( 423 jnp.expand_dims(sigma2, axis=ensemble_axis) 424 + jnp.expand_dims(sigma2, axis=ensemble_axis + 1) 425 ) ** 0.5 426 wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims( 427 gm_w, axis=ensemble_axis + 1 428 ) 429 crps2 = compute_crps(muij, sigmaij) 430 gm_crps2 = (wij * crps2).sum( 431 axis=(ensemble_axis, ensemble_axis + 1) 432 ) 433 434 crps = gm_crps1 - 0.5 * gm_crps2 435 loss = jnp.sum(natscale * truth_mask * crps) 436 437 elif loss_type == "evidential": 438 evidence = loss_prms["evidence_key"] 439 nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1) 440 gamma = predicted 441 nu = nu.reshape(shape_mask) 442 alpha = alpha.reshape(shape_mask) 443 beta = beta.reshape(shape_mask) 444 nu = jnp.where(truth_mask, nu, 1.0) 445 alpha = jnp.where(truth_mask, alpha, 1.0) 446 beta = jnp.where(truth_mask, beta, 1.0) 447 omega = 2 * beta * (1 + nu) 448 lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln( 449 alpha + 0.5 450 ) 451 ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega) 452 lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2) 453 wst = ( 454 (beta * (1 + nu) / (alpha * nu)) ** 0.5 455 if loss_prms.get("normalize_evidence", True) 456 else 1.0 457 ) 458 lr = ( 459 loss_prms.get("lambda_evidence", 1.0) 460 * jnp.abs(gamma - ref) 461 * nu 462 / wst 463 ) 464 r = loss_prms.get("evidence_ratio", 1.0) 465 le = ( 466 loss_prms.get("lambda_evidence_diff", 0.0) 467 * (nu - r * 2 * alpha) ** 2 468 ) 469 lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta 470 loss = lg + ls + lt + lr + le + lb 471 472 loss = jnp.sum(natscale * loss * truth_mask) 473 elif loss_type == "raw": 474 loss = jnp.sum(natscale * predicted) 475 elif loss_type == "binary_crossentropy": 476 ref = ref.astype(predicted.dtype) 477 if loss_prms.get("raw_logits",False): 478 loss = jnp.sum( 479 natscale*truth_mask 480 * ( 481 -ref * jax.nn.log_sigmoid(predicted) 482 - (1.0 - ref) * jax.nn.log_sigmoid(-predicted) 483 ) 484 ) 485 else: 486 loss = jnp.sum( 487 natscale*truth_mask 488 * ( 489 -ref * jnp.log(predicted + 1.0e-10) 490 - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10) 491 ) 492 ) 493 else: 494 raise ValueError(f"Unknown loss type: {loss_type}") 495 496 if "weight_schedule" in loss_prms: 497 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 498 else: 499 w = loss_prms["weight"] 500 501 loss_tot = loss_tot + w * loss 502 503 return loss_tot, output 504 505 variables = training_state["variables"] 506 new_training_state = {**training_state, "step":training_state["step"]+1} 507 508 (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables) 509 updates, new_training_state["opt_state"] = optimizer.update(grad, training_state["opt_state"], params=variables) 510 variables = optax.apply_updates(variables, updates) 511 new_training_state["variables"] = variables 512 513 if ema is not None: 514 ema_st = training_state.get("ema_state",None) 515 if ema_st is None: 516 raise ValueError( 517 "train_step was setup with ema but either variables_ema or ema_st was not provided" 518 ) 519 new_training_state["variables_ema"], new_training_state["ema_state"] = ema.update(variables, ema_st) 520 521 return loss, new_training_state, o 522 523 if jit: 524 return jax.jit(train_step) 525 return train_step 526 527 528def get_validation_function( 529 loss_definition: Dict, 530 model: FENNIX, 531 evaluate: Callable, 532 model_ref: Optional[FENNIX] = None, 533 compute_ref_coords: bool = False, 534 return_targets: bool = False, 535 jit: bool = True, 536): 537 538 def validation(data, inputs, variables, inputs_ref=None): 539 if model_ref is not None: 540 output_ref = evaluate(model_ref, model_ref.variables, inputs) 541 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 542 output = evaluate(model, variables, inputs) 543 # _, _, output = model._energy_and_forces(variables, data) 544 rmses = {} 545 maes = {} 546 if return_targets: 547 targets = {} 548 549 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 550 nsys = inputs["natoms"].shape[0] 551 system_mask = inputs["true_sys"] 552 atom_mask = inputs["true_atoms"] 553 batch_index = inputs["batch_index"] 554 if "system_sign" in inputs: 555 system_sign = inputs["system_sign"] > 0 556 system_mask = jnp.logical_and(system_mask,system_sign) 557 atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index]) 558 559 for name, loss_prms in loss_definition.items(): 560 do_validation = loss_prms.get("validate", True) 561 if not do_validation: 562 continue 563 predicted = output[loss_prms["key"]] 564 use_ref_mask = False 565 566 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 567 assert compute_ref_coords, "compute_ref_coords must be True" 568 # predicted = predicted - output_data_ref[loss_prms["key"]] 569 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 570 if predicted.shape[0] == nsys: 571 system_sign = inputs["system_sign"].reshape(*shape_mask) 572 predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys) 573 elif predicted.shape[0] == batch_index.shape[0]: 574 atom_sign = inputs["system_sign"][batch_index].reshape(*shape_mask) 575 natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)]) 576 iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index] 577 iatdest = natshift[inputs["system_index"]][batch_index]+iat 578 predicted = jax.ops.segment_sum(atom_sign*predicted,iatdest,batch_index.shape[0]) 579 else: 580 raise ValueError("remove_ref_sys only works with system-level or atom-level predictions") 581 582 if "ref" in loss_prms: 583 if loss_prms["ref"].startswith("model_ref/"): 584 assert model_ref is not None, "model_ref must be provided" 585 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 586 mask_key = loss_prms["ref"][6:] + "_mask" 587 if mask_key in data: 588 use_ref_mask = True 589 ref_mask = data[mask_key] 590 elif mask_key in output_ref: 591 use_ref_mask = True 592 ref_mask = output[mask_key] 593 elif mask_key in output: 594 use_ref_mask = True 595 ref_mask = output[mask_key] 596 elif loss_prms["ref"].startswith("model/"): 597 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 598 mask_key = loss_prms["ref"][6:] + "_mask" 599 if mask_key in data: 600 use_ref_mask = True 601 ref_mask = data[mask_key] 602 elif mask_key in output: 603 use_ref_mask = True 604 ref_mask = output[mask_key] 605 else: 606 ref = data[loss_prms["ref"]] * loss_prms["mult"] 607 if loss_prms["ref"] + "_mask" in data: 608 use_ref_mask = True 609 ref_mask = data[loss_prms["ref"] + "_mask"] 610 else: 611 ref = jnp.zeros_like(predicted) 612 613 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 614 norm_axis = loss_prms["norm_axis"] 615 predicted = jnp.linalg.norm( 616 predicted, axis=norm_axis, keepdims=True 617 ) 618 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 619 620 loss_type = loss_prms["type"].lower() 621 if loss_type.startswith("ensemble"): 622 axis = loss_prms.get("ensemble_axis", -1) 623 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 624 if ensemble_weights_key is not None: 625 ensemble_weights = output[ensemble_weights_key] 626 else: 627 ensemble_weights = jnp.ones_like(predicted) 628 ensemble_weights = ensemble_weights / jnp.sum( 629 ensemble_weights, axis=axis, keepdims=True 630 ) 631 632 # predicted = predicted.mean(axis=axis) 633 predicted = (ensemble_weights * predicted).sum(axis=axis) 634 635 if loss_prms["type"] in ["gaussian_mixture"]: 636 axis = loss_prms.get("ensemble_axis", -1) 637 gm_w = output[loss_prms["key"] + "_w"] 638 gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True) 639 predicted = (predicted * gm_w).sum(axis=axis) 640 641 if predicted.ndim > 1 and predicted.shape[-1] == 1: 642 predicted = jnp.squeeze(predicted, axis=-1) 643 644 if ref.ndim > 1 and ref.shape[-1] == 1: 645 ref = jnp.squeeze(ref, axis=-1) 646 647 shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1) 648 if ref.shape[0] == output["batch_index"].shape[0]: 649 ## shape is number of atoms 650 truth_mask = atom_mask 651 elif ref.shape[0] == natoms.shape[0]: 652 ## shape is number of systems 653 truth_mask = system_mask 654 if loss_prms.get("per_atom_validation", False): 655 ref = ref / natoms.reshape(*shape_mask) 656 predicted = predicted / natoms.reshape(*shape_mask) 657 else: 658 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 659 660 if use_ref_mask: 661 truth_mask = truth_mask * ref_mask.astype(bool) 662 663 nel = jnp.maximum( 664 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 665 * jnp.sum(truth_mask).astype(predicted.dtype), 666 1.0, 667 ) 668 669 truth_mask = truth_mask.reshape(*shape_mask) 670 671 ref = ref * truth_mask 672 predicted = predicted * truth_mask 673 674 if loss_type == "binary_crossentropy": 675 ref = ref.astype(predicted.dtype) 676 if loss_prms.get("raw_logits",False): 677 rmse = jnp.sum( 678 ( 679 -ref * jax.nn.log_sigmoid(predicted) 680 - (1.0 - ref) * jax.nn.log_sigmoid(-predicted) 681 )*truth_mask/nel 682 ) 683 predicted = truth_mask*jax.nn.sigmoid(predicted) 684 else: 685 rmse = jnp.sum( 686 ( 687 -ref * jnp.log(predicted + 1.0e-10) 688 - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10) 689 )*truth_mask/nel 690 ) 691 thr = loss_prms.get("class_threshold", 0.05) 692 mae = jnp.sum(jnp.where(jnp.abs(predicted - ref) > thr, 1, 0)/nel)*100 693 else: 694 rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5 695 mae = jnp.sum(jnp.abs(predicted - ref) / nel) 696 697 rmses[name] = rmse 698 maes[name] = mae 699 if return_targets: 700 targets[name] = (predicted, ref, truth_mask) 701 702 if return_targets: 703 return rmses, maes, output, targets 704 705 return rmses, maes, output 706 707 if jit: 708 return jax.jit(validation) 709 return validation
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def
linear_schedule(step, start_value, end_value, start_step, duration):
def
get_training_parameters( parameters: Dict[str, <built-in function any>], stage: int = -1) -> Dict[str, <built-in function any>]:
20def get_training_parameters( 21 parameters: Dict[str, any], stage: int = -1 22) -> Dict[str, any]: 23 params = deepcopy(parameters["training"]) 24 if "stages" not in params: 25 return params 26 27 stages: dict = params.pop("stages") 28 stage_keys = list(stages.keys()) 29 if stage < 0: 30 stage = len(stage_keys) + stage 31 assert stage >= 0 and stage < len( 32 stage_keys 33 ), f"Stage {stage} not found in training parameters" 34 for i in range(stage + 1): 35 ## remove end_event from previous stage ## 36 if i > 0 and "end_event" in params: 37 params.pop("end_event") 38 ## incrementally update training parameters ## 39 stage_params = stages[stage_keys[i]] 40 params = deep_update(params, stage_params) 41 return params
def
get_loss_definition( training_parameters: Dict[str, <built-in function any>], model_energy_unit: str = 'Ha') -> Tuple[Dict[str, <built-in function any>], List[str], List[str]]:
44def get_loss_definition( 45 training_parameters: Dict[str, any], 46 model_energy_unit: str = "Ha", # , manual_renames: List[str] = [] 47) -> Tuple[Dict[str, any], List[str], List[str]]: 48 """ 49 Returns the loss definition and a list of renamed references. 50 51 Args: 52 training_parameters (dict): A dictionary containing training parameters. 53 54 Returns: 55 tuple: A tuple containing: 56 - loss_definition (dict): A dictionary containing the loss definition. 57 - rename_refs (list): A list of renamed references. 58 """ 59 default_loss_type = training_parameters.get("default_loss_type", "log_cosh") 60 # loss_definition = deepcopy(training_parameters["loss"]) 61 used_keys = [] 62 ref_keys = [] 63 energy_mult = au.get_multiplier(model_energy_unit) 64 loss_definition = {} 65 for k in training_parameters["loss"]: 66 # loss_prms = loss_definition[k] 67 loss_prms = deepcopy(training_parameters["loss"][k]) 68 if "energy_unit" in loss_prms: 69 loss_prms["mult"] = energy_mult / au.get_multiplier( 70 loss_prms["energy_unit"] 71 ) 72 if "unit" in loss_prms: 73 print( 74 "Warning: Both 'unit' and 'energy_unit' are defined for loss component", 75 k, 76 " -> using 'energy_unit'", 77 ) 78 loss_prms["unit"] = loss_prms["energy_unit"] 79 elif "unit" in loss_prms: 80 loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"]) 81 else: 82 loss_prms["mult"] = 1.0 83 if "display_unit" not in loss_prms and "unit" in loss_prms: 84 loss_prms["display_unit"] = loss_prms["unit"] 85 au_convert = loss_prms["mult"] * au.get_multiplier(loss_prms.get("unit", "au")) 86 loss_prms["display_mult"] = au_convert/au.get_multiplier(loss_prms.get("display_unit", "au")) 87 88 if "key" not in loss_prms: 89 loss_prms["key"] = k 90 if "type" not in loss_prms: 91 loss_prms["type"] = default_loss_type 92 if "weight" not in loss_prms: 93 loss_prms["weight"] = 1.0 94 assert loss_prms["weight"] >= 0.0, "Loss weight must be positive" 95 if "threshold" in loss_prms: 96 assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0" 97 if "ref" in loss_prms: 98 ref = loss_prms["ref"] 99 if not (ref.startswith("model_ref/") or ref.startswith("model/")): 100 ref_keys.append(ref) 101 if "ds_weight" in loss_prms: 102 ref_keys.append(loss_prms["ds_weight"]) 103 104 if "weight_start" in loss_prms: 105 weight_start = loss_prms["weight_start"] 106 if "weight_ramp" in loss_prms: 107 weight_ramp = loss_prms["weight_ramp"] 108 else: 109 weight_ramp = training_parameters.get("max_epochs") 110 weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0) 111 weight_end = loss_prms["weight"] 112 print(f"Weight ramp for {k} : {weight_start} -> {weight_end} in {weight_ramp} epochs") 113 loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp) 114 115 used_keys.append(loss_prms["key"]) 116 loss_definition[k] = loss_prms 117 118 # rename_refs = list( 119 # set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys) 120 # ) 121 122 # for k in loss_definition.keys(): 123 # loss_prms = loss_definition[k] 124 # if "ref" in loss_prms: 125 # if loss_prms["ref"] in rename_refs: 126 # loss_prms["ref"] = "true_" + loss_prms["ref"] 127 128 return loss_definition, list(set(used_keys)), list(set(ref_keys))
Returns the loss definition and a list of renamed references.
Args: training_parameters (dict): A dictionary containing training parameters.
Returns: tuple: A tuple containing: - loss_definition (dict): A dictionary containing the loss definition. - rename_refs (list): A list of renamed references.
def
get_train_step_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, optimizer: optax._src.base.GradientTransformation, ema: Optional[optax._src.base.GradientTransformation] = None, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, jit: bool = True):
131def get_train_step_function( 132 loss_definition: Dict, 133 model: FENNIX, 134 evaluate: Callable, 135 optimizer: optax.GradientTransformation, 136 ema: Optional[optax.GradientTransformation] = None, 137 model_ref: Optional[FENNIX] = None, 138 compute_ref_coords: bool = False, 139 jit: bool = True, 140): 141 def train_step( 142 data, 143 inputs, 144 training_state, 145 ): 146 epoch = training_state["epoch"] 147 148 def loss_fn(variables): 149 if model_ref is not None: 150 output_ref = evaluate(model_ref, model_ref.variables, inputs) 151 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 152 output = evaluate(model, variables, inputs) 153 # _, _, output = model._energy_and_forces(variables, data) 154 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 155 nsys = inputs["natoms"].shape[0] 156 system_mask = inputs["true_sys"] 157 atom_mask = inputs["true_atoms"] 158 batch_index = inputs["batch_index"] 159 if "system_sign" in inputs: 160 system_sign = inputs["system_sign"] > 0 161 system_mask = jnp.logical_and(system_mask,system_sign) 162 atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index]) 163 164 loss_tot = 0.0 165 for loss_prms in loss_definition.values(): 166 use_ref_mask = False 167 predicted = output[loss_prms["key"]] 168 169 170 if "ref" in loss_prms: 171 if loss_prms["ref"].startswith("model_ref/"): 172 assert model_ref is not None, "model_ref must be provided" 173 try: 174 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 175 except KeyError: 176 raise KeyError( 177 f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}" 178 ) 179 mask_key = loss_prms["ref"][6:] + "_mask" 180 if mask_key in data: 181 use_ref_mask = True 182 ref_mask = data[mask_key] 183 elif mask_key in output_ref: 184 use_ref_mask = True 185 ref_mask = output[mask_key] 186 elif mask_key in output: 187 use_ref_mask = True 188 ref_mask = output[mask_key] 189 elif loss_prms["ref"].startswith("model/"): 190 try: 191 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 192 except KeyError: 193 raise KeyError( 194 f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}" 195 ) 196 mask_key = loss_prms["ref"][6:] + "_mask" 197 if mask_key in data: 198 use_ref_mask = True 199 ref_mask = data[mask_key] 200 elif mask_key in output: 201 use_ref_mask = True 202 ref_mask = output[mask_key] 203 else: 204 try: 205 ref = data[loss_prms["ref"]] * loss_prms["mult"] 206 except KeyError: 207 raise KeyError( 208 f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}" 209 ) 210 if loss_prms["ref"] + "_mask" in data: 211 use_ref_mask = True 212 ref_mask = data[loss_prms["ref"] + "_mask"] 213 else: 214 ref = jnp.zeros_like(predicted) 215 216 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 217 norm_axis = loss_prms["norm_axis"] 218 predicted = jnp.linalg.norm( 219 predicted, axis=norm_axis, keepdims=True 220 ) 221 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 222 223 ensemble_loss = loss_prms["type"] in ["ensemble_nll","ensemble_crps","gaussian_mixture",] 224 if ensemble_loss: 225 ensemble_axis = loss_prms.get("ensemble_axis", -1) 226 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 227 ensemble_size = predicted.shape[ensemble_axis] 228 if ensemble_weights_key is not None: 229 ensemble_weights = output[ensemble_weights_key] 230 else: 231 ensemble_weights = jnp.ones_like(predicted) 232 if "ensemble_subsample" in loss_prms and "rng_key" in output: 233 ns = min( 234 loss_prms["ensemble_subsample"], 235 predicted.shape[ensemble_axis], 236 ) 237 key, subkey = jax.random.split(output["rng_key"]) 238 239 def get_subsample(x): 240 return jax.lax.slice_in_dim( 241 jax.random.permutation( 242 subkey, x, axis=ensemble_axis, independent=True 243 ), 244 start_index=0, 245 limit_index=ns, 246 axis=ensemble_axis, 247 ) 248 249 predicted = get_subsample(predicted) 250 ensemble_weights = get_subsample(ensemble_weights) 251 output["rng_key"] = key 252 else: 253 get_subsample = lambda x: x 254 ensemble_weights = ensemble_weights / jnp.sum( 255 ensemble_weights, axis=ensemble_axis, keepdims=True 256 ) 257 predicted_ensemble = predicted 258 predicted = (predicted * ensemble_weights).sum( 259 axis=ensemble_axis, keepdims=True 260 ) 261 predicted_var = ( 262 ensemble_weights * (predicted_ensemble - predicted) ** 2 263 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 264 predicted = jnp.squeeze(predicted, axis=ensemble_axis) 265 266 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 267 assert compute_ref_coords, "compute_ref_coords must be True" 268 # predicted = predicted - output_data_ref[loss_prms["key"]] 269 # assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions" 270 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 271 if predicted.shape[0] == nsys: 272 predicted_sign_mask = inputs["system_sign"].reshape(*shape_mask) < 0 273 predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,inputs["system_index"],nsys) 274 predicted = predicted - predicted_shift 275 elif predicted.shape[0] == batch_index.shape[0]: 276 predicted_sign_mask = (inputs["system_sign"]<0)[batch_index].reshape(*shape_mask) 277 natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)]) 278 iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index] 279 iatdest = natshift[inputs["system_index"]][batch_index]+iat 280 predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,iatdest,batch_index.shape[0]) 281 predicted = predicted - predicted_shift 282 else: 283 raise ValueError("remove_ref_sys only works with system-level or atom-level predictions") 284 285 if ensemble_loss: 286 predicted_ensemble = predicted_ensemble - jnp.expand_dims(predicted_shift,ensemble_axis) 287 predicted_var = ( 288 ensemble_weights * (predicted_ensemble - jnp.expand_dims(predicted,ensemble_axis)) ** 2 289 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 290 291 if predicted.ndim > 1 and predicted.shape[-1] == 1: 292 predicted = jnp.squeeze(predicted, axis=-1) 293 294 if ref.ndim > 1 and ref.shape[-1] == 1: 295 ref = jnp.squeeze(ref, axis=-1) 296 297 per_atom = False 298 shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1) 299 # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape) 300 natscale = 1.0 301 if ref.shape[0] == output["batch_index"].shape[0]: 302 ## shape is number of atoms 303 truth_mask = atom_mask 304 if "ds_weight" in loss_prms: 305 weight_key = loss_prms["ds_weight"] 306 natscale=data[weight_key] 307 if natscale.shape[0] == natoms.shape[0]: 308 natscale = natscale[output["batch_index"]] 309 natscale = natscale.reshape(*shape_mask) 310 elif ref.shape[0] == natoms.shape[0]: 311 ## shape is number of systems 312 truth_mask = system_mask 313 if loss_prms.get("per_atom", False): 314 per_atom = True 315 ref = ref / natoms.reshape(*shape_mask) 316 predicted = predicted / natoms.reshape(*shape_mask) 317 if "nat_pow" in loss_prms: 318 natscale = ( 319 1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"] 320 ) 321 if "ds_weight" in loss_prms: 322 weight_key = loss_prms["ds_weight"] 323 natscale = natscale * data[weight_key].reshape(*shape_mask) 324 else: 325 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 326 327 if use_ref_mask: 328 truth_mask = truth_mask * ref_mask.astype(bool) 329 330 nel = jnp.maximum( 331 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 332 * jnp.sum(truth_mask).astype(jnp.float32), 333 1.0, 334 ) 335 truth_mask = truth_mask.reshape(*shape_mask) 336 337 ref = ref * truth_mask 338 predicted = predicted * truth_mask 339 340 natscale = natscale / nel 341 342 loss_type = loss_prms["type"] 343 if loss_type == "mse": 344 loss = jnp.sum(natscale * (predicted - ref) ** 2) 345 elif loss_type == "log_cosh": 346 loss = jnp.sum(natscale * optax.log_cosh(predicted, ref)) 347 elif loss_type == "mae": 348 loss = jnp.sum(natscale * jnp.abs(predicted - ref)) 349 elif loss_type == "rel_mse": 350 eps = loss_prms.get("rel_eps", 1.0e-6) 351 rel_pow = loss_prms.get("rel_pow", 2.0) 352 loss = jnp.sum( 353 natscale 354 * (predicted - ref) ** 2 355 / (eps + jnp.abs(ref) ** rel_pow) 356 ) 357 elif loss_type == "rmse+mae": 358 loss = ( 359 jnp.sum(natscale * (predicted - ref) ** 2) 360 ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref)) 361 elif loss_type == "crps": 362 predicted_var_key = loss_prms.get( 363 "var_key", loss_prms["key"] + "_var" 364 ) 365 lamba_crps = loss_prms.get("lambda_crps", 1.0) 366 predicted_var = output[predicted_var_key] 367 if per_atom: 368 predicted_var = predicted_var / natoms.reshape(*shape_mask)**2 369 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 370 sigma = predicted_var**0.5 371 dy = (ref - predicted) / sigma 372 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 373 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 374 loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5)) 375 elif loss_type == "ensemble_nll": 376 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 377 loss = 0.5 * jnp.sum( 378 natscale 379 * truth_mask 380 * ( 381 jnp.log(predicted_var) 382 + (ref - predicted) ** 2 / predicted_var 383 ) 384 ) 385 elif loss_type == "ensemble_crps": 386 if per_atom: 387 predicted_var = predicted_var / natoms.reshape(*shape_mask)**2 388 lamba_crps = loss_prms.get("lambda_crps", 1.0) 389 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 390 sigma = predicted_var**0.5 391 dy = (ref - predicted) / sigma 392 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 393 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 394 loss = jnp.sum( 395 natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5) 396 ) 397 elif loss_type == "gaussian_mixture": 398 gm_w = get_subsample(output[loss_prms["key"] + "_w"]) 399 gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True) 400 sigma = get_subsample(output[loss_prms["key"] + "_sigma"]) 401 ref = jnp.expand_dims(ref, axis=ensemble_axis) 402 # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2) 403 # log_w =jnp.log(gm_w + 1.e-10) 404 # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis) 405 # loss = jnp.sum(natscale * truth_mask * negloglikelihood) 406 407 # CRPS loss 408 def compute_crps(mu, sigma): 409 dy = mu / sigma 410 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 411 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 412 return sigma * (dy * (2 * Phi - 1.0) + 2 * phi - 1./jnp.pi**0.5) 413 414 crps1 = compute_crps(ref - predicted_ensemble, sigma) 415 gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis) 416 417 if ensemble_axis < 0: 418 ensemble_axis = predicted_ensemble.ndim + ensemble_axis 419 muij = jnp.expand_dims( 420 predicted_ensemble, axis=ensemble_axis 421 ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1) 422 sigma2 = sigma**2 423 sigmaij = ( 424 jnp.expand_dims(sigma2, axis=ensemble_axis) 425 + jnp.expand_dims(sigma2, axis=ensemble_axis + 1) 426 ) ** 0.5 427 wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims( 428 gm_w, axis=ensemble_axis + 1 429 ) 430 crps2 = compute_crps(muij, sigmaij) 431 gm_crps2 = (wij * crps2).sum( 432 axis=(ensemble_axis, ensemble_axis + 1) 433 ) 434 435 crps = gm_crps1 - 0.5 * gm_crps2 436 loss = jnp.sum(natscale * truth_mask * crps) 437 438 elif loss_type == "evidential": 439 evidence = loss_prms["evidence_key"] 440 nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1) 441 gamma = predicted 442 nu = nu.reshape(shape_mask) 443 alpha = alpha.reshape(shape_mask) 444 beta = beta.reshape(shape_mask) 445 nu = jnp.where(truth_mask, nu, 1.0) 446 alpha = jnp.where(truth_mask, alpha, 1.0) 447 beta = jnp.where(truth_mask, beta, 1.0) 448 omega = 2 * beta * (1 + nu) 449 lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln( 450 alpha + 0.5 451 ) 452 ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega) 453 lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2) 454 wst = ( 455 (beta * (1 + nu) / (alpha * nu)) ** 0.5 456 if loss_prms.get("normalize_evidence", True) 457 else 1.0 458 ) 459 lr = ( 460 loss_prms.get("lambda_evidence", 1.0) 461 * jnp.abs(gamma - ref) 462 * nu 463 / wst 464 ) 465 r = loss_prms.get("evidence_ratio", 1.0) 466 le = ( 467 loss_prms.get("lambda_evidence_diff", 0.0) 468 * (nu - r * 2 * alpha) ** 2 469 ) 470 lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta 471 loss = lg + ls + lt + lr + le + lb 472 473 loss = jnp.sum(natscale * loss * truth_mask) 474 elif loss_type == "raw": 475 loss = jnp.sum(natscale * predicted) 476 elif loss_type == "binary_crossentropy": 477 ref = ref.astype(predicted.dtype) 478 if loss_prms.get("raw_logits",False): 479 loss = jnp.sum( 480 natscale*truth_mask 481 * ( 482 -ref * jax.nn.log_sigmoid(predicted) 483 - (1.0 - ref) * jax.nn.log_sigmoid(-predicted) 484 ) 485 ) 486 else: 487 loss = jnp.sum( 488 natscale*truth_mask 489 * ( 490 -ref * jnp.log(predicted + 1.0e-10) 491 - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10) 492 ) 493 ) 494 else: 495 raise ValueError(f"Unknown loss type: {loss_type}") 496 497 if "weight_schedule" in loss_prms: 498 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 499 else: 500 w = loss_prms["weight"] 501 502 loss_tot = loss_tot + w * loss 503 504 return loss_tot, output 505 506 variables = training_state["variables"] 507 new_training_state = {**training_state, "step":training_state["step"]+1} 508 509 (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables) 510 updates, new_training_state["opt_state"] = optimizer.update(grad, training_state["opt_state"], params=variables) 511 variables = optax.apply_updates(variables, updates) 512 new_training_state["variables"] = variables 513 514 if ema is not None: 515 ema_st = training_state.get("ema_state",None) 516 if ema_st is None: 517 raise ValueError( 518 "train_step was setup with ema but either variables_ema or ema_st was not provided" 519 ) 520 new_training_state["variables_ema"], new_training_state["ema_state"] = ema.update(variables, ema_st) 521 522 return loss, new_training_state, o 523 524 if jit: 525 return jax.jit(train_step) 526 return train_step
def
get_validation_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, return_targets: bool = False, jit: bool = True):
529def get_validation_function( 530 loss_definition: Dict, 531 model: FENNIX, 532 evaluate: Callable, 533 model_ref: Optional[FENNIX] = None, 534 compute_ref_coords: bool = False, 535 return_targets: bool = False, 536 jit: bool = True, 537): 538 539 def validation(data, inputs, variables, inputs_ref=None): 540 if model_ref is not None: 541 output_ref = evaluate(model_ref, model_ref.variables, inputs) 542 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 543 output = evaluate(model, variables, inputs) 544 # _, _, output = model._energy_and_forces(variables, data) 545 rmses = {} 546 maes = {} 547 if return_targets: 548 targets = {} 549 550 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 551 nsys = inputs["natoms"].shape[0] 552 system_mask = inputs["true_sys"] 553 atom_mask = inputs["true_atoms"] 554 batch_index = inputs["batch_index"] 555 if "system_sign" in inputs: 556 system_sign = inputs["system_sign"] > 0 557 system_mask = jnp.logical_and(system_mask,system_sign) 558 atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index]) 559 560 for name, loss_prms in loss_definition.items(): 561 do_validation = loss_prms.get("validate", True) 562 if not do_validation: 563 continue 564 predicted = output[loss_prms["key"]] 565 use_ref_mask = False 566 567 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 568 assert compute_ref_coords, "compute_ref_coords must be True" 569 # predicted = predicted - output_data_ref[loss_prms["key"]] 570 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 571 if predicted.shape[0] == nsys: 572 system_sign = inputs["system_sign"].reshape(*shape_mask) 573 predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys) 574 elif predicted.shape[0] == batch_index.shape[0]: 575 atom_sign = inputs["system_sign"][batch_index].reshape(*shape_mask) 576 natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)]) 577 iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index] 578 iatdest = natshift[inputs["system_index"]][batch_index]+iat 579 predicted = jax.ops.segment_sum(atom_sign*predicted,iatdest,batch_index.shape[0]) 580 else: 581 raise ValueError("remove_ref_sys only works with system-level or atom-level predictions") 582 583 if "ref" in loss_prms: 584 if loss_prms["ref"].startswith("model_ref/"): 585 assert model_ref is not None, "model_ref must be provided" 586 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 587 mask_key = loss_prms["ref"][6:] + "_mask" 588 if mask_key in data: 589 use_ref_mask = True 590 ref_mask = data[mask_key] 591 elif mask_key in output_ref: 592 use_ref_mask = True 593 ref_mask = output[mask_key] 594 elif mask_key in output: 595 use_ref_mask = True 596 ref_mask = output[mask_key] 597 elif loss_prms["ref"].startswith("model/"): 598 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 599 mask_key = loss_prms["ref"][6:] + "_mask" 600 if mask_key in data: 601 use_ref_mask = True 602 ref_mask = data[mask_key] 603 elif mask_key in output: 604 use_ref_mask = True 605 ref_mask = output[mask_key] 606 else: 607 ref = data[loss_prms["ref"]] * loss_prms["mult"] 608 if loss_prms["ref"] + "_mask" in data: 609 use_ref_mask = True 610 ref_mask = data[loss_prms["ref"] + "_mask"] 611 else: 612 ref = jnp.zeros_like(predicted) 613 614 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 615 norm_axis = loss_prms["norm_axis"] 616 predicted = jnp.linalg.norm( 617 predicted, axis=norm_axis, keepdims=True 618 ) 619 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 620 621 loss_type = loss_prms["type"].lower() 622 if loss_type.startswith("ensemble"): 623 axis = loss_prms.get("ensemble_axis", -1) 624 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 625 if ensemble_weights_key is not None: 626 ensemble_weights = output[ensemble_weights_key] 627 else: 628 ensemble_weights = jnp.ones_like(predicted) 629 ensemble_weights = ensemble_weights / jnp.sum( 630 ensemble_weights, axis=axis, keepdims=True 631 ) 632 633 # predicted = predicted.mean(axis=axis) 634 predicted = (ensemble_weights * predicted).sum(axis=axis) 635 636 if loss_prms["type"] in ["gaussian_mixture"]: 637 axis = loss_prms.get("ensemble_axis", -1) 638 gm_w = output[loss_prms["key"] + "_w"] 639 gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True) 640 predicted = (predicted * gm_w).sum(axis=axis) 641 642 if predicted.ndim > 1 and predicted.shape[-1] == 1: 643 predicted = jnp.squeeze(predicted, axis=-1) 644 645 if ref.ndim > 1 and ref.shape[-1] == 1: 646 ref = jnp.squeeze(ref, axis=-1) 647 648 shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1) 649 if ref.shape[0] == output["batch_index"].shape[0]: 650 ## shape is number of atoms 651 truth_mask = atom_mask 652 elif ref.shape[0] == natoms.shape[0]: 653 ## shape is number of systems 654 truth_mask = system_mask 655 if loss_prms.get("per_atom_validation", False): 656 ref = ref / natoms.reshape(*shape_mask) 657 predicted = predicted / natoms.reshape(*shape_mask) 658 else: 659 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 660 661 if use_ref_mask: 662 truth_mask = truth_mask * ref_mask.astype(bool) 663 664 nel = jnp.maximum( 665 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 666 * jnp.sum(truth_mask).astype(predicted.dtype), 667 1.0, 668 ) 669 670 truth_mask = truth_mask.reshape(*shape_mask) 671 672 ref = ref * truth_mask 673 predicted = predicted * truth_mask 674 675 if loss_type == "binary_crossentropy": 676 ref = ref.astype(predicted.dtype) 677 if loss_prms.get("raw_logits",False): 678 rmse = jnp.sum( 679 ( 680 -ref * jax.nn.log_sigmoid(predicted) 681 - (1.0 - ref) * jax.nn.log_sigmoid(-predicted) 682 )*truth_mask/nel 683 ) 684 predicted = truth_mask*jax.nn.sigmoid(predicted) 685 else: 686 rmse = jnp.sum( 687 ( 688 -ref * jnp.log(predicted + 1.0e-10) 689 - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10) 690 )*truth_mask/nel 691 ) 692 thr = loss_prms.get("class_threshold", 0.05) 693 mae = jnp.sum(jnp.where(jnp.abs(predicted - ref) > thr, 1, 0)/nel)*100 694 else: 695 rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5 696 mae = jnp.sum(jnp.abs(predicted - ref) / nel) 697 698 rmses[name] = rmse 699 maes[name] = mae 700 if return_targets: 701 targets[name] = (predicted, ref, truth_mask) 702 703 if return_targets: 704 return rmses, maes, output, targets 705 706 return rmses, maes, output 707 708 if jit: 709 return jax.jit(validation) 710 return validation