fennol.training.optimizers

  1from typing import (
  2    Callable,
  3    Optional,
  4    Dict,
  5    List,
  6    Tuple,
  7    Union,
  8    Any,
  9    NamedTuple,
 10    Sequence,
 11)
 12import optax
 13import jax
 14import jax.numpy as jnp
 15import numpy as np
 16import operator
 17from flax import traverse_util
 18import json
 19import re
 20
 21from optax._src import base
 22from optax._src import wrappers
 23import chex
 24from optax import tree_utils as otu
 25
 26
 27class AddWeightDiffState(NamedTuple):
 28    ref_weights: Any
 29
 30
 31def add_weights_difference(
 32    weight_decay: Union[float, jax.Array] = 0.0,
 33    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
 34) -> base.GradientTransformation:
 35    """weight decay toward initial weights."""
 36
 37    def init_fn(params):
 38        return AddWeightDiffState(ref_weights=params)
 39
 40    def update_fn(updates, state, params):
 41        if params is None:
 42            raise ValueError(base.NO_PARAMS_MSG)
 43        updates = jax.tree_util.tree_map(
 44            lambda g, p, pref: g + weight_decay * (p - pref),
 45            updates,
 46            params,
 47            state.ref_weights,
 48        )
 49        return updates, state
 50
 51    # If mask is not `None`, apply mask to the gradient transformation.
 52    # E.g. it is common to skip weight decay on bias units and batch stats.
 53    if mask is not None:
 54        return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask)
 55    return base.GradientTransformation(init_fn, update_fn)
 56
 57
 58def add_grokfast(
 59    alpha: float = 0.9,
 60    l: float = 1.0,
 61) -> base.GradientTransformation:
 62    """Grokfast: amplify slow gradients by exponential moving average."""
 63
 64    ema: base.GradientTransformation = optax.ema(decay=alpha, debias=False)
 65
 66    def init_fn(params):
 67        return ema.init(params)
 68
 69    def update_fn(updates, state, params=None):
 70        dupdates, state = ema.update(updates, state, params)
 71        # updates = updates + l*dupdates
 72        updates = jax.tree_util.tree_map(lambda g, d: g + l * d, updates, dupdates)
 73        return updates, state
 74
 75    return base.GradientTransformation(init_fn, update_fn)
 76
 77
 78class PROFITState(NamedTuple):
 79    ref_weights: Any
 80    istep: int
 81    main_opt_state: Any
 82    internal_opt_state: Any
 83
 84
 85def profit(
 86    learning_rate: base.ScalarOrSchedule,
 87    nsteps_ref: int = 1,
 88    main_opt: str = "adam",
 89    main_opt_params: Dict[str, Any] = {},
 90    internal_opt: str = "sgd",
 91    internal_opt_params: Dict[str, Any] = {},
 92    **kwargs,
 93) -> base.GradientTransformation:
 94    """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930"""
 95
 96    main_opt_params = {"learning_rate": learning_rate, **main_opt_params}
 97    main_opt = eval(
 98        main_opt,
 99        {"__builtins__": None},
100        {**optax.__dict__},
101    )(**main_opt_params)
102
103    internal_opt_params = {"learning_rate": 0.1, **internal_opt_params}
104    internal_opt = eval(
105        internal_opt,
106        {"__builtins__": None},
107        {**optax.__dict__},
108    )(**internal_opt_params)
109
110    def init_fn(params):
111        return PROFITState(
112            ref_weights=params,
113            istep=0,
114            main_opt_state=main_opt.init(params),
115            internal_opt_state=internal_opt.init(params),
116        )
117
118    def update_main(gradients, main_opt_state, internal_opt_state, params, params_ref):
119        delta = jax.tree_util.tree_map(lambda p, pref: p - pref, params, params_ref)
120        dot = jax.tree.reduce(
121            operator.add,
122            jax.tree_util.tree_map(lambda g, d: (g * d).sum(), gradients, delta),
123        )
124        delta2 = jax.tree.reduce(
125            operator.add, jax.tree_util.tree_map(lambda d: (d**2).sum(), delta)
126        )
127        proj = dot / (delta2 + 1.0e-6)
128
129        gradients = jax.lax.cond(
130            dot >= 0,
131            lambda g, d: g,
132            lambda g, d: jax.tree_util.tree_map(lambda x: proj * x, d),
133            gradients,
134            delta,
135        )
136        updates, main_opt_state = main_opt.update(gradients, main_opt_state, params)
137        updates = jax.tree_util.tree_map(lambda g, d: g - d, updates, delta)
138        return updates, main_opt_state, internal_opt_state
139
140    def update_internal(
141        gradients, main_opt_state, internal_opt_state, params, params_ref
142    ):
143        updates, internal_opt_state = internal_opt.update(
144            gradients, internal_opt_state, params
145        )
146        return updates, main_opt_state, internal_opt_state
147
148    def update_fn(gradients, state, params):
149        istep = state.istep % (nsteps_ref + 1)
150        # jax.debug.print("{i} {j}",i=istep,j=state.istep)
151
152        params_ref = jax.lax.cond(
153            istep == 0, lambda a, b: a, lambda a, b: b, params, state.ref_weights
154        )
155
156        updates, main_opt_state, internal_opt_state = jax.lax.cond(
157            istep == nsteps_ref,
158            update_main,
159            update_internal,
160            gradients,
161            state.main_opt_state,
162            state.internal_opt_state,
163            params,
164            params_ref,
165        )
166
167        new_state = PROFITState(
168            ref_weights=params_ref,
169            istep=state.istep + 1,
170            main_opt_state=main_opt_state,
171            internal_opt_state=internal_opt_state,
172        )
173        return updates, new_state
174
175    return base.GradientTransformation(init_fn, update_fn)
176
177
178class MultiEmaState(NamedTuple):
179    """Holds an exponential moving average of past updates."""
180
181    count: int
182    ema: Sequence[base.Params]
183
184
185def multi_ema(
186    decays: Sequence[float],
187    debias: bool = True,
188    power: Union[bool,Sequence[bool]] = False,
189) -> base.GradientTransformation:
190    """Compute mutliple power moving averages of past updates."""
191
192    if len(decays) == 0:
193        def init_fn(params):
194            return base.EmptyState()
195        def update_fn(updates, state, params=None):
196            return [updates], state
197        return base.GradientTransformation(init_fn, update_fn)
198
199    if isinstance(power, bool):
200        power = [power] * len(decays)
201    assert len(power) == len(decays), "power and decays must have the same length"
202
203    gammas = []
204    for decay,p in zip(decays,power):
205        if not p:
206            gammas.append(None)
207            continue
208        t = decay**-2
209        gamma = np.roots([1, 7, 16 - t, 12 - t]).real.max()
210        assert gamma > 0, f"Invalid gamma for decay {decay}: {gamma}"
211        gammas.append(gamma)
212
213    def init_fn(params):
214        return MultiEmaState(
215            count=0,
216            ema=[otu.tree_zeros_like(params)] * len(decays),
217        )
218
219    def update_fn(params, state, other=None):
220        count_inc = state.count + 1
221        updates = []
222        state_ema = []
223        for decay,gamma,ema in zip(decays, gammas,state.ema):
224            if gamma is not None:
225                decay = (1.0 - 1.0 / count_inc) ** (gamma + 1)
226            update = new_ema = otu.tree_update_moment(
227                params, ema, decay, order=1
228            )
229            if debias and gamma is None:
230                update = otu.tree_bias_correction(update, decay, count_inc)
231            updates.append(update)
232            state_ema.append(new_ema)
233        return updates, MultiEmaState(count=count_inc, ema=state_ema)
234
235    return base.GradientTransformation(init_fn, update_fn)
236
237
238def get_optimizer(
239    training_parameters: Dict[str, any], variables: Dict, initial_lr: float
240) -> optax.GradientTransformation:
241    """
242    Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
243
244    Args:
245    - training_parameters: A dictionary containing the training parameters.
246    - variables: A  pytree containing the model parameters.
247    - initial_lr: The initial learning rate.
248
249    Returns:
250    - An optax.GradientTransformation object that can be used to optimize the model parameters.
251    """
252
253    default_status = str(training_parameters.get("default_status", "trainable")).lower()
254    assert default_status in [
255        "trainable",
256        "frozen",
257    ], f"Default status must be 'trainable' or 'frozen', got {default_status}"
258
259    # find frozen and trainable parameters
260    frozen = training_parameters.get("frozen", [])
261    trainable = training_parameters.get("trainable", [])
262
263    def training_status(full_path, v):
264        full_path = "/".join(full_path[1:]).lower()
265        # full_path = "/".join(full_path[1:]).lower()
266        status = (default_status, "")
267        for path in frozen:
268            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
269            if re.match(path.lower(), full_path):
270                status = ("frozen", path)
271        for path in trainable:
272            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
273            if re.match(path.lower(), full_path):
274                status = ("trainable", path)
275        return status[0]
276
277    params_partition = traverse_util.path_aware_map(training_status, variables)
278    if len(frozen) > 0 or len(trainable) > 0:
279        print("params partition:")
280        print(json.dumps(params_partition, indent=2, sort_keys=False))
281
282    ## Gradient preprocessing
283    grad_processing = []
284
285    # zero nans
286    zero_nans = training_parameters.get("zero_nans", False)
287    if zero_nans:
288        grad_processing.append(optax.zero_nans())
289
290    use_grokfast = training_parameters.get("use_grokfast", False)
291    if use_grokfast:
292        print("Using Grokfast")
293        alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9)
294        l_grokfast = training_parameters.get("l_grokfast", 1.0)
295        grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast))
296
297    
298
299    # OPTIMIZER
300    optimizer_name = training_parameters.get("optimizer", "adabelief")
301    optimizer = eval(
302        optimizer_name,
303        {"__builtins__": None},
304        {
305            **optax.__dict__,
306            "profit": profit,
307        },
308    )
309    print("Optimizer:", optimizer_name)
310    optimizer_configuration = training_parameters.get("optimizer_config", {})
311    optimizer_configuration["learning_rate"] = 1.0
312    grad_processing.append(optimizer(**optimizer_configuration))
313
314    # weight decay
315    weight_decay = training_parameters.get("weight_decay", 0.0)
316    assert weight_decay >= 0.0, "Weight decay must be positive"
317    decay_targets = training_parameters.get("decay_targets", [""])
318
319    def decay_status(full_path, v):
320        full_path = "/".join(full_path).lower()
321        status = False
322        # print(full_path,re.match(r'^params\/', full_path))
323        for path in decay_targets:
324            if re.match(r"^params/" + path.lower(), full_path):
325                status = True
326            # if full_path.startswith("params/" + path.lower()):
327            #     status = True
328        return status
329
330    decay_mask = traverse_util.path_aware_map(decay_status, variables)
331    if weight_decay > 0.0:
332        print("weight decay:", weight_decay)
333        print(json.dumps(decay_mask, indent=2, sort_keys=False))
334        grad_processing.append(
335            optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask)
336        )
337
338    regularize_init_weight = training_parameters.get("regularize_init_weights", 0.0)
339    if regularize_init_weight > 0.0:
340        print(
341            "Regularizing toward initial weights with L2 norm:", regularize_init_weight
342        )
343        if weight_decay <= 0.0:
344            print(json.dumps(decay_mask, indent=2, sort_keys=False))
345
346        grad_processing.append(
347            add_weights_difference(
348                weight_decay=-regularize_init_weight, mask=decay_mask
349            )
350        )
351
352    if zero_nans:
353        grad_processing.append(optax.zero_nans())
354
355    # learning rate
356    grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr))
357    ilr = -1
358
359    # gradient clipping
360    clip_threshold = training_parameters.get("gradient_clipping", -1.0)
361    if clip_threshold > 0.0:
362        print("Adaptive gradient clipping threshold:", clip_threshold)
363        grad_processing.append(optax.adaptive_grad_clip(clip_threshold))
364        ilr -= 1
365
366    ## define optimizer chain
367    optimizer_ = optax.chain(
368        *grad_processing,
369    )
370    partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()}
371    return optax.multi_transform(partition_optimizer, params_partition),ilr
372
373
374def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
375    lr = training_parameters.get("lr", 1.0e-3)
376    init_lr = training_parameters.get("init_lr", lr / 25)
377    final_lr = training_parameters.get("final_lr", lr / 10000)
378
379    #### LEARNING RATE SCHEDULER ####
380    schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower()
381    schedule_type = training_parameters.get("scheduler", schedule_type).lower()
382    schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot")
383
384    adaptive_scheduler = False
385    print("Schedule type:", schedule_type)
386    if schedule_type == "cosine_onecycle":
387        transition_epochs = training_parameters.get("onecycle_epochs", max_epochs)
388        peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs)
389        schedule_ = optax.cosine_onecycle_schedule(
390            peak_value=lr,
391            div_factor=lr / init_lr,
392            final_div_factor=init_lr / final_lr,
393            transition_steps=transition_epochs * nbatch_per_epoch,
394            pct_start=peak_epoch / transition_epochs,
395        )
396        sch_state = {"count": 0, "best": np.inf, "lr": init_lr}
397
398        def schedule(state, rmse=None):
399            new_state = {**state}
400            lr = schedule_(state["count"])
401            if rmse is None:
402                new_state["count"] += 1
403            new_state["lr"] = lr
404            return lr, new_state
405
406    elif schedule_type == "piecewise_interpolate":
407        schedule_params = training_parameters.get("scheduler_parameters", {})
408        schedule_ = optax.piecewise_interpolate_schedule(
409            **{"init_value": lr, "interpolate_type": "linear", **schedule_params}
410        )
411        sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)}
412
413        def schedule(state, rmse=None):
414            new_state = {**state}
415            lr = schedule_(state["count"])
416            if rmse is None:
417                new_state["count"] += 1
418            new_state["lr"] = lr
419            return lr, new_state
420
421    elif schedule_type == "constant":
422        sch_state = {"count": 0}
423
424        def schedule(state, rmse=None):
425            new_state = {**state}
426            new_state["lr"] = lr
427            if rmse is None:
428                new_state["count"] += 1
429            return lr, new_state
430
431    elif schedule_type == "cosine":
432        assert (
433            "peak_epoch" in training_parameters
434        ), "Sine schedule requires 'peak_epoch' parameter"
435        period = training_parameters["peak_epoch"] * nbatch_per_epoch
436        peak_lr = lr
437        sch_state = {"count": 0}
438
439        def schedule(state, rmse=None):
440            new_state = {**state}
441            istep = state["count"]
442            g = 0.5 * (1 + jnp.cos(jnp.pi * istep / period))
443            lr = peak_lr + (init_lr - peak_lr) * g
444            new_state["lr"] = lr
445            if rmse is None:
446                new_state["count"] += 1
447            return lr, new_state
448
449    elif schedule_type == "reduce_on_plateau":
450        patience = training_parameters.get("patience", 10)
451        factor = training_parameters.get("lr_factor", 0.5)
452        patience_thr = training_parameters.get("patience_thr", 0.0)
453        sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience}
454        adaptive_scheduler = True
455
456        def schedule(state, rmse=None):
457            new_state = {**state}
458            if rmse is None:
459                new_state["count"] += 1
460                return state["lr"], new_state
461            if rmse <= state["best"] * (1.0 + patience_thr):
462                if rmse < state["best"]:
463                    new_state["best"] = rmse
464                new_state["patience"] = 0
465            else:
466                new_state["patience"] += 1
467                if new_state["patience"] >= patience:
468                    new_state["lr"] = state["lr"] * factor
469                    new_state["patience"] = 0
470                    print("Reducing learning rate to", new_state["lr"])
471            return new_state["lr"], new_state
472
473    else:
474        raise ValueError(f"Unknown schedule_type: {schedule_type}")
475
476    stochastic_scheduler = training_parameters.get("stochastic_scheduler", False)
477    if stochastic_scheduler:
478        schedule_ = schedule
479        rng_key, scheduler_key = jax.random.split(rng_key)
480        sch_state["rng_key"] = scheduler_key
481        sch_state["lr_max"] = lr
482        sch_state["lr_min"] = final_lr
483
484        def schedule(state, rmse=None):
485            new_state = {**state, "lr": state["lr_max"]}
486            if rmse is None:
487                lr_max, new_state = schedule_(new_state, rmse=rmse)
488                lr_min = new_state["lr_min"]
489                new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"])
490                lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey)
491                new_state["lr"] = lr
492                new_state["lr_max"] = lr_max
493
494            return new_state["lr"], new_state
495
496    return schedule, sch_state, schedule_metrics, adaptive_scheduler
class AddWeightDiffState(typing.NamedTuple):
28class AddWeightDiffState(NamedTuple):
29    ref_weights: Any

