fennol.training.utils
1import jax 2import jax.numpy as jnp 3import numpy as np 4from typing import Callable, Optional, Dict, List, Tuple, Union, Any, NamedTuple 5import optax 6from copy import deepcopy 7from functools import partial 8 9from ..utils import deep_update, AtomicUnits as au 10from ..models import FENNIX 11 12 13@partial(jax.jit, static_argnums=(1,2,3,4)) 14def linear_schedule(step, start_value,end_value, start_step, duration): 15 return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value) 16 17 18def get_training_parameters( 19 parameters: Dict[str, any], stage: int = -1 20) -> Dict[str, any]: 21 params = deepcopy(parameters["training"]) 22 if "stages" not in params: 23 return params 24 25 stages: dict = params.pop("stages") 26 stage_keys = list(stages.keys()) 27 if stage < 0: 28 stage = len(stage_keys) + stage 29 assert stage >= 0 and stage < len( 30 stage_keys 31 ), f"Stage {stage} not found in training parameters" 32 for i in range(stage + 1): 33 ## remove end_event from previous stage ## 34 if i > 0 and "end_event" in params: 35 params.pop("end_event") 36 ## incrementally update training parameters ## 37 stage_params = stages[stage_keys[i]] 38 params = deep_update(params, stage_params) 39 return params 40 41 42def get_loss_definition( 43 training_parameters: Dict[str, any], 44 model_energy_unit: str = "Ha", # , manual_renames: List[str] = [] 45) -> Tuple[Dict[str, any], List[str], List[str]]: 46 """ 47 Returns the loss definition and a list of renamed references. 48 49 Args: 50 training_parameters (dict): A dictionary containing training parameters. 51 52 Returns: 53 tuple: A tuple containing: 54 - loss_definition (dict): A dictionary containing the loss definition. 55 - rename_refs (list): A list of renamed references. 56 """ 57 default_loss_type = training_parameters.get("default_loss_type", "log_cosh") 58 # loss_definition = deepcopy(training_parameters["loss"]) 59 used_keys = [] 60 ref_keys = [] 61 energy_mult = au.get_multiplier(model_energy_unit) 62 loss_definition = {} 63 for k in training_parameters["loss"]: 64 # loss_prms = loss_definition[k] 65 loss_prms = deepcopy(training_parameters["loss"][k]) 66 if "energy_unit" in loss_prms: 67 loss_prms["mult"] = energy_mult / au.get_multiplier( 68 loss_prms["energy_unit"] 69 ) 70 if "unit" in loss_prms: 71 print( 72 "Warning: Both 'unit' and 'energy_unit' are defined for loss component", 73 k, 74 " -> using 'energy_unit'", 75 ) 76 loss_prms["unit"] = loss_prms["energy_unit"] 77 elif "unit" in loss_prms: 78 loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"]) 79 else: 80 loss_prms["mult"] = 1.0 81 if "key" not in loss_prms: 82 loss_prms["key"] = k 83 if "type" not in loss_prms: 84 loss_prms["type"] = default_loss_type 85 if "weight" not in loss_prms: 86 loss_prms["weight"] = 1.0 87 assert loss_prms["weight"] >= 0.0, "Loss weight must be positive" 88 if "threshold" in loss_prms: 89 assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0" 90 if "ref" in loss_prms: 91 ref = loss_prms["ref"] 92 if not (ref.startswith("model_ref/") or ref.startswith("model/")): 93 ref_keys.append(ref) 94 if "ds_weight" in loss_prms: 95 ref_keys.append(loss_prms["ds_weight"]) 96 97 if "weight_start" in loss_prms: 98 weight_start = loss_prms["weight_start"] 99 if "weight_ramp" in loss_prms: 100 weight_ramp = loss_prms["weight_ramp"] 101 else: 102 weight_ramp = training_parameters.get("max_epochs") 103 weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0) 104 weight_end = loss_prms["weight"] 105 print(f"Weight ramp for {k} : {weight_start} -> {weight_end} in {weight_ramp} epochs") 106 loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp) 107 108 used_keys.append(loss_prms["key"]) 109 loss_definition[k] = loss_prms 110 111 # rename_refs = list( 112 # set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys) 113 # ) 114 115 # for k in loss_definition.keys(): 116 # loss_prms = loss_definition[k] 117 # if "ref" in loss_prms: 118 # if loss_prms["ref"] in rename_refs: 119 # loss_prms["ref"] = "true_" + loss_prms["ref"] 120 121 return loss_definition, list(set(used_keys)), list(set(ref_keys)) 122 123 124def get_train_step_function( 125 loss_definition: Dict, 126 model: FENNIX, 127 evaluate: Callable, 128 optimizer: optax.GradientTransformation, 129 ema: Optional[optax.GradientTransformation] = None, 130 model_ref: Optional[FENNIX] = None, 131 compute_ref_coords: bool = False, 132 jit: bool = True, 133): 134 def train_step( 135 data, 136 inputs, 137 training_state, 138 ): 139 epoch = training_state["epoch"] 140 141 def loss_fn(variables): 142 if model_ref is not None: 143 output_ref = evaluate(model_ref, model_ref.variables, inputs) 144 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 145 output = evaluate(model, variables, inputs) 146 # _, _, output = model._energy_and_forces(variables, data) 147 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 148 nsys = inputs["natoms"].shape[0] 149 system_mask = inputs["true_sys"] 150 atom_mask = inputs["true_atoms"] 151 batch_index = inputs["batch_index"] 152 if "system_sign" in inputs: 153 system_sign = inputs["system_sign"] > 0 154 system_mask = jnp.logical_and(system_mask,system_sign) 155 atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index]) 156 157 loss_tot = 0.0 158 for loss_prms in loss_definition.values(): 159 use_ref_mask = False 160 predicted = output[loss_prms["key"]] 161 162 163 if "ref" in loss_prms: 164 if loss_prms["ref"].startswith("model_ref/"): 165 assert model_ref is not None, "model_ref must be provided" 166 try: 167 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 168 except KeyError: 169 raise KeyError( 170 f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}" 171 ) 172 mask_key = loss_prms["ref"][6:] + "_mask" 173 if mask_key in data: 174 use_ref_mask = True 175 ref_mask = data[mask_key] 176 elif mask_key in output_ref: 177 use_ref_mask = True 178 ref_mask = output[mask_key] 179 elif mask_key in output: 180 use_ref_mask = True 181 ref_mask = output[mask_key] 182 elif loss_prms["ref"].startswith("model/"): 183 try: 184 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 185 except KeyError: 186 raise KeyError( 187 f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}" 188 ) 189 mask_key = loss_prms["ref"][6:] + "_mask" 190 if mask_key in data: 191 use_ref_mask = True 192 ref_mask = data[mask_key] 193 elif mask_key in output: 194 use_ref_mask = True 195 ref_mask = output[mask_key] 196 else: 197 try: 198 ref = data[loss_prms["ref"]] * loss_prms["mult"] 199 except KeyError: 200 raise KeyError( 201 f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}" 202 ) 203 if loss_prms["ref"] + "_mask" in data: 204 use_ref_mask = True 205 ref_mask = data[loss_prms["ref"] + "_mask"] 206 else: 207 ref = jnp.zeros_like(predicted) 208 209 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 210 norm_axis = loss_prms["norm_axis"] 211 predicted = jnp.linalg.norm( 212 predicted, axis=norm_axis, keepdims=True 213 ) 214 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 215 216 ensemble_loss = loss_prms["type"] in ["ensemble_nll","ensemble_crps","gaussian_mixture",] 217 if ensemble_loss: 218 ensemble_axis = loss_prms.get("ensemble_axis", -1) 219 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 220 ensemble_size = predicted.shape[ensemble_axis] 221 if ensemble_weights_key is not None: 222 ensemble_weights = output[ensemble_weights_key] 223 else: 224 ensemble_weights = jnp.ones_like(predicted) 225 if "ensemble_subsample" in loss_prms and "rng_key" in output: 226 ns = min( 227 loss_prms["ensemble_subsample"], 228 predicted.shape[ensemble_axis], 229 ) 230 key, subkey = jax.random.split(output["rng_key"]) 231 232 def get_subsample(x): 233 return jax.lax.slice_in_dim( 234 jax.random.permutation( 235 subkey, x, axis=ensemble_axis, independent=True 236 ), 237 start_index=0, 238 limit_index=ns, 239 axis=ensemble_axis, 240 ) 241 242 predicted = get_subsample(predicted) 243 ensemble_weights = get_subsample(ensemble_weights) 244 output["rng_key"] = key 245 else: 246 get_subsample = lambda x: x 247 ensemble_weights = ensemble_weights / jnp.sum( 248 ensemble_weights, axis=ensemble_axis, keepdims=True 249 ) 250 predicted_ensemble = predicted 251 predicted = (predicted * ensemble_weights).sum( 252 axis=ensemble_axis, keepdims=True 253 ) 254 predicted_var = ( 255 ensemble_weights * (predicted_ensemble - predicted) ** 2 256 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 257 predicted = jnp.squeeze(predicted, axis=ensemble_axis) 258 259 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 260 assert compute_ref_coords, "compute_ref_coords must be True" 261 # predicted = predicted - output_data_ref[loss_prms["key"]] 262 # assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions" 263 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 264 if predicted.shape[0] == nsys: 265 predicted_sign_mask = inputs["system_sign"].reshape(*shape_mask) < 0 266 predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,inputs["system_index"],nsys) 267 predicted = predicted - predicted_shift 268 elif predicted.shape[0] == batch_index.shape[0]: 269 predicted_sign_mask = (inputs["system_sign"]<0)[batch_index].reshape(*shape_mask) 270 natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)]) 271 iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index] 272 iatdest = natshift[inputs["system_index"]][batch_index]+iat 273 predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,iatdest,batch_index.shape[0]) 274 predicted = predicted - predicted_shift 275 else: 276 raise ValueError("remove_ref_sys only works with system-level or atom-level predictions") 277 278 if ensemble_loss: 279 predicted_ensemble = predicted_ensemble - jnp.expand_dims(predicted_shift,ensemble_axis) 280 predicted_var = ( 281 ensemble_weights * (predicted_ensemble - jnp.expand_dims(predicted,ensemble_axis)) ** 2 282 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 283 284 if predicted.ndim > 1 and predicted.shape[-1] == 1: 285 predicted = jnp.squeeze(predicted, axis=-1) 286 287 if ref.ndim > 1 and ref.shape[-1] == 1: 288 ref = jnp.squeeze(ref, axis=-1) 289 290 per_atom = False 291 shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1) 292 # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape) 293 natscale = 1.0 294 if ref.shape[0] == output["batch_index"].shape[0]: 295 ## shape is number of atoms 296 truth_mask = atom_mask 297 if "ds_weight" in loss_prms: 298 weight_key = loss_prms["ds_weight"] 299 natscale=data[weight_key] 300 if natscale.shape[0] == natoms.shape[0]: 301 natscale = natscale[output["batch_index"]] 302 natscale = natscale.reshape(*shape_mask) 303 elif ref.shape[0] == natoms.shape[0]: 304 ## shape is number of systems 305 truth_mask = system_mask 306 if loss_prms.get("per_atom", False): 307 per_atom = True 308 ref = ref / natoms.reshape(*shape_mask) 309 predicted = predicted / natoms.reshape(*shape_mask) 310 if "nat_pow" in loss_prms: 311 natscale = ( 312 1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"] 313 ) 314 if "ds_weight" in loss_prms: 315 weight_key = loss_prms["ds_weight"] 316 natscale = natscale * data[weight_key].reshape(*shape_mask) 317 else: 318 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 319 320 if use_ref_mask: 321 truth_mask = truth_mask * ref_mask.astype(bool) 322 323 nel = jnp.maximum( 324 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 325 * jnp.sum(truth_mask).astype(jnp.float32), 326 1.0, 327 ) 328 truth_mask = truth_mask.reshape(*shape_mask) 329 330 ref = ref * truth_mask 331 predicted = predicted * truth_mask 332 333 natscale = natscale / nel 334 335 loss_type = loss_prms["type"] 336 if loss_type == "mse": 337 loss = jnp.sum(natscale * (predicted - ref) ** 2) 338 elif loss_type == "log_cosh": 339 loss = jnp.sum(natscale * optax.log_cosh(predicted, ref)) 340 elif loss_type == "mae": 341 loss = jnp.sum(natscale * jnp.abs(predicted - ref)) 342 elif loss_type == "rel_mse": 343 eps = loss_prms.get("rel_eps", 1.0e-6) 344 rel_pow = loss_prms.get("rel_pow", 2.0) 345 loss = jnp.sum( 346 natscale 347 * (predicted - ref) ** 2 348 / (eps + jnp.abs(ref) ** rel_pow) 349 ) 350 elif loss_type == "rmse+mae": 351 loss = ( 352 jnp.sum(natscale * (predicted - ref) ** 2) 353 ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref)) 354 elif loss_type == "crps": 355 predicted_var_key = loss_prms.get( 356 "var_key", loss_prms["key"] + "_var" 357 ) 358 lamba_crps = loss_prms.get("lambda_crps", 1.0) 359 predicted_var = output[predicted_var_key] 360 if per_atom: 361 predicted_var = predicted_var / natoms.reshape(*shape_mask)**2 362 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 363 sigma = predicted_var**0.5 364 dy = (ref - predicted) / sigma 365 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 366 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 367 loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5)) 368 elif loss_type == "ensemble_nll": 369 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 370 loss = 0.5 * jnp.sum( 371 natscale 372 * truth_mask 373 * ( 374 jnp.log(predicted_var) 375 + (ref - predicted) ** 2 / predicted_var 376 ) 377 ) 378 elif loss_type == "ensemble_crps": 379 if per_atom: 380 predicted_var = predicted_var / natoms.reshape(*shape_mask)**2 381 lamba_crps = loss_prms.get("lambda_crps", 1.0) 382 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 383 sigma = predicted_var**0.5 384 dy = (ref - predicted) / sigma 385 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 386 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 387 loss = jnp.sum( 388 natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5) 389 ) 390 elif loss_type == "gaussian_mixture": 391 gm_w = get_subsample(output[loss_prms["key"] + "_w"]) 392 gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True) 393 sigma = get_subsample(output[loss_prms["key"] + "_sigma"]) 394 ref = jnp.expand_dims(ref, axis=ensemble_axis) 395 # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2) 396 # log_w =jnp.log(gm_w + 1.e-10) 397 # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis) 398 # loss = jnp.sum(natscale * truth_mask * negloglikelihood) 399 400 # CRPS loss 401 def compute_crps(mu, sigma): 402 dy = mu / sigma 403 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 404 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 405 return sigma * (dy * (2 * Phi - 1.0) + 2 * phi - 1./jnp.pi**0.5) 406 407 crps1 = compute_crps(ref - predicted_ensemble, sigma) 408 gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis) 409 410 if ensemble_axis < 0: 411 ensemble_axis = predicted_ensemble.ndim + ensemble_axis 412 muij = jnp.expand_dims( 413 predicted_ensemble, axis=ensemble_axis 414 ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1) 415 sigma2 = sigma**2 416 sigmaij = ( 417 jnp.expand_dims(sigma2, axis=ensemble_axis) 418 + jnp.expand_dims(sigma2, axis=ensemble_axis + 1) 419 ) ** 0.5 420 wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims( 421 gm_w, axis=ensemble_axis + 1 422 ) 423 crps2 = compute_crps(muij, sigmaij) 424 gm_crps2 = (wij * crps2).sum( 425 axis=(ensemble_axis, ensemble_axis + 1) 426 ) 427 428 crps = gm_crps1 - 0.5 * gm_crps2 429 loss = jnp.sum(natscale * truth_mask * crps) 430 431 elif loss_type == "evidential": 432 evidence = loss_prms["evidence_key"] 433 nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1) 434 gamma = predicted 435 nu = nu.reshape(shape_mask) 436 alpha = alpha.reshape(shape_mask) 437 beta = beta.reshape(shape_mask) 438 nu = jnp.where(truth_mask, nu, 1.0) 439 alpha = jnp.where(truth_mask, alpha, 1.0) 440 beta = jnp.where(truth_mask, beta, 1.0) 441 omega = 2 * beta * (1 + nu) 442 lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln( 443 alpha + 0.5 444 ) 445 ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega) 446 lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2) 447 wst = ( 448 (beta * (1 + nu) / (alpha * nu)) ** 0.5 449 if loss_prms.get("normalize_evidence", True) 450 else 1.0 451 ) 452 lr = ( 453 loss_prms.get("lambda_evidence", 1.0) 454 * jnp.abs(gamma - ref) 455 * nu 456 / wst 457 ) 458 r = loss_prms.get("evidence_ratio", 1.0) 459 le = ( 460 loss_prms.get("lambda_evidence_diff", 0.0) 461 * (nu - r * 2 * alpha) ** 2 462 ) 463 lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta 464 loss = lg + ls + lt + lr + le + lb 465 466 loss = jnp.sum(natscale * loss * truth_mask) 467 elif loss_type == "raw": 468 loss = jnp.sum(natscale * predicted) 469 elif loss_type == "binary_crossentropy": 470 ref = ref.astype(predicted.dtype) 471 if loss_prms.get("raw_logits",False): 472 loss = jnp.sum( 473 natscale*truth_mask 474 * ( 475 -ref * jax.nn.log_sigmoid(predicted) 476 - (1.0 - ref) * jax.nn.log_sigmoid(-predicted) 477 ) 478 ) 479 else: 480 loss = jnp.sum( 481 natscale*truth_mask 482 * ( 483 -ref * jnp.log(predicted + 1.0e-10) 484 - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10) 485 ) 486 ) 487 else: 488 raise ValueError(f"Unknown loss type: {loss_type}") 489 490 if "weight_schedule" in loss_prms: 491 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 492 else: 493 w = loss_prms["weight"] 494 495 loss_tot = loss_tot + w * loss 496 497 return loss_tot, output 498 499 variables = training_state["variables"] 500 new_training_state = {**training_state, "step":training_state["step"]+1} 501 502 (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables) 503 updates, new_training_state["opt_state"] = optimizer.update(grad, training_state["opt_state"], params=variables) 504 variables = optax.apply_updates(variables, updates) 505 new_training_state["variables"] = variables 506 507 if ema is not None: 508 ema_st = training_state.get("ema_state",None) 509 if ema_st is None: 510 raise ValueError( 511 "train_step was setup with ema but either variables_ema or ema_st was not provided" 512 ) 513 new_training_state["variables_ema"], new_training_state["ema_state"] = ema.update(variables, ema_st) 514 515 return loss, new_training_state, o 516 517 if jit: 518 return jax.jit(train_step) 519 return train_step 520 521 522def get_validation_function( 523 loss_definition: Dict, 524 model: FENNIX, 525 evaluate: Callable, 526 model_ref: Optional[FENNIX] = None, 527 compute_ref_coords: bool = False, 528 return_targets: bool = False, 529 jit: bool = True, 530): 531 532 def validation(data, inputs, variables, inputs_ref=None): 533 if model_ref is not None: 534 output_ref = evaluate(model_ref, model_ref.variables, inputs) 535 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 536 output = evaluate(model, variables, inputs) 537 # _, _, output = model._energy_and_forces(variables, data) 538 rmses = {} 539 maes = {} 540 if return_targets: 541 targets = {} 542 543 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 544 nsys = inputs["natoms"].shape[0] 545 system_mask = inputs["true_sys"] 546 atom_mask = inputs["true_atoms"] 547 batch_index = inputs["batch_index"] 548 if "system_sign" in inputs: 549 system_sign = inputs["system_sign"] > 0 550 system_mask = jnp.logical_and(system_mask,system_sign) 551 atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index]) 552 553 for name, loss_prms in loss_definition.items(): 554 do_validation = loss_prms.get("validate", True) 555 if not do_validation: 556 continue 557 predicted = output[loss_prms["key"]] 558 use_ref_mask = False 559 560 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 561 assert compute_ref_coords, "compute_ref_coords must be True" 562 # predicted = predicted - output_data_ref[loss_prms["key"]] 563 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 564 if predicted.shape[0] == nsys: 565 system_sign = inputs["system_sign"].reshape(*shape_mask) 566 predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys) 567 elif predicted.shape[0] == batch_index.shape[0]: 568 atom_sign = inputs["system_sign"][batch_index].reshape(*shape_mask) 569 natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)]) 570 iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index] 571 iatdest = natshift[inputs["system_index"]][batch_index]+iat 572 predicted = jax.ops.segment_sum(atom_sign*predicted,iatdest,batch_index.shape[0]) 573 else: 574 raise ValueError("remove_ref_sys only works with system-level or atom-level predictions") 575 576 if "ref" in loss_prms: 577 if loss_prms["ref"].startswith("model_ref/"): 578 assert model_ref is not None, "model_ref must be provided" 579 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 580 mask_key = loss_prms["ref"][6:] + "_mask" 581 if mask_key in data: 582 use_ref_mask = True 583 ref_mask = data[mask_key] 584 elif mask_key in output_ref: 585 use_ref_mask = True 586 ref_mask = output[mask_key] 587 elif mask_key in output: 588 use_ref_mask = True 589 ref_mask = output[mask_key] 590 elif loss_prms["ref"].startswith("model/"): 591 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 592 mask_key = loss_prms["ref"][6:] + "_mask" 593 if mask_key in data: 594 use_ref_mask = True 595 ref_mask = data[mask_key] 596 elif mask_key in output: 597 use_ref_mask = True 598 ref_mask = output[mask_key] 599 else: 600 ref = data[loss_prms["ref"]] * loss_prms["mult"] 601 if loss_prms["ref"] + "_mask" in data: 602 use_ref_mask = True 603 ref_mask = data[loss_prms["ref"] + "_mask"] 604 else: 605 ref = jnp.zeros_like(predicted) 606 607 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 608 norm_axis = loss_prms["norm_axis"] 609 predicted = jnp.linalg.norm( 610 predicted, axis=norm_axis, keepdims=True 611 ) 612 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 613 614 loss_type = loss_prms["type"].lower() 615 if loss_type.startswith("ensemble"): 616 axis = loss_prms.get("ensemble_axis", -1) 617 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 618 if ensemble_weights_key is not None: 619 ensemble_weights = output[ensemble_weights_key] 620 else: 621 ensemble_weights = jnp.ones_like(predicted) 622 ensemble_weights = ensemble_weights / jnp.sum( 623 ensemble_weights, axis=axis, keepdims=True 624 ) 625 626 # predicted = predicted.mean(axis=axis) 627 predicted = (ensemble_weights * predicted).sum(axis=axis) 628 629 if loss_prms["type"] in ["gaussian_mixture"]: 630 axis = loss_prms.get("ensemble_axis", -1) 631 gm_w = output[loss_prms["key"] + "_w"] 632 gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True) 633 predicted = (predicted * gm_w).sum(axis=axis) 634 635 if predicted.ndim > 1 and predicted.shape[-1] == 1: 636 predicted = jnp.squeeze(predicted, axis=-1) 637 638 if ref.ndim > 1 and ref.shape[-1] == 1: 639 ref = jnp.squeeze(ref, axis=-1) 640 641 shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1) 642 if ref.shape[0] == output["batch_index"].shape[0]: 643 ## shape is number of atoms 644 truth_mask = atom_mask 645 elif ref.shape[0] == natoms.shape[0]: 646 ## shape is number of systems 647 truth_mask = system_mask 648 if loss_prms.get("per_atom_validation", False): 649 ref = ref / natoms.reshape(*shape_mask) 650 predicted = predicted / natoms.reshape(*shape_mask) 651 else: 652 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 653 654 if use_ref_mask: 655 truth_mask = truth_mask * ref_mask.astype(bool) 656 657 nel = jnp.maximum( 658 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 659 * jnp.sum(truth_mask).astype(predicted.dtype), 660 1.0, 661 ) 662 663 truth_mask = truth_mask.reshape(*shape_mask) 664 665 ref = ref * truth_mask 666 predicted = predicted * truth_mask 667 668 if loss_type == "binary_crossentropy": 669 ref = ref.astype(predicted.dtype) 670 if loss_prms.get("raw_logits",False): 671 rmse = jnp.sum( 672 ( 673 -ref * jax.nn.log_sigmoid(predicted) 674 - (1.0 - ref) * jax.nn.log_sigmoid(-predicted) 675 )*truth_mask/nel 676 ) 677 predicted = truth_mask*jax.nn.sigmoid(predicted) 678 else: 679 rmse = jnp.sum( 680 ( 681 -ref * jnp.log(predicted + 1.0e-10) 682 - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10) 683 )*truth_mask/nel 684 ) 685 thr = loss_prms.get("class_threshold", 0.05) 686 mae = jnp.sum(jnp.where(jnp.abs(predicted - ref) > thr, 1, 0)/nel)*100 687 else: 688 rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5 689 mae = jnp.sum(jnp.abs(predicted - ref) / nel) 690 691 rmses[name] = rmse 692 maes[name] = mae 693 if return_targets: 694 targets[name] = (predicted, ref, truth_mask) 695 696 if return_targets: 697 return rmses, maes, output, targets 698 699 return rmses, maes, output 700 701 if jit: 702 return jax.jit(validation) 703 return validation
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def
linear_schedule(step, start_value, end_value, start_step, duration):
def
get_training_parameters( parameters: Dict[str, <built-in function any>], stage: int = -1) -> Dict[str, <built-in function any>]:
19def get_training_parameters( 20 parameters: Dict[str, any], stage: int = -1 21) -> Dict[str, any]: 22 params = deepcopy(parameters["training"]) 23 if "stages" not in params: 24 return params 25 26 stages: dict = params.pop("stages") 27 stage_keys = list(stages.keys()) 28 if stage < 0: 29 stage = len(stage_keys) + stage 30 assert stage >= 0 and stage < len( 31 stage_keys 32 ), f"Stage {stage} not found in training parameters" 33 for i in range(stage + 1): 34 ## remove end_event from previous stage ## 35 if i > 0 and "end_event" in params: 36 params.pop("end_event") 37 ## incrementally update training parameters ## 38 stage_params = stages[stage_keys[i]] 39 params = deep_update(params, stage_params) 40 return params
def
get_loss_definition( training_parameters: Dict[str, <built-in function any>], model_energy_unit: str = 'Ha') -> Tuple[Dict[str, <built-in function any>], List[str], List[str]]:
43def get_loss_definition( 44 training_parameters: Dict[str, any], 45 model_energy_unit: str = "Ha", # , manual_renames: List[str] = [] 46) -> Tuple[Dict[str, any], List[str], List[str]]: 47 """ 48 Returns the loss definition and a list of renamed references. 49 50 Args: 51 training_parameters (dict): A dictionary containing training parameters. 52 53 Returns: 54 tuple: A tuple containing: 55 - loss_definition (dict): A dictionary containing the loss definition. 56 - rename_refs (list): A list of renamed references. 57 """ 58 default_loss_type = training_parameters.get("default_loss_type", "log_cosh") 59 # loss_definition = deepcopy(training_parameters["loss"]) 60 used_keys = [] 61 ref_keys = [] 62 energy_mult = au.get_multiplier(model_energy_unit) 63 loss_definition = {} 64 for k in training_parameters["loss"]: 65 # loss_prms = loss_definition[k] 66 loss_prms = deepcopy(training_parameters["loss"][k]) 67 if "energy_unit" in loss_prms: 68 loss_prms["mult"] = energy_mult / au.get_multiplier( 69 loss_prms["energy_unit"] 70 ) 71 if "unit" in loss_prms: 72 print( 73 "Warning: Both 'unit' and 'energy_unit' are defined for loss component", 74 k, 75 " -> using 'energy_unit'", 76 ) 77 loss_prms["unit"] = loss_prms["energy_unit"] 78 elif "unit" in loss_prms: 79 loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"]) 80 else: 81 loss_prms["mult"] = 1.0 82 if "key" not in loss_prms: 83 loss_prms["key"] = k 84 if "type" not in loss_prms: 85 loss_prms["type"] = default_loss_type 86 if "weight" not in loss_prms: 87 loss_prms["weight"] = 1.0 88 assert loss_prms["weight"] >= 0.0, "Loss weight must be positive" 89 if "threshold" in loss_prms: 90 assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0" 91 if "ref" in loss_prms: 92 ref = loss_prms["ref"] 93 if not (ref.startswith("model_ref/") or ref.startswith("model/")): 94 ref_keys.append(ref) 95 if "ds_weight" in loss_prms: 96 ref_keys.append(loss_prms["ds_weight"]) 97 98 if "weight_start" in loss_prms: 99 weight_start = loss_prms["weight_start"] 100 if "weight_ramp" in loss_prms: 101 weight_ramp = loss_prms["weight_ramp"] 102 else: 103 weight_ramp = training_parameters.get("max_epochs") 104 weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0) 105 weight_end = loss_prms["weight"] 106 print(f"Weight ramp for {k} : {weight_start} -> {weight_end} in {weight_ramp} epochs") 107 loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp) 108 109 used_keys.append(loss_prms["key"]) 110 loss_definition[k] = loss_prms 111 112 # rename_refs = list( 113 # set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys) 114 # ) 115 116 # for k in loss_definition.keys(): 117 # loss_prms = loss_definition[k] 118 # if "ref" in loss_prms: 119 # if loss_prms["ref"] in rename_refs: 120 # loss_prms["ref"] = "true_" + loss_prms["ref"] 121 122 return loss_definition, list(set(used_keys)), list(set(ref_keys))
Returns the loss definition and a list of renamed references.
Args: training_parameters (dict): A dictionary containing training parameters.
Returns: tuple: A tuple containing: - loss_definition (dict): A dictionary containing the loss definition. - rename_refs (list): A list of renamed references.
def
get_train_step_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, optimizer: optax._src.base.GradientTransformation, ema: Optional[optax._src.base.GradientTransformation] = None, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, jit: bool = True):
125def get_train_step_function( 126 loss_definition: Dict, 127 model: FENNIX, 128 evaluate: Callable, 129 optimizer: optax.GradientTransformation, 130 ema: Optional[optax.GradientTransformation] = None, 131 model_ref: Optional[FENNIX] = None, 132 compute_ref_coords: bool = False, 133 jit: bool = True, 134): 135 def train_step( 136 data, 137 inputs, 138 training_state, 139 ): 140 epoch = training_state["epoch"] 141 142 def loss_fn(variables): 143 if model_ref is not None: 144 output_ref = evaluate(model_ref, model_ref.variables, inputs) 145 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 146 output = evaluate(model, variables, inputs) 147 # _, _, output = model._energy_and_forces(variables, data) 148 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 149 nsys = inputs["natoms"].shape[0] 150 system_mask = inputs["true_sys"] 151 atom_mask = inputs["true_atoms"] 152 batch_index = inputs["batch_index"] 153 if "system_sign" in inputs: 154 system_sign = inputs["system_sign"] > 0 155 system_mask = jnp.logical_and(system_mask,system_sign) 156 atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index]) 157 158 loss_tot = 0.0 159 for loss_prms in loss_definition.values(): 160 use_ref_mask = False 161 predicted = output[loss_prms["key"]] 162 163 164 if "ref" in loss_prms: 165 if loss_prms["ref"].startswith("model_ref/"): 166 assert model_ref is not None, "model_ref must be provided" 167 try: 168 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 169 except KeyError: 170 raise KeyError( 171 f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}" 172 ) 173 mask_key = loss_prms["ref"][6:] + "_mask" 174 if mask_key in data: 175 use_ref_mask = True 176 ref_mask = data[mask_key] 177 elif mask_key in output_ref: 178 use_ref_mask = True 179 ref_mask = output[mask_key] 180 elif mask_key in output: 181 use_ref_mask = True 182 ref_mask = output[mask_key] 183 elif loss_prms["ref"].startswith("model/"): 184 try: 185 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 186 except KeyError: 187 raise KeyError( 188 f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}" 189 ) 190 mask_key = loss_prms["ref"][6:] + "_mask" 191 if mask_key in data: 192 use_ref_mask = True 193 ref_mask = data[mask_key] 194 elif mask_key in output: 195 use_ref_mask = True 196 ref_mask = output[mask_key] 197 else: 198 try: 199 ref = data[loss_prms["ref"]] * loss_prms["mult"] 200 except KeyError: 201 raise KeyError( 202 f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}" 203 ) 204 if loss_prms["ref"] + "_mask" in data: 205 use_ref_mask = True 206 ref_mask = data[loss_prms["ref"] + "_mask"] 207 else: 208 ref = jnp.zeros_like(predicted) 209 210 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 211 norm_axis = loss_prms["norm_axis"] 212 predicted = jnp.linalg.norm( 213 predicted, axis=norm_axis, keepdims=True 214 ) 215 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 216 217 ensemble_loss = loss_prms["type"] in ["ensemble_nll","ensemble_crps","gaussian_mixture",] 218 if ensemble_loss: 219 ensemble_axis = loss_prms.get("ensemble_axis", -1) 220 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 221 ensemble_size = predicted.shape[ensemble_axis] 222 if ensemble_weights_key is not None: 223 ensemble_weights = output[ensemble_weights_key] 224 else: 225 ensemble_weights = jnp.ones_like(predicted) 226 if "ensemble_subsample" in loss_prms and "rng_key" in output: 227 ns = min( 228 loss_prms["ensemble_subsample"], 229 predicted.shape[ensemble_axis], 230 ) 231 key, subkey = jax.random.split(output["rng_key"]) 232 233 def get_subsample(x): 234 return jax.lax.slice_in_dim( 235 jax.random.permutation( 236 subkey, x, axis=ensemble_axis, independent=True 237 ), 238 start_index=0, 239 limit_index=ns, 240 axis=ensemble_axis, 241 ) 242 243 predicted = get_subsample(predicted) 244 ensemble_weights = get_subsample(ensemble_weights) 245 output["rng_key"] = key 246 else: 247 get_subsample = lambda x: x 248 ensemble_weights = ensemble_weights / jnp.sum( 249 ensemble_weights, axis=ensemble_axis, keepdims=True 250 ) 251 predicted_ensemble = predicted 252 predicted = (predicted * ensemble_weights).sum( 253 axis=ensemble_axis, keepdims=True 254 ) 255 predicted_var = ( 256 ensemble_weights * (predicted_ensemble - predicted) ** 2 257 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 258 predicted = jnp.squeeze(predicted, axis=ensemble_axis) 259 260 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 261 assert compute_ref_coords, "compute_ref_coords must be True" 262 # predicted = predicted - output_data_ref[loss_prms["key"]] 263 # assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions" 264 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 265 if predicted.shape[0] == nsys: 266 predicted_sign_mask = inputs["system_sign"].reshape(*shape_mask) < 0 267 predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,inputs["system_index"],nsys) 268 predicted = predicted - predicted_shift 269 elif predicted.shape[0] == batch_index.shape[0]: 270 predicted_sign_mask = (inputs["system_sign"]<0)[batch_index].reshape(*shape_mask) 271 natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)]) 272 iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index] 273 iatdest = natshift[inputs["system_index"]][batch_index]+iat 274 predicted_shift = jax.ops.segment_sum(predicted_sign_mask*predicted,iatdest,batch_index.shape[0]) 275 predicted = predicted - predicted_shift 276 else: 277 raise ValueError("remove_ref_sys only works with system-level or atom-level predictions") 278 279 if ensemble_loss: 280 predicted_ensemble = predicted_ensemble - jnp.expand_dims(predicted_shift,ensemble_axis) 281 predicted_var = ( 282 ensemble_weights * (predicted_ensemble - jnp.expand_dims(predicted,ensemble_axis)) ** 2 283 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 284 285 if predicted.ndim > 1 and predicted.shape[-1] == 1: 286 predicted = jnp.squeeze(predicted, axis=-1) 287 288 if ref.ndim > 1 and ref.shape[-1] == 1: 289 ref = jnp.squeeze(ref, axis=-1) 290 291 per_atom = False 292 shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1) 293 # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape) 294 natscale = 1.0 295 if ref.shape[0] == output["batch_index"].shape[0]: 296 ## shape is number of atoms 297 truth_mask = atom_mask 298 if "ds_weight" in loss_prms: 299 weight_key = loss_prms["ds_weight"] 300 natscale=data[weight_key] 301 if natscale.shape[0] == natoms.shape[0]: 302 natscale = natscale[output["batch_index"]] 303 natscale = natscale.reshape(*shape_mask) 304 elif ref.shape[0] == natoms.shape[0]: 305 ## shape is number of systems 306 truth_mask = system_mask 307 if loss_prms.get("per_atom", False): 308 per_atom = True 309 ref = ref / natoms.reshape(*shape_mask) 310 predicted = predicted / natoms.reshape(*shape_mask) 311 if "nat_pow" in loss_prms: 312 natscale = ( 313 1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"] 314 ) 315 if "ds_weight" in loss_prms: 316 weight_key = loss_prms["ds_weight"] 317 natscale = natscale * data[weight_key].reshape(*shape_mask) 318 else: 319 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 320 321 if use_ref_mask: 322 truth_mask = truth_mask * ref_mask.astype(bool) 323 324 nel = jnp.maximum( 325 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 326 * jnp.sum(truth_mask).astype(jnp.float32), 327 1.0, 328 ) 329 truth_mask = truth_mask.reshape(*shape_mask) 330 331 ref = ref * truth_mask 332 predicted = predicted * truth_mask 333 334 natscale = natscale / nel 335 336 loss_type = loss_prms["type"] 337 if loss_type == "mse": 338 loss = jnp.sum(natscale * (predicted - ref) ** 2) 339 elif loss_type == "log_cosh": 340 loss = jnp.sum(natscale * optax.log_cosh(predicted, ref)) 341 elif loss_type == "mae": 342 loss = jnp.sum(natscale * jnp.abs(predicted - ref)) 343 elif loss_type == "rel_mse": 344 eps = loss_prms.get("rel_eps", 1.0e-6) 345 rel_pow = loss_prms.get("rel_pow", 2.0) 346 loss = jnp.sum( 347 natscale 348 * (predicted - ref) ** 2 349 / (eps + jnp.abs(ref) ** rel_pow) 350 ) 351 elif loss_type == "rmse+mae": 352 loss = ( 353 jnp.sum(natscale * (predicted - ref) ** 2) 354 ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref)) 355 elif loss_type == "crps": 356 predicted_var_key = loss_prms.get( 357 "var_key", loss_prms["key"] + "_var" 358 ) 359 lamba_crps = loss_prms.get("lambda_crps", 1.0) 360 predicted_var = output[predicted_var_key] 361 if per_atom: 362 predicted_var = predicted_var / natoms.reshape(*shape_mask)**2 363 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 364 sigma = predicted_var**0.5 365 dy = (ref - predicted) / sigma 366 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 367 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 368 loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5)) 369 elif loss_type == "ensemble_nll": 370 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 371 loss = 0.5 * jnp.sum( 372 natscale 373 * truth_mask 374 * ( 375 jnp.log(predicted_var) 376 + (ref - predicted) ** 2 / predicted_var 377 ) 378 ) 379 elif loss_type == "ensemble_crps": 380 if per_atom: 381 predicted_var = predicted_var / natoms.reshape(*shape_mask)**2 382 lamba_crps = loss_prms.get("lambda_crps", 1.0) 383 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 384 sigma = predicted_var**0.5 385 dy = (ref - predicted) / sigma 386 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 387 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 388 loss = jnp.sum( 389 natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi - lamba_crps/jnp.pi**0.5) 390 ) 391 elif loss_type == "gaussian_mixture": 392 gm_w = get_subsample(output[loss_prms["key"] + "_w"]) 393 gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True) 394 sigma = get_subsample(output[loss_prms["key"] + "_sigma"]) 395 ref = jnp.expand_dims(ref, axis=ensemble_axis) 396 # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2) 397 # log_w =jnp.log(gm_w + 1.e-10) 398 # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis) 399 # loss = jnp.sum(natscale * truth_mask * negloglikelihood) 400 401 # CRPS loss 402 def compute_crps(mu, sigma): 403 dy = mu / sigma 404 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 405 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 406 return sigma * (dy * (2 * Phi - 1.0) + 2 * phi - 1./jnp.pi**0.5) 407 408 crps1 = compute_crps(ref - predicted_ensemble, sigma) 409 gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis) 410 411 if ensemble_axis < 0: 412 ensemble_axis = predicted_ensemble.ndim + ensemble_axis 413 muij = jnp.expand_dims( 414 predicted_ensemble, axis=ensemble_axis 415 ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1) 416 sigma2 = sigma**2 417 sigmaij = ( 418 jnp.expand_dims(sigma2, axis=ensemble_axis) 419 + jnp.expand_dims(sigma2, axis=ensemble_axis + 1) 420 ) ** 0.5 421 wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims( 422 gm_w, axis=ensemble_axis + 1 423 ) 424 crps2 = compute_crps(muij, sigmaij) 425 gm_crps2 = (wij * crps2).sum( 426 axis=(ensemble_axis, ensemble_axis + 1) 427 ) 428 429 crps = gm_crps1 - 0.5 * gm_crps2 430 loss = jnp.sum(natscale * truth_mask * crps) 431 432 elif loss_type == "evidential": 433 evidence = loss_prms["evidence_key"] 434 nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1) 435 gamma = predicted 436 nu = nu.reshape(shape_mask) 437 alpha = alpha.reshape(shape_mask) 438 beta = beta.reshape(shape_mask) 439 nu = jnp.where(truth_mask, nu, 1.0) 440 alpha = jnp.where(truth_mask, alpha, 1.0) 441 beta = jnp.where(truth_mask, beta, 1.0) 442 omega = 2 * beta * (1 + nu) 443 lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln( 444 alpha + 0.5 445 ) 446 ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega) 447 lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2) 448 wst = ( 449 (beta * (1 + nu) / (alpha * nu)) ** 0.5 450 if loss_prms.get("normalize_evidence", True) 451 else 1.0 452 ) 453 lr = ( 454 loss_prms.get("lambda_evidence", 1.0) 455 * jnp.abs(gamma - ref) 456 * nu 457 / wst 458 ) 459 r = loss_prms.get("evidence_ratio", 1.0) 460 le = ( 461 loss_prms.get("lambda_evidence_diff", 0.0) 462 * (nu - r * 2 * alpha) ** 2 463 ) 464 lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta 465 loss = lg + ls + lt + lr + le + lb 466 467 loss = jnp.sum(natscale * loss * truth_mask) 468 elif loss_type == "raw": 469 loss = jnp.sum(natscale * predicted) 470 elif loss_type == "binary_crossentropy": 471 ref = ref.astype(predicted.dtype) 472 if loss_prms.get("raw_logits",False): 473 loss = jnp.sum( 474 natscale*truth_mask 475 * ( 476 -ref * jax.nn.log_sigmoid(predicted) 477 - (1.0 - ref) * jax.nn.log_sigmoid(-predicted) 478 ) 479 ) 480 else: 481 loss = jnp.sum( 482 natscale*truth_mask 483 * ( 484 -ref * jnp.log(predicted + 1.0e-10) 485 - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10) 486 ) 487 ) 488 else: 489 raise ValueError(f"Unknown loss type: {loss_type}") 490 491 if "weight_schedule" in loss_prms: 492 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 493 else: 494 w = loss_prms["weight"] 495 496 loss_tot = loss_tot + w * loss 497 498 return loss_tot, output 499 500 variables = training_state["variables"] 501 new_training_state = {**training_state, "step":training_state["step"]+1} 502 503 (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables) 504 updates, new_training_state["opt_state"] = optimizer.update(grad, training_state["opt_state"], params=variables) 505 variables = optax.apply_updates(variables, updates) 506 new_training_state["variables"] = variables 507 508 if ema is not None: 509 ema_st = training_state.get("ema_state",None) 510 if ema_st is None: 511 raise ValueError( 512 "train_step was setup with ema but either variables_ema or ema_st was not provided" 513 ) 514 new_training_state["variables_ema"], new_training_state["ema_state"] = ema.update(variables, ema_st) 515 516 return loss, new_training_state, o 517 518 if jit: 519 return jax.jit(train_step) 520 return train_step
def
get_validation_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, return_targets: bool = False, jit: bool = True):
523def get_validation_function( 524 loss_definition: Dict, 525 model: FENNIX, 526 evaluate: Callable, 527 model_ref: Optional[FENNIX] = None, 528 compute_ref_coords: bool = False, 529 return_targets: bool = False, 530 jit: bool = True, 531): 532 533 def validation(data, inputs, variables, inputs_ref=None): 534 if model_ref is not None: 535 output_ref = evaluate(model_ref, model_ref.variables, inputs) 536 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 537 output = evaluate(model, variables, inputs) 538 # _, _, output = model._energy_and_forces(variables, data) 539 rmses = {} 540 maes = {} 541 if return_targets: 542 targets = {} 543 544 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 545 nsys = inputs["natoms"].shape[0] 546 system_mask = inputs["true_sys"] 547 atom_mask = inputs["true_atoms"] 548 batch_index = inputs["batch_index"] 549 if "system_sign" in inputs: 550 system_sign = inputs["system_sign"] > 0 551 system_mask = jnp.logical_and(system_mask,system_sign) 552 atom_mask = jnp.logical_and(atom_mask, system_sign[batch_index]) 553 554 for name, loss_prms in loss_definition.items(): 555 do_validation = loss_prms.get("validate", True) 556 if not do_validation: 557 continue 558 predicted = output[loss_prms["key"]] 559 use_ref_mask = False 560 561 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 562 assert compute_ref_coords, "compute_ref_coords must be True" 563 # predicted = predicted - output_data_ref[loss_prms["key"]] 564 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 565 if predicted.shape[0] == nsys: 566 system_sign = inputs["system_sign"].reshape(*shape_mask) 567 predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys) 568 elif predicted.shape[0] == batch_index.shape[0]: 569 atom_sign = inputs["system_sign"][batch_index].reshape(*shape_mask) 570 natshift = jnp.concatenate([jnp.array([0]),jnp.cumsum(natoms)]) 571 iat = jnp.arange(batch_index.shape[0]) - natshift[batch_index] 572 iatdest = natshift[inputs["system_index"]][batch_index]+iat 573 predicted = jax.ops.segment_sum(atom_sign*predicted,iatdest,batch_index.shape[0]) 574 else: 575 raise ValueError("remove_ref_sys only works with system-level or atom-level predictions") 576 577 if "ref" in loss_prms: 578 if loss_prms["ref"].startswith("model_ref/"): 579 assert model_ref is not None, "model_ref must be provided" 580 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 581 mask_key = loss_prms["ref"][6:] + "_mask" 582 if mask_key in data: 583 use_ref_mask = True 584 ref_mask = data[mask_key] 585 elif mask_key in output_ref: 586 use_ref_mask = True 587 ref_mask = output[mask_key] 588 elif mask_key in output: 589 use_ref_mask = True 590 ref_mask = output[mask_key] 591 elif loss_prms["ref"].startswith("model/"): 592 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 593 mask_key = loss_prms["ref"][6:] + "_mask" 594 if mask_key in data: 595 use_ref_mask = True 596 ref_mask = data[mask_key] 597 elif mask_key in output: 598 use_ref_mask = True 599 ref_mask = output[mask_key] 600 else: 601 ref = data[loss_prms["ref"]] * loss_prms["mult"] 602 if loss_prms["ref"] + "_mask" in data: 603 use_ref_mask = True 604 ref_mask = data[loss_prms["ref"] + "_mask"] 605 else: 606 ref = jnp.zeros_like(predicted) 607 608 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 609 norm_axis = loss_prms["norm_axis"] 610 predicted = jnp.linalg.norm( 611 predicted, axis=norm_axis, keepdims=True 612 ) 613 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 614 615 loss_type = loss_prms["type"].lower() 616 if loss_type.startswith("ensemble"): 617 axis = loss_prms.get("ensemble_axis", -1) 618 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 619 if ensemble_weights_key is not None: 620 ensemble_weights = output[ensemble_weights_key] 621 else: 622 ensemble_weights = jnp.ones_like(predicted) 623 ensemble_weights = ensemble_weights / jnp.sum( 624 ensemble_weights, axis=axis, keepdims=True 625 ) 626 627 # predicted = predicted.mean(axis=axis) 628 predicted = (ensemble_weights * predicted).sum(axis=axis) 629 630 if loss_prms["type"] in ["gaussian_mixture"]: 631 axis = loss_prms.get("ensemble_axis", -1) 632 gm_w = output[loss_prms["key"] + "_w"] 633 gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True) 634 predicted = (predicted * gm_w).sum(axis=axis) 635 636 if predicted.ndim > 1 and predicted.shape[-1] == 1: 637 predicted = jnp.squeeze(predicted, axis=-1) 638 639 if ref.ndim > 1 and ref.shape[-1] == 1: 640 ref = jnp.squeeze(ref, axis=-1) 641 642 shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1) 643 if ref.shape[0] == output["batch_index"].shape[0]: 644 ## shape is number of atoms 645 truth_mask = atom_mask 646 elif ref.shape[0] == natoms.shape[0]: 647 ## shape is number of systems 648 truth_mask = system_mask 649 if loss_prms.get("per_atom_validation", False): 650 ref = ref / natoms.reshape(*shape_mask) 651 predicted = predicted / natoms.reshape(*shape_mask) 652 else: 653 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 654 655 if use_ref_mask: 656 truth_mask = truth_mask * ref_mask.astype(bool) 657 658 nel = jnp.maximum( 659 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 660 * jnp.sum(truth_mask).astype(predicted.dtype), 661 1.0, 662 ) 663 664 truth_mask = truth_mask.reshape(*shape_mask) 665 666 ref = ref * truth_mask 667 predicted = predicted * truth_mask 668 669 if loss_type == "binary_crossentropy": 670 ref = ref.astype(predicted.dtype) 671 if loss_prms.get("raw_logits",False): 672 rmse = jnp.sum( 673 ( 674 -ref * jax.nn.log_sigmoid(predicted) 675 - (1.0 - ref) * jax.nn.log_sigmoid(-predicted) 676 )*truth_mask/nel 677 ) 678 predicted = truth_mask*jax.nn.sigmoid(predicted) 679 else: 680 rmse = jnp.sum( 681 ( 682 -ref * jnp.log(predicted + 1.0e-10) 683 - (1.0 - ref) * jnp.log(1.0 - predicted + 1.0e-10) 684 )*truth_mask/nel 685 ) 686 thr = loss_prms.get("class_threshold", 0.05) 687 mae = jnp.sum(jnp.where(jnp.abs(predicted - ref) > thr, 1, 0)/nel)*100 688 else: 689 rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5 690 mae = jnp.sum(jnp.abs(predicted - ref) / nel) 691 692 rmses[name] = rmse 693 maes[name] = mae 694 if return_targets: 695 targets[name] = (predicted, ref, truth_mask) 696 697 if return_targets: 698 return rmses, maes, output, targets 699 700 return rmses, maes, output 701 702 if jit: 703 return jax.jit(validation) 704 return validation