fennol.analyze

  1import sys, os, io
  2import argparse
  3from pathlib import Path
  4
  5import numpy as np
  6import jax
  7import jax.numpy as jnp
  8
  9from .models import FENNIX
 10from .utils.io import xyz_reader
 11from .utils.periodic_table import PERIODIC_TABLE_REV_IDX
 12from .utils.atomic_units import AtomicUnits as au
 13import yaml
 14import pickle
 15import json
 16
 17
 18def main():
 19    parser = argparse.ArgumentParser(prog="fennol_analyze")
 20    parser.add_argument(
 21        "input_file", type=Path, help="file containing the geometries to analyze"
 22    )
 23    parser.add_argument(
 24        "model_file", type=Path, help="file containing the model to use"
 25    )
 26    parser.add_argument("--output", "-o", type=Path, help="file to write the output to")
 27    parser.add_argument(
 28        "--batch_size",
 29        "-b",
 30        type=int,
 31        default=1,
 32        help="batch size to use for the model",
 33    )
 34    parser.add_argument(
 35        "--device", type=str, default="cpu", help="device to use for the model"
 36    )
 37    parser.add_argument("-f64", action="store_true", help="use double precision")
 38    parser.add_argument("--periodic", action="store_true", help="use PBC")
 39    parser.add_argument(
 40        "--format", type=str, default="xyz", help="format of the input file"
 41    )
 42    parser.add_argument(
 43        "--output_keys",
 44        type=str,
 45        nargs="+",
 46        help="keys to output",
 47        default=["total_energy"],
 48    )
 49    args = parser.parse_args()
 50
 51    # set the device
 52    device = args.device.lower()
 53    if device == "cpu":
 54        device = "cpu"
 55        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 56    elif device.startswith("cuda") or device.startswith("gpu"):
 57        if ":" in device:
 58            num = device.split(":")[-1]
 59            os.environ["CUDA_VISIBLE_DEVICES"] = num
 60        else:
 61            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 62        device = "gpu"
 63
 64    _device = jax.devices(device)[0]
 65    jax.config.update("jax_default_device", _device)
 66
 67    # set the precision
 68    if args.f64:
 69        jax.config.update("jax_enable_x64", True)
 70        fprec = "float64"
 71    else:
 72        fprec = "float32"
 73    jax.config.update("jax_default_matmul_precision", "highest")
 74
 75
 76    # check the input file
 77    input_file = args.input_file.resolve()
 78    assert input_file.exists(), f"Input file {input_file} does not exist"
 79    # assert args.format == "xyz", f"Only xyz format is supported for now"
 80    assert args.format in ["xyz","arc","pkl"], f"Only xyz, arc and pkl formats are supported for now"
 81    output_keys = args.output_keys
 82    xyz_indexed = args.format == "arc"
 83
 84    # load the model
 85    model_file: Path = args.model_file.resolve()
 86    assert model_file.exists(), f"Model file {model_file} does not exist"
 87    model = FENNIX.load(model_file, use_atom_padding=True)
 88
 89    # check the output file
 90    if args.output is not None:
 91        output_file: Path = args.output.resolve()
 92        assert not output_file.exists(), f"Output file {args.output} already exists"
 93        assert output_file.suffix in [
 94            ".pkl",
 95            ".json",
 96            ".yaml",
 97        ], f"Unsupported output file format {output_file.suffix}"
 98
 99    # print metadata
