fennol.utils

 1from .spherical_harmonics import CG_SO3, generate_spherical_harmonics
 2from .atomic_units import AtomicUnits
 3from typing import Dict, Any
 4import jax
 5import jax.numpy as jnp
 6
 7
 8def minmaxone(x, name=""):
 9    print(name, x.min(), x.max(), (x**2).mean() ** 0.5)
10
11def minmaxone_jax(x, name=""):
12    jax.debug.print(
13        "{name}  {min}  {max}  {mean}",
14        name=name,
15        min=x.min(),
16        max=x.max(),
17        mean=(x**2).mean(),
18    )
19
20
21def mask_filter_1d(mask, max_size, *values_fill):
22    cumsum = jnp.cumsum(mask,dtype=jnp.int32)
23    scatter_idx = jnp.where(mask, cumsum - 1, max_size)
24    outputs = []
25    for value, fill in values_fill:
26        shape = list(value.shape)
27        shape[0] = max_size
28        output = (
29            jnp.full(shape, fill, dtype=value.dtype)
30            .at[scatter_idx]
31            .set(value, mode="drop")
32        )
33        outputs.append(output)
34    if cumsum.size == 0:
35        return outputs, scatter_idx, 0
36    return outputs, scatter_idx, cumsum[-1]
37
38
39def deep_update(
40    mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]
41) -> Dict[Any, Any]:
42    updated_mapping = mapping.copy()
43    for updating_mapping in updating_mappings:
44        for k, v in updating_mapping.items():
45            if (
46                k in updated_mapping
47                and isinstance(updated_mapping[k], dict)
48                and isinstance(v, dict)
49            ):
50                updated_mapping[k] = deep_update(updated_mapping[k], v)
51            else:
52                updated_mapping[k] = v
53    return updated_mapping
54
55
56class Counter:
57    def __init__(self, nseg, startsave=1):
58        self.i = 0
59        self.i_avg = 0
60        self.nseg = nseg
61        self.startsave = startsave
62
63    @property
64    def count(self):
65        return self.i
66
67    @property
68    def count_avg(self):
69        return self.i_avg
70
71    @property
72    def nsample(self):
73        return max(self.count_avg - self.startsave + 1, 1)
74
75    @property
76    def is_reset_step(self):
77        return self.count == 0
78
79    def reset_avg(self):
80        self.i_avg = 0
81
82    def reset_all(self):
83        self.i = 0
84        self.i_avg = 0
85
86    def increment(self):
87        self.i = self.i + 1
88        if self.i >= self.nseg:
89            self.i = 0
90            self.i_avg = self.i_avg + 1
def minmaxone(x, name=''):
 9def minmaxone(x, name=""):
10    print(name, x.min(), x.max(), (x**2).mean() ** 0.5)
def minmaxone_jax(x, name=''):
12def minmaxone_jax(x, name=""):
13    jax.debug.print(
14        "{name}  {min}  {max}  {mean}",
15        name=name,
16        min=x.min(),
17        max=x.max(),
18        mean=(x**2).mean(),
19    )
def mask_filter_1d(mask, max_size, *values_fill):
22def mask_filter_1d(mask, max_size, *values_fill):
23    cumsum = jnp.cumsum(mask,dtype=jnp.int32)
24    scatter_idx = jnp.where(mask, cumsum - 1, max_size)
25    outputs = []
26    for value, fill in values_fill:
27        shape = list(value.shape)
28        shape[0] = max_size
29        output = (
30            jnp.full(shape, fill, dtype=value.dtype)
31            .at[scatter_idx]
32            .set(value, mode="drop")
33        )
34        outputs.append(output)
35    if cumsum.size == 0:
36        return outputs, scatter_idx, 0
37    return outputs, scatter_idx, cumsum[-1]
def deep_update( mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]) -> Dict[Any, Any]:
40def deep_update(
41    mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]
42) -> Dict[Any, Any]:
43    updated_mapping = mapping.copy()
44    for updating_mapping in updating_mappings:
45        for k, v in updating_mapping.items():
46            if (
47                k in updated_mapping
48                and isinstance(updated_mapping[k], dict)
49                and isinstance(v, dict)
50            ):
51                updated_mapping[k] = deep_update(updated_mapping[k], v)
52            else:
53                updated_mapping[k] = v
54    return updated_mapping
class Counter:
57class Counter:
58    def __init__(self, nseg, startsave=1):
59        self.i = 0
60        self.i_avg = 0
61        self.nseg = nseg
62        self.startsave = startsave
63
64    @property
65    def count(self):
66        return self.i
67
68    @property
69    def count_avg(self):
70        return self.i_avg
71
72    @property
73    def nsample(self):
74        return max(self.count_avg - self.startsave + 1, 1)
75
76    @property
77    def is_reset_step(self):
78        return self.count == 0
79
80    def reset_avg(self):
81        self.i_avg = 0
82
83    def reset_all(self):
84        self.i = 0
85        self.i_avg = 0
86
87    def increment(self):
88        self.i = self.i + 1
89        if self.i >= self.nseg:
90            self.i = 0
91            self.i_avg = self.i_avg + 1
Counter(nseg, startsave=1)
58    def __init__(self, nseg, startsave=1):
59        self.i = 0
60        self.i_avg = 0
61        self.nseg = nseg
62        self.startsave = startsave
i
i_avg
nseg
startsave
count
64    @property
65    def count(self):
66        return self.i
count_avg
68    @property
69    def count_avg(self):
70        return self.i_avg
nsample
72    @property
73    def nsample(self):
74        return max(self.count_avg - self.startsave + 1, 1)
is_reset_step
76    @property
77    def is_reset_step(self):
78        return self.count == 0
def reset_avg(self):
80    def reset_avg(self):
81        self.i_avg = 0
def reset_all(self):
83    def reset_all(self):
84        self.i = 0
85        self.i_avg = 0
def increment(self):
87    def increment(self):
88        self.i = self.i + 1
89        if self.i >= self.nseg:
90            self.i = 0
91            self.i_avg = self.i_avg + 1