fennol.training.optimizers

  1from typing import Callable, Optional, Dict, List, Tuple, Union, Any, NamedTuple
  2import optax
  3import jax
  4import jax.numpy as jnp
  5import numpy as np
  6import operator
  7from flax import traverse_util
  8import json
  9import re
 10
 11from optax._src import base
 12from optax._src import wrappers
 13
 14
 15
 16class AddWeightDiffState(NamedTuple):
 17    ref_weights: Any
 18
 19def add_weights_difference(
 20    weight_decay: Union[float, jax.Array] = 0.0,
 21    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None
 22) -> base.GradientTransformation:
 23  """weight decay toward initial weights."""
 24  def init_fn(params):
 25    return AddWeightDiffState(ref_weights=params)
 26
 27  def update_fn(updates, state, params):
 28    if params is None:
 29      raise ValueError(base.NO_PARAMS_MSG)
 30    updates = jax.tree_util.tree_map(
 31        lambda g, p, pref: g + weight_decay * (p-pref), updates, params,state.ref_weights)
 32    return updates, state
 33
 34  # If mask is not `None`, apply mask to the gradient transformation.
 35  # E.g. it is common to skip weight decay on bias units and batch stats.
 36  if mask is not None:
 37    return wrappers.masked(
 38        base.GradientTransformation(init_fn, update_fn), mask)
 39  return base.GradientTransformation(init_fn, update_fn)
 40
 41
 42
 43def add_grokfast(
 44    alpha: float = 0.9,
 45    l: float = 1.,
 46)-> base.GradientTransformation:
 47    """Grokfast: amplify slow gradients by exponential moving average."""
 48
 49    ema: base.GradientTransformation = optax.ema(decay=alpha,debias=False)
 50    def init_fn(params):
 51        return ema.init(params)
 52    
 53    def update_fn(updates, state, params=None):
 54        dupdates, state = ema.update(updates, state, params)
 55        # updates = updates + l*dupdates
 56        updates = jax.tree_util.tree_map(
 57            lambda g,d: g+l*d, updates, dupdates)
 58        return updates, state
 59    
 60    return base.GradientTransformation(init_fn, update_fn)
 61
 62
 63class PROFITState(NamedTuple):
 64    ref_weights: Any
 65    istep: int
 66    main_opt_state: Any
 67    internal_opt_state: Any
 68
 69def profit(
 70    learning_rate: base.ScalarOrSchedule,
 71    nsteps_ref: int = 1,
 72    main_opt: str = 'adam',
 73    main_opt_params: Dict[str, Any] = {},
 74    internal_opt: str = 'sgd',
 75    internal_opt_params: Dict[str, Any] = {},
 76    **kwargs
 77)-> base.GradientTransformation:
 78    """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930"""
 79
 80    main_opt_params = {'learning_rate': learning_rate,**main_opt_params}
 81    main_opt = eval(
 82        main_opt,
 83        {"__builtins__": None},
 84        {**optax.__dict__},
 85    )(**main_opt_params)
 86
 87    internal_opt_params = {'learning_rate': .1,**internal_opt_params}
 88    internal_opt = eval(
 89        internal_opt,
 90        {"__builtins__": None},
 91        {**optax.__dict__},
 92    )(**internal_opt_params)
 93
 94
 95    def init_fn(params):
 96        return PROFITState(
 97            ref_weights=params,
 98            istep=0,
 99            main_opt_state=main_opt.init(params),
100            internal_opt_state=internal_opt.init(params),
101        )
102    
103
104    def update_main(gradients,main_opt_state,internal_opt_state,params,params_ref):
105        delta = jax.tree_util.tree_map(lambda p, pref: p-pref, params, params_ref)
106        dot = jax.tree.reduce(
107           operator.add,
108           jax.tree_util.tree_map(lambda g,d: (g*d).sum(), gradients, delta)
109        )
110        delta2 = jax.tree.reduce(
111           operator.add,
112           jax.tree_util.tree_map(lambda d: (d**2).sum(), delta)
113        )
114        proj = dot/(delta2+1.e-6)
115
116        gradients = jax.lax.cond(dot>=0,
117            lambda g,d: g,
118            lambda g,d: jax.tree_util.tree_map(lambda x: proj*x, d),
119            gradients,delta
120        )
121        updates,main_opt_state = main_opt.update(gradients,main_opt_state,params)
122        updates = jax.tree_util.tree_map(lambda g,d: g-d, updates, delta)
123        return updates,main_opt_state,internal_opt_state
124    
125    def update_internal(gradients,main_opt_state,internal_opt_state,params,params_ref):
126        updates,internal_opt_state = internal_opt.update(gradients,internal_opt_state,params)
127        return updates,main_opt_state,internal_opt_state
128    
129    def update_fn(gradients, state, params):
130        istep = state.istep % (nsteps_ref+1)
131        # jax.debug.print("{i} {j}",i=istep,j=state.istep)
132
133        params_ref = jax.lax.cond(istep==0, lambda a,b: a, lambda a,b: b, params, state.ref_weights)
134
135        updates,main_opt_state,internal_opt_state = jax.lax.cond(
136            istep==nsteps_ref,
137            update_main,
138            update_internal,
139            gradients,state.main_opt_state,state.internal_opt_state,params,params_ref
140        )
141
142        new_state = PROFITState(
143            ref_weights=params_ref,
144            istep=state.istep+1,
145            main_opt_state=main_opt_state,
146            internal_opt_state=internal_opt_state
147        )
148        return updates, new_state
149
150    return base.GradientTransformation(init_fn, update_fn)
151
152def get_optimizer(
153    training_parameters: Dict[str, any], variables: Dict, initial_lr: float
154) -> optax.GradientTransformation:
155    """
156    Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
157
158    Args:
159    - training_parameters: A dictionary containing the training parameters.
160    - variables: A  pytree containing the model parameters.
161    - initial_lr: The initial learning rate.
162
163    Returns:
164    - An optax.GradientTransformation object that can be used to optimize the model parameters.
165    """
166
167    default_status = str(training_parameters.get("default_status", "trainable")).lower()
168    assert default_status in [
169        "trainable",
170        "frozen",
171    ], f"Default status must be 'trainable' or 'frozen', got {default_status}"
172
173    # find frozen and trainable parameters
174    frozen = training_parameters.get("frozen", [])
175    trainable = training_parameters.get("trainable", [])
176
177    def training_status(full_path, v):
178        full_path = "/".join(full_path[1:]).lower()
179        # full_path = "/".join(full_path[1:]).lower()
180        status = (default_status, "")
181        for path in frozen:
182            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
183            if re.match(path.lower(), full_path):
184                status = ("frozen", path)
185        for path in trainable:
186            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
187            if re.match(path.lower(), full_path):
188                status = ("trainable", path)
189        return status[0]
190
191    params_partition = traverse_util.path_aware_map(training_status, variables)
192    if len(frozen) > 0 or len(trainable) > 0:
193        print("params partition:")
194        print(json.dumps(params_partition, indent=2, sort_keys=False))
195
196    ## Gradient preprocessing
197    grad_processing = []
198
199    # zero nans
200    zero_nans = training_parameters.get("zero_nans", False)
201    if zero_nans:
202        grad_processing.append(optax.zero_nans())
203
204    use_grokfast = training_parameters.get("use_grokfast", False)
205    if use_grokfast:
206        print("Using Grokfast")
207        alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9)
208        l_grokfast = training_parameters.get("l_grokfast", 1.0)
209        grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast))
210
211
212    # gradient clipping
213    clip_threshold = training_parameters.get("gradient_clipping", -1.0)
214    if clip_threshold > 0.0:
215        print("Adaptive gradient clipping threshold:", clip_threshold)
216        grad_processing.append(optax.adaptive_grad_clip(clip_threshold))
217
218    # OPTIMIZER
219    optimizer_name = training_parameters.get("optimizer", "adabelief")
220    optimizer = eval(
221        optimizer_name,
222        {"__builtins__": None},
223        {**optax.__dict__,"profit":profit},
224    )
225    print("Optimizer:", optimizer_name)
226    optimizer_configuration = training_parameters.get("optimizer_config", {})
227    optimizer_configuration["learning_rate"] = 1.0
228    grad_processing.append(optimizer(**optimizer_configuration))
229
230    # weight decay
231    weight_decay = training_parameters.get("weight_decay", 0.0)
232    assert weight_decay >= 0.0, "Weight decay must be positive"
233    decay_targets = training_parameters.get("decay_targets", [""])
234
235    def decay_status(full_path, v):
236        full_path = "/".join(full_path).lower()
237        status = False
238        # print(full_path,re.match(r'^params\/', full_path))
239        for path in decay_targets:
240            if re.match(r"^params/" + path.lower(), full_path):
241                status = True
242            # if full_path.startswith("params/" + path.lower()):
243            #     status = True
244        return status
245
246    decay_mask = traverse_util.path_aware_map(decay_status, variables)
247    if weight_decay > 0.0:
248        print("weight decay:", weight_decay)
249        print(json.dumps(decay_mask, indent=2, sort_keys=False))
250        grad_processing.append(
251            optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask)
252        )
253    
254    regularize_init_weight = training_parameters.get("regularize_init_weights", 0.)
255    if regularize_init_weight > 0.0:
256        print("Regularizing toward initial weights with L2 norm:", 
257        regularize_init_weight)
258        if weight_decay <=0.:
259            print(json.dumps(decay_mask, indent=2, sort_keys=False))
260
261        grad_processing.append(add_weights_difference(weight_decay=-regularize_init_weight, mask=decay_mask))
262
263    if zero_nans:
264        grad_processing.append(optax.zero_nans())
265
266    # learning rate
267    grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr))
268
269    ## define optimizer chain
270    optimizer_ = optax.chain(
271        *grad_processing,
272    )
273    partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()}
274    return optax.multi_transform(partition_optimizer, params_partition)
275
276def get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters):
277    lr = training_parameters.get("lr", 1.0e-3)
278    init_lr = training_parameters.get("init_lr", lr / 25)
279    final_lr = training_parameters.get("final_lr", lr / 10000)
280
281    #### LEARNING RATE SCHEDULER ####
282    schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower()
283    schedule_type = training_parameters.get("scheduler", schedule_type).lower()
284    schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot")
285
286    adaptive_scheduler = False
287    print("Schedule type:", schedule_type)
288    if schedule_type == "cosine_onecycle":
289        transition_epochs = training_parameters.get("onecycle_epochs", max_epochs)
290        peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs)
291        schedule_ = optax.cosine_onecycle_schedule(
292            peak_value=lr,
293            div_factor=lr / init_lr,
294            final_div_factor=init_lr / final_lr,
295            transition_steps=transition_epochs * nbatch_per_epoch,
296            pct_start=peak_epoch / transition_epochs,
297        )
298        sch_state = {"count": 0, "best": np.inf, "lr": init_lr}
299
300        def schedule(state, rmse=None):
301            new_state = {**state}
302            lr = schedule_(state["count"])
303            if rmse is None:
304                new_state["count"] += 1
305            new_state["lr"] = lr
306            return lr, new_state
307    
308    elif schedule_type == "piecewise_interpolate":
309        schedule_params = training_parameters.get("scheduler_parameters", {})
310        schedule_ = optax.piecewise_interpolate_schedule(
311            **{"init_value":lr,"interpolate_type":"linear",**schedule_params}
312        )
313        sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)}
314        def schedule(state, rmse=None):
315            new_state = {**state}
316            lr = schedule_(state["count"])
317            if rmse is None:
318                new_state["count"] += 1
319            new_state["lr"] = lr
320            return lr, new_state
321
322    elif schedule_type == "constant":
323        sch_state = {"count": 0}
324
325        def schedule(state, rmse=None):
326            new_state = {**state}
327            new_state["lr"] = lr
328            if rmse is None:
329                new_state["count"] += 1
330            return lr, new_state
331
332    elif schedule_type == "reduce_on_plateau":
333        patience = training_parameters.get("patience", 10)
334        factor = training_parameters.get("lr_factor", 0.5)
335        patience_thr = training_parameters.get("patience_thr", 0.0)
336        sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience}
337        adaptive_scheduler = True
338
339        def schedule(state, rmse=None):
340            new_state = {**state}
341            if rmse is None:
342                new_state["count"] += 1
343                return state["lr"], new_state
344            if rmse <= state["best"] * (1.0 + patience_thr):
345                if rmse < state["best"]:
346                    new_state["best"] = rmse
347                new_state["patience"] = 0
348            else:
349                new_state["patience"] += 1
350                if new_state["patience"] >= patience:
351                    new_state["lr"] = state["lr"] * factor
352                    new_state["patience"] = 0
353                    print("Reducing learning rate to", new_state["lr"])
354            return new_state["lr"], new_state
355
356    else:
357        raise ValueError(f"Unknown schedule_type: {schedule_type}")
358
359    stochastic_scheduler = training_parameters.get("stochastic_scheduler", False)
360    if stochastic_scheduler:
361        schedule_ = schedule
362        rng_key, scheduler_key = jax.random.split(rng_key)
363        sch_state["rng_key"] = scheduler_key
364        sch_state["lr_max"] = lr
365        sch_state["lr_min"] = final_lr
366
367        def schedule(state, rmse=None):
368            new_state = {**state, "lr": state["lr_max"]}
369            if rmse is None:
370                lr_max, new_state = schedule_(new_state, rmse=rmse)
371                lr_min = new_state["lr_min"]
372                new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"])
373                lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey)
374                new_state["lr"] = lr
375                new_state["lr_max"] = lr_max
376
377            return new_state["lr"], new_state
378    
379    return schedule, sch_state, schedule_metrics, adaptive_scheduler
class AddWeightDiffState(typing.NamedTuple):
17class AddWeightDiffState(NamedTuple):
18    ref_weights: Any

