fennol.analyze

  1import sys, os, io
  2import argparse
  3from pathlib import Path
  4
  5import numpy as np
  6import jax
  7import jax.numpy as jnp
  8
  9from .models import FENNIX
 10from .utils import parse_cell
 11from .utils.io import xyz_reader,human_time_duration
 12from .utils.periodic_table import PERIODIC_TABLE_REV_IDX
 13from .utils.atomic_units import AtomicUnits as au
 14import yaml
 15import pickle
 16from flax.core import freeze, unfreeze
 17import time
 18import multiprocessing as mp
 19
 20def get_file_reader(input_file,file_format=None,has_comment_line=True, periodic=False):
 21    # check the input file
 22    input_file = Path(input_file).resolve()
 23    assert input_file.exists(), f"Input file {input_file} does not exist"
 24    # assert args.format == "xyz", f"Only xyz format is supported for now"
 25    if file_format is None:
 26        file_format = input_file.suffix[1:]  # remove the dot
 27    file_format = file_format.lower()
 28    if file_format == "pickle":
 29        file_format = "pkl"
 30
 31    assert file_format in [
 32        "xyz",
 33        "arc",
 34        "pkl",
 35    ], f"Only xyz, arc and pkl formats are supported for now"
 36    xyz_indexed = file_format == "arc"
 37
 38    if file_format in ["arc", "xyz"]:
 39
 40        def reader():
 41            for symbols, xyz, comment in xyz_reader(
 42                input_file, has_comment_line=has_comment_line, indexed=xyz_indexed
 43            ):
 44                species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols])
 45                inputs = {
 46                    "species": species,
 47                    "coordinates": xyz,
 48                    "natoms": species.shape[0],
 49                    "total_charge": 0,  # default total charge
 50                }
 51                if periodic:
 52                    box = np.array(comment.split(), dtype=float)
 53                    cell = parse_cell(box)
 54                    inputs["cell"] = cell
 55                yield inputs
 56
 57    elif file_format == "pkl":
 58        with open(input_file, "rb") as f:
 59            data = pickle.load(f)
 60        if isinstance(data, dict):
 61            if "frames" in data:
 62                frames = data["frames"]
 63            elif "training" in data:
 64                frames = data["training"]
 65                if "validation" in data:
 66                    frames.extend(data["validation"])
 67            elif "coordinates" in data or "output_keys" in data:
 68                # data comes from fennol_analyze or fennol_md
 69                frames = []
 70                with open(input_file, "rb") as f:
 71                    if "output_keys" in data: # skip metadata
 72                        frame = pickle.load(f)
 73                    try:
 74                        while True:
 75                            frame = pickle.load(f)
 76                            frames.append(frame)
 77                    except EOFError:
 78                        pass
 79            else:
 80                raise ValueError("No frames found in the input file")
 81        else:
 82            frames = data
 83
 84        def reader():
 85            _periodic = periodic
 86            cell_key = "cells" if "cells" in frames[0] else "cell"
 87            if cell_key in frames[0]:
 88                _periodic = True
 89            for frame in frames:
 90                species = np.array(frame["species"])
 91                coordinates = np.array(frame["coordinates"])
 92                inputs = {
 93                    "species": species,
 94                    "coordinates": coordinates,
 95                    "natoms": species.shape[0],
 96                    "total_charge": frame.get("total_charge", 0),
 97                }
 98                cell_key = "cells" if "cells" in frame else "cell"
 99                if _periodic:
