fennol.utils
1from .spherical_harmonics import CG_SO3, generate_spherical_harmonics 2from .atomic_units import AtomicUnits 3from typing import Dict, Any 4import jax 5import jax.numpy as jnp 6 7 8def minmaxone(x, name=""): 9 print(name, x.min(), x.max(), (x**2).mean()) 10 11 12def minmaxone_jax(x, name=""): 13 jax.debug.print( 14 "{name} {min} {max} {mean}", 15 name=name, 16 min=x.min(), 17 max=x.max(), 18 mean=(x**2).mean(), 19 ) 20 21 22def mask_filter_1d(mask, max_size, *values_fill): 23 cumsum = jnp.cumsum(mask,dtype=jnp.int32) 24 scatter_idx = jnp.where(mask, cumsum - 1, max_size) 25 outputs = [] 26 for value, fill in values_fill: 27 shape = list(value.shape) 28 shape[0] = max_size 29 output = ( 30 jnp.full(shape, fill, dtype=value.dtype) 31 .at[scatter_idx] 32 .set(value, mode="drop") 33 ) 34 outputs.append(output) 35 if cumsum.size == 0: 36 return outputs, scatter_idx, 0 37 return outputs, scatter_idx, cumsum[-1] 38 39 40def deep_update( 41 mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any] 42) -> Dict[Any, Any]: 43 updated_mapping = mapping.copy() 44 for updating_mapping in updating_mappings: 45 for k, v in updating_mapping.items(): 46 if ( 47 k in updated_mapping 48 and isinstance(updated_mapping[k], dict) 49 and isinstance(v, dict) 50 ): 51 updated_mapping[k] = deep_update(updated_mapping[k], v) 52 else: 53 updated_mapping[k] = v 54 return updated_mapping 55 56 57class Counter: 58 def __init__(self, nseg, startsave=1): 59 self.i = 0 60 self.i_avg = 0 61 self.nseg = nseg 62 self.startsave = startsave 63 64 @property 65 def count(self): 66 return self.i 67 68 @property 69 def count_avg(self): 70 return self.i_avg 71 72 @property 73 def nsample(self): 74 return max(self.count_avg - self.startsave + 1, 1) 75 76 @property 77 def is_reset_step(self): 78 return self.count == 0 79 80 def reset_avg(self): 81 self.i_avg = 0 82 83 def reset_all(self): 84 self.i = 0 85 self.i_avg = 0 86 87 def increment(self): 88 self.i = self.i + 1 89 if self.i >= self.nseg: 90 self.i = 0 91 self.i_avg = self.i_avg + 1
def
minmaxone(x, name=''):
def
minmaxone_jax(x, name=''):
def
mask_filter_1d(mask, max_size, *values_fill):
23def mask_filter_1d(mask, max_size, *values_fill): 24 cumsum = jnp.cumsum(mask,dtype=jnp.int32) 25 scatter_idx = jnp.where(mask, cumsum - 1, max_size) 26 outputs = [] 27 for value, fill in values_fill: 28 shape = list(value.shape) 29 shape[0] = max_size 30 output = ( 31 jnp.full(shape, fill, dtype=value.dtype) 32 .at[scatter_idx] 33 .set(value, mode="drop") 34 ) 35 outputs.append(output) 36 if cumsum.size == 0: 37 return outputs, scatter_idx, 0 38 return outputs, scatter_idx, cumsum[-1]
def
deep_update( mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]) -> Dict[Any, Any]:
41def deep_update( 42 mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any] 43) -> Dict[Any, Any]: 44 updated_mapping = mapping.copy() 45 for updating_mapping in updating_mappings: 46 for k, v in updating_mapping.items(): 47 if ( 48 k in updated_mapping 49 and isinstance(updated_mapping[k], dict) 50 and isinstance(v, dict) 51 ): 52 updated_mapping[k] = deep_update(updated_mapping[k], v) 53 else: 54 updated_mapping[k] = v 55 return updated_mapping
class
Counter:
58class Counter: 59 def __init__(self, nseg, startsave=1): 60 self.i = 0 61 self.i_avg = 0 62 self.nseg = nseg 63 self.startsave = startsave 64 65 @property 66 def count(self): 67 return self.i 68 69 @property 70 def count_avg(self): 71 return self.i_avg 72 73 @property 74 def nsample(self): 75 return max(self.count_avg - self.startsave + 1, 1) 76 77 @property 78 def is_reset_step(self): 79 return self.count == 0 80 81 def reset_avg(self): 82 self.i_avg = 0 83 84 def reset_all(self): 85 self.i = 0 86 self.i_avg = 0 87 88 def increment(self): 89 self.i = self.i + 1 90 if self.i >= self.nseg: 91 self.i = 0 92 self.i_avg = self.i_avg + 1