fennol.models.misc.e3
1import flax.linen as nn 2from typing import Any, Optional, ClassVar 3import jax.numpy as jnp 4import jax 5import numpy as np 6from ...utils.spherical_harmonics import CG_SO3, spherical_to_cartesian_tensor 7 8### e3nn version 9try: 10 import e3nn_jax as e3nn 11 E3NN_AVAILABLE = True 12 E3NN_EXCEPTION = None 13 Irreps = e3nn.Irreps 14except Exception as e: 15 E3NN_AVAILABLE = False 16 E3NN_EXCEPTION = e 17 class Irreps(tuple): 18 pass 19 20 21class FullTensorProduct(nn.Module): 22 """Tensor product of two spherical harmonics.""" 23 24 lmax1: int 25 """The maximum order of the first spherical tensor.""" 26 lmax2: int 27 """The maximum order of the second spherical tensor.""" 28 lmax_out: Optional[int] = None 29 """The maximum order of the output spherical tensor. If None, it is the sum of lmax1 and lmax2.""" 30 ignore_parity: bool = False 31 """Whether to ignore the parity of the spherical tensors.""" 32 33 @nn.compact 34 def __call__(self, x1, x2) -> None: 35 irreps_1 = [(l, (-1) ** l) for l in range(self.lmax1 + 1)] 36 irreps_2 = [(l, (-1) ** l) for l in range(self.lmax2 + 1)] 37 irreps_out = [] 38 lsout = [] 39 psout = [] 40 i12 = [] 41 42 lmax_out = self.lmax_out or self.lmax1 + self.lmax2 43 44 for i1, (l1, p1) in enumerate(irreps_1): 45 for i2, (l2, p2) in enumerate(irreps_2): 46 for lout in range(abs(l1 - l2), l1 + l2 + 1): 47 if p1 * p2 != (-1) ** lout and not self.ignore_parity: 48 continue 49 if lout > lmax_out: 50 continue 51 lsout.append(lout) 52 psout.append(p1 * p2) 53 i12.append((i1, i2)) 54 55 lsout = np.array(lsout) 56 psout = np.array(psout) 57 idx = np.lexsort((psout, lsout)) 58 lsout = lsout[idx] 59 psout = psout[idx] 60 i12 = [i12[i] for i in idx] 61 irreps_out = [(l, p) for l, p in zip(lsout, psout)] 62 63 slices_1 = [0] 64 for l, p in irreps_1: 65 slices_1.append(slices_1[-1] + 2 * l + 1) 66 slices_2 = [0] 67 for l, p in irreps_2: 68 slices_2.append(slices_2[-1] + 2 * l + 1) 69 slices_out = [0] 70 for l, p in irreps_out: 71 slices_out.append(slices_out[-1] + 2 * l + 1) 72 73 assert slices_1[-1] == (self.lmax1 + 1) ** 2 74 assert slices_2[-1] == (self.lmax2 + 1) ** 2 75 76 shape = ((self.lmax1 + 1) ** 2, (self.lmax2 + 1) ** 2, slices_out[-1]) 77 w3js = np.zeros(shape) 78 for iout, (lout, pout) in enumerate(irreps_out): 79 i1, i2 = i12[iout] 80 l1, p1 = irreps_1[i1] 81 l2, p2 = irreps_2[i2] 82 w3j = CG_SO3(l1, l2, lout) 83 scale = (2 * lout + 1) ** 0.5 84 w3js[ 85 slices_1[i1] : slices_1[i1 + 1], 86 slices_2[i2] : slices_2[i2 + 1], 87 slices_out[iout] : slices_out[iout + 1], 88 ] = ( 89 w3j * scale 90 ) 91 w3j = jnp.asarray(w3js) 92 93 return jnp.einsum("...a,...b,abc->...c", x1, x2, w3j), irreps_out 94 95 96class FilteredTensorProduct(nn.Module): 97 """Tensor product of two spherical harmonics filtered to give back the irreps of the first input""" 98 99 lmax1: int 100 """The maximum order of the first spherical tensor.""" 101 lmax2: int 102 """The maximum order of the second spherical tensor.""" 103 lmax_out: Optional[int] = None 104 """The maximum order of the output spherical tensor. If None, it is the same as lmax1.""" 105 ignore_parity: bool = False 106 """Whether to ignore the parity of the spherical tensors.""" 107 weights_by_channel: bool = False 108 """Whether to use different path weights for each channel.""" 109 110 @nn.compact 111 def __call__(self, x1, x2) -> None: 112 irreps_1 = [(l, (-1) ** l) for l in range(self.lmax1 + 1)] 113 irreps_2 = [(l, (-1) ** l) for l in range(self.lmax2 + 1)] 114 lmax_out = self.lmax_out if self.lmax_out is not None else self.lmax1 115 irreps_out = [(l, (-1) ** l) for l in range(lmax_out + 1)] 116 117 slices_1 = [0] 118 for l, p in irreps_1: 119 slices_1.append(slices_1[-1] + 2 * l + 1) 120 slices_2 = [0] 121 for l, p in irreps_2: 122 slices_2.append(slices_2[-1] + 2 * l + 1) 123 slices_out = [0] 124 for l, p in irreps_out: 125 slices_out.append(slices_out[-1] + 2 * l + 1) 126 127 shape = (x1.shape[-1], x2.shape[-1], x1.shape[-1]) 128 w3js = [] 129 for iout, (lout, pout) in enumerate(irreps_out): 130 for i1, (l1, p1) in enumerate(irreps_1): 131 for i2, (l2, p2) in enumerate(irreps_2): 132 if pout != p1 * p2 and not self.ignore_parity: 133 continue 134 if lout > l1 + l2 or lout < abs(l1 - l2): 135 continue 136 w3j = CG_SO3(l1, l2, lout) 137 w3j_full = np.zeros(shape) 138 scale = (2 * lout + 1) ** 0.5 139 w3j_full[ 140 slices_1[i1] : slices_1[i1 + 1], 141 slices_2[i2] : slices_2[i2 + 1], 142 slices_out[iout] : slices_out[iout + 1], 143 ] = ( 144 w3j * scale 145 ) 146 w3js.append(w3j_full) 147 npath = len(w3js) 148 w3j = jnp.asarray(np.stack(w3js)) 149 150 if self.weights_by_channel: 151 nchannels = x1.shape[-2] 152 weights = self.param("weights", jax.nn.initializers.normal(stddev=1./npath**0.5), (nchannels,npath,)) 153 ww3j = jnp.einsum("np,pabc->nabc", weights, w3j) 154 return jnp.einsum("...na,...nb,nabc->...nc", x1, x2, ww3j) 155 156 157 weights = self.param("weights", jax.nn.initializers.normal(stddev=1./npath**0.5), (npath,)) 158 159 ww3j = jnp.einsum("p,pabc->abc", weights, w3j) 160 return jnp.einsum("...a,...b,abc->...c", x1, x2, ww3j) 161 162 163class ChannelMixing(nn.Module): 164 """Linear mixing of input channels. 165 166 FID: CHANNEL_MIXING 167 """ 168 169 lmax: int 170 """The maximum order of the spherical tensor.""" 171 nchannels: int 172 """The number of input channels.""" 173 nchannels_out: Optional[int] = None 174 """The number of output channels. If None, it is the same as nchannels.""" 175 input_key: Optional[str] = None 176 """The key in the input dictionary that corresponds to the input tensor.""" 177 output_key: Optional[str] = None 178 """The key in the output dictionary where the computed tensor will be stored.""" 179 squeeze: bool = False 180 """Whether to squeeze the output tensor to remove the channel dimension.""" 181 182 FID: ClassVar[str] = "CHANNEL_MIXING" 183 184 @nn.compact 185 def __call__(self, inputs): 186 if self.input_key is None: 187 assert not isinstance( 188 inputs, dict 189 ), "input key must be provided if inputs is a dictionary" 190 x = inputs 191 else: 192 x = inputs[self.input_key] 193 194 ######################################## 195 nchannels_out = self.nchannels_out or self.nchannels 196 weights = self.param( 197 "weights", 198 jax.nn.initializers.normal(stddev=1./self.nchannels**0.5), 199 (nchannels_out, self.nchannels), 200 ) 201 out = jnp.einsum("ij,...jk->...ik", weights, x) 202 if self.squeeze and nchannels_out == 1: 203 out = jnp.squeeze(out, axis=-2) 204 ######################################## 205 206 if self.input_key is not None: 207 output_key = self.name if self.output_key is None else self.output_key 208 return {**inputs, output_key: out} if output_key is not None else out 209 return out 210 211 212class ChannelMixingE3(nn.Module): 213 """Linear mixing of input channels with different weight for each angular momentum. 214 215 FID: CHANNEL_MIXING_E3 216 """ 217 218 lmax: int 219 """The maximum order of the spherical tensor.""" 220 nchannels: int 221 """The number of input channels.""" 222 nchannels_out: Optional[int] = None 223 """The number of output channels. If None, it is the same as nchannels.""" 224 input_key: Optional[str] = None 225 """The key in the input dictionary that corresponds to the input tensor.""" 226 output_key: Optional[str] = None 227 """The key in the output dictionary where the computed tensor will be stored.""" 228 squeeze: bool = False 229 """Whether to squeeze the output tensor to remove the channel dimension.""" 230 231 FID: ClassVar[str] = "CHANNEL_MIXING_E3" 232 233 @nn.compact 234 def __call__(self, inputs): 235 if self.input_key is None: 236 assert not isinstance( 237 inputs, dict 238 ), "input key must be provided if inputs is a dictionary" 239 x = inputs 240 else: 241 x = inputs[self.input_key] 242 243 ######################################## 244 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 245 nchannels_out = self.nchannels_out or self.nchannels 246 weights = jnp.repeat( 247 self.param( 248 "weights", 249 jax.nn.initializers.normal(stddev=1./self.nchannels**0.5), 250 (nchannels_out, self.nchannels, self.lmax + 1), 251 ), 252 nrep, 253 axis=-1, 254 ) 255 out = jnp.einsum("ijk,...jk->...ik", weights, x) 256 if self.squeeze and nchannels_out == 1: 257 out = jnp.squeeze(out, axis=-2) 258 ######################################## 259 260 if self.input_key is not None: 261 output_key = self.name if self.output_key is None else self.output_key 262 return {**inputs, output_key: out} if output_key is not None else out 263 return out 264 265 266class SphericalToCartesian(nn.Module): 267 """Convert spherical tensors to cartesian tensors. 268 269 FID: SPHERICAL_TO_CARTESIAN 270 """ 271 272 lmax: int 273 """The maximum order of the spherical tensor. Only implemented for lmax up to 2.""" 274 input_key: Optional[str] = None 275 """The key in the input dictionary that corresponds to the input tensor.""" 276 output_key: Optional[str] = None 277 """The key in the output dictionary where the computed tensor will be stored.""" 278 279 FID: ClassVar[str] = "SPHERICAL_TO_CARTESIAN" 280 281 @nn.compact 282 def __call__(self, inputs) -> Any: 283 if self.input_key is None: 284 assert not isinstance( 285 inputs, dict 286 ), "input key must be provided if inputs is a dictionary" 287 x = inputs 288 else: 289 x = inputs[self.input_key] 290 291 ######################################## 292 out = spherical_to_cartesian_tensor(x, self.lmax) 293 ######################################## 294 295 if self.input_key is not None: 296 output_key = self.input_key if self.output_key is None else self.output_key 297 return {**inputs, output_key: out} if output_key is not None else out 298 return out
22class FullTensorProduct(nn.Module): 23 """Tensor product of two spherical harmonics.""" 24 25 lmax1: int 26 """The maximum order of the first spherical tensor.""" 27 lmax2: int 28 """The maximum order of the second spherical tensor.""" 29 lmax_out: Optional[int] = None 30 """The maximum order of the output spherical tensor. If None, it is the sum of lmax1 and lmax2.""" 31 ignore_parity: bool = False 32 """Whether to ignore the parity of the spherical tensors.""" 33 34 @nn.compact 35 def __call__(self, x1, x2) -> None: 36 irreps_1 = [(l, (-1) ** l) for l in range(self.lmax1 + 1)] 37 irreps_2 = [(l, (-1) ** l) for l in range(self.lmax2 + 1)] 38 irreps_out = [] 39 lsout = [] 40 psout = [] 41 i12 = [] 42 43 lmax_out = self.lmax_out or self.lmax1 + self.lmax2 44 45 for i1, (l1, p1) in enumerate(irreps_1): 46 for i2, (l2, p2) in enumerate(irreps_2): 47 for lout in range(abs(l1 - l2), l1 + l2 + 1): 48 if p1 * p2 != (-1) ** lout and not self.ignore_parity: 49 continue 50 if lout > lmax_out: 51 continue 52 lsout.append(lout) 53 psout.append(p1 * p2) 54 i12.append((i1, i2)) 55 56 lsout = np.array(lsout) 57 psout = np.array(psout) 58 idx = np.lexsort((psout, lsout)) 59 lsout = lsout[idx] 60 psout = psout[idx] 61 i12 = [i12[i] for i in idx] 62 irreps_out = [(l, p) for l, p in zip(lsout, psout)] 63 64 slices_1 = [0] 65 for l, p in irreps_1: 66 slices_1.append(slices_1[-1] + 2 * l + 1) 67 slices_2 = [0] 68 for l, p in irreps_2: 69 slices_2.append(slices_2[-1] + 2 * l + 1) 70 slices_out = [0] 71 for l, p in irreps_out: 72 slices_out.append(slices_out[-1] + 2 * l + 1) 73 74 assert slices_1[-1] == (self.lmax1 + 1) ** 2 75 assert slices_2[-1] == (self.lmax2 + 1) ** 2 76 77 shape = ((self.lmax1 + 1) ** 2, (self.lmax2 + 1) ** 2, slices_out[-1]) 78 w3js = np.zeros(shape) 79 for iout, (lout, pout) in enumerate(irreps_out): 80 i1, i2 = i12[iout] 81 l1, p1 = irreps_1[i1] 82 l2, p2 = irreps_2[i2] 83 w3j = CG_SO3(l1, l2, lout) 84 scale = (2 * lout + 1) ** 0.5 85 w3js[ 86 slices_1[i1] : slices_1[i1 + 1], 87 slices_2[i2] : slices_2[i2 + 1], 88 slices_out[iout] : slices_out[iout + 1], 89 ] = ( 90 w3j * scale 91 ) 92 w3j = jnp.asarray(w3js) 93 94 return jnp.einsum("...a,...b,abc->...c", x1, x2, w3j), irreps_out
Tensor product of two spherical harmonics.
The maximum order of the output spherical tensor. If None, it is the sum of lmax1 and lmax2.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
97class FilteredTensorProduct(nn.Module): 98 """Tensor product of two spherical harmonics filtered to give back the irreps of the first input""" 99 100 lmax1: int 101 """The maximum order of the first spherical tensor.""" 102 lmax2: int 103 """The maximum order of the second spherical tensor.""" 104 lmax_out: Optional[int] = None 105 """The maximum order of the output spherical tensor. If None, it is the same as lmax1.""" 106 ignore_parity: bool = False 107 """Whether to ignore the parity of the spherical tensors.""" 108 weights_by_channel: bool = False 109 """Whether to use different path weights for each channel.""" 110 111 @nn.compact 112 def __call__(self, x1, x2) -> None: 113 irreps_1 = [(l, (-1) ** l) for l in range(self.lmax1 + 1)] 114 irreps_2 = [(l, (-1) ** l) for l in range(self.lmax2 + 1)] 115 lmax_out = self.lmax_out if self.lmax_out is not None else self.lmax1 116 irreps_out = [(l, (-1) ** l) for l in range(lmax_out + 1)] 117 118 slices_1 = [0] 119 for l, p in irreps_1: 120 slices_1.append(slices_1[-1] + 2 * l + 1) 121 slices_2 = [0] 122 for l, p in irreps_2: 123 slices_2.append(slices_2[-1] + 2 * l + 1) 124 slices_out = [0] 125 for l, p in irreps_out: 126 slices_out.append(slices_out[-1] + 2 * l + 1) 127 128 shape = (x1.shape[-1], x2.shape[-1], x1.shape[-1]) 129 w3js = [] 130 for iout, (lout, pout) in enumerate(irreps_out): 131 for i1, (l1, p1) in enumerate(irreps_1): 132 for i2, (l2, p2) in enumerate(irreps_2): 133 if pout != p1 * p2 and not self.ignore_parity: 134 continue 135 if lout > l1 + l2 or lout < abs(l1 - l2): 136 continue 137 w3j = CG_SO3(l1, l2, lout) 138 w3j_full = np.zeros(shape) 139 scale = (2 * lout + 1) ** 0.5 140 w3j_full[ 141 slices_1[i1] : slices_1[i1 + 1], 142 slices_2[i2] : slices_2[i2 + 1], 143 slices_out[iout] : slices_out[iout + 1], 144 ] = ( 145 w3j * scale 146 ) 147 w3js.append(w3j_full) 148 npath = len(w3js) 149 w3j = jnp.asarray(np.stack(w3js)) 150 151 if self.weights_by_channel: 152 nchannels = x1.shape[-2] 153 weights = self.param("weights", jax.nn.initializers.normal(stddev=1./npath**0.5), (nchannels,npath,)) 154 ww3j = jnp.einsum("np,pabc->nabc", weights, w3j) 155 return jnp.einsum("...na,...nb,nabc->...nc", x1, x2, ww3j) 156 157 158 weights = self.param("weights", jax.nn.initializers.normal(stddev=1./npath**0.5), (npath,)) 159 160 ww3j = jnp.einsum("p,pabc->abc", weights, w3j) 161 return jnp.einsum("...a,...b,abc->...c", x1, x2, ww3j)
Tensor product of two spherical harmonics filtered to give back the irreps of the first input
The maximum order of the output spherical tensor. If None, it is the same as lmax1.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
164class ChannelMixing(nn.Module): 165 """Linear mixing of input channels. 166 167 FID: CHANNEL_MIXING 168 """ 169 170 lmax: int 171 """The maximum order of the spherical tensor.""" 172 nchannels: int 173 """The number of input channels.""" 174 nchannels_out: Optional[int] = None 175 """The number of output channels. If None, it is the same as nchannels.""" 176 input_key: Optional[str] = None 177 """The key in the input dictionary that corresponds to the input tensor.""" 178 output_key: Optional[str] = None 179 """The key in the output dictionary where the computed tensor will be stored.""" 180 squeeze: bool = False 181 """Whether to squeeze the output tensor to remove the channel dimension.""" 182 183 FID: ClassVar[str] = "CHANNEL_MIXING" 184 185 @nn.compact 186 def __call__(self, inputs): 187 if self.input_key is None: 188 assert not isinstance( 189 inputs, dict 190 ), "input key must be provided if inputs is a dictionary" 191 x = inputs 192 else: 193 x = inputs[self.input_key] 194 195 ######################################## 196 nchannels_out = self.nchannels_out or self.nchannels 197 weights = self.param( 198 "weights", 199 jax.nn.initializers.normal(stddev=1./self.nchannels**0.5), 200 (nchannels_out, self.nchannels), 201 ) 202 out = jnp.einsum("ij,...jk->...ik", weights, x) 203 if self.squeeze and nchannels_out == 1: 204 out = jnp.squeeze(out, axis=-2) 205 ######################################## 206 207 if self.input_key is not None: 208 output_key = self.name if self.output_key is None else self.output_key 209 return {**inputs, output_key: out} if output_key is not None else out 210 return out
Linear mixing of input channels.
FID: CHANNEL_MIXING
The number of output channels. If None, it is the same as nchannels.
The key in the input dictionary that corresponds to the input tensor.
The key in the output dictionary where the computed tensor will be stored.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
213class ChannelMixingE3(nn.Module): 214 """Linear mixing of input channels with different weight for each angular momentum. 215 216 FID: CHANNEL_MIXING_E3 217 """ 218 219 lmax: int 220 """The maximum order of the spherical tensor.""" 221 nchannels: int 222 """The number of input channels.""" 223 nchannels_out: Optional[int] = None 224 """The number of output channels. If None, it is the same as nchannels.""" 225 input_key: Optional[str] = None 226 """The key in the input dictionary that corresponds to the input tensor.""" 227 output_key: Optional[str] = None 228 """The key in the output dictionary where the computed tensor will be stored.""" 229 squeeze: bool = False 230 """Whether to squeeze the output tensor to remove the channel dimension.""" 231 232 FID: ClassVar[str] = "CHANNEL_MIXING_E3" 233 234 @nn.compact 235 def __call__(self, inputs): 236 if self.input_key is None: 237 assert not isinstance( 238 inputs, dict 239 ), "input key must be provided if inputs is a dictionary" 240 x = inputs 241 else: 242 x = inputs[self.input_key] 243 244 ######################################## 245 nrep = np.array([2 * l + 1 for l in range(self.lmax + 1)]) 246 nchannels_out = self.nchannels_out or self.nchannels 247 weights = jnp.repeat( 248 self.param( 249 "weights", 250 jax.nn.initializers.normal(stddev=1./self.nchannels**0.5), 251 (nchannels_out, self.nchannels, self.lmax + 1), 252 ), 253 nrep, 254 axis=-1, 255 ) 256 out = jnp.einsum("ijk,...jk->...ik", weights, x) 257 if self.squeeze and nchannels_out == 1: 258 out = jnp.squeeze(out, axis=-2) 259 ######################################## 260 261 if self.input_key is not None: 262 output_key = self.name if self.output_key is None else self.output_key 263 return {**inputs, output_key: out} if output_key is not None else out 264 return out
Linear mixing of input channels with different weight for each angular momentum.
FID: CHANNEL_MIXING_E3
The number of output channels. If None, it is the same as nchannels.
The key in the input dictionary that corresponds to the input tensor.
The key in the output dictionary where the computed tensor will be stored.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
267class SphericalToCartesian(nn.Module): 268 """Convert spherical tensors to cartesian tensors. 269 270 FID: SPHERICAL_TO_CARTESIAN 271 """ 272 273 lmax: int 274 """The maximum order of the spherical tensor. Only implemented for lmax up to 2.""" 275 input_key: Optional[str] = None 276 """The key in the input dictionary that corresponds to the input tensor.""" 277 output_key: Optional[str] = None 278 """The key in the output dictionary where the computed tensor will be stored.""" 279 280 FID: ClassVar[str] = "SPHERICAL_TO_CARTESIAN" 281 282 @nn.compact 283 def __call__(self, inputs) -> Any: 284 if self.input_key is None: 285 assert not isinstance( 286 inputs, dict 287 ), "input key must be provided if inputs is a dictionary" 288 x = inputs 289 else: 290 x = inputs[self.input_key] 291 292 ######################################## 293 out = spherical_to_cartesian_tensor(x, self.lmax) 294 ######################################## 295 296 if self.input_key is not None: 297 output_key = self.input_key if self.output_key is None else self.output_key 298 return {**inputs, output_key: out} if output_key is not None else out 299 return out
Convert spherical tensors to cartesian tensors.
FID: SPHERICAL_TO_CARTESIAN
The key in the input dictionary that corresponds to the input tensor.
The key in the output dictionary where the computed tensor will be stored.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.