fennol.training.io

  1import os, io, sys
  2import numpy as np
  3from scipy.spatial.transform import Rotation
  4from collections import defaultdict
  5import pickle
  6import glob
  7from flax import traverse_util
  8from typing import Dict, List, Tuple, Union, Optional, Callable, Sequence
  9from .databases import DBDataset, H5Dataset
 10from ..models.preprocessing import AtomPadding
 11
 12import json
 13import yaml
 14
 15try:
 16    import tomlkit
 17except ImportError:
 18    tomlkit = None
 19
 20try:
 21    from torch.utils.data import DataLoader
 22except ImportError:
 23    raise ImportError(
 24        "PyTorch is required for training. Install the CPU version from https://pytorch.org/get-started/locally/"
 25    )
 26
 27from ..models import FENNIX
 28
 29
 30def load_configuration(config_file: str) -> Dict[str, any]:
 31    if config_file.endswith(".json"):
 32        parameters = json.load(open(config_file))
 33    elif config_file.endswith(".yaml") or config_file.endswith(".yml"):
 34        parameters = yaml.load(open(config_file), Loader=yaml.FullLoader)
 35    elif tomlkit is not None and config_file.endswith(".toml"):
 36        parameters = tomlkit.loads(open(config_file).read())
 37    else:
 38        supported_formats = [".json", ".yaml", ".yml"]
 39        if tomlkit is not None:
 40            supported_formats.append(".toml")
 41        raise ValueError(
 42            f"Unknown config file format. Supported formats: {supported_formats}"
 43        )
 44    return parameters
 45
 46
 47def load_dataset(
 48    dspath: str,
 49    batch_size: int,
 50    rename_refs: dict = {},
 51    infinite_iterator: bool = False,
 52    atom_padding: bool = False,
 53    ref_keys: Optional[Sequence[str]] = None,
 54    split_data_inputs: bool = False,
 55    np_rng: Optional[np.random.Generator] = None,
 56    train_val_split: bool = True,
 57    training_parameters: dict = {},
 58    add_flags: Sequence[str] = ["training"],
 59    fprec: str = "float32",
 60):
 61    """
 62    Load a dataset from a pickle file and return two iterators for training and validation batches.
 63
 64    Args:
 65        training_parameters (dict): A dictionary with the following keys:
 66            - 'dspath': str. Path to the pickle file containing the dataset.
 67            - 'batch_size': int. Number of samples per batch.
 68        rename_refs (list, optional): A list of strings with the names of the reference properties to rename.
 69            Default is an empty list.
 70
 71    Returns:
 72        tuple: A tuple of two infinite iterators, one for training batches and one for validation batches.
 73            For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems:
 74            - 'natoms': np.ndarray. Array with the number of atoms in each system.
 75            - 'batch_index': np.ndarray. Array with the index of the system to which each atom
 76            if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
 77    """
 78
 79    assert isinstance(
 80        training_parameters, dict
 81    ), "training_parameters must be a dictionary."
 82    assert isinstance(
 83        rename_refs, dict
 84    ), "rename_refs must be a dictionary with the keys to rename."
 85
 86    pbc_training = training_parameters.get("pbc_training", False)
 87    minimum_image = training_parameters.get("minimum_image", False)
 88
 89    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
 90
 91    input_keys = [
 92        "species",
 93        "coordinates",
 94        "natoms",
 95        "batch_index",
 96        "total_charge",
 97        "flags",
 98    ]
 99    if pbc_training:
