fennol.models.misc.e3

  1import flax.linen as nn
  2from typing import Any, Optional, ClassVar
  3import jax.numpy as jnp
  4import jax
  5import numpy as np
  6from ...utils.spherical_harmonics import CG_SO3, spherical_to_cartesian_tensor
  7
  8### e3nn version
  9try:
 10    import e3nn_jax as e3nn
 11    E3NN_AVAILABLE = True
 12    E3NN_EXCEPTION = None
 13    Irreps = e3nn.Irreps
 14except Exception as e:
 15    E3NN_AVAILABLE = False
 16    E3NN_EXCEPTION = e
 17    class Irreps(tuple):
 18        pass
 19
 20
 21class FullTensorProduct(nn.Module):
 22    """Tensor product of two spherical harmonics."""
 23
 24    lmax1: int
 25    """The maximum order of the first spherical tensor."""
 26    lmax2: int
 27    """The maximum order of the second spherical tensor."""
 28    lmax_out: Optional[int] = None
 29    """The maximum order of the output spherical tensor. If None, it is the sum of lmax1 and lmax2."""
 30    ignore_parity: bool = False
 31    """Whether to ignore the parity of the spherical tensors."""
 32
 33    @nn.compact
 34    def __call__(self, x1, x2) -> None:
 35        irreps_1 = [(l, (-1) ** l) for l in range(self.lmax1 + 1)]
 36        irreps_2 = [(l, (-1) ** l) for l in range(self.lmax2 + 1)]
 37        irreps_out = []
 38        lsout = []
 39        psout = []
 40        i12 = []
 41
 42        lmax_out = self.lmax_out or self.lmax1 + self.lmax2
 43
 44        for i1, (l1, p1) in enumerate(irreps_1):
 45            for i2, (l2, p2) in enumerate(irreps_2):
 46                for lout in range(abs(l1 - l2), l1 + l2 + 1):
 47                    if p1 * p2 != (-1) ** lout and not self.ignore_parity:
 48                        continue
 49                    if lout > lmax_out:
 50                        continue
 51                    lsout.append(lout)
 52                    psout.append(p1 * p2)
 53                    i12.append((i1, i2))
 54
 55        lsout = np.array(lsout)
 56        psout = np.array(psout)
 57        idx = np.lexsort((psout, lsout))
 58        lsout = lsout[idx]
 59        psout = psout[idx]
 60        i12 = [i12[i] for i in idx]
 61        irreps_out = [(l, p) for l, p in zip(lsout, psout)]
 62
 63        slices_1 = [0]
 64        for l, p in irreps_1:
 65            slices_1.append(slices_1[-1] + 2 * l + 1)
 66        slices_2 = [0]
 67        for l, p in irreps_2:
 68            slices_2.append(slices_2[-1] + 2 * l + 1)
 69        slices_out = [0]
 70        for l, p in irreps_out:
 71            slices_out.append(slices_out[-1] + 2 * l + 1)
 72
 73        assert slices_1[-1] == (self.lmax1 + 1) ** 2
 74        assert slices_2[-1] == (self.lmax2 + 1) ** 2
 75
 76        shape = ((self.lmax1 + 1) ** 2, (self.lmax2 + 1) ** 2, slices_out[-1])
 77        w3js = np.zeros(shape)
 78        for iout, (lout, pout) in enumerate(irreps_out):
 79            i1, i2 = i12[iout]
 80            l1, p1 = irreps_1[i1]
 81            l2, p2 = irreps_2[i2]
 82            w3j = CG_SO3(l1, l2, lout)
 83            scale = (2 * lout + 1) ** 0.5
 84            w3js[
 85                slices_1[i1] : slices_1[i1 + 1],
 86                slices_2[i2] : slices_2[i2 + 1],
 87                slices_out[iout] : slices_out[iout + 1],
 88            ] = (
 89                w3j * scale
 90            )
 91        w3j = jnp.asarray(w3js)
 92
 93        return jnp.einsum("...a,...b,abc->...c", x1, x2, w3j), irreps_out
 94
 95
 96class FilteredTensorProduct(nn.Module):
 97    """Tensor product of two spherical harmonics filtered to give back the irreps of the first input"""
 98
 99    lmax1: int