100                    cell_key = "cells" if "cells" in frame else "cell"
101                    cell = np.array(frame[cell_key]).reshape(3, 3)
102                    inputs["cell"] = cell
103                yield inputs
104    
105    return reader
106
107
108def fennix_analyzer(input_file, model, output_keys,file_format=None,periodic=False,flags={},has_comment_line=True,batch_size=1):
109    
110    reader = get_file_reader(input_file,file_format=file_format,periodic=periodic,has_comment_line=has_comment_line)
111
112    # define the model prediction function
113    def model_predict(batch):
114        natoms = np.array([frame["natoms"] for frame in batch], dtype=np.int32)
115        batch_index = np.concatenate([frame["batch_index"] for frame in batch])
116        species = np.concatenate([frame["species"] for frame in batch])
117        xyz = np.concatenate([frame["coordinates"] for frame in batch], axis=0)
118        total_charge = np.array([frame["total_charge"] for frame in batch], dtype=np.int32)
119        inputs = {
120            "species": species,
121            "coordinates": xyz,
122            "batch_index": batch_index,
123            "natoms": natoms,
124            "total_charge": total_charge,
125        }
126        if "cell" in batch[0]:
127            cells = np.stack([frame["cell"] for frame in batch], axis=0)
128            inputs["cells"] = cells
129            inputs["reciprocal_cells"] = np.linalg.inv(cells)
130        
131        inputs["flags"] = flags
132
133        if "forces" in output_keys:
134            e, f, output = model.energy_and_forces(**inputs,gpu_preprocessing=True)
135        else:
136            e, output = model.total_energy(**inputs,gpu_preprocessing=True)
137        return output
138
139    # define the function to process a batch
140    def process_batch(batch):
141        output = model_predict(batch)
142        natoms = np.array([frame["natoms"] for frame in batch])
143        _periodic = "cells" in output
144        if _periodic:
145            cells = np.array(output["cells"])
146        species = np.array(output["species"])
147        coordinates = np.array(output["coordinates"])
148        natshift = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(natoms)])
149        frames_data = []
150        for i in range(len(batch)):
151            frame_data = {
152                "species": species[natshift[i] : natshift[i + 1]],
153                "coordinates": coordinates[natshift[i] : natshift[i + 1]],
154                "total_charge": int(output["total_charge"][i]),
155            }
156            if _periodic:
157                frame_data["cell"] = cells[i]
158
159            for k in output_keys:
160                if k not in output:
161                    raise ValueError(f"Output key {k} not found")
162                v = np.asarray(output[k])
163                # if scalar or only 1 element, convert to float
164                # if v.ndim == 0:
165                #     v = float(v)
166                #     print(k,v)
167                # if v.size == 1:
168                #     v = v.flatten()[0]
169                #     print(k,v)
170                if v.shape[0] == species.shape[0]:
171                    frame_data[k] = v[natshift[i] : natshift[i + 1]]
172                elif v.shape[0] == natoms.shape[0]:
173                    frame_data[k] = v[i]
174                else:
175                    raise ValueError(f"Output key {k} has wrong shape {v.shape} {natoms.shape} {species.shape}")
176
177            frames_data.append(frame_data)
178        
179        return frames_data
180
181    batch = []
182    iframe = 0
183    for frame in reader():
184        batch_index = np.full(frame["natoms"], iframe, dtype=np.int32)
185        frame["batch_index"] = batch_index
186        batch.append(frame)
187        iframe += 1
188        if len(batch) == batch_size:
189            frames_data = process_batch(batch)
190            for frame_data in frames_data:
191                yield frame_data
192            # output_data.extend(frames_data)
193            batch = []
194            iframe = 0
195    # process the last batch
196    if len(batch) > 0:
197        frames_data = process_batch(batch)
198        for frame_data in frames_data:
199            yield frame_data
200
201def main():
202    parser = argparse.ArgumentParser(prog="fennol_analyze")
203    parser.add_argument(
204        "input_file", type=Path, help="file containing the geometries to analyze"
205    )
206    parser.add_argument(
207        "model_file", type=Path, help="file containing the model to use"
208    )
209    parser.add_argument(
210        "-o", "--outfile", type=Path, help="file to write the output to"
211    )
212    parser.add_argument(
213        "--batch_size",
214        "-b",
215        type=int,
216        default=1,
217        help="batch size to use for the model",
218    )
219    parser.add_argument("--device", type=str, help="device to use for the model")
220    parser.add_argument("-f64", action="store_true", help="use double precision")
221    parser.add_argument("--periodic", action="store_true", help="use PBC")
222    parser.add_argument(
223        "--format",
224        type=str,
225        help="format of the input file. Default: auto-detect from file extension",
226    )
227    parser.add_argument(
228        "-c",
229        "--nocomment",
230        action="store_true",
231        help="flag to indicate that the input file does not have a comment line. Only used for xyz and arc formats",
232    )
233    parser.add_argument(
234        "--output_keys",
235        type=str,
236        nargs="+",
237        help="keys to output",
238        default=["total_energy"],
239    )
240    parser.add_argument(
241        "--nblist",
242        nargs=3,
243        metavar=("mult_size", "add_neigh", "add_atoms"),
244        help="neighbour list parameters: mult_size, add_neigh, add_atoms. If not provided, the default values from the model will be used.",
245    )
246    parser.add_argument(
247        "-m",
248        "--metadata",
249        action="store_true",
250        help="add metadata as a first frame.",
251    )
252    parser.add_argument(
253        "--flags",
254        type=str,
255        nargs="*",
256        default=[],
257        help="additional flags to pass to the model. These will be added to the inputs as a dictionary with the flag name as key and None as value.",
258    )
259    args = parser.parse_args()
260
261    # set the device
262    if args.device:
263        device = args.device.lower()
264    elif "FENNOL_DEVICE" in os.environ:
265        device = os.environ["FENNOL_DEVICE"].lower()
266        print(f"# Setting device from env FENNOL_DEVICE={device}")
267    else:
268        device = "cpu"
269    if device == "cpu":
270        jax.config.update("jax_platforms", "cpu")
271        os.environ["CUDA_VISIBLE_DEVICES"] = ""
272    elif device.startswith("cuda") or device.startswith("gpu"):
273        if ":" in device:
274            num = device.split(":")[-1]
275            os.environ["CUDA_VISIBLE_DEVICES"] = num
276        else:
277            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
278        device = "gpu"
279
280    _device = jax.devices(device)[0]
281    jax.config.update("jax_default_device", _device)
282
283    # set the precision
284    if args.f64:
285        jax.config.update("jax_enable_x64", True)
286        fprec = "float64"
287    else:
288        fprec = "float32"
289    jax.config.update("jax_default_matmul_precision", "highest")
290
291    output_keys = args.output_keys
292
293    # load the model
294    model_file: Path = args.model_file.resolve()
295    assert model_file.exists(), f"Model file {model_file} does not exist"
296    model = FENNIX.load(model_file, use_atom_padding=True)
297
298    preproc_state = unfreeze(model.preproc_state)
299    layer_state = []
300    for st in preproc_state["layers_state"]:
301        stnew = unfreeze(st)
302        # if "nblist_mult_size" in training_parameters:
303        #     stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
304        # if "nblist_add_neigh" in training_parameters:
305        #     stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
306        # if "nblist_add_atoms" in training_parameters:
307        #     stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
308        if args.nblist is not None:
309            mult_size, add_neigh, add_atoms = args.nblist
310            stnew["nblist_mult_size"] = float(mult_size)
311            stnew["add_neigh"] = int(add_neigh)
312            stnew["add_atoms"] = int(add_atoms)
313        layer_state.append(freeze(stnew))
314
315    preproc_state["layers_state"] = tuple(layer_state)
316    preproc_state["check_input"] = False
317    model.preproc_state = freeze(preproc_state)
318    # model.preproc_state = model.preproc_state.copy({"check_input": False})
319
320    # check the output file
321    if args.outfile is not None:
322        output_file: Path = args.outfile.resolve()
323        assert not output_file.exists(), f"Output file {args.outfile} already exists"
324        out_format = output_file.suffix.lower()[1:]
325        if out_format == "yml":
326            out_format = "yaml"
327        if out_format == "pickle":
328            out_format = "pkl"
329        if out_format == "hdf5":
330            out_format = "h5"
331        if out_format == "msgpack":
332            out_format = "mpk"
333        assert out_format in [
334            "pkl",
335            "yaml",
336            "h5",
337            "mpk",
338        ], f"Unsupported output file format {output_file.suffix}"
339
340    flags = {flag.strip(): None for flag in args.flags}
341    # print metadata
342    metadata = {
343        "input_file": str(args.input_file),
344        "model_file": str(model_file),
345        "output_keys": output_keys,
346        "energy_unit": model.energy_unit,
347        "flags": flags,
348    }
349    print("metadata:")
350    for k, v in metadata.items():
351        print(f"  {k}: {v}")
352
353    def make_dumpable(data):
354        for k in data:
355            if isinstance(data[k], np.ndarray):
356                data[k] = data[k].tolist()
357        return data
358    
359    def queue_reader(queue):
360        if args.outfile is not None:
361            if out_format == "yaml":
362                of = open(args.outfile, "w")
363                def dump(data):
364                    of.write("\n---\n")
365                    i=yaml.dump(make_dumpable(data), of, sort_keys=False)
366                    of.flush()
367                    return i
368            elif out_format == "pkl":
369                of = open(args.outfile, "wb")
370                def dump(data):
371                    i=pickle.dump(data,of)
372                    of.flush()
373                    return i
374            elif out_format == "h5":
375                import h5py
376                of = h5py.File(args.outfile, "w")
377                global iframeh5
378                iframeh5 = 0  # start from frame 1
379                def dump(data):
380                    global iframeh5
381                    of.create_group(f'{iframeh5}', track_order=True)
382                    for k in data:
383                        of[f"{iframeh5}/{k}"] = data[k]
384                    iframeh5 += 1
385                    return None
386            elif out_format == "mpk":
387                import msgpack
388                of = open(args.outfile, "wb")
389                def dump(data):
390                    i=msgpack.pack(make_dumpable(data), of)
391                    of.flush()
392                    return i
393            else:
394                raise ValueError(f"Unsupported output file format {out_format}")
395        else:
396            def dump(data):
397                print("\n---\n")
398                data = make_dumpable(data)
399                for k in data:
400                    print(f"{k}: {data[k]}")
401                return None
402        while True:
403            data = queue.get()
404            if data is None:
405                break
406            dump(data)
407        if args.outfile is not None:
408            of.close()
409    # create a multiprocessing queue to dump the output
410    queue = mp.Queue()
411    p = mp.Process(target=queue_reader, args=(queue,))
412    p.start()
413    if args.metadata and args.outfile is not None:
414        queue.put(metadata)
415
416   
417
418    error = None
419    try:
420        time_start = time.time()
421        ibatch = 0
422        for iframe,frame in enumerate(fennix_analyzer(
423            args.input_file,
424            model,
425            output_keys=output_keys,
426            file_format=args.format,
427            periodic=args.periodic,
428            flags=flags,
429            has_comment_line=not args.nocomment,
430            batch_size=args.batch_size,
431        )):
432            queue.put(frame)
433            if (iframe+1) % args.batch_size == 0 and args.outfile is not None:
434                ibatch += 1
435                elapsed = time.time() - time_start
436                print(f"# Processed batch {ibatch}. Elapsed time: {human_time_duration(elapsed)}")
437                # output_data.extend(frames_data)
438    except KeyboardInterrupt:
439        print("# Interrupted by user. Exiting...")
440    except Exception as err:
441        error = err
442        print(f"# Exiting with Exception: {error}")
443
444    # wait for the process to finish
445    queue.put(None)
446    print("# Waiting for the writing process to finish...")
447    p.join()
448    print("# All done in ", human_time_duration(time.time() - time_start))
449
450    if error is not None:
451        raise error
452    
453
454
455if __name__ == "__main__":
456    main()
def get_file_reader(input_file, file_format=None, has_comment_line=True, periodic=False):
 21def get_file_reader(input_file,file_format=None,has_comment_line=True, periodic=False):
 22    # check the input file
 23    input_file = Path(input_file).resolve()
 24    assert input_file.exists(), f"Input file {input_file} does not exist"
 25    # assert args.format == "xyz", f"Only xyz format is supported for now"
 26    if file_format is None:
 27        file_format = input_file.suffix[1:]  # remove the dot
 28    file_format = file_format.lower()
 29    if file_format == "pickle":
 30        file_format = "pkl"
 31
 32    assert file_format in [
 33        "xyz",
 34        "arc",
 35        "pkl",
 36    ], f"Only xyz, arc and pkl formats are supported for now"
 37    xyz_indexed = file_format == "arc"
 38
 39    if file_format in ["arc", "xyz"]:
 40
 41        def reader():
 42            for symbols, xyz, comment in xyz_reader(
 43                input_file, has_comment_line=has_comment_line, indexed=xyz_indexed
 44            ):
 45                species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols])
 46                inputs = {
 47                    "species": species,
 48                    "coordinates": xyz,
 49                    "natoms": species.shape[0],
 50                    "total_charge": 0,  # default total charge
 51                }
 52                if periodic:
 53                    box = np.array(comment.split(), dtype=float)
 54                    cell = parse_cell(box)
 55                    inputs["cell"] = cell
 56                yield inputs
 57
 58    elif file_format == "pkl":
 59        with open(input_file, "rb") as f:
 60            data = pickle.load(f)
 61        if isinstance(data, dict):
 62            if "frames" in data:
 63                frames = data["frames"]
 64            elif "training" in data:
 65                frames = data["training"]
 66                if "validation" in data:
 67                    frames.extend(data["validation"])
 68            elif "coordinates" in data or "output_keys" in data:
 69                # data comes from fennol_analyze or fennol_md
 70                frames = []
 71                with open(input_file, "rb") as f:
 72                    if "output_keys" in data: # skip metadata
 73                        frame = pickle.load(f)
 74                    try:
 75                        while True:
 76                            frame = pickle.load(f)
 77                            frames.append(frame)
 78                    except EOFError:
 79                        pass
 80            else:
 81                raise ValueError("No frames found in the input file")
 82        else:
 83            frames = data
 84
 85        def reader():
 86            _periodic = periodic
 87            cell_key = "cells" if "cells" in frames[0] else "cell"
 88            if cell_key in frames[0]:
 89                _periodic = True
 90            for frame in frames:
 91                species = np.array(frame["species"])
 92                coordinates = np.array(frame["coordinates"])
 93                inputs = {
 94                    "species": species,
 95                    "coordinates": coordinates,
 96                    "natoms": species.shape[0],
 97                    "total_charge": frame.get("total_charge", 0),
 98                }
 99                cell_key = "cells" if "cells" in frame else "cell"
