fennol.training.optimizers
1from typing import Callable, Optional, Dict, List, Tuple, Union, Any, NamedTuple 2import optax 3import jax 4import jax.numpy as jnp 5import numpy as np 6import operator 7from flax import traverse_util 8import json 9import re 10 11from optax._src import base 12from optax._src import wrappers 13 14 15 16class AddWeightDiffState(NamedTuple): 17 ref_weights: Any 18 19def add_weights_difference( 20 weight_decay: Union[float, jax.Array] = 0.0, 21 mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None 22) -> base.GradientTransformation: 23 """weight decay toward initial weights.""" 24 def init_fn(params): 25 return AddWeightDiffState(ref_weights=params) 26 27 def update_fn(updates, state, params): 28 if params is None: 29 raise ValueError(base.NO_PARAMS_MSG) 30 updates = jax.tree_util.tree_map( 31 lambda g, p, pref: g + weight_decay * (p-pref), updates, params,state.ref_weights) 32 return updates, state 33 34 # If mask is not `None`, apply mask to the gradient transformation. 35 # E.g. it is common to skip weight decay on bias units and batch stats. 36 if mask is not None: 37 return wrappers.masked( 38 base.GradientTransformation(init_fn, update_fn), mask) 39 return base.GradientTransformation(init_fn, update_fn) 40 41 42 43def add_grokfast( 44 alpha: float = 0.9, 45 l: float = 1., 46)-> base.GradientTransformation: 47 """Grokfast: amplify slow gradients by exponential moving average.""" 48 49 ema: base.GradientTransformation = optax.ema(decay=alpha,debias=False) 50 def init_fn(params): 51 return ema.init(params) 52 53 def update_fn(updates, state, params=None): 54 dupdates, state = ema.update(updates, state, params) 55 # updates = updates + l*dupdates 56 updates = jax.tree_util.tree_map( 57 lambda g,d: g+l*d, updates, dupdates) 58 return updates, state 59 60 return base.GradientTransformation(init_fn, update_fn) 61 62 63class PROFITState(NamedTuple): 64 ref_weights: Any 65 istep: int 66 main_opt_state: Any 67 internal_opt_state: Any 68 69def profit( 70 learning_rate: base.ScalarOrSchedule, 71 nsteps_ref: int = 1, 72 main_opt: str = 'adam', 73 main_opt_params: Dict[str, Any] = {}, 74 internal_opt: str = 'sgd', 75 internal_opt_params: Dict[str, Any] = {}, 76 **kwargs 77)-> base.GradientTransformation: 78 """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930""" 79 80 main_opt_params = {'learning_rate': learning_rate,**main_opt_params} 81 main_opt = eval( 82 main_opt, 83 {"__builtins__": None}, 84 {**optax.__dict__}, 85 )(**main_opt_params) 86 87 internal_opt_params = {'learning_rate': .1,**internal_opt_params} 88 internal_opt = eval( 89 internal_opt, 90 {"__builtins__": None}, 91 {**optax.__dict__}, 92 )(**internal_opt_params) 93 94 95 def init_fn(params): 96 return PROFITState( 97 ref_weights=params, 98 istep=0, 99 main_opt_state=main_opt.init(params), 100 internal_opt_state=internal_opt.init(params), 101 ) 102 103 104 def update_main(gradients,main_opt_state,internal_opt_state,params,params_ref): 105 delta = jax.tree_util.tree_map(lambda p, pref: p-pref, params, params_ref) 106 dot = jax.tree.reduce( 107 operator.add, 108 jax.tree_util.tree_map(lambda g,d: (g*d).sum(), gradients, delta) 109 ) 110 delta2 = jax.tree.reduce( 111 operator.add, 112 jax.tree_util.tree_map(lambda d: (d**2).sum(), delta) 113 ) 114 proj = dot/(delta2+1.e-6) 115 116 gradients = jax.lax.cond(dot>=0, 117 lambda g,d: g, 118 lambda g,d: jax.tree_util.tree_map(lambda x: proj*x, d), 119 gradients,delta 120 ) 121 updates,main_opt_state = main_opt.update(gradients,main_opt_state,params) 122 updates = jax.tree_util.tree_map(lambda g,d: g-d, updates, delta) 123 return updates,main_opt_state,internal_opt_state 124 125 def update_internal(gradients,main_opt_state,internal_opt_state,params,params_ref): 126 updates,internal_opt_state = internal_opt.update(gradients,internal_opt_state,params) 127 return updates,main_opt_state,internal_opt_state 128 129 def update_fn(gradients, state, params): 130 istep = state.istep % (nsteps_ref+1) 131 # jax.debug.print("{i} {j}",i=istep,j=state.istep) 132 133 params_ref = jax.lax.cond(istep==0, lambda a,b: a, lambda a,b: b, params, state.ref_weights) 134 135 updates,main_opt_state,internal_opt_state = jax.lax.cond( 136 istep==nsteps_ref, 137 update_main, 138 update_internal, 139 gradients,state.main_opt_state,state.internal_opt_state,params,params_ref 140 ) 141 142 new_state = PROFITState( 143 ref_weights=params_ref, 144 istep=state.istep+1, 145 main_opt_state=main_opt_state, 146 internal_opt_state=internal_opt_state 147 ) 148 return updates, new_state 149 150 return base.GradientTransformation(init_fn, update_fn) 151 152def get_optimizer( 153 training_parameters: Dict[str, any], variables: Dict, initial_lr: float 154) -> optax.GradientTransformation: 155 """ 156 Returns an optax.GradientTransformation object that can be used to optimize the model parameters. 157 158 Args: 159 - training_parameters: A dictionary containing the training parameters. 160 - variables: A pytree containing the model parameters. 161 - initial_lr: The initial learning rate. 162 163 Returns: 164 - An optax.GradientTransformation object that can be used to optimize the model parameters. 165 """ 166 167 default_status = str(training_parameters.get("default_status", "trainable")).lower() 168 assert default_status in [ 169 "trainable", 170 "frozen", 171 ], f"Default status must be 'trainable' or 'frozen', got {default_status}" 172 173 # find frozen and trainable parameters 174 frozen = training_parameters.get("frozen", []) 175 trainable = training_parameters.get("trainable", []) 176 177 def training_status(full_path, v): 178 full_path = "/".join(full_path[1:]).lower() 179 # full_path = "/".join(full_path[1:]).lower() 180 status = (default_status, "") 181 for path in frozen: 182 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 183 if re.match(path.lower(), full_path): 184 status = ("frozen", path) 185 for path in trainable: 186 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 187 if re.match(path.lower(), full_path): 188 status = ("trainable", path) 189 return status[0] 190 191 params_partition = traverse_util.path_aware_map(training_status, variables) 192 if len(frozen) > 0 or len(trainable) > 0: 193 print("params partition:") 194 print(json.dumps(params_partition, indent=2, sort_keys=False)) 195 196 ## Gradient preprocessing 197 grad_processing = [] 198 199 # zero nans 200 zero_nans = training_parameters.get("zero_nans", False) 201 if zero_nans: 202 grad_processing.append(optax.zero_nans()) 203 204 use_grokfast = training_parameters.get("use_grokfast", False) 205 if use_grokfast: 206 print("Using Grokfast") 207 alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9) 208 l_grokfast = training_parameters.get("l_grokfast", 1.0) 209 grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast)) 210 211 212 # gradient clipping 213 clip_threshold = training_parameters.get("gradient_clipping", -1.0) 214 if clip_threshold > 0.0: 215 print("Adaptive gradient clipping threshold:", clip_threshold) 216 grad_processing.append(optax.adaptive_grad_clip(clip_threshold)) 217 218 # OPTIMIZER 219 optimizer_name = training_parameters.get("optimizer", "adabelief") 220 optimizer = eval( 221 optimizer_name, 222 {"__builtins__": None}, 223 {**optax.__dict__,"profit":profit}, 224 ) 225 print("Optimizer:", optimizer_name) 226 optimizer_configuration = training_parameters.get("optimizer_config", {}) 227 optimizer_configuration["learning_rate"] = 1.0 228 grad_processing.append(optimizer(**optimizer_configuration)) 229 230 # weight decay 231 weight_decay = training_parameters.get("weight_decay", 0.0) 232 assert weight_decay >= 0.0, "Weight decay must be positive" 233 decay_targets = training_parameters.get("decay_targets", [""]) 234 235 def decay_status(full_path, v): 236 full_path = "/".join(full_path).lower() 237 status = False 238 # print(full_path,re.match(r'^params\/', full_path)) 239 for path in decay_targets: 240 if re.match(r"^params/" + path.lower(), full_path): 241 status = True 242 # if full_path.startswith("params/" + path.lower()): 243 # status = True 244 return status 245 246 decay_mask = traverse_util.path_aware_map(decay_status, variables) 247 if weight_decay > 0.0: 248 print("weight decay:", weight_decay) 249 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 250 grad_processing.append( 251 optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask) 252 ) 253 254 regularize_init_weight = training_parameters.get("regularize_init_weights", 0.) 255 if regularize_init_weight > 0.0: 256 print("Regularizing toward initial weights with L2 norm:", 257 regularize_init_weight) 258 if weight_decay <=0.: 259 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 260 261 grad_processing.append(add_weights_difference(weight_decay=-regularize_init_weight, mask=decay_mask)) 262 263 if zero_nans: 264 grad_processing.append(optax.zero_nans()) 265 266 # learning rate 267 grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr)) 268 269 ## define optimizer chain 270 optimizer_ = optax.chain( 271 *grad_processing, 272 ) 273 partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()} 274 return optax.multi_transform(partition_optimizer, params_partition) 275 276def get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters): 277 lr = training_parameters.get("lr", 1.0e-3) 278 init_lr = training_parameters.get("init_lr", lr / 25) 279 final_lr = training_parameters.get("final_lr", lr / 10000) 280 281 #### LEARNING RATE SCHEDULER #### 282 schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower() 283 schedule_type = training_parameters.get("scheduler", schedule_type).lower() 284 schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot") 285 286 adaptive_scheduler = False 287 print("Schedule type:", schedule_type) 288 if schedule_type == "cosine_onecycle": 289 transition_epochs = training_parameters.get("onecycle_epochs", max_epochs) 290 peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs) 291 schedule_ = optax.cosine_onecycle_schedule( 292 peak_value=lr, 293 div_factor=lr / init_lr, 294 final_div_factor=init_lr / final_lr, 295 transition_steps=transition_epochs * nbatch_per_epoch, 296 pct_start=peak_epoch / transition_epochs, 297 ) 298 sch_state = {"count": 0, "best": np.inf, "lr": init_lr} 299 300 def schedule(state, rmse=None): 301 new_state = {**state} 302 lr = schedule_(state["count"]) 303 if rmse is None: 304 new_state["count"] += 1 305 new_state["lr"] = lr 306 return lr, new_state 307 308 elif schedule_type == "piecewise_interpolate": 309 schedule_params = training_parameters.get("scheduler_parameters", {}) 310 schedule_ = optax.piecewise_interpolate_schedule( 311 **{"init_value":lr,"interpolate_type":"linear",**schedule_params} 312 ) 313 sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)} 314 def schedule(state, rmse=None): 315 new_state = {**state} 316 lr = schedule_(state["count"]) 317 if rmse is None: 318 new_state["count"] += 1 319 new_state["lr"] = lr 320 return lr, new_state 321 322 elif schedule_type == "constant": 323 sch_state = {"count": 0} 324 325 def schedule(state, rmse=None): 326 new_state = {**state} 327 new_state["lr"] = lr 328 if rmse is None: 329 new_state["count"] += 1 330 return lr, new_state 331 332 elif schedule_type == "reduce_on_plateau": 333 patience = training_parameters.get("patience", 10) 334 factor = training_parameters.get("lr_factor", 0.5) 335 patience_thr = training_parameters.get("patience_thr", 0.0) 336 sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience} 337 adaptive_scheduler = True 338 339 def schedule(state, rmse=None): 340 new_state = {**state} 341 if rmse is None: 342 new_state["count"] += 1 343 return state["lr"], new_state 344 if rmse <= state["best"] * (1.0 + patience_thr): 345 if rmse < state["best"]: 346 new_state["best"] = rmse 347 new_state["patience"] = 0 348 else: 349 new_state["patience"] += 1 350 if new_state["patience"] >= patience: 351 new_state["lr"] = state["lr"] * factor 352 new_state["patience"] = 0 353 print("Reducing learning rate to", new_state["lr"]) 354 return new_state["lr"], new_state 355 356 else: 357 raise ValueError(f"Unknown schedule_type: {schedule_type}") 358 359 stochastic_scheduler = training_parameters.get("stochastic_scheduler", False) 360 if stochastic_scheduler: 361 schedule_ = schedule 362 rng_key, scheduler_key = jax.random.split(rng_key) 363 sch_state["rng_key"] = scheduler_key 364 sch_state["lr_max"] = lr 365 sch_state["lr_min"] = final_lr 366 367 def schedule(state, rmse=None): 368 new_state = {**state, "lr": state["lr_max"]} 369 if rmse is None: 370 lr_max, new_state = schedule_(new_state, rmse=rmse) 371 lr_min = new_state["lr_min"] 372 new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"]) 373 lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey) 374 new_state["lr"] = lr 375 new_state["lr_max"] = lr_max 376 377 return new_state["lr"], new_state 378 379 return schedule, sch_state, schedule_metrics, adaptive_scheduler
class
AddWeightDiffState(typing.NamedTuple):
AddWeightDiffState(ref_weights,)
def
add_weights_difference( weight_decay: Union[float, jax.Array] = 0.0, mask: Union[Any, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]], Any], NoneType] = None) -> optax._src.base.GradientTransformation:
20def add_weights_difference( 21 weight_decay: Union[float, jax.Array] = 0.0, 22 mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None 23) -> base.GradientTransformation: 24 """weight decay toward initial weights.""" 25 def init_fn(params): 26 return AddWeightDiffState(ref_weights=params) 27 28 def update_fn(updates, state, params): 29 if params is None: 30 raise ValueError(base.NO_PARAMS_MSG) 31 updates = jax.tree_util.tree_map( 32 lambda g, p, pref: g + weight_decay * (p-pref), updates, params,state.ref_weights) 33 return updates, state 34 35 # If mask is not `None`, apply mask to the gradient transformation. 36 # E.g. it is common to skip weight decay on bias units and batch stats. 37 if mask is not None: 38 return wrappers.masked( 39 base.GradientTransformation(init_fn, update_fn), mask) 40 return base.GradientTransformation(init_fn, update_fn)
weight decay toward initial weights.
def
add_grokfast( alpha: float = 0.9, l: float = 1.0) -> optax._src.base.GradientTransformation:
44def add_grokfast( 45 alpha: float = 0.9, 46 l: float = 1., 47)-> base.GradientTransformation: 48 """Grokfast: amplify slow gradients by exponential moving average.""" 49 50 ema: base.GradientTransformation = optax.ema(decay=alpha,debias=False) 51 def init_fn(params): 52 return ema.init(params) 53 54 def update_fn(updates, state, params=None): 55 dupdates, state = ema.update(updates, state, params) 56 # updates = updates + l*dupdates 57 updates = jax.tree_util.tree_map( 58 lambda g,d: g+l*d, updates, dupdates) 59 return updates, state 60 61 return base.GradientTransformation(init_fn, update_fn)
Grokfast: amplify slow gradients by exponential moving average.
class
PROFITState(typing.NamedTuple):
64class PROFITState(NamedTuple): 65 ref_weights: Any 66 istep: int 67 main_opt_state: Any 68 internal_opt_state: Any
PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)
def
profit( learning_rate: Union[float, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]]], nsteps_ref: int = 1, main_opt: str = 'adam', main_opt_params: Dict[str, Any] = {}, internal_opt: str = 'sgd', internal_opt_params: Dict[str, Any] = {}, **kwargs) -> optax._src.base.GradientTransformation:
70def profit( 71 learning_rate: base.ScalarOrSchedule, 72 nsteps_ref: int = 1, 73 main_opt: str = 'adam', 74 main_opt_params: Dict[str, Any] = {}, 75 internal_opt: str = 'sgd', 76 internal_opt_params: Dict[str, Any] = {}, 77 **kwargs 78)-> base.GradientTransformation: 79 """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930""" 80 81 main_opt_params = {'learning_rate': learning_rate,**main_opt_params} 82 main_opt = eval( 83 main_opt, 84 {"__builtins__": None}, 85 {**optax.__dict__}, 86 )(**main_opt_params) 87 88 internal_opt_params = {'learning_rate': .1,**internal_opt_params} 89 internal_opt = eval( 90 internal_opt, 91 {"__builtins__": None}, 92 {**optax.__dict__}, 93 )(**internal_opt_params) 94 95 96 def init_fn(params): 97 return PROFITState( 98 ref_weights=params, 99 istep=0, 100 main_opt_state=main_opt.init(params), 101 internal_opt_state=internal_opt.init(params), 102 ) 103 104 105 def update_main(gradients,main_opt_state,internal_opt_state,params,params_ref): 106 delta = jax.tree_util.tree_map(lambda p, pref: p-pref, params, params_ref) 107 dot = jax.tree.reduce( 108 operator.add, 109 jax.tree_util.tree_map(lambda g,d: (g*d).sum(), gradients, delta) 110 ) 111 delta2 = jax.tree.reduce( 112 operator.add, 113 jax.tree_util.tree_map(lambda d: (d**2).sum(), delta) 114 ) 115 proj = dot/(delta2+1.e-6) 116 117 gradients = jax.lax.cond(dot>=0, 118 lambda g,d: g, 119 lambda g,d: jax.tree_util.tree_map(lambda x: proj*x, d), 120 gradients,delta 121 ) 122 updates,main_opt_state = main_opt.update(gradients,main_opt_state,params) 123 updates = jax.tree_util.tree_map(lambda g,d: g-d, updates, delta) 124 return updates,main_opt_state,internal_opt_state 125 126 def update_internal(gradients,main_opt_state,internal_opt_state,params,params_ref): 127 updates,internal_opt_state = internal_opt.update(gradients,internal_opt_state,params) 128 return updates,main_opt_state,internal_opt_state 129 130 def update_fn(gradients, state, params): 131 istep = state.istep % (nsteps_ref+1) 132 # jax.debug.print("{i} {j}",i=istep,j=state.istep) 133 134 params_ref = jax.lax.cond(istep==0, lambda a,b: a, lambda a,b: b, params, state.ref_weights) 135 136 updates,main_opt_state,internal_opt_state = jax.lax.cond( 137 istep==nsteps_ref, 138 update_main, 139 update_internal, 140 gradients,state.main_opt_state,state.internal_opt_state,params,params_ref 141 ) 142 143 new_state = PROFITState( 144 ref_weights=params_ref, 145 istep=state.istep+1, 146 main_opt_state=main_opt_state, 147 internal_opt_state=internal_opt_state 148 ) 149 return updates, new_state 150 151 return base.GradientTransformation(init_fn, update_fn)
PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930
def
get_optimizer( training_parameters: Dict[str, <built-in function any>], variables: Dict, initial_lr: float) -> optax._src.base.GradientTransformation:
153def get_optimizer( 154 training_parameters: Dict[str, any], variables: Dict, initial_lr: float 155) -> optax.GradientTransformation: 156 """ 157 Returns an optax.GradientTransformation object that can be used to optimize the model parameters. 158 159 Args: 160 - training_parameters: A dictionary containing the training parameters. 161 - variables: A pytree containing the model parameters. 162 - initial_lr: The initial learning rate. 163 164 Returns: 165 - An optax.GradientTransformation object that can be used to optimize the model parameters. 166 """ 167 168 default_status = str(training_parameters.get("default_status", "trainable")).lower() 169 assert default_status in [ 170 "trainable", 171 "frozen", 172 ], f"Default status must be 'trainable' or 'frozen', got {default_status}" 173 174 # find frozen and trainable parameters 175 frozen = training_parameters.get("frozen", []) 176 trainable = training_parameters.get("trainable", []) 177 178 def training_status(full_path, v): 179 full_path = "/".join(full_path[1:]).lower() 180 # full_path = "/".join(full_path[1:]).lower() 181 status = (default_status, "") 182 for path in frozen: 183 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 184 if re.match(path.lower(), full_path): 185 status = ("frozen", path) 186 for path in trainable: 187 # if full_path.startswith(path.lower()) and len(path) > len(status[1]): 188 if re.match(path.lower(), full_path): 189 status = ("trainable", path) 190 return status[0] 191 192 params_partition = traverse_util.path_aware_map(training_status, variables) 193 if len(frozen) > 0 or len(trainable) > 0: 194 print("params partition:") 195 print(json.dumps(params_partition, indent=2, sort_keys=False)) 196 197 ## Gradient preprocessing 198 grad_processing = [] 199 200 # zero nans 201 zero_nans = training_parameters.get("zero_nans", False) 202 if zero_nans: 203 grad_processing.append(optax.zero_nans()) 204 205 use_grokfast = training_parameters.get("use_grokfast", False) 206 if use_grokfast: 207 print("Using Grokfast") 208 alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9) 209 l_grokfast = training_parameters.get("l_grokfast", 1.0) 210 grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast)) 211 212 213 # gradient clipping 214 clip_threshold = training_parameters.get("gradient_clipping", -1.0) 215 if clip_threshold > 0.0: 216 print("Adaptive gradient clipping threshold:", clip_threshold) 217 grad_processing.append(optax.adaptive_grad_clip(clip_threshold)) 218 219 # OPTIMIZER 220 optimizer_name = training_parameters.get("optimizer", "adabelief") 221 optimizer = eval( 222 optimizer_name, 223 {"__builtins__": None}, 224 {**optax.__dict__,"profit":profit}, 225 ) 226 print("Optimizer:", optimizer_name) 227 optimizer_configuration = training_parameters.get("optimizer_config", {}) 228 optimizer_configuration["learning_rate"] = 1.0 229 grad_processing.append(optimizer(**optimizer_configuration)) 230 231 # weight decay 232 weight_decay = training_parameters.get("weight_decay", 0.0) 233 assert weight_decay >= 0.0, "Weight decay must be positive" 234 decay_targets = training_parameters.get("decay_targets", [""]) 235 236 def decay_status(full_path, v): 237 full_path = "/".join(full_path).lower() 238 status = False 239 # print(full_path,re.match(r'^params\/', full_path)) 240 for path in decay_targets: 241 if re.match(r"^params/" + path.lower(), full_path): 242 status = True 243 # if full_path.startswith("params/" + path.lower()): 244 # status = True 245 return status 246 247 decay_mask = traverse_util.path_aware_map(decay_status, variables) 248 if weight_decay > 0.0: 249 print("weight decay:", weight_decay) 250 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 251 grad_processing.append( 252 optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask) 253 ) 254 255 regularize_init_weight = training_parameters.get("regularize_init_weights", 0.) 256 if regularize_init_weight > 0.0: 257 print("Regularizing toward initial weights with L2 norm:", 258 regularize_init_weight) 259 if weight_decay <=0.: 260 print(json.dumps(decay_mask, indent=2, sort_keys=False)) 261 262 grad_processing.append(add_weights_difference(weight_decay=-regularize_init_weight, mask=decay_mask)) 263 264 if zero_nans: 265 grad_processing.append(optax.zero_nans()) 266 267 # learning rate 268 grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr)) 269 270 ## define optimizer chain 271 optimizer_ = optax.chain( 272 *grad_processing, 273 ) 274 partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()} 275 return optax.multi_transform(partition_optimizer, params_partition)
Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
Args:
- training_parameters: A dictionary containing the training parameters.
- variables: A pytree containing the model parameters.
- initial_lr: The initial learning rate.
Returns:
- An optax.GradientTransformation object that can be used to optimize the model parameters.
def
get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
277def get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters): 278 lr = training_parameters.get("lr", 1.0e-3) 279 init_lr = training_parameters.get("init_lr", lr / 25) 280 final_lr = training_parameters.get("final_lr", lr / 10000) 281 282 #### LEARNING RATE SCHEDULER #### 283 schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower() 284 schedule_type = training_parameters.get("scheduler", schedule_type).lower() 285 schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot") 286 287 adaptive_scheduler = False 288 print("Schedule type:", schedule_type) 289 if schedule_type == "cosine_onecycle": 290 transition_epochs = training_parameters.get("onecycle_epochs", max_epochs) 291 peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs) 292 schedule_ = optax.cosine_onecycle_schedule( 293 peak_value=lr, 294 div_factor=lr / init_lr, 295 final_div_factor=init_lr / final_lr, 296 transition_steps=transition_epochs * nbatch_per_epoch, 297 pct_start=peak_epoch / transition_epochs, 298 ) 299 sch_state = {"count": 0, "best": np.inf, "lr": init_lr} 300 301 def schedule(state, rmse=None): 302 new_state = {**state} 303 lr = schedule_(state["count"]) 304 if rmse is None: 305 new_state["count"] += 1 306 new_state["lr"] = lr 307 return lr, new_state 308 309 elif schedule_type == "piecewise_interpolate": 310 schedule_params = training_parameters.get("scheduler_parameters", {}) 311 schedule_ = optax.piecewise_interpolate_schedule( 312 **{"init_value":lr,"interpolate_type":"linear",**schedule_params} 313 ) 314 sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)} 315 def schedule(state, rmse=None): 316 new_state = {**state} 317 lr = schedule_(state["count"]) 318 if rmse is None: 319 new_state["count"] += 1 320 new_state["lr"] = lr 321 return lr, new_state 322 323 elif schedule_type == "constant": 324 sch_state = {"count": 0} 325 326 def schedule(state, rmse=None): 327 new_state = {**state} 328 new_state["lr"] = lr 329 if rmse is None: 330 new_state["count"] += 1 331 return lr, new_state 332 333 elif schedule_type == "reduce_on_plateau": 334 patience = training_parameters.get("patience", 10) 335 factor = training_parameters.get("lr_factor", 0.5) 336 patience_thr = training_parameters.get("patience_thr", 0.0) 337 sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience} 338 adaptive_scheduler = True 339 340 def schedule(state, rmse=None): 341 new_state = {**state} 342 if rmse is None: 343 new_state["count"] += 1 344 return state["lr"], new_state 345 if rmse <= state["best"] * (1.0 + patience_thr): 346 if rmse < state["best"]: 347 new_state["best"] = rmse 348 new_state["patience"] = 0 349 else: 350 new_state["patience"] += 1 351 if new_state["patience"] >= patience: 352 new_state["lr"] = state["lr"] * factor 353 new_state["patience"] = 0 354 print("Reducing learning rate to", new_state["lr"]) 355 return new_state["lr"], new_state 356 357 else: 358 raise ValueError(f"Unknown schedule_type: {schedule_type}") 359 360 stochastic_scheduler = training_parameters.get("stochastic_scheduler", False) 361 if stochastic_scheduler: 362 schedule_ = schedule 363 rng_key, scheduler_key = jax.random.split(rng_key) 364 sch_state["rng_key"] = scheduler_key 365 sch_state["lr_max"] = lr 366 sch_state["lr_min"] = final_lr 367 368 def schedule(state, rmse=None): 369 new_state = {**state, "lr": state["lr_max"]} 370 if rmse is None: 371 lr_max, new_state = schedule_(new_state, rmse=rmse) 372 lr_min = new_state["lr_min"] 373 new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"]) 374 lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey) 375 new_state["lr"] = lr 376 new_state["lr_max"] = lr_max 377 378 return new_state["lr"], new_state 379 380 return schedule, sch_state, schedule_metrics, adaptive_scheduler