fennol.utils.initializers
1import flax.linen as nn 2import jax 3import jax.numpy as jnp 4from typing import Callable,Union 5 6def scaled_orthogonal( 7 scale=1.0, mode="fan_avg", in_axis=-2, out_axis=-1, dtype=jnp.float_ 8): 9 assert mode in ["fan_in", "fan_out", "fan_avg"] 10 init_ortho = nn.initializers.orthogonal( 11 scale=scale, column_axis=out_axis, dtype=dtype 12 ) 13 if mode == "fan_in": 14 15 def init(key, shape, dtype=jnp.float32): 16 return init_ortho(key, shape, dtype=dtype) * (shape[in_axis] ** -0.5) 17 18 elif mode == "fan_out": 19 20 def init(key, shape, dtype=jnp.float32): 21 return init_ortho(key, shape, dtype=dtype) * (shape[out_axis] ** -0.5) 22 23 else: 24 25 def init(key, shape, dtype=jnp.float32): 26 return init_ortho(key, shape, dtype=dtype) * ( 27 (shape[in_axis] + shape[out_axis]) ** -0.5 28 ) 29 30 return init 31 32 33def initializer_from_str(name: Union[str,Callable,None])->Callable: 34 if name is None: 35 return nn.initializers.lecun_normal() 36 if callable(name): 37 return name 38 if not isinstance(name, str): 39 raise ValueError(f"Invalid initializer {name}") 40 return eval( 41 name, 42 {"__builtins__": None}, 43 { 44 **nn.initializers.__dict__, 45 "scaled_orthogonal": scaled_orthogonal, 46 }, 47 )
def
scaled_orthogonal( scale=1.0, mode='fan_avg', in_axis=-2, out_axis=-1, dtype=<class 'jax.numpy.float64'>):
7def scaled_orthogonal( 8 scale=1.0, mode="fan_avg", in_axis=-2, out_axis=-1, dtype=jnp.float_ 9): 10 assert mode in ["fan_in", "fan_out", "fan_avg"] 11 init_ortho = nn.initializers.orthogonal( 12 scale=scale, column_axis=out_axis, dtype=dtype 13 ) 14 if mode == "fan_in": 15 16 def init(key, shape, dtype=jnp.float32): 17 return init_ortho(key, shape, dtype=dtype) * (shape[in_axis] ** -0.5) 18 19 elif mode == "fan_out": 20 21 def init(key, shape, dtype=jnp.float32): 22 return init_ortho(key, shape, dtype=dtype) * (shape[out_axis] ** -0.5) 23 24 else: 25 26 def init(key, shape, dtype=jnp.float32): 27 return init_ortho(key, shape, dtype=dtype) * ( 28 (shape[in_axis] + shape[out_axis]) ** -0.5 29 ) 30 31 return init
def
initializer_from_str(name: Union[str, Callable, NoneType]) -> Callable:
34def initializer_from_str(name: Union[str,Callable,None])->Callable: 35 if name is None: 36 return nn.initializers.lecun_normal() 37 if callable(name): 38 return name 39 if not isinstance(name, str): 40 raise ValueError(f"Invalid initializer {name}") 41 return eval( 42 name, 43 {"__builtins__": None}, 44 { 45 **nn.initializers.__dict__, 46 "scaled_orthogonal": scaled_orthogonal, 47 }, 48 )