100                if _periodic:
101                    cell_key = "cells" if "cells" in frame else "cell"
102                    cell = np.array(frame[cell_key]).reshape(3, 3)
103                    inputs["cell"] = cell
104                yield inputs
105    
106    return reader
def fennix_analyzer( input_file, model, output_keys, file_format=None, periodic=False, flags={}, has_comment_line=True, batch_size=1):
109def fennix_analyzer(input_file, model, output_keys,file_format=None,periodic=False,flags={},has_comment_line=True,batch_size=1):
110    
111    reader = get_file_reader(input_file,file_format=file_format,periodic=periodic,has_comment_line=has_comment_line)
112
113    # define the model prediction function
114    def model_predict(batch):
115        natoms = np.array([frame["natoms"] for frame in batch], dtype=np.int32)
116        batch_index = np.concatenate([frame["batch_index"] for frame in batch])
117        species = np.concatenate([frame["species"] for frame in batch])
118        xyz = np.concatenate([frame["coordinates"] for frame in batch], axis=0)
119        total_charge = np.array([frame["total_charge"] for frame in batch], dtype=np.int32)
120        inputs = {
121            "species": species,
122            "coordinates": xyz,
123            "batch_index": batch_index,
124            "natoms": natoms,
125            "total_charge": total_charge,
126        }
127        if "cell" in batch[0]:
128            cells = np.stack([frame["cell"] for frame in batch], axis=0)
129            inputs["cells"] = cells
130            inputs["reciprocal_cells"] = np.linalg.inv(cells)
131        
132        inputs["flags"] = flags
133
134        if "forces" in output_keys:
135            e, f, output = model.energy_and_forces(**inputs,gpu_preprocessing=True)
136        else:
137            e, output = model.total_energy(**inputs,gpu_preprocessing=True)
138        return output
139
140    # define the function to process a batch
141    def process_batch(batch):
142        output = model_predict(batch)
143        natoms = np.array([frame["natoms"] for frame in batch])
144        _periodic = "cells" in output
145        if _periodic:
146            cells = np.array(output["cells"])
147        species = np.array(output["species"])
148        coordinates = np.array(output["coordinates"])
149        natshift = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(natoms)])
150        frames_data = []
151        for i in range(len(batch)):
152            frame_data = {
153                "species": species[natshift[i] : natshift[i + 1]],
154                "coordinates": coordinates[natshift[i] : natshift[i + 1]],
155                "total_charge": int(output["total_charge"][i]),
156            }
157            if _periodic:
158                frame_data["cell"] = cells[i]
159
160            for k in output_keys:
161                if k not in output:
162                    raise ValueError(f"Output key {k} not found")
163                v = np.asarray(output[k])
164                # if scalar or only 1 element, convert to float
165                # if v.ndim == 0:
166                #     v = float(v)
167                #     print(k,v)
168                # if v.size == 1:
169                #     v = v.flatten()[0]
170                #     print(k,v)
171                if v.shape[0] == species.shape[0]:
172                    frame_data[k] = v[natshift[i] : natshift[i + 1]]
173                elif v.shape[0] == natoms.shape[0]:
174                    frame_data[k] = v[i]
175                else:
176                    raise ValueError(f"Output key {k} has wrong shape {v.shape} {natoms.shape} {species.shape}")
177
178            frames_data.append(frame_data)
179        
180        return frames_data
181
182    batch = []
183    iframe = 0
184    for frame in reader():
185        batch_index = np.full(frame["natoms"], iframe, dtype=np.int32)
186        frame["batch_index"] = batch_index
187        batch.append(frame)
188        iframe += 1
189        if len(batch) == batch_size:
190            frames_data = process_batch(batch)
191            for frame_data in frames_data:
192                yield frame_data
193            # output_data.extend(frames_data)
194            batch = []
195            iframe = 0
196    # process the last batch
197    if len(batch) > 0:
198        frames_data = process_batch(batch)
199        for frame_data in frames_data:
200            yield frame_data
def main():
202def main():
203    parser = argparse.ArgumentParser(prog="fennol_analyze")
204    parser.add_argument(
205        "input_file", type=Path, help="file containing the geometries to analyze"
206    )
207    parser.add_argument(
208        "model_file", type=Path, help="file containing the model to use"
209    )
210    parser.add_argument(
211        "-o", "--outfile", type=Path, help="file to write the output to"
212    )
213    parser.add_argument(
214        "--batch_size",
215        "-b",
216        type=int,
217        default=1,
218        help="batch size to use for the model",
219    )
220    parser.add_argument("--device", type=str, help="device to use for the model")
221    parser.add_argument("-f64", action="store_true", help="use double precision")
222    parser.add_argument("--periodic", action="store_true", help="use PBC")
223    parser.add_argument(
224        "--format",
225        type=str,
226        help="format of the input file. Default: auto-detect from file extension",
227    )
228    parser.add_argument(
229        "-c",
230        "--nocomment",
231        action="store_true",
232        help="flag to indicate that the input file does not have a comment line. Only used for xyz and arc formats",
233    )
234    parser.add_argument(
235        "--output_keys",
236        type=str,
237        nargs="+",
238        help="keys to output",
239        default=["total_energy"],
240    )
241    parser.add_argument(
242        "--nblist",
243        nargs=3,
244        metavar=("mult_size", "add_neigh", "add_atoms"),
245        help="neighbour list parameters: mult_size, add_neigh, add_atoms. If not provided, the default values from the model will be used.",
246    )
247    parser.add_argument(
248        "-m",
249        "--metadata",
250        action="store_true",
251        help="add metadata as a first frame.",
252    )
253    parser.add_argument(
254        "--flags",
255        type=str,
256        nargs="*",
257        default=[],
258        help="additional flags to pass to the model. These will be added to the inputs as a dictionary with the flag name as key and None as value.",
259    )
260    args = parser.parse_args()
261
262    # set the device
263    if args.device:
264        device = args.device.lower()
265    elif "FENNOL_DEVICE" in os.environ:
266        device = os.environ["FENNOL_DEVICE"].lower()
267        print(f"# Setting device from env FENNOL_DEVICE={device}")
268    else:
269        device = "cpu"
270    if device == "cpu":
271        jax.config.update("jax_platforms", "cpu")
272        os.environ["CUDA_VISIBLE_DEVICES"] = ""
273    elif device.startswith("cuda") or device.startswith("gpu"):
274        if ":" in device:
275            num = device.split(":")[-1]
276            os.environ["CUDA_VISIBLE_DEVICES"] = num
277        else:
278            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
279        device = "gpu"
280
281    _device = jax.devices(device)[0]
282    jax.config.update("jax_default_device", _device)
283
284    # set the precision
285    if args.f64:
286        jax.config.update("jax_enable_x64", True)
287        fprec = "float64"
288    else:
289        fprec = "float32"
290    jax.config.update("jax_default_matmul_precision", "highest")
291
292    output_keys = args.output_keys
293
294    # load the model
295    model_file: Path = args.model_file.resolve()
296    assert model_file.exists(), f"Model file {model_file} does not exist"
297    model = FENNIX.load(model_file, use_atom_padding=True)
298
299    preproc_state = unfreeze(model.preproc_state)
300    layer_state = []
301    for st in preproc_state["layers_state"]:
302        stnew = unfreeze(st)
303        # if "nblist_mult_size" in training_parameters:
304        #     stnew["nblist_mult_size"] = training_parameters["nblist_mult_size"]
305        # if "nblist_add_neigh" in training_parameters:
306        #     stnew["add_neigh"] = training_parameters["nblist_add_neigh"]
307        # if "nblist_add_atoms" in training_parameters:
308        #     stnew["add_atoms"] = training_parameters["nblist_add_atoms"]
309        if args.nblist is not None:
310            mult_size, add_neigh, add_atoms = args.nblist
311            stnew["nblist_mult_size"] = float(mult_size)
312            stnew["add_neigh"] = int(add_neigh)
313            stnew["add_atoms"] = int(add_atoms)
314        layer_state.append(freeze(stnew))
315
316    preproc_state["layers_state"] = tuple(layer_state)
317    preproc_state["check_input"] = False
318    model.preproc_state = freeze(preproc_state)
319    # model.preproc_state = model.preproc_state.copy({"check_input": False})
320
321    # check the output file
322    if args.outfile is not None:
323        output_file: Path = args.outfile.resolve()
324        assert not output_file.exists(), f"Output file {args.outfile} already exists"
325        out_format = output_file.suffix.lower()[1:]
326        if out_format == "yml":
327            out_format = "yaml"
328        if out_format == "pickle":
329            out_format = "pkl"
330        if out_format == "hdf5":
331            out_format = "h5"
332        if out_format == "msgpack":
333            out_format = "mpk"
334        assert out_format in [
335            "pkl",
336            "yaml",
337            "h5",
338            "mpk",
339        ], f"Unsupported output file format {output_file.suffix}"
340
341    flags = {flag.strip(): None for flag in args.flags}
342    # print metadata
343    metadata = {
344        "input_file": str(args.input_file),
345        "model_file": str(model_file),
346        "output_keys": output_keys,
347        "energy_unit": model.energy_unit,
348        "flags": flags,
349    }
350    print("metadata:")
351    for k, v in metadata.items():
352        print(f"  {k}: {v}")
353
354    def make_dumpable(data):
355        for k in data:
356            if isinstance(data[k], np.ndarray):
357                data[k] = data[k].tolist()
358        return data
359    
360    def queue_reader(queue):
361        if args.outfile is not None:
362            if out_format == "yaml":
363                of = open(args.outfile, "w")
364                def dump(data):
365                    of.write("\n---\n")
366                    i=yaml.dump(make_dumpable(data), of, sort_keys=False)
367                    of.flush()
368                    return i
369            elif out_format == "pkl":
370                of = open(args.outfile, "wb")
371                def dump(data):
372                    i=pickle.dump(data,of)
373                    of.flush()
374                    return i
375            elif out_format == "h5":
376                import h5py
377                of = h5py.File(args.outfile, "w")
378                global iframeh5
379                iframeh5 = 0  # start from frame 1
380                def dump(data):
381                    global iframeh5
382                    of.create_group(f'{iframeh5}', track_order=True)
383                    for k in data:
384                        of[f"{iframeh5}/{k}"] = data[k]
385                    iframeh5 += 1
386                    return None
387            elif out_format == "mpk":
388                import msgpack
389                of = open(args.outfile, "wb")
390                def dump(data):
391                    i=msgpack.pack(make_dumpable(data), of)
392                    of.flush()
393                    return i
394            else:
395                raise ValueError(f"Unsupported output file format {out_format}")
396        else:
397            def dump(data):
398                print("\n---\n")
399                data = make_dumpable(data)
400                for k in data:
401                    print(f"{k}: {data[k]}")
402                return None
403        while True:
404            data = queue.get()
405            if data is None:
406                break
407            dump(data)
408        if args.outfile is not None:
409            of.close()
410    # create a multiprocessing queue to dump the output
411    queue = mp.Queue()
412    p = mp.Process(target=queue_reader, args=(queue,))
413    p.start()
414    if args.metadata and args.outfile is not None:
415        queue.put(metadata)
416
417   
418
419    error = None
420    try:
421        time_start = time.time()
422        ibatch = 0
423        for iframe,frame in enumerate(fennix_analyzer(
424            args.input_file,
425            model,
426            output_keys=output_keys,
427            file_format=args.format,
428            periodic=args.periodic,
429            flags=flags,
430            has_comment_line=not args.nocomment,
431            batch_size=args.batch_size,
432        )):
433            queue.put(frame)
434            if (iframe+1) % args.batch_size == 0 and args.outfile is not None:
435                ibatch += 1
436                elapsed = time.time() - time_start
437                print(f"# Processed batch {ibatch}. Elapsed time: {human_time_duration(elapsed)}")
438                # output_data.extend(frames_data)
439    except KeyboardInterrupt:
440        print("# Interrupted by user. Exiting...")
441    except Exception as err:
442        error = err
443        print(f"# Exiting with Exception: {error}")
444
445    # wait for the process to finish
446    queue.put(None)
447    print("# Waiting for the writing process to finish...")
448    p.join()
449    print("# All done in ", human_time_duration(time.time() - time_start))
450
451    if error is not None:
452        raise error