fennol.training.utils
1import jax 2import jax.numpy as jnp 3import numpy as np 4from typing import Callable, Optional, Dict, List, Tuple 5import optax 6from copy import deepcopy 7from flax import traverse_util 8import json 9import re 10from functools import partial 11 12from ..utils import deep_update, AtomicUnits as au 13from ..models import FENNIX 14 15@partial(jax.jit, static_argnums=(1,2,3,4)) 16def linear_schedule(step, start_value,end_value, start_step, duration): 17 return start_value + jnp.clip((step - start_step) / duration, 0.0, 1.0) * (end_value - start_value) 18 19 20def get_training_parameters( 21 parameters: Dict[str, any], stage: int = -1 22) -> Dict[str, any]: 23 params = deepcopy(parameters["training"]) 24 if "stages" not in params: 25 return params 26 27 stages: dict = params.pop("stages") 28 stage_keys = list(stages.keys()) 29 if stage < 0: 30 stage = len(stage_keys) + stage 31 assert stage >= 0 and stage < len( 32 stage_keys 33 ), f"Stage {stage} not found in training parameters" 34 for i in range(stage + 1): 35 ## remove end_event from previous stage ## 36 if i > 0 and "end_event" in params: 37 params.pop("end_event") 38 ## incrementally update training parameters ## 39 stage_params = stages[stage_keys[i]] 40 params = deep_update(params, stage_params) 41 return params 42 43 44def get_loss_definition( 45 training_parameters: Dict[str, any], 46 model_energy_unit: str = "Ha", # , manual_renames: List[str] = [] 47) -> Tuple[Dict[str, any], List[str], List[str]]: 48 """ 49 Returns the loss definition and a list of renamed references. 50 51 Args: 52 training_parameters (dict): A dictionary containing training parameters. 53 54 Returns: 55 tuple: A tuple containing: 56 - loss_definition (dict): A dictionary containing the loss definition. 57 - rename_refs (list): A list of renamed references. 58 """ 59 default_loss_type = training_parameters.get("default_loss_type", "log_cosh") 60 # loss_definition = deepcopy(training_parameters["loss"]) 61 used_keys = [] 62 ref_keys = [] 63 energy_mult = au.get_multiplier(model_energy_unit) 64 loss_definition = {} 65 for k in training_parameters["loss"]: 66 # loss_prms = loss_definition[k] 67 loss_prms = deepcopy(training_parameters["loss"][k]) 68 if "energy_unit" in loss_prms: 69 loss_prms["mult"] = energy_mult / au.get_multiplier( 70 loss_prms["energy_unit"] 71 ) 72 if "unit" in loss_prms: 73 print( 74 "Warning: Both 'unit' and 'energy_unit' are defined for loss component", 75 k, 76 " -> using 'energy_unit'", 77 ) 78 loss_prms["unit"] = loss_prms["energy_unit"] 79 elif "unit" in loss_prms: 80 loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"]) 81 else: 82 loss_prms["mult"] = 1.0 83 if "key" not in loss_prms: 84 loss_prms["key"] = k 85 if "type" not in loss_prms: 86 loss_prms["type"] = default_loss_type 87 if "weight" not in loss_prms: 88 loss_prms["weight"] = 1.0 89 assert loss_prms["weight"] >= 0.0, "Loss weight must be positive" 90 if "threshold" in loss_prms: 91 assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0" 92 if "ref" in loss_prms: 93 ref = loss_prms["ref"] 94 if not (ref.startswith("model_ref/") or ref.startswith("model/")): 95 ref_keys.append(ref) 96 if "ds_weight" in loss_prms: 97 ref_keys.append(loss_prms["ds_weight"]) 98 99 if "weight_start" in loss_prms: 100 weight_start = loss_prms["weight_start"] 101 if "weight_ramp" in loss_prms: 102 weight_ramp = loss_prms["weight_ramp"] 103 else: 104 weight_ramp = training_parameters.get("max_epochs") 105 weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0) 106 weight_end = loss_prms["weight"] 107 print( 108 "Weight ramp for", 109 k, 110 ":", 111 weight_start, 112 "->", 113 loss_prms["weight"], 114 " in", 115 weight_ramp, 116 "epochs", 117 ) 118 loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp) 119 # loss_prms["weight_schedule"] = lambda e: weight_start + jnp.clip( 120 # (e - float(weight_ramp_start)) / float(weight_ramp), 0.0, 1.0 121 # ) * (weight_end - weight_start) 122 123 used_keys.append(loss_prms["key"]) 124 loss_definition[k] = loss_prms 125 126 # rename_refs = list( 127 # set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys) 128 # ) 129 130 # for k in loss_definition.keys(): 131 # loss_prms = loss_definition[k] 132 # if "ref" in loss_prms: 133 # if loss_prms["ref"] in rename_refs: 134 # loss_prms["ref"] = "true_" + loss_prms["ref"] 135 136 return loss_definition, list(set(used_keys)), list(set(ref_keys)) 137 138 139def get_optimizer( 140 training_parameters: Dict[str, any], variables: Dict, initial_lr: float 141) -> optax.GradientTransformation: 142 """ 143 Returns an optax.GradientTransformation object that can be used to optimize the model parameters. 144 145 Args: 146 - training_parameters: A dictionary containing the training parameters. 147 - variables: A pytree containing the model parameters. 148 - initial_lr: The initial learning rate. 149 150 Returns: 151 - An optax.GradientTransformation object that can be used to optimize the model parameters. 152 """ 153 154 default_status = str(training_parameters.get("default_status", "trainable")).lower() 155 assert default_status in [ 156 "trainable", 157 "frozen", 158 ], f"Default status must be 'trainable' or 'frozen', got {default_status}" 159 160 # find frozen and trainable parameters 161 frozen = training_parameters.get("frozen", []) 162 trainable = training_parameters.get("trainable", []) 163 164 def training_status(full_path, v): 165 full_path = "/".join(full_path[1:]).lower() 166 status = (default_status, "") 167 for path in frozen: 168 if full_path.startswith(path.lower()) and len(path) > len(status[1]): 169 status = ("frozen", path) 170 for path in trainable: 171 if full_path.startswith(path.lower()) and len(path) > len(status[1]): 172 status = ("trainable", path) 173 return status[0] 174 175 params_partition = traverse_util.path_aware_map(training_status, variables) 176 if len(frozen) > 0 or len(trainable) > 0: 177 print("params partition:") 178 print(json.dumps(params_partition, indent=2, sort_keys=False)) 179 180 ## Gradient preprocessing 181 grad_processing = [] 182 183 # zero nans 184 zero_nans = training_parameters.get("zero_nans", False) 185 if zero_nans: 186 grad_processing.append(optax.zero_nans()) 187 188 # gradient clipping 189 clip_threshold = training_parameters.get("gradient_clipping", -1.0) 190 if clip_threshold > 0.0: 191 print("Adaptive gradient clipping threshold:", clip_threshold) 192 grad_processing.append(optax.adaptive_grad_clip(clip_threshold)) 193 194 # OPTIMIZER 195 optimizer_name = training_parameters.get("optimizer", "adabelief") 196 optimizer = eval( 197 optimizer_name, 198 {"__builtins__": None}, 199 {**optax.__dict__}, 200 ) 201 print("Optimizer:", optimizer_name) 202 optimizer_configuration = training_parameters.get("optimizer_config", {}) 203 optimizer_configuration["learning_rate"] = 1.0 204 grad_processing.append(optimizer(**optimizer_configuration)) 205 206 # weight decay 207 weight_decay = training_parameters.get("weight_decay", 0.0) 208 assert weight_decay >= 0.0, "Weight decay must be positive" 209 decay_targets = training_parameters.get("decay_targets", [""]) 210 211 def decay_status(full_path, v): 212 full_path = "/".join(full_path).lower() 213 status = False 214 # print(full_path,re.match(r'^params\/', full_path)) 215 for path in decay_targets: 216 if re.match(r"^params/" + path.lower(), full_path): 217 status = True 218 # if full_path.startswith("params/" + path.lower()): 219 # status = True 220 return status 221 222 decay_mask = traverse_util.path_aware_map(decay_status, variables) 223 if weight_decay > 0.0: 224 print("weight decay:", weight_decay) 225 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 226 grad_processing.append( 227 optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask) 228 ) 229 230 if zero_nans: 231 grad_processing.append(optax.zero_nans()) 232 233 # learning rate 234 grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr)) 235 236 ## define optimizer chain 237 optimizer_ = optax.chain( 238 *grad_processing, 239 ) 240 partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()} 241 return optax.multi_transform(partition_optimizer, params_partition) 242 243 244def get_train_step_function( 245 loss_definition: Dict, 246 model: FENNIX, 247 evaluate: Callable, 248 optimizer: optax.GradientTransformation, 249 ema: Optional[optax.GradientTransformation] = None, 250 model_ref: Optional[FENNIX] = None, 251 compute_ref_coords: bool = False, 252 jit: bool = True, 253): 254 def train_step( 255 epoch, 256 data, 257 inputs, 258 variables, 259 opt_st, 260 variables_ema=None, 261 ema_st=None, 262 ): 263 264 def loss_fn(variables): 265 if model_ref is not None: 266 output_ref = evaluate(model_ref, model_ref.variables, inputs) 267 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 268 output = evaluate(model, variables, inputs) 269 # _, _, output = model._energy_and_forces(variables, data) 270 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 271 nsys = inputs["natoms"].shape[0] 272 system_mask = inputs["true_sys"] 273 atom_mask = inputs["true_atoms"] 274 if "system_sign" in inputs: 275 system_sign = inputs["system_sign"] > 0 276 system_mask = jnp.logical_and(system_mask,system_sign) 277 atom_mask = jnp.logical_and(atom_mask, system_sign[inputs["batch_index"]]) 278 279 loss_tot = 0.0 280 for loss_prms in loss_definition.values(): 281 use_ref_mask = False 282 predicted = output[loss_prms["key"]] 283 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 284 assert compute_ref_coords, "compute_ref_coords must be True" 285 # predicted = predicted - output_data_ref[loss_prms["key"]] 286 assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions" 287 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 288 system_sign = inputs["system_sign"].reshape(*shape_mask) 289 predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys) 290 291 if "ref" in loss_prms: 292 if loss_prms["ref"].startswith("model_ref/"): 293 assert model_ref is not None, "model_ref must be provided" 294 try: 295 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 296 except KeyError: 297 raise KeyError( 298 f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}" 299 ) 300 elif loss_prms["ref"].startswith("model/"): 301 try: 302 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 303 except KeyError: 304 raise KeyError( 305 f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}" 306 ) 307 else: 308 try: 309 ref = data[loss_prms["ref"]] * loss_prms["mult"] 310 except KeyError: 311 raise KeyError( 312 f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}" 313 ) 314 if loss_prms["ref"] + "_mask" in data: 315 use_ref_mask = True 316 ref_mask = data[loss_prms["ref"] + "_mask"] 317 else: 318 ref = jnp.zeros_like(predicted) 319 320 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 321 norm_axis = loss_prms["norm_axis"] 322 predicted = jnp.linalg.norm( 323 predicted, axis=norm_axis, keepdims=True 324 ) 325 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 326 327 if loss_prms["type"] in [ 328 "ensemble_nll", 329 "ensemble_crps", 330 "gaussian_mixture", 331 ]: 332 ensemble_axis = loss_prms.get("ensemble_axis", -1) 333 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 334 ensemble_size = predicted.shape[ensemble_axis] 335 if ensemble_weights_key is not None: 336 ensemble_weights = output[ensemble_weights_key] 337 else: 338 ensemble_weights = jnp.ones_like(predicted) 339 if "ensemble_subsample" in loss_prms and "rng_key" in output: 340 ns = min( 341 loss_prms["ensemble_subsample"], 342 predicted.shape[ensemble_axis], 343 ) 344 key, subkey = jax.random.split(output["rng_key"]) 345 346 def get_subsample(x): 347 return jax.lax.slice_in_dim( 348 jax.random.permutation( 349 subkey, x, axis=ensemble_axis, independent=True 350 ), 351 start_index=0, 352 limit_index=ns, 353 axis=ensemble_axis, 354 ) 355 356 predicted = get_subsample(predicted) 357 ensemble_weights = get_subsample(ensemble_weights) 358 output["rng_key"] = key 359 else: 360 get_subsample = lambda x: x 361 ensemble_weights = ensemble_weights / jnp.sum( 362 ensemble_weights, axis=ensemble_axis, keepdims=True 363 ) 364 predicted_ensemble = predicted 365 predicted = (predicted * ensemble_weights).sum( 366 axis=ensemble_axis, keepdims=True 367 ) 368 predicted_var = ( 369 ensemble_weights * (predicted_ensemble - predicted) ** 2 370 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 371 predicted = jnp.squeeze(predicted, axis=ensemble_axis) 372 373 if predicted.ndim > 1 and predicted.shape[-1] == 1: 374 predicted = jnp.squeeze(predicted, axis=-1) 375 376 if ref.ndim > 1 and ref.shape[-1] == 1: 377 ref = jnp.squeeze(ref, axis=-1) 378 379 per_atom = False 380 shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1) 381 # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape) 382 natscale = 1.0 383 if ref.shape[0] == output["batch_index"].shape[0]: 384 ## shape is number of atoms 385 truth_mask = atom_mask 386 if "ds_weight" in loss_prms: 387 weight_key = loss_prms["ds_weight"] 388 natscale = data[weight_key][output["batch_index"]].reshape( 389 *shape_mask 390 ) 391 elif ref.shape[0] == natoms.shape[0]: 392 ## shape is number of systems 393 truth_mask = system_mask 394 if loss_prms.get("per_atom", False): 395 per_atom = True 396 ref = ref / natoms.reshape(*shape_mask) 397 predicted = predicted / natoms.reshape(*shape_mask) 398 if "nat_pow" in loss_prms: 399 natscale = ( 400 1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"] 401 ) 402 if "ds_weight" in loss_prms: 403 weight_key = loss_prms["ds_weight"] 404 natscale = natscale * data[weight_key].reshape(*shape_mask) 405 else: 406 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 407 408 if use_ref_mask: 409 truth_mask = truth_mask * ref_mask.astype(bool) 410 411 nel = jnp.maximum( 412 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 413 * jnp.sum(truth_mask).astype(jnp.float32), 414 1.0, 415 ) 416 truth_mask = truth_mask.reshape(*shape_mask) 417 418 ref = ref * truth_mask 419 predicted = predicted * truth_mask 420 421 natscale = natscale / nel 422 423 loss_type = loss_prms["type"] 424 if loss_type == "mse": 425 loss = jnp.sum(natscale * (predicted - ref) ** 2) 426 elif loss_type == "log_cosh": 427 loss = jnp.sum(natscale * optax.log_cosh(predicted, ref)) 428 elif loss_type == "mae": 429 loss = jnp.sum(natscale * jnp.abs(predicted - ref)) 430 elif loss_type == "rel_mse": 431 eps = loss_prms.get("rel_eps", 1.0e-6) 432 rel_pow = loss_prms.get("rel_pow", 2.0) 433 loss = jnp.sum( 434 natscale 435 * (predicted - ref) ** 2 436 / (eps + jnp.abs(ref) ** rel_pow) 437 ) 438 elif loss_type == "rmse+mae": 439 loss = ( 440 jnp.sum(natscale * (predicted - ref) ** 2) 441 ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref)) 442 elif loss_type == "crps": 443 predicted_var_key = loss_prms.get( 444 "var_key", loss_prms["key"] + "_var" 445 ) 446 predicted_var = output[predicted_var_key] 447 if per_atom: 448 predicted_var = predicted_var / natoms.reshape(*shape_mask) 449 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 450 sigma = predicted_var**0.5 451 dy = (ref - predicted) / sigma 452 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 453 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 454 loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi)) 455 elif loss_type == "ensemble_nll": 456 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 457 loss = 0.5 * jnp.sum( 458 natscale 459 * truth_mask 460 * ( 461 jnp.log(predicted_var) 462 + (ref - predicted) ** 2 / predicted_var 463 ) 464 ) 465 elif loss_type == "ensemble_crps": 466 if per_atom: 467 predicted_var = predicted_var / natoms.reshape(*shape_mask) 468 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 469 sigma = predicted_var**0.5 470 dy = (ref - predicted) / sigma 471 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 472 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 473 loss = jnp.sum( 474 natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi) 475 ) 476 elif loss_type == "gaussian_mixture": 477 gm_w = get_subsample(output[loss_prms["key"] + "_w"]) 478 gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True) 479 sigma = get_subsample(output[loss_prms["key"] + "_sigma"]) 480 ref = jnp.expand_dims(ref, axis=ensemble_axis) 481 # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2) 482 # log_w =jnp.log(gm_w + 1.e-10) 483 # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis) 484 # loss = jnp.sum(natscale * truth_mask * negloglikelihood) 485 486 # CRPS loss 487 def compute_crps(mu, sigma): 488 dy = mu / sigma 489 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 490 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 491 return sigma * (dy * (2 * Phi - 1.0) + 2 * phi) 492 493 crps1 = compute_crps(ref - predicted_ensemble, sigma) 494 gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis) 495 496 if ensemble_axis < 0: 497 ensemble_axis = predicted_ensemble.ndim + ensemble_axis 498 muij = jnp.expand_dims( 499 predicted_ensemble, axis=ensemble_axis 500 ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1) 501 sigma2 = sigma**2 502 sigmaij = ( 503 jnp.expand_dims(sigma2, axis=ensemble_axis) 504 + jnp.expand_dims(sigma2, axis=ensemble_axis + 1) 505 ) ** 0.5 506 wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims( 507 gm_w, axis=ensemble_axis + 1 508 ) 509 crps2 = compute_crps(muij, sigmaij) 510 gm_crps2 = (wij * crps2).sum( 511 axis=(ensemble_axis, ensemble_axis + 1) 512 ) 513 514 crps = gm_crps1 - 0.5 * gm_crps2 515 loss = jnp.sum(natscale * truth_mask * crps) 516 517 elif loss_type == "evidential": 518 evidence = loss_prms["evidence_key"] 519 nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1) 520 gamma = predicted 521 nu = nu.reshape(shape_mask) 522 alpha = alpha.reshape(shape_mask) 523 beta = beta.reshape(shape_mask) 524 nu = jnp.where(truth_mask, nu, 1.0) 525 alpha = jnp.where(truth_mask, alpha, 1.0) 526 beta = jnp.where(truth_mask, beta, 1.0) 527 omega = 2 * beta * (1 + nu) 528 lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln( 529 alpha + 0.5 530 ) 531 ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega) 532 lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2) 533 wst = ( 534 (beta * (1 + nu) / (alpha * nu)) ** 0.5 535 if loss_prms.get("normalize_evidence", True) 536 else 1.0 537 ) 538 lr = ( 539 loss_prms.get("lambda_evidence", 1.0) 540 * jnp.abs(gamma - ref) 541 * nu 542 / wst 543 ) 544 r = loss_prms.get("evidence_ratio", 1.0) 545 le = ( 546 loss_prms.get("lambda_evidence_diff", 0.0) 547 * (nu - r * 2 * alpha) ** 2 548 ) 549 lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta 550 loss = lg + ls + lt + lr + le + lb 551 552 loss = jnp.sum(natscale * loss * truth_mask) 553 elif loss_type == "raw": 554 loss = jnp.sum(natscale * predicted) 555 else: 556 raise ValueError(f"Unknown loss type: {loss_type}") 557 558 if "weight_schedule" in loss_prms: 559 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 560 else: 561 w = loss_prms["weight"] 562 563 loss_tot = loss_tot + w * loss 564 565 return loss_tot, output 566 567 (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables) 568 updates, opt_st = optimizer.update(grad, opt_st, params=variables) 569 variables = optax.apply_updates(variables, updates) 570 if ema is not None: 571 if variables_ema is None or ema_st is None: 572 raise ValueError( 573 "train_step was setup with ema but either variables_ema or ema_st was not provided" 574 ) 575 variables_ema, ema_st = ema.update(variables, ema_st) 576 return loss, variables, opt_st, variables_ema, ema_st, o 577 else: 578 return loss, variables, opt_st, o 579 580 if jit: 581 return jax.jit(train_step) 582 return train_step 583 584 585def get_validation_function( 586 loss_definition: Dict, 587 model: FENNIX, 588 evaluate: Callable, 589 model_ref: Optional[FENNIX] = None, 590 compute_ref_coords: bool = False, 591 return_targets: bool = False, 592 jit: bool = True, 593): 594 595 def validation(data, inputs, variables, inputs_ref=None): 596 if model_ref is not None: 597 output_ref = evaluate(model_ref, model_ref.variables, inputs) 598 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 599 output = evaluate(model, variables, inputs) 600 # _, _, output = model._energy_and_forces(variables, data) 601 rmses = {} 602 maes = {} 603 if return_targets: 604 targets = {} 605 606 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 607 nsys = inputs["natoms"].shape[0] 608 system_mask = inputs["true_sys"] 609 atom_mask = inputs["true_atoms"] 610 if "system_sign" in inputs: 611 system_sign = inputs["system_sign"] > 0 612 system_mask = jnp.logical_and(system_mask,system_sign) 613 atom_mask = jnp.logical_and(atom_mask, system_sign[inputs["batch_index"]]) 614 615 for name, loss_prms in loss_definition.items(): 616 do_validation = loss_prms.get("validate", True) 617 if not do_validation: 618 continue 619 predicted = output[loss_prms["key"]] 620 use_ref_mask = False 621 622 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 623 assert compute_ref_coords, "compute_ref_coords must be True" 624 # predicted = predicted - output_data_ref[loss_prms["key"]] 625 assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions" 626 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 627 system_sign = inputs["system_sign"].reshape(*shape_mask) 628 predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys) 629 630 if "ref" in loss_prms: 631 if loss_prms["ref"].startswith("model_ref/"): 632 assert model_ref is not None, "model_ref must be provided" 633 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 634 elif loss_prms["ref"].startswith("model/"): 635 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 636 else: 637 ref = data[loss_prms["ref"]] * loss_prms["mult"] 638 if loss_prms["ref"] + "_mask" in data: 639 use_ref_mask = True 640 ref_mask = data[loss_prms["ref"] + "_mask"] 641 else: 642 ref = jnp.zeros_like(predicted) 643 644 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 645 norm_axis = loss_prms["norm_axis"] 646 predicted = jnp.linalg.norm( 647 predicted, axis=norm_axis, keepdims=True 648 ) 649 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 650 651 if loss_prms["type"].startswith("ensemble"): 652 axis = loss_prms.get("ensemble_axis", -1) 653 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 654 if ensemble_weights_key is not None: 655 ensemble_weights = output[ensemble_weights_key] 656 else: 657 ensemble_weights = jnp.ones_like(predicted) 658 ensemble_weights = ensemble_weights / jnp.sum( 659 ensemble_weights, axis=axis, keepdims=True 660 ) 661 662 # predicted = predicted.mean(axis=axis) 663 predicted = (ensemble_weights * predicted).sum(axis=axis) 664 665 if loss_prms["type"] in ["gaussian_mixture"]: 666 axis = loss_prms.get("ensemble_axis", -1) 667 gm_w = output[loss_prms["key"] + "_w"] 668 gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True) 669 predicted = (predicted * gm_w).sum(axis=axis) 670 671 if predicted.ndim > 1 and predicted.shape[-1] == 1: 672 predicted = jnp.squeeze(predicted, axis=-1) 673 674 if ref.ndim > 1 and ref.shape[-1] == 1: 675 ref = jnp.squeeze(ref, axis=-1) 676 677 shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1) 678 if ref.shape[0] == output["batch_index"].shape[0]: 679 ## shape is number of atoms 680 truth_mask = atom_mask 681 elif ref.shape[0] == natoms.shape[0]: 682 ## shape is number of systems 683 truth_mask = system_mask 684 if loss_prms.get("per_atom_validation", False): 685 ref = ref / natoms.reshape(*shape_mask) 686 predicted = predicted / natoms.reshape(*shape_mask) 687 else: 688 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 689 690 if use_ref_mask: 691 truth_mask = truth_mask * ref_mask.astype(bool) 692 693 nel = jnp.maximum( 694 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 695 * jnp.sum(truth_mask).astype(predicted.dtype), 696 1.0, 697 ) 698 699 truth_mask = truth_mask.reshape(*shape_mask) 700 701 ref = ref * truth_mask 702 predicted = predicted * truth_mask 703 704 rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5 705 mae = jnp.sum(jnp.abs(predicted - ref) / nel) 706 707 rmses[name] = rmse 708 maes[name] = mae 709 if return_targets: 710 targets[name] = (predicted, ref, truth_mask) 711 712 if return_targets: 713 return rmses, maes, output, targets 714 715 return rmses, maes, output 716 717 if jit: 718 return jax.jit(validation) 719 return validation
@partial(jax.jit, static_argnums=(1, 2, 3, 4))
def
linear_schedule(step, start_value, end_value, start_step, duration):
def
get_training_parameters( parameters: Dict[str, <built-in function any>], stage: int = -1) -> Dict[str, <built-in function any>]:
21def get_training_parameters( 22 parameters: Dict[str, any], stage: int = -1 23) -> Dict[str, any]: 24 params = deepcopy(parameters["training"]) 25 if "stages" not in params: 26 return params 27 28 stages: dict = params.pop("stages") 29 stage_keys = list(stages.keys()) 30 if stage < 0: 31 stage = len(stage_keys) + stage 32 assert stage >= 0 and stage < len( 33 stage_keys 34 ), f"Stage {stage} not found in training parameters" 35 for i in range(stage + 1): 36 ## remove end_event from previous stage ## 37 if i > 0 and "end_event" in params: 38 params.pop("end_event") 39 ## incrementally update training parameters ## 40 stage_params = stages[stage_keys[i]] 41 params = deep_update(params, stage_params) 42 return params
def
get_loss_definition( training_parameters: Dict[str, <built-in function any>], model_energy_unit: str = 'Ha') -> Tuple[Dict[str, <built-in function any>], List[str], List[str]]:
45def get_loss_definition( 46 training_parameters: Dict[str, any], 47 model_energy_unit: str = "Ha", # , manual_renames: List[str] = [] 48) -> Tuple[Dict[str, any], List[str], List[str]]: 49 """ 50 Returns the loss definition and a list of renamed references. 51 52 Args: 53 training_parameters (dict): A dictionary containing training parameters. 54 55 Returns: 56 tuple: A tuple containing: 57 - loss_definition (dict): A dictionary containing the loss definition. 58 - rename_refs (list): A list of renamed references. 59 """ 60 default_loss_type = training_parameters.get("default_loss_type", "log_cosh") 61 # loss_definition = deepcopy(training_parameters["loss"]) 62 used_keys = [] 63 ref_keys = [] 64 energy_mult = au.get_multiplier(model_energy_unit) 65 loss_definition = {} 66 for k in training_parameters["loss"]: 67 # loss_prms = loss_definition[k] 68 loss_prms = deepcopy(training_parameters["loss"][k]) 69 if "energy_unit" in loss_prms: 70 loss_prms["mult"] = energy_mult / au.get_multiplier( 71 loss_prms["energy_unit"] 72 ) 73 if "unit" in loss_prms: 74 print( 75 "Warning: Both 'unit' and 'energy_unit' are defined for loss component", 76 k, 77 " -> using 'energy_unit'", 78 ) 79 loss_prms["unit"] = loss_prms["energy_unit"] 80 elif "unit" in loss_prms: 81 loss_prms["mult"] = 1.0 / au.get_multiplier(loss_prms["unit"]) 82 else: 83 loss_prms["mult"] = 1.0 84 if "key" not in loss_prms: 85 loss_prms["key"] = k 86 if "type" not in loss_prms: 87 loss_prms["type"] = default_loss_type 88 if "weight" not in loss_prms: 89 loss_prms["weight"] = 1.0 90 assert loss_prms["weight"] >= 0.0, "Loss weight must be positive" 91 if "threshold" in loss_prms: 92 assert loss_prms["threshold"] > 1.0, "Threshold must be greater than 1.0" 93 if "ref" in loss_prms: 94 ref = loss_prms["ref"] 95 if not (ref.startswith("model_ref/") or ref.startswith("model/")): 96 ref_keys.append(ref) 97 if "ds_weight" in loss_prms: 98 ref_keys.append(loss_prms["ds_weight"]) 99 100 if "weight_start" in loss_prms: 101 weight_start = loss_prms["weight_start"] 102 if "weight_ramp" in loss_prms: 103 weight_ramp = loss_prms["weight_ramp"] 104 else: 105 weight_ramp = training_parameters.get("max_epochs") 106 weight_ramp_start = loss_prms.get("weight_ramp_start", 0.0) 107 weight_end = loss_prms["weight"] 108 print( 109 "Weight ramp for", 110 k, 111 ":", 112 weight_start, 113 "->", 114 loss_prms["weight"], 115 " in", 116 weight_ramp, 117 "epochs", 118 ) 119 loss_prms["weight_schedule"] = (weight_start,weight_end,weight_ramp_start,weight_ramp) 120 # loss_prms["weight_schedule"] = lambda e: weight_start + jnp.clip( 121 # (e - float(weight_ramp_start)) / float(weight_ramp), 0.0, 1.0 122 # ) * (weight_end - weight_start) 123 124 used_keys.append(loss_prms["key"]) 125 loss_definition[k] = loss_prms 126 127 # rename_refs = list( 128 # set(["forces", "total_energy", "atomic_energies"] + manual_renames + used_keys) 129 # ) 130 131 # for k in loss_definition.keys(): 132 # loss_prms = loss_definition[k] 133 # if "ref" in loss_prms: 134 # if loss_prms["ref"] in rename_refs: 135 # loss_prms["ref"] = "true_" + loss_prms["ref"] 136 137 return loss_definition, list(set(used_keys)), list(set(ref_keys))
Returns the loss definition and a list of renamed references.
Args: training_parameters (dict): A dictionary containing training parameters.
Returns: tuple: A tuple containing: - loss_definition (dict): A dictionary containing the loss definition. - rename_refs (list): A list of renamed references.
def
get_optimizer( training_parameters: Dict[str, <built-in function any>], variables: Dict, initial_lr: float) -> optax._src.base.GradientTransformation:
140def get_optimizer( 141 training_parameters: Dict[str, any], variables: Dict, initial_lr: float 142) -> optax.GradientTransformation: 143 """ 144 Returns an optax.GradientTransformation object that can be used to optimize the model parameters. 145 146 Args: 147 - training_parameters: A dictionary containing the training parameters. 148 - variables: A pytree containing the model parameters. 149 - initial_lr: The initial learning rate. 150 151 Returns: 152 - An optax.GradientTransformation object that can be used to optimize the model parameters. 153 """ 154 155 default_status = str(training_parameters.get("default_status", "trainable")).lower() 156 assert default_status in [ 157 "trainable", 158 "frozen", 159 ], f"Default status must be 'trainable' or 'frozen', got {default_status}" 160 161 # find frozen and trainable parameters 162 frozen = training_parameters.get("frozen", []) 163 trainable = training_parameters.get("trainable", []) 164 165 def training_status(full_path, v): 166 full_path = "/".join(full_path[1:]).lower() 167 status = (default_status, "") 168 for path in frozen: 169 if full_path.startswith(path.lower()) and len(path) > len(status[1]): 170 status = ("frozen", path) 171 for path in trainable: 172 if full_path.startswith(path.lower()) and len(path) > len(status[1]): 173 status = ("trainable", path) 174 return status[0] 175 176 params_partition = traverse_util.path_aware_map(training_status, variables) 177 if len(frozen) > 0 or len(trainable) > 0: 178 print("params partition:") 179 print(json.dumps(params_partition, indent=2, sort_keys=False)) 180 181 ## Gradient preprocessing 182 grad_processing = [] 183 184 # zero nans 185 zero_nans = training_parameters.get("zero_nans", False) 186 if zero_nans: 187 grad_processing.append(optax.zero_nans()) 188 189 # gradient clipping 190 clip_threshold = training_parameters.get("gradient_clipping", -1.0) 191 if clip_threshold > 0.0: 192 print("Adaptive gradient clipping threshold:", clip_threshold) 193 grad_processing.append(optax.adaptive_grad_clip(clip_threshold)) 194 195 # OPTIMIZER 196 optimizer_name = training_parameters.get("optimizer", "adabelief") 197 optimizer = eval( 198 optimizer_name, 199 {"__builtins__": None}, 200 {**optax.__dict__}, 201 ) 202 print("Optimizer:", optimizer_name) 203 optimizer_configuration = training_parameters.get("optimizer_config", {}) 204 optimizer_configuration["learning_rate"] = 1.0 205 grad_processing.append(optimizer(**optimizer_configuration)) 206 207 # weight decay 208 weight_decay = training_parameters.get("weight_decay", 0.0) 209 assert weight_decay >= 0.0, "Weight decay must be positive" 210 decay_targets = training_parameters.get("decay_targets", [""]) 211 212 def decay_status(full_path, v): 213 full_path = "/".join(full_path).lower() 214 status = False 215 # print(full_path,re.match(r'^params\/', full_path)) 216 for path in decay_targets: 217 if re.match(r"^params/" + path.lower(), full_path): 218 status = True 219 # if full_path.startswith("params/" + path.lower()): 220 # status = True 221 return status 222 223 decay_mask = traverse_util.path_aware_map(decay_status, variables) 224 if weight_decay > 0.0: 225 print("weight decay:", weight_decay) 226 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 227 grad_processing.append( 228 optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask) 229 ) 230 231 if zero_nans: 232 grad_processing.append(optax.zero_nans()) 233 234 # learning rate 235 grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr)) 236 237 ## define optimizer chain 238 optimizer_ = optax.chain( 239 *grad_processing, 240 ) 241 partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()} 242 return optax.multi_transform(partition_optimizer, params_partition)
Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
Args:
- training_parameters: A dictionary containing the training parameters.
- variables: A pytree containing the model parameters.
- initial_lr: The initial learning rate.
Returns:
- An optax.GradientTransformation object that can be used to optimize the model parameters.
def
get_train_step_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, optimizer: optax._src.base.GradientTransformation, ema: Optional[optax._src.base.GradientTransformation] = None, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, jit: bool = True):
245def get_train_step_function( 246 loss_definition: Dict, 247 model: FENNIX, 248 evaluate: Callable, 249 optimizer: optax.GradientTransformation, 250 ema: Optional[optax.GradientTransformation] = None, 251 model_ref: Optional[FENNIX] = None, 252 compute_ref_coords: bool = False, 253 jit: bool = True, 254): 255 def train_step( 256 epoch, 257 data, 258 inputs, 259 variables, 260 opt_st, 261 variables_ema=None, 262 ema_st=None, 263 ): 264 265 def loss_fn(variables): 266 if model_ref is not None: 267 output_ref = evaluate(model_ref, model_ref.variables, inputs) 268 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 269 output = evaluate(model, variables, inputs) 270 # _, _, output = model._energy_and_forces(variables, data) 271 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 272 nsys = inputs["natoms"].shape[0] 273 system_mask = inputs["true_sys"] 274 atom_mask = inputs["true_atoms"] 275 if "system_sign" in inputs: 276 system_sign = inputs["system_sign"] > 0 277 system_mask = jnp.logical_and(system_mask,system_sign) 278 atom_mask = jnp.logical_and(atom_mask, system_sign[inputs["batch_index"]]) 279 280 loss_tot = 0.0 281 for loss_prms in loss_definition.values(): 282 use_ref_mask = False 283 predicted = output[loss_prms["key"]] 284 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 285 assert compute_ref_coords, "compute_ref_coords must be True" 286 # predicted = predicted - output_data_ref[loss_prms["key"]] 287 assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions" 288 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 289 system_sign = inputs["system_sign"].reshape(*shape_mask) 290 predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys) 291 292 if "ref" in loss_prms: 293 if loss_prms["ref"].startswith("model_ref/"): 294 assert model_ref is not None, "model_ref must be provided" 295 try: 296 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 297 except KeyError: 298 raise KeyError( 299 f"Reference key '{loss_prms['ref'][10:]}' not found in model_ref output. Keys available: {output_ref.keys()}" 300 ) 301 elif loss_prms["ref"].startswith("model/"): 302 try: 303 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 304 except KeyError: 305 raise KeyError( 306 f"Reference key '{loss_prms['ref'][6:]}' not found in model output. Keys available: {output.keys()}" 307 ) 308 else: 309 try: 310 ref = data[loss_prms["ref"]] * loss_prms["mult"] 311 except KeyError: 312 raise KeyError( 313 f"Reference key '{loss_prms['ref']}' not found in data. Keys available: {data.keys()}" 314 ) 315 if loss_prms["ref"] + "_mask" in data: 316 use_ref_mask = True 317 ref_mask = data[loss_prms["ref"] + "_mask"] 318 else: 319 ref = jnp.zeros_like(predicted) 320 321 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 322 norm_axis = loss_prms["norm_axis"] 323 predicted = jnp.linalg.norm( 324 predicted, axis=norm_axis, keepdims=True 325 ) 326 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 327 328 if loss_prms["type"] in [ 329 "ensemble_nll", 330 "ensemble_crps", 331 "gaussian_mixture", 332 ]: 333 ensemble_axis = loss_prms.get("ensemble_axis", -1) 334 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 335 ensemble_size = predicted.shape[ensemble_axis] 336 if ensemble_weights_key is not None: 337 ensemble_weights = output[ensemble_weights_key] 338 else: 339 ensemble_weights = jnp.ones_like(predicted) 340 if "ensemble_subsample" in loss_prms and "rng_key" in output: 341 ns = min( 342 loss_prms["ensemble_subsample"], 343 predicted.shape[ensemble_axis], 344 ) 345 key, subkey = jax.random.split(output["rng_key"]) 346 347 def get_subsample(x): 348 return jax.lax.slice_in_dim( 349 jax.random.permutation( 350 subkey, x, axis=ensemble_axis, independent=True 351 ), 352 start_index=0, 353 limit_index=ns, 354 axis=ensemble_axis, 355 ) 356 357 predicted = get_subsample(predicted) 358 ensemble_weights = get_subsample(ensemble_weights) 359 output["rng_key"] = key 360 else: 361 get_subsample = lambda x: x 362 ensemble_weights = ensemble_weights / jnp.sum( 363 ensemble_weights, axis=ensemble_axis, keepdims=True 364 ) 365 predicted_ensemble = predicted 366 predicted = (predicted * ensemble_weights).sum( 367 axis=ensemble_axis, keepdims=True 368 ) 369 predicted_var = ( 370 ensemble_weights * (predicted_ensemble - predicted) ** 2 371 ).sum(axis=ensemble_axis) * (ensemble_size / (ensemble_size - 1.0)) 372 predicted = jnp.squeeze(predicted, axis=ensemble_axis) 373 374 if predicted.ndim > 1 and predicted.shape[-1] == 1: 375 predicted = jnp.squeeze(predicted, axis=-1) 376 377 if ref.ndim > 1 and ref.shape[-1] == 1: 378 ref = jnp.squeeze(ref, axis=-1) 379 380 per_atom = False 381 shape_mask = [ref.shape[0]] + [1] * (len(ref.shape) - 1) 382 # print(loss_prms["key"],predicted.shape,loss_prms["ref"],ref.shape) 383 natscale = 1.0 384 if ref.shape[0] == output["batch_index"].shape[0]: 385 ## shape is number of atoms 386 truth_mask = atom_mask 387 if "ds_weight" in loss_prms: 388 weight_key = loss_prms["ds_weight"] 389 natscale = data[weight_key][output["batch_index"]].reshape( 390 *shape_mask 391 ) 392 elif ref.shape[0] == natoms.shape[0]: 393 ## shape is number of systems 394 truth_mask = system_mask 395 if loss_prms.get("per_atom", False): 396 per_atom = True 397 ref = ref / natoms.reshape(*shape_mask) 398 predicted = predicted / natoms.reshape(*shape_mask) 399 if "nat_pow" in loss_prms: 400 natscale = ( 401 1.0 / natoms.reshape(*shape_mask) ** loss_prms["nat_pow"] 402 ) 403 if "ds_weight" in loss_prms: 404 weight_key = loss_prms["ds_weight"] 405 natscale = natscale * data[weight_key].reshape(*shape_mask) 406 else: 407 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 408 409 if use_ref_mask: 410 truth_mask = truth_mask * ref_mask.astype(bool) 411 412 nel = jnp.maximum( 413 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 414 * jnp.sum(truth_mask).astype(jnp.float32), 415 1.0, 416 ) 417 truth_mask = truth_mask.reshape(*shape_mask) 418 419 ref = ref * truth_mask 420 predicted = predicted * truth_mask 421 422 natscale = natscale / nel 423 424 loss_type = loss_prms["type"] 425 if loss_type == "mse": 426 loss = jnp.sum(natscale * (predicted - ref) ** 2) 427 elif loss_type == "log_cosh": 428 loss = jnp.sum(natscale * optax.log_cosh(predicted, ref)) 429 elif loss_type == "mae": 430 loss = jnp.sum(natscale * jnp.abs(predicted - ref)) 431 elif loss_type == "rel_mse": 432 eps = loss_prms.get("rel_eps", 1.0e-6) 433 rel_pow = loss_prms.get("rel_pow", 2.0) 434 loss = jnp.sum( 435 natscale 436 * (predicted - ref) ** 2 437 / (eps + jnp.abs(ref) ** rel_pow) 438 ) 439 elif loss_type == "rmse+mae": 440 loss = ( 441 jnp.sum(natscale * (predicted - ref) ** 2) 442 ) ** 0.5 + jnp.sum(natscale * jnp.abs(predicted - ref)) 443 elif loss_type == "crps": 444 predicted_var_key = loss_prms.get( 445 "var_key", loss_prms["key"] + "_var" 446 ) 447 predicted_var = output[predicted_var_key] 448 if per_atom: 449 predicted_var = predicted_var / natoms.reshape(*shape_mask) 450 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 451 sigma = predicted_var**0.5 452 dy = (ref - predicted) / sigma 453 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 454 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 455 loss = jnp.sum(natscale * sigma * (dy * (2 * Phi - 1.0) + 2 * phi)) 456 elif loss_type == "ensemble_nll": 457 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 458 loss = 0.5 * jnp.sum( 459 natscale 460 * truth_mask 461 * ( 462 jnp.log(predicted_var) 463 + (ref - predicted) ** 2 / predicted_var 464 ) 465 ) 466 elif loss_type == "ensemble_crps": 467 if per_atom: 468 predicted_var = predicted_var / natoms.reshape(*shape_mask) 469 predicted_var = predicted_var * truth_mask + (1.0 - truth_mask) 470 sigma = predicted_var**0.5 471 dy = (ref - predicted) / sigma 472 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 473 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 474 loss = jnp.sum( 475 natscale * truth_mask * sigma * (dy * (2 * Phi - 1.0) + 2 * phi) 476 ) 477 elif loss_type == "gaussian_mixture": 478 gm_w = get_subsample(output[loss_prms["key"] + "_w"]) 479 gm_w = gm_w / jnp.sum(gm_w, axis=ensemble_axis, keepdims=True) 480 sigma = get_subsample(output[loss_prms["key"] + "_sigma"]) 481 ref = jnp.expand_dims(ref, axis=ensemble_axis) 482 # log_likelihoods = -0.5*jnp.log(2*np.pi*gm_sigma2) - ((ref - predicted_ensemble) ** 2) / (2 * gm_sigma2) 483 # log_w =jnp.log(gm_w + 1.e-10) 484 # negloglikelihood = -jax.scipy.special.logsumexp(log_w + log_likelihoods,axis=ensemble_axis) 485 # loss = jnp.sum(natscale * truth_mask * negloglikelihood) 486 487 # CRPS loss 488 def compute_crps(mu, sigma): 489 dy = mu / sigma 490 Phi = 0.5 * (1.0 + jax.scipy.special.erf(dy / 2**0.5)) 491 phi = jnp.exp(-0.5 * dy**2) / (2 * jnp.pi) ** 0.5 492 return sigma * (dy * (2 * Phi - 1.0) + 2 * phi) 493 494 crps1 = compute_crps(ref - predicted_ensemble, sigma) 495 gm_crps1 = (gm_w * crps1).sum(axis=ensemble_axis) 496 497 if ensemble_axis < 0: 498 ensemble_axis = predicted_ensemble.ndim + ensemble_axis 499 muij = jnp.expand_dims( 500 predicted_ensemble, axis=ensemble_axis 501 ) - jnp.expand_dims(predicted_ensemble, axis=ensemble_axis + 1) 502 sigma2 = sigma**2 503 sigmaij = ( 504 jnp.expand_dims(sigma2, axis=ensemble_axis) 505 + jnp.expand_dims(sigma2, axis=ensemble_axis + 1) 506 ) ** 0.5 507 wij = jnp.expand_dims(gm_w, axis=ensemble_axis) * jnp.expand_dims( 508 gm_w, axis=ensemble_axis + 1 509 ) 510 crps2 = compute_crps(muij, sigmaij) 511 gm_crps2 = (wij * crps2).sum( 512 axis=(ensemble_axis, ensemble_axis + 1) 513 ) 514 515 crps = gm_crps1 - 0.5 * gm_crps2 516 loss = jnp.sum(natscale * truth_mask * crps) 517 518 elif loss_type == "evidential": 519 evidence = loss_prms["evidence_key"] 520 nu, alpha, beta = jnp.split(output[evidence], 3, axis=-1) 521 gamma = predicted 522 nu = nu.reshape(shape_mask) 523 alpha = alpha.reshape(shape_mask) 524 beta = beta.reshape(shape_mask) 525 nu = jnp.where(truth_mask, nu, 1.0) 526 alpha = jnp.where(truth_mask, alpha, 1.0) 527 beta = jnp.where(truth_mask, beta, 1.0) 528 omega = 2 * beta * (1 + nu) 529 lg = jax.scipy.special.gammaln(alpha) - jax.scipy.special.gammaln( 530 alpha + 0.5 531 ) 532 ls = 0.5 * jnp.log(jnp.pi / nu) - alpha * jnp.log(omega) 533 lt = (alpha + 0.5) * jnp.log(omega + nu * (gamma - ref) ** 2) 534 wst = ( 535 (beta * (1 + nu) / (alpha * nu)) ** 0.5 536 if loss_prms.get("normalize_evidence", True) 537 else 1.0 538 ) 539 lr = ( 540 loss_prms.get("lambda_evidence", 1.0) 541 * jnp.abs(gamma - ref) 542 * nu 543 / wst 544 ) 545 r = loss_prms.get("evidence_ratio", 1.0) 546 le = ( 547 loss_prms.get("lambda_evidence_diff", 0.0) 548 * (nu - r * 2 * alpha) ** 2 549 ) 550 lb = loss_prms.get("lambda_evidence_beta", 0.0) * beta 551 loss = lg + ls + lt + lr + le + lb 552 553 loss = jnp.sum(natscale * loss * truth_mask) 554 elif loss_type == "raw": 555 loss = jnp.sum(natscale * predicted) 556 else: 557 raise ValueError(f"Unknown loss type: {loss_type}") 558 559 if "weight_schedule" in loss_prms: 560 w = linear_schedule(epoch, *loss_prms["weight_schedule"]) 561 else: 562 w = loss_prms["weight"] 563 564 loss_tot = loss_tot + w * loss 565 566 return loss_tot, output 567 568 (loss, o), grad = jax.value_and_grad(loss_fn, has_aux=True)(variables) 569 updates, opt_st = optimizer.update(grad, opt_st, params=variables) 570 variables = optax.apply_updates(variables, updates) 571 if ema is not None: 572 if variables_ema is None or ema_st is None: 573 raise ValueError( 574 "train_step was setup with ema but either variables_ema or ema_st was not provided" 575 ) 576 variables_ema, ema_st = ema.update(variables, ema_st) 577 return loss, variables, opt_st, variables_ema, ema_st, o 578 else: 579 return loss, variables, opt_st, o 580 581 if jit: 582 return jax.jit(train_step) 583 return train_step
def
get_validation_function( loss_definition: Dict, model: fennol.models.fennix.FENNIX, evaluate: Callable, model_ref: Optional[fennol.models.fennix.FENNIX] = None, compute_ref_coords: bool = False, return_targets: bool = False, jit: bool = True):
586def get_validation_function( 587 loss_definition: Dict, 588 model: FENNIX, 589 evaluate: Callable, 590 model_ref: Optional[FENNIX] = None, 591 compute_ref_coords: bool = False, 592 return_targets: bool = False, 593 jit: bool = True, 594): 595 596 def validation(data, inputs, variables, inputs_ref=None): 597 if model_ref is not None: 598 output_ref = evaluate(model_ref, model_ref.variables, inputs) 599 # _, _, output_ref = model_ref._energy_and_forces(model_ref.variables, data) 600 output = evaluate(model, variables, inputs) 601 # _, _, output = model._energy_and_forces(variables, data) 602 rmses = {} 603 maes = {} 604 if return_targets: 605 targets = {} 606 607 natoms = jnp.where(inputs["true_sys"], inputs["natoms"], 1) 608 nsys = inputs["natoms"].shape[0] 609 system_mask = inputs["true_sys"] 610 atom_mask = inputs["true_atoms"] 611 if "system_sign" in inputs: 612 system_sign = inputs["system_sign"] > 0 613 system_mask = jnp.logical_and(system_mask,system_sign) 614 atom_mask = jnp.logical_and(atom_mask, system_sign[inputs["batch_index"]]) 615 616 for name, loss_prms in loss_definition.items(): 617 do_validation = loss_prms.get("validate", True) 618 if not do_validation: 619 continue 620 predicted = output[loss_prms["key"]] 621 use_ref_mask = False 622 623 if "remove_ref_sys" in loss_prms and loss_prms["remove_ref_sys"]: 624 assert compute_ref_coords, "compute_ref_coords must be True" 625 # predicted = predicted - output_data_ref[loss_prms["key"]] 626 assert predicted.shape[0] == nsys, "remove_ref_sys only works with system-level predictions" 627 shape_mask = [predicted.shape[0]] + [1] * (len(predicted.shape) - 1) 628 system_sign = inputs["system_sign"].reshape(*shape_mask) 629 predicted = jax.ops.segment_sum(system_sign*predicted,inputs["system_index"],nsys) 630 631 if "ref" in loss_prms: 632 if loss_prms["ref"].startswith("model_ref/"): 633 assert model_ref is not None, "model_ref must be provided" 634 ref = output_ref[loss_prms["ref"][10:]] * loss_prms["mult"] 635 elif loss_prms["ref"].startswith("model/"): 636 ref = output[loss_prms["ref"][6:]] * loss_prms["mult"] 637 else: 638 ref = data[loss_prms["ref"]] * loss_prms["mult"] 639 if loss_prms["ref"] + "_mask" in data: 640 use_ref_mask = True 641 ref_mask = data[loss_prms["ref"] + "_mask"] 642 else: 643 ref = jnp.zeros_like(predicted) 644 645 if "norm_axis" in loss_prms and loss_prms["norm_axis"] is not None: 646 norm_axis = loss_prms["norm_axis"] 647 predicted = jnp.linalg.norm( 648 predicted, axis=norm_axis, keepdims=True 649 ) 650 ref = jnp.linalg.norm(ref, axis=norm_axis, keepdims=True) 651 652 if loss_prms["type"].startswith("ensemble"): 653 axis = loss_prms.get("ensemble_axis", -1) 654 ensemble_weights_key = loss_prms.get("ensemble_weights", None) 655 if ensemble_weights_key is not None: 656 ensemble_weights = output[ensemble_weights_key] 657 else: 658 ensemble_weights = jnp.ones_like(predicted) 659 ensemble_weights = ensemble_weights / jnp.sum( 660 ensemble_weights, axis=axis, keepdims=True 661 ) 662 663 # predicted = predicted.mean(axis=axis) 664 predicted = (ensemble_weights * predicted).sum(axis=axis) 665 666 if loss_prms["type"] in ["gaussian_mixture"]: 667 axis = loss_prms.get("ensemble_axis", -1) 668 gm_w = output[loss_prms["key"] + "_w"] 669 gm_w = gm_w / jnp.sum(gm_w, axis=axis, keepdims=True) 670 predicted = (predicted * gm_w).sum(axis=axis) 671 672 if predicted.ndim > 1 and predicted.shape[-1] == 1: 673 predicted = jnp.squeeze(predicted, axis=-1) 674 675 if ref.ndim > 1 and ref.shape[-1] == 1: 676 ref = jnp.squeeze(ref, axis=-1) 677 678 shape_mask = [ref.shape[0]] + [1] * (len(predicted.shape) - 1) 679 if ref.shape[0] == output["batch_index"].shape[0]: 680 ## shape is number of atoms 681 truth_mask = atom_mask 682 elif ref.shape[0] == natoms.shape[0]: 683 ## shape is number of systems 684 truth_mask = system_mask 685 if loss_prms.get("per_atom_validation", False): 686 ref = ref / natoms.reshape(*shape_mask) 687 predicted = predicted / natoms.reshape(*shape_mask) 688 else: 689 truth_mask = jnp.ones(ref.shape[0], dtype=bool) 690 691 if use_ref_mask: 692 truth_mask = truth_mask * ref_mask.astype(bool) 693 694 nel = jnp.maximum( 695 (float(np.prod(ref.shape)) / float(truth_mask.shape[0])) 696 * jnp.sum(truth_mask).astype(predicted.dtype), 697 1.0, 698 ) 699 700 truth_mask = truth_mask.reshape(*shape_mask) 701 702 ref = ref * truth_mask 703 predicted = predicted * truth_mask 704 705 rmse = jnp.sum((predicted - ref) ** 2 / nel) ** 0.5 706 mae = jnp.sum(jnp.abs(predicted - ref) / nel) 707 708 rmses[name] = rmse 709 maes[name] = mae 710 if return_targets: 711 targets[name] = (predicted, ref, truth_mask) 712 713 if return_targets: 714 return rmses, maes, output, targets 715 716 return rmses, maes, output 717 718 if jit: 719 return jax.jit(validation) 720 return validation