100    """The maximum order of the first spherical tensor."""
101    lmax2: int
102    """The maximum order of the second spherical tensor."""
103    lmax_out: Optional[int] = None
104    """The maximum order of the output spherical tensor. If None, it is the same as lmax1."""
105    ignore_parity: bool = False
106    """Whether to ignore the parity of the spherical tensors."""
107    weights_by_channel: bool = False
108    """Whether to use different path weights for each channel."""
109
110    @nn.compact
111    def __call__(self, x1, x2) -> None:
112        irreps_1 = [(l, (-1) ** l) for l in range(self.lmax1 + 1)]
113        irreps_2 = [(l, (-1) ** l) for l in range(self.lmax2 + 1)]
114        lmax_out = self.lmax_out if self.lmax_out is not None else self.lmax1
115        irreps_out = [(l, (-1) ** l) for l in range(lmax_out + 1)]
116
117        slices_1 = [0]
118        for l, p in irreps_1:
119            slices_1.append(slices_1[-1] + 2 * l + 1)
120        slices_2 = [0]
121        for l, p in irreps_2:
122            slices_2.append(slices_2[-1] + 2 * l + 1)
123        slices_out = [0]
124        for l, p in irreps_out:
125            slices_out.append(slices_out[-1] + 2 * l + 1)
126
127        shape = (x1.shape[-1], x2.shape[-1], x1.shape[-1])
128        w3js = []
129        for iout, (lout, pout) in enumerate(irreps_out):
130            for i1, (l1, p1) in enumerate(irreps_1):
131                for i2, (l2, p2) in enumerate(irreps_2):
132                    if pout != p1 * p2 and not self.ignore_parity:
133                        continue
134                    if lout > l1 + l2 or lout < abs(l1 - l2):
135                        continue
136                    w3j = CG_SO3(l1, l2, lout)
137                    w3j_full = np.zeros(shape)
138                    scale = (2 * lout + 1) ** 0.5
139                    w3j_full[
140                        slices_1[i1] : slices_1[i1 + 1],
141                        slices_2[i2] : slices_2[i2 + 1],
142                        slices_out[iout] : slices_out[iout + 1],
143                    ] = (
144                        w3j * scale
145                    )
146                    w3js.append(w3j_full)
147        npath = len(w3js)
148        w3j = jnp.asarray(np.stack(w3js))
149
150        if self.weights_by_channel:
151            nchannels = x1.shape[-2]
152            weights = self.param("weights", jax.nn.initializers.normal(stddev=1./npath**0.5), (nchannels,npath,))
153            ww3j = jnp.einsum("np,pabc->nabc", weights, w3j)
154            return jnp.einsum("...na,...nb,nabc->...nc", x1, x2, ww3j)
155
156
157        weights = self.param("weights", jax.nn.initializers.normal(stddev=1./npath**0.5), (npath,))
158
159        ww3j = jnp.einsum("p,pabc->abc", weights, w3j)
160        return jnp.einsum("...a,...b,abc->...c", x1, x2, ww3j)
161
162
163class ChannelMixing(nn.Module):
164    """Linear mixing of input channels.
165    
166    FID: CHANNEL_MIXING
167    """
168
169    lmax: int
170    """The maximum order of the spherical tensor."""
171    nchannels: int
172    """The number of input channels."""
173    nchannels_out: Optional[int] = None
174    """The number of output channels. If None, it is the same as nchannels."""
175    input_key: Optional[str] = None
176    """The key in the input dictionary that corresponds to the input tensor."""
177    output_key: Optional[str] = None
178    """The key in the output dictionary where the computed tensor will be stored."""
179    squeeze: bool = False
180    """Whether to squeeze the output tensor to remove the channel dimension."""
181
182    FID: ClassVar[str] = "CHANNEL_MIXING"
183
184    @nn.compact
185    def __call__(self, inputs):
186        if self.input_key is None:
187            assert not isinstance(
188                inputs, dict
189            ), "input key must be provided if inputs is a dictionary"
190            x = inputs
191        else:
192            x = inputs[self.input_key]
193
194        ########################################
195        nchannels_out = self.nchannels_out or self.nchannels
196        weights = self.param(
197            "weights",
198            jax.nn.initializers.normal(stddev=1./self.nchannels**0.5),
199            (nchannels_out, self.nchannels),
200        )
201        out = jnp.einsum("ij,...jk->...ik", weights, x)
202        if self.squeeze and nchannels_out == 1:
203            out = jnp.squeeze(out, axis=-2)
204        ########################################
205
206        if self.input_key is not None:
207            output_key = self.name if self.output_key is None else self.output_key
208            return {**inputs, output_key: out} if output_key is not None else out
209        return out
210
211
212class ChannelMixingE3(nn.Module):
213    """Linear mixing of input channels with different weight for each angular momentum.
214    
215    FID: CHANNEL_MIXING_E3
216    """
217
218    lmax: int
219    """The maximum order of the spherical tensor."""
220    nchannels: int
221    """The number of input channels."""
222    nchannels_out: Optional[int] = None
223    """The number of output channels. If None, it is the same as nchannels."""
224    input_key: Optional[str] = None
225    """The key in the input dictionary that corresponds to the input tensor."""
226    output_key: Optional[str] = None
227    """The key in the output dictionary where the computed tensor will be stored."""
228    squeeze: bool = False
229    """Whether to squeeze the output tensor to remove the channel dimension."""
230
231    FID: ClassVar[str] = "CHANNEL_MIXING_E3"
232
233    @nn.compact
234    def __call__(self, inputs):
235        if self.input_key is None:
236            assert not isinstance(
237                inputs, dict
238            ), "input key must be provided if inputs is a dictionary"
239            x = inputs
240        else:
241            x = inputs[self.input_key]
242
243        ########################################
244        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
245        nchannels_out = self.nchannels_out or self.nchannels
246        weights = jnp.repeat(
247            self.param(
248                "weights",
249                jax.nn.initializers.normal(stddev=1./self.nchannels**0.5),
250                (nchannels_out, self.nchannels, self.lmax + 1),
251            ),
252            nrep,
253            axis=-1,
254        )
255        out = jnp.einsum("ijk,...jk->...ik", weights, x)
256        if self.squeeze and nchannels_out == 1:
257            out = jnp.squeeze(out, axis=-2)
258        ########################################
259
260        if self.input_key is not None:
261            output_key = self.name if self.output_key is None else self.output_key
262            return {**inputs, output_key: out} if output_key is not None else out
263        return out
264
265
266class SphericalToCartesian(nn.Module):
267    """Convert spherical tensors to cartesian tensors.
268    
269    FID: SPHERICAL_TO_CARTESIAN
270    """
271    
272    lmax: int
273    """The maximum order of the spherical tensor. Only implemented for lmax up to 2."""
274    input_key: Optional[str] = None
275    """The key in the input dictionary that corresponds to the input tensor."""
276    output_key: Optional[str] = None
277    """The key in the output dictionary where the computed tensor will be stored."""
278
279    FID: ClassVar[str] = "SPHERICAL_TO_CARTESIAN"
280
281    @nn.compact
282    def __call__(self, inputs) -> Any:
283        if self.input_key is None:
284            assert not isinstance(
285                inputs, dict
286            ), "input key must be provided if inputs is a dictionary"
287            x = inputs
288        else:
289            x = inputs[self.input_key]
290
291        ########################################
292        out = spherical_to_cartesian_tensor(x, self.lmax)
293        ########################################
294
295        if self.input_key is not None:
296            output_key = self.input_key if self.output_key is None else self.output_key
297            return {**inputs, output_key: out} if output_key is not None else out
298        return out
class FullTensorProduct(flax.linen.module.Module):
22class FullTensorProduct(nn.Module):
23    """Tensor product of two spherical harmonics."""
24
25    lmax1: int
26    """The maximum order of the first spherical tensor."""
27    lmax2: int
28    """The maximum order of the second spherical tensor."""
29    lmax_out: Optional[int] = None
30    """The maximum order of the output spherical tensor. If None, it is the sum of lmax1 and lmax2."""
31    ignore_parity: bool = False
32    """Whether to ignore the parity of the spherical tensors."""
33
34    @nn.compact
35    def __call__(self, x1, x2) -> None:
36        irreps_1 = [(l, (-1) ** l) for l in range(self.lmax1 + 1)]
37        irreps_2 = [(l, (-1) ** l) for l in range(self.lmax2 + 1)]
38        irreps_out = []
39        lsout = []
40        psout = []
41        i12 = []
42
43        lmax_out = self.lmax_out or self.lmax1 + self.lmax2
44
45        for i1, (l1, p1) in enumerate(irreps_1):
46            for i2, (l2, p2) in enumerate(irreps_2):
47                for lout in range(abs(l1 - l2), l1 + l2 + 1):
48                    if p1 * p2 != (-1) ** lout and not self.ignore_parity:
49                        continue
50                    if lout > lmax_out:
51                        continue
52                    lsout.append(lout)
53                    psout.append(p1 * p2)
54                    i12.append((i1, i2))
55
56        lsout = np.array(lsout)
57        psout = np.array(psout)
58        idx = np.lexsort((psout, lsout))
59        lsout = lsout[idx]
60        psout = psout[idx]
61        i12 = [i12[i] for i in idx]
62        irreps_out = [(l, p) for l, p in zip(lsout, psout)]
63
64        slices_1 = [0]
65        for l, p in irreps_1:
66            slices_1.append(slices_1[-1] + 2 * l + 1)
67        slices_2 = [0]
68        for l, p in irreps_2:
69            slices_2.append(slices_2[-1] + 2 * l + 1)
70        slices_out = [0]
71        for l, p in irreps_out:
72            slices_out.append(slices_out[-1] + 2 * l + 1)
73
74        assert slices_1[-1] == (self.lmax1 + 1) ** 2
75        assert slices_2[-1] == (self.lmax2 + 1) ** 2
76
77        shape = ((self.lmax1 + 1) ** 2, (self.lmax2 + 1) ** 2, slices_out[-1])
78        w3js = np.zeros(shape)
79        for iout, (lout, pout) in enumerate(irreps_out):
80            i1, i2 = i12[iout]
81            l1, p1 = irreps_1[i1]
82            l2, p2 = irreps_2[i2]
83            w3j = CG_SO3(l1, l2, lout)
84            scale = (2 * lout + 1) ** 0.5
85            w3js[
86                slices_1[i1] : slices_1[i1 + 1],
87                slices_2[i2] : slices_2[i2 + 1],
88                slices_out[iout] : slices_out[iout + 1],
89            ] = (
90                w3j * scale
91            )
92        w3j = jnp.asarray(w3js)
93
94        return jnp.einsum("...a,...b,abc->...c", x1, x2, w3j), irreps_out

