fennol.utils.initializers

 1import flax.linen as nn
 2import jax
 3import jax.numpy as jnp
 4from typing import Callable,Union
 5
 6def scaled_orthogonal(
 7    scale=1.0, mode="fan_avg", in_axis=-2, out_axis=-1, dtype=jnp.float_
 8):
 9    assert mode in ["fan_in", "fan_out", "fan_avg"]
10    init_ortho = nn.initializers.orthogonal(
11        scale=scale, column_axis=out_axis, dtype=dtype
12    )
13    if mode == "fan_in":
14
15        def init(key, shape, dtype=jnp.float32):
16            return init_ortho(key, shape, dtype=dtype) * (shape[in_axis] ** -0.5)
17
18    elif mode == "fan_out":
19
20        def init(key, shape, dtype=jnp.float32):
21            return init_ortho(key, shape, dtype=dtype) * (shape[out_axis] ** -0.5)
22
23    else:
24
25        def init(key, shape, dtype=jnp.float32):
26            return init_ortho(key, shape, dtype=dtype) * (
27                (shape[in_axis] + shape[out_axis]) ** -0.5
28            )
29
30    return init
31
32
33def initializer_from_str(name: Union[str,Callable])->Callable:
34    if callable(name):
35        return name
36    if not isinstance(name, str):
37        raise ValueError(f"Invalid initializer {name}")
38    return eval(
39        name,
40        {"__builtins__": None},
41        {
42            **nn.initializers.__dict__,
43            "scaled_orthogonal": scaled_orthogonal,
44        },
45    )
def scaled_orthogonal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, dtype=<class 'jax.numpy.float64'>):
 7def scaled_orthogonal(
 8    scale=1.0, mode="fan_avg", in_axis=-2, out_axis=-1, dtype=jnp.float_
 9):
10    assert mode in ["fan_in", "fan_out", "fan_avg"]
11    init_ortho = nn.initializers.orthogonal(
12        scale=scale, column_axis=out_axis, dtype=dtype
13    )
14    if mode == "fan_in":
15
16        def init(key, shape, dtype=jnp.float32):
17            return init_ortho(key, shape, dtype=dtype) * (shape[in_axis] ** -0.5)
18
19    elif mode == "fan_out":
20
21        def init(key, shape, dtype=jnp.float32):
22            return init_ortho(key, shape, dtype=dtype) * (shape[out_axis] ** -0.5)
23
24    else:
25
26        def init(key, shape, dtype=jnp.float32):
27            return init_ortho(key, shape, dtype=dtype) * (
28                (shape[in_axis] + shape[out_axis]) ** -0.5
29            )
30
31    return init
def initializer_from_str(name: Union[str, Callable]) -> Callable:
34def initializer_from_str(name: Union[str,Callable])->Callable:
35    if callable(name):
36        return name
37    if not isinstance(name, str):
38        raise ValueError(f"Invalid initializer {name}")
39    return eval(
40        name,
41        {"__builtins__": None},
42        {
43            **nn.initializers.__dict__,
44            "scaled_orthogonal": scaled_orthogonal,
45        },
46    )