fennol.models.misc.uncertainty

  1import flax.linen as nn
  2from typing import Any, Sequence, Callable, Union, ClassVar,Optional
  3import jax.numpy as jnp
  4import jax
  5import numpy as np
  6from functools import partial
  7import dataclasses
  8from ...utils.periodic_table import PERIODIC_TABLE
  9
 10
 11class EnsembleStatistics(nn.Module):
 12    """Computes the mean and variance of an ensemble.
 13    
 14    FID: ENSEMBLE_STAT
 15    """
 16
 17    key: str
 18    """The key to access the input data from the `inputs` dictionary."""
 19    axis: int = -1
 20    """The axis along which to compute the mean and variance."""
 21    shuffle_ensemble: bool = False
 22    """Whether to shuffle the ensemble."""
 23    weights_key: Optional[str] = None
 24    mean_key: Optional[str] = None
 25
 26    FID: ClassVar[str] = "ENSEMBLE_STAT"
 27
 28    @nn.compact
 29    def __call__(self, inputs) -> Any:
 30        x = inputs[self.key]
 31        if self.weights_key is not None:
 32            weights = inputs[self.weights_key]
 33        else:
 34            weights = jnp.ones_like(x)
 35        weights = weights/jnp.sum(weights,axis=self.axis,keepdims=True)
 36        mean = jnp.sum(x*weights, axis=self.axis,keepdims=True)
 37
 38        nsamples = x.shape[self.axis]
 39        if nsamples == 1:
 40            var = jnp.zeros_like(mean)
 41        else:
 42            # var = jnp.var(x, axis=self.axis, ddof=1)
 43            var = jnp.sum(weights*(x-mean)**2,axis=self.axis)*(nsamples/(nsamples-1.))
 44        
 45        mean = jnp.squeeze(mean,axis=self.axis)
 46        output = {**inputs, self.key + "_mean": mean, self.key + "_var": var}
 47
 48        if self.mean_key is not None:
 49            output[self.mean_key] = mean
 50
 51        training = "training" in inputs.get("flags", {})
 52        if self.shuffle_ensemble and training and "rng_key" in inputs:
 53            key, subkey = jax.random.split(inputs["rng_key"])
 54            x = jax.random.permutation(subkey, x, axis=self.axis, independent=True)
 55            output[self.key] = x
 56            output["rng_key"] = key
 57        return output
 58
 59
 60class EnsembleShift(nn.Module):
 61    """Shifts the mean of an ensemble to match a reference tensor.
 62    
 63    FID: ENSEMBLE_SHIFT
 64    """
 65
 66    key: str
 67    """The key to access the input data from the `inputs` dictionary."""
 68    ref_key: str
 69    """The key to access the reference data from the `inputs` dictionary."""
 70    axis: int = -1
 71    """The axis of the ensemble."""
 72    output_key: Optional[str] = None
 73    """The key of the output ensemble. If None, the input key will be used."""
 74
 75
 76    FID: ClassVar[str] = "ENSEMBLE_SHIFT"
 77
 78    @nn.compact
 79    def __call__(self, inputs) -> Any:
 80        x = inputs[self.key]
 81        mean = jnp.mean(x, axis=self.axis, keepdims=True)
 82        ref = inputs[self.ref_key].reshape(mean.shape)
 83        x = x - mean + ref
 84        output_key = self.key if self.output_key is None else self.output_key
 85        return {**inputs, output_key: x}
 86
 87
 88class ConstrainEvidence(nn.Module):
 89    """ Constrain the parameters of an evidential model
 90
 91    FID: CONSTRAIN_EVIDENCE
 92
 93    ### References
 94    - Amini et al, Deep Evidential Regression, NeurIPS 2020 (https://arxiv.org/abs/1910.02600)
 95    - Meinert et al, Multivariate Deep Evidential Regression, (https://arxiv.org/pdf/2104.06135.pdf)
 96
 97    """
 98    key: str
 99    """The key to access the input data from the `inputs` dictionary."""