AddWeightDiffState(ref_weights,)

AddWeightDiffState(ref_weights: Any)

Create new instance of AddWeightDiffState(ref_weights,)

ref_weights: Any

Alias for field number 0

def add_weights_difference( weight_decay: Union[float, jax.Array] = 0.0, mask: Union[Any, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]], Any], NoneType] = None) -> optax._src.base.GradientTransformation:
20def add_weights_difference(
21    weight_decay: Union[float, jax.Array] = 0.0,
22    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None
23) -> base.GradientTransformation:
24  """weight decay toward initial weights."""
25  def init_fn(params):
26    return AddWeightDiffState(ref_weights=params)
27
28  def update_fn(updates, state, params):
29    if params is None:
30      raise ValueError(base.NO_PARAMS_MSG)
31    updates = jax.tree_util.tree_map(
32        lambda g, p, pref: g + weight_decay * (p-pref), updates, params,state.ref_weights)
33    return updates, state
34
35  # If mask is not `None`, apply mask to the gradient transformation.
36  # E.g. it is common to skip weight decay on bias units and batch stats.
37  if mask is not None:
38    return wrappers.masked(
39        base.GradientTransformation(init_fn, update_fn), mask)
40  return base.GradientTransformation(init_fn, update_fn)

weight decay toward initial weights.