100        input_keys += ["cells", "reciprocal_cells"]
101    if atom_padding:
102        input_keys += ["true_atoms", "true_sys"]
103    if coordinates_ref_key is not None:
104        input_keys += ["system_index", "system_sign"]
105
106    flags = {f: None for f in add_flags}
107    if minimum_image and pbc_training:
108        flags["minimum_image"] = None
109
110    additional_input_keys = set(training_parameters.get("additional_input_keys", []))
111    additional_input_keys_ = set()
112    for key in additional_input_keys:
113        if key not in input_keys:
114            additional_input_keys_.add(key)
115    additional_input_keys = additional_input_keys_
116
117    all_inputs = set(input_keys + list(additional_input_keys))
118
119    extract_all_keys = ref_keys is None
120    if ref_keys is not None:
121        ref_keys = set(ref_keys)
122        ref_keys_ = set()
123        for key in ref_keys:
124            if key not in all_inputs:
125                ref_keys_.add(key)
126
127    random_rotation = training_parameters.get("random_rotation", False)
128    if random_rotation:
129        assert np_rng is not None, "np_rng must be provided for adding noise."
130
131        apply_rotation = {
132            1: lambda x, r: x @ r,
133            -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r),
134            2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r),
135        }
136        valid_rotations = tuple(apply_rotation.keys())
137        rotated_keys = {
138            "coordinates": 1,
139            "forces": 1,
140            "virial_tensor": 2,
141            "stress_tensor": 2,
142            "virial": 2,
143            "stress": 2,
144        }
145        if pbc_training:
146            rotated_keys["cells"] = 1
147        user_rotated_keys = dict(training_parameters.get("rotated_keys", {}))
148        for k, v in user_rotated_keys.items():
149            assert (
150                v in valid_rotations
151            ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}"
152            rotated_keys[k] = v
153
154        # rotated_vector_keys = set(
155        #     ["coordinates", "forces"]
156        #     + list(training_parameters.get("rotated_vector_keys", []))
157        # )
158        # if pbc_training:
159        #     rotated_vector_keys.add("cells")
160
161        # rotated_tensor_keys = set(
162        #     ["virial_tensor", "stress_tensor", "virial", "stress"]
163        #     + list(training_parameters.get("rotated_tensor_keys", []))
164        # )
165        # assert rotated_vector_keys.isdisjoint(
166        #     rotated_tensor_keys
167        # ), "Rotated vector keys and rotated tensor keys must be disjoint."
168        # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys)
169
170        print(
171            "Applying random rotations to the following keys if present:",
172            list(rotated_keys.keys()),
173        )
174
175        def apply_random_rotations(output, nbatch):
176            euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3))
177            r = [
178                Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T
179                for i in range(nbatch)
180            ]
181            for k, l in rotated_keys.items():
182                if k in output:
183                    for i in range(nbatch):
184                        output[k][i] = apply_rotation[l](output[k][i], r[i])
185
186    else:
187
188        def apply_random_rotations(output, nbatch):
189            pass
190
191    if pbc_training:
192        print("Periodic boundary conditions are active.")
193        length_nopbc = training_parameters.get("length_nopbc", 1000.0)
194
195        def add_cell(d, output):
196            if "cell" not in d:
197                cell = np.asarray(
198                    [
199                        [length_nopbc, 0.0, 0.0],
200                        [0.0, length_nopbc, 0.0],
201                        [0.0, 0.0, length_nopbc],
202                    ],
203                    dtype=fprec,
204                )
205            else:
206                cell = np.asarray(d["cell"], dtype=fprec)
207            output["cells"].append(cell.reshape(1, 3, 3))
208
209    else:
210
211        def add_cell(d, output):
212            if "cell" in d:
213                print(
214                    "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions."
215                )
216
217    if extract_all_keys:
218
219        def add_other_keys(d, output, atom_shift):
220            for k, v in d.items():
221                if k in ("cell", "total_charge"):
222                    continue
223                v_array = np.array(v)
224                # Shift atom number if necessary
225                if k.endswith("_atidx"):
226                    v_array = v_array + atom_shift
227                output[k].append(v_array)
228
229    else:
230
231        def add_other_keys(d, output, atom_shift):
232            output["species"].append(np.asarray(d["species"]))
233            output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec))
234            for k in additional_input_keys:
235                v_array = np.array(d[k])
236                # Shift atom number if necessary
237                if k.endswith("_atidx"):
238                    v_array = v_array + atom_shift
239                output[k].append(v_array)
240            for k in ref_keys_:
241                v_array = np.array(d[k])
242                # Shift atom number if necessary
243                if k.endswith("_atidx"):
244                    v_array = v_array + atom_shift
245                output[k].append(v_array)
246                if k + "_mask" in d:
247                    output[k + "_mask"].append(np.asarray(d[k + "_mask"]))
248
249    def add_keys(d, output, atom_shift, batch_index):
250        nat = d["species"].shape[0]
251
252        output["natoms"].append(np.asarray([nat]))
253        output["batch_index"].append(np.asarray([batch_index] * nat))
254        if "total_charge" not in d:
255            total_charge = np.asarray(0.0, dtype=fprec)
256        else:
257            total_charge = np.asarray(d["total_charge"], dtype=fprec)
258        output["total_charge"].append(total_charge)
259
260        add_cell(d, output)
261        add_other_keys(d, output, atom_shift)
262
263        return atom_shift + nat
264
265    def collate_fn_(batch):
266        output = defaultdict(list)
267        atom_shift = 0
268        batch_index = 0
269
270        for d in batch:
271            atom_shift = add_keys(d, output, atom_shift, batch_index)
272            batch_index += 1
273
274            if coordinates_ref_key is not None:
275                output["system_index"].append(np.asarray([batch_index - 1]))
276                output["system_sign"].append(np.asarray([1]))
277                if coordinates_ref_key in d:
278                    dref = {**d, "coordinates": d[coordinates_ref_key]}
279                    atom_shift = add_keys(dref, output, atom_shift, batch_index)
280                    output["system_index"].append(np.asarray([batch_index - 1]))
281                    output["system_sign"].append(np.asarray([-1]))
282                    batch_index += 1
283
284        nbatch_ = len(output["natoms"])
285        apply_random_rotations(output,nbatch_)
286
287        # Stack and concatenate the arrays
288        for k, v in output.items():
289            if v[0].ndim == 0:
290                v = np.stack(v)
291            else:
292                v = np.concatenate(v, axis=0)
293            if np.issubdtype(v.dtype, np.floating):
294                v = v.astype(fprec)
295            output[k] = v
296
297        if "cells" in output and pbc_training:
298            output["reciprocal_cells"] = np.linalg.inv(output["cells"])
299
300        # Rename necessary keys
301        # for key in rename_refs:
302        #     if key in output:
303        #         output["true_" + key] = output.pop(key)
304        for kold, knew in rename_refs.items():
305            assert (
306                knew not in output
307            ), f"Cannot rename key {kold} to {knew}. Key {knew} already present."
308            if kold in output:
309                output[knew] = output.pop(kold)
310
311        output["flags"] = flags
312        return output
313
314    collate_layers_train = [collate_fn_]
315    collate_layers_valid = [collate_fn_]
316
317    ### collate preprocessing
318    # add noise to the training data
319    noise_sigma = training_parameters.get("noise_sigma", None)
320    if noise_sigma is not None:
321        assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary"
322
323        for sigma in noise_sigma.values():
324            assert sigma >= 0, "Noise sigma should be a positive number"
325
326        print("Adding noise to the training data:")
327        for key, sigma in noise_sigma.items():
328            print(f"  - {key} with sigma = {sigma}")
329
330        assert np_rng is not None, "np_rng must be provided for adding noise."
331
332        def collate_with_noise(batch):
333            for key, sigma in noise_sigma.items():
334                if key in batch and sigma > 0:
335                    batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype(
336                        batch[key].dtype
337                    )
338            return batch
339
340        collate_layers_train.append(collate_with_noise)
341
342    if atom_padding:
343        padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0))
344        padder_state = padder.init()
345
346        def collate_with_padding(batch):
347            padder_state_up, output = padder(padder_state, batch)
348            padder_state.update(padder_state_up)
349            return output
350
351        collate_layers_train.append(collate_with_padding)
352        collate_layers_valid.append(collate_with_padding)
353
354    if split_data_inputs:
355
356        # input_keys += additional_input_keys
357        # input_keys = set(input_keys)
358        print("Input keys:", all_inputs)
359        print("Ref keys:", ref_keys)
360
361        def collate_split(batch):
362            inputs = {}
363            refs = {}
364            for k, v in batch.items():
365                if k in all_inputs:
366                    inputs[k] = v
367                if k in ref_keys:
368                    refs[k] = v
369                if k.endswith("_mask") and k[:-5] in ref_keys:
370                    refs[k] = v
371            return inputs, refs
372
373        collate_layers_train.append(collate_split)
374        collate_layers_valid.append(collate_split)
375
376    ### apply all collate preprocessing
377    if len(collate_layers_train) == 1:
378        collate_fn_train = collate_layers_train[0]
379    else:
380
381        def collate_fn_train(batch):
382            for layer in collate_layers_train:
383                batch = layer(batch)
384            return batch
385
386    if len(collate_layers_valid) == 1:
387        collate_fn_valid = collate_layers_valid[0]
388    else:
389
390        def collate_fn_valid(batch):
391            for layer in collate_layers_valid:
392                batch = layer(batch)
393            return batch
394
395    # dspath = training_parameters.get("dspath", None)
396    print(f"Loading dataset from {dspath}...", end="")
397    # print(f"   the following keys will be renamed if present : {rename_refs}")
398    sharded_training = False
399    if dspath.endswith(".db"):
400        dataset = {}
401        if train_val_split:
402            dataset["training"] = DBDataset(dspath, table="training")
403            dataset["validation"] = DBDataset(dspath, table="validation")
404        else:
405            dataset = DBDataset(dspath)
406    elif dspath.endswith(".h5") or dspath.endswith(".hdf5"):
407        dataset = {}
408        if train_val_split:
409            dataset["training"] = H5Dataset(dspath, table="training")
410            dataset["validation"] = H5Dataset(dspath, table="validation")
411        else:
412            dataset = H5Dataset(dspath)
413    elif dspath.endswith(".pkl") or dspath.endswith(".pickle"):
414        with open(dspath, "rb") as f:
415            dataset = pickle.load(f)
416        if not train_val_split and isinstance(dataset, dict):
417            dataset = dataset["training"]
418    elif os.path.isdir(dspath):
419        if train_val_split:
420            dataset = {}
421            with open(dspath + "/validation.pkl", "rb") as f:
422                dataset["validation"] = pickle.load(f)
423        else:
424            dataset = None
425
426        shard_files = sorted(glob.glob(dspath + "/training_*.pkl"))
427        nshards = len(shard_files)
428        if nshards == 0:
429            raise ValueError("No dataset shards found.")
430        elif nshards == 1:
431            with open(shard_files[0], "rb") as f:
432                if train_val_split:
433                    dataset["training"] = pickle.load(f)
434                else:
435                    dataset = pickle.load(f)
436        else:
437            print(f"Found {nshards} dataset shards.")
438            sharded_training = True
439
440    else:
441        raise ValueError(
442            f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'"
443        )
444    print(" done.")
445
446    ### BUILD DATALOADERS
447    # batch_size = training_parameters.get("batch_size", 16)
448    shuffle = training_parameters.get("shuffle_dataset", True)
449    if train_val_split:
450        dataloader_validation = DataLoader(
451            dataset["validation"],
452            batch_size=batch_size,
453            shuffle=shuffle,
454            collate_fn=collate_fn_valid,
455        )
456
457    if sharded_training:
458
459        def iterate_sharded_dataset():
460            indices = np.arange(nshards)
461            if shuffle:
462                assert np_rng is not None, "np_rng must be provided for shuffling."
463                np_rng.shuffle(indices)
464            for i in indices:
465                filename = shard_files[i]
466                print(f"# Loading dataset shard from {filename}...", end="")
467                with open(filename, "rb") as f:
468                    dataset = pickle.load(f)
469                print(" done.")
470                dataloader = DataLoader(
471                    dataset,
472                    batch_size=batch_size,
473                    shuffle=shuffle,
474                    collate_fn=collate_fn_train,
475                )
476                for batch in dataloader:
477                    yield batch
478
479        class DataLoaderSharded:
480            def __iter__(self):
481                return iterate_sharded_dataset()
482
483        dataloader_training = DataLoaderSharded()
484    else:
485        dataloader_training = DataLoader(
486            dataset["training"] if train_val_split else dataset,
487            batch_size=batch_size,
488            shuffle=shuffle,
489            collate_fn=collate_fn_train,
490        )
491
492    if not infinite_iterator:
493        if train_val_split:
494            return dataloader_training, dataloader_validation
495        return dataloader_training
496
497    def next_batch_factory(dataloader):
498        while True:
499            for batch in dataloader:
500                yield batch
501
502    training_iterator = next_batch_factory(dataloader_training)
503    if train_val_split:
504        validation_iterator = next_batch_factory(dataloader_validation)
505        return training_iterator, validation_iterator
506    return training_iterator
507
508
509def load_model(
510    parameters: Dict[str, any],
511    model_file: Optional[str] = None,
512    rng_key: Optional[str] = None,
513) -> FENNIX:
514    """
515    Load a FENNIX model from a file or create a new one.
516
517    Args:
518        parameters (dict): A dictionary of parameters for the model.
519        model_file (str, optional): The path to a saved model file to load.
520
521    Returns:
522        FENNIX: A FENNIX model object.
523    """
524    print_model = parameters["training"].get("print_model", False)
525    if model_file is None:
526        model_file = parameters.get("model_file", None)
527    if model_file is not None and os.path.exists(model_file):
528        model = FENNIX.load(model_file, use_atom_padding=False)
529        if print_model:
530            print(model.summarize())
531        print(f"Restored model from '{model_file}'.")
532    else:
533        assert (
534            rng_key is not None
535        ), "rng_key must be specified if model_file is not provided."
536        model_params = parameters["model"]
537        if isinstance(model_params, str):
538            assert os.path.exists(
539                model_params
540            ), f"Model file '{model_params}' not found."
541            model = FENNIX.load(model_params, use_atom_padding=False)
542            print(f"Restored model from '{model_params}'.")
543        else:
544            model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False)
545        if print_model:
546            print(model.summarize())
547    return model
548
549
550def copy_parameters(variables, variables_ref, params):
551    def merge_params(full_path_, v, v_ref):
552        full_path = "/".join(full_path_[1:]).lower()
553        status = (False, "")
554        for path in params:
555            if full_path.startswith(path.lower()) and len(path) > len(status[1]):
556                status = (True, path)
557        return v_ref if status[0] else v
558
559    flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False)
560    flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False)
561    return traverse_util.unflatten_dict(
562        {
563            k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v
564            for k, v in flat.items()
565        }
566    )
567
568
569class TeeLogger(object):
570    def __init__(self, name):
571        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
572        self.stdout = None
573
574    def __del__(self):
575        self.close()
576
577    def write(self, data):
578        self.file.write(data)
579        self.stdout.write(data)
580        self.flush()
581
582    def close(self):
583        self.file.close()
584
585    def flush(self):
586        self.file.flush()
587
588    def bind_stdout(self):
589        if isinstance(sys.stdout, TeeLogger):
590            raise ValueError("stdout already bound to a Tee instance.")
591        if self.stdout is not None:
592            raise ValueError("stdout already bound.")
593        self.stdout = sys.stdout
594        sys.stdout = self
595
596    def unbind_stdout(self):
597        if self.stdout is None:
598            raise ValueError("stdout is not bound.")
599        sys.stdout = self.stdout
def load_configuration(config_file: str) -> Dict[str, <built-in function any>]:
31def load_configuration(config_file: str) -> Dict[str, any]:
32    if config_file.endswith(".json"):
33        parameters = json.load(open(config_file))
34    elif config_file.endswith(".yaml") or config_file.endswith(".yml"):
35        parameters = yaml.load(open(config_file), Loader=yaml.FullLoader)
36    elif tomlkit is not None and config_file.endswith(".toml"):
37        parameters = tomlkit.loads(open(config_file).read())
38    else:
39        supported_formats = [".json", ".yaml", ".yml"]
40        if tomlkit is not None:
41            supported_formats.append(".toml")
42        raise ValueError(
43            f"Unknown config file format. Supported formats: {supported_formats}"
44        )
45    return parameters
def load_dataset( dspath: str, batch_size: int, rename_refs: dict = {}, infinite_iterator: bool = False, atom_padding: bool = False, ref_keys: Optional[Sequence[str]] = None, split_data_inputs: bool = False, np_rng: Optional[numpy.random._generator.Generator] = None, train_val_split: bool = True, training_parameters: dict = {}, add_flags: Sequence[str] = ['training'], fprec: str = 'float32'):
 48def load_dataset(
 49    dspath: str,
 50    batch_size: int,
 51    rename_refs: dict = {},
 52    infinite_iterator: bool = False,
 53    atom_padding: bool = False,
 54    ref_keys: Optional[Sequence[str]] = None,
 55    split_data_inputs: bool = False,
 56    np_rng: Optional[np.random.Generator] = None,
 57    train_val_split: bool = True,
 58    training_parameters: dict = {},
 59    add_flags: Sequence[str] = ["training"],
 60    fprec: str = "float32",
 61):
 62    """
 63    Load a dataset from a pickle file and return two iterators for training and validation batches.
 64
 65    Args:
 66        training_parameters (dict): A dictionary with the following keys:
 67            - 'dspath': str. Path to the pickle file containing the dataset.
 68            - 'batch_size': int. Number of samples per batch.
 69        rename_refs (list, optional): A list of strings with the names of the reference properties to rename.
 70            Default is an empty list.
 71
 72    Returns:
 73        tuple: A tuple of two infinite iterators, one for training batches and one for validation batches.
 74            For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems:
 75            - 'natoms': np.ndarray. Array with the number of atoms in each system.
 76            - 'batch_index': np.ndarray. Array with the index of the system to which each atom
 77            if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
 78    """
 79
 80    assert isinstance(
 81        training_parameters, dict
 82    ), "training_parameters must be a dictionary."
 83    assert isinstance(
 84        rename_refs, dict
 85    ), "rename_refs must be a dictionary with the keys to rename."
 86
 87    pbc_training = training_parameters.get("pbc_training", False)
 88    minimum_image = training_parameters.get("minimum_image", False)
 89
 90    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
 91
 92    input_keys = [
 93        "species",
 94        "coordinates",
 95        "natoms",
 96        "batch_index",
 97        "total_charge",
 98        "flags",
 99    ]
