fennol.models.modules
1from typing import Sequence, Tuple, Dict, Optional, Any 2import flax.linen as nn 3from inspect import isclass, ismodule 4from pkgutil import iter_modules 5import importlib 6import importlib.util 7import os 8import glob 9 10### python modules where FENNIX Modules are defined ### 11from . import misc,physics, embeddings,preprocessing 12 13 14MODULES: Dict[str, nn.Module] = {} 15PREPROCESSING: Dict = {} 16 17 18def available_fennix_modules(): 19 return list(MODULES.keys()) 20 21def available_fennix_preprocessing(): 22 return list(PREPROCESSING.keys()) 23 24 25def register_fennix_module(module: nn.Module, FID: Optional[str] = None): 26 if FID is not None: 27 names = [FID.upper()] 28 else: 29 if not hasattr(module, "FID"): 30 print(f"Warning: module {module.__name__} does not have a FID field and no explicit FID was provided. Module was NOT registered.") 31 if not (isinstance( 32 module.FID, str 33 ) or isinstance(module.FID,tuple)): 34 print(f"Warning: module {module.__name__} has an invalid FID field. Module was NOT registered.") 35 if isinstance(module.FID, str): 36 names = [module.FID.upper()] 37 else: 38 names = [fid.upper() for fid in module.FID] 39 for name in names: 40 if name in MODULES and MODULES[name] != module: 41 raise ValueError( 42 f"A different module identified as '{name}' is already registered !" 43 ) 44 MODULES[name] = module 45 46def register_fennix_preprocessing(module, FPID: Optional[str] = None): 47 if FPID is not None: 48 names = [FPID.upper()] 49 else: 50 if not hasattr(module, "FPID"): 51 print(f"Warning: module {module.__name__} does not have a FPID field and no explicit FPID was provided. Module was NOT registered.") 52 if not (isinstance( 53 module.FPID, str 54 ) or isinstance(module.FPID,tuple)): 55 print(f"Warning: module {module.__name__} has an invalid FPID field. Module was NOT registered.") 56 if isinstance(module.FPID, str): 57 names = [module.FPID.upper()] 58 else: 59 names = [fid.upper() for fid in module.FPID] 60 for name in names: 61 if name in PREPROCESSING and PREPROCESSING[name] != module: 62 raise ValueError( 63 f"A different module identified as '{name}' is already registered !" 64 ) 65 PREPROCESSING[name] = module 66 67 68def register_fennix_modules(module, recurs=0, max_recurs=2): 69 if ismodule(module) and hasattr(module,"__path__"): 70 for _, name, _ in iter_modules(module.__path__): 71 sub_module = __import__(f"{module.__name__}.{name}", fromlist=[""]) 72 register_fennix_modules(sub_module, recurs=recurs + 1, max_recurs=max_recurs) 73 for m in module.__dict__.values(): 74 if isclass(m) and issubclass(m, nn.Module) and m != nn.Module: 75 if hasattr(m, "FID"): 76 register_fennix_module(m) 77 elif isclass(m) and hasattr(m, "FPID"): 78 register_fennix_preprocessing(m) 79 80 81### REGISTER DEFAULT MODULES ##################### 82for mods in [misc, physics, embeddings,preprocessing]: 83 register_fennix_modules(mods) 84module_path = os.environ.get("FENNOL_MODULES_PATH","").split(":") 85for path in module_path: 86 if os.path.exists(path): 87 for file in glob.glob(f"{path}/*.py"): 88 spec = importlib.util.spec_from_file_location("custom_modules", file) 89 module = importlib.util.module_from_spec(spec) 90 spec.loader.exec_module(module) 91 register_fennix_modules(module) 92 93################################################## 94 95def get_modules_documentation(): 96 doc = {} 97 for name, module in MODULES.items(): 98 doc[name] = module.__doc__ 99 return doc 100 101 102class FENNIXModules(nn.Module): 103 r"""Sequential module that applies a sequence of FENNIX modules. 104 105 Attributes: 106 layers (Sequence[Tuple[nn.Module, Dict]]): Sequence of tuples (module, parameters) to apply. 107 108 """ 109 110 layers: Sequence[Tuple[nn.Module, Dict]] 111 112 def __post_init__(self): 113 if not isinstance(self.layers, Sequence): 114 raise ValueError( 115 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 116 ) 117 super().__post_init__() 118 119 @nn.compact 120 def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 121 if not self.layers: 122 raise ValueError(f"Empty Sequential module {self.name}.") 123 124 outputs = inputs 125 for layer, prms in self.layers: 126 outputs = layer(**prms)(outputs) 127 return outputs
MODULES: Dict[str, flax.linen.module.Module] =
{'CHANNEL_MIXING': <class 'fennol.models.misc.e3.ChannelMixing'>, 'CHANNEL_MIXING_E3': <class 'fennol.models.misc.e3.ChannelMixingE3'>, 'SPHERICAL_TO_CARTESIAN': <class 'fennol.models.misc.e3.SphericalToCartesian'>, 'NEURAL_NET': <class 'fennol.models.misc.nets.FullyConnectedNet'>, 'SPECIES_ENCODING': <class 'fennol.models.misc.encodings.SpeciesEncoding'>, 'RADIAL_BASIS': <class 'fennol.models.misc.encodings.RadialBasis'>, 'APPLY_SWITCH': <class 'fennol.models.misc.misc.ApplySwitch'>, 'ATOM_TO_EDGE': <class 'fennol.models.misc.misc.AtomToEdge'>, 'SCATTER_EDGES': <class 'fennol.models.misc.misc.ScatterEdges'>, 'EDGE_CONCATENATE': <class 'fennol.models.misc.misc.EdgeConcatenate'>, 'SCATTER_SYSTEM': <class 'fennol.models.misc.misc.ScatterSystem'>, 'SYSTEM_TO_ATOMS': <class 'fennol.models.misc.misc.SystemToAtoms'>, 'SUM_AXIS': <class 'fennol.models.misc.misc.SumAxis'>, 'SPLIT': <class 'fennol.models.misc.misc.Split'>, 'CONCATENATE': <class 'fennol.models.misc.misc.Concatenate'>, 'ACTIVATION': <class 'fennol.models.misc.misc.Activation'>, 'SCALE': <class 'fennol.models.misc.misc.Scale'>, 'ADD': <class 'fennol.models.misc.misc.Add'>, 'MULTIPLY': <class 'fennol.models.misc.misc.Multiply'>, 'TRANSPOSE': <class 'fennol.models.misc.misc.Transpose'>, 'RESHAPE': <class 'fennol.models.misc.misc.Reshape'>, 'CHEMICAL_CONSTANT': <class 'fennol.models.misc.misc.ChemicalConstant'>, 'SWITCH_FUNCTION': <class 'fennol.models.misc.misc.SwitchFunction'>, 'RES_MLP': <class 'fennol.models.misc.nets.ResMLP'>, 'SKIP_NET': <class 'fennol.models.misc.nets.FullyResidualNet'>, 'HIERARCHICAL_NET': <class 'fennol.models.misc.nets.HierarchicalNet'>, 'SPECIES_INDEX_NET': <class 'fennol.models.misc.nets.SpeciesIndexNet'>, 'CHEMICAL_NET': <class 'fennol.models.misc.nets.ChemicalNet'>, 'MOE_NET': <class 'fennol.models.misc.nets.MOENet'>, 'CHANNEL_NET': <class 'fennol.models.misc.nets.ChannelNet'>, 'GATED_PERCEPTRON': <class 'fennol.models.misc.nets.GatedPerceptron'>, 'ZACNET': <class 'fennol.models.misc.nets.ZAcNet'>, 'ZLORANET': <class 'fennol.models.misc.nets.ZLoRANet'>, 'BLOCK_INDEX_NET': <class 'fennol.models.misc.nets.BlockIndexNet'>, 'ENSEMBLE_STAT': <class 'fennol.models.misc.uncertainty.EnsembleStatistics'>, 'ENSEMBLE_SHIFT': <class 'fennol.models.misc.uncertainty.EnsembleShift'>, 'CONSTRAIN_EVIDENCE': <class 'fennol.models.misc.uncertainty.ConstrainEvidence'>, 'CN_D4': <class 'fennol.models.physics.bond.CND4'>, 'SUM_SWITCH': <class 'fennol.models.physics.bond.SumSwitch'>, 'CN_SHIFT': <class 'fennol.models.physics.bond.CNShift'>, 'CN_STORE': <class 'fennol.models.physics.bond.CNStore'>, 'FLAT_BOTTOM': <class 'fennol.models.physics.bond.FlatBottom'>, 'VDW_OQDO': <class 'fennol.models.physics.dispersion.VdwOQDO'>, 'ELECTRIC_FIELD': <class 'fennol.models.physics.electric_field.ElectricField'>, 'COULOMB': <class 'fennol.models.physics.electrostatics.Coulomb'>, 'QEQ_D4': <class 'fennol.models.physics.electrostatics.QeqD4'>, 'CHARGE_CORRECTION': <class 'fennol.models.physics.electrostatics.ChargeCorrection'>, 'DISTRIBUTE_ELECTRONS': <class 'fennol.models.physics.electrostatics.DistributeElectrons'>, 'POLARIZATION': <class 'fennol.models.physics.polarisation.Polarization'>, 'REPULSION_ZBL': <class 'fennol.models.physics.repulsion.RepulsionZBL'>, 'AIMNET': <class 'fennol.models.embeddings.aimnet.AIMNet'>, 'ALLEGRO': <class 'fennol.models.embeddings.allegro.AllegroEmbedding'>, 'ALLEGRO_E3NN': <class 'fennol.models.embeddings.allegro.AllegroE3NNEmbedding'>, 'ANI_AEV': <class 'fennol.models.embeddings.ani.ANIAEV'>, 'CAIMAN': <class 'fennol.models.embeddings.caiman.CaimanEmbedding'>, 'CHARGE_HYPOTHESIS': <class 'fennol.models.embeddings.charge_embeddings.ChargeHypothesis'>, 'CHGNET': <class 'fennol.models.embeddings.chgnet.CHGNetEmbedding'>, 'CRATE': <class 'fennol.models.embeddings.crate.CRATEmbedding'>, 'DEEPPOT': <class 'fennol.models.embeddings.deeppot.DeepPotEmbedding'>, 'DEEPPOT_E3': <class 'fennol.models.embeddings.deeppot.DeepPotE3Embedding'>, 'EEACSF': <class 'fennol.models.embeddings.eeacsf.EEACSF'>, 'FOAM': <class 'fennol.models.embeddings.foam.FOAMEmbedding'>, 'GAUSSIAN_MOMENTS': <class 'fennol.models.embeddings.gaussian_moments.GaussianMomentsEmbedding'>, 'HIPNN': <class 'fennol.models.embeddings.hipnn.HIPNNEmbedding'>, 'MACE': <class 'fennol.models.embeddings.mace.MACE'>, 'MINIMACE': <class 'fennol.models.embeddings.minimace.MiniMaceEmbedding'>, 'NEWTONNET': <class 'fennol.models.embeddings.newtonnet.NewtonNetEmbedding'>, 'PAINN': <class 'fennol.models.embeddings.painn.PAINNEmbedding'>, 'SCHNET': <class 'fennol.models.embeddings.schnet.SchNetEmbedding'>, 'SPOOKYNET': <class 'fennol.models.embeddings.spookynet.SpookyNetEmbedding'>}
PREPROCESSING: Dict =
{'GRAPH': <class 'fennol.models.preprocessing.GraphGenerator'>, 'GRAPH_FILTER': <class 'fennol.models.preprocessing.GraphFilter'>, 'GRAPH_ANGULAR_EXTENSION': <class 'fennol.models.preprocessing.GraphAngularExtension'>, 'SPECIES_INDEXER': <class 'fennol.models.preprocessing.SpeciesIndexer'>, 'BLOCK_INDEXER': <class 'fennol.models.preprocessing.BlockIndexer'>}
def
available_fennix_modules():
def
available_fennix_preprocessing():
def
register_fennix_module(module: flax.linen.module.Module, FID: Optional[str] = None):
26def register_fennix_module(module: nn.Module, FID: Optional[str] = None): 27 if FID is not None: 28 names = [FID.upper()] 29 else: 30 if not hasattr(module, "FID"): 31 print(f"Warning: module {module.__name__} does not have a FID field and no explicit FID was provided. Module was NOT registered.") 32 if not (isinstance( 33 module.FID, str 34 ) or isinstance(module.FID,tuple)): 35 print(f"Warning: module {module.__name__} has an invalid FID field. Module was NOT registered.") 36 if isinstance(module.FID, str): 37 names = [module.FID.upper()] 38 else: 39 names = [fid.upper() for fid in module.FID] 40 for name in names: 41 if name in MODULES and MODULES[name] != module: 42 raise ValueError( 43 f"A different module identified as '{name}' is already registered !" 44 ) 45 MODULES[name] = module
def
register_fennix_preprocessing(module, FPID: Optional[str] = None):
47def register_fennix_preprocessing(module, FPID: Optional[str] = None): 48 if FPID is not None: 49 names = [FPID.upper()] 50 else: 51 if not hasattr(module, "FPID"): 52 print(f"Warning: module {module.__name__} does not have a FPID field and no explicit FPID was provided. Module was NOT registered.") 53 if not (isinstance( 54 module.FPID, str 55 ) or isinstance(module.FPID,tuple)): 56 print(f"Warning: module {module.__name__} has an invalid FPID field. Module was NOT registered.") 57 if isinstance(module.FPID, str): 58 names = [module.FPID.upper()] 59 else: 60 names = [fid.upper() for fid in module.FPID] 61 for name in names: 62 if name in PREPROCESSING and PREPROCESSING[name] != module: 63 raise ValueError( 64 f"A different module identified as '{name}' is already registered !" 65 ) 66 PREPROCESSING[name] = module
def
register_fennix_modules(module, recurs=0, max_recurs=2):
69def register_fennix_modules(module, recurs=0, max_recurs=2): 70 if ismodule(module) and hasattr(module,"__path__"): 71 for _, name, _ in iter_modules(module.__path__): 72 sub_module = __import__(f"{module.__name__}.{name}", fromlist=[""]) 73 register_fennix_modules(sub_module, recurs=recurs + 1, max_recurs=max_recurs) 74 for m in module.__dict__.values(): 75 if isclass(m) and issubclass(m, nn.Module) and m != nn.Module: 76 if hasattr(m, "FID"): 77 register_fennix_module(m) 78 elif isclass(m) and hasattr(m, "FPID"): 79 register_fennix_preprocessing(m)
module_path =
['']
def
get_modules_documentation():
class
FENNIXModules(flax.linen.module.Module):
103class FENNIXModules(nn.Module): 104 r"""Sequential module that applies a sequence of FENNIX modules. 105 106 Attributes: 107 layers (Sequence[Tuple[nn.Module, Dict]]): Sequence of tuples (module, parameters) to apply. 108 109 """ 110 111 layers: Sequence[Tuple[nn.Module, Dict]] 112 113 def __post_init__(self): 114 if not isinstance(self.layers, Sequence): 115 raise ValueError( 116 f"'layers' must be a sequence, got '{type(self.layers).__name__}'." 117 ) 118 super().__post_init__() 119 120 @nn.compact 121 def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]: 122 if not self.layers: 123 raise ValueError(f"Empty Sequential module {self.name}.") 124 125 outputs = inputs 126 for layer, prms in self.layers: 127 outputs = layer(**prms)(outputs) 128 return outputs
Sequential module that applies a sequence of FENNIX modules.
Attributes: layers (Sequence[Tuple[nn.Module, Dict]]): Sequence of tuples (module, parameters) to apply.
FENNIXModules( layers: Sequence[Tuple[flax.linen.module.Module, Dict]], parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.