fennol.utils.activations
1import jax 2import jax.numpy as jnp 3from functools import partial 4import math 5import flax.linen as nn 6from typing import Union, Callable 7 8 9class TrainableSiLU(nn.Module): 10 @nn.compact 11 def __call__(self, x): 12 a = self.param("alpha", lambda k, s: jnp.ones(s), (1, x.shape[-1])) 13 b = self.param("beta", lambda k, s: 1.702 * jnp.ones(s), (1, x.shape[-1])) 14 shape = x.shape 15 x = x.reshape(-1, shape[-1]) 16 return (a * x * jax.nn.sigmoid(b * x)).reshape(shape) 17 18 19class TrainableCELU(nn.Module): 20 alpha: float = 0.1 21 22 @nn.compact 23 def __call__(self, x): 24 a = self.alpha * ( 25 1 26 + jax.nn.celu( 27 self.param( 28 "alpha", 29 lambda k, s: jnp.zeros(s), 30 ( 31 1, 32 x.shape[-1], 33 ), 34 ), 35 alpha=1.0, 36 ) 37 ) 38 shape = x.shape 39 x = x.reshape(-1, shape[-1]) 40 return jax.nn.celu(x, a).reshape(shape) 41 42 43class TrainableLeakyCELU(nn.Module): 44 alpha: float = 0.05 45 beta: float = 1.0 46 47 @nn.compact 48 def __call__(self, x): 49 a = self.alpha + self.param( 50 "alpha", 51 lambda k, s: jnp.zeros(s), 52 ( 53 1, 54 x.shape[-1], 55 ), 56 ) 57 b = self.beta * ( 58 1 59 + jax.nn.celu( 60 self.param( 61 "beta", 62 lambda k, s: jnp.zeros(s), 63 ( 64 1, 65 x.shape[-1], 66 ), 67 ), 68 alpha=1.0, 69 ) 70 ) 71 shape = x.shape 72 x = x.reshape(-1, shape[-1]) 73 return leaky_celu(x, a, b).reshape(shape) 74 75 76@jax.jit 77def aptx(x): 78 return (1.0 + jax.nn.tanh(x)) * x 79 80 81@partial(jax.jit, static_argnums=(1,)) 82def leaky_celu(x, alpha=0.1, beta=1.0): 83 return alpha * x + ((1.0 - alpha) / beta) * ( 84 jax.nn.softplus(beta * x) - math.log(2.0) 85 ) 86 87 88@jax.jit 89def tssr(x): 90 ax = jnp.abs(x) 91 mask = ax <= 1.0 92 axx = jnp.where(mask, 1.0, ax) 93 return jnp.where(mask, x, jnp.sign(x) * (2 * axx**0.5 - 1)) 94 95 96@jax.jit 97def tssr2(x): 98 ax = jnp.abs(x) 99 mask = ax <= 1.0 100 axx = jnp.where(mask, 1.0, ax) 101 return jnp.sign(x) * jnp.where(mask, 1.25 * ax - 0.25 * ax**3, axx**0.5) 102 103 104@jax.jit 105def tssr3(x): 106 ax = jnp.abs(x) 107 mask = ax <= 1.0 108 axx = jnp.where(mask, 1.0, ax) 109 ax2 = ax * ax 110 dax2 = ax - ax2 111 poly = 2.1875 * dax2 + ax2 * (ax + 0.3125 * dax2) 112 return jnp.sign(x) * jnp.where(mask, poly, axx**0.5) 113 114 115@jax.jit 116def pow(x, a): 117 return x**a 118 119 120@jax.jit 121def ssp(x): 122 return jnp.logaddexp(x + math.log(0.5), math.log(0.5)) 123 124 125@jax.jit 126def smooth_floor(x, eps=0.99): 127 return ( 128 x 129 - 0.5 130 - jnp.atan( 131 -eps * jnp.sin((-2 * jnp.pi) * x) / (eps * jnp.cos((2 * jnp.pi) * x) - 1.0) 132 ) 133 / jnp.pi 134 ) 135 136 137@jax.jit 138def smooth_round(x, eps=0.99): 139 return ( 140 x 141 - jnp.atan( 142 -eps 143 * jnp.sin(-2 * jnp.pi * (x - 0.5)) 144 / (eps * jnp.cos(2 * jnp.pi * (x - 0.5)) - 1.0) 145 ) 146 / jnp.pi 147 ) 148 149 150def chain(*activations): 151 # @jax.jit 152 def act(x): 153 for a in activations: 154 x = a(x) 155 return x 156 157 return act 158 159 160def activation_from_str(activation: Union[str, Callable, None]) -> Callable: 161 if activation is None: 162 return lambda x: x 163 if callable(activation): 164 return activation 165 if not isinstance(activation, str): 166 raise ValueError(f"Invalid activation {activation}") 167 if activation.lower() in ["none", "linear", "identity"]: 168 return lambda x: x 169 try: 170 return eval( 171 activation, 172 {"__builtins__": None}, 173 { 174 **jax.nn.__dict__, 175 **jax.numpy.__dict__, 176 **jax.__dict__, 177 "chain": chain, 178 "pow": pow, 179 "partial": partial, 180 "leaky_celu": leaky_celu, 181 "aptx": aptx, 182 "tssr": tssr, 183 "tssr2": tssr2, 184 "tssr3": tssr3, 185 "ssp": ssp, 186 "smooth_floor": smooth_floor, 187 "smooth_round": smooth_round, 188 "TrainableSiLU": TrainableSiLU, 189 "TrainableCELU": TrainableCELU, 190 }, 191 ) 192 except Exception as e: 193 raise ValueError( 194 f"The following exception was raised while parsing the activation function {activation} : {e}" 195 )
10class TrainableSiLU(nn.Module): 11 @nn.compact 12 def __call__(self, x): 13 a = self.param("alpha", lambda k, s: jnp.ones(s), (1, x.shape[-1])) 14 b = self.param("beta", lambda k, s: 1.702 * jnp.ones(s), (1, x.shape[-1])) 15 shape = x.shape 16 x = x.reshape(-1, shape[-1]) 17 return (a * x * jax.nn.sigmoid(b * x)).reshape(shape)
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
20class TrainableCELU(nn.Module): 21 alpha: float = 0.1 22 23 @nn.compact 24 def __call__(self, x): 25 a = self.alpha * ( 26 1 27 + jax.nn.celu( 28 self.param( 29 "alpha", 30 lambda k, s: jnp.zeros(s), 31 ( 32 1, 33 x.shape[-1], 34 ), 35 ), 36 alpha=1.0, 37 ) 38 ) 39 shape = x.shape 40 x = x.reshape(-1, shape[-1]) 41 return jax.nn.celu(x, a).reshape(shape)
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
44class TrainableLeakyCELU(nn.Module): 45 alpha: float = 0.05 46 beta: float = 1.0 47 48 @nn.compact 49 def __call__(self, x): 50 a = self.alpha + self.param( 51 "alpha", 52 lambda k, s: jnp.zeros(s), 53 ( 54 1, 55 x.shape[-1], 56 ), 57 ) 58 b = self.beta * ( 59 1 60 + jax.nn.celu( 61 self.param( 62 "beta", 63 lambda k, s: jnp.zeros(s), 64 ( 65 1, 66 x.shape[-1], 67 ), 68 ), 69 alpha=1.0, 70 ) 71 ) 72 shape = x.shape 73 x = x.reshape(-1, shape[-1]) 74 return leaky_celu(x, a, b).reshape(shape)
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
161def activation_from_str(activation: Union[str, Callable, None]) -> Callable: 162 if activation is None: 163 return lambda x: x 164 if callable(activation): 165 return activation 166 if not isinstance(activation, str): 167 raise ValueError(f"Invalid activation {activation}") 168 if activation.lower() in ["none", "linear", "identity"]: 169 return lambda x: x 170 try: 171 return eval( 172 activation, 173 {"__builtins__": None}, 174 { 175 **jax.nn.__dict__, 176 **jax.numpy.__dict__, 177 **jax.__dict__, 178 "chain": chain, 179 "pow": pow, 180 "partial": partial, 181 "leaky_celu": leaky_celu, 182 "aptx": aptx, 183 "tssr": tssr, 184 "tssr2": tssr2, 185 "tssr3": tssr3, 186 "ssp": ssp, 187 "smooth_floor": smooth_floor, 188 "smooth_round": smooth_round, 189 "TrainableSiLU": TrainableSiLU, 190 "TrainableCELU": TrainableCELU, 191 }, 192 ) 193 except Exception as e: 194 raise ValueError( 195 f"The following exception was raised while parsing the activation function {activation} : {e}" 196 )