100    if pbc_training:
101        input_keys += ["cells", "reciprocal_cells"]
102    if atom_padding:
103        input_keys += ["true_atoms", "true_sys"]
104    if coordinates_ref_key is not None:
105        input_keys += ["system_index", "system_sign"]
106
107    flags = {f: None for f in add_flags}
108    if minimum_image and pbc_training:
109        flags["minimum_image"] = None
110
111    additional_input_keys = set(training_parameters.get("additional_input_keys", []))
112    additional_input_keys_ = set()
113    for key in additional_input_keys:
114        if key not in input_keys:
115            additional_input_keys_.add(key)
116    additional_input_keys = additional_input_keys_
117
118    all_inputs = set(input_keys + list(additional_input_keys))
119
120    extract_all_keys = ref_keys is None
121    if ref_keys is not None:
122        ref_keys = set(ref_keys)
123        ref_keys_ = set()
124        for key in ref_keys:
125            if key not in all_inputs:
126                ref_keys_.add(key)
127
128    random_rotation = training_parameters.get("random_rotation", False)
129    if random_rotation:
130        assert np_rng is not None, "np_rng must be provided for adding noise."
131
132        apply_rotation = {
133            1: lambda x, r: x @ r,
134            -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r),
135            2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r),
136        }
137        valid_rotations = tuple(apply_rotation.keys())
138        rotated_keys = {
139            "coordinates": 1,
140            "forces": 1,
141            "virial_tensor": 2,
142            "stress_tensor": 2,
143            "virial": 2,
144            "stress": 2,
145        }
146        if pbc_training:
147            rotated_keys["cells"] = 1
148        user_rotated_keys = dict(training_parameters.get("rotated_keys", {}))
149        for k, v in user_rotated_keys.items():
150            assert (
151                v in valid_rotations
152            ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}"
153            rotated_keys[k] = v
154
155        # rotated_vector_keys = set(
156        #     ["coordinates", "forces"]
157        #     + list(training_parameters.get("rotated_vector_keys", []))
158        # )
159        # if pbc_training:
160        #     rotated_vector_keys.add("cells")
161
162        # rotated_tensor_keys = set(
163        #     ["virial_tensor", "stress_tensor", "virial", "stress"]
164        #     + list(training_parameters.get("rotated_tensor_keys", []))
165        # )
166        # assert rotated_vector_keys.isdisjoint(
167        #     rotated_tensor_keys
168        # ), "Rotated vector keys and rotated tensor keys must be disjoint."
169        # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys)
170
171        print(
172            "Applying random rotations to the following keys if present:",
173            list(rotated_keys.keys()),
174        )
175
176        def apply_random_rotations(output, nbatch):
177            euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3))
178            r = [
179                Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T
180                for i in range(nbatch)
181            ]
182            for k, l in rotated_keys.items():
183                if k in output:
184                    for i in range(nbatch):
185                        output[k][i] = apply_rotation[l](output[k][i], r[i])
186
187    else:
188
189        def apply_random_rotations(output, nbatch):
190            pass
191
192    if pbc_training:
193        print("Periodic boundary conditions are active.")
194        length_nopbc = training_parameters.get("length_nopbc", 1000.0)
195
196        def add_cell(d, output):
197            if "cell" not in d:
198                cell = np.asarray(
199                    [
200                        [length_nopbc, 0.0, 0.0],
201                        [0.0, length_nopbc, 0.0],
202                        [0.0, 0.0, length_nopbc],
203                    ],
204                    dtype=fprec,
205                )
206            else:
207                cell = np.asarray(d["cell"], dtype=fprec)
208            output["cells"].append(cell.reshape(1, 3, 3))
209
210    else:
211
212        def add_cell(d, output):
213            if "cell" in d:
214                print(
215                    "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions."
216                )
217
218    if extract_all_keys:
219
220        def add_other_keys(d, output, atom_shift):
221            for k, v in d.items():
222                if k in ("cell", "total_charge"):
223                    continue
224                v_array = np.array(v)
225                # Shift atom number if necessary
226                if k.endswith("_atidx"):
227                    v_array = v_array + atom_shift
228                output[k].append(v_array)
229
230    else:
231
232        def add_other_keys(d, output, atom_shift):
233            output["species"].append(np.asarray(d["species"]))
234            output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec))
235            for k in additional_input_keys:
236                v_array = np.array(d[k])
237                # Shift atom number if necessary
238                if k.endswith("_atidx"):
239                    v_array = v_array + atom_shift
240                output[k].append(v_array)
241            for k in ref_keys_:
242                v_array = np.array(d[k])
243                # Shift atom number if necessary
244                if k.endswith("_atidx"):
245                    v_array = v_array + atom_shift
246                output[k].append(v_array)
247                if k + "_mask" in d:
248                    output[k + "_mask"].append(np.asarray(d[k + "_mask"]))
249
250    def add_keys(d, output, atom_shift, batch_index):
251        nat = d["species"].shape[0]
252
253        output["natoms"].append(np.asarray([nat]))
254        output["batch_index"].append(np.asarray([batch_index] * nat))
255        if "total_charge" not in d:
256            total_charge = np.asarray(0.0, dtype=fprec)
257        else:
258            total_charge = np.asarray(d["total_charge"], dtype=fprec)
259        output["total_charge"].append(total_charge)
260
261        add_cell(d, output)
262        add_other_keys(d, output, atom_shift)
263
264        return atom_shift + nat
265
266    def collate_fn_(batch):
267        output = defaultdict(list)
268        atom_shift = 0
269        batch_index = 0
270
271        for d in batch:
272            atom_shift = add_keys(d, output, atom_shift, batch_index)
273            batch_index += 1
274
275            if coordinates_ref_key is not None:
276                output["system_index"].append(np.asarray([batch_index - 1]))
277                output["system_sign"].append(np.asarray([1]))
278                if coordinates_ref_key in d:
279                    dref = {**d, "coordinates": d[coordinates_ref_key]}
280                    atom_shift = add_keys(dref, output, atom_shift, batch_index)
281                    output["system_index"].append(np.asarray([batch_index - 1]))
282                    output["system_sign"].append(np.asarray([-1]))
283                    batch_index += 1
284
285        nbatch_ = len(output["natoms"])
286        apply_random_rotations(output,nbatch_)
287
288        # Stack and concatenate the arrays
289        for k, v in output.items():
290            if v[0].ndim == 0:
291                v = np.stack(v)
292            else:
293                v = np.concatenate(v, axis=0)
294            if np.issubdtype(v.dtype, np.floating):
295                v = v.astype(fprec)
296            output[k] = v
297
298        if "cells" in output and pbc_training:
299            output["reciprocal_cells"] = np.linalg.inv(output["cells"])
300
301        # Rename necessary keys
302        # for key in rename_refs:
303        #     if key in output:
304        #         output["true_" + key] = output.pop(key)
305        for kold, knew in rename_refs.items():
306            assert (
307                knew not in output
308            ), f"Cannot rename key {kold} to {knew}. Key {knew} already present."
309            if kold in output:
310                output[knew] = output.pop(kold)
311
312        output["flags"] = flags
313        return output
314
315    collate_layers_train = [collate_fn_]
316    collate_layers_valid = [collate_fn_]
317
318    ### collate preprocessing
319    # add noise to the training data
320    noise_sigma = training_parameters.get("noise_sigma", None)
321    if noise_sigma is not None:
322        assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary"
323
324        for sigma in noise_sigma.values():
325            assert sigma >= 0, "Noise sigma should be a positive number"
326
327        print("Adding noise to the training data:")
328        for key, sigma in noise_sigma.items():
329            print(f"  - {key} with sigma = {sigma}")
330
331        assert np_rng is not None, "np_rng must be provided for adding noise."
332
333        def collate_with_noise(batch):
334            for key, sigma in noise_sigma.items():
335                if key in batch and sigma > 0:
336                    batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype(
337                        batch[key].dtype
338                    )
339            return batch
340
341        collate_layers_train.append(collate_with_noise)
342
343    if atom_padding:
344        padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0))
345        padder_state = padder.init()
346
347        def collate_with_padding(batch):
348            padder_state_up, output = padder(padder_state, batch)
349            padder_state.update(padder_state_up)
350            return output
351
352        collate_layers_train.append(collate_with_padding)
353        collate_layers_valid.append(collate_with_padding)
354
355    if split_data_inputs:
356
357        # input_keys += additional_input_keys
358        # input_keys = set(input_keys)
359        print("Input keys:", all_inputs)
360        print("Ref keys:", ref_keys)
361
362        def collate_split(batch):
363            inputs = {}
364            refs = {}
365            for k, v in batch.items():
366                if k in all_inputs:
367                    inputs[k] = v
368                if k in ref_keys:
369                    refs[k] = v
370                if k.endswith("_mask") and k[:-5] in ref_keys:
371                    refs[k] = v
372            return inputs, refs
373
374        collate_layers_train.append(collate_split)
375        collate_layers_valid.append(collate_split)
376
377    ### apply all collate preprocessing
378    if len(collate_layers_train) == 1:
379        collate_fn_train = collate_layers_train[0]
380    else:
381
382        def collate_fn_train(batch):
383            for layer in collate_layers_train:
384                batch = layer(batch)
385            return batch
386
387    if len(collate_layers_valid) == 1:
388        collate_fn_valid = collate_layers_valid[0]
389    else:
390
391        def collate_fn_valid(batch):
392            for layer in collate_layers_valid:
393                batch = layer(batch)
394            return batch
395
396    # dspath = training_parameters.get("dspath", None)
397    print(f"Loading dataset from {dspath}...", end="")
398    # print(f"   the following keys will be renamed if present : {rename_refs}")
399    sharded_training = False
400    if dspath.endswith(".db"):
401        dataset = {}
402        if train_val_split:
403            dataset["training"] = DBDataset(dspath, table="training")
404            dataset["validation"] = DBDataset(dspath, table="validation")
405        else:
406            dataset = DBDataset(dspath)
407    elif dspath.endswith(".h5") or dspath.endswith(".hdf5"):
408        dataset = {}
409        if train_val_split:
410            dataset["training"] = H5Dataset(dspath, table="training")
411            dataset["validation"] = H5Dataset(dspath, table="validation")
412        else:
413            dataset = H5Dataset(dspath)
414    elif dspath.endswith(".pkl") or dspath.endswith(".pickle"):
415        with open(dspath, "rb") as f:
416            dataset = pickle.load(f)
417        if not train_val_split and isinstance(dataset, dict):
418            dataset = dataset["training"]
419    elif os.path.isdir(dspath):
420        if train_val_split:
421            dataset = {}
422            with open(dspath + "/validation.pkl", "rb") as f:
423                dataset["validation"] = pickle.load(f)
424        else:
425            dataset = None
426
427        shard_files = sorted(glob.glob(dspath + "/training_*.pkl"))
428        nshards = len(shard_files)
429        if nshards == 0:
430            raise ValueError("No dataset shards found.")
431        elif nshards == 1:
432            with open(shard_files[0], "rb") as f:
433                if train_val_split:
434                    dataset["training"] = pickle.load(f)
435                else:
436                    dataset = pickle.load(f)
437        else:
438            print(f"Found {nshards} dataset shards.")
439            sharded_training = True
440
441    else:
442        raise ValueError(
443            f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'"
444        )
445    print(" done.")
446
447    ### BUILD DATALOADERS
448    # batch_size = training_parameters.get("batch_size", 16)
449    shuffle = training_parameters.get("shuffle_dataset", True)
450    if train_val_split:
451        dataloader_validation = DataLoader(
452            dataset["validation"],
453            batch_size=batch_size,
454            shuffle=shuffle,
455            collate_fn=collate_fn_valid,
456        )
457
458    if sharded_training:
459
460        def iterate_sharded_dataset():
461            indices = np.arange(nshards)
462            if shuffle:
463                assert np_rng is not None, "np_rng must be provided for shuffling."
464                np_rng.shuffle(indices)
465            for i in indices:
466                filename = shard_files[i]
467                print(f"# Loading dataset shard from {filename}...", end="")
468                with open(filename, "rb") as f:
469                    dataset = pickle.load(f)
470                print(" done.")
471                dataloader = DataLoader(
472                    dataset,
473                    batch_size=batch_size,
474                    shuffle=shuffle,
475                    collate_fn=collate_fn_train,
476                )
477                for batch in dataloader:
478                    yield batch
479
480        class DataLoaderSharded:
481            def __iter__(self):
482                return iterate_sharded_dataset()
483
484        dataloader_training = DataLoaderSharded()
485    else:
486        dataloader_training = DataLoader(
487            dataset["training"] if train_val_split else dataset,
488            batch_size=batch_size,
489            shuffle=shuffle,
490            collate_fn=collate_fn_train,
491        )
492
493    if not infinite_iterator:
494        if train_val_split:
495            return dataloader_training, dataloader_validation
496        return dataloader_training
497
498    def next_batch_factory(dataloader):
499        while True:
500            for batch in dataloader:
501                yield batch
502
503    training_iterator = next_batch_factory(dataloader_training)
504    if train_val_split:
505        validation_iterator = next_batch_factory(dataloader_validation)
506        return training_iterator, validation_iterator
507    return training_iterator

