fennol.utils.initializers
1import flax.linen as nn 2import jax 3import jax.numpy as jnp 4from typing import Callable,Union 5 6def scaled_orthogonal( 7 scale=1.0, mode="fan_avg", in_axis=-2, out_axis=-1, dtype=jnp.float_ 8): 9 assert mode in ["fan_in", "fan_out", "fan_avg"] 10 init_ortho = nn.initializers.orthogonal( 11 scale=scale, column_axis=out_axis, dtype=dtype 12 ) 13 if mode == "fan_in": 14 15 def init(key, shape, dtype=jnp.float32): 16 return init_ortho(key, shape, dtype=dtype) * (shape[in_axis] ** -0.5) 17 18 elif mode == "fan_out": 19 20 def init(key, shape, dtype=jnp.float32): 21 return init_ortho(key, shape, dtype=dtype) * (shape[out_axis] ** -0.5) 22 23 else: 24 25 def init(key, shape, dtype=jnp.float32): 26 return init_ortho(key, shape, dtype=dtype) * ( 27 (shape[in_axis] + shape[out_axis]) ** -0.5 28 ) 29 30 return init 31 32 33def initializer_from_str(name: Union[str,Callable])->Callable: 34 if callable(name): 35 return name 36 if not isinstance(name, str): 37 raise ValueError(f"Invalid initializer {name}") 38 return eval( 39 name, 40 {"__builtins__": None}, 41 { 42 **nn.initializers.__dict__, 43 "scaled_orthogonal": scaled_orthogonal, 44 }, 45 )
def
scaled_orthogonal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, dtype=<class 'jax.numpy.float64'>):
7def scaled_orthogonal( 8 scale=1.0, mode="fan_avg", in_axis=-2, out_axis=-1, dtype=jnp.float_ 9): 10 assert mode in ["fan_in", "fan_out", "fan_avg"] 11 init_ortho = nn.initializers.orthogonal( 12 scale=scale, column_axis=out_axis, dtype=dtype 13 ) 14 if mode == "fan_in": 15 16 def init(key, shape, dtype=jnp.float32): 17 return init_ortho(key, shape, dtype=dtype) * (shape[in_axis] ** -0.5) 18 19 elif mode == "fan_out": 20 21 def init(key, shape, dtype=jnp.float32): 22 return init_ortho(key, shape, dtype=dtype) * (shape[out_axis] ** -0.5) 23 24 else: 25 26 def init(key, shape, dtype=jnp.float32): 27 return init_ortho(key, shape, dtype=dtype) * ( 28 (shape[in_axis] + shape[out_axis]) ** -0.5 29 ) 30 31 return init
def
initializer_from_str(name: Union[str, Callable]) -> Callable:
34def initializer_from_str(name: Union[str,Callable])->Callable: 35 if callable(name): 36 return name 37 if not isinstance(name, str): 38 raise ValueError(f"Invalid initializer {name}") 39 return eval( 40 name, 41 {"__builtins__": None}, 42 { 43 **nn.initializers.__dict__, 44 "scaled_orthogonal": scaled_orthogonal, 45 }, 46 )