AddWeightDiffState(ref_weights,)

AddWeightDiffState(ref_weights: Any)

Create new instance of AddWeightDiffState(ref_weights,)

ref_weights: Any

Alias for field number 0

def add_weights_difference( weight_decay: Union[float, jax.Array] = 0.0, mask: Union[Any, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]], Any], NoneType] = None) -> optax._src.base.GradientTransformation:
32def add_weights_difference(
33    weight_decay: Union[float, jax.Array] = 0.0,
34    mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
35) -> base.GradientTransformation:
36    """weight decay toward initial weights."""
37
38    def init_fn(params):
39        return AddWeightDiffState(ref_weights=params)
40
41    def update_fn(updates, state, params):
42        if params is None:
43            raise ValueError(base.NO_PARAMS_MSG)
44        updates = jax.tree_util.tree_map(
45            lambda g, p, pref: g + weight_decay * (p - pref),
46            updates,
47            params,
48            state.ref_weights,
49        )
50        return updates, state
51
52    # If mask is not `None`, apply mask to the gradient transformation.
53    # E.g. it is common to skip weight decay on bias units and batch stats.
54    if mask is not None:
55        return wrappers.masked(base.GradientTransformation(init_fn, update_fn), mask)
56    return base.GradientTransformation(init_fn, update_fn)