Load a dataset from a pickle file and return two iterators for training and validation batches.

Args: training_parameters (dict): A dictionary with the following keys: - 'dspath': str. Path to the pickle file containing the dataset. - 'batch_size': int. Number of samples per batch. rename_refs (list, optional): A list of strings with the names of the reference properties to rename. Default is an empty list.

Returns: tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: - 'natoms': np.ndarray. Array with the number of atoms in each system. - 'batch_index': np.ndarray. Array with the index of the system to which each atom if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.

def load_model( parameters: Dict[str, <built-in function any>], model_file: Optional[str] = None, rng_key: Optional[str] = None) -> fennol.models.fennix.FENNIX:
510def load_model(
511    parameters: Dict[str, any],
512    model_file: Optional[str] = None,
513    rng_key: Optional[str] = None,
514) -> FENNIX:
515    """
516    Load a FENNIX model from a file or create a new one.
517
518    Args:
519        parameters (dict): A dictionary of parameters for the model.
520        model_file (str, optional): The path to a saved model file to load.
521
522    Returns:
523        FENNIX: A FENNIX model object.
524    """
525    print_model = parameters["training"].get("print_model", False)
526    if model_file is None:
527        model_file = parameters.get("model_file", None)
528    if model_file is not None and os.path.exists(model_file):
529        model = FENNIX.load(model_file, use_atom_padding=False)
530        if print_model:
531            print(model.summarize())
532        print(f"Restored model from '{model_file}'.")
533    else:
534        assert (
535            rng_key is not None
536        ), "rng_key must be specified if model_file is not provided."
537        model_params = parameters["model"]
538        if isinstance(model_params, str):
539            assert os.path.exists(
540                model_params
541            ), f"Model file '{model_params}' not found."
542            model = FENNIX.load(model_params, use_atom_padding=False)
543            print(f"Restored model from '{model_params}'.")
544        else:
545            model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False)
546        if print_model:
547            print(model.summarize())
548    return model