Tensor product of two spherical harmonics.

FullTensorProduct( lmax1: int, lmax2: int, lmax_out: Optional[int] = None, ignore_parity: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
lmax1: int

The maximum order of the first spherical tensor.

lmax2: int

The maximum order of the second spherical tensor.

lmax_out: Optional[int] = None

The maximum order of the output spherical tensor. If None, it is the sum of lmax1 and lmax2.

ignore_parity: bool = False

Whether to ignore the parity of the spherical tensors.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class FilteredTensorProduct(flax.linen.module.Module):
 97class FilteredTensorProduct(nn.Module):
 98    """Tensor product of two spherical harmonics filtered to give back the irreps of the first input"""
 99
100    lmax1: int
101    """The maximum order of the first spherical tensor."""
102    lmax2: int
103    """The maximum order of the second spherical tensor."""
104    lmax_out: Optional[int] = None
105    """The maximum order of the output spherical tensor. If None, it is the same as lmax1."""
106    ignore_parity: bool = False
107    """Whether to ignore the parity of the spherical tensors."""
108    weights_by_channel: bool = False
109    """Whether to use different path weights for each channel."""
110
111    @nn.compact
112    def __call__(self, x1, x2) -> None:
113        irreps_1 = [(l, (-1) ** l) for l in range(self.lmax1 + 1)]
114        irreps_2 = [(l, (-1) ** l) for l in range(self.lmax2 + 1)]
115        lmax_out = self.lmax_out if self.lmax_out is not None else self.lmax1
116        irreps_out = [(l, (-1) ** l) for l in range(lmax_out + 1)]
117
118        slices_1 = [0]
119        for l, p in irreps_1:
120            slices_1.append(slices_1[-1] + 2 * l + 1)
121        slices_2 = [0]
122        for l, p in irreps_2:
123            slices_2.append(slices_2[-1] + 2 * l + 1)
124        slices_out = [0]
125        for l, p in irreps_out:
126            slices_out.append(slices_out[-1] + 2 * l + 1)
127
128        shape = (x1.shape[-1], x2.shape[-1], x1.shape[-1])
129        w3js = []
130        for iout, (lout, pout) in enumerate(irreps_out):
131            for i1, (l1, p1) in enumerate(irreps_1):
132                for i2, (l2, p2) in enumerate(irreps_2):
133                    if pout != p1 * p2 and not self.ignore_parity:
134                        continue
135                    if lout > l1 + l2 or lout < abs(l1 - l2):
136                        continue
137                    w3j = CG_SO3(l1, l2, lout)
138                    w3j_full = np.zeros(shape)
139                    scale = (2 * lout + 1) ** 0.5
140                    w3j_full[
141                        slices_1[i1] : slices_1[i1 + 1],
142                        slices_2[i2] : slices_2[i2 + 1],
143                        slices_out[iout] : slices_out[iout + 1],
144                    ] = (
145                        w3j * scale
146                    )
147                    w3js.append(w3j_full)
148        npath = len(w3js)
149        w3j = jnp.asarray(np.stack(w3js))
150
151        if self.weights_by_channel:
152            nchannels = x1.shape[-2]
153            weights = self.param("weights", jax.nn.initializers.normal(stddev=1./npath**0.5), (nchannels,npath,))
154            ww3j = jnp.einsum("np,pabc->nabc", weights, w3j)
155            return jnp.einsum("...na,...nb,nabc->...nc", x1, x2, ww3j)
156
157
158        weights = self.param("weights", jax.nn.initializers.normal(stddev=1./npath**0.5), (npath,))
159
160        ww3j = jnp.einsum("p,pabc->abc", weights, w3j)
161        return jnp.einsum("...a,...b,abc->...c", x1, x2, ww3j)

Tensor product of two spherical harmonics filtered to give back the irreps of the first input

FilteredTensorProduct( lmax1: int, lmax2: int, lmax_out: Optional[int] = None, ignore_parity: bool = False, weights_by_channel: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
lmax1: int

The maximum order of the first spherical tensor.

lmax2: int

The maximum order of the second spherical tensor.

lmax_out: Optional[int] = None

The maximum order of the output spherical tensor. If None, it is the same as lmax1.

ignore_parity: bool = False

Whether to ignore the parity of the spherical tensors.

weights_by_channel: bool = False

Whether to use different path weights for each channel.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChannelMixing(flax.linen.module.Module):
164class ChannelMixing(nn.Module):
165    """Linear mixing of input channels.
166    
167    FID: CHANNEL_MIXING
168    """
169
170    lmax: int
171    """The maximum order of the spherical tensor."""
172    nchannels: int
173    """The number of input channels."""
174    nchannels_out: Optional[int] = None
175    """The number of output channels. If None, it is the same as nchannels."""
176    input_key: Optional[str] = None
177    """The key in the input dictionary that corresponds to the input tensor."""
178    output_key: Optional[str] = None
179    """The key in the output dictionary where the computed tensor will be stored."""
180    squeeze: bool = False
181    """Whether to squeeze the output tensor to remove the channel dimension."""
182
183    FID: ClassVar[str] = "CHANNEL_MIXING"
184
185    @nn.compact
186    def __call__(self, inputs):
187        if self.input_key is None:
188            assert not isinstance(
189                inputs, dict
190            ), "input key must be provided if inputs is a dictionary"
191            x = inputs
192        else:
193            x = inputs[self.input_key]
194
195        ########################################
196        nchannels_out = self.nchannels_out or self.nchannels
197        weights = self.param(
198            "weights",
199            jax.nn.initializers.normal(stddev=1./self.nchannels**0.5),
200            (nchannels_out, self.nchannels),
201        )
202        out = jnp.einsum("ij,...jk->...ik", weights, x)
203        if self.squeeze and nchannels_out == 1:
204            out = jnp.squeeze(out, axis=-2)
205        ########################################
206
207        if self.input_key is not None:
208            output_key = self.name if self.output_key is None else self.output_key
209            return {**inputs, output_key: out} if output_key is not None else out
210        return out

Linear mixing of input channels.

FID: CHANNEL_MIXING

ChannelMixing( lmax: int, nchannels: int, nchannels_out: Optional[int] = None, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
lmax: int

The maximum order of the spherical tensor.

nchannels: int

The number of input channels.

nchannels_out: Optional[int] = None

The number of output channels. If None, it is the same as nchannels.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the input tensor.

output_key: Optional[str] = None

The key in the output dictionary where the computed tensor will be stored.

squeeze: bool = False

Whether to squeeze the output tensor to remove the channel dimension.

FID: ClassVar[str] = 'CHANNEL_MIXING'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChannelMixingE3(flax.linen.module.Module):
213class ChannelMixingE3(nn.Module):
214    """Linear mixing of input channels with different weight for each angular momentum.
215    
216    FID: CHANNEL_MIXING_E3
217    """
218
219    lmax: int
220    """The maximum order of the spherical tensor."""
221    nchannels: int
222    """The number of input channels."""
223    nchannels_out: Optional[int] = None
224    """The number of output channels. If None, it is the same as nchannels."""
225    input_key: Optional[str] = None
226    """The key in the input dictionary that corresponds to the input tensor."""
227    output_key: Optional[str] = None
228    """The key in the output dictionary where the computed tensor will be stored."""
229    squeeze: bool = False
230    """Whether to squeeze the output tensor to remove the channel dimension."""
231
232    FID: ClassVar[str] = "CHANNEL_MIXING_E3"
233
234    @nn.compact
235    def __call__(self, inputs):
236        if self.input_key is None:
237            assert not isinstance(
238                inputs, dict
239            ), "input key must be provided if inputs is a dictionary"
240            x = inputs
241        else:
242            x = inputs[self.input_key]
243
244        ########################################
245        nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)])
246        nchannels_out = self.nchannels_out or self.nchannels
247        weights = jnp.repeat(
248            self.param(
249                "weights",
250                jax.nn.initializers.normal(stddev=1./self.nchannels**0.5),
251                (nchannels_out, self.nchannels, self.lmax + 1),
252            ),
253            nrep,
254            axis=-1,
255        )
256        out = jnp.einsum("ijk,...jk->...ik", weights, x)
257        if self.squeeze and nchannels_out == 1:
258            out = jnp.squeeze(out, axis=-2)
259        ########################################
260
261        if self.input_key is not None:
262            output_key = self.name if self.output_key is None else self.output_key
263            return {**inputs, output_key: out} if output_key is not None else out
264        return out

