fennol.utils.activations
1import jax 2import jax.numpy as jnp 3from functools import partial 4import math 5import flax.linen as nn 6import numpy as np 7from typing import Union, Callable 8 9 10class TrainableSiLU(nn.Module): 11 @nn.compact 12 def __call__(self, x): 13 a = self.param("alpha", lambda k, s: jnp.ones(s), (1, x.shape[-1])) 14 b = self.param("beta", lambda k, s: 1.702 * jnp.ones(s), (1, x.shape[-1])) 15 shape = x.shape 16 x = x.reshape(-1, shape[-1]) 17 return (a * x * jax.nn.sigmoid(b * x)).reshape(shape) 18 19class DyT(nn.Module): 20 alpha0: float = 0.5 21 22 @nn.compact 23 def __call__(self, x): 24 shape = x.shape 25 dim = shape[-1] 26 a = self.param("scale", lambda k, s: jnp.ones(s,dtype=x.dtype), (1, dim)) 27 b = self.param("bias", lambda k, s: jnp.zeros(s,dtype=x.dtype), (1, dim)) 28 alpha = self.param("alpha", lambda k: jnp.asarray(self.alpha0,dtype=x.dtype)) 29 x = x.reshape(-1, dim) 30 return (a*jnp.tanh(alpha*x) + b).reshape(shape) 31 32class DynAct(nn.Module): 33 activation : Union[str, Callable] 34 alpha0: float = 1. 35 channelwise: bool = False 36 37 @nn.compact 38 def __call__(self, x): 39 shape = x.shape 40 dim = shape[-1] 41 a = self.param("scale", lambda k, s: jnp.ones(s,dtype=x.dtype), (1, dim)) 42 b = self.param("bias", lambda k, s: jnp.zeros(s,dtype=x.dtype), (1, dim)) 43 44 shape_alpha = (1,dim) if self.channelwise else (1,1) 45 alpha = self.param("alpha", lambda k,s: self.alpha0*jnp.ones(s,dtype=x.dtype), shape_alpha) 46 47 x = x.reshape(-1, dim) 48 act = activation_from_str(self.activation) 49 return (a*act(alpha*x) + b).reshape(shape) 50 51class TrainableCELU(nn.Module): 52 alpha: float = 0.1 53 54 @nn.compact 55 def __call__(self, x): 56 a = self.alpha * ( 57 1 58 + jax.nn.celu( 59 self.param( 60 "alpha", 61 lambda k, s: jnp.zeros(s), 62 ( 63 1, 64 x.shape[-1], 65 ), 66 ), 67 alpha=1.0, 68 ) 69 ) 70 shape = x.shape 71 x = x.reshape(-1, shape[-1]) 72 return jax.nn.celu(x, a).reshape(shape) 73 74 75class TrainableLeakyCELU(nn.Module): 76 alpha: float = 0.05 77 beta: float = 1.0 78 79 @nn.compact 80 def __call__(self, x): 81 a = self.alpha + self.param( 82 "alpha", 83 lambda k, s: jnp.zeros(s), 84 ( 85 1, 86 x.shape[-1], 87 ), 88 ) 89 b = self.beta * ( 90 1 91 + jax.nn.celu( 92 self.param( 93 "beta", 94 lambda k, s: jnp.zeros(s), 95 ( 96 1, 97 x.shape[-1], 98 ), 99 ), 100 alpha=1.0, 101 ) 102 ) 103 shape = x.shape 104 x = x.reshape(-1, shape[-1]) 105 return leaky_celu(x, a, b).reshape(shape) 106 107 108class FourierActivation(nn.Module): 109 nmax: int = 4 110 111 @nn.compact 112 def __call__(self, x): 113 114 out = jax.nn.swish(x) 115 nfeatures = x.shape[-1] 116 n = 2 * np.pi * np.arange(1, self.nmax + 1) 117 shape = [1] * x.ndim + [self.nmax] 118 n = n.reshape(shape) 119 x = jnp.expand_dims(x, axis=-1) 120 cx = jnp.cos(n * x) 121 sx = jnp.sin(n * x) 122 x = jnp.concatenate((cx, sx), axis=-1) 123 124 shape = [1] * (out.ndim - 1) + [nfeatures] 125 b = self.param("b", jax.nn.initializers.zeros, shape) 126 127 w = self.param( 128 "w", jax.nn.initializers.normal(stddev=0.1), (*shape, 2 * self.nmax) 129 ) 130 131 return out + (x * w).sum(axis=-1) + b 132 133 134class GausianBasisActivation(nn.Module): 135 nbasis: int = 10 136 xmin: float = -1.0 137 xmax: float = 3.0 138 139 @nn.compact 140 def __call__(self, x): 141 142 nfeatures = x.shape[-1] 143 shape = [1] * (x.ndim - 1) + [nfeatures] 144 b = self.param("b", jax.nn.initializers.zeros, shape) 145 w0 = self.param("w0", jax.nn.initializers.ones, shape) 146 out = jax.nn.swish(w0*x) 147 148 shape+= [self.nbasis] 149 x0 = self.param( 150 "x0", 151 lambda key: jnp.asarray( 152 np.linspace(self.xmin, self.xmax, self.nbasis)[None, :] 153 .repeat(nfeatures, axis=0) 154 .reshape(shape), 155 dtype=x.dtype, 156 ), 157 ) 158 159 sigma0 = np.abs(self.xmax - self.xmin)/self.nbasis 160 alphas = (0.5**0.5)*self.param( 161 "sigmas", 162 lambda key: jnp.full(shape, 1./sigma0, dtype=x.dtype), 163 ) 164 165 w = self.param( 166 "w", jax.nn.initializers.normal(stddev=0.1), shape 167 ) 168 169 ex = jnp.exp(-((x[...,None] - x0) * alphas)**2) 170 171 return out + (ex * w).sum(axis=-1) + b 172 173def safe_sqrt(x,eps=1.e-5): 174 return jnp.sqrt(jnp.clip(x, min=eps)) 175 176@jax.jit 177def aptx(x): 178 return (1.0 + jax.nn.tanh(x)) * x 179 180 181@jax.jit 182def serf(x): 183 return x * jax.scipy.special.erf(jax.nn.softplus(x)) 184 185 186@partial(jax.jit, static_argnums=(1,)) 187def leaky_celu(x, alpha=0.1, beta=1.0): 188 return alpha * x + ((1.0 - alpha) / beta) * ( 189 jax.nn.softplus(beta * x) - math.log(2.0) 190 ) 191 192 193@jax.jit 194def tssr(x): 195 ax = jnp.abs(x) 196 mask = ax <= 1.0 197 axx = jnp.where(mask, 1.0, ax) 198 return jnp.where(mask, x, jnp.sign(x) * (2 * axx**0.5 - 1)) 199 200 201@jax.jit 202def tssr2(x): 203 ax = jnp.abs(x) 204 mask = ax <= 1.0 205 axx = jnp.where(mask, 1.0, ax) 206 return jnp.sign(x) * jnp.where(mask, 1.25 * ax - 0.25 * ax**3, axx**0.5) 207 208 209@jax.jit 210def tssr3(x): 211 ax = jnp.abs(x) 212 mask = ax <= 1.0 213 axx = jnp.where(mask, 1.0, ax) 214 ax2 = ax * ax 215 dax2 = ax - ax2 216 poly = 2.1875 * dax2 + ax2 * (ax + 0.3125 * dax2) 217 return jnp.sign(x) * jnp.where(mask, poly, axx**0.5) 218 219 220@jax.jit 221def pow(x, a): 222 return x**a 223 224 225@jax.jit 226def ssp(x): 227 return jnp.logaddexp(x + math.log(0.5), math.log(0.5)) 228 229 230@jax.jit 231def smooth_floor(x, eps=0.99): 232 return ( 233 x 234 - 0.5 235 - jnp.atan( 236 -eps * jnp.sin((-2 * jnp.pi) * x) / (eps * jnp.cos((2 * jnp.pi) * x) - 1.0) 237 ) 238 / jnp.pi 239 ) 240 241 242@jax.jit 243def smooth_round(x, eps=0.99): 244 return ( 245 x 246 - jnp.atan( 247 -eps 248 * jnp.sin(-2 * jnp.pi * (x - 0.5)) 249 / (eps * jnp.cos(2 * jnp.pi * (x - 0.5)) - 1.0) 250 ) 251 / jnp.pi 252 ) 253 254 255def chain(*activations): 256 # @jax.jit 257 def act(x): 258 for a in activations: 259 x = a(x) 260 return x 261 262 return act 263 264 265def normalize_activation( 266 phi: Callable[[float], float], return_scale=False 267) -> Callable[[float], float]: 268 r"""Normalize a function, :math:`\psi(x)=\phi(x)/c` where :math:`c` is the normalization constant such that 269 270 .. math:: 271 272 \int_{-\infty}^{\infty} \psi(x)^2 \frac{e^{-x^2/2}}{\sqrt{2\pi}} dx = 1 273 274 ! Adapted from e3nn_jax ! 275 """ 276 with jax.ensure_compile_time_eval(): 277 # k = jax.random.PRNGKey(0) 278 # x = jax.random.normal(k, (1_000_000,)) 279 n = 1_000_001 280 x = jnp.sqrt(2) * jax.scipy.special.erfinv(jnp.linspace(-1.0, 1.0, n + 2)[1:-1]) 281 c = jnp.mean(phi(x) ** 2) ** 0.5 282 c = c.item() 283 284 if jnp.allclose(c, 1.0): 285 rho = phi 286 else: 287 288 def rho(x): 289 return phi(x) / c 290 291 if return_scale: 292 return rho, 1.0 / c 293 return rho 294 295 296def activation_from_str(activation: Union[str, Callable, None]) -> Callable: 297 if activation is None: 298 return lambda x: x 299 if callable(activation): 300 return activation 301 if not isinstance(activation, str): 302 raise ValueError(f"Invalid activation {activation}") 303 if activation.lower() in ["none", "linear", "identity"]: 304 return lambda x: x 305 try: 306 return eval( 307 activation, 308 {"__builtins__": None}, 309 { 310 **jax.nn.__dict__, 311 **jax.numpy.__dict__, 312 **jax.__dict__, 313 "chain": chain, 314 "pow": pow, 315 "partial": partial, 316 "leaky_celu": leaky_celu, 317 "aptx": aptx, 318 "tssr": tssr, 319 "tssr2": tssr2, 320 "tssr3": tssr3, 321 "ssp": ssp, 322 "smooth_floor": smooth_floor, 323 "smooth_round": smooth_round, 324 "TrainableSiLU": TrainableSiLU, 325 "TrainableCELU": TrainableCELU, 326 "FourierActivation": FourierActivation, 327 "GaussianBasis": GausianBasisActivation, 328 "normalize": normalize_activation, 329 "DyT": DyT, 330 "DynAct": DynAct, 331 }, 332 ) 333 except Exception as e: 334 raise ValueError( 335 f"The following exception was raised while parsing the activation function {activation} : {e}" 336 )
11class TrainableSiLU(nn.Module): 12 @nn.compact 13 def __call__(self, x): 14 a = self.param("alpha", lambda k, s: jnp.ones(s), (1, x.shape[-1])) 15 b = self.param("beta", lambda k, s: 1.702 * jnp.ones(s), (1, x.shape[-1])) 16 shape = x.shape 17 x = x.reshape(-1, shape[-1]) 18 return (a * x * jax.nn.sigmoid(b * x)).reshape(shape)
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
20class DyT(nn.Module): 21 alpha0: float = 0.5 22 23 @nn.compact 24 def __call__(self, x): 25 shape = x.shape 26 dim = shape[-1] 27 a = self.param("scale", lambda k, s: jnp.ones(s,dtype=x.dtype), (1, dim)) 28 b = self.param("bias", lambda k, s: jnp.zeros(s,dtype=x.dtype), (1, dim)) 29 alpha = self.param("alpha", lambda k: jnp.asarray(self.alpha0,dtype=x.dtype)) 30 x = x.reshape(-1, dim) 31 return (a*jnp.tanh(alpha*x) + b).reshape(shape)
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
33class DynAct(nn.Module): 34 activation : Union[str, Callable] 35 alpha0: float = 1. 36 channelwise: bool = False 37 38 @nn.compact 39 def __call__(self, x): 40 shape = x.shape 41 dim = shape[-1] 42 a = self.param("scale", lambda k, s: jnp.ones(s,dtype=x.dtype), (1, dim)) 43 b = self.param("bias", lambda k, s: jnp.zeros(s,dtype=x.dtype), (1, dim)) 44 45 shape_alpha = (1,dim) if self.channelwise else (1,1) 46 alpha = self.param("alpha", lambda k,s: self.alpha0*jnp.ones(s,dtype=x.dtype), shape_alpha) 47 48 x = x.reshape(-1, dim) 49 act = activation_from_str(self.activation) 50 return (a*act(alpha*x) + b).reshape(shape)
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
52class TrainableCELU(nn.Module): 53 alpha: float = 0.1 54 55 @nn.compact 56 def __call__(self, x): 57 a = self.alpha * ( 58 1 59 + jax.nn.celu( 60 self.param( 61 "alpha", 62 lambda k, s: jnp.zeros(s), 63 ( 64 1, 65 x.shape[-1], 66 ), 67 ), 68 alpha=1.0, 69 ) 70 ) 71 shape = x.shape 72 x = x.reshape(-1, shape[-1]) 73 return jax.nn.celu(x, a).reshape(shape)
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
76class TrainableLeakyCELU(nn.Module): 77 alpha: float = 0.05 78 beta: float = 1.0 79 80 @nn.compact 81 def __call__(self, x): 82 a = self.alpha + self.param( 83 "alpha", 84 lambda k, s: jnp.zeros(s), 85 ( 86 1, 87 x.shape[-1], 88 ), 89 ) 90 b = self.beta * ( 91 1 92 + jax.nn.celu( 93 self.param( 94 "beta", 95 lambda k, s: jnp.zeros(s), 96 ( 97 1, 98 x.shape[-1], 99 ), 100 ), 101 alpha=1.0, 102 ) 103 ) 104 shape = x.shape 105 x = x.reshape(-1, shape[-1]) 106 return leaky_celu(x, a, b).reshape(shape)
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
109class FourierActivation(nn.Module): 110 nmax: int = 4 111 112 @nn.compact 113 def __call__(self, x): 114 115 out = jax.nn.swish(x) 116 nfeatures = x.shape[-1] 117 n = 2 * np.pi * np.arange(1, self.nmax + 1) 118 shape = [1] * x.ndim + [self.nmax] 119 n = n.reshape(shape) 120 x = jnp.expand_dims(x, axis=-1) 121 cx = jnp.cos(n * x) 122 sx = jnp.sin(n * x) 123 x = jnp.concatenate((cx, sx), axis=-1) 124 125 shape = [1] * (out.ndim - 1) + [nfeatures] 126 b = self.param("b", jax.nn.initializers.zeros, shape) 127 128 w = self.param( 129 "w", jax.nn.initializers.normal(stddev=0.1), (*shape, 2 * self.nmax) 130 ) 131 132 return out + (x * w).sum(axis=-1) + b
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
135class GausianBasisActivation(nn.Module): 136 nbasis: int = 10 137 xmin: float = -1.0 138 xmax: float = 3.0 139 140 @nn.compact 141 def __call__(self, x): 142 143 nfeatures = x.shape[-1] 144 shape = [1] * (x.ndim - 1) + [nfeatures] 145 b = self.param("b", jax.nn.initializers.zeros, shape) 146 w0 = self.param("w0", jax.nn.initializers.ones, shape) 147 out = jax.nn.swish(w0*x) 148 149 shape+= [self.nbasis] 150 x0 = self.param( 151 "x0", 152 lambda key: jnp.asarray( 153 np.linspace(self.xmin, self.xmax, self.nbasis)[None, :] 154 .repeat(nfeatures, axis=0) 155 .reshape(shape), 156 dtype=x.dtype, 157 ), 158 ) 159 160 sigma0 = np.abs(self.xmax - self.xmin)/self.nbasis 161 alphas = (0.5**0.5)*self.param( 162 "sigmas", 163 lambda key: jnp.full(shape, 1./sigma0, dtype=x.dtype), 164 ) 165 166 w = self.param( 167 "w", jax.nn.initializers.normal(stddev=0.1), shape 168 ) 169 170 ex = jnp.exp(-((x[...,None] - x0) * alphas)**2) 171 172 return out + (ex * w).sum(axis=-1) + b
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
266def normalize_activation( 267 phi: Callable[[float], float], return_scale=False 268) -> Callable[[float], float]: 269 r"""Normalize a function, :math:`\psi(x)=\phi(x)/c` where :math:`c` is the normalization constant such that 270 271 .. math:: 272 273 \int_{-\infty}^{\infty} \psi(x)^2 \frac{e^{-x^2/2}}{\sqrt{2\pi}} dx = 1 274 275 ! Adapted from e3nn_jax ! 276 """ 277 with jax.ensure_compile_time_eval(): 278 # k = jax.random.PRNGKey(0) 279 # x = jax.random.normal(k, (1_000_000,)) 280 n = 1_000_001 281 x = jnp.sqrt(2) * jax.scipy.special.erfinv(jnp.linspace(-1.0, 1.0, n + 2)[1:-1]) 282 c = jnp.mean(phi(x) ** 2) ** 0.5 283 c = c.item() 284 285 if jnp.allclose(c, 1.0): 286 rho = phi 287 else: 288 289 def rho(x): 290 return phi(x) / c 291 292 if return_scale: 293 return rho, 1.0 / c 294 return rho
Normalize a function, \( \psi(x)=\phi(x)/c \) where \( c \) is the normalization constant such that
$$\int_{-\infty}^{\infty} \psi(x)^2 \frac{e^{-x^2/2}}{\sqrt{2\pi}} dx = 1$$
! Adapted from e3nn_jax !
297def activation_from_str(activation: Union[str, Callable, None]) -> Callable: 298 if activation is None: 299 return lambda x: x 300 if callable(activation): 301 return activation 302 if not isinstance(activation, str): 303 raise ValueError(f"Invalid activation {activation}") 304 if activation.lower() in ["none", "linear", "identity"]: 305 return lambda x: x 306 try: 307 return eval( 308 activation, 309 {"__builtins__": None}, 310 { 311 **jax.nn.__dict__, 312 **jax.numpy.__dict__, 313 **jax.__dict__, 314 "chain": chain, 315 "pow": pow, 316 "partial": partial, 317 "leaky_celu": leaky_celu, 318 "aptx": aptx, 319 "tssr": tssr, 320 "tssr2": tssr2, 321 "tssr3": tssr3, 322 "ssp": ssp, 323 "smooth_floor": smooth_floor, 324 "smooth_round": smooth_round, 325 "TrainableSiLU": TrainableSiLU, 326 "TrainableCELU": TrainableCELU, 327 "FourierActivation": FourierActivation, 328 "GaussianBasis": GausianBasisActivation, 329 "normalize": normalize_activation, 330 "DyT": DyT, 331 "DynAct": DynAct, 332 }, 333 ) 334 except Exception as e: 335 raise ValueError( 336 f"The following exception was raised while parsing the activation function {activation} : {e}" 337 )