fennol.models.misc.uncertainty
1import flax.linen as nn 2from typing import Any, Sequence, Callable, Union, ClassVar,Optional 3import jax.numpy as jnp 4import jax 5import numpy as np 6from functools import partial 7import dataclasses 8from ...utils.periodic_table import PERIODIC_TABLE 9 10 11class EnsembleStatistics(nn.Module): 12 """Computes the mean and variance of an ensemble. 13 14 FID: ENSEMBLE_STAT 15 """ 16 17 key: str 18 """The key to access the input data from the `inputs` dictionary.""" 19 axis: int = -1 20 """The axis along which to compute the mean and variance.""" 21 shuffle_ensemble: bool = False 22 """Whether to shuffle the ensemble.""" 23 weights_key: Optional[str] = None 24 mean_key: Optional[str] = None 25 26 FID: ClassVar[str] = "ENSEMBLE_STAT" 27 28 @nn.compact 29 def __call__(self, inputs) -> Any: 30 x = inputs[self.key] 31 if self.weights_key is not None: 32 weights = inputs[self.weights_key] 33 else: 34 weights = jnp.ones_like(x) 35 weights = weights/jnp.sum(weights,axis=self.axis,keepdims=True) 36 mean = jnp.sum(x*weights, axis=self.axis,keepdims=True) 37 38 nsamples = x.shape[self.axis] 39 if nsamples == 1: 40 var = jnp.zeros_like(mean) 41 else: 42 # var = jnp.var(x, axis=self.axis, ddof=1) 43 var = jnp.sum(weights*(x-mean)**2,axis=self.axis)*(nsamples/(nsamples-1.)) 44 45 mean = jnp.squeeze(mean,axis=self.axis) 46 output = {**inputs, self.key + "_mean": mean, self.key + "_var": var} 47 48 if self.mean_key is not None: 49 output[self.mean_key] = mean 50 51 training = "training" in inputs.get("flags", {}) 52 if self.shuffle_ensemble and training and "rng_key" in inputs: 53 key, subkey = jax.random.split(inputs["rng_key"]) 54 x = jax.random.permutation(subkey, x, axis=self.axis, independent=True) 55 output[self.key] = x 56 output["rng_key"] = key 57 return output 58 59 60class EnsembleShift(nn.Module): 61 """Shifts the mean of an ensemble to match a reference tensor. 62 63 FID: ENSEMBLE_SHIFT 64 """ 65 66 key: str 67 """The key to access the input data from the `inputs` dictionary.""" 68 ref_key: str 69 """The key to access the reference data from the `inputs` dictionary.""" 70 axis: int = -1 71 """The axis of the ensemble.""" 72 output_key: Optional[str] = None 73 """The key of the output ensemble. If None, the input key will be used.""" 74 75 76 FID: ClassVar[str] = "ENSEMBLE_SHIFT" 77 78 @nn.compact 79 def __call__(self, inputs) -> Any: 80 x = inputs[self.key] 81 mean = jnp.mean(x, axis=self.axis, keepdims=True) 82 ref = inputs[self.ref_key].reshape(mean.shape) 83 x = x - mean + ref 84 output_key = self.key if self.output_key is None else self.output_key 85 return {**inputs, output_key: x} 86 87 88class ConstrainEvidence(nn.Module): 89 """ Constrain the parameters of an evidential model 90 91 FID: CONSTRAIN_EVIDENCE 92 93 ### References 94 - Amini et al, Deep Evidential Regression, NeurIPS 2020 (https://arxiv.org/abs/1910.02600) 95 - Meinert et al, Multivariate Deep Evidential Regression, (https://arxiv.org/pdf/2104.06135.pdf) 96 97 """ 98 key: str 99 """The key to access the input data from the `inputs` dictionary.""" 100 output_key: Optional[str] = None 101 """The key to use for the constrained paramters in the `output` dictionary.""" 102 beta_scale: Union[str, float] = 1.0 103 """The scale of the beta parameter.""" 104 alpha_init: float = 2.0 105 """The initial value of the alpha parameter.""" 106 nu_init: float = 1.0 107 """The initial value of the nu parameter.""" 108 chemical_shift: Optional[float] = None 109 """The initial chemical shift of evidence.""" 110 trainable_beta: bool = False 111 """Whether the beta parameter is trainable.""" 112 constant_beta: bool = False 113 """Whether the beta parameter is a constant.""" 114 trainable_alpha: bool = False 115 """Whether the alpha parameter is trainable.""" 116 constant_alpha: bool = False 117 """Whether the alpha parameter is a constant.""" 118 trainable_nu: bool = False 119 """Whether the nu parameter is trainable.""" 120 constant_nu: bool = False 121 """Whether the nu parameter is a constant.""" 122 nualpha_coupling: Optional[float] = None 123 """The coupling constant between nu and alpha.""" 124 graph_key: Optional[str] = None 125 """The key to access the graph data from the `inputs` dictionary. 126 Only used to obtain an environment-dependent chemical shift.""" 127 self_weight: float = 10.0 128 """The weight of the self interaction in environment-dependent chemical shift.""" 129 # target_dim: Optional[int] = None 130 131 FID: ClassVar[str] = "CONSTRAIN_EVIDENCE" 132 133 @nn.compact 134 def __call__(self, inputs) -> Any: 135 x = inputs[self.key] 136 dim = 3 137 assert x.shape[-1] == dim, f"Dimension {-1} must be {dim}, got {x.shape[-1]}" 138 nui, alphai, betai = jnp.split(x, 3, axis=-1) 139 140 if self.chemical_shift is not None: 141 # modify nu evidence depending on the chemical species 142 # this is useful for predicting high uncertainties for unknown species 143 # in known geometries 144 nu_shift = jnp.abs( 145 self.param( 146 "nu_shift", 147 lambda key, shape: jnp.ones(shape) * self.chemical_shift, 148 (len(PERIODIC_TABLE),), 149 ) 150 )[inputs["species"]] 151 if self.graph_key is not None: 152 graph = inputs[self.graph_key] 153 edge_dst = graph["edge_dst"] 154 edge_src = graph["edge_src"] 155 switch = graph["switch"] 156 nushift_neigh = jax.ops.segment_sum( 157 nu_shift[edge_dst] * switch, edge_src, nu_shift.shape[0] 158 ) 159 norm = self.self_weight + jax.ops.segment_sum( 160 switch, edge_src, nu_shift.shape[0] 161 ) 162 nu_shift = (self.self_weight * nu_shift + nushift_neigh) / norm 163 nu_shift = nu_shift[:, None] 164 else: 165 nu_shift = 1.0 166 167 if self.nualpha_coupling is None: 168 # couple nu and alpha to remove overparameterization 169 # of the evidential model (see Meinert et al) 170 if self.constant_alpha: 171 if self.trainable_alpha: 172 alphai = 1 + jnp.abs( 173 self.param( 174 "alpha", 175 lambda key: jnp.asarray(self.alpha_init - 1), 176 ) 177 ) * jnp.ones_like(alphai) 178 else: 179 assert self.alpha_init > 1, "alpha_init must be >1" 180 alphai = self.alpha_init * jnp.ones_like(alphai) 181 elif self.constant_nu: 182 alphai = 1 + (self.alpha_init - 1) * nu_shift * jax.nn.softplus(alphai) 183 else: 184 alphai = 1 + (self.alpha_init - 1) * jax.nn.softplus(alphai) 185 186 if self.constant_nu: 187 if self.trainable_nu: 188 nui = 1.0e-5 + jnp.abs( 189 self.param( 190 "nu", 191 lambda key: jnp.asarray(self.nu_init), 192 ) 193 ) * jnp.ones_like(nui) 194 else: 195 nui = self.nu_init * jnp.ones_like(nui) 196 else: 197 nui = 1.0e-5 + self.nu_init * nu_shift * jax.nn.softplus(nui) 198 else: 199 alphai = 1 + nu_shift * jax.nn.softplus(alphai) 200 if self.trainable_nu: 201 nualpha_coupling = jnp.abs( 202 self.param( 203 "nualpha_coupling", 204 lambda key: jnp.asarray(self.nualpha_coupling), 205 ) 206 ) 207 else: 208 nualpha_coupling = self.nualpha_coupling 209 nui = nualpha_coupling * 2 * alphai 210 211 if self.constant_beta: 212 if self.trainable_beta: 213 betai = ( 214 jax.nn.softplus( 215 self.param( 216 "beta", 217 lambda key: jnp.asarray(0.0), 218 ) 219 ) 220 / np.log(2.0) 221 ) * jnp.ones_like(alphai) 222 else: 223 betai = jnp.ones_like(alphai) 224 else: 225 betai = jax.nn.softplus(betai) 226 227 betai = self.beta_scale * betai 228 229 output = jnp.concatenate([nui, alphai, betai], axis=-1) 230 # if self.target_dim is not None: 231 # output = output[...,None,:].repeat(self.target_dim, axis=-2) 232 233 output_key = self.key if self.output_key is None else self.output_key 234 out = { 235 **inputs, 236 output_key: output, 237 output_key + "_var": jnp.squeeze(betai / (nui * (alphai - 1)), axis=-1), 238 output_key + "_aleatoric": jnp.squeeze(betai / (alphai - 1), axis=-1), 239 output_key 240 + "_wst2": jnp.squeeze(betai * (1 + nui) / (alphai * nui), axis=-1), 241 } 242 return out
12class EnsembleStatistics(nn.Module): 13 """Computes the mean and variance of an ensemble. 14 15 FID: ENSEMBLE_STAT 16 """ 17 18 key: str 19 """The key to access the input data from the `inputs` dictionary.""" 20 axis: int = -1 21 """The axis along which to compute the mean and variance.""" 22 shuffle_ensemble: bool = False 23 """Whether to shuffle the ensemble.""" 24 weights_key: Optional[str] = None 25 mean_key: Optional[str] = None 26 27 FID: ClassVar[str] = "ENSEMBLE_STAT" 28 29 @nn.compact 30 def __call__(self, inputs) -> Any: 31 x = inputs[self.key] 32 if self.weights_key is not None: 33 weights = inputs[self.weights_key] 34 else: 35 weights = jnp.ones_like(x) 36 weights = weights/jnp.sum(weights,axis=self.axis,keepdims=True) 37 mean = jnp.sum(x*weights, axis=self.axis,keepdims=True) 38 39 nsamples = x.shape[self.axis] 40 if nsamples == 1: 41 var = jnp.zeros_like(mean) 42 else: 43 # var = jnp.var(x, axis=self.axis, ddof=1) 44 var = jnp.sum(weights*(x-mean)**2,axis=self.axis)*(nsamples/(nsamples-1.)) 45 46 mean = jnp.squeeze(mean,axis=self.axis) 47 output = {**inputs, self.key + "_mean": mean, self.key + "_var": var} 48 49 if self.mean_key is not None: 50 output[self.mean_key] = mean 51 52 training = "training" in inputs.get("flags", {}) 53 if self.shuffle_ensemble and training and "rng_key" in inputs: 54 key, subkey = jax.random.split(inputs["rng_key"]) 55 x = jax.random.permutation(subkey, x, axis=self.axis, independent=True) 56 output[self.key] = x 57 output["rng_key"] = key 58 return output
Computes the mean and variance of an ensemble.
FID: ENSEMBLE_STAT
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
61class EnsembleShift(nn.Module): 62 """Shifts the mean of an ensemble to match a reference tensor. 63 64 FID: ENSEMBLE_SHIFT 65 """ 66 67 key: str 68 """The key to access the input data from the `inputs` dictionary.""" 69 ref_key: str 70 """The key to access the reference data from the `inputs` dictionary.""" 71 axis: int = -1 72 """The axis of the ensemble.""" 73 output_key: Optional[str] = None 74 """The key of the output ensemble. If None, the input key will be used.""" 75 76 77 FID: ClassVar[str] = "ENSEMBLE_SHIFT" 78 79 @nn.compact 80 def __call__(self, inputs) -> Any: 81 x = inputs[self.key] 82 mean = jnp.mean(x, axis=self.axis, keepdims=True) 83 ref = inputs[self.ref_key].reshape(mean.shape) 84 x = x - mean + ref 85 output_key = self.key if self.output_key is None else self.output_key 86 return {**inputs, output_key: x}
Shifts the mean of an ensemble to match a reference tensor.
FID: ENSEMBLE_SHIFT
The key of the output ensemble. If None, the input key will be used.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
89class ConstrainEvidence(nn.Module): 90 """ Constrain the parameters of an evidential model 91 92 FID: CONSTRAIN_EVIDENCE 93 94 ### References 95 - Amini et al, Deep Evidential Regression, NeurIPS 2020 (https://arxiv.org/abs/1910.02600) 96 - Meinert et al, Multivariate Deep Evidential Regression, (https://arxiv.org/pdf/2104.06135.pdf) 97 98 """ 99 key: str 100 """The key to access the input data from the `inputs` dictionary.""" 101 output_key: Optional[str] = None 102 """The key to use for the constrained paramters in the `output` dictionary.""" 103 beta_scale: Union[str, float] = 1.0 104 """The scale of the beta parameter.""" 105 alpha_init: float = 2.0 106 """The initial value of the alpha parameter.""" 107 nu_init: float = 1.0 108 """The initial value of the nu parameter.""" 109 chemical_shift: Optional[float] = None 110 """The initial chemical shift of evidence.""" 111 trainable_beta: bool = False 112 """Whether the beta parameter is trainable.""" 113 constant_beta: bool = False 114 """Whether the beta parameter is a constant.""" 115 trainable_alpha: bool = False 116 """Whether the alpha parameter is trainable.""" 117 constant_alpha: bool = False 118 """Whether the alpha parameter is a constant.""" 119 trainable_nu: bool = False 120 """Whether the nu parameter is trainable.""" 121 constant_nu: bool = False 122 """Whether the nu parameter is a constant.""" 123 nualpha_coupling: Optional[float] = None 124 """The coupling constant between nu and alpha.""" 125 graph_key: Optional[str] = None 126 """The key to access the graph data from the `inputs` dictionary. 127 Only used to obtain an environment-dependent chemical shift.""" 128 self_weight: float = 10.0 129 """The weight of the self interaction in environment-dependent chemical shift.""" 130 # target_dim: Optional[int] = None 131 132 FID: ClassVar[str] = "CONSTRAIN_EVIDENCE" 133 134 @nn.compact 135 def __call__(self, inputs) -> Any: 136 x = inputs[self.key] 137 dim = 3 138 assert x.shape[-1] == dim, f"Dimension {-1} must be {dim}, got {x.shape[-1]}" 139 nui, alphai, betai = jnp.split(x, 3, axis=-1) 140 141 if self.chemical_shift is not None: 142 # modify nu evidence depending on the chemical species 143 # this is useful for predicting high uncertainties for unknown species 144 # in known geometries 145 nu_shift = jnp.abs( 146 self.param( 147 "nu_shift", 148 lambda key, shape: jnp.ones(shape) * self.chemical_shift, 149 (len(PERIODIC_TABLE),), 150 ) 151 )[inputs["species"]] 152 if self.graph_key is not None: 153 graph = inputs[self.graph_key] 154 edge_dst = graph["edge_dst"] 155 edge_src = graph["edge_src"] 156 switch = graph["switch"] 157 nushift_neigh = jax.ops.segment_sum( 158 nu_shift[edge_dst] * switch, edge_src, nu_shift.shape[0] 159 ) 160 norm = self.self_weight + jax.ops.segment_sum( 161 switch, edge_src, nu_shift.shape[0] 162 ) 163 nu_shift = (self.self_weight * nu_shift + nushift_neigh) / norm 164 nu_shift = nu_shift[:, None] 165 else: 166 nu_shift = 1.0 167 168 if self.nualpha_coupling is None: 169 # couple nu and alpha to remove overparameterization 170 # of the evidential model (see Meinert et al) 171 if self.constant_alpha: 172 if self.trainable_alpha: 173 alphai = 1 + jnp.abs( 174 self.param( 175 "alpha", 176 lambda key: jnp.asarray(self.alpha_init - 1), 177 ) 178 ) * jnp.ones_like(alphai) 179 else: 180 assert self.alpha_init > 1, "alpha_init must be >1" 181 alphai = self.alpha_init * jnp.ones_like(alphai) 182 elif self.constant_nu: 183 alphai = 1 + (self.alpha_init - 1) * nu_shift * jax.nn.softplus(alphai) 184 else: 185 alphai = 1 + (self.alpha_init - 1) * jax.nn.softplus(alphai) 186 187 if self.constant_nu: 188 if self.trainable_nu: 189 nui = 1.0e-5 + jnp.abs( 190 self.param( 191 "nu", 192 lambda key: jnp.asarray(self.nu_init), 193 ) 194 ) * jnp.ones_like(nui) 195 else: 196 nui = self.nu_init * jnp.ones_like(nui) 197 else: 198 nui = 1.0e-5 + self.nu_init * nu_shift * jax.nn.softplus(nui) 199 else: 200 alphai = 1 + nu_shift * jax.nn.softplus(alphai) 201 if self.trainable_nu: 202 nualpha_coupling = jnp.abs( 203 self.param( 204 "nualpha_coupling", 205 lambda key: jnp.asarray(self.nualpha_coupling), 206 ) 207 ) 208 else: 209 nualpha_coupling = self.nualpha_coupling 210 nui = nualpha_coupling * 2 * alphai 211 212 if self.constant_beta: 213 if self.trainable_beta: 214 betai = ( 215 jax.nn.softplus( 216 self.param( 217 "beta", 218 lambda key: jnp.asarray(0.0), 219 ) 220 ) 221 / np.log(2.0) 222 ) * jnp.ones_like(alphai) 223 else: 224 betai = jnp.ones_like(alphai) 225 else: 226 betai = jax.nn.softplus(betai) 227 228 betai = self.beta_scale * betai 229 230 output = jnp.concatenate([nui, alphai, betai], axis=-1) 231 # if self.target_dim is not None: 232 # output = output[...,None,:].repeat(self.target_dim, axis=-2) 233 234 output_key = self.key if self.output_key is None else self.output_key 235 out = { 236 **inputs, 237 output_key: output, 238 output_key + "_var": jnp.squeeze(betai / (nui * (alphai - 1)), axis=-1), 239 output_key + "_aleatoric": jnp.squeeze(betai / (alphai - 1), axis=-1), 240 output_key 241 + "_wst2": jnp.squeeze(betai * (1 + nui) / (alphai * nui), axis=-1), 242 } 243 return out
Constrain the parameters of an evidential model
FID: CONSTRAIN_EVIDENCE
References
- Amini et al, Deep Evidential Regression, NeurIPS 2020 (https://arxiv.org/abs/1910.02600)
- Meinert et al, Multivariate Deep Evidential Regression, (https://arxiv.org/pdf/2104.06135.pdf)
The key to use for the constrained paramters in the output
dictionary.
The key to access the graph data from the inputs
dictionary.
Only used to obtain an environment-dependent chemical shift.
The weight of the self interaction in environment-dependent chemical shift.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.