weight decay toward initial weights.

def add_grokfast( alpha: float = 0.9, l: float = 1.0) -> optax._src.base.GradientTransformation:
59def add_grokfast(
60    alpha: float = 0.9,
61    l: float = 1.0,
62) -> base.GradientTransformation:
63    """Grokfast: amplify slow gradients by exponential moving average."""
64
65    ema: base.GradientTransformation = optax.ema(decay=alpha, debias=False)
66
67    def init_fn(params):
68        return ema.init(params)
69
70    def update_fn(updates, state, params=None):
71        dupdates, state = ema.update(updates, state, params)
72        # updates = updates + l*dupdates
73        updates = jax.tree_util.tree_map(lambda g, d: g + l * d, updates, dupdates)
74        return updates, state
75
76    return base.GradientTransformation(init_fn, update_fn)

Grokfast: amplify slow gradients by exponential moving average.

class PROFITState(typing.NamedTuple):
79class PROFITState(NamedTuple):
80    ref_weights: Any
81    istep: int
82    main_opt_state: Any
83    internal_opt_state: Any

PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)

PROFITState( ref_weights: Any, istep: int, main_opt_state: Any, internal_opt_state: Any)

Create new instance of PROFITState(ref_weights, istep, main_opt_state, internal_opt_state)

ref_weights: Any