100    metadata = {
101        "input_file": str(input_file),
102        "model_file": str(model_file),
103        "output_keys": output_keys,
104        "energy_unit": model.energy_unit,
105    }
106    print("metadata:")
107    for k,v in metadata.items():
108        print(f"  {k}: {v}")
109
110    # define the model prediction function
111    def model_predict(batch):
112        natoms = np.array([frame["natoms"] for frame in batch])
113        batch_index = np.concatenate([frame["batch_index"] for frame in batch])
114        species = np.concatenate([frame["species"] for frame in batch])
115        xyz = np.concatenate([frame["coordinates"] for frame in batch], axis=0)
116        inputs = {
117            "species": species,
118            "coordinates": xyz,
119            "batch_index": batch_index,
120            "natoms": natoms,}
121        if args.periodic:
122            cells = np.concatenate([frame["cell"] for frame in batch], axis=0)
123            inputs["cells"] = cells
124
125        if "forces" in output_keys:
126            e, f, output = model.energy_and_forces(
127                **inputs
128            )
129        else:
130            e, output = model.total_energy(
131                **inputs
132            )
133        return output
134
135    # define the function to process a batch
136    def process_batch(batch):
137        output = model_predict(batch)
138        natoms = np.array([frame["natoms"] for frame in batch])
139        if args.periodic:
140            cells = np.array(output["cells"])
141        species = np.array(output["species"])
142        coordinates = np.array(output["coordinates"])
143        natshift = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(natoms)])
144        frames_data = []
145        for i in range(len(batch)):
146            frame_data = {
147                "species": species[natshift[i] : natshift[i + 1]].tolist(),
148                "coordinates": coordinates[natshift[i] : natshift[i + 1]].tolist(),
149            }
150            if args.periodic:
151                frame_data["cell"] = cells[i].tolist()
152
153            for k in output_keys:
154                if k not in output:
155                    raise ValueError(f"Output key {k} not found")
156                v = output[k]
157                if v.shape[0] == species.shape[0]:
158                    frame_data[k] = v[natshift[i] : natshift[i + 1]].tolist()
159                elif v.shape[0] == natoms.shape[0]:
160                    frame_data[k] = v[i].tolist()
161                else:
162                    raise ValueError(f"Output key {k} has wrong shape")
163
164            frames_data.append(frame_data)
165        return frames_data
166
167    ### start processing the input file
168    # reader = xyz_reader(input_file, has_comment_line=True, indexed=xyz_indexed)
169    if args.format in ["arc","xyz"]:
170        def reader():
171            for symbols, xyz, comment in xyz_reader(input_file, has_comment_line=True, indexed=xyz_indexed):
172                species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols])
173                frame = {"species": species, "coordinates": xyz, "natoms":species.shape[0]}
174                if args.periodic:
175                    cell = np.array([float(x) for x in comment.split()]).reshape(1,3,3)
176                    frame["cell"] = cell
177                yield frame
178
179    elif args.format == "pkl":
180        with open(input_file, "rb") as f:
181            data = pickle.load(f)
182        if isinstance(data, dict):
183            if "frames" in data:
184                frames = data["frames"]
185            elif "training" in data:
186                frames = data["training"]
187                if "validation" in data:
188                    frames.extend(data["validation"])
189            else:
190                raise ValueError("No frames found in the input file")
191        else:
192            frames = data
193        
194        def reader():
195            for frame in frames:
196                species = np.array(frame["species"])
197                coordinates = np.array(frame["coordinates"])
198                frame = {"species": species, "coordinates": coordinates,"natoms":species.shape[0]}
199                if args.periodic:
200                    cell = np.array(frame["cell"]).reshape(1,3,3)
201                    frame["cell"] = cell
202                yield frame
203
204    batch = []
205    output_data = []
206    ibatch = 0
207    for frame in reader():
208        batch_index = np.full(frame["natoms"], ibatch, dtype=np.int32)
209        frame["batch_index"] = batch_index
210        batch.append(frame)
211        ibatch += 1
212        if len(batch) == args.batch_size:
213            frames_data = process_batch(batch)
214            if args.output is None:
215                for frame_data in frames_data:
216                    print("\n---\n")
217                    for k in frame_data:
218                        print(f"{k}: {frame_data[k]}")
219            output_data.extend(frames_data)
220            batch = []
221            ibatch = 0
222    # process the last batch
223    if len(batch) > 0:
224        frames_data = process_batch(batch)
225        if args.output is None:
226            for frame_data in frames_data:
227                print("\n---\n")
228                for k in frame_data:
229                    print(f"{k}: {frame_data[k]}")
230        output_data.extend(frames_data)
231
232    # write the output to a file
233    if args.output is not None:
234        output_data = {
235            "metadata": metadata,
236            "frames": output_data,
237        }
238        if output_file.suffix == ".json":
239            with open(args.output, "w") as f:
240                json.dump(output_data, f, indent=2)
241        elif output_file.suffix == ".yaml":
242            with open(args.output, "w") as f:
243                yaml.dump(output_data, f, sort_keys=False)
244        elif output_file.suffix == ".pkl":
245            with open(args.output, "wb") as f:
246                pickle.dump(output_data, f)
247
248
249if __name__ == "__main__":
250    main()
def main():
 19def main():
 20    parser = argparse.ArgumentParser(prog="fennol_analyze")
 21    parser.add_argument(
 22        "input_file", type=Path, help="file containing the geometries to analyze"
 23    )
 24    parser.add_argument(
 25        "model_file", type=Path, help="file containing the model to use"
 26    )
 27    parser.add_argument("--output", "-o", type=Path, help="file to write the output to")
 28    parser.add_argument(
 29        "--batch_size",
 30        "-b",
 31        type=int,
 32        default=1,
 33        help="batch size to use for the model",
 34    )
 35    parser.add_argument(
 36        "--device", type=str, default="cpu", help="device to use for the model"
 37    )
 38    parser.add_argument("-f64", action="store_true", help="use double precision")
 39    parser.add_argument("--periodic", action="store_true", help="use PBC")
 40    parser.add_argument(
 41        "--format", type=str, default="xyz", help="format of the input file"
 42    )
 43    parser.add_argument(
 44        "--output_keys",
 45        type=str,
 46        nargs="+",
 47        help="keys to output",
 48        default=["total_energy"],
 49    )
 50    args = parser.parse_args()
 51
 52    # set the device
 53    device = args.device.lower()
 54    if device == "cpu":
 55        device = "cpu"
 56        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 57    elif device.startswith("cuda") or device.startswith("gpu"):
 58        if ":" in device:
 59            num = device.split(":")[-1]
 60            os.environ["CUDA_VISIBLE_DEVICES"] = num
 61        else:
 62            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 63        device = "gpu"
 64
 65    _device = jax.devices(device)[0]
 66    jax.config.update("jax_default_device", _device)
 67
 68    # set the precision
 69    if args.f64:
 70        jax.config.update("jax_enable_x64", True)
 71        fprec = "float64"
 72    else:
 73        fprec = "float32"
 74    jax.config.update("jax_default_matmul_precision", "highest")
 75
 76
 77    # check the input file
 78    input_file = args.input_file.resolve()
 79    assert input_file.exists(), f"Input file {input_file} does not exist"
 80    # assert args.format == "xyz", f"Only xyz format is supported for now"
 81    assert args.format in ["xyz","arc","pkl"], f"Only xyz, arc and pkl formats are supported for now"
 82    output_keys = args.output_keys
 83    xyz_indexed = args.format == "arc"
 84
 85    # load the model
 86    model_file: Path = args.model_file.resolve()
 87    assert model_file.exists(), f"Model file {model_file} does not exist"
 88    model = FENNIX.load(model_file, use_atom_padding=True)
 89
 90    # check the output file
 91    if args.output is not None:
 92        output_file: Path = args.output.resolve()
 93        assert not output_file.exists(), f"Output file {args.output} already exists"
 94        assert output_file.suffix in [
 95            ".pkl",
 96            ".json",
 97            ".yaml",
 98        ], f"Unsupported output file format {output_file.suffix}"
 99