100    output_key: Optional[str] = None
101    """The key to use for the constrained paramters in the `output` dictionary."""
102    beta_scale: Union[str, float] = 1.0
103    """The scale of the beta parameter."""
104    alpha_init: float = 2.0
105    """The initial value of the alpha parameter."""
106    nu_init: float = 1.0
107    """The initial value of the nu parameter."""
108    chemical_shift: Optional[float] = None
109    """The initial chemical shift of evidence."""
110    trainable_beta: bool = False
111    """Whether the beta parameter is trainable."""
112    constant_beta: bool = False
113    """Whether the beta parameter is a constant."""
114    trainable_alpha: bool = False
115    """Whether the alpha parameter is trainable."""
116    constant_alpha: bool = False
117    """Whether the alpha parameter is a constant."""
118    trainable_nu: bool = False
119    """Whether the nu parameter is trainable."""
120    constant_nu: bool = False
121    """Whether the nu parameter is a constant."""
122    nualpha_coupling: Optional[float] = None
123    """The coupling constant between nu and alpha."""
124    graph_key: Optional[str] = None
125    """The key to access the graph data from the `inputs` dictionary. 
126        Only used to obtain an environment-dependent chemical shift."""
127    self_weight: float = 10.0
128    """The weight of the self interaction in environment-dependent chemical shift."""
129    # target_dim: Optional[int] = None
130
131    FID: ClassVar[str] = "CONSTRAIN_EVIDENCE"
132
133    @nn.compact
134    def __call__(self, inputs) -> Any:
135        x = inputs[self.key]
136        dim = 3
137        assert x.shape[-1] == dim, f"Dimension {-1} must be {dim}, got {x.shape[-1]}"
138        nui, alphai, betai = jnp.split(x, 3, axis=-1)
139
140        if self.chemical_shift is not None:
141            # modify nu evidence depending on the chemical species 
142            #   this is useful for predicting high uncertainties for unknown species
143            #   in known geometries
144            nu_shift = jnp.abs(
145                self.param(
146                    "nu_shift",
147                    lambda key, shape: jnp.ones(shape) * self.chemical_shift,
148                    (len(PERIODIC_TABLE),),
149                )
150            )[inputs["species"]]
151            if self.graph_key is not None:
152                graph = inputs[self.graph_key]
153                edge_dst = graph["edge_dst"]
154                edge_src = graph["edge_src"]
155                switch = graph["switch"]
156                nushift_neigh = jax.ops.segment_sum(
157                    nu_shift[edge_dst] * switch, edge_src, nu_shift.shape[0]
158                )
159                norm = self.self_weight + jax.ops.segment_sum(
160                    switch, edge_src, nu_shift.shape[0]
161                )
162                nu_shift = (self.self_weight * nu_shift + nushift_neigh) / norm
163            nu_shift = nu_shift[:, None]
164        else:
165            nu_shift = 1.0
166
167        if self.nualpha_coupling is None:
168            # couple nu and alpha to remove overparameterization
169            #  of the evidential model (see Meinert et al)
170            if self.constant_alpha:
171                if self.trainable_alpha:
172                    alphai = 1 + jnp.abs(
173                        self.param(
174                            "alpha",
175                            lambda key: jnp.asarray(self.alpha_init - 1),
176                        )
177                    ) * jnp.ones_like(alphai)
178                else:
179                    assert self.alpha_init > 1, "alpha_init must be >1"
180                    alphai = self.alpha_init * jnp.ones_like(alphai)
181            elif self.constant_nu:
182                alphai = 1 + (self.alpha_init - 1) * nu_shift * jax.nn.softplus(alphai)
183            else:
184                alphai = 1 + (self.alpha_init - 1) * jax.nn.softplus(alphai)
185
186            if self.constant_nu:
187                if self.trainable_nu:
188                    nui = 1.0e-5 + jnp.abs(
189                        self.param(
190                            "nu",
191                            lambda key: jnp.asarray(self.nu_init),
192                        )
193                    ) * jnp.ones_like(nui)
194                else:
195                    nui = self.nu_init * jnp.ones_like(nui)
196            else:
197                nui = 1.0e-5 + self.nu_init * nu_shift * jax.nn.softplus(nui)
198        else:
199            alphai = 1 + nu_shift * jax.nn.softplus(alphai)
200            if self.trainable_nu:
201                nualpha_coupling = jnp.abs(
202                    self.param(
203                        "nualpha_coupling",
204                        lambda key: jnp.asarray(self.nualpha_coupling),
205                    )
206                )
207            else:
208                nualpha_coupling = self.nualpha_coupling
209            nui = nualpha_coupling * 2 * alphai
210
211        if self.constant_beta:
212            if self.trainable_beta:
213                betai = (
214                    jax.nn.softplus(
215                        self.param(
216                            "beta",
217                            lambda key: jnp.asarray(0.0),
218                        )
219                    )
220                    / np.log(2.0)
221                ) * jnp.ones_like(alphai)
222            else:
223                betai = jnp.ones_like(alphai)
224        else:
225            betai = jax.nn.softplus(betai)
226
227        betai = self.beta_scale * betai
228
229        output = jnp.concatenate([nui, alphai, betai], axis=-1)
230        # if self.target_dim is not None:
231        #     output = output[...,None,:].repeat(self.target_dim, axis=-2)
232
233        output_key = self.key if self.output_key is None else self.output_key
234        out = {
235            **inputs,
236            output_key: output,
237            output_key + "_var": jnp.squeeze(betai / (nui * (alphai - 1)), axis=-1),
238            output_key + "_aleatoric": jnp.squeeze(betai / (alphai - 1), axis=-1),
239            output_key
240            + "_wst2": jnp.squeeze(betai * (1 + nui) / (alphai * nui), axis=-1),
241        }
242        return out
class EnsembleStatistics(flax.linen.module.Module):
12class EnsembleStatistics(nn.Module):
13    """Computes the mean and variance of an ensemble.
14    
15    FID: ENSEMBLE_STAT
16    """
17
18    key: str
19    """The key to access the input data from the `inputs` dictionary."""
20    axis: int = -1
21    """The axis along which to compute the mean and variance."""
22    shuffle_ensemble: bool = False
23    """Whether to shuffle the ensemble."""
24    weights_key: Optional[str] = None
25    mean_key: Optional[str] = None
26
27    FID: ClassVar[str] = "ENSEMBLE_STAT"
28
29    @nn.compact
30    def __call__(self, inputs) -> Any:
31        x = inputs[self.key]
32        if self.weights_key is not None:
33            weights = inputs[self.weights_key]
34        else:
35            weights = jnp.ones_like(x)
36        weights = weights/jnp.sum(weights,axis=self.axis,keepdims=True)
37        mean = jnp.sum(x*weights, axis=self.axis,keepdims=True)
38
39        nsamples = x.shape[self.axis]
40        if nsamples == 1:
41            var = jnp.zeros_like(mean)
42        else:
43            # var = jnp.var(x, axis=self.axis, ddof=1)
44            var = jnp.sum(weights*(x-mean)**2,axis=self.axis)*(nsamples/(nsamples-1.))
45        
46        mean = jnp.squeeze(mean,axis=self.axis)
47        output = {**inputs, self.key + "_mean": mean, self.key + "_var": var}
48
49        if self.mean_key is not None:
50            output[self.mean_key] = mean
51
52        training = "training" in inputs.get("flags", {})
53        if self.shuffle_ensemble and training and "rng_key" in inputs:
54            key, subkey = jax.random.split(inputs["rng_key"])
55            x = jax.random.permutation(subkey, x, axis=self.axis, independent=True)
56            output[self.key] = x
57            output["rng_key"] = key
58        return output