Alias for field number 0

istep: int

Alias for field number 1

main_opt_state: Any

Alias for field number 2

internal_opt_state: Any

Alias for field number 3

def profit( learning_rate: Union[float, jax.Array, Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, float, int]]], nsteps_ref: int = 1, main_opt: str = 'adam', main_opt_params: Dict[str, Any] = {}, internal_opt: str = 'sgd', internal_opt_params: Dict[str, Any] = {}, **kwargs) -> optax._src.base.GradientTransformation:
 86def profit(
 87    learning_rate: base.ScalarOrSchedule,
 88    nsteps_ref: int = 1,
 89    main_opt: str = "adam",
 90    main_opt_params: Dict[str, Any] = {},
 91    internal_opt: str = "sgd",
 92    internal_opt_params: Dict[str, Any] = {},
 93    **kwargs,
 94) -> base.GradientTransformation:
 95    """PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930"""
 96
 97    main_opt_params = {"learning_rate": learning_rate, **main_opt_params}
 98    main_opt = eval(
 99        main_opt,
100        {"__builtins__": None},
101        {**optax.__dict__},
102    )(**main_opt_params)
103
104    internal_opt_params = {"learning_rate": 0.1, **internal_opt_params}
105    internal_opt = eval(
106        internal_opt,
107        {"__builtins__": None},
108        {**optax.__dict__},
109    )(**internal_opt_params)
110
111    def init_fn(params):
112        return PROFITState(
113            ref_weights=params,
114            istep=0,
115            main_opt_state=main_opt.init(params),
116            internal_opt_state=internal_opt.init(params),
117        )
118
119    def update_main(gradients, main_opt_state, internal_opt_state, params, params_ref):
120        delta = jax.tree_util.tree_map(lambda p, pref: p - pref, params, params_ref)
121        dot = jax.tree.reduce(
122            operator.add,
123            jax.tree_util.tree_map(lambda g, d: (g * d).sum(), gradients, delta),
124        )
125        delta2 = jax.tree.reduce(
126            operator.add, jax.tree_util.tree_map(lambda d: (d**2).sum(), delta)
127        )
128        proj = dot / (delta2 + 1.0e-6)
129
130        gradients = jax.lax.cond(
131            dot >= 0,
132            lambda g, d: g,
133            lambda g, d: jax.tree_util.tree_map(lambda x: proj * x, d),
134            gradients,
135            delta,
136        )
137        updates, main_opt_state = main_opt.update(gradients, main_opt_state, params)
138        updates = jax.tree_util.tree_map(lambda g, d: g - d, updates, delta)
139        return updates, main_opt_state, internal_opt_state
140
141    def update_internal(
142        gradients, main_opt_state, internal_opt_state, params, params_ref
143    ):
144        updates, internal_opt_state = internal_opt.update(
145            gradients, internal_opt_state, params
146        )
147        return updates, main_opt_state, internal_opt_state
148
149    def update_fn(gradients, state, params):
150        istep = state.istep % (nsteps_ref + 1)
151        # jax.debug.print("{i} {j}",i=istep,j=state.istep)
152
153        params_ref = jax.lax.cond(
154            istep == 0, lambda a, b: a, lambda a, b: b, params, state.ref_weights
155        )
156
157        updates, main_opt_state, internal_opt_state = jax.lax.cond(
158            istep == nsteps_ref,
159            update_main,
160            update_internal,
161            gradients,
162            state.main_opt_state,
163            state.internal_opt_state,
164            params,
165            params_ref,
166        )
167
168        new_state = PROFITState(
169            ref_weights=params_ref,
170            istep=state.istep + 1,
171            main_opt_state=main_opt_state,
172            internal_opt_state=internal_opt_state,
173        )
174        return updates, new_state
175
176    return base.GradientTransformation(init_fn, update_fn)

