fennol.utils

 1from .spherical_harmonics import CG_SO3, generate_spherical_harmonics
 2from .atomic_units import AtomicUnits
 3from typing import Dict, Any
 4import jax
 5import jax.numpy as jnp
 6
 7
 8def minmaxone(x, name=""):
 9    print(name, x.min(), x.max(), (x**2).mean())
10
11
12def minmaxone_jax(x, name=""):
13    jax.debug.print(
14        "{name}  {min}  {max}  {mean}",
15        name=name,
16        min=x.min(),
17        max=x.max(),
18        mean=(x**2).mean(),
19    )
20
21
22def mask_filter_1d(mask, max_size, *values_fill):
23    cumsum = jnp.cumsum(mask,dtype=jnp.int32)
24    scatter_idx = jnp.where(mask, cumsum - 1, max_size)
25    outputs = []
26    for value, fill in values_fill:
27        shape = list(value.shape)
28        shape[0] = max_size
29        output = (
30            jnp.full(shape, fill, dtype=value.dtype)
31            .at[scatter_idx]
32            .set(value, mode="drop")
33        )
34        outputs.append(output)
35    if cumsum.size == 0:
36        return outputs, scatter_idx, 0
37    return outputs, scatter_idx, cumsum[-1]
38
39
40def deep_update(
41    mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]
42) -> Dict[Any, Any]:
43    updated_mapping = mapping.copy()
44    for updating_mapping in updating_mappings:
45        for k, v in updating_mapping.items():
46            if (
47                k in updated_mapping
48                and isinstance(updated_mapping[k], dict)
49                and isinstance(v, dict)
50            ):
51                updated_mapping[k] = deep_update(updated_mapping[k], v)
52            else:
53                updated_mapping[k] = v
54    return updated_mapping
55
56
57class Counter:
58    def __init__(self, nseg, startsave=1):
59        self.i = 0
60        self.i_avg = 0
61        self.nseg = nseg
62        self.startsave = startsave
63
64    @property
65    def count(self):
66        return self.i
67
68    @property
69    def count_avg(self):
70        return self.i_avg
71
72    @property
73    def nsample(self):
74        return max(self.count_avg - self.startsave + 1, 1)
75
76    @property
77    def is_reset_step(self):
78        return self.count == 0
79
80    def reset_avg(self):
81        self.i_avg = 0
82
83    def reset_all(self):
84        self.i = 0
85        self.i_avg = 0
86
87    def increment(self):
88        self.i = self.i + 1
89        if self.i >= self.nseg:
90            self.i = 0
91            self.i_avg = self.i_avg + 1
def minmaxone(x, name=''):
 9def minmaxone(x, name=""):
10    print(name, x.min(), x.max(), (x**2).mean())
def minmaxone_jax(x, name=''):
13def minmaxone_jax(x, name=""):
14    jax.debug.print(
15        "{name}  {min}  {max}  {mean}",
16        name=name,
17        min=x.min(),
18        max=x.max(),
19        mean=(x**2).mean(),
20    )
def mask_filter_1d(mask, max_size, *values_fill):
23def mask_filter_1d(mask, max_size, *values_fill):
24    cumsum = jnp.cumsum(mask,dtype=jnp.int32)
25    scatter_idx = jnp.where(mask, cumsum - 1, max_size)
26    outputs = []
27    for value, fill in values_fill:
28        shape = list(value.shape)
29        shape[0] = max_size
30        output = (
31            jnp.full(shape, fill, dtype=value.dtype)
32            .at[scatter_idx]
33            .set(value, mode="drop")
34        )
35        outputs.append(output)
36    if cumsum.size == 0:
37        return outputs, scatter_idx, 0
38    return outputs, scatter_idx, cumsum[-1]
def deep_update( mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]) -> Dict[Any, Any]:
41def deep_update(
42    mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]
43) -> Dict[Any, Any]:
44    updated_mapping = mapping.copy()
45    for updating_mapping in updating_mappings:
46        for k, v in updating_mapping.items():
47            if (
48                k in updated_mapping
49                and isinstance(updated_mapping[k], dict)
50                and isinstance(v, dict)
51            ):
52                updated_mapping[k] = deep_update(updated_mapping[k], v)
53            else:
54                updated_mapping[k] = v
55    return updated_mapping
class Counter:
58class Counter:
59    def __init__(self, nseg, startsave=1):
60        self.i = 0
61        self.i_avg = 0
62        self.nseg = nseg
63        self.startsave = startsave
64
65    @property
66    def count(self):
67        return self.i
68
69    @property
70    def count_avg(self):
71        return self.i_avg
72
73    @property
74    def nsample(self):
75        return max(self.count_avg - self.startsave + 1, 1)
76
77    @property
78    def is_reset_step(self):
79        return self.count == 0
80
81    def reset_avg(self):
82        self.i_avg = 0
83
84    def reset_all(self):
85        self.i = 0
86        self.i_avg = 0
87
88    def increment(self):
89        self.i = self.i + 1
90        if self.i >= self.nseg:
91            self.i = 0
92            self.i_avg = self.i_avg + 1
Counter(nseg, startsave=1)
59    def __init__(self, nseg, startsave=1):
60        self.i = 0
61        self.i_avg = 0
62        self.nseg = nseg
63        self.startsave = startsave
i
i_avg
nseg
startsave
count
65    @property
66    def count(self):
67        return self.i
count_avg
69    @property
70    def count_avg(self):
71        return self.i_avg
nsample
73    @property
74    def nsample(self):
75        return max(self.count_avg - self.startsave + 1, 1)
is_reset_step
77    @property
78    def is_reset_step(self):
79        return self.count == 0
def reset_avg(self):
81    def reset_avg(self):
82        self.i_avg = 0
def reset_all(self):
84    def reset_all(self):
85        self.i = 0
86        self.i_avg = 0
def increment(self):
88    def increment(self):
89        self.i = self.i + 1
90        if self.i >= self.nseg:
91            self.i = 0
92            self.i_avg = self.i_avg + 1