Computes the mean and variance of an ensemble.

FID: ENSEMBLE_STAT

EnsembleStatistics( key: str, axis: int = -1, shuffle_ensemble: bool = False, weights_key: Optional[str] = None, mean_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key to access the input data from the inputs dictionary.

axis: int = -1

The axis along which to compute the mean and variance.

shuffle_ensemble: bool = False

Whether to shuffle the ensemble.

weights_key: Optional[str] = None
mean_key: Optional[str] = None
FID: ClassVar[str] = 'ENSEMBLE_STAT'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class EnsembleShift(flax.linen.module.Module):
61class EnsembleShift(nn.Module):
62    """Shifts the mean of an ensemble to match a reference tensor.
63    
64    FID: ENSEMBLE_SHIFT
65    """
66
67    key: str
68    """The key to access the input data from the `inputs` dictionary."""
69    ref_key: str
70    """The key to access the reference data from the `inputs` dictionary."""
71    axis: int = -1
72    """The axis of the ensemble."""
73    output_key: Optional[str] = None
74    """The key of the output ensemble. If None, the input key will be used."""
75
76
77    FID: ClassVar[str] = "ENSEMBLE_SHIFT"
78
79    @nn.compact
80    def __call__(self, inputs) -> Any:
81        x = inputs[self.key]
82        mean = jnp.mean(x, axis=self.axis, keepdims=True)
83        ref = inputs[self.ref_key].reshape(mean.shape)
84        x = x - mean + ref
85        output_key = self.key if self.output_key is None else self.output_key
86        return {**inputs, output_key: x}

Shifts the mean of an ensemble to match a reference tensor.

FID: ENSEMBLE_SHIFT

EnsembleShift( key: str, ref_key: str, axis: int = -1, output_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key to access the input data from the inputs dictionary.

ref_key: str

The key to access the reference data from the inputs dictionary.

axis: int = -1

The axis of the ensemble.

output_key: Optional[str] = None

The key of the output ensemble. If None, the input key will be used.

FID: ClassVar[str] = 'ENSEMBLE_SHIFT'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ConstrainEvidence(flax.linen.module.Module):
 89class ConstrainEvidence(nn.Module):
 90    """ Constrain the parameters of an evidential model
 91
 92    FID: CONSTRAIN_EVIDENCE
 93
 94    ### References
 95    - Amini et al, Deep Evidential Regression, NeurIPS 2020 (https://arxiv.org/abs/1910.02600)
 96    - Meinert et al, Multivariate Deep Evidential Regression, (https://arxiv.org/pdf/2104.06135.pdf)
 97
 98    """
 99    key: str
100    """The key to access the input data from the `inputs` dictionary."""
101    output_key: Optional[str] = None
102    """The key to use for the constrained paramters in the `output` dictionary."""
103    beta_scale: Union[str, float] = 1.0
104    """The scale of the beta parameter."""
105    alpha_init: float = 2.0
106    """The initial value of the alpha parameter."""
107    nu_init: float = 1.0
108    """The initial value of the nu parameter."""
109    chemical_shift: Optional[float] = None
110    """The initial chemical shift of evidence."""
111    trainable_beta: bool = False
112    """Whether the beta parameter is trainable."""
113    constant_beta: bool = False
114    """Whether the beta parameter is a constant."""
115    trainable_alpha: bool = False
116    """Whether the alpha parameter is trainable."""
117    constant_alpha: bool = False
118    """Whether the alpha parameter is a constant."""
119    trainable_nu: bool = False
120    """Whether the nu parameter is trainable."""
121    constant_nu: bool = False
122    """Whether the nu parameter is a constant."""
123    nualpha_coupling: Optional[float] = None
124    """The coupling constant between nu and alpha."""
125    graph_key: Optional[str] = None
126    """The key to access the graph data from the `inputs` dictionary. 
127        Only used to obtain an environment-dependent chemical shift."""
128    self_weight: float = 10.0
129    """The weight of the self interaction in environment-dependent chemical shift."""
130    # target_dim: Optional[int] = None
131
132    FID: ClassVar[str] = "CONSTRAIN_EVIDENCE"
133
134    @nn.compact
135    def __call__(self, inputs) -> Any:
136        x = inputs[self.key]
137        dim = 3
138        assert x.shape[-1] == dim, f"Dimension {-1} must be {dim}, got {x.shape[-1]}"
139        nui, alphai, betai = jnp.split(x, 3, axis=-1)
140
141        if self.chemical_shift is not None:
142            # modify nu evidence depending on the chemical species 
143            #   this is useful for predicting high uncertainties for unknown species
144            #   in known geometries
145            nu_shift = jnp.abs(
146                self.param(
147                    "nu_shift",
148                    lambda key, shape: jnp.ones(shape) * self.chemical_shift,
149                    (len(PERIODIC_TABLE),),
150                )
151            )[inputs["species"]]
152            if self.graph_key is not None:
153                graph = inputs[self.graph_key]
154                edge_dst = graph["edge_dst"]
155                edge_src = graph["edge_src"]
156                switch = graph["switch"]
157                nushift_neigh = jax.ops.segment_sum(
158                    nu_shift[edge_dst] * switch, edge_src, nu_shift.shape[0]
159                )
160                norm = self.self_weight + jax.ops.segment_sum(
161                    switch, edge_src, nu_shift.shape[0]
162                )
163                nu_shift = (self.self_weight * nu_shift + nushift_neigh) / norm
164            nu_shift = nu_shift[:, None]
165        else:
166            nu_shift = 1.0
167
168        if self.nualpha_coupling is None:
169            # couple nu and alpha to remove overparameterization
170            #  of the evidential model (see Meinert et al)
171            if self.constant_alpha:
172                if self.trainable_alpha:
173                    alphai = 1 + jnp.abs(
174                        self.param(
175                            "alpha",
176                            lambda key: jnp.asarray(self.alpha_init - 1),
177                        )
178                    ) * jnp.ones_like(alphai)
179                else:
180                    assert self.alpha_init > 1, "alpha_init must be >1"
181                    alphai = self.alpha_init * jnp.ones_like(alphai)
182            elif self.constant_nu:
183                alphai = 1 + (self.alpha_init - 1) * nu_shift * jax.nn.softplus(alphai)
184            else:
185                alphai = 1 + (self.alpha_init - 1) * jax.nn.softplus(alphai)
186
187            if self.constant_nu:
188                if self.trainable_nu:
189                    nui = 1.0e-5 + jnp.abs(
190                        self.param(
191                            "nu",
192                            lambda key: jnp.asarray(self.nu_init),
193                        )
194                    ) * jnp.ones_like(nui)
195                else:
196                    nui = self.nu_init * jnp.ones_like(nui)
197            else:
198                nui = 1.0e-5 + self.nu_init * nu_shift * jax.nn.softplus(nui)
199        else:
200            alphai = 1 + nu_shift * jax.nn.softplus(alphai)
201            if self.trainable_nu:
202                nualpha_coupling = jnp.abs(
203                    self.param(
204                        "nualpha_coupling",
205                        lambda key: jnp.asarray(self.nualpha_coupling),
206                    )
207                )
208            else:
209                nualpha_coupling = self.nualpha_coupling
210            nui = nualpha_coupling * 2 * alphai
211
212        if self.constant_beta:
213            if self.trainable_beta:
214                betai = (
215                    jax.nn.softplus(
216                        self.param(
217                            "beta",
218                            lambda key: jnp.asarray(0.0),
219                        )
220                    )
221                    / np.log(2.0)
222                ) * jnp.ones_like(alphai)
223            else:
224                betai = jnp.ones_like(alphai)
225        else:
226            betai = jax.nn.softplus(betai)
227
228        betai = self.beta_scale * betai
229
230        output = jnp.concatenate([nui, alphai, betai], axis=-1)
231        # if self.target_dim is not None:
232        #     output = output[...,None,:].repeat(self.target_dim, axis=-2)
233
234        output_key = self.key if self.output_key is None else self.output_key
235        out = {
236            **inputs,
237            output_key: output,
238            output_key + "_var": jnp.squeeze(betai / (nui * (alphai - 1)), axis=-1),
239            output_key + "_aleatoric": jnp.squeeze(betai / (alphai - 1), axis=-1),
240            output_key
241            + "_wst2": jnp.squeeze(betai * (1 + nui) / (alphai * nui), axis=-1),
242        }
243        return out

Constrain the parameters of an evidential model

FID: CONSTRAIN_EVIDENCE

References

ConstrainEvidence( key: str, output_key: Optional[str] = None, beta_scale: Union[str, float] = 1.0, alpha_init: float = 2.0, nu_init: float = 1.0, chemical_shift: Optional[float] = None, trainable_beta: bool = False, constant_beta: bool = False, trainable_alpha: bool = False, constant_alpha: bool = False, trainable_nu: bool = False, constant_nu: bool = False, nualpha_coupling: Optional[float] = None, graph_key: Optional[str] = None, self_weight: float = 10.0, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str

The key to access the input data from the inputs dictionary.

output_key: Optional[str] = None

The key to use for the constrained paramters in the output dictionary.

beta_scale: Union[str, float] = 1.0

The scale of the beta parameter.

alpha_init: float = 2.0

The initial value of the alpha parameter.

nu_init: float = 1.0

The initial value of the nu parameter.

chemical_shift: Optional[float] = None

The initial chemical shift of evidence.

trainable_beta: bool = False

Whether the beta parameter is trainable.

constant_beta: bool = False

Whether the beta parameter is a constant.

trainable_alpha: bool = False

Whether the alpha parameter is trainable.

constant_alpha: bool = False

Whether the alpha parameter is a constant.

trainable_nu: bool = False

Whether the nu parameter is trainable.

constant_nu: bool = False

Whether the nu parameter is a constant.

nualpha_coupling: Optional[float] = None

The coupling constant between nu and alpha.

graph_key: Optional[str] = None

The key to access the graph data from the inputs dictionary. Only used to obtain an environment-dependent chemical shift.

self_weight: float = 10.0

The weight of the self interaction in environment-dependent chemical shift.

FID: ClassVar[str] = 'CONSTRAIN_EVIDENCE'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None