PROFIT optimizer for fine-tuning https://arxiv.org/pdf/2412.01930

class MultiEmaState(typing.NamedTuple):
179class MultiEmaState(NamedTuple):
180    """Holds an exponential moving average of past updates."""
181
182    count: int
183    ema: Sequence[base.Params]

Holds an exponential moving average of past updates.

MultiEmaState( count: int, ema: Sequence[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]])

Create new instance of MultiEmaState(count, ema)

count: int

Alias for field number 0

ema: Sequence[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]]

Alias for field number 1

def multi_ema( decays: Sequence[float], debias: bool = True, power: Union[bool, Sequence[bool]] = False) -> optax._src.base.GradientTransformation:
186def multi_ema(
187    decays: Sequence[float],
188    debias: bool = True,
189    power: Union[bool,Sequence[bool]] = False,
190) -> base.GradientTransformation:
191    """Compute mutliple power moving averages of past updates."""
192
193    if len(decays) == 0:
194        def init_fn(params):
195            return base.EmptyState()
196        def update_fn(updates, state, params=None):
197            return [updates], state
198        return base.GradientTransformation(init_fn, update_fn)
199
200    if isinstance(power, bool):
201        power = [power] * len(decays)
202    assert len(power) == len(decays), "power and decays must have the same length"
203
204    gammas = []
205    for decay,p in zip(decays,power):
206        if not p:
207            gammas.append(None)
208            continue
209        t = decay**-2
210        gamma = np.roots([1, 7, 16 - t, 12 - t]).real.max()
211        assert gamma > 0, f"Invalid gamma for decay {decay}: {gamma}"
212        gammas.append(gamma)
213
214    def init_fn(params):
215        return MultiEmaState(
216            count=0,
217            ema=[otu.tree_zeros_like(params)] * len(decays),
218        )
219
220    def update_fn(params, state, other=None):
221        count_inc = state.count + 1
222        updates = []
223        state_ema = []
224        for decay,gamma,ema in zip(decays, gammas,state.ema):
225            if gamma is not None:
226                decay = (1.0 - 1.0 / count_inc) ** (gamma + 1)
227            update = new_ema = otu.tree_update_moment(
228                params, ema, decay, order=1
229            )
230            if debias and gamma is None:
231                update = otu.tree_bias_correction(update, decay, count_inc)
232            updates.append(update)
233            state_ema.append(new_ema)
234        return updates, MultiEmaState(count=count_inc, ema=state_ema)
235
236    return base.GradientTransformation(init_fn, update_fn)

