fennol.models.modules

  1from typing import Sequence, Tuple, Dict, Optional, Any
  2import flax.linen as nn
  3from inspect import isclass, ismodule
  4from pkgutil import iter_modules
  5import importlib
  6import importlib.util
  7import os
  8import glob
  9
 10### python modules where FENNIX Modules are defined ###
 11from . import misc,physics, embeddings,preprocessing
 12
 13
 14MODULES: Dict[str, nn.Module] = {}
 15PREPROCESSING: Dict = {}
 16
 17
 18def available_fennix_modules():
 19    return list(MODULES.keys())
 20
 21def available_fennix_preprocessing():
 22    return list(PREPROCESSING.keys())
 23
 24
 25def register_fennix_module(module: nn.Module, FID: Optional[str] = None):
 26    if FID is not None:
 27        names = [FID.upper()]
 28    else:
 29        if not hasattr(module, "FID"):
 30            print(f"Warning: module {module.__name__} does not have a FID field and no explicit FID was provided. Module was NOT registered.")
 31        if not (isinstance(
 32            module.FID, str
 33        ) or isinstance(module.FID,tuple)):
 34            print(f"Warning: module {module.__name__} has an invalid FID field. Module was NOT registered.")
 35        if isinstance(module.FID, str):
 36            names = [module.FID.upper()]
 37        else:
 38            names = [fid.upper() for fid in module.FID]
 39    for name in names:
 40        if name in MODULES and MODULES[name] != module:
 41            raise ValueError(
 42                f"A different module identified as '{name}' is already registered !"
 43            )
 44        MODULES[name] = module
 45
 46def register_fennix_preprocessing(module, FPID: Optional[str] = None):
 47    if FPID is not None:
 48        names = [FPID.upper()]
 49    else:
 50        if not hasattr(module, "FPID"):
 51            print(f"Warning: module {module.__name__} does not have a FPID field and no explicit FPID was provided. Module was NOT registered.")
 52        if not (isinstance(
 53            module.FPID, str
 54        ) or isinstance(module.FPID,tuple)):
 55            print(f"Warning: module {module.__name__} has an invalid FPID field. Module was NOT registered.")
 56        if isinstance(module.FPID, str):
 57            names = [module.FPID.upper()]
 58        else:
 59            names = [fid.upper() for fid in module.FPID]
 60    for name in names:
 61        if name in PREPROCESSING and PREPROCESSING[name] != module:
 62            raise ValueError(
 63                f"A different module identified as '{name}' is already registered !"
 64            )
 65        PREPROCESSING[name] = module
 66
 67
 68def register_fennix_modules(module, recurs=0, max_recurs=2):
 69    if ismodule(module) and hasattr(module,"__path__"):
 70        for _, name, _ in iter_modules(module.__path__):
 71            sub_module = __import__(f"{module.__name__}.{name}", fromlist=[""])
 72            register_fennix_modules(sub_module, recurs=recurs + 1, max_recurs=max_recurs)
 73    for m in module.__dict__.values():
 74        if isclass(m) and issubclass(m, nn.Module) and m != nn.Module:
 75            if hasattr(m, "FID"):
 76                register_fennix_module(m)
 77        elif isclass(m) and hasattr(m, "FPID"):
 78            register_fennix_preprocessing(m)
 79
 80
 81### REGISTER DEFAULT MODULES #####################
 82for mods in [misc, physics, embeddings,preprocessing]:
 83    register_fennix_modules(mods)
 84module_path = os.environ.get("FENNOL_MODULES_PATH","").split(":")
 85for path in module_path:
 86    if os.path.exists(path):
 87        for file in glob.glob(f"{path}/*.py"):
 88            spec = importlib.util.spec_from_file_location("custom_modules", file)
 89            module = importlib.util.module_from_spec(spec)
 90            spec.loader.exec_module(module)
 91            register_fennix_modules(module)
 92
 93##################################################
 94
 95def get_modules_documentation():
 96    doc = {}
 97    for name, module in MODULES.items():
 98        doc[name] = module.__doc__
 99    return doc