100    # print metadata
101    metadata = {
102        "input_file": str(input_file),
103        "model_file": str(model_file),
104        "output_keys": output_keys,
105        "energy_unit": model.energy_unit,
106    }
107    print("metadata:")
108    for k,v in metadata.items():
109        print(f"  {k}: {v}")
110
111    # define the model prediction function
112    def model_predict(batch):
113        natoms = np.array([frame["natoms"] for frame in batch])
114        batch_index = np.concatenate([frame["batch_index"] for frame in batch])
115        species = np.concatenate([frame["species"] for frame in batch])
116        xyz = np.concatenate([frame["coordinates"] for frame in batch], axis=0)
117        inputs = {
118            "species": species,
119            "coordinates": xyz,
120            "batch_index": batch_index,
121            "natoms": natoms,}
122        if args.periodic:
123            cells = np.concatenate([frame["cell"] for frame in batch], axis=0)
124            inputs["cells"] = cells
125
126        if "forces" in output_keys:
127            e, f, output = model.energy_and_forces(
128                **inputs
129            )
130        else:
131            e, output = model.total_energy(
132                **inputs
133            )
134        return output
135
136    # define the function to process a batch
137    def process_batch(batch):
138        output = model_predict(batch)
139        natoms = np.array([frame["natoms"] for frame in batch])
140        if args.periodic:
141            cells = np.array(output["cells"])
142        species = np.array(output["species"])
143        coordinates = np.array(output["coordinates"])
144        natshift = np.concatenate([np.array([0], dtype=np.int32), np.cumsum(natoms)])
145        frames_data = []
146        for i in range(len(batch)):
147            frame_data = {
148                "species": species[natshift[i] : natshift[i + 1]].tolist(),
149                "coordinates": coordinates[natshift[i] : natshift[i + 1]].tolist(),
150            }
151            if args.periodic:
152                frame_data["cell"] = cells[i].tolist()
153
154            for k in output_keys:
155                if k not in output:
156                    raise ValueError(f"Output key {k} not found")
157                v = output[k]
158                if v.shape[0] == species.shape[0]:
159                    frame_data[k] = v[natshift[i] : natshift[i + 1]].tolist()
160                elif v.shape[0] == natoms.shape[0]:
161                    frame_data[k] = v[i].tolist()
162                else:
163                    raise ValueError(f"Output key {k} has wrong shape")
164
165            frames_data.append(frame_data)
166        return frames_data
167
168    ### start processing the input file
169    # reader = xyz_reader(input_file, has_comment_line=True, indexed=xyz_indexed)
170    if args.format in ["arc","xyz"]:
171        def reader():
172            for symbols, xyz, comment in xyz_reader(input_file, has_comment_line=True, indexed=xyz_indexed):
173                species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols])
174                frame = {"species": species, "coordinates": xyz, "natoms":species.shape[0]}
175                if args.periodic:
176                    cell = np.array([float(x) for x in comment.split()]).reshape(1,3,3)
177                    frame["cell"] = cell
178                yield frame
179
180    elif args.format == "pkl":
181        with open(input_file, "rb") as f:
182            data = pickle.load(f)
183        if isinstance(data, dict):
184            if "frames" in data:
185                frames = data["frames"]
186            elif "training" in data:
187                frames = data["training"]
188                if "validation" in data:
189                    frames.extend(data["validation"])
190            else:
191                raise ValueError("No frames found in the input file")
192        else:
193            frames = data
194        
195        def reader():
196            for frame in frames:
197                species = np.array(frame["species"])
198                coordinates = np.array(frame["coordinates"])
199                frame = {"species": species, "coordinates": coordinates,"natoms":species.shape[0]}
200                if args.periodic:
201                    cell = np.array(frame["cell"]).reshape(1,3,3)
202                    frame["cell"] = cell
203                yield frame
204
205    batch = []
206    output_data = []
207    ibatch = 0
208    for frame in reader():
209        batch_index = np.full(frame["natoms"], ibatch, dtype=np.int32)
210        frame["batch_index"] = batch_index
211        batch.append(frame)
212        ibatch += 1
213        if len(batch) == args.batch_size:
214            frames_data = process_batch(batch)
215            if args.output is None:
216                for frame_data in frames_data:
217                    print("\n---\n")
218                    for k in frame_data:
219                        print(f"{k}: {frame_data[k]}")
220            output_data.extend(frames_data)
221            batch = []
222            ibatch = 0
223    # process the last batch
224    if len(batch) > 0:
225        frames_data = process_batch(batch)
226        if args.output is None:
227            for frame_data in frames_data:
228                print("\n---\n")
229                for k in frame_data:
230                    print(f"{k}: {frame_data[k]}")
231        output_data.extend(frames_data)
232
233    # write the output to a file
234    if args.output is not None:
235        output_data = {
236            "metadata": metadata,
237            "frames": output_data,
238        }
239        if output_file.suffix == ".json":
240            with open(args.output, "w") as f:
241                json.dump(output_data, f, indent=2)
242        elif output_file.suffix == ".yaml":
243            with open(args.output, "w") as f:
244                yaml.dump(output_data, f, sort_keys=False)
245        elif output_file.suffix == ".pkl":
246            with open(args.output, "wb") as f:
247                pickle.dump(output_data, f)