fennol.utils.activations

  1import jax
  2import jax.numpy as jnp
  3from functools import partial
  4import math
  5import flax.linen as nn
  6from typing import Union, Callable
  7
  8
  9class TrainableSiLU(nn.Module):
 10    @nn.compact
 11    def __call__(self, x):
 12        a = self.param("alpha", lambda k, s: jnp.ones(s), (1, x.shape[-1]))
 13        b = self.param("beta", lambda k, s: 1.702 * jnp.ones(s), (1, x.shape[-1]))
 14        shape = x.shape
 15        x = x.reshape(-1, shape[-1])
 16        return (a * x * jax.nn.sigmoid(b * x)).reshape(shape)
 17
 18
 19class TrainableCELU(nn.Module):
 20    alpha: float = 0.1
 21
 22    @nn.compact
 23    def __call__(self, x):
 24        a = self.alpha * (
 25            1
 26            + jax.nn.celu(
 27                self.param(
 28                    "alpha",
 29                    lambda k, s: jnp.zeros(s),
 30                    (
 31                        1,
 32                        x.shape[-1],
 33                    ),
 34                ),
 35                alpha=1.0,
 36            )
 37        )
 38        shape = x.shape
 39        x = x.reshape(-1, shape[-1])
 40        return jax.nn.celu(x, a).reshape(shape)
 41
 42
 43class TrainableLeakyCELU(nn.Module):
 44    alpha: float = 0.05
 45    beta: float = 1.0
 46
 47    @nn.compact
 48    def __call__(self, x):
 49        a = self.alpha + self.param(
 50            "alpha",
 51            lambda k, s: jnp.zeros(s),
 52            (
 53                1,
 54                x.shape[-1],
 55            ),
 56        )
 57        b = self.beta * (
 58            1
 59            + jax.nn.celu(
 60                self.param(
 61                    "beta",
 62                    lambda k, s: jnp.zeros(s),
 63                    (
 64                        1,
 65                        x.shape[-1],
 66                    ),
 67                ),
 68                alpha=1.0,
 69            )
 70        )
 71        shape = x.shape
 72        x = x.reshape(-1, shape[-1])
 73        return leaky_celu(x, a, b).reshape(shape)
 74
 75
 76@jax.jit
 77def aptx(x):
 78    return (1.0 + jax.nn.tanh(x)) * x
 79
 80
 81@partial(jax.jit, static_argnums=(1,))
 82def leaky_celu(x, alpha=0.1, beta=1.0):
 83    return alpha * x + ((1.0 - alpha) / beta) * (
 84        jax.nn.softplus(beta * x) - math.log(2.0)
 85    )
 86
 87
 88@jax.jit
 89def tssr(x):
 90    ax = jnp.abs(x)
 91    mask = ax <= 1.0
 92    axx = jnp.where(mask, 1.0, ax)
 93    return jnp.where(mask, x, jnp.sign(x) * (2 * axx**0.5 - 1))
 94
 95
 96@jax.jit
 97def tssr2(x):
 98    ax = jnp.abs(x)
 99    mask = ax <= 1.0