Compute mutliple power moving averages of past updates.

def get_optimizer( training_parameters: Dict[str, <built-in function any>], variables: Dict, initial_lr: float) -> optax._src.base.GradientTransformation:
239def get_optimizer(
240    training_parameters: Dict[str, any], variables: Dict, initial_lr: float
241) -> optax.GradientTransformation:
242    """
243    Returns an optax.GradientTransformation object that can be used to optimize the model parameters.
244
245    Args:
246    - training_parameters: A dictionary containing the training parameters.
247    - variables: A  pytree containing the model parameters.
248    - initial_lr: The initial learning rate.
249
250    Returns:
251    - An optax.GradientTransformation object that can be used to optimize the model parameters.
252    """
253
254    default_status = str(training_parameters.get("default_status", "trainable")).lower()
255    assert default_status in [
256        "trainable",
257        "frozen",
258    ], f"Default status must be 'trainable' or 'frozen', got {default_status}"
259
260    # find frozen and trainable parameters
261    frozen = training_parameters.get("frozen", [])
262    trainable = training_parameters.get("trainable", [])
263
264    def training_status(full_path, v):
265        full_path = "/".join(full_path[1:]).lower()
266        # full_path = "/".join(full_path[1:]).lower()
267        status = (default_status, "")
268        for path in frozen:
269            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
270            if re.match(path.lower(), full_path):
271                status = ("frozen", path)
272        for path in trainable:
273            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
274            if re.match(path.lower(), full_path):
275                status = ("trainable", path)
276        return status[0]
277
278    params_partition = traverse_util.path_aware_map(training_status, variables)
279    if len(frozen) > 0 or len(trainable) > 0:
280        print("params partition:")
281        print(json.dumps(params_partition, indent=2, sort_keys=False))
282
283    ## Gradient preprocessing
284    grad_processing = []
285
286    # zero nans
287    zero_nans = training_parameters.get("zero_nans", False)
288    if zero_nans:
289        grad_processing.append(optax.zero_nans())
290
291    use_grokfast = training_parameters.get("use_grokfast", False)
292    if use_grokfast:
293        print("Using Grokfast")
294        alpha_grokfast = training_parameters.get("alpha_grokfast", 0.9)
295        l_grokfast = training_parameters.get("l_grokfast", 1.0)
296        grad_processing.append(add_grokfast(alpha=alpha_grokfast, l=l_grokfast))
297
298    
299
300    # OPTIMIZER
301    optimizer_name = training_parameters.get("optimizer", "adabelief")
302    optimizer = eval(
303        optimizer_name,
304        {"__builtins__": None},
305        {
306            **optax.__dict__,
307            "profit": profit,
308        },
309    )
310    print("Optimizer:", optimizer_name)
311    optimizer_configuration = training_parameters.get("optimizer_config", {})
312    optimizer_configuration["learning_rate"] = 1.0
313    grad_processing.append(optimizer(**optimizer_configuration))
314
315    # weight decay
316    weight_decay = training_parameters.get("weight_decay", 0.0)
317    assert weight_decay >= 0.0, "Weight decay must be positive"
318    decay_targets = training_parameters.get("decay_targets", [""])
319
320    def decay_status(full_path, v):
321        full_path = "/".join(full_path).lower()
322        status = False
323        # print(full_path,re.match(r'^params\/', full_path))
324        for path in decay_targets:
325            if re.match(r"^params/" + path.lower(), full_path):
326                status = True
327            # if full_path.startswith("params/" + path.lower()):
328            #     status = True
329        return status
330
331    decay_mask = traverse_util.path_aware_map(decay_status, variables)
332    if weight_decay > 0.0:
333        print("weight decay:", weight_decay)
334        print(json.dumps(decay_mask, indent=2, sort_keys=False))
335        grad_processing.append(
336            optax.add_decayed_weights(weight_decay=-weight_decay, mask=decay_mask)
337        )
338
339    regularize_init_weight = training_parameters.get("regularize_init_weights", 0.0)
340    if regularize_init_weight > 0.0:
341        print(
342            "Regularizing toward initial weights with L2 norm:", regularize_init_weight
343        )
344        if weight_decay <= 0.0:
345            print(json.dumps(decay_mask, indent=2, sort_keys=False))
346
347        grad_processing.append(
348            add_weights_difference(
349                weight_decay=-regularize_init_weight, mask=decay_mask
350            )
351        )
352
353    if zero_nans:
354        grad_processing.append(optax.zero_nans())
355
356    # learning rate
357    grad_processing.append(optax.inject_hyperparams(optax.scale)(step_size=initial_lr))
358    ilr = -1
359
360    # gradient clipping
361    clip_threshold = training_parameters.get("gradient_clipping", -1.0)
362    if clip_threshold > 0.0:
363        print("Adaptive gradient clipping threshold:", clip_threshold)
364        grad_processing.append(optax.adaptive_grad_clip(clip_threshold))
365        ilr -= 1
366
367    ## define optimizer chain
368    optimizer_ = optax.chain(
369        *grad_processing,
370    )
371    partition_optimizer = {"trainable": optimizer_, "frozen": optax.set_to_zero()}
372    return optax.multi_transform(partition_optimizer, params_partition),ilr