100
101
102class FENNIXModules(nn.Module):
103    r"""Sequential module that applies a sequence of FENNIX modules.
104
105    Attributes:
106        layers (Sequence[Tuple[nn.Module, Dict]]): Sequence of tuples (module, parameters) to apply.
107
108    """
109
110    layers: Sequence[Tuple[nn.Module, Dict]]
111
112    def __post_init__(self):
113        if not isinstance(self.layers, Sequence):
114            raise ValueError(
115                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
116            )
117        super().__post_init__()
118
119    @nn.compact
120    def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
121        if not self.layers:
122            raise ValueError(f"Empty Sequential module {self.name}.")
123
124        outputs = inputs
125        for layer, prms in self.layers:
126            outputs = layer(**prms)(outputs)
127        return outputs
MODULES: Dict[str, flax.linen.module.Module] = {'CHANNEL_MIXING': <class 'fennol.models.misc.e3.ChannelMixing'>, 'CHANNEL_MIXING_E3': <class 'fennol.models.misc.e3.ChannelMixingE3'>, 'SPHERICAL_TO_CARTESIAN': <class 'fennol.models.misc.e3.SphericalToCartesian'>, 'NEURAL_NET': <class 'fennol.models.misc.nets.FullyConnectedNet'>, 'SPECIES_ENCODING': <class 'fennol.models.misc.encodings.SpeciesEncoding'>, 'RADIAL_BASIS': <class 'fennol.models.misc.encodings.RadialBasis'>, 'APPLY_SWITCH': <class 'fennol.models.misc.misc.ApplySwitch'>, 'ATOM_TO_EDGE': <class 'fennol.models.misc.misc.AtomToEdge'>, 'SCATTER_EDGES': <class 'fennol.models.misc.misc.ScatterEdges'>, 'EDGE_CONCATENATE': <class 'fennol.models.misc.misc.EdgeConcatenate'>, 'SCATTER_SYSTEM': <class 'fennol.models.misc.misc.ScatterSystem'>, 'SYSTEM_TO_ATOMS': <class 'fennol.models.misc.misc.SystemToAtoms'>, 'SUM_AXIS': <class 'fennol.models.misc.misc.SumAxis'>, 'SPLIT': <class 'fennol.models.misc.misc.Split'>, 'CONCATENATE': <class 'fennol.models.misc.misc.Concatenate'>, 'ACTIVATION': <class 'fennol.models.misc.misc.Activation'>, 'SCALE': <class 'fennol.models.misc.misc.Scale'>, 'ADD': <class 'fennol.models.misc.misc.Add'>, 'MULTIPLY': <class 'fennol.models.misc.misc.Multiply'>, 'TRANSPOSE': <class 'fennol.models.misc.misc.Transpose'>, 'RESHAPE': <class 'fennol.models.misc.misc.Reshape'>, 'CHEMICAL_CONSTANT': <class 'fennol.models.misc.misc.ChemicalConstant'>, 'SWITCH_FUNCTION': <class 'fennol.models.misc.misc.SwitchFunction'>, 'RES_MLP': <class 'fennol.models.misc.nets.ResMLP'>, 'SKIP_NET': <class 'fennol.models.misc.nets.FullyResidualNet'>, 'HIERARCHICAL_NET': <class 'fennol.models.misc.nets.HierarchicalNet'>, 'SPECIES_INDEX_NET': <class 'fennol.models.misc.nets.SpeciesIndexNet'>, 'CHEMICAL_NET': <class 'fennol.models.misc.nets.ChemicalNet'>, 'MOE_NET': <class 'fennol.models.misc.nets.MOENet'>, 'CHANNEL_NET': <class 'fennol.models.misc.nets.ChannelNet'>, 'GATED_PERCEPTRON': <class 'fennol.models.misc.nets.GatedPerceptron'>, 'ZACNET': <class 'fennol.models.misc.nets.ZAcNet'>, 'ZLORANET': <class 'fennol.models.misc.nets.ZLoRANet'>, 'BLOCK_INDEX_NET': <class 'fennol.models.misc.nets.BlockIndexNet'>, 'ENSEMBLE_STAT': <class 'fennol.models.misc.uncertainty.EnsembleStatistics'>, 'ENSEMBLE_SHIFT': <class 'fennol.models.misc.uncertainty.EnsembleShift'>, 'CONSTRAIN_EVIDENCE': <class 'fennol.models.misc.uncertainty.ConstrainEvidence'>, 'CN_D4': <class 'fennol.models.physics.bond.CND4'>, 'SUM_SWITCH': <class 'fennol.models.physics.bond.SumSwitch'>, 'CN_SHIFT': <class 'fennol.models.physics.bond.CNShift'>, 'CN_STORE': <class 'fennol.models.physics.bond.CNStore'>, 'FLAT_BOTTOM': <class 'fennol.models.physics.bond.FlatBottom'>, 'VDW_OQDO': <class 'fennol.models.physics.dispersion.VdwOQDO'>, 'ELECTRIC_FIELD': <class 'fennol.models.physics.electric_field.ElectricField'>, 'COULOMB': <class 'fennol.models.physics.electrostatics.Coulomb'>, 'QEQ_D4': <class 'fennol.models.physics.electrostatics.QeqD4'>, 'CHARGE_CORRECTION': <class 'fennol.models.physics.electrostatics.ChargeCorrection'>, 'DISTRIBUTE_ELECTRONS': <class 'fennol.models.physics.electrostatics.DistributeElectrons'>, 'POLARIZATION': <class 'fennol.models.physics.polarisation.Polarization'>, 'REPULSION_ZBL': <class 'fennol.models.physics.repulsion.RepulsionZBL'>, 'AIMNET': <class 'fennol.models.embeddings.aimnet.AIMNet'>, 'ALLEGRO': <class 'fennol.models.embeddings.allegro.AllegroEmbedding'>, 'ALLEGRO_E3NN': <class 'fennol.models.embeddings.allegro.AllegroE3NNEmbedding'>, 'ANI_AEV': <class 'fennol.models.embeddings.ani.ANIAEV'>, 'CAIMAN': <class 'fennol.models.embeddings.caiman.CaimanEmbedding'>, 'CHARGE_HYPOTHESIS': <class 'fennol.models.embeddings.charge_embeddings.ChargeHypothesis'>, 'CHGNET': <class 'fennol.models.embeddings.chgnet.CHGNetEmbedding'>, 'CRATE': <class 'fennol.models.embeddings.crate.CRATEmbedding'>, 'DEEPPOT': <class 'fennol.models.embeddings.deeppot.DeepPotEmbedding'>, 'DEEPPOT_E3': <class 'fennol.models.embeddings.deeppot.DeepPotE3Embedding'>, 'EEACSF': <class 'fennol.models.embeddings.eeacsf.EEACSF'>, 'FOAM': <class 'fennol.models.embeddings.foam.FOAMEmbedding'>, 'GAUSSIAN_MOMENTS': <class 'fennol.models.embeddings.gaussian_moments.GaussianMomentsEmbedding'>, 'HIPNN': <class 'fennol.models.embeddings.hipnn.HIPNNEmbedding'>, 'MACE': <class 'fennol.models.embeddings.mace.MACE'>, 'MINIMACE': <class 'fennol.models.embeddings.minimace.MiniMaceEmbedding'>, 'NEWTONNET': <class 'fennol.models.embeddings.newtonnet.NewtonNetEmbedding'>, 'PAINN': <class 'fennol.models.embeddings.painn.PAINNEmbedding'>, 'SCHNET': <class 'fennol.models.embeddings.schnet.SchNetEmbedding'>, 'SPOOKYNET': <class 'fennol.models.embeddings.spookynet.SpookyNetEmbedding'>}
PREPROCESSING: Dict = {'GRAPH': <class 'fennol.models.preprocessing.GraphGenerator'>, 'GRAPH_FILTER': <class 'fennol.models.preprocessing.GraphFilter'>, 'GRAPH_ANGULAR_EXTENSION': <class 'fennol.models.preprocessing.GraphAngularExtension'>, 'SPECIES_INDEXER': <class 'fennol.models.preprocessing.SpeciesIndexer'>, 'BLOCK_INDEXER': <class 'fennol.models.preprocessing.BlockIndexer'>}
def available_fennix_modules():
19def available_fennix_modules():
20    return list(MODULES.keys())
def available_fennix_preprocessing():
22def available_fennix_preprocessing():
23    return list(PREPROCESSING.keys())
def register_fennix_module(module: flax.linen.module.Module, FID: Optional[str] = None):
26def register_fennix_module(module: nn.Module, FID: Optional[str] = None):
27    if FID is not None:
28        names = [FID.upper()]
29    else:
30        if not hasattr(module, "FID"):
31            print(f"Warning: module {module.__name__} does not have a FID field and no explicit FID was provided. Module was NOT registered.")
32        if not (isinstance(
33            module.FID, str
34        ) or isinstance(module.FID,tuple)):
35            print(f"Warning: module {module.__name__} has an invalid FID field. Module was NOT registered.")
36        if isinstance(module.FID, str):
37            names = [module.FID.upper()]
38        else:
39            names = [fid.upper() for fid in module.FID]
40    for name in names:
41        if name in MODULES and MODULES[name] != module:
42            raise ValueError(
43                f"A different module identified as '{name}' is already registered !"
44            )
45        MODULES[name] = module
def register_fennix_preprocessing(module, FPID: Optional[str] = None):
47def register_fennix_preprocessing(module, FPID: Optional[str] = None):
48    if FPID is not None:
49        names = [FPID.upper()]
50    else:
51        if not hasattr(module, "FPID"):
52            print(f"Warning: module {module.__name__} does not have a FPID field and no explicit FPID was provided. Module was NOT registered.")
53        if not (isinstance(
54            module.FPID, str
55        ) or isinstance(module.FPID,tuple)):
56            print(f"Warning: module {module.__name__} has an invalid FPID field. Module was NOT registered.")
57        if isinstance(module.FPID, str):
58            names = [module.FPID.upper()]
59        else:
60            names = [fid.upper() for fid in module.FPID]
61    for name in names:
62        if name in PREPROCESSING and PREPROCESSING[name] != module:
63            raise ValueError(
64                f"A different module identified as '{name}' is already registered !"
65            )
66        PREPROCESSING[name] = module
def register_fennix_modules(module, recurs=0, max_recurs=2):
69def register_fennix_modules(module, recurs=0, max_recurs=2):
70    if ismodule(module) and hasattr(module,"__path__"):
71        for _, name, _ in iter_modules(module.__path__):
72            sub_module = __import__(f"{module.__name__}.{name}", fromlist=[""])
73            register_fennix_modules(sub_module, recurs=recurs + 1, max_recurs=max_recurs)
74    for m in module.__dict__.values():
75        if isclass(m) and issubclass(m, nn.Module) and m != nn.Module:
76            if hasattr(m, "FID"):
77                register_fennix_module(m)
78        elif isclass(m) and hasattr(m, "FPID"):
79            register_fennix_preprocessing(m)
module_path = ['']
def get_modules_documentation():
 96def get_modules_documentation():
 97    doc = {}
 98    for name, module in MODULES.items():
 99        doc[name] = module.__doc__
100    return doc
class FENNIXModules(flax.linen.module.Module):
103class FENNIXModules(nn.Module):
104    r"""Sequential module that applies a sequence of FENNIX modules.
105
106    Attributes:
107        layers (Sequence[Tuple[nn.Module, Dict]]): Sequence of tuples (module, parameters) to apply.
108
109    """
110
111    layers: Sequence[Tuple[nn.Module, Dict]]
112
113    def __post_init__(self):
114        if not isinstance(self.layers, Sequence):
115            raise ValueError(
116                f"'layers' must be a sequence, got '{type(self.layers).__name__}'."
117            )
118        super().__post_init__()
119
120    @nn.compact
121    def __call__(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
122        if not self.layers:
123            raise ValueError(f"Empty Sequential module {self.name}.")
124
125        outputs = inputs
126        for layer, prms in self.layers:
127            outputs = layer(**prms)(outputs)
128        return outputs

Sequential module that applies a sequence of FENNIX modules.

Attributes: layers (Sequence[Tuple[nn.Module, Dict]]): Sequence of tuples (module, parameters) to apply.

FENNIXModules( layers: Sequence[Tuple[flax.linen.module.Module, Dict]], parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
layers: Sequence[Tuple[flax.linen.module.Module, Dict]]
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None