100    axx = jnp.where(mask, 1.0, ax)
101    return jnp.sign(x) * jnp.where(mask, 1.25 * ax - 0.25 * ax**3, axx**0.5)
102
103
104@jax.jit
105def tssr3(x):
106    ax = jnp.abs(x)
107    mask = ax <= 1.0
108    axx = jnp.where(mask, 1.0, ax)
109    ax2 = ax * ax
110    dax2 = ax - ax2
111    poly = 2.1875 * dax2 + ax2 * (ax + 0.3125 * dax2)
112    return jnp.sign(x) * jnp.where(mask, poly, axx**0.5)
113
114
115@jax.jit
116def pow(x, a):
117    return x**a
118
119
120@jax.jit
121def ssp(x):
122    return jnp.logaddexp(x + math.log(0.5), math.log(0.5))
123
124
125@jax.jit
126def smooth_floor(x, eps=0.99):
127    return (
128        x
129        - 0.5
130        - jnp.atan(
131            -eps * jnp.sin((-2 * jnp.pi) * x) / (eps * jnp.cos((2 * jnp.pi) * x) - 1.0)
132        )
133        / jnp.pi
134    )
135
136
137@jax.jit
138def smooth_round(x, eps=0.99):
139    return (
140        x
141        - jnp.atan(
142            -eps
143            * jnp.sin(-2 * jnp.pi * (x - 0.5))
144            / (eps * jnp.cos(2 * jnp.pi * (x - 0.5)) - 1.0)
145        )
146        / jnp.pi
147    )
148
149
150def chain(*activations):
151    # @jax.jit
152    def act(x):
153        for a in activations:
154            x = a(x)
155        return x
156
157    return act
158
159
160def activation_from_str(activation: Union[str, Callable, None]) -> Callable:
161    if activation is None:
162        return lambda x: x
163    if callable(activation):
164        return activation
165    if not isinstance(activation, str):
166        raise ValueError(f"Invalid activation {activation}")
167    if activation.lower() in ["none", "linear", "identity"]:
168        return lambda x: x
169    try:
170        return eval(
171            activation,
172            {"__builtins__": None},
173            {
174                **jax.nn.__dict__,
175                **jax.numpy.__dict__,
176                **jax.__dict__,
177                "chain": chain,
178                "pow": pow,
179                "partial": partial,
180                "leaky_celu": leaky_celu,
181                "aptx": aptx,
182                "tssr": tssr,
183                "tssr2": tssr2,
184                "tssr3": tssr3,
185                "ssp": ssp,
186                "smooth_floor": smooth_floor,
187                "smooth_round": smooth_round,
188                "TrainableSiLU": TrainableSiLU,
189                "TrainableCELU": TrainableCELU,
190            },
191        )
192    except Exception as e:
193        raise ValueError(
194            f"The following exception was raised while parsing the activation function {activation} : {e}"
195        )
class TrainableSiLU(flax.linen.module.Module):
10class TrainableSiLU(nn.Module):
11    @nn.compact
12    def __call__(self, x):
13        a = self.param("alpha", lambda k, s: jnp.ones(s), (1, x.shape[-1]))
14        b = self.param("beta", lambda k, s: 1.702 * jnp.ones(s), (1, x.shape[-1]))
15        shape = x.shape
16        x = x.reshape(-1, shape[-1])
17        return (a * x * jax.nn.sigmoid(b * x)).reshape(shape)
TrainableSiLU( parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class TrainableCELU(flax.linen.module.Module):
20class TrainableCELU(nn.Module):
21    alpha: float = 0.1
22
23    @nn.compact
24    def __call__(self, x):
25        a = self.alpha * (
26            1
27            + jax.nn.celu(
28                self.param(
29                    "alpha",
30                    lambda k, s: jnp.zeros(s),
31                    (
32                        1,
33                        x.shape[-1],
34                    ),
35                ),
36                alpha=1.0,
37            )
38        )
39        shape = x.shape
40        x = x.reshape(-1, shape[-1])
41        return jax.nn.celu(x, a).reshape(shape)
TrainableCELU( alpha: float = 0.1, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
alpha: float = 0.1
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class TrainableLeakyCELU(flax.linen.module.Module):
44class TrainableLeakyCELU(nn.Module):
45    alpha: float = 0.05
46    beta: float = 1.0
47
48    @nn.compact
49    def __call__(self, x):
50        a = self.alpha + self.param(
51            "alpha",
52            lambda k, s: jnp.zeros(s),
53            (
54                1,
55                x.shape[-1],
56            ),
57        )
58        b = self.beta * (
59            1
60            + jax.nn.celu(
61                self.param(
62                    "beta",
63                    lambda k, s: jnp.zeros(s),
64                    (
65                        1,
66                        x.shape[-1],
67                    ),
68                ),
69                alpha=1.0,
70            )
71        )
72        shape = x.shape
73        x = x.reshape(-1, shape[-1])
74        return leaky_celu(x, a, b).reshape(shape)
TrainableLeakyCELU( alpha: float = 0.05, beta: float = 1.0, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
alpha: float = 0.05
beta: float = 1.0
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
@jax.jit
def aptx(x):
77@jax.jit
78def aptx(x):
79    return (1.0 + jax.nn.tanh(x)) * x
@partial(jax.jit, static_argnums=(1,))
def leaky_celu(x, alpha=0.1, beta=1.0):
82@partial(jax.jit, static_argnums=(1,))
83def leaky_celu(x, alpha=0.1, beta=1.0):
84    return alpha * x + ((1.0 - alpha) / beta) * (
85        jax.nn.softplus(beta * x) - math.log(2.0)
86    )
@jax.jit
def tssr(x):
89@jax.jit
90def tssr(x):
91    ax = jnp.abs(x)
92    mask = ax <= 1.0
93    axx = jnp.where(mask, 1.0, ax)
94    return jnp.where(mask, x, jnp.sign(x) * (2 * axx**0.5 - 1))
@jax.jit
def tssr2(x):
 97@jax.jit
 98def tssr2(x):
 99    ax = jnp.abs(x)
100    mask = ax <= 1.0
101    axx = jnp.where(mask, 1.0, ax)
102    return jnp.sign(x) * jnp.where(mask, 1.25 * ax - 0.25 * ax**3, axx**0.5)
@jax.jit
def tssr3(x):
105@jax.jit
106def tssr3(x):
107    ax = jnp.abs(x)
108    mask = ax <= 1.0
109    axx = jnp.where(mask, 1.0, ax)
110    ax2 = ax * ax
111    dax2 = ax - ax2
112    poly = 2.1875 * dax2 + ax2 * (ax + 0.3125 * dax2)
113    return jnp.sign(x) * jnp.where(mask, poly, axx**0.5)
@jax.jit
def pow(x, a):
116@jax.jit
117def pow(x, a):
118    return x**a
@jax.jit
def ssp(x):
121@jax.jit
122def ssp(x):
123    return jnp.logaddexp(x + math.log(0.5), math.log(0.5))
@jax.jit
def smooth_floor(x, eps=0.99):
126@jax.jit
127def smooth_floor(x, eps=0.99):
128    return (
129        x
130        - 0.5
131        - jnp.atan(
132            -eps * jnp.sin((-2 * jnp.pi) * x) / (eps * jnp.cos((2 * jnp.pi) * x) - 1.0)
133        )
134        / jnp.pi
135    )
@jax.jit
def smooth_round(x, eps=0.99):
138@jax.jit
139def smooth_round(x, eps=0.99):
140    return (
141        x
142        - jnp.atan(
143            -eps
144            * jnp.sin(-2 * jnp.pi * (x - 0.5))
145            / (eps * jnp.cos(2 * jnp.pi * (x - 0.5)) - 1.0)
146        )
147        / jnp.pi
148    )
def chain(*activations):
151def chain(*activations):
152    # @jax.jit
153    def act(x):
154        for a in activations:
155            x = a(x)
156        return x
157
158    return act
def activation_from_str(activation: Union[str, Callable, NoneType]) -> Callable:
161def activation_from_str(activation: Union[str, Callable, None]) -> Callable:
162    if activation is None:
163        return lambda x: x
164    if callable(activation):
165        return activation
166    if not isinstance(activation, str):
167        raise ValueError(f"Invalid activation {activation}")
168    if activation.lower() in ["none", "linear", "identity"]:
169        return lambda x: x
170    try:
171        return eval(
172            activation,
173            {"__builtins__": None},
174            {
175                **jax.nn.__dict__,
176                **jax.numpy.__dict__,
177                **jax.__dict__,
178                "chain": chain,
179                "pow": pow,
180                "partial": partial,
181                "leaky_celu": leaky_celu,
182                "aptx": aptx,
183                "tssr": tssr,
184                "tssr2": tssr2,
185                "tssr3": tssr3,
186                "ssp": ssp,
187                "smooth_floor": smooth_floor,
188                "smooth_round": smooth_round,
189                "TrainableSiLU": TrainableSiLU,
190                "TrainableCELU": TrainableCELU,
191            },
192        )
193    except Exception as e:
194        raise ValueError(
195            f"The following exception was raised while parsing the activation function {activation} : {e}"
196        )