fennol.training.optimizers
1from typing import ( 2 Callable, 3 Optional, 4 Dict, 5 List, 6 Tuple, 7 Union, 8 Any, 9 NamedTuple, 10 Sequence, 11) 12import optax 13import jax 14import jax.numpy as jnp 15import numpy as np 16import operator 17from flax import traverse_util 18import json 19import re 20 21from optax._src import base 22from optax._src import wrappers 23import chex 24from optax import tree_utils as otu 25 26 27class AddWeightDiffState(NamedTuple): 28 ref_weights: Any 29 30 31def add_weights_difference( 32 weight_decay: Union[float, jax.Array] = 0.0, 33 mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 34) -> base.GradientTransformation: 35 """weight decay toward initial weights.""" 36 37 def init_fn(params): 38 return AddWeightDiffState(ref_weights=params) 39 40 def update_fn(updates, state, params): 41 if params is None: 42 raise ValueError(base.NO_PARAMS_MSG) 43 updates = jax.tree_util.tree_map( 44 lambda g, p, pref: g + weight_decay * (p - pref), 45 updates, 46 params, 47 state.ref_weights, 48 ) 49 return updates, state 50 51 # If mask is not `None`, apply mask to the gradient transformation. 52 # E.g. it is common to skip weight decay on bias units and batch stats. 53 if mask is not None: 54 return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask) 55 return base.GradientTransformation(init_fn, update_fn) 56 57 58def add_grokfast( 59 alpha: float = 0.9, 60 l: float = 1.0, 61) -> base.GradientTransformation: 62 """Grokfast: amplify slow gradients by exponential moving average.""" 63 64 ema: base.GradientTransformation = optax.ema(decay=alpha, debias=False) 65 66 def init_fn(params): 67 return ema.init(params) 68 69 def update_fn(updates, state, params=None): 70 dupdates, state = ema.update(updates, state, params) 71 # updates = updates + l*dupdates 72 updates = jax.tree_util.tree_map(lambda g, d: g + l * d, updates, dupdates) 73 return updates, state 74 75 return base.GradientTransformation(init_fn, update_fn) 76 77 78class PROFITState(NamedTuple): 79 ref_weights: Any 80 istep: int 81 main_opt_state: Any 82 internal_opt_state: Any 83 84 85def profit( 86 learning_rate: base.ScalarOrSchedule, 87 nsteps_ref: int = 1, 88 main_opt: str = "adam", 89 main_opt_params: Dict[str, Any] = {}, 90 internal_opt: str = "sgd", 91 internal_opt_params: Dict[str, Any] = {}, 92 **kwargs, 93) -> base.GradientTransformation: 94 """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930""" 95 96 main_opt_params = {"learning_rate": learning_rate, **main_opt_params} 97 main_opt = eval( 98 main_opt, 99 {"__builtins__": None}, 100 {**optax.__dict__}, 101 )(**main_opt_params) 102 103 internal_opt_params = {"learning_rate": 0.1, **internal_opt_params} 104 internal_opt = eval( 105 internal_opt, 106 {"__builtins__": None}, 107 {**optax.__dict__}, 108 )(**internal_opt_params) 109 110 def init_fn(params): 111 return PROFITState( 112 ref_weights=params, 113 istep=0, 114 main_opt_state=main_opt.init(params), 115 internal_opt_state=internal_opt.init(params), 116 ) 117 118 def update_main(gradients, main_opt_state, internal_opt_state, params, params_ref): 119 delta = jax.tree_util.tree_map(lambda p, pref: p - pref, params, params_ref) 120 dot = jax.tree.reduce( 121 operator.add, 122 jax.tree_util.tree_map(lambda g, d: (g * d).sum(), gradients, delta), 123 ) 124 delta2 = jax.tree.reduce( 125 operator.add, jax.tree_util.tree_map(lambda d: (d**2).sum(), delta) 126 ) 127 proj = dot / (delta2 + 1.0e-6) 128 129 gradients = jax.lax.cond( 130 dot >= 0, 131 lambda g, d: g, 132 lambda g, d: jax.tree_util.tree_map(lambda x: proj * x, d), 133 gradients, 134 delta, 135 ) 136 updates, main_opt_state = main_opt.update(gradients, main_opt_state, params) 137 updates = jax.tree_util.tree_map(lambda g, d: g - d, updates, delta) 138 return updates, main_opt_state, internal_opt_state 139 140 def update_internal( 141 gradients, main_opt_state, internal_opt_state, params, params_ref 142 ): 143 updates, internal_opt_state = internal_opt.update( 144 gradients, internal_opt_state, params 145 ) 146 return updates, main_opt_state, internal_opt_state 147 148 def update_fn(gradients, state, params): 149 istep = state.istep % (nsteps_ref + 1) 150 # jax.debug.print("{i} {j}",i=istep,j=state.istep) 151 152 params_ref = jax.lax.cond( 153 istep == 0, lambda a, b: a, lambda a, b: b, params, state.ref_weights 154 ) 155 156 updates, main_opt_state, internal_opt_state = jax.lax.cond( 157 istep == nsteps_ref, 158 update_main, 159 update_internal, 160 gradients, 161 state.main_opt_state, 162 state.internal_opt_state, 163 params, 164 params_ref, 165 ) 166 167 new_state = PROFITState( 168 ref_weights=params_ref, 169 istep=state.istep + 1, 170 main_opt_state=main_opt_state, 171 internal_opt_state=internal_opt_state, 172 ) 173 return updates, new_state 174 175 return base.GradientTransformation(init_fn, update_fn) 176 177 178class MultiEmaState(NamedTuple): 179 """Holds an exponential moving average of past updates.""" 180 181 count: int 182 ema: Sequence[base.Params] 183 184 185def multi_ema( 186 decays: Sequence[float], 187 debias: bool = True, 188 power: Union[bool,Sequence[bool]] = False, 189) -> base.GradientTransformation: 190 """Compute mutliple power moving averages of past updates.""" 191 192 if len(decays) == 0: 193 def init_fn(params): 194 return base.EmptyState() 195 def update_fn(updates, state, params=None): 196 return [updates], state 197 return base.GradientTransformation(init_fn, update_fn) 198 199 if isinstance(power, bool): 200 power = [power] * len(decays) 201 assert len(power) == len(decays), "power and decays must have the same length" 202 203 gammas = [] 204 for decay,p in zip(decays,power): 205 if not p: 206 gammas.append(None) 207 continue 208 t = decay**-2 209 gamma = np.roots([1, 7, 16 - t, 12 - t]).real.max() 210 assert gamma > 0, f"Invalid gamma for decay {decay}: {gamma}" 211 gammas.append(gamma) 212 213 def init_fn(params): 214 return MultiEmaState( 215 count=0, 216 ema=[otu.tree_zeros_like(params)] * len(decays), 217 ) 218 219 def update_fn(params, state, other=None): 220 count_inc = state.count + 1 221 updates = [] 222 state_ema = [] 223 for decay,gamma,ema in zip(decays, gammas,state.ema): 224 if gamma is not None: 225 decay = (1.0 - 1.0 / count_inc) ** (gamma + 1) 226 update = new_ema = otu.tree_update_moment( 227 params, ema, decay, order=1 228 ) 229 if debias and gamma is None: 230 update = otu.tree_bias_correction(update, decay, count_inc) 231 updates.append(update) 232 state_ema.append(new_ema) 233 return updates, MultiEmaState(count=count_inc, ema=state_ema) 234 235 return base.GradientTransformation(init_fn, update_fn) 236 237 238def get_optimizer( 239 training_parameters: Dict[str, any], variables: Dict, initial_lr: float 240) -> optax.GradientTransformation: 241 """ 242 Returns an optax.GradientTransformation object that can be used to optimize the model parameters. 243 244 Args: 245 - training_parameters: A dictionary containing the training parameters. 246 - variables: A pytree containing the model parameters. 247 - initial_lr: The initial learning rate. 248 249 Returns: 250 - An optax.GradientTransformation object that can be used to optimize the model parameters. 251 """ 252 253 default_status = str(training_parameters.get("default_status", "trainable")).lower() 254 assert default_status in [ 255 "trainable", 256 "frozen", 257 ], f"Default status must be 'trainable' or 'frozen', got {default_status}" 258 259 # find frozen and trainable parameters 260 frozen = training_parameters.get("frozen", []) 261 trainable = training_parameters.get("trainable", []) 262 263 def training_status(full_path, v): 264 full_path = "/".join(full_path[1:]).lower() 265 # full_path = "/".join(full_path[1:]).lower() 266 status = (default_status, "") 267 for path in frozen: 268 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 269 if re.match(path.lower(), full_path): 270 status = ("frozen", path) 271 for path in trainable: 272 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 273 if re.match(path.lower(), full_path): 274 status = ("trainable", path) 275 return status[0] 276 277 params_partition = traverse_util.path_aware_map(training_status, variables) 278 if len(frozen) > 0 or len(trainable) > 0: 279 print("params partition:") 280 print(json.dumps(params_partition, indent=2, sort_keys=False)) 281 282 ## Gradient preprocessing 283 grad_processing = [] 284 285 # zero nans 286 zero_nans = training_parameters.get("zero_nans", False) 287 if zero_nans: 288 grad_processing.append(optax.zero_nans()) 289 290 use_grokfast = training_parameters.get("use_grokfast", False) 291 if use_grokfast: 292 print("Using Grokfast") 293 alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9) 294 l_grokfast = training_parameters.get("l_grokfast", 1.0) 295 grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast)) 296 297 298 299 # OPTIMIZER 300 optimizer_name = training_parameters.get("optimizer", "adabelief") 301 optimizer = eval( 302 optimizer_name, 303 {"__builtins__": None}, 304 { 305 **optax.__dict__, 306 "profit": profit, 307 }, 308 ) 309 print("Optimizer:", optimizer_name) 310 optimizer_configuration = training_parameters.get("optimizer_config", {}) 311 optimizer_configuration["learning_rate"] = 1.0 312 grad_processing.append(optimizer(**optimizer_configuration)) 313 314 # weight decay 315 weight_decay = training_parameters.get("weight_decay", 0.0) 316 assert weight_decay >= 0.0, "Weight decay must be positive" 317 decay_targets = training_parameters.get("decay_targets", [""]) 318 319 def decay_status(full_path, v): 320 full_path = "/".join(full_path).lower() 321 status = False 322 # print(full_path,re.match(r'^params\/', full_path)) 323 for path in decay_targets: 324 if re.match(r"^params/" + path.lower(), full_path): 325 status = True 326 # if full_path.startswith("params/" + path.lower()): 327 # status = True 328 return status 329 330 decay_mask = traverse_util.path_aware_map(decay_status, variables) 331 if weight_decay > 0.0: 332 print("weight decay:", weight_decay) 333 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 334 grad_processing.append( 335 optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask) 336 ) 337 338 regularize_init_weight = training_parameters.get("regularize_init_weights", 0.0) 339 if regularize_init_weight > 0.0: 340 print( 341 "Regularizing toward initial weights with L2 norm:", regularize_init_weight 342 ) 343 if weight_decay <= 0.0: 344 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 345 346 grad_processing.append( 347 add_weights_difference( 348 weight_decay=-regularize_init_weight, mask=decay_mask 349 ) 350 ) 351 352 if zero_nans: 353 grad_processing.append(optax.zero_nans()) 354 355 # learning rate 356 grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr)) 357 ilr = -1 358 359 # gradient clipping 360 clip_threshold = training_parameters.get("gradient_clipping", -1.0) 361 if clip_threshold > 0.0: 362 print("Adaptive gradient clipping threshold:", clip_threshold) 363 grad_processing.append(optax.adaptive_grad_clip(clip_threshold)) 364 ilr -= 1 365 366 ## define optimizer chain 367 optimizer_ = optax.chain( 368 *grad_processing, 369 ) 370 partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()} 371 return optax.multi_transform(partition_optimizer, params_partition),ilr 372 373 374def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters): 375 lr = training_parameters.get("lr", 1.0e-3) 376 init_lr = training_parameters.get("init_lr", lr / 25) 377 final_lr = training_parameters.get("final_lr", lr / 10000) 378 379 #### LEARNING RATE SCHEDULER #### 380 schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower() 381 schedule_type = training_parameters.get("scheduler", schedule_type).lower() 382 schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot") 383 384 adaptive_scheduler = False 385 print("Schedule type:", schedule_type) 386 if schedule_type == "cosine_onecycle": 387 transition_epochs = training_parameters.get("onecycle_epochs", max_epochs) 388 peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs) 389 schedule_ = optax.cosine_onecycle_schedule( 390 peak_value=lr, 391 div_factor=lr / init_lr, 392 final_div_factor=init_lr / final_lr, 393 transition_steps=transition_epochs * nbatch_per_epoch, 394 pct_start=peak_epoch / transition_epochs, 395 ) 396 sch_state = {"count": 0, "best": np.inf, "lr": init_lr} 397 398 def schedule(state, rmse=None): 399 new_state = {**state} 400 lr = schedule_(state["count"]) 401 if rmse is None: 402 new_state["count"] += 1 403 new_state["lr"] = lr 404 return lr, new_state 405 406 elif schedule_type == "piecewise_interpolate": 407 schedule_params = training_parameters.get("scheduler_parameters", {}) 408 schedule_ = optax.piecewise_interpolate_schedule( 409 **{"init_value": lr, "interpolate_type": "linear", **schedule_params} 410 ) 411 sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)} 412 413 def schedule(state, rmse=None): 414 new_state = {**state} 415 lr = schedule_(state["count"]) 416 if rmse is None: 417 new_state["count"] += 1 418 new_state["lr"] = lr 419 return lr, new_state 420 421 elif schedule_type == "constant": 422 sch_state = {"count": 0} 423 424 def schedule(state, rmse=None): 425 new_state = {**state} 426 new_state["lr"] = lr 427 if rmse is None: 428 new_state["count"] += 1 429 return lr, new_state 430 431 elif schedule_type == "cosine": 432 assert ( 433 "peak_epoch" in training_parameters 434 ), "Sine schedule requires 'peak_epoch' parameter" 435 period = training_parameters["peak_epoch"] * nbatch_per_epoch 436 peak_lr = lr 437 sch_state = {"count": 0} 438 439 def schedule(state, rmse=None): 440 new_state = {**state} 441 istep = state["count"] 442 g = 0.5 * (1 + jnp.cos(jnp.pi * istep / period)) 443 lr = peak_lr + (init_lr - peak_lr) * g 444 new_state["lr"] = lr 445 if rmse is None: 446 new_state["count"] += 1 447 return lr, new_state 448 449 elif schedule_type == "reduce_on_plateau": 450 patience = training_parameters.get("patience", 10) 451 factor = training_parameters.get("lr_factor", 0.5) 452 patience_thr = training_parameters.get("patience_thr", 0.0) 453 sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience} 454 adaptive_scheduler = True 455 456 def schedule(state, rmse=None): 457 new_state = {**state} 458 if rmse is None: 459 new_state["count"] += 1 460 return state["lr"], new_state 461 if rmse <= state["best"] * (1.0 + patience_thr): 462 if rmse < state["best"]: 463 new_state["best"] = rmse 464 new_state["patience"] = 0 465 else: 466 new_state["patience"] += 1 467 if new_state["patience"] >= patience: 468 new_state["lr"] = state["lr"] * factor 469 new_state["patience"] = 0 470 print("Reducing learning rate to", new_state["lr"]) 471 return new_state["lr"], new_state 472 473 else: 474 raise ValueError(f"Unknown schedule_type: {schedule_type}") 475 476 stochastic_scheduler = training_parameters.get("stochastic_scheduler", False) 477 if stochastic_scheduler: 478 schedule_ = schedule 479 rng_key, scheduler_key = jax.random.split(rng_key) 480 sch_state["rng_key"] = scheduler_key 481 sch_state["lr_max"] = lr 482 sch_state["lr_min"] = final_lr 483 484 def schedule(state, rmse=None): 485 new_state = {**state, "lr": state["lr_max"]} 486 if rmse is None: 487 lr_max, new_state = schedule_(new_state, rmse=rmse) 488 lr_min = new_state["lr_min"] 489 new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"]) 490 lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey) 491 new_state["lr"] = lr 492 new_state["lr_max"] = lr_max 493 494 return new_state["lr"], new_state 495 496 return schedule, sch_state, schedule_metrics, adaptive_scheduler
class
AddWeightDiffState(typing.NamedTuple):
AddWeightDiffState(ref_weights,)
def
add_weights_difference( weight_decay: Union[float, jax.Array] = 0.0, mask: Union[Any, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]], Any], NoneType] = None) -> optax._src.base.GradientTransformation:
32def add_weights_difference( 33 weight_decay: Union[float, jax.Array] = 0.0, 34 mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, 35) -> base.GradientTransformation: 36 """weight decay toward initial weights.""" 37 38 def init_fn(params): 39 return AddWeightDiffState(ref_weights=params) 40 41 def update_fn(updates, state, params): 42 if params is None: 43 raise ValueError(base.NO_PARAMS_MSG) 44 updates = jax.tree_util.tree_map( 45 lambda g, p, pref: g + weight_decay * (p - pref), 46 updates, 47 params, 48 state.ref_weights, 49 ) 50 return updates, state 51 52 # If mask is not `None`, apply mask to the gradient transformation. 53 # E.g. it is common to skip weight decay on bias units and batch stats. 54 if mask is not None: 55 return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask) 56 return base.GradientTransformation(init_fn, update_fn)
weight decay toward initial weights.
def
add_grokfast( alpha: float = 0.9, l: float = 1.0) -> optax._src.base.GradientTransformation:
59def add_grokfast( 60 alpha: float = 0.9, 61 l: float = 1.0, 62) -> base.GradientTransformation: 63 """Grokfast: amplify slow gradients by exponential moving average.""" 64 65 ema: base.GradientTransformation = optax.ema(decay=alpha, debias=False) 66 67 def init_fn(params): 68 return ema.init(params) 69 70 def update_fn(updates, state, params=None): 71 dupdates, state = ema.update(updates, state, params) 72 # updates = updates + l*dupdates 73 updates = jax.tree_util.tree_map(lambda g, d: g + l * d, updates, dupdates) 74 return updates, state 75 76 return base.GradientTransformation(init_fn, update_fn)
Grokfast: amplify slow gradients by exponential moving average.
class
PROFITState(typing.NamedTuple):
79class PROFITState(NamedTuple): 80 ref_weights: Any 81 istep: int 82 main_opt_state: Any 83 internal_opt_state: Any
PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)
def
profit( learning_rate: Union[float, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]]], nsteps_ref: int = 1, main_opt: str = 'adam', main_opt_params: Dict[str, Any] = {}, internal_opt: str = 'sgd', internal_opt_params: Dict[str, Any] = {}, **kwargs) -> optax._src.base.GradientTransformation:
86def profit( 87 learning_rate: base.ScalarOrSchedule, 88 nsteps_ref: int = 1, 89 main_opt: str = "adam", 90 main_opt_params: Dict[str, Any] = {}, 91 internal_opt: str = "sgd", 92 internal_opt_params: Dict[str, Any] = {}, 93 **kwargs, 94) -> base.GradientTransformation: 95 """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930""" 96 97 main_opt_params = {"learning_rate": learning_rate, **main_opt_params} 98 main_opt = eval( 99 main_opt, 100 {"__builtins__": None}, 101 {**optax.__dict__}, 102 )(**main_opt_params) 103 104 internal_opt_params = {"learning_rate": 0.1, **internal_opt_params} 105 internal_opt = eval( 106 internal_opt, 107 {"__builtins__": None}, 108 {**optax.__dict__}, 109 )(**internal_opt_params) 110 111 def init_fn(params): 112 return PROFITState( 113 ref_weights=params, 114 istep=0, 115 main_opt_state=main_opt.init(params), 116 internal_opt_state=internal_opt.init(params), 117 ) 118 119 def update_main(gradients, main_opt_state, internal_opt_state, params, params_ref): 120 delta = jax.tree_util.tree_map(lambda p, pref: p - pref, params, params_ref) 121 dot = jax.tree.reduce( 122 operator.add, 123 jax.tree_util.tree_map(lambda g, d: (g * d).sum(), gradients, delta), 124 ) 125 delta2 = jax.tree.reduce( 126 operator.add, jax.tree_util.tree_map(lambda d: (d**2).sum(), delta) 127 ) 128 proj = dot / (delta2 + 1.0e-6) 129 130 gradients = jax.lax.cond( 131 dot >= 0, 132 lambda g, d: g, 133 lambda g, d: jax.tree_util.tree_map(lambda x: proj * x, d), 134 gradients, 135 delta, 136 ) 137 updates, main_opt_state = main_opt.update(gradients, main_opt_state, params) 138 updates = jax.tree_util.tree_map(lambda g, d: g - d, updates, delta) 139 return updates, main_opt_state, internal_opt_state 140 141 def update_internal( 142 gradients, main_opt_state, internal_opt_state, params, params_ref 143 ): 144 updates, internal_opt_state = internal_opt.update( 145 gradients, internal_opt_state, params 146 ) 147 return updates, main_opt_state, internal_opt_state 148 149 def update_fn(gradients, state, params): 150 istep = state.istep % (nsteps_ref + 1) 151 # jax.debug.print("{i} {j}",i=istep,j=state.istep) 152 153 params_ref = jax.lax.cond( 154 istep == 0, lambda a, b: a, lambda a, b: b, params, state.ref_weights 155 ) 156 157 updates, main_opt_state, internal_opt_state = jax.lax.cond( 158 istep == nsteps_ref, 159 update_main, 160 update_internal, 161 gradients, 162 state.main_opt_state, 163 state.internal_opt_state, 164 params, 165 params_ref, 166 ) 167 168 new_state = PROFITState( 169 ref_weights=params_ref, 170 istep=state.istep + 1, 171 main_opt_state=main_opt_state, 172 internal_opt_state=internal_opt_state, 173 ) 174 return updates, new_state 175 176 return base.GradientTransformation(init_fn, update_fn)
PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930
class
MultiEmaState(typing.NamedTuple):
179class MultiEmaState(NamedTuple): 180 """Holds an exponential moving average of past updates.""" 181 182 count: int 183 ema: Sequence[base.Params]
Holds an exponential moving average of past updates.
def
multi_ema( decays: Sequence[float], debias: bool = True, power: Union[bool, Sequence[bool]] = False) -> optax._src.base.GradientTransformation:
186def multi_ema( 187 decays: Sequence[float], 188 debias: bool = True, 189 power: Union[bool,Sequence[bool]] = False, 190) -> base.GradientTransformation: 191 """Compute mutliple power moving averages of past updates.""" 192 193 if len(decays) == 0: 194 def init_fn(params): 195 return base.EmptyState() 196 def update_fn(updates, state, params=None): 197 return [updates], state 198 return base.GradientTransformation(init_fn, update_fn) 199 200 if isinstance(power, bool): 201 power = [power] * len(decays) 202 assert len(power) == len(decays), "power and decays must have the same length" 203 204 gammas = [] 205 for decay,p in zip(decays,power): 206 if not p: 207 gammas.append(None) 208 continue 209 t = decay**-2 210 gamma = np.roots([1, 7, 16 - t, 12 - t]).real.max() 211 assert gamma > 0, f"Invalid gamma for decay {decay}: {gamma}" 212 gammas.append(gamma) 213 214 def init_fn(params): 215 return MultiEmaState( 216 count=0, 217 ema=[otu.tree_zeros_like(params)] * len(decays), 218 ) 219 220 def update_fn(params, state, other=None): 221 count_inc = state.count + 1 222 updates = [] 223 state_ema = [] 224 for decay,gamma,ema in zip(decays, gammas,state.ema): 225 if gamma is not None: 226 decay = (1.0 - 1.0 / count_inc) ** (gamma + 1) 227 update = new_ema = otu.tree_update_moment( 228 params, ema, decay, order=1 229 ) 230 if debias and gamma is None: 231 update = otu.tree_bias_correction(update, decay, count_inc) 232 updates.append(update) 233 state_ema.append(new_ema) 234 return updates, MultiEmaState(count=count_inc, ema=state_ema) 235 236 return base.GradientTransformation(init_fn, update_fn)
Compute mutliple power moving averages of past updates.
def
get_optimizer( training_parameters: Dict[str, <built-in function any>], variables: Dict, initial_lr: float) -> optax._src.base.GradientTransformation:
239def get_optimizer( 240 training_parameters: Dict[str, any], variables: Dict, initial_lr: float 241) -> optax.GradientTransformation: 242 """ 243 Returns an optax.GradientTransformation object that can be used to optimize the model parameters. 244 245 Args: 246 - training_parameters: A dictionary containing the training parameters. 247 - variables: A pytree containing the model parameters. 248 - initial_lr: The initial learning rate. 249 250 Returns: 251 - An optax.GradientTransformation object that can be used to optimize the model parameters. 252 """ 253 254 default_status = str(training_parameters.get("default_status", "trainable")).lower() 255 assert default_status in [ 256 "trainable", 257 "frozen", 258 ], f"Default status must be 'trainable' or 'frozen', got {default_status}" 259 260 # find frozen and trainable parameters 261 frozen = training_parameters.get("frozen", []) 262 trainable = training_parameters.get("trainable", []) 263 264 def training_status(full_path, v): 265 full_path = "/".join(full_path[1:]).lower() 266 # full_path = "/".join(full_path[1:]).lower() 267 status = (default_status, "") 268 for path in frozen: 269 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 270 if re.match(path.lower(), full_path): 271 status = ("frozen", path) 272 for path in trainable: 273 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 274 if re.match(path.lower(), full_path): 275 status = ("trainable", path) 276 return status[0] 277 278 params_partition = traverse_util.path_aware_map(training_status, variables) 279 if len(frozen) > 0 or len(trainable) > 0: 280 print("params partition:") 281 print(json.dumps(params_partition, indent=2, sort_keys=False)) 282 283 ## Gradient preprocessing 284 grad_processing = [] 285 286 # zero nans 287 zero_nans = training_parameters.get("zero_nans", False) 288 if zero_nans: 289 grad_processing.append(optax.zero_nans()) 290 291 use_grokfast = training_parameters.get("use_grokfast", False) 292 if use_grokfast: 293 print("Using Grokfast") 294 alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9) 295 l_grokfast = training_parameters.get("l_grokfast", 1.0) 296 grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast)) 297 298 299 300 # OPTIMIZER 301 optimizer_name = training_parameters.get("optimizer", "adabelief") 302 optimizer = eval( 303 optimizer_name, 304 {"__builtins__": None}, 305 { 306 **optax.__dict__, 307 "profit": profit, 308 }, 309 ) 310 print("Optimizer:", optimizer_name) 311 optimizer_configuration = training_parameters.get("optimizer_config", {}) 312 optimizer_configuration["learning_rate"] = 1.0 313 grad_processing.append(optimizer(**optimizer_configuration)) 314 315 # weight decay 316 weight_decay = training_parameters.get("weight_decay", 0.0) 317 assert weight_decay >= 0.0, "Weight decay must be positive" 318 decay_targets = training_parameters.get("decay_targets", [""]) 319 320 def decay_status(full_path, v): 321 full_path = "/".join(full_path).lower() 322 status = False 323 # print(full_path,re.match(r'^params\/', full_path)) 324 for path in decay_targets: 325 if re.match(r"^params/" + path.lower(), full_path): 326 status = True 327 # if full_path.startswith("params/" + path.lower()): 328 # status = True 329 return status 330 331 decay_mask = traverse_util.path_aware_map(decay_status, variables) 332 if weight_decay > 0.0: 333 print("weight decay:", weight_decay) 334 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 335 grad_processing.append( 336 optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask) 337 ) 338 339 regularize_init_weight = training_parameters.get("regularize_init_weights", 0.0) 340 if regularize_init_weight > 0.0: 341 print( 342 "Regularizing toward initial weights with L2 norm:", regularize_init_weight 343 ) 344 if weight_decay <= 0.0: 345 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 346 347 grad_processing.append( 348 add_weights_difference( 349 weight_decay=-regularize_init_weight, mask=decay_mask 350 ) 351 ) 352 353 if zero_nans: 354 grad_processing.append(optax.zero_nans()) 355 356 # learning rate 357 grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr)) 358 ilr = -1 359 360 # gradient clipping 361 clip_threshold = training_parameters.get("gradient_clipping", -1.0) 362 if clip_threshold > 0.0: 363 print("Adaptive gradient clipping threshold:", clip_threshold) 364 grad_processing.append(optax.adaptive_grad_clip(clip_threshold)) 365 ilr -= 1 366 367 ## define optimizer chain 368 optimizer_ = optax.chain( 369 *grad_processing, 370 ) 371 partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()} 372 return optax.multi_transform(partition_optimizer, params_partition),ilr
Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
Args:
- training_parameters: A dictionary containing the training parameters.
- variables: A pytree containing the model parameters.
- initial_lr: The initial learning rate.
Returns:
- An optax.GradientTransformation object that can be used to optimize the model parameters.
def
get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
375def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters): 376 lr = training_parameters.get("lr", 1.0e-3) 377 init_lr = training_parameters.get("init_lr", lr / 25) 378 final_lr = training_parameters.get("final_lr", lr / 10000) 379 380 #### LEARNING RATE SCHEDULER #### 381 schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower() 382 schedule_type = training_parameters.get("scheduler", schedule_type).lower() 383 schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot") 384 385 adaptive_scheduler = False 386 print("Schedule type:", schedule_type) 387 if schedule_type == "cosine_onecycle": 388 transition_epochs = training_parameters.get("onecycle_epochs", max_epochs) 389 peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs) 390 schedule_ = optax.cosine_onecycle_schedule( 391 peak_value=lr, 392 div_factor=lr / init_lr, 393 final_div_factor=init_lr / final_lr, 394 transition_steps=transition_epochs * nbatch_per_epoch, 395 pct_start=peak_epoch / transition_epochs, 396 ) 397 sch_state = {"count": 0, "best": np.inf, "lr": init_lr} 398 399 def schedule(state, rmse=None): 400 new_state = {**state} 401 lr = schedule_(state["count"]) 402 if rmse is None: 403 new_state["count"] += 1 404 new_state["lr"] = lr 405 return lr, new_state 406 407 elif schedule_type == "piecewise_interpolate": 408 schedule_params = training_parameters.get("scheduler_parameters", {}) 409 schedule_ = optax.piecewise_interpolate_schedule( 410 **{"init_value": lr, "interpolate_type": "linear", **schedule_params} 411 ) 412 sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)} 413 414 def schedule(state, rmse=None): 415 new_state = {**state} 416 lr = schedule_(state["count"]) 417 if rmse is None: 418 new_state["count"] += 1 419 new_state["lr"] = lr 420 return lr, new_state 421 422 elif schedule_type == "constant": 423 sch_state = {"count": 0} 424 425 def schedule(state, rmse=None): 426 new_state = {**state} 427 new_state["lr"] = lr 428 if rmse is None: 429 new_state["count"] += 1 430 return lr, new_state 431 432 elif schedule_type == "cosine": 433 assert ( 434 "peak_epoch" in training_parameters 435 ), "Sine schedule requires 'peak_epoch' parameter" 436 period = training_parameters["peak_epoch"] * nbatch_per_epoch 437 peak_lr = lr 438 sch_state = {"count": 0} 439 440 def schedule(state, rmse=None): 441 new_state = {**state} 442 istep = state["count"] 443 g = 0.5 * (1 + jnp.cos(jnp.pi * istep / period)) 444 lr = peak_lr + (init_lr - peak_lr) * g 445 new_state["lr"] = lr 446 if rmse is None: 447 new_state["count"] += 1 448 return lr, new_state 449 450 elif schedule_type == "reduce_on_plateau": 451 patience = training_parameters.get("patience", 10) 452 factor = training_parameters.get("lr_factor", 0.5) 453 patience_thr = training_parameters.get("patience_thr", 0.0) 454 sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience} 455 adaptive_scheduler = True 456 457 def schedule(state, rmse=None): 458 new_state = {**state} 459 if rmse is None: 460 new_state["count"] += 1 461 return state["lr"], new_state 462 if rmse <= state["best"] * (1.0 + patience_thr): 463 if rmse < state["best"]: 464 new_state["best"] = rmse 465 new_state["patience"] = 0 466 else: 467 new_state["patience"] += 1 468 if new_state["patience"] >= patience: 469 new_state["lr"] = state["lr"] * factor 470 new_state["patience"] = 0 471 print("Reducing learning rate to", new_state["lr"]) 472 return new_state["lr"], new_state 473 474 else: 475 raise ValueError(f"Unknown schedule_type: {schedule_type}") 476 477 stochastic_scheduler = training_parameters.get("stochastic_scheduler", False) 478 if stochastic_scheduler: 479 schedule_ = schedule 480 rng_key, scheduler_key = jax.random.split(rng_key) 481 sch_state["rng_key"] = scheduler_key 482 sch_state["lr_max"] = lr 483 sch_state["lr_min"] = final_lr 484 485 def schedule(state, rmse=None): 486 new_state = {**state, "lr": state["lr_max"]} 487 if rmse is None: 488 lr_max, new_state = schedule_(new_state, rmse=rmse) 489 lr_min = new_state["lr_min"] 490 new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"]) 491 lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey) 492 new_state["lr"] = lr 493 new_state["lr_max"] = lr_max 494 495 return new_state["lr"], new_state 496 497 return schedule, sch_state, schedule_metrics, adaptive_scheduler