def add_grokfast( alpha: float = 0.9, l: float = 1.0) -> optax._src.base.GradientTransformation:
44def add_grokfast(
45    alpha: float = 0.9,
46    l: float = 1.,
47)-> base.GradientTransformation:
48    """Grokfast: amplify slow gradients by exponential moving average."""
49
50    ema: base.GradientTransformation = optax.ema(decay=alpha,debias=False)
51    def init_fn(params):
52        return ema.init(params)
53    
54    def update_fn(updates, state, params=None):
55        dupdates, state = ema.update(updates, state, params)
56        # updates = updates + l*dupdates
57        updates = jax.tree_util.tree_map(
58            lambda g,d: g+l*d, updates, dupdates)
59        return updates, state
60    
61    return base.GradientTransformation(init_fn, update_fn)

Grokfast: amplify slow gradients by exponential moving average.

class PROFITState(typing.NamedTuple):
64class PROFITState(NamedTuple):
65    ref_weights: Any
66    istep: int
67    main_opt_state: Any
68    internal_opt_state: Any

PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)

PROFITState( ref_weights: Any, istep: int, main_opt_state: Any, internal_opt_state: Any)

Create new instance of PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)

ref_weights: Any

Alias for field number 0

istep: int

Alias for field number 1

main_opt_state: Any

