fennol.utils.activations

  1import jax
  2import jax.numpy as jnp
  3from functools import partial
  4import math
  5import flax.linen as nn
  6import numpy as np
  7from typing import Union, Callable
  8
  9
 10class TrainableSiLU(nn.Module):
 11    @nn.compact
 12    def __call__(self, x):
 13        a = self.param("alpha", lambda k, s: jnp.ones(s), (1, x.shape[-1]))
 14        b = self.param("beta", lambda k, s: 1.702 * jnp.ones(s), (1, x.shape[-1]))
 15        shape = x.shape
 16        x = x.reshape(-1, shape[-1])
 17        return (a * x * jax.nn.sigmoid(b * x)).reshape(shape)
 18
 19class DyT(nn.Module):
 20    alpha0: float = 0.5
 21
 22    @nn.compact
 23    def __call__(self, x):
 24        shape = x.shape
 25        dim = shape[-1]
 26        a = self.param("scale", lambda k, s: jnp.ones(s,dtype=x.dtype), (1, dim))
 27        b = self.param("bias", lambda k, s: jnp.zeros(s,dtype=x.dtype), (1, dim))
 28        alpha = self.param("alpha", lambda k: jnp.asarray(self.alpha0,dtype=x.dtype))
 29        x = x.reshape(-1, dim)
 30        return (a*jnp.tanh(alpha*x) + b).reshape(shape)
 31
 32class DynAct(nn.Module):
 33    activation : Union[str, Callable]
 34    alpha0: float = 1.
 35    channelwise: bool = False
 36
 37    @nn.compact
 38    def __call__(self, x):
 39        shape = x.shape
 40        dim = shape[-1]
 41        a = self.param("scale", lambda k, s: jnp.ones(s,dtype=x.dtype), (1, dim))
 42        b = self.param("bias", lambda k, s: jnp.zeros(s,dtype=x.dtype), (1, dim))
 43
 44        shape_alpha = (1,dim) if self.channelwise else (1,1)
 45        alpha = self.param("alpha", lambda k,s: self.alpha0*jnp.ones(s,dtype=x.dtype), shape_alpha)
 46
 47        x = x.reshape(-1, dim)
 48        act = activation_from_str(self.activation)
 49        return (a*act(alpha*x) + b).reshape(shape)
 50
 51class TrainableCELU(nn.Module):
 52    alpha: float = 0.1
 53
 54    @nn.compact
 55    def __call__(self, x):
 56        a = self.alpha * (
 57            1
 58            + jax.nn.celu(
 59                self.param(
 60                    "alpha",
 61                    lambda k, s: jnp.zeros(s),
 62                    (
 63                        1,
 64                        x.shape[-1],
 65                    ),
 66                ),
 67                alpha=1.0,
 68            )
 69        )
 70        shape = x.shape
 71        x = x.reshape(-1, shape[-1])
 72        return jax.nn.celu(x, a).reshape(shape)
 73
 74
 75class TrainableLeakyCELU(nn.Module):
 76    alpha: float = 0.05
 77    beta: float = 1.0
 78
 79    @nn.compact
 80    def __call__(self, x):
 81        a = self.alpha + self.param(
 82            "alpha",
 83            lambda k, s: jnp.zeros(s),
 84            (
 85                1,
 86                x.shape[-1],
 87            ),
 88        )
 89        b = self.beta * (
 90            1
 91            + jax.nn.celu(
 92                self.param(
 93                    "beta",
 94                    lambda k, s: jnp.zeros(s),
 95                    (
 96                        1,
 97                        x.shape[-1],
 98                    ),
 99                ),
100                alpha=1.0,
101            )
102        )
103        shape = x.shape
104        x = x.reshape(-1, shape[-1])
105        return leaky_celu(x, a, b).reshape(shape)
106
107
108class FourierActivation(nn.Module):
109    nmax: int = 4
110
111    @nn.compact
112    def __call__(self, x):
113
114        out = jax.nn.swish(x)
115        nfeatures = x.shape[-1]
116        n = 2 * np.pi * np.arange(1, self.nmax + 1)
117        shape = [1] * x.ndim + [self.nmax]
118        n = n.reshape(shape)
119        x = jnp.expand_dims(x, axis=-1)
120        cx = jnp.cos(n * x)
121        sx = jnp.sin(n * x)
122        x = jnp.concatenate((cx, sx), axis=-1)
123
124        shape = [1] * (out.ndim - 1) + [nfeatures]
125        b = self.param("b", jax.nn.initializers.zeros, shape)
126
127        w = self.param(
128            "w", jax.nn.initializers.normal(stddev=0.1), (*shape, 2 * self.nmax)
129        )
130
131        return out + (x * w).sum(axis=-1) + b
132
133
134class GausianBasisActivation(nn.Module):
135    nbasis: int = 10
136    xmin: float = -1.0
137    xmax: float = 3.0
138
139    @nn.compact
140    def __call__(self, x):
141
142        nfeatures = x.shape[-1]
143        shape = [1] * (x.ndim - 1) + [nfeatures]
144        b = self.param("b", jax.nn.initializers.zeros, shape)
145        w0 = self.param("w0", jax.nn.initializers.ones, shape)
146        out = jax.nn.swish(w0*x)
147
148        shape+= [self.nbasis]
149        x0 = self.param(
150            "x0",
151            lambda key: jnp.asarray(
152                np.linspace(self.xmin, self.xmax, self.nbasis)[None, :]
153                .repeat(nfeatures, axis=0)
154                .reshape(shape),
155                dtype=x.dtype,
156            ),
157        )
158
159        sigma0 = np.abs(self.xmax - self.xmin)/self.nbasis
160        alphas = (0.5**0.5)*self.param(
161            "sigmas",
162            lambda key: jnp.full(shape, 1./sigma0, dtype=x.dtype),
163        )
164
165        w = self.param(
166            "w", jax.nn.initializers.normal(stddev=0.1), shape
167        )
168
169        ex = jnp.exp(-((x[...,None] - x0) * alphas)**2)
170
171        return out + (ex * w).sum(axis=-1) + b
172
173def safe_sqrt(x,eps=1.e-5):
174    return jnp.sqrt(jnp.clip(x, min=eps))
175
176@jax.jit
177def aptx(x):
178    return (1.0 + jax.nn.tanh(x)) * x
179
180
181@jax.jit
182def serf(x):
183    return x * jax.scipy.special.erf(jax.nn.softplus(x))
184
185
186@partial(jax.jit, static_argnums=(1,))
187def leaky_celu(x, alpha=0.1, beta=1.0):
188    return alpha * x + ((1.0 - alpha) / beta) * (
189        jax.nn.softplus(beta * x) - math.log(2.0)
190    )
191
192
193@jax.jit
194def tssr(x):
195    ax = jnp.abs(x)
196    mask = ax <= 1.0
197    axx = jnp.where(mask, 1.0, ax)
198    return jnp.where(mask, x, jnp.sign(x) * (2 * axx**0.5 - 1))
199
200
201@jax.jit
202def tssr2(x):
203    ax = jnp.abs(x)
204    mask = ax <= 1.0
205    axx = jnp.where(mask, 1.0, ax)
206    return jnp.sign(x) * jnp.where(mask, 1.25 * ax - 0.25 * ax**3, axx**0.5)
207
208
209@jax.jit
210def tssr3(x):
211    ax = jnp.abs(x)
212    mask = ax <= 1.0
213    axx = jnp.where(mask, 1.0, ax)
214    ax2 = ax * ax
215    dax2 = ax - ax2
216    poly = 2.1875 * dax2 + ax2 * (ax + 0.3125 * dax2)
217    return jnp.sign(x) * jnp.where(mask, poly, axx**0.5)
218
219
220@jax.jit
221def pow(x, a):
222    return x**a
223
224
225@jax.jit
226def ssp(x):
227    return jnp.logaddexp(x + math.log(0.5), math.log(0.5))
228
229
230@jax.jit
231def smooth_floor(x, eps=0.99):
232    return (
233        x
234        - 0.5
235        - jnp.atan(
236            -eps * jnp.sin((-2 * jnp.pi) * x) / (eps * jnp.cos((2 * jnp.pi) * x) - 1.0)
237        )
238        / jnp.pi
239    )
240
241
242@jax.jit
243def smooth_round(x, eps=0.99):
244    return (
245        x
246        - jnp.atan(
247            -eps
248            * jnp.sin(-2 * jnp.pi * (x - 0.5))
249            / (eps * jnp.cos(2 * jnp.pi * (x - 0.5)) - 1.0)
250        )
251        / jnp.pi
252    )
253
254
255def chain(*activations):
256    # @jax.jit
257    def act(x):
258        for a in activations:
259            x = a(x)
260        return x
261
262    return act
263
264
265def normalize_activation(
266    phi: Callable[[float], float], return_scale=False
267) -> Callable[[float], float]:
268    r"""Normalize a function, :math:`\psi(x)=\phi(x)/c` where :math:`c` is the normalization constant such that
269
270    .. math::
271
272        \int_{-\infty}^{\infty} \psi(x)^2 \frac{e^{-x^2/2}}{\sqrt{2\pi}} dx = 1
273
274    ! Adapted from e3nn_jax !
275    """
276    with jax.ensure_compile_time_eval():
277        # k = jax.random.PRNGKey(0)
278        # x = jax.random.normal(k, (1_000_000,))
279        n = 1_000_001
280        x = jnp.sqrt(2) * jax.scipy.special.erfinv(jnp.linspace(-1.0, 1.0, n + 2)[1:-1])
281        c = jnp.mean(phi(x) ** 2) ** 0.5
282        c = c.item()
283
284        if jnp.allclose(c, 1.0):
285            rho = phi
286        else:
287
288            def rho(x):
289                return phi(x) / c
290
291        if return_scale:
292            return rho, 1.0 / c
293        return rho
294
295
296def activation_from_str(activation: Union[str, Callable, None]) -> Callable:
297    if activation is None:
298        return lambda x: x
299    if callable(activation):
300        return activation
301    if not isinstance(activation, str):
302        raise ValueError(f"Invalid activation {activation}")
303    if activation.lower() in ["none", "linear", "identity"]:
304        return lambda x: x
305    try:
306        return eval(
307            activation,
308            {"__builtins__": None},
309            {
310                **jax.nn.__dict__,
311                **jax.numpy.__dict__,
312                **jax.__dict__,
313                "chain": chain,
314                "pow": pow,
315                "partial": partial,
316                "leaky_celu": leaky_celu,
317                "aptx": aptx,
318                "tssr": tssr,
319                "tssr2": tssr2,
320                "tssr3": tssr3,
321                "ssp": ssp,
322                "smooth_floor": smooth_floor,
323                "smooth_round": smooth_round,
324                "TrainableSiLU": TrainableSiLU,
325                "TrainableCELU": TrainableCELU,
326                "FourierActivation": FourierActivation,
327                "GaussianBasis": GausianBasisActivation,
328                "normalize": normalize_activation,
329                "DyT": DyT,
330                "DynAct": DynAct,
331            },
332        )
333    except Exception as e:
334        raise ValueError(
335            f"The following exception was raised while parsing the activation function {activation} : {e}"
336        )
class TrainableSiLU(flax.linen.module.Module):
11class TrainableSiLU(nn.Module):
12    @nn.compact
13    def __call__(self, x):
14        a = self.param("alpha", lambda k, s: jnp.ones(s), (1, x.shape[-1]))
15        b = self.param("beta", lambda k, s: 1.702 * jnp.ones(s), (1, x.shape[-1]))
16        shape = x.shape
17        x = x.reshape(-1, shape[-1])
18        return (a * x * jax.nn.sigmoid(b * x)).reshape(shape)
TrainableSiLU( parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class DyT(flax.linen.module.Module):
20class DyT(nn.Module):
21    alpha0: float = 0.5
22
23    @nn.compact
24    def __call__(self, x):
25        shape = x.shape
26        dim = shape[-1]
27        a = self.param("scale", lambda k, s: jnp.ones(s,dtype=x.dtype), (1, dim))
28        b = self.param("bias", lambda k, s: jnp.zeros(s,dtype=x.dtype), (1, dim))
29        alpha = self.param("alpha", lambda k: jnp.asarray(self.alpha0,dtype=x.dtype))
30        x = x.reshape(-1, dim)
31        return (a*jnp.tanh(alpha*x) + b).reshape(shape)
DyT( alpha0: float = 0.5, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
alpha0: float = 0.5
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class DynAct(flax.linen.module.Module):
33class DynAct(nn.Module):
34    activation : Union[str, Callable]
35    alpha0: float = 1.
36    channelwise: bool = False
37
38    @nn.compact
39    def __call__(self, x):
40        shape = x.shape
41        dim = shape[-1]
42        a = self.param("scale", lambda k, s: jnp.ones(s,dtype=x.dtype), (1, dim))
43        b = self.param("bias", lambda k, s: jnp.zeros(s,dtype=x.dtype), (1, dim))
44
45        shape_alpha = (1,dim) if self.channelwise else (1,1)
46        alpha = self.param("alpha", lambda k,s: self.alpha0*jnp.ones(s,dtype=x.dtype), shape_alpha)
47
48        x = x.reshape(-1, dim)
49        act = activation_from_str(self.activation)
50        return (a*act(alpha*x) + b).reshape(shape)
DynAct( activation: Union[str, Callable], alpha0: float = 1.0, channelwise: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
activation: Union[str, Callable]
alpha0: float = 1.0
channelwise: bool = False
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class TrainableCELU(flax.linen.module.Module):
52class TrainableCELU(nn.Module):
53    alpha: float = 0.1
54
55    @nn.compact
56    def __call__(self, x):
57        a = self.alpha * (
58            1
59            + jax.nn.celu(
60                self.param(
61                    "alpha",
62                    lambda k, s: jnp.zeros(s),
63                    (
64                        1,
65                        x.shape[-1],
66                    ),
67                ),
68                alpha=1.0,
69            )
70        )
71        shape = x.shape
72        x = x.reshape(-1, shape[-1])
73        return jax.nn.celu(x, a).reshape(shape)
TrainableCELU( alpha: float = 0.1, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
alpha: float = 0.1
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class TrainableLeakyCELU(flax.linen.module.Module):
 76class TrainableLeakyCELU(nn.Module):
 77    alpha: float = 0.05
 78    beta: float = 1.0
 79
 80    @nn.compact
 81    def __call__(self, x):
 82        a = self.alpha + self.param(
 83            "alpha",
 84            lambda k, s: jnp.zeros(s),
 85            (
 86                1,
 87                x.shape[-1],
 88            ),
 89        )
 90        b = self.beta * (
 91            1
 92            + jax.nn.celu(
 93                self.param(
 94                    "beta",
 95                    lambda k, s: jnp.zeros(s),
 96                    (
 97                        1,
 98                        x.shape[-1],
 99                    ),
100                ),
101                alpha=1.0,
102            )
103        )
104        shape = x.shape
105        x = x.reshape(-1, shape[-1])
106        return leaky_celu(x, a, b).reshape(shape)
TrainableLeakyCELU( alpha: float = 0.05, beta: float = 1.0, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
alpha: float = 0.05
beta: float = 1.0
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class FourierActivation(flax.linen.module.Module):
109class FourierActivation(nn.Module):
110    nmax: int = 4
111
112    @nn.compact
113    def __call__(self, x):
114
115        out = jax.nn.swish(x)
116        nfeatures = x.shape[-1]
117        n = 2 * np.pi * np.arange(1, self.nmax + 1)
118        shape = [1] * x.ndim + [self.nmax]
119        n = n.reshape(shape)
120        x = jnp.expand_dims(x, axis=-1)
121        cx = jnp.cos(n * x)
122        sx = jnp.sin(n * x)
123        x = jnp.concatenate((cx, sx), axis=-1)
124
125        shape = [1] * (out.ndim - 1) + [nfeatures]
126        b = self.param("b", jax.nn.initializers.zeros, shape)
127
128        w = self.param(
129            "w", jax.nn.initializers.normal(stddev=0.1), (*shape, 2 * self.nmax)
130        )
131
132        return out + (x * w).sum(axis=-1) + b
FourierActivation( nmax: int = 4, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
nmax: int = 4
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class GausianBasisActivation(flax.linen.module.Module):
135class GausianBasisActivation(nn.Module):
136    nbasis: int = 10
137    xmin: float = -1.0
138    xmax: float = 3.0
139
140    @nn.compact
141    def __call__(self, x):
142
143        nfeatures = x.shape[-1]
144        shape = [1] * (x.ndim - 1) + [nfeatures]
145        b = self.param("b", jax.nn.initializers.zeros, shape)
146        w0 = self.param("w0", jax.nn.initializers.ones, shape)
147        out = jax.nn.swish(w0*x)
148
149        shape+= [self.nbasis]
150        x0 = self.param(
151            "x0",
152            lambda key: jnp.asarray(
153                np.linspace(self.xmin, self.xmax, self.nbasis)[None, :]
154                .repeat(nfeatures, axis=0)
155                .reshape(shape),
156                dtype=x.dtype,
157            ),
158        )
159
160        sigma0 = np.abs(self.xmax - self.xmin)/self.nbasis
161        alphas = (0.5**0.5)*self.param(
162            "sigmas",
163            lambda key: jnp.full(shape, 1./sigma0, dtype=x.dtype),
164        )
165
166        w = self.param(
167            "w", jax.nn.initializers.normal(stddev=0.1), shape
168        )
169
170        ex = jnp.exp(-((x[...,None] - x0) * alphas)**2)
171
172        return out + (ex * w).sum(axis=-1) + b
GausianBasisActivation( nbasis: int = 10, xmin: float = -1.0, xmax: float = 3.0, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
nbasis: int = 10
xmin: float = -1.0
xmax: float = 3.0
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
def safe_sqrt(x, eps=1e-05):
174def safe_sqrt(x,eps=1.e-5):
175    return jnp.sqrt(jnp.clip(x, min=eps))
@jax.jit
def aptx(x):
177@jax.jit
178def aptx(x):
179    return (1.0 + jax.nn.tanh(x)) * x
@jax.jit
def serf(x):
182@jax.jit
183def serf(x):
184    return x * jax.scipy.special.erf(jax.nn.softplus(x))
@partial(jax.jit, static_argnums=(1,))
def leaky_celu(x, alpha=0.1, beta=1.0):
187@partial(jax.jit, static_argnums=(1,))
188def leaky_celu(x, alpha=0.1, beta=1.0):
189    return alpha * x + ((1.0 - alpha) / beta) * (
190        jax.nn.softplus(beta * x) - math.log(2.0)
191    )
@jax.jit
def tssr(x):
194@jax.jit
195def tssr(x):
196    ax = jnp.abs(x)
197    mask = ax <= 1.0
198    axx = jnp.where(mask, 1.0, ax)
199    return jnp.where(mask, x, jnp.sign(x) * (2 * axx**0.5 - 1))
@jax.jit
def tssr2(x):
202@jax.jit
203def tssr2(x):
204    ax = jnp.abs(x)
205    mask = ax <= 1.0
206    axx = jnp.where(mask, 1.0, ax)
207    return jnp.sign(x) * jnp.where(mask, 1.25 * ax - 0.25 * ax**3, axx**0.5)
@jax.jit
def tssr3(x):
210@jax.jit
211def tssr3(x):
212    ax = jnp.abs(x)
213    mask = ax <= 1.0
214    axx = jnp.where(mask, 1.0, ax)
215    ax2 = ax * ax
216    dax2 = ax - ax2
217    poly = 2.1875 * dax2 + ax2 * (ax + 0.3125 * dax2)
218    return jnp.sign(x) * jnp.where(mask, poly, axx**0.5)
@jax.jit
def pow(x, a):
221@jax.jit
222def pow(x, a):
223    return x**a
@jax.jit
def ssp(x):
226@jax.jit
227def ssp(x):
228    return jnp.logaddexp(x + math.log(0.5), math.log(0.5))
@jax.jit
def smooth_floor(x, eps=0.99):
231@jax.jit
232def smooth_floor(x, eps=0.99):
233    return (
234        x
235        - 0.5
236        - jnp.atan(
237            -eps * jnp.sin((-2 * jnp.pi) * x) / (eps * jnp.cos((2 * jnp.pi) * x) - 1.0)
238        )
239        / jnp.pi
240    )
@jax.jit
def smooth_round(x, eps=0.99):
243@jax.jit
244def smooth_round(x, eps=0.99):
245    return (
246        x
247        - jnp.atan(
248            -eps
249            * jnp.sin(-2 * jnp.pi * (x - 0.5))
250            / (eps * jnp.cos(2 * jnp.pi * (x - 0.5)) - 1.0)
251        )
252        / jnp.pi
253    )
def chain(*activations):
256def chain(*activations):
257    # @jax.jit
258    def act(x):
259        for a in activations:
260            x = a(x)
261        return x
262
263    return act
def normalize_activation( phi: Callable[[float], float], return_scale=False) -> Callable[[float], float]:
266def normalize_activation(
267    phi: Callable[[float], float], return_scale=False
268) -> Callable[[float], float]:
269    r"""Normalize a function, :math:`\psi(x)=\phi(x)/c` where :math:`c` is the normalization constant such that
270
271    .. math::
272
273        \int_{-\infty}^{\infty} \psi(x)^2 \frac{e^{-x^2/2}}{\sqrt{2\pi}} dx = 1
274
275    ! Adapted from e3nn_jax !
276    """
277    with jax.ensure_compile_time_eval():
278        # k = jax.random.PRNGKey(0)
279        # x = jax.random.normal(k, (1_000_000,))
280        n = 1_000_001
281        x = jnp.sqrt(2) * jax.scipy.special.erfinv(jnp.linspace(-1.0, 1.0, n + 2)[1:-1])
282        c = jnp.mean(phi(x) ** 2) ** 0.5
283        c = c.item()
284
285        if jnp.allclose(c, 1.0):
286            rho = phi
287        else:
288
289            def rho(x):
290                return phi(x) / c
291
292        if return_scale:
293            return rho, 1.0 / c
294        return rho

Normalize a function, \( \psi(x)=\phi(x)/c \) where \( c \) is the normalization constant such that

$$\int_{-\infty}^{\infty} \psi(x)^2 \frac{e^{-x^2/2}}{\sqrt{2\pi}} dx = 1$$

! Adapted from e3nn_jax !

def activation_from_str(activation: Union[str, Callable, NoneType]) -> Callable:
297def activation_from_str(activation: Union[str, Callable, None]) -> Callable:
298    if activation is None:
299        return lambda x: x
300    if callable(activation):
301        return activation
302    if not isinstance(activation, str):
303        raise ValueError(f"Invalid activation {activation}")
304    if activation.lower() in ["none", "linear", "identity"]:
305        return lambda x: x
306    try:
307        return eval(
308            activation,
309            {"__builtins__": None},
310            {
311                **jax.nn.__dict__,
312                **jax.numpy.__dict__,
313                **jax.__dict__,
314                "chain": chain,
315                "pow": pow,
316                "partial": partial,
317                "leaky_celu": leaky_celu,
318                "aptx": aptx,
319                "tssr": tssr,
320                "tssr2": tssr2,
321                "tssr3": tssr3,
322                "ssp": ssp,
323                "smooth_floor": smooth_floor,
324                "smooth_round": smooth_round,
325                "TrainableSiLU": TrainableSiLU,
326                "TrainableCELU": TrainableCELU,
327                "FourierActivation": FourierActivation,
328                "GaussianBasis": GausianBasisActivation,
329                "normalize": normalize_activation,
330                "DyT": DyT,
331                "DynAct": DynAct,
332            },
333        )
334    except Exception as e:
335        raise ValueError(
336            f"The following exception was raised while parsing the activation function {activation} : {e}"
337        )