fennol.analyze
1import sys, os, io 2import argparse 3from pathlib import Path 4 5import numpy as np 6import jax 7import jax.numpy as jnp 8 9from .models import FENNIX 10from .utils import parse_cell 11from .utils.io import xyz_reader,human_time_duration 12from .utils.periodic_table import PERIODIC_TABLE_REV_IDX 13from .utils.atomic_units import AtomicUnits as au 14import yaml 15import pickle 16from flax.core import freeze, unfreeze 17import time 18import multiprocessing as mp 19 20def get_file_reader(input_file,file_format=None,has_comment_line=True, periodic=False): 21 # check the input file 22 input_file = Path(input_file).resolve() 23 assert input_file.exists(), f"Input file {input_file} does not exist" 24 # assert args.format == "xyz", f"Only xyz format is supported for now" 25 if file_format is None: 26 file_format = input_file.suffix[1:] # remove the dot 27 file_format = file_format.lower() 28 if file_format == "pickle": 29 file_format = "pkl" 30 31 assert file_format in [ 32 "xyz", 33 "arc", 34 "pkl", 35 ], f"Only xyz, arc and pkl formats are supported for now" 36 xyz_indexed = file_format == "arc" 37 38 if file_format in ["arc", "xyz"]: 39 40 def reader(): 41 for symbols, xyz, comment in xyz_reader( 42 input_file, has_comment_line=has_comment_line, indexed=xyz_indexed 43 ): 44 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols]) 45 inputs = { 46 "species": species, 47 "coordinates": xyz, 48 "natoms": species.shape[0], 49 "total_charge": 0, # default total charge 50 } 51 if periodic: 52 box = np.array(comment.split(), dtype=float) 53 cell = parse_cell(box) 54 inputs["cell"] = cell 55 yield inputs 56 57 elif file_format == "pkl": 58 with open(input_file, "rb") as f: 59 data = pickle.load(f) 60 if isinstance(data, dict): 61 if "frames" in data: 62 frames = data["frames"] 63 elif "training" in data: 64 frames = data["training"] 65 if "validation" in data: 66 frames.extend(data["validation"]) 67 elif "coordinates" in data or "output_keys" in data: 68 # data comes from fennol_analyze or fennol_md 69 frames = [] 70 with open(input_file, "rb") as f: 71 if "output_keys" in data: # skip metadata 72 frame = pickle.load(f) 73 try: 74 while True: 75 frame = pickle.load(f) 76 frames.append(frame) 77 except EOFError: 78 pass 79 else: 80 raise ValueError("No frames found in the input file") 81 else: 82 frames = data 83 84 def reader(): 85 _periodic = periodic 86 cell_key = "cells" if "cells" in frames[0] else "cell" 87 if cell_key in frames[0]: 88 _periodic = True 89 for frame in frames: 90 species = np.array(frame["species"]) 91 coordinates = np.array(frame["coordinates"]) 92 inputs = { 93 "species": species, 94 "coordinates": coordinates, 95 "natoms": species.shape[0], 96 "total_charge": frame.get("total_charge", 0), 97 } 98 cell_key = "cells" if "cells" in frame else "cell" 99 if _periodic: 100 cell_key = "cells" if "cells" in frame else "cell" 101 cell = np.array(frame[cell_key]).reshape(3, 3) 102 inputs["cell"] = cell 103 yield inputs 104 105 return reader 106 107 108def fennix_analyzer(input_file, model, output_keys,file_format=None,periodic=False,flags={},has_comment_line=True,batch_size=1): 109 110 reader = get_file_reader(input_file,file_format=file_format,periodic=periodic,has_comment_line=has_comment_line) 111 112 # define the model prediction function 113 def model_predict(batch): 114 natoms = np.array([frame["natoms"] for frame in batch], dtype=np.int32) 115 batch_index = np.concatenate([frame["batch_index"] for frame in batch]) 116 species = np.concatenate([frame["species"] for frame in batch]) 117 xyz = np.concatenate([frame["coordinates"] for frame in batch], axis=0) 118 total_charge = np.array([frame["total_charge"] for frame in batch], dtype=np.int32) 119 inputs = { 120 "species": species, 121 "coordinates": xyz, 122 "batch_index": batch_index, 123 "natoms": natoms, 124 "total_charge": total_charge, 125 } 126 if "cell" in batch[0]: 127 cells = np.stack([frame["cell"] for frame in batch], axis=0) 128 inputs["cells"] = cells 129 inputs["reciprocal_cells"] = np.linalg.inv(cells) 130 131 inputs["flags"] = flags 132 133 if "forces" in output_keys: 134 e, f, output = model.energy_and_forces(**inputs,gpu_preprocessing=True) 135 else: 136 e, output = model.total_energy(**inputs,gpu_preprocessing=True) 137 return output 138 139 # define the function to process a batch 140 def process_batch(batch): 141 output = model_predict(batch) 142 natoms = np.array([frame["natoms"] for frame in batch]) 143 _periodic = "cells" in output 144 if _periodic: 145 cells = np.array(output["cells"]) 146 species = np.array(output["species"]) 147 coordinates = np.array(output["coordinates"]) 148 natshift = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(natoms)]) 149 frames_data = [] 150 for i in range(len(batch)): 151 frame_data = { 152 "species": species[natshift[i] : natshift[i + 1]], 153 "coordinates": coordinates[natshift[i] : natshift[i + 1]], 154 "total_charge": int(output["total_charge"][i]), 155 } 156 if _periodic: 157 frame_data["cell"] = cells[i] 158 159 for k in output_keys: 160 if k not in output: 161 raise ValueError(f"Output key {k} not found") 162 v = np.asarray(output[k]) 163 # if scalar or only 1 element, convert to float 164 # if v.ndim == 0: 165 # v = float(v) 166 # print(k,v) 167 # if v.size == 1: 168 # v = v.flatten()[0] 169 # print(k,v) 170 if v.shape[0] == species.shape[0]: 171 frame_data[k] = v[natshift[i] : natshift[i + 1]] 172 elif v.shape[0] == natoms.shape[0]: 173 frame_data[k] = v[i] 174 else: 175 raise ValueError(f"Output key {k} has wrong shape {v.shape} {natoms.shape} {species.shape}") 176 177 frames_data.append(frame_data) 178 179 return frames_data 180 181 batch = [] 182 iframe = 0 183 for frame in reader(): 184 batch_index = np.full(frame["natoms"], iframe, dtype=np.int32) 185 frame["batch_index"] = batch_index 186 batch.append(frame) 187 iframe += 1 188 if len(batch) == batch_size: 189 frames_data = process_batch(batch) 190 for frame_data in frames_data: 191 yield frame_data 192 # output_data.extend(frames_data) 193 batch = [] 194 iframe = 0 195 # process the last batch 196 if len(batch) > 0: 197 frames_data = process_batch(batch) 198 for frame_data in frames_data: 199 yield frame_data 200 201def main(): 202 parser = argparse.ArgumentParser(prog="fennol_analyze") 203 parser.add_argument( 204 "input_file", type=Path, help="file containing the geometries to analyze" 205 ) 206 parser.add_argument( 207 "model_file", type=Path, help="file containing the model to use" 208 ) 209 parser.add_argument( 210 "-o", "--outfile", type=Path, help="file to write the output to" 211 ) 212 parser.add_argument( 213 "--batch_size", 214 "-b", 215 type=int, 216 default=1, 217 help="batch size to use for the model", 218 ) 219 parser.add_argument("--device", type=str, help="device to use for the model") 220 parser.add_argument("-f64", action="store_true", help="use double precision") 221 parser.add_argument("--periodic", action="store_true", help="use PBC") 222 parser.add_argument( 223 "--format", 224 type=str, 225 help="format of the input file. Default: auto-detect from file extension", 226 ) 227 parser.add_argument( 228 "-c", 229 "--nocomment", 230 action="store_true", 231 help="flag to indicate that the input file does not have a comment line. Only used for xyz and arc formats", 232 ) 233 parser.add_argument( 234 "--output_keys", 235 type=str, 236 nargs="+", 237 help="keys to output", 238 default=["total_energy"], 239 ) 240 parser.add_argument( 241 "--nblist", 242 nargs=3, 243 metavar=("mult_size", "add_neigh", "add_atoms"), 244 help="neighbour list parameters: mult_size, add_neigh, add_atoms. If not provided, the default values from the model will be used.", 245 ) 246 parser.add_argument( 247 "-m", 248 "--metadata", 249 action="store_true", 250 help="add metadata as a first frame.", 251 ) 252 parser.add_argument( 253 "--flags", 254 type=str, 255 nargs="*", 256 default=[], 257 help="additional flags to pass to the model. These will be added to the inputs as a dictionary with the flag name as key and None as value.", 258 ) 259 args = parser.parse_args() 260 261 # set the device 262 if args.device: 263 device = args.device.lower() 264 elif "FENNOL_DEVICE" in os.environ: 265 device = os.environ["FENNOL_DEVICE"].lower() 266 print(f"# Setting device from env FENNOL_DEVICE={device}") 267 else: 268 device = "cpu" 269 if device == "cpu": 270 jax.config.update("jax_platforms", "cpu") 271 os.environ["CUDA_VISIBLE_DEVICES"] = "" 272 elif device.startswith("cuda") or device.startswith("gpu"): 273 if ":" in device: 274 num = device.split(":")[-1] 275 os.environ["CUDA_VISIBLE_DEVICES"] = num 276 else: 277 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 278 device = "gpu" 279 280 _device = jax.devices(device)[0] 281 jax.config.update("jax_default_device", _device) 282 283 # set the precision 284 if args.f64: 285 jax.config.update("jax_enable_x64", True) 286 fprec = "float64" 287 else: 288 fprec = "float32" 289 jax.config.update("jax_default_matmul_precision", "highest") 290 291 output_keys = args.output_keys 292 293 # load the model 294 model_file: Path = args.model_file.resolve() 295 assert model_file.exists(), f"Model file {model_file} does not exist" 296 model = FENNIX.load(model_file, use_atom_padding=True) 297 298 preproc_state = unfreeze(model.preproc_state) 299 layer_state = [] 300 for st in preproc_state["layers_state"]: 301 stnew = unfreeze(st) 302 # if "nblist_mult_size" in training_parameters: 303 # stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 304 # if "nblist_add_neigh" in training_parameters: 305 # stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 306 # if "nblist_add_atoms" in training_parameters: 307 # stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 308 if args.nblist is not None: 309 mult_size, add_neigh, add_atoms = args.nblist 310 stnew["nblist_mult_size"] = float(mult_size) 311 stnew["add_neigh"] = int(add_neigh) 312 stnew["add_atoms"] = int(add_atoms) 313 layer_state.append(freeze(stnew)) 314 315 preproc_state["layers_state"] = tuple(layer_state) 316 preproc_state["check_input"] = False 317 model.preproc_state = freeze(preproc_state) 318 # model.preproc_state = model.preproc_state.copy({"check_input": False}) 319 320 # check the output file 321 if args.outfile is not None: 322 output_file: Path = args.outfile.resolve() 323 assert not output_file.exists(), f"Output file {args.outfile} already exists" 324 out_format = output_file.suffix.lower()[1:] 325 if out_format == "yml": 326 out_format = "yaml" 327 if out_format == "pickle": 328 out_format = "pkl" 329 if out_format == "hdf5": 330 out_format = "h5" 331 if out_format == "msgpack": 332 out_format = "mpk" 333 assert out_format in [ 334 "pkl", 335 "yaml", 336 "h5", 337 "mpk", 338 ], f"Unsupported output file format {output_file.suffix}" 339 340 flags = {flag.strip(): None for flag in args.flags} 341 # print metadata 342 metadata = { 343 "input_file": str(args.input_file), 344 "model_file": str(model_file), 345 "output_keys": output_keys, 346 "energy_unit": model.energy_unit, 347 "flags": flags, 348 } 349 print("metadata:") 350 for k, v in metadata.items(): 351 print(f" {k}: {v}") 352 353 def make_dumpable(data): 354 for k in data: 355 if isinstance(data[k], np.ndarray): 356 data[k] = data[k].tolist() 357 return data 358 359 def queue_reader(queue): 360 if args.outfile is not None: 361 if out_format == "yaml": 362 of = open(args.outfile, "w") 363 def dump(data): 364 of.write("\n---\n") 365 i=yaml.dump(make_dumpable(data), of, sort_keys=False) 366 of.flush() 367 return i 368 elif out_format == "pkl": 369 of = open(args.outfile, "wb") 370 def dump(data): 371 i=pickle.dump(data,of) 372 of.flush() 373 return i 374 elif out_format == "h5": 375 import h5py 376 of = h5py.File(args.outfile, "w") 377 global iframeh5 378 iframeh5 = 0 # start from frame 1 379 def dump(data): 380 global iframeh5 381 of.create_group(f'{iframeh5}', track_order=True) 382 for k in data: 383 of[f"{iframeh5}/{k}"] = data[k] 384 iframeh5 += 1 385 return None 386 elif out_format == "mpk": 387 import msgpack 388 of = open(args.outfile, "wb") 389 def dump(data): 390 i=msgpack.pack(make_dumpable(data), of) 391 of.flush() 392 return i 393 else: 394 raise ValueError(f"Unsupported output file format {out_format}") 395 else: 396 def dump(data): 397 print("\n---\n") 398 data = make_dumpable(data) 399 for k in data: 400 print(f"{k}: {data[k]}") 401 return None 402 while True: 403 data = queue.get() 404 if data is None: 405 break 406 dump(data) 407 if args.outfile is not None: 408 of.close() 409 # create a multiprocessing queue to dump the output 410 queue = mp.Queue() 411 p = mp.Process(target=queue_reader, args=(queue,)) 412 p.start() 413 if args.metadata and args.outfile is not None: 414 queue.put(metadata) 415 416 417 418 error = None 419 try: 420 time_start = time.time() 421 ibatch = 0 422 for iframe,frame in enumerate(fennix_analyzer( 423 args.input_file, 424 model, 425 output_keys=output_keys, 426 file_format=args.format, 427 periodic=args.periodic, 428 flags=flags, 429 has_comment_line=not args.nocomment, 430 batch_size=args.batch_size, 431 )): 432 queue.put(frame) 433 if (iframe+1) % args.batch_size == 0 and args.outfile is not None: 434 ibatch += 1 435 elapsed = time.time() - time_start 436 print(f"# Processed batch {ibatch}. Elapsed time: {human_time_duration(elapsed)}") 437 # output_data.extend(frames_data) 438 except KeyboardInterrupt: 439 print("# Interrupted by user. Exiting...") 440 except Exception as err: 441 error = err 442 print(f"# Exiting with Exception: {error}") 443 444 # wait for the process to finish 445 queue.put(None) 446 print("# Waiting for the writing process to finish...") 447 p.join() 448 print("# All done in ", human_time_duration(time.time() - time_start)) 449 450 if error is not None: 451 raise error 452 453 454 455if __name__ == "__main__": 456 main()
def
get_file_reader(input_file, file_format=None, has_comment_line=True, periodic=False):
21def get_file_reader(input_file,file_format=None,has_comment_line=True, periodic=False): 22 # check the input file 23 input_file = Path(input_file).resolve() 24 assert input_file.exists(), f"Input file {input_file} does not exist" 25 # assert args.format == "xyz", f"Only xyz format is supported for now" 26 if file_format is None: 27 file_format = input_file.suffix[1:] # remove the dot 28 file_format = file_format.lower() 29 if file_format == "pickle": 30 file_format = "pkl" 31 32 assert file_format in [ 33 "xyz", 34 "arc", 35 "pkl", 36 ], f"Only xyz, arc and pkl formats are supported for now" 37 xyz_indexed = file_format == "arc" 38 39 if file_format in ["arc", "xyz"]: 40 41 def reader(): 42 for symbols, xyz, comment in xyz_reader( 43 input_file, has_comment_line=has_comment_line, indexed=xyz_indexed 44 ): 45 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols]) 46 inputs = { 47 "species": species, 48 "coordinates": xyz, 49 "natoms": species.shape[0], 50 "total_charge": 0, # default total charge 51 } 52 if periodic: 53 box = np.array(comment.split(), dtype=float) 54 cell = parse_cell(box) 55 inputs["cell"] = cell 56 yield inputs 57 58 elif file_format == "pkl": 59 with open(input_file, "rb") as f: 60 data = pickle.load(f) 61 if isinstance(data, dict): 62 if "frames" in data: 63 frames = data["frames"] 64 elif "training" in data: 65 frames = data["training"] 66 if "validation" in data: 67 frames.extend(data["validation"]) 68 elif "coordinates" in data or "output_keys" in data: 69 # data comes from fennol_analyze or fennol_md 70 frames = [] 71 with open(input_file, "rb") as f: 72 if "output_keys" in data: # skip metadata 73 frame = pickle.load(f) 74 try: 75 while True: 76 frame = pickle.load(f) 77 frames.append(frame) 78 except EOFError: 79 pass 80 else: 81 raise ValueError("No frames found in the input file") 82 else: 83 frames = data 84 85 def reader(): 86 _periodic = periodic 87 cell_key = "cells" if "cells" in frames[0] else "cell" 88 if cell_key in frames[0]: 89 _periodic = True 90 for frame in frames: 91 species = np.array(frame["species"]) 92 coordinates = np.array(frame["coordinates"]) 93 inputs = { 94 "species": species, 95 "coordinates": coordinates, 96 "natoms": species.shape[0], 97 "total_charge": frame.get("total_charge", 0), 98 } 99 cell_key = "cells" if "cells" in frame else "cell" 100 if _periodic: 101 cell_key = "cells" if "cells" in frame else "cell" 102 cell = np.array(frame[cell_key]).reshape(3, 3) 103 inputs["cell"] = cell 104 yield inputs 105 106 return reader
def
fennix_analyzer( input_file, model, output_keys, file_format=None, periodic=False, flags={}, has_comment_line=True, batch_size=1):
109def fennix_analyzer(input_file, model, output_keys,file_format=None,periodic=False,flags={},has_comment_line=True,batch_size=1): 110 111 reader = get_file_reader(input_file,file_format=file_format,periodic=periodic,has_comment_line=has_comment_line) 112 113 # define the model prediction function 114 def model_predict(batch): 115 natoms = np.array([frame["natoms"] for frame in batch], dtype=np.int32) 116 batch_index = np.concatenate([frame["batch_index"] for frame in batch]) 117 species = np.concatenate([frame["species"] for frame in batch]) 118 xyz = np.concatenate([frame["coordinates"] for frame in batch], axis=0) 119 total_charge = np.array([frame["total_charge"] for frame in batch], dtype=np.int32) 120 inputs = { 121 "species": species, 122 "coordinates": xyz, 123 "batch_index": batch_index, 124 "natoms": natoms, 125 "total_charge": total_charge, 126 } 127 if "cell" in batch[0]: 128 cells = np.stack([frame["cell"] for frame in batch], axis=0) 129 inputs["cells"] = cells 130 inputs["reciprocal_cells"] = np.linalg.inv(cells) 131 132 inputs["flags"] = flags 133 134 if "forces" in output_keys: 135 e, f, output = model.energy_and_forces(**inputs,gpu_preprocessing=True) 136 else: 137 e, output = model.total_energy(**inputs,gpu_preprocessing=True) 138 return output 139 140 # define the function to process a batch 141 def process_batch(batch): 142 output = model_predict(batch) 143 natoms = np.array([frame["natoms"] for frame in batch]) 144 _periodic = "cells" in output 145 if _periodic: 146 cells = np.array(output["cells"]) 147 species = np.array(output["species"]) 148 coordinates = np.array(output["coordinates"]) 149 natshift = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(natoms)]) 150 frames_data = [] 151 for i in range(len(batch)): 152 frame_data = { 153 "species": species[natshift[i] : natshift[i + 1]], 154 "coordinates": coordinates[natshift[i] : natshift[i + 1]], 155 "total_charge": int(output["total_charge"][i]), 156 } 157 if _periodic: 158 frame_data["cell"] = cells[i] 159 160 for k in output_keys: 161 if k not in output: 162 raise ValueError(f"Output key {k} not found") 163 v = np.asarray(output[k]) 164 # if scalar or only 1 element, convert to float 165 # if v.ndim == 0: 166 # v = float(v) 167 # print(k,v) 168 # if v.size == 1: 169 # v = v.flatten()[0] 170 # print(k,v) 171 if v.shape[0] == species.shape[0]: 172 frame_data[k] = v[natshift[i] : natshift[i + 1]] 173 elif v.shape[0] == natoms.shape[0]: 174 frame_data[k] = v[i] 175 else: 176 raise ValueError(f"Output key {k} has wrong shape {v.shape} {natoms.shape} {species.shape}") 177 178 frames_data.append(frame_data) 179 180 return frames_data 181 182 batch = [] 183 iframe = 0 184 for frame in reader(): 185 batch_index = np.full(frame["natoms"], iframe, dtype=np.int32) 186 frame["batch_index"] = batch_index 187 batch.append(frame) 188 iframe += 1 189 if len(batch) == batch_size: 190 frames_data = process_batch(batch) 191 for frame_data in frames_data: 192 yield frame_data 193 # output_data.extend(frames_data) 194 batch = [] 195 iframe = 0 196 # process the last batch 197 if len(batch) > 0: 198 frames_data = process_batch(batch) 199 for frame_data in frames_data: 200 yield frame_data
def
main():
202def main(): 203 parser = argparse.ArgumentParser(prog="fennol_analyze") 204 parser.add_argument( 205 "input_file", type=Path, help="file containing the geometries to analyze" 206 ) 207 parser.add_argument( 208 "model_file", type=Path, help="file containing the model to use" 209 ) 210 parser.add_argument( 211 "-o", "--outfile", type=Path, help="file to write the output to" 212 ) 213 parser.add_argument( 214 "--batch_size", 215 "-b", 216 type=int, 217 default=1, 218 help="batch size to use for the model", 219 ) 220 parser.add_argument("--device", type=str, help="device to use for the model") 221 parser.add_argument("-f64", action="store_true", help="use double precision") 222 parser.add_argument("--periodic", action="store_true", help="use PBC") 223 parser.add_argument( 224 "--format", 225 type=str, 226 help="format of the input file. Default: auto-detect from file extension", 227 ) 228 parser.add_argument( 229 "-c", 230 "--nocomment", 231 action="store_true", 232 help="flag to indicate that the input file does not have a comment line. Only used for xyz and arc formats", 233 ) 234 parser.add_argument( 235 "--output_keys", 236 type=str, 237 nargs="+", 238 help="keys to output", 239 default=["total_energy"], 240 ) 241 parser.add_argument( 242 "--nblist", 243 nargs=3, 244 metavar=("mult_size", "add_neigh", "add_atoms"), 245 help="neighbour list parameters: mult_size, add_neigh, add_atoms. If not provided, the default values from the model will be used.", 246 ) 247 parser.add_argument( 248 "-m", 249 "--metadata", 250 action="store_true", 251 help="add metadata as a first frame.", 252 ) 253 parser.add_argument( 254 "--flags", 255 type=str, 256 nargs="*", 257 default=[], 258 help="additional flags to pass to the model. These will be added to the inputs as a dictionary with the flag name as key and None as value.", 259 ) 260 args = parser.parse_args() 261 262 # set the device 263 if args.device: 264 device = args.device.lower() 265 elif "FENNOL_DEVICE" in os.environ: 266 device = os.environ["FENNOL_DEVICE"].lower() 267 print(f"# Setting device from env FENNOL_DEVICE={device}") 268 else: 269 device = "cpu" 270 if device == "cpu": 271 jax.config.update("jax_platforms", "cpu") 272 os.environ["CUDA_VISIBLE_DEVICES"] = "" 273 elif device.startswith("cuda") or device.startswith("gpu"): 274 if ":" in device: 275 num = device.split(":")[-1] 276 os.environ["CUDA_VISIBLE_DEVICES"] = num 277 else: 278 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 279 device = "gpu" 280 281 _device = jax.devices(device)[0] 282 jax.config.update("jax_default_device", _device) 283 284 # set the precision 285 if args.f64: 286 jax.config.update("jax_enable_x64", True) 287 fprec = "float64" 288 else: 289 fprec = "float32" 290 jax.config.update("jax_default_matmul_precision", "highest") 291 292 output_keys = args.output_keys 293 294 # load the model 295 model_file: Path = args.model_file.resolve() 296 assert model_file.exists(), f"Model file {model_file} does not exist" 297 model = FENNIX.load(model_file, use_atom_padding=True) 298 299 preproc_state = unfreeze(model.preproc_state) 300 layer_state = [] 301 for st in preproc_state["layers_state"]: 302 stnew = unfreeze(st) 303 # if "nblist_mult_size" in training_parameters: 304 # stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"] 305 # if "nblist_add_neigh" in training_parameters: 306 # stnew["add_neigh"] = training_parameters["nblist_add_neigh"] 307 # if "nblist_add_atoms" in training_parameters: 308 # stnew["add_atoms"] = training_parameters["nblist_add_atoms"] 309 if args.nblist is not None: 310 mult_size, add_neigh, add_atoms = args.nblist 311 stnew["nblist_mult_size"] = float(mult_size) 312 stnew["add_neigh"] = int(add_neigh) 313 stnew["add_atoms"] = int(add_atoms) 314 layer_state.append(freeze(stnew)) 315 316 preproc_state["layers_state"] = tuple(layer_state) 317 preproc_state["check_input"] = False 318 model.preproc_state = freeze(preproc_state) 319 # model.preproc_state = model.preproc_state.copy({"check_input": False}) 320 321 # check the output file 322 if args.outfile is not None: 323 output_file: Path = args.outfile.resolve() 324 assert not output_file.exists(), f"Output file {args.outfile} already exists" 325 out_format = output_file.suffix.lower()[1:] 326 if out_format == "yml": 327 out_format = "yaml" 328 if out_format == "pickle": 329 out_format = "pkl" 330 if out_format == "hdf5": 331 out_format = "h5" 332 if out_format == "msgpack": 333 out_format = "mpk" 334 assert out_format in [ 335 "pkl", 336 "yaml", 337 "h5", 338 "mpk", 339 ], f"Unsupported output file format {output_file.suffix}" 340 341 flags = {flag.strip(): None for flag in args.flags} 342 # print metadata 343 metadata = { 344 "input_file": str(args.input_file), 345 "model_file": str(model_file), 346 "output_keys": output_keys, 347 "energy_unit": model.energy_unit, 348 "flags": flags, 349 } 350 print("metadata:") 351 for k, v in metadata.items(): 352 print(f" {k}: {v}") 353 354 def make_dumpable(data): 355 for k in data: 356 if isinstance(data[k], np.ndarray): 357 data[k] = data[k].tolist() 358 return data 359 360 def queue_reader(queue): 361 if args.outfile is not None: 362 if out_format == "yaml": 363 of = open(args.outfile, "w") 364 def dump(data): 365 of.write("\n---\n") 366 i=yaml.dump(make_dumpable(data), of, sort_keys=False) 367 of.flush() 368 return i 369 elif out_format == "pkl": 370 of = open(args.outfile, "wb") 371 def dump(data): 372 i=pickle.dump(data,of) 373 of.flush() 374 return i 375 elif out_format == "h5": 376 import h5py 377 of = h5py.File(args.outfile, "w") 378 global iframeh5 379 iframeh5 = 0 # start from frame 1 380 def dump(data): 381 global iframeh5 382 of.create_group(f'{iframeh5}', track_order=True) 383 for k in data: 384 of[f"{iframeh5}/{k}"] = data[k] 385 iframeh5 += 1 386 return None 387 elif out_format == "mpk": 388 import msgpack 389 of = open(args.outfile, "wb") 390 def dump(data): 391 i=msgpack.pack(make_dumpable(data), of) 392 of.flush() 393 return i 394 else: 395 raise ValueError(f"Unsupported output file format {out_format}") 396 else: 397 def dump(data): 398 print("\n---\n") 399 data = make_dumpable(data) 400 for k in data: 401 print(f"{k}: {data[k]}") 402 return None 403 while True: 404 data = queue.get() 405 if data is None: 406 break 407 dump(data) 408 if args.outfile is not None: 409 of.close() 410 # create a multiprocessing queue to dump the output 411 queue = mp.Queue() 412 p = mp.Process(target=queue_reader, args=(queue,)) 413 p.start() 414 if args.metadata and args.outfile is not None: 415 queue.put(metadata) 416 417 418 419 error = None 420 try: 421 time_start = time.time() 422 ibatch = 0 423 for iframe,frame in enumerate(fennix_analyzer( 424 args.input_file, 425 model, 426 output_keys=output_keys, 427 file_format=args.format, 428 periodic=args.periodic, 429 flags=flags, 430 has_comment_line=not args.nocomment, 431 batch_size=args.batch_size, 432 )): 433 queue.put(frame) 434 if (iframe+1) % args.batch_size == 0 and args.outfile is not None: 435 ibatch += 1 436 elapsed = time.time() - time_start 437 print(f"# Processed batch {ibatch}. Elapsed time: {human_time_duration(elapsed)}") 438 # output_data.extend(frames_data) 439 except KeyboardInterrupt: 440 print("# Interrupted by user. Exiting...") 441 except Exception as err: 442 error = err 443 print(f"# Exiting with Exception: {error}") 444 445 # wait for the process to finish 446 queue.put(None) 447 print("# Waiting for the writing process to finish...") 448 p.join() 449 print("# All done in ", human_time_duration(time.time() - time_start)) 450 451 if error is not None: 452 raise error