Returns an optax.GradientTransformation object that can be used to optimize the model parameters.

Args:

  • training_parameters: A dictionary containing the training parameters.
  • variables: A pytree containing the model parameters.
  • initial_lr: The initial learning rate.

Returns:

  • An optax.GradientTransformation object that can be used to optimize the model parameters.
def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
375def get_lr_schedule(max_epochs, nbatch_per_epoch, training_parameters):
376    lr = training_parameters.get("lr", 1.0e-3)
377    init_lr = training_parameters.get("init_lr", lr / 25)
378    final_lr = training_parameters.get("final_lr", lr / 10000)
379
380    #### LEARNING RATE SCHEDULER ####
381    schedule_type = training_parameters.get("schedule_type", "cosine_onecycle").lower()
382    schedule_type = training_parameters.get("scheduler", schedule_type).lower()
383    schedule_metrics = training_parameters.get("schedule_metrics", "rmse_tot")
384
385    adaptive_scheduler = False
386    print("Schedule type:", schedule_type)
387    if schedule_type == "cosine_onecycle":
388        transition_epochs = training_parameters.get("onecycle_epochs", max_epochs)
389        peak_epoch = training_parameters.get("peak_epoch", 0.3 * transition_epochs)
390        schedule_ = optax.cosine_onecycle_schedule(
391            peak_value=lr,
392            div_factor=lr / init_lr,
393            final_div_factor=init_lr / final_lr,
394            transition_steps=transition_epochs * nbatch_per_epoch,
395            pct_start=peak_epoch / transition_epochs,
396        )
397        sch_state = {"count": 0, "best": np.inf, "lr": init_lr}
398
399        def schedule(state, rmse=None):
400            new_state = {**state}
401            lr = schedule_(state["count"])
402            if rmse is None:
403                new_state["count"] += 1
404            new_state["lr"] = lr
405            return lr, new_state
406
407    elif schedule_type == "piecewise_interpolate":
408        schedule_params = training_parameters.get("scheduler_parameters", {})
409        schedule_ = optax.piecewise_interpolate_schedule(
410            **{"init_value": lr, "interpolate_type": "linear", **schedule_params}
411        )
412        sch_state = {"count": 0, "best": np.inf, "lr": schedule_(0)}
413
414        def schedule(state, rmse=None):
415            new_state = {**state}
416            lr = schedule_(state["count"])
417            if rmse is None:
418                new_state["count"] += 1
419            new_state["lr"] = lr
420            return lr, new_state
421
422    elif schedule_type == "constant":
423        sch_state = {"count": 0}
424
425        def schedule(state, rmse=None):
426            new_state = {**state}
427            new_state["lr"] = lr
428            if rmse is None:
429                new_state["count"] += 1
430            return lr, new_state
431
432    elif schedule_type == "cosine":
433        assert (
434            "peak_epoch" in training_parameters
435        ), "Sine schedule requires 'peak_epoch' parameter"
436        period = training_parameters["peak_epoch"] * nbatch_per_epoch
437        peak_lr = lr
438        sch_state = {"count": 0}
439
440        def schedule(state, rmse=None):
441            new_state = {**state}
442            istep = state["count"]
443            g = 0.5 * (1 + jnp.cos(jnp.pi * istep / period))
444            lr = peak_lr + (init_lr - peak_lr) * g
445            new_state["lr"] = lr
446            if rmse is None:
447                new_state["count"] += 1
448            return lr, new_state
449
450    elif schedule_type == "reduce_on_plateau":
451        patience = training_parameters.get("patience", 10)
452        factor = training_parameters.get("lr_factor", 0.5)
453        patience_thr = training_parameters.get("patience_thr", 0.0)
454        sch_state = {"count": 0, "best": np.inf, "lr": lr, "patience": patience}
455        adaptive_scheduler = True
456
457        def schedule(state, rmse=None):
458            new_state = {**state}
459            if rmse is None:
460                new_state["count"] += 1
461                return state["lr"], new_state
462            if rmse <= state["best"] * (1.0 + patience_thr):
463                if rmse < state["best"]:
464                    new_state["best"] = rmse
465                new_state["patience"] = 0
466            else:
467                new_state["patience"] += 1
468                if new_state["patience"] >= patience:
469                    new_state["lr"] = state["lr"] * factor
470                    new_state["patience"] = 0
471                    print("Reducing learning rate to", new_state["lr"])
472            return new_state["lr"], new_state
473
474    else:
475        raise ValueError(f"Unknown schedule_type: {schedule_type}")
476
477    stochastic_scheduler = training_parameters.get("stochastic_scheduler", False)
478    if stochastic_scheduler:
479        schedule_ = schedule
480        rng_key, scheduler_key = jax.random.split(rng_key)
481        sch_state["rng_key"] = scheduler_key
482        sch_state["lr_max"] = lr
483        sch_state["lr_min"] = final_lr
484
485        def schedule(state, rmse=None):
486            new_state = {**state, "lr": state["lr_max"]}
487            if rmse is None:
488                lr_max, new_state = schedule_(new_state, rmse=rmse)
489                lr_min = new_state["lr_min"]
490                new_state["rng_key"], subkey = jax.random.split(new_state["rng_key"])
491                lr = lr_min + (lr_max - lr_min) * jax.random.uniform(subkey)
492                new_state["lr"] = lr
493                new_state["lr_max"] = lr_max
494
495            return new_state["lr"], new_state
496
497    return schedule, sch_state, schedule_metrics, adaptive_scheduler