Alias for field number 2

internal_opt_state: Any

Alias for field number 3

def profit( learning_rate: Union[float, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]]], nsteps_ref: int = 1, main_opt: str = 'adam', main_opt_params: Dict[str, Any] = {}, internal_opt: str = 'sgd', internal_opt_params: Dict[str, Any] = {}, **kwargs) -> optax._src.base.GradientTransformation:
 70def profit(
 71    learning_rate: base.ScalarOrSchedule,
 72    nsteps_ref: int = 1,
 73    main_opt: str = 'adam',
 74    main_opt_params: Dict[str, Any] = {},
 75    internal_opt: str = 'sgd',
 76    internal_opt_params: Dict[str, Any] = {},
 77    **kwargs
 78)-> base.GradientTransformation:
 79    """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930"""
 80
 81    main_opt_params = {'learning_rate': learning_rate,**main_opt_params}
 82    main_opt = eval(
 83        main_opt,
 84        {"__builtins__": None},
 85        {**optax.__dict__},
 86    )(**main_opt_params)
 87
 88    internal_opt_params = {'learning_rate': .1,**internal_opt_params}
 89    internal_opt = eval(
 90        internal_opt,
 91        {"__builtins__": None},
 92        {**optax.__dict__},
 93    )(**internal_opt_params)
 94
 95
 96    def init_fn(params):
 97        return PROFITState(
 98            ref_weights=params,
 99            istep=0,
100            main_opt_state=main_opt.init(params),
101            internal_opt_state=internal_opt.init(params),
102        )
103    
104
105    def update_main(gradients,main_opt_state,internal_opt_state,params,params_ref):
106        delta = jax.tree_util.tree_map(lambda p, pref: p-pref, params, params_ref)
107        dot = jax.tree.reduce(
108           operator.add,
109           jax.tree_util.tree_map(lambda g,d: (g*d).sum(), gradients, delta)
110        )
111        delta2 = jax.tree.reduce(
112           operator.add,
113           jax.tree_util.tree_map(lambda d: (d**2).sum(), delta)
114        )
115        proj = dot/(delta2+1.e-6)
116
117        gradients = jax.lax.cond(dot>=0,
118            lambda g,d: g,
119            lambda g,d: jax.tree_util.tree_map(lambda x: proj*x, d),
120            gradients,delta
121        )
122        updates,main_opt_state = main_opt.update(gradients,main_opt_state,params)
123        updates = jax.tree_util.tree_map(lambda g,d: g-d, updates, delta)
124        return updates,main_opt_state,internal_opt_state
125    
126    def update_internal(gradients,main_opt_state,internal_opt_state,params,params_ref):
127        updates,internal_opt_state = internal_opt.update(gradients,internal_opt_state,params)
128        return updates,main_opt_state,internal_opt_state
129    
130    def update_fn(gradients, state, params):
131        istep = state.istep % (nsteps_ref+1)
132        # jax.debug.print("{i} {j}",i=istep,j=state.istep)
133
134        params_ref = jax.lax.cond(istep==0, lambda a,b: a, lambda a,b: b, params, state.ref_weights)
135
136        updates,main_opt_state,internal_opt_state = jax.lax.cond(
137            istep==nsteps_ref,
138            update_main,
139            update_internal,
140            gradients,state.main_opt_state,state.internal_opt_state,params,params_ref
141        )
142
143        new_state = PROFITState(
144            ref_weights=params_ref,
145            istep=state.istep+1,
146            main_opt_state=main_opt_state,
147            internal_opt_state=internal_opt_state
148        )
149        return updates, new_state
150
151    return base.GradientTransformation(init_fn, update_fn)

PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930

def get_optimizer( training_parameters: Dict[str, <built-in function any>], variables: Dict, initial_lr: float) -> optax._src.base.GradientTransformation:
153def get_optimizer(
154    training_parameters: Dict[str, any], variables: Dict, initial_lr: float
155) -> optax.GradientTransformation:
156    """
157    Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
158
159    Args:
160    - training_parameters: A dictionary containing the training parameters.
161    - variables: A  pytree containing the model parameters.
162    - initial_lr: The initial learning rate.
163
164    Returns:
165    - An optax.GradientTransformation object that can be used to optimize the model parameters.
166    """
167
168    default_status = str(training_parameters.get("default_status", "trainable")).lower()
169    assert default_status in [
170        "trainable",
171        "frozen",
172    ], f"Default status must be 'trainable' or 'frozen', got {default_status}"
173
174    # find frozen and trainable parameters
175    frozen = training_parameters.get("frozen", [])
176    trainable = training_parameters.get("trainable", [])
177
178    def training_status(full_path, v):
179        full_path = "/".join(full_path[1:]).lower()
180        # full_path = "/".join(full_path[1:]).lower()
181        status = (default_status, "")
182        for path in frozen:
183            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
184            if re.match(path.lower(), full_path):
185                status = ("frozen", path)
186        for path in trainable:
187            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
188            if re.match(path.lower(), full_path):
189                status = ("trainable", path)
190        return status[0]
191
192    params_partition = traverse_util.path_aware_map(training_status, variables)
193    if len(frozen) > 0 or len(trainable) > 0:
194        print("params partition:")
195        print(json.dumps(params_partition, indent=2, sort_keys=False))
196
197    ## Gradient preprocessing
198    grad_processing = []
199
200    # zero nans
201    zero_nans = training_parameters.get("zero_nans", False)
202    if zero_nans:
203        grad_processing.append(optax.zero_nans())
204
205    use_grokfast = training_parameters.get("use_grokfast", False)
206    if use_grokfast:
207        print("Using Grokfast")
208        alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9)
209        l_grokfast = training_parameters.get("l_grokfast", 1.0)
210        grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast))
211
212
213    # gradient clipping
214    clip_threshold = training_parameters.get("gradient_clipping", -1.0)
215    if clip_threshold > 0.0:
216        print("Adaptive gradient clipping threshold:", clip_threshold)
217        grad_processing.append(optax.adaptive_grad_clip(clip_threshold))
218
219    # OPTIMIZER
220    optimizer_name = training_parameters.get("optimizer", "adabelief")
221    optimizer = eval(
222        optimizer_name,
223        {"__builtins__": None},
224        {**optax.__dict__,"profit":profit},
225    )
226    print("Optimizer:", optimizer_name)
227    optimizer_configuration = training_parameters.get("optimizer_config", {})
228    optimizer_configuration["learning_rate"] = 1.0
229    grad_processing.append(optimizer(**optimizer_configuration))
230
231    # weight decay
232    weight_decay = training_parameters.get("weight_decay", 0.0)
233    assert weight_decay >= 0.0, "Weight decay must be positive"
234    decay_targets = training_parameters.get("decay_targets", [""])
235
236    def decay_status(full_path, v):
237        full_path = "/".join(full_path).lower()
238        status = False
239        # print(full_path,re.match(r'^params\/', full_path))
240        for path in decay_targets:
241            if re.match(r"^params/" + path.lower(), full_path):
242                status = True
243            # if full_path.startswith("params/" + path.lower()):
244            #     status = True
245        return status
246
247    decay_mask = traverse_util.path_aware_map(decay_status, variables)
248    if weight_decay > 0.0:
249        print("weight decay:", weight_decay)
250        print(json.dumps(decay_mask, indent=2, sort_keys=False))
251        grad_processing.append(
252            optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask)
253        )
254    
255    regularize_init_weight = training_parameters.get("regularize_init_weights", 0.)
256    if regularize_init_weight > 0.0:
257        print("Regularizing toward initial weights with L2 norm:", 
258        regularize_init_weight)
259        if weight_decay <=0.:
260            print(json.dumps(decay_mask, indent=2, sort_keys=False))
261
262        grad_processing.append(add_weights_difference(weight_decay=-regularize_init_weight, mask=decay_mask))
263
264    if zero_nans:
265        grad_processing.append(optax.zero_nans())
266
267    # learning rate
268    grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr))
269
270    ## define optimizer chain
271    optimizer_ = optax.chain(
272        *grad_processing,
273    )
274    partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()}
275    return optax.multi_transform(partition_optimizer, params_partition)

