fennol.utils
1from .spherical_harmonics import CG_SO3, generate_spherical_harmonics 2from .atomic_units import AtomicUnits 3from typing import Dict, Any 4import jax 5import jax.numpy as jnp 6 7 8def minmaxone(x, name=""): 9 print(name, x.min(), x.max(), (x**2).mean() ** 0.5) 10 11def minmaxone_jax(x, name=""): 12 jax.debug.print( 13 "{name} {min} {max} {mean}", 14 name=name, 15 min=x.min(), 16 max=x.max(), 17 mean=(x**2).mean(), 18 ) 19 20 21def mask_filter_1d(mask, max_size, *values_fill): 22 cumsum = jnp.cumsum(mask,dtype=jnp.int32) 23 scatter_idx = jnp.where(mask, cumsum - 1, max_size) 24 outputs = [] 25 for value, fill in values_fill: 26 shape = list(value.shape) 27 shape[0] = max_size 28 output = ( 29 jnp.full(shape, fill, dtype=value.dtype) 30 .at[scatter_idx] 31 .set(value, mode="drop") 32 ) 33 outputs.append(output) 34 if cumsum.size == 0: 35 return outputs, scatter_idx, 0 36 return outputs, scatter_idx, cumsum[-1] 37 38 39def deep_update( 40 mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any] 41) -> Dict[Any, Any]: 42 updated_mapping = mapping.copy() 43 for updating_mapping in updating_mappings: 44 for k, v in updating_mapping.items(): 45 if ( 46 k in updated_mapping 47 and isinstance(updated_mapping[k], dict) 48 and isinstance(v, dict) 49 ): 50 updated_mapping[k] = deep_update(updated_mapping[k], v) 51 else: 52 updated_mapping[k] = v 53 return updated_mapping 54 55 56class Counter: 57 def __init__(self, nseg, startsave=1): 58 self.i = 0 59 self.i_avg = 0 60 self.nseg = nseg 61 self.startsave = startsave 62 63 @property 64 def count(self): 65 return self.i 66 67 @property 68 def count_avg(self): 69 return self.i_avg 70 71 @property 72 def nsample(self): 73 return max(self.count_avg - self.startsave + 1, 1) 74 75 @property 76 def is_reset_step(self): 77 return self.count == 0 78 79 def reset_avg(self): 80 self.i_avg = 0 81 82 def reset_all(self): 83 self.i = 0 84 self.i_avg = 0 85 86 def increment(self): 87 self.i = self.i + 1 88 if self.i >= self.nseg: 89 self.i = 0 90 self.i_avg = self.i_avg + 1
def
minmaxone(x, name=''):
def
minmaxone_jax(x, name=''):
def
mask_filter_1d(mask, max_size, *values_fill):
22def mask_filter_1d(mask, max_size, *values_fill): 23 cumsum = jnp.cumsum(mask,dtype=jnp.int32) 24 scatter_idx = jnp.where(mask, cumsum - 1, max_size) 25 outputs = [] 26 for value, fill in values_fill: 27 shape = list(value.shape) 28 shape[0] = max_size 29 output = ( 30 jnp.full(shape, fill, dtype=value.dtype) 31 .at[scatter_idx] 32 .set(value, mode="drop") 33 ) 34 outputs.append(output) 35 if cumsum.size == 0: 36 return outputs, scatter_idx, 0 37 return outputs, scatter_idx, cumsum[-1]
def
deep_update( mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]) -> Dict[Any, Any]:
40def deep_update( 41 mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any] 42) -> Dict[Any, Any]: 43 updated_mapping = mapping.copy() 44 for updating_mapping in updating_mappings: 45 for k, v in updating_mapping.items(): 46 if ( 47 k in updated_mapping 48 and isinstance(updated_mapping[k], dict) 49 and isinstance(v, dict) 50 ): 51 updated_mapping[k] = deep_update(updated_mapping[k], v) 52 else: 53 updated_mapping[k] = v 54 return updated_mapping
class
Counter:
57class Counter: 58 def __init__(self, nseg, startsave=1): 59 self.i = 0 60 self.i_avg = 0 61 self.nseg = nseg 62 self.startsave = startsave 63 64 @property 65 def count(self): 66 return self.i 67 68 @property 69 def count_avg(self): 70 return self.i_avg 71 72 @property 73 def nsample(self): 74 return max(self.count_avg - self.startsave + 1, 1) 75 76 @property 77 def is_reset_step(self): 78 return self.count == 0 79 80 def reset_avg(self): 81 self.i_avg = 0 82 83 def reset_all(self): 84 self.i = 0 85 self.i_avg = 0 86 87 def increment(self): 88 self.i = self.i + 1 89 if self.i >= self.nseg: 90 self.i = 0 91 self.i_avg = self.i_avg + 1