Load a FENNIX model from a file or create a new one.

Args: parameters (dict): A dictionary of parameters for the model. model_file (str, optional): The path to a saved model file to load.

Returns: FENNIX: A FENNIX model object.

def copy_parameters(variables, variables_ref, params):
551def copy_parameters(variables, variables_ref, params):
552    def merge_params(full_path_, v, v_ref):
553        full_path = "/".join(full_path_[1:]).lower()
554        status = (False, "")
555        for path in params:
556            if full_path.startswith(path.lower()) and len(path) > len(status[1]):
557                status = (True, path)
558        return v_ref if status[0] else v
559
560    flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False)
561    flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False)
562    return traverse_util.unflatten_dict(
563        {
564            k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v
565            for k, v in flat.items()
566        }
567    )
class TeeLogger:
570class TeeLogger(object):
571    def __init__(self, name):
572        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
573        self.stdout = None
574
575    def __del__(self):
576        self.close()
577
578    def write(self, data):
579        self.file.write(data)
580        self.stdout.write(data)
581        self.flush()
582
583    def close(self):
584        self.file.close()
585
586    def flush(self):
587        self.file.flush()
588
589    def bind_stdout(self):
590        if isinstance(sys.stdout, TeeLogger):
591            raise ValueError("stdout already bound to a Tee instance.")
592        if self.stdout is not None:
593            raise ValueError("stdout already bound.")
594        self.stdout = sys.stdout
595        sys.stdout = self
596
597    def unbind_stdout(self):
598        if self.stdout is None:
599            raise ValueError("stdout is not bound.")
600        sys.stdout = self.stdout
TeeLogger(name)
571    def __init__(self, name):
572        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
573        self.stdout = None
file
stdout
def write(self, data):
578    def write(self, data):
579        self.file.write(data)
580        self.stdout.write(data)
581        self.flush()
def close(self):
583    def close(self):
584        self.file.close()
def flush(self):
586    def flush(self):
587        self.file.flush()
def bind_stdout(self):
589    def bind_stdout(self):
590        if isinstance(sys.stdout, TeeLogger):
591            raise ValueError("stdout already bound to a Tee instance.")
592        if self.stdout is not None:
593            raise ValueError("stdout already bound.")
594        self.stdout = sys.stdout
595        sys.stdout = self
def unbind_stdout(self):
597    def unbind_stdout(self):
598        if self.stdout is None:
599            raise ValueError("stdout is not bound.")
600        sys.stdout = self.stdout