Returns an optax.GradientTransformation object that can be used to optimize the model parameters.

Args:

  • training_parameters: A dictionary containing the training parameters.
  • variables: A pytree containing the model parameters.
  • initial_lr: The initial learning rate.

Returns:

  • An optax.GradientTransformation object that can be used to optimize the model parameters.
def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
277def get_lr_schedule(max_epochs,nbatch_per_epoch,training_parameters):
278    lr = training_parameters.get("lr", 1.0e-3)
279    init_lr = training_parameters.get("init_lr", lr / 25)
280    final_lr = training_parameters.get("final_lr", lr / 10000)
281
282    #### LEARNING RATE SCHEDULER ####
283    schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower()
284    schedule_type = training_parameters.get("scheduler", schedule_type).lower()
285    schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot")
286
287    adaptive_scheduler = False
288    print("Schedule type:", schedule_type)
289    if schedule_type == "cosine_onecycle":
290        transition_epochs = training_parameters.get("onecycle_epochs", max_epochs)
291        peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs)
292        schedule_ = optax.cosine_onecycle_schedule(
293            peak_value=lr,
294            div_factor=lr / init_lr,
295            final_div_factor=init_lr / final_lr,
296            transition_steps=transition_epochs * nbatch_per_epoch,
297            pct_start=peak_epoch / transition_epochs,
298        )
299        sch_state = {"count": 0, "best": np.inf, "lr": init_lr}
300
301        def schedule(state, rmse=None):
302            new_state = {**state}
303            lr = schedule_(state["count"])
304            if rmse is None:
305                new_state["count"] += 1
306            new_state["lr"] = lr
307            return lr, new_state
308    
309    elif schedule_type == "piecewise_interpolate":
310        schedule_params = training_parameters.get("scheduler_parameters", {})
311        schedule_ = optax.piecewise_interpolate_schedule(
312            **{"init_value":lr,"interpolate_type":"linear",**schedule_params}
313        )
314        sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)}
315        def schedule(state, rmse=None):
316            new_state = {**state}
317            lr = schedule_(state["count"])
318            if rmse is None:
319                new_state["count"] += 1
320            new_state["lr"] = lr
321            return lr, new_state
322
323    elif schedule_type == "constant":
324        sch_state = {"count": 0}
325
326        def schedule(state, rmse=None):
327            new_state = {**state}
328            new_state["lr"] = lr
329            if rmse is None:
330                new_state["count"] += 1
331            return lr, new_state
332
333    elif schedule_type == "reduce_on_plateau":
334        patience = training_parameters.get("patience", 10)
335        factor = training_parameters.get("lr_factor", 0.5)
336        patience_thr = training_parameters.get("patience_thr", 0.0)
337        sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience}
338        adaptive_scheduler = True
339
340        def schedule(state, rmse=None):
341            new_state = {**state}
342            if rmse is None:
343                new_state["count"] += 1
344                return state["lr"], new_state
345            if rmse <= state["best"] * (1.0 + patience_thr):
346                if rmse < state["best"]:
347                    new_state["best"] = rmse
348                new_state["patience"] = 0
349            else:
350                new_state["patience"] += 1
351                if new_state["patience"] >= patience:
352                    new_state["lr"] = state["lr"] * factor
353                    new_state["patience"] = 0
354                    print("Reducing learning rate to", new_state["lr"])
355            return new_state["lr"], new_state
356
357    else:
358        raise ValueError(f"Unknown schedule_type: {schedule_type}")
359
360    stochastic_scheduler = training_parameters.get("stochastic_scheduler", False)
361    if stochastic_scheduler:
362        schedule_ = schedule
363        rng_key, scheduler_key = jax.random.split(rng_key)
364        sch_state["rng_key"] = scheduler_key
365        sch_state["lr_max"] = lr
366        sch_state["lr_min"] = final_lr
367
368        def schedule(state, rmse=None):
369            new_state = {**state, "lr": state["lr_max"]}
370            if rmse is None:
371                lr_max, new_state = schedule_(new_state, rmse=rmse)
372                lr_min = new_state["lr_min"]
373                new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"])
374                lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey)
375                new_state["lr"] = lr
376                new_state["lr_max"] = lr_max
377
378            return new_state["lr"], new_state
379    
380    return schedule, sch_state, schedule_metrics, adaptive_scheduler