Linear mixing of input channels with different weight for each angular momentum.

FID: CHANNEL_MIXING_E3

ChannelMixingE3( lmax: int, nchannels: int, nchannels_out: Optional[int] = None, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
lmax: int

The maximum order of the spherical tensor.

nchannels: int

The number of input channels.

nchannels_out: Optional[int] = None

The number of output channels. If None, it is the same as nchannels.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the input tensor.

output_key: Optional[str] = None

The key in the output dictionary where the computed tensor will be stored.

squeeze: bool = False

Whether to squeeze the output tensor to remove the channel dimension.

FID: ClassVar[str] = 'CHANNEL_MIXING_E3'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SphericalToCartesian(flax.linen.module.Module):
267class SphericalToCartesian(nn.Module):
268    """Convert spherical tensors to cartesian tensors.
269    
270    FID: SPHERICAL_TO_CARTESIAN
271    """
272    
273    lmax: int
274    """The maximum order of the spherical tensor. Only implemented for lmax up to 2."""
275    input_key: Optional[str] = None
276    """The key in the input dictionary that corresponds to the input tensor."""
277    output_key: Optional[str] = None
278    """The key in the output dictionary where the computed tensor will be stored."""
279
280    FID: ClassVar[str] = "SPHERICAL_TO_CARTESIAN"
281
282    @nn.compact
283    def __call__(self, inputs) -> Any:
284        if self.input_key is None:
285            assert not isinstance(
286                inputs, dict
287            ), "input key must be provided if inputs is a dictionary"
288            x = inputs
289        else:
290            x = inputs[self.input_key]
291
292        ########################################
293        out = spherical_to_cartesian_tensor(x, self.lmax)
294        ########################################
295
296        if self.input_key is not None:
297            output_key = self.input_key if self.output_key is None else self.output_key
298            return {**inputs, output_key: out} if output_key is not None else out
299        return out

Convert spherical tensors to cartesian tensors.

FID: SPHERICAL_TO_CARTESIAN

SphericalToCartesian( lmax: int, input_key: Optional[str] = None, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
lmax: int

The maximum order of the spherical tensor. Only implemented for lmax up to 2.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the input tensor.

output_key: Optional[str] = None

The key in the output dictionary where the computed tensor will be stored.

FID: ClassVar[str] = 'SPHERICAL_TO_CARTESIAN'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None