fennol.utils.initializers

 1import flax.linen as nn
 2import jax
 3import jax.numpy as jnp
 4from typing import Callable,Union
 5
 6def scaled_orthogonal(
 7    scale=1.0, mode="fan_avg", in_axis=-2, out_axis=-1, dtype=jnp.float_
 8):
 9    assert mode in ["fan_in", "fan_out", "fan_avg"]
10    init_ortho = nn.initializers.orthogonal(
11        scale=scale, column_axis=out_axis, dtype=dtype
12    )
13    if mode == "fan_in":
14
15        def init(key, shape, dtype=jnp.float32):
16            return init_ortho(key, shape, dtype=dtype) * (shape[in_axis] ** -0.5)
17
18    elif mode == "fan_out":
19
20        def init(key, shape, dtype=jnp.float32):
21            return init_ortho(key, shape, dtype=dtype) * (shape[out_axis] ** -0.5)
22
23    else:
24
25        def init(key, shape, dtype=jnp.float32):
26            return init_ortho(key, shape, dtype=dtype) * (
27                (shape[in_axis] + shape[out_axis]) ** -0.5
28            )
29
30    return init
31
32
33def initializer_from_str(name: Union[str,Callable,None])->Callable:
34    if name is None:
35        return nn.initializers.lecun_normal()
36    if callable(name):
37        return name
38    if not isinstance(name, str):
39        raise ValueError(f"Invalid initializer {name}")
40    return eval(
41        name,
42        {"__builtins__": None},
43        {
44            **nn.initializers.__dict__,
45            "scaled_orthogonal": scaled_orthogonal,
46        },
47    )
def scaled_orthogonal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, dtype=<class 'jax.numpy.float64'>):
 7def scaled_orthogonal(
 8    scale=1.0, mode="fan_avg", in_axis=-2, out_axis=-1, dtype=jnp.float_
 9):
10    assert mode in ["fan_in", "fan_out", "fan_avg"]
11    init_ortho = nn.initializers.orthogonal(
12        scale=scale, column_axis=out_axis, dtype=dtype
13    )
14    if mode == "fan_in":
15
16        def init(key, shape, dtype=jnp.float32):
17            return init_ortho(key, shape, dtype=dtype) * (shape[in_axis] ** -0.5)
18
19    elif mode == "fan_out":
20
21        def init(key, shape, dtype=jnp.float32):
22            return init_ortho(key, shape, dtype=dtype) * (shape[out_axis] ** -0.5)
23
24    else:
25
26        def init(key, shape, dtype=jnp.float32):
27            return init_ortho(key, shape, dtype=dtype) * (
28                (shape[in_axis] + shape[out_axis]) ** -0.5
29            )
30
31    return init
def initializer_from_str(name: Union[str, Callable, NoneType]) -> Callable:
34def initializer_from_str(name: Union[str,Callable,None])->Callable:
35    if name is None:
36        return nn.initializers.lecun_normal()
37    if callable(name):
38        return name
39    if not isinstance(name, str):
40        raise ValueError(f"Invalid initializer {name}")
41    return eval(
42        name,
43        {"__builtins__": None},
44        {
45            **nn.initializers.__dict__,
46            "scaled_orthogonal": scaled_orthogonal,
47        },
48    )