fennol.analyze
1import sys, os, io 2import argparse 3from pathlib import Path 4 5import numpy as np 6import jax 7import jax.numpy as jnp 8 9from .models import FENNIX 10from .utils.io import xyz_reader 11from .utils.periodic_table import PERIODIC_TABLE_REV_IDX 12from .utils.atomic_units import AtomicUnits as au 13import yaml 14import pickle 15import json 16 17 18def main(): 19 parser = argparse.ArgumentParser(prog="fennol_analyze") 20 parser.add_argument( 21 "input_file", type=Path, help="file containing the geometries to analyze" 22 ) 23 parser.add_argument( 24 "model_file", type=Path, help="file containing the model to use" 25 ) 26 parser.add_argument("--output", "-o", type=Path, help="file to write the output to") 27 parser.add_argument( 28 "--batch_size", 29 "-b", 30 type=int, 31 default=1, 32 help="batch size to use for the model", 33 ) 34 parser.add_argument( 35 "--device", type=str, default="cpu", help="device to use for the model" 36 ) 37 parser.add_argument("-f64", action="store_true", help="use double precision") 38 parser.add_argument("--periodic", action="store_true", help="use PBC") 39 parser.add_argument( 40 "--format", type=str, default="xyz", help="format of the input file" 41 ) 42 parser.add_argument( 43 "--output_keys", 44 type=str, 45 nargs="+", 46 help="keys to output", 47 default=["total_energy"], 48 ) 49 args = parser.parse_args() 50 51 # set the device 52 device = args.device.lower() 53 if device == "cpu": 54 device = "cpu" 55 os.environ["CUDA_VISIBLE_DEVICES"] = "" 56 elif device.startswith("cuda") or device.startswith("gpu"): 57 if ":" in device: 58 num = device.split(":")[-1] 59 os.environ["CUDA_VISIBLE_DEVICES"] = num 60 else: 61 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 62 device = "gpu" 63 64 _device = jax.devices(device)[0] 65 jax.config.update("jax_default_device", _device) 66 67 # set the precision 68 if args.f64: 69 jax.config.update("jax_enable_x64", True) 70 fprec = "float64" 71 else: 72 fprec = "float32" 73 jax.config.update("jax_default_matmul_precision", "highest") 74 75 76 # check the input file 77 input_file = args.input_file.resolve() 78 assert input_file.exists(), f"Input file {input_file} does not exist" 79 # assert args.format == "xyz", f"Only xyz format is supported for now" 80 assert args.format in ["xyz","arc","pkl"], f"Only xyz, arc and pkl formats are supported for now" 81 output_keys = args.output_keys 82 xyz_indexed = args.format == "arc" 83 84 # load the model 85 model_file: Path = args.model_file.resolve() 86 assert model_file.exists(), f"Model file {model_file} does not exist" 87 model = FENNIX.load(model_file, use_atom_padding=True) 88 89 # check the output file 90 if args.output is not None: 91 output_file: Path = args.output.resolve() 92 assert not output_file.exists(), f"Output file {args.output} already exists" 93 assert output_file.suffix in [ 94 ".pkl", 95 ".json", 96 ".yaml", 97 ], f"Unsupported output file format {output_file.suffix}" 98 99 # print metadata 100 metadata = { 101 "input_file": str(input_file), 102 "model_file": str(model_file), 103 "output_keys": output_keys, 104 "energy_unit": model.energy_unit, 105 } 106 print("metadata:") 107 for k,v in metadata.items(): 108 print(f" {k}: {v}") 109 110 # define the model prediction function 111 def model_predict(batch): 112 natoms = np.array([frame["natoms"] for frame in batch]) 113 batch_index = np.concatenate([frame["batch_index"] for frame in batch]) 114 species = np.concatenate([frame["species"] for frame in batch]) 115 xyz = np.concatenate([frame["coordinates"] for frame in batch], axis=0) 116 inputs = { 117 "species": species, 118 "coordinates": xyz, 119 "batch_index": batch_index, 120 "natoms": natoms,} 121 if args.periodic: 122 cells = np.concatenate([frame["cell"] for frame in batch], axis=0) 123 inputs["cells"] = cells 124 125 if "forces" in output_keys: 126 e, f, output = model.energy_and_forces( 127 **inputs 128 ) 129 else: 130 e, output = model.total_energy( 131 **inputs 132 ) 133 return output 134 135 # define the function to process a batch 136 def process_batch(batch): 137 output = model_predict(batch) 138 natoms = np.array([frame["natoms"] for frame in batch]) 139 if args.periodic: 140 cells = np.array(output["cells"]) 141 species = np.array(output["species"]) 142 coordinates = np.array(output["coordinates"]) 143 natshift = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(natoms)]) 144 frames_data = [] 145 for i in range(len(batch)): 146 frame_data = { 147 "species": species[natshift[i] : natshift[i + 1]].tolist(), 148 "coordinates": coordinates[natshift[i] : natshift[i + 1]].tolist(), 149 } 150 if args.periodic: 151 frame_data["cell"] = cells[i].tolist() 152 153 for k in output_keys: 154 if k not in output: 155 raise ValueError(f"Output key {k} not found") 156 v = output[k] 157 if v.shape[0] == species.shape[0]: 158 frame_data[k] = v[natshift[i] : natshift[i + 1]].tolist() 159 elif v.shape[0] == natoms.shape[0]: 160 frame_data[k] = v[i].tolist() 161 else: 162 raise ValueError(f"Output key {k} has wrong shape") 163 164 frames_data.append(frame_data) 165 return frames_data 166 167 ### start processing the input file 168 # reader = xyz_reader(input_file, has_comment_line=True, indexed=xyz_indexed) 169 if args.format in ["arc","xyz"]: 170 def reader(): 171 for symbols, xyz, comment in xyz_reader(input_file, has_comment_line=True, indexed=xyz_indexed): 172 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols]) 173 frame = {"species": species, "coordinates": xyz, "natoms":species.shape[0]} 174 if args.periodic: 175 cell = np.array([float(x) for x in comment.split()]).reshape(1,3,3) 176 frame["cell"] = cell 177 yield frame 178 179 elif args.format == "pkl": 180 with open(input_file, "rb") as f: 181 data = pickle.load(f) 182 if isinstance(data, dict): 183 if "frames" in data: 184 frames = data["frames"] 185 elif "training" in data: 186 frames = data["training"] 187 if "validation" in data: 188 frames.extend(data["validation"]) 189 else: 190 raise ValueError("No frames found in the input file") 191 else: 192 frames = data 193 194 def reader(): 195 for frame in frames: 196 species = np.array(frame["species"]) 197 coordinates = np.array(frame["coordinates"]) 198 frame = {"species": species, "coordinates": coordinates,"natoms":species.shape[0]} 199 if args.periodic: 200 cell = np.array(frame["cell"]).reshape(1,3,3) 201 frame["cell"] = cell 202 yield frame 203 204 batch = [] 205 output_data = [] 206 ibatch = 0 207 for frame in reader(): 208 batch_index = np.full(frame["natoms"], ibatch, dtype=np.int32) 209 frame["batch_index"] = batch_index 210 batch.append(frame) 211 ibatch += 1 212 if len(batch) == args.batch_size: 213 frames_data = process_batch(batch) 214 if args.output is None: 215 for frame_data in frames_data: 216 print("\n---\n") 217 for k in frame_data: 218 print(f"{k}: {frame_data[k]}") 219 output_data.extend(frames_data) 220 batch = [] 221 ibatch = 0 222 # process the last batch 223 if len(batch) > 0: 224 frames_data = process_batch(batch) 225 if args.output is None: 226 for frame_data in frames_data: 227 print("\n---\n") 228 for k in frame_data: 229 print(f"{k}: {frame_data[k]}") 230 output_data.extend(frames_data) 231 232 # write the output to a file 233 if args.output is not None: 234 output_data = { 235 "metadata": metadata, 236 "frames": output_data, 237 } 238 if output_file.suffix == ".json": 239 with open(args.output, "w") as f: 240 json.dump(output_data, f, indent=2) 241 elif output_file.suffix == ".yaml": 242 with open(args.output, "w") as f: 243 yaml.dump(output_data, f, sort_keys=False) 244 elif output_file.suffix == ".pkl": 245 with open(args.output, "wb") as f: 246 pickle.dump(output_data, f) 247 248 249if __name__ == "__main__": 250 main()
def
main():
19def main(): 20 parser = argparse.ArgumentParser(prog="fennol_analyze") 21 parser.add_argument( 22 "input_file", type=Path, help="file containing the geometries to analyze" 23 ) 24 parser.add_argument( 25 "model_file", type=Path, help="file containing the model to use" 26 ) 27 parser.add_argument("--output", "-o", type=Path, help="file to write the output to") 28 parser.add_argument( 29 "--batch_size", 30 "-b", 31 type=int, 32 default=1, 33 help="batch size to use for the model", 34 ) 35 parser.add_argument( 36 "--device", type=str, default="cpu", help="device to use for the model" 37 ) 38 parser.add_argument("-f64", action="store_true", help="use double precision") 39 parser.add_argument("--periodic", action="store_true", help="use PBC") 40 parser.add_argument( 41 "--format", type=str, default="xyz", help="format of the input file" 42 ) 43 parser.add_argument( 44 "--output_keys", 45 type=str, 46 nargs="+", 47 help="keys to output", 48 default=["total_energy"], 49 ) 50 args = parser.parse_args() 51 52 # set the device 53 device = args.device.lower() 54 if device == "cpu": 55 device = "cpu" 56 os.environ["CUDA_VISIBLE_DEVICES"] = "" 57 elif device.startswith("cuda") or device.startswith("gpu"): 58 if ":" in device: 59 num = device.split(":")[-1] 60 os.environ["CUDA_VISIBLE_DEVICES"] = num 61 else: 62 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 63 device = "gpu" 64 65 _device = jax.devices(device)[0] 66 jax.config.update("jax_default_device", _device) 67 68 # set the precision 69 if args.f64: 70 jax.config.update("jax_enable_x64", True) 71 fprec = "float64" 72 else: 73 fprec = "float32" 74 jax.config.update("jax_default_matmul_precision", "highest") 75 76 77 # check the input file 78 input_file = args.input_file.resolve() 79 assert input_file.exists(), f"Input file {input_file} does not exist" 80 # assert args.format == "xyz", f"Only xyz format is supported for now" 81 assert args.format in ["xyz","arc","pkl"], f"Only xyz, arc and pkl formats are supported for now" 82 output_keys = args.output_keys 83 xyz_indexed = args.format == "arc" 84 85 # load the model 86 model_file: Path = args.model_file.resolve() 87 assert model_file.exists(), f"Model file {model_file} does not exist" 88 model = FENNIX.load(model_file, use_atom_padding=True) 89 90 # check the output file 91 if args.output is not None: 92 output_file: Path = args.output.resolve() 93 assert not output_file.exists(), f"Output file {args.output} already exists" 94 assert output_file.suffix in [ 95 ".pkl", 96 ".json", 97 ".yaml", 98 ], f"Unsupported output file format {output_file.suffix}" 99 100 # print metadata 101 metadata = { 102 "input_file": str(input_file), 103 "model_file": str(model_file), 104 "output_keys": output_keys, 105 "energy_unit": model.energy_unit, 106 } 107 print("metadata:") 108 for k,v in metadata.items(): 109 print(f" {k}: {v}") 110 111 # define the model prediction function 112 def model_predict(batch): 113 natoms = np.array([frame["natoms"] for frame in batch]) 114 batch_index = np.concatenate([frame["batch_index"] for frame in batch]) 115 species = np.concatenate([frame["species"] for frame in batch]) 116 xyz = np.concatenate([frame["coordinates"] for frame in batch], axis=0) 117 inputs = { 118 "species": species, 119 "coordinates": xyz, 120 "batch_index": batch_index, 121 "natoms": natoms,} 122 if args.periodic: 123 cells = np.concatenate([frame["cell"] for frame in batch], axis=0) 124 inputs["cells"] = cells 125 126 if "forces" in output_keys: 127 e, f, output = model.energy_and_forces( 128 **inputs 129 ) 130 else: 131 e, output = model.total_energy( 132 **inputs 133 ) 134 return output 135 136 # define the function to process a batch 137 def process_batch(batch): 138 output = model_predict(batch) 139 natoms = np.array([frame["natoms"] for frame in batch]) 140 if args.periodic: 141 cells = np.array(output["cells"]) 142 species = np.array(output["species"]) 143 coordinates = np.array(output["coordinates"]) 144 natshift = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(natoms)]) 145 frames_data = [] 146 for i in range(len(batch)): 147 frame_data = { 148 "species": species[natshift[i] : natshift[i + 1]].tolist(), 149 "coordinates": coordinates[natshift[i] : natshift[i + 1]].tolist(), 150 } 151 if args.periodic: 152 frame_data["cell"] = cells[i].tolist() 153 154 for k in output_keys: 155 if k not in output: 156 raise ValueError(f"Output key {k} not found") 157 v = output[k] 158 if v.shape[0] == species.shape[0]: 159 frame_data[k] = v[natshift[i] : natshift[i + 1]].tolist() 160 elif v.shape[0] == natoms.shape[0]: 161 frame_data[k] = v[i].tolist() 162 else: 163 raise ValueError(f"Output key {k} has wrong shape") 164 165 frames_data.append(frame_data) 166 return frames_data 167 168 ### start processing the input file 169 # reader = xyz_reader(input_file, has_comment_line=True, indexed=xyz_indexed) 170 if args.format in ["arc","xyz"]: 171 def reader(): 172 for symbols, xyz, comment in xyz_reader(input_file, has_comment_line=True, indexed=xyz_indexed): 173 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols]) 174 frame = {"species": species, "coordinates": xyz, "natoms":species.shape[0]} 175 if args.periodic: 176 cell = np.array([float(x) for x in comment.split()]).reshape(1,3,3) 177 frame["cell"] = cell 178 yield frame 179 180 elif args.format == "pkl": 181 with open(input_file, "rb") as f: 182 data = pickle.load(f) 183 if isinstance(data, dict): 184 if "frames" in data: 185 frames = data["frames"] 186 elif "training" in data: 187 frames = data["training"] 188 if "validation" in data: 189 frames.extend(data["validation"]) 190 else: 191 raise ValueError("No frames found in the input file") 192 else: 193 frames = data 194 195 def reader(): 196 for frame in frames: 197 species = np.array(frame["species"]) 198 coordinates = np.array(frame["coordinates"]) 199 frame = {"species": species, "coordinates": coordinates,"natoms":species.shape[0]} 200 if args.periodic: 201 cell = np.array(frame["cell"]).reshape(1,3,3) 202 frame["cell"] = cell 203 yield frame 204 205 batch = [] 206 output_data = [] 207 ibatch = 0 208 for frame in reader(): 209 batch_index = np.full(frame["natoms"], ibatch, dtype=np.int32) 210 frame["batch_index"] = batch_index 211 batch.append(frame) 212 ibatch += 1 213 if len(batch) == args.batch_size: 214 frames_data = process_batch(batch) 215 if args.output is None: 216 for frame_data in frames_data: 217 print("\n---\n") 218 for k in frame_data: 219 print(f"{k}: {frame_data[k]}") 220 output_data.extend(frames_data) 221 batch = [] 222 ibatch = 0 223 # process the last batch 224 if len(batch) > 0: 225 frames_data = process_batch(batch) 226 if args.output is None: 227 for frame_data in frames_data: 228 print("\n---\n") 229 for k in frame_data: 230 print(f"{k}: {frame_data[k]}") 231 output_data.extend(frames_data) 232 233 # write the output to a file 234 if args.output is not None: 235 output_data = { 236 "metadata": metadata, 237 "frames": output_data, 238 } 239 if output_file.suffix == ".json": 240 with open(args.output, "w") as f: 241 json.dump(output_data, f, indent=2) 242 elif output_file.suffix == ".yaml": 243 with open(args.output, "w") as f: 244 yaml.dump(output_data, f, sort_keys=False) 245 elif output_file.suffix == ".pkl": 246 with open(args.output, "wb") as f: 247 pickle.dump(output_data, f)