fennol.models.inspect

  1import argparse
  2from pathlib import Path
  3import yaml
  4import json
  5import os
  6from flax import traverse_util
  7import jax
  8import dataclasses
  9
 10from .fennix import FENNIX
 11
 12
 13class IndentDumper(yaml.Dumper):
 14    def increase_indent(self, flow=False, indentless=False):
 15        return super(IndentDumper, self).increase_indent(flow, False)
 16
 17
 18def main():
 19    # os.environ["CUDA_VISIBLE_DEVICES"] = ""
 20    _device = jax.devices("cpu")[0]
 21    jax.config.update("jax_default_device", _device)
 22    ### Read the parameter file
 23    parser = argparse.ArgumentParser(prog="fennol_inspect")
 24    parser.add_argument("model_file", type=Path, help="Model file")
 25    parser.add_argument(
 26        "-s",
 27        "--short",
 28        action="store_true",
 29        help="Print only module names in the list of modules",
 30    )
 31    parser.add_argument(
 32        "-a",
 33        "--all",
 34        action="store_true",
 35        help="Print all module attributes and automatically added modules",
 36    )
 37    parser.add_argument(
 38        "-p","--prm", action="store_true", help="Print parameter shapes"
 39    )
 40    args = parser.parse_args()
 41    model_file = args.model_file
 42    # param_shapes = args.param_shapes
 43    # short = args.short
 44    # all = args.all
 45
 46    ### Load the model
 47    if model_file.suffix == ".yaml" or model_file.suffix == ".yml":
 48        with open(model_file, "r") as f:
 49            model_dict = yaml.safe_load(f)
 50        model = FENNIX(**model_dict,rng_key=jax.random.PRNGKey(0))
 51    else:
 52        model = FENNIX.load(model_file)
 53    print(f"# filename: {model_file}")
 54    print(inspect_model(model, **vars(args)))
 55
 56
 57def inspect_model(model, prm=False, short=False, all=False, **kwargs):
 58
 59    model_dict = model._input_args
 60
 61    inspect_dict = {
 62        "energy_unit": model.energy_unit,
 63        "energy_terms": model.energy_terms,
 64        "cutoff": model.cutoff,
 65    }
 66    
 67    inspect_dict["preprocessing"] = dict(model_dict["preprocessing"])
 68    if not all:
 69        inspect_dict["modules"] = dict(model_dict["modules"])
 70    else:
 71        mods = {}
 72        for mod, inp in model.modules.layers:
 73            m = mod(**inp)
 74            key = m.name
 75            if key is None or key.startswith("_"):
 76                continue
 77            mods[key] = {}
 78            for field in dataclasses.fields(m):
 79                if field.name not in ["name", "parent"]:
 80                    mods[key][field.name] = getattr(m, field.name)
 81        inspect_dict["modules"] = mods
 82    
 83    if short:
 84        inspect_dict["preprocessing"] = list(inspect_dict["preprocessing"].keys())
 85        inspect_dict["modules"] = list(inspect_dict["modules"].keys())
 86    
 87    data = "# MODEL DESCRIPTION\n"
 88    data += yaml.dump(inspect_dict, sort_keys=False, Dumper=IndentDumper)
 89
 90    params = model.variables["params"] if "params" in model.variables else {}
 91
 92    if prm:
 93        shapes = traverse_util.path_aware_map(
 94            lambda p, v: f"[{','.join(str(i) for i in v.shape)}]",
 95            params,
 96        )
 97
 98        data = (
 99            data
100            + "\n\n# PARAMETER SHAPES\n"
101            + yaml.dump(
102                shapes,
103                sort_keys=False,
104            )
105        )
106
107    number_of_params = sum(
108        jax.tree.leaves(
109            traverse_util.path_aware_map(
110                lambda p, v: v.size,
111                params,
112            )
113        )
114    )
115
116    data += f"\n# NUMBER OF PARAMETERS: {number_of_params:_}"
117    return data
118
119
120if __name__ == "__main__":
121    main()
class IndentDumper(yaml.dumper.Dumper):
14class IndentDumper(yaml.Dumper):
15    def increase_indent(self, flow=False, indentless=False):
16        return super(IndentDumper, self).increase_indent(flow, False)
def increase_indent(self, flow=False, indentless=False):
15    def increase_indent(self, flow=False, indentless=False):
16        return super(IndentDumper, self).increase_indent(flow, False)
def main():
19def main():
20    # os.environ["CUDA_VISIBLE_DEVICES"] = ""
21    _device = jax.devices("cpu")[0]
22    jax.config.update("jax_default_device", _device)
23    ### Read the parameter file
24    parser = argparse.ArgumentParser(prog="fennol_inspect")
25    parser.add_argument("model_file", type=Path, help="Model file")
26    parser.add_argument(
27        "-s",
28        "--short",
29        action="store_true",
30        help="Print only module names in the list of modules",
31    )
32    parser.add_argument(
33        "-a",
34        "--all",
35        action="store_true",
36        help="Print all module attributes and automatically added modules",
37    )
38    parser.add_argument(
39        "-p","--prm", action="store_true", help="Print parameter shapes"
40    )
41    args = parser.parse_args()
42    model_file = args.model_file
43    # param_shapes = args.param_shapes
44    # short = args.short
45    # all = args.all
46
47    ### Load the model
48    if model_file.suffix == ".yaml" or model_file.suffix == ".yml":
49        with open(model_file, "r") as f:
50            model_dict = yaml.safe_load(f)
51        model = FENNIX(**model_dict,rng_key=jax.random.PRNGKey(0))
52    else:
53        model = FENNIX.load(model_file)
54    print(f"# filename: {model_file}")
55    print(inspect_model(model, **vars(args)))
def inspect_model(model, prm=False, short=False, all=False, **kwargs):
 58def inspect_model(model, prm=False, short=False, all=False, **kwargs):
 59
 60    model_dict = model._input_args
 61
 62    inspect_dict = {
 63        "energy_unit": model.energy_unit,
 64        "energy_terms": model.energy_terms,
 65        "cutoff": model.cutoff,
 66    }
 67    
 68    inspect_dict["preprocessing"] = dict(model_dict["preprocessing"])
 69    if not all:
 70        inspect_dict["modules"] = dict(model_dict["modules"])
 71    else:
 72        mods = {}
 73        for mod, inp in model.modules.layers:
 74            m = mod(**inp)
 75            key = m.name
 76            if key is None or key.startswith("_"):
 77                continue
 78            mods[key] = {}
 79            for field in dataclasses.fields(m):
 80                if field.name not in ["name", "parent"]:
 81                    mods[key][field.name] = getattr(m, field.name)
 82        inspect_dict["modules"] = mods
 83    
 84    if short:
 85        inspect_dict["preprocessing"] = list(inspect_dict["preprocessing"].keys())
 86        inspect_dict["modules"] = list(inspect_dict["modules"].keys())
 87    
 88    data = "# MODEL DESCRIPTION\n"
 89    data += yaml.dump(inspect_dict, sort_keys=False, Dumper=IndentDumper)
 90
 91    params = model.variables["params"] if "params" in model.variables else {}
 92
 93    if prm:
 94        shapes = traverse_util.path_aware_map(
 95            lambda p, v: f"[{','.join(str(i) for i in v.shape)}]",
 96            params,
 97        )
 98
 99        data = (
100            data
101            + "\n\n# PARAMETER SHAPES\n"
102            + yaml.dump(
103                shapes,
104                sort_keys=False,
105            )
106        )
107
108    number_of_params = sum(
109        jax.tree.leaves(
110            traverse_util.path_aware_map(
111                lambda p, v: v.size,
112                params,
113            )
114        )
115    )
116
117    data += f"\n# NUMBER OF PARAMETERS: {number_of_params:_}"
118    return data