fennol.models.inspect
1import argparse 2from pathlib import Path 3import yaml 4import json 5import os 6from flax import traverse_util 7import jax 8import dataclasses 9 10from .fennix import FENNIX 11 12 13class IndentDumper(yaml.Dumper): 14 def increase_indent(self, flow=False, indentless=False): 15 return super(IndentDumper, self).increase_indent(flow, False) 16 17 18def main(): 19 # os.environ["CUDA_VISIBLE_DEVICES"] = "" 20 _device = jax.devices("cpu")[0] 21 jax.config.update("jax_default_device", _device) 22 ### Read the parameter file 23 parser = argparse.ArgumentParser(prog="fennol_inspect") 24 parser.add_argument("model_file", type=Path, help="Model file") 25 parser.add_argument( 26 "-s", 27 "--short", 28 action="store_true", 29 help="Print only module names in the list of modules", 30 ) 31 parser.add_argument( 32 "-a", 33 "--all", 34 action="store_true", 35 help="Print all module attributes and automatically added modules", 36 ) 37 parser.add_argument( 38 "-p","--prm", action="store_true", help="Print parameter shapes" 39 ) 40 args = parser.parse_args() 41 model_file = args.model_file 42 # param_shapes = args.param_shapes 43 # short = args.short 44 # all = args.all 45 46 ### Load the model 47 if model_file.suffix == ".yaml" or model_file.suffix == ".yml": 48 with open(model_file, "r") as f: 49 model_dict = yaml.safe_load(f) 50 model = FENNIX(**model_dict,rng_key=jax.random.PRNGKey(0)) 51 else: 52 model = FENNIX.load(model_file) 53 print(f"# filename: {model_file}") 54 print(inspect_model(model, **vars(args))) 55 56 57def inspect_model(model, prm=False, short=False, all=False, **kwargs): 58 59 model_dict = model._input_args 60 61 inspect_dict = { 62 "energy_unit": model.energy_unit, 63 "energy_terms": model.energy_terms, 64 "cutoff": model.cutoff, 65 } 66 67 inspect_dict["preprocessing"] = dict(model_dict["preprocessing"]) 68 if not all: 69 inspect_dict["modules"] = dict(model_dict["modules"]) 70 else: 71 mods = {} 72 for mod, inp in model.modules.layers: 73 m = mod(**inp) 74 key = m.name 75 if key is None or key.startswith("_"): 76 continue 77 mods[key] = {} 78 for field in dataclasses.fields(m): 79 if field.name not in ["name", "parent"]: 80 mods[key][field.name] = getattr(m, field.name) 81 inspect_dict["modules"] = mods 82 83 if short: 84 inspect_dict["preprocessing"] = list(inspect_dict["preprocessing"].keys()) 85 inspect_dict["modules"] = list(inspect_dict["modules"].keys()) 86 87 data = "# MODEL DESCRIPTION\n" 88 data += yaml.dump(inspect_dict, sort_keys=False, Dumper=IndentDumper) 89 90 params = model.variables["params"] if "params" in model.variables else {} 91 92 if prm: 93 shapes = traverse_util.path_aware_map( 94 lambda p, v: f"[{','.join(str(i) for i in v.shape)}]", 95 params, 96 ) 97 98 data = ( 99 data 100 + "\n\n# PARAMETER SHAPES\n" 101 + yaml.dump( 102 shapes, 103 sort_keys=False, 104 ) 105 ) 106 107 number_of_params = sum( 108 jax.tree.leaves( 109 traverse_util.path_aware_map( 110 lambda p, v: v.size, 111 params, 112 ) 113 ) 114 ) 115 116 data += f"\n# NUMBER OF PARAMETERS: {number_of_params:_}" 117 return data 118 119 120if __name__ == "__main__": 121 main()
class
IndentDumper(yaml.dumper.Dumper):
def
main():
19def main(): 20 # os.environ["CUDA_VISIBLE_DEVICES"] = "" 21 _device = jax.devices("cpu")[0] 22 jax.config.update("jax_default_device", _device) 23 ### Read the parameter file 24 parser = argparse.ArgumentParser(prog="fennol_inspect") 25 parser.add_argument("model_file", type=Path, help="Model file") 26 parser.add_argument( 27 "-s", 28 "--short", 29 action="store_true", 30 help="Print only module names in the list of modules", 31 ) 32 parser.add_argument( 33 "-a", 34 "--all", 35 action="store_true", 36 help="Print all module attributes and automatically added modules", 37 ) 38 parser.add_argument( 39 "-p","--prm", action="store_true", help="Print parameter shapes" 40 ) 41 args = parser.parse_args() 42 model_file = args.model_file 43 # param_shapes = args.param_shapes 44 # short = args.short 45 # all = args.all 46 47 ### Load the model 48 if model_file.suffix == ".yaml" or model_file.suffix == ".yml": 49 with open(model_file, "r") as f: 50 model_dict = yaml.safe_load(f) 51 model = FENNIX(**model_dict,rng_key=jax.random.PRNGKey(0)) 52 else: 53 model = FENNIX.load(model_file) 54 print(f"# filename: {model_file}") 55 print(inspect_model(model, **vars(args)))
def
inspect_model(model, prm=False, short=False, all=False, **kwargs):
58def inspect_model(model, prm=False, short=False, all=False, **kwargs): 59 60 model_dict = model._input_args 61 62 inspect_dict = { 63 "energy_unit": model.energy_unit, 64 "energy_terms": model.energy_terms, 65 "cutoff": model.cutoff, 66 } 67 68 inspect_dict["preprocessing"] = dict(model_dict["preprocessing"]) 69 if not all: 70 inspect_dict["modules"] = dict(model_dict["modules"]) 71 else: 72 mods = {} 73 for mod, inp in model.modules.layers: 74 m = mod(**inp) 75 key = m.name 76 if key is None or key.startswith("_"): 77 continue 78 mods[key] = {} 79 for field in dataclasses.fields(m): 80 if field.name not in ["name", "parent"]: 81 mods[key][field.name] = getattr(m, field.name) 82 inspect_dict["modules"] = mods 83 84 if short: 85 inspect_dict["preprocessing"] = list(inspect_dict["preprocessing"].keys()) 86 inspect_dict["modules"] = list(inspect_dict["modules"].keys()) 87 88 data = "# MODEL DESCRIPTION\n" 89 data += yaml.dump(inspect_dict, sort_keys=False, Dumper=IndentDumper) 90 91 params = model.variables["params"] if "params" in model.variables else {} 92 93 if prm: 94 shapes = traverse_util.path_aware_map( 95 lambda p, v: f"[{','.join(str(i) for i in v.shape)}]", 96 params, 97 ) 98 99 data = ( 100 data 101 + "\n\n# PARAMETER SHAPES\n" 102 + yaml.dump( 103 shapes, 104 sort_keys=False, 105 ) 106 ) 107 108 number_of_params = sum( 109 jax.tree.leaves( 110 traverse_util.path_aware_map( 111 lambda p, v: v.size, 112 params, 113 ) 114 ) 115 ) 116 117 data += f"\n# NUMBER OF PARAMETERS: {number_of_params:_}" 118 return data