fennol.training.io

  1import os, io, sys
  2import numpy as np
  3from scipy.spatial.transform import Rotation
  4from collections import defaultdict
  5import pickle
  6import glob
  7from flax import traverse_util
  8from typing import Dict, List, Tuple, Union, Optional, Callable, Sequence
  9from .databases import DBDataset, H5Dataset
 10from ..models.preprocessing import AtomPadding
 11import re
 12import json
 13import yaml
 14
 15try:
 16    import tomlkit
 17except ImportError:
 18    tomlkit = None
 19
 20try:
 21    from torch.utils.data import DataLoader
 22except ImportError:
 23    raise ImportError(
 24        "PyTorch is required for training models. Install the CPU version following instructions at https://pytorch.org/get-started/locally/"
 25    )
 26
 27from ..models import FENNIX
 28
 29
 30def load_configuration(config_file: str) -> Dict[str, any]:
 31    if config_file.endswith(".json"):
 32        parameters = json.load(open(config_file))
 33    elif config_file.endswith(".yaml") or config_file.endswith(".yml"):
 34        parameters = yaml.load(open(config_file), Loader=yaml.FullLoader)
 35    elif tomlkit is not None and config_file.endswith(".toml"):
 36        parameters = tomlkit.loads(open(config_file).read())
 37    else:
 38        supported_formats = [".json", ".yaml", ".yml"]
 39        if tomlkit is not None:
 40            supported_formats.append(".toml")
 41        raise ValueError(
 42            f"Unknown config file format. Supported formats: {supported_formats}"
 43        )
 44    return parameters
 45
 46
 47def load_dataset(
 48    dspath: str,
 49    batch_size: int,
 50    batch_size_val: Optional[int] = None,
 51    rename_refs: dict = {},
 52    infinite_iterator: bool = False,
 53    atom_padding: bool = False,
 54    ref_keys: Optional[Sequence[str]] = None,
 55    split_data_inputs: bool = False,
 56    np_rng: Optional[np.random.Generator] = None,
 57    train_val_split: bool = True,
 58    training_parameters: dict = {},
 59    add_flags: Sequence[str] = ["training"],
 60    fprec: str = "float32",
 61):
 62    """
 63    Load a dataset from a pickle file and return two iterators for training and validation batches.
 64
 65    Args:
 66        training_parameters (dict): A dictionary with the following keys:
 67            - 'dspath': str. Path to the pickle file containing the dataset.
 68            - 'batch_size': int. Number of samples per batch.
 69        rename_refs (list, optional): A list of strings with the names of the reference properties to rename.
 70            Default is an empty list.
 71
 72    Returns:
 73        tuple: A tuple of two infinite iterators, one for training batches and one for validation batches.
 74            For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems:
 75            - 'natoms': np.ndarray. Array with the number of atoms in each system.
 76            - 'batch_index': np.ndarray. Array with the index of the system to which each atom
 77            if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
 78    """
 79
 80    assert isinstance(
 81        training_parameters, dict
 82    ), "training_parameters must be a dictionary."
 83    assert isinstance(
 84        rename_refs, dict
 85    ), "rename_refs must be a dictionary with the keys to rename."
 86
 87    pbc_training = training_parameters.get("pbc_training", False)
 88    minimum_image = training_parameters.get("minimum_image", False)
 89
 90    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
 91
 92    input_keys = [
 93        "species",
 94        "coordinates",
 95        "natoms",
 96        "batch_index",
 97        "total_charge",
 98        "flags",
 99    ]
100    if pbc_training:
101        input_keys += ["cells", "reciprocal_cells"]
102    if atom_padding:
103        input_keys += ["true_atoms", "true_sys"]
104    if coordinates_ref_key is not None:
105        input_keys += ["system_index", "system_sign"]
106
107    flags = {f: None for f in add_flags}
108    if minimum_image and pbc_training:
109        flags["minimum_image"] = None
110
111    additional_input_keys = set(training_parameters.get("additional_input_keys", []))
112    additional_input_keys_ = set()
113    for key in additional_input_keys:
114        if key not in input_keys:
115            additional_input_keys_.add(key)
116    additional_input_keys = additional_input_keys_
117
118    all_inputs = set(input_keys + list(additional_input_keys))
119
120    extract_all_keys = ref_keys is None
121    if ref_keys is not None:
122        ref_keys = set(ref_keys)
123        ref_keys_ = set()
124        for key in ref_keys:
125            if key not in all_inputs:
126                ref_keys_.add(key)
127
128    random_rotation = training_parameters.get("random_rotation", False)
129    if random_rotation:
130        assert np_rng is not None, "np_rng must be provided for adding noise."
131
132        apply_rotation = {
133            1: lambda x, r: x @ r,
134            -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r),
135            2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r),
136        }
137        def rotate_2f(x,r):
138            assert x.shape[-1]==6
139            # select from 6 components (xx,yy,zz,xy,xz,yz) to form the 3x3 tensor
140            indices = np.array([0,3,4,3,1,5,4,5,2])
141            x=x[...,indices].reshape(*x.shape[:-1],3,3)
142            x=np.einsum("li,...lk,kj->...ij", r, x, r)
143            # select back the 6 components
144            indices = np.array([[0,0],[1,1],[2,2],[0,1],[0,2],[1,2]])
145            x=x[...,indices[:,0],indices[:,1]]
146            return x
147        apply_rotation[-2]=rotate_2f
148
149        valid_rotations = tuple(apply_rotation.keys())
150        rotated_keys = {
151            "coordinates": 1,
152            "forces": 1,
153            "virial_tensor": 2,
154            "stress_tensor": 2,
155            "virial": 2,
156            "stress": 2,
157        }
158        if pbc_training:
159            rotated_keys["cells"] = 1
160        user_rotated_keys = dict(training_parameters.get("rotated_keys", {}))
161        for k, v in user_rotated_keys.items():
162            assert (
163                v in valid_rotations
164            ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}"
165            rotated_keys[k] = v
166
167        # rotated_vector_keys = set(
168        #     ["coordinates", "forces"]
169        #     + list(training_parameters.get("rotated_vector_keys", []))
170        # )
171        # if pbc_training:
172        #     rotated_vector_keys.add("cells")
173
174        # rotated_tensor_keys = set(
175        #     ["virial_tensor", "stress_tensor", "virial", "stress"]
176        #     + list(training_parameters.get("rotated_tensor_keys", []))
177        # )
178        # assert rotated_vector_keys.isdisjoint(
179        #     rotated_tensor_keys
180        # ), "Rotated vector keys and rotated tensor keys must be disjoint."
181        # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys)
182
183        print(
184            "Applying random rotations to the following keys if present:",
185            list(rotated_keys.keys()),
186        )
187
188        def apply_random_rotations(output, nbatch):
189            euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3))
190            r = [
191                Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T
192                for i in range(nbatch)
193            ]
194            for k, l in rotated_keys.items():
195                if k in output:
196                    for i in range(nbatch):
197                        output[k][i] = apply_rotation[l](output[k][i], r[i])
198
199    else:
200
201        def apply_random_rotations(output, nbatch):
202            pass
203
204    flow_matching = training_parameters.get("flow_matching", False)
205    if flow_matching:
206        if ref_keys is not None:
207            ref_keys.add("flow_matching_target")
208            if "flow_matching_target" in ref_keys_:
209                ref_keys_.remove("flow_matching_target")
210        all_inputs.add("flow_matching_time")
211
212        def add_flow_matching(output, nbatch):
213            ts = np_rng.uniform(0.0, 1.0, (nbatch,))
214            targets = []
215            for i in range(nbatch):
216                x1 = output["coordinates"][i]
217                com = x1.mean(axis=0, keepdims=True)
218                x1 = x1 - com
219                x0 = np_rng.normal(0.0, 1.0, x1.shape)
220                xt = (1 - ts[i]) * x0 + ts[i] * x1
221                output["coordinates"][i] = xt
222                targets.append(x1 - x0)
223            output["flow_matching_target"] = targets
224            output["flow_matching_time"] = [np.array(t) for t in ts]
225
226    else:
227
228        def add_flow_matching(output, nbatch):
229            pass
230
231    if pbc_training:
232        print("Periodic boundary conditions are active.")
233        length_nopbc = training_parameters.get("length_nopbc", 1000.0)
234
235        def add_cell(d, output):
236            if "cell" not in d:
237                cell = np.asarray(
238                    [
239                        [length_nopbc, 0.0, 0.0],
240                        [0.0, length_nopbc, 0.0],
241                        [0.0, 0.0, length_nopbc],
242                    ],
243                    dtype=fprec,
244                )
245            else:
246                cell = np.asarray(d["cell"], dtype=fprec)
247            output["cells"].append(cell.reshape(1, 3, 3))
248
249    else:
250
251        def add_cell(d, output):
252            if "cell" in d:
253                print(
254                    "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions."
255                )
256
257    if extract_all_keys:
258
259        def add_other_keys(d, output, atom_shift):
260            for k, v in d.items():
261                if k in ("cell", "total_charge"):
262                    continue
263                v_array = np.array(v)
264                # Shift atom number if necessary
265                if k.endswith("_atidx"):
266                    v_array = v_array + atom_shift
267                output[k].append(v_array)
268
269    else:
270
271        def add_other_keys(d, output, atom_shift):
272            output["species"].append(np.asarray(d["species"]))
273            output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec))
274            for k in additional_input_keys:
275                v_array = np.array(d[k])
276                # Shift atom number if necessary
277                if k.endswith("_atidx"):
278                    v_array = v_array + atom_shift
279                output[k].append(v_array)
280            for k in ref_keys_:
281                v_array = np.array(d[k])
282                # Shift atom number if necessary
283                if k.endswith("_atidx"):
284                    v_array = v_array + atom_shift
285                output[k].append(v_array)
286                if k + "_mask" in d:
287                    output[k + "_mask"].append(np.asarray(d[k + "_mask"]))
288
289    def add_keys(d, output, atom_shift, batch_index):
290        nat = d["species"].shape[0]
291
292        output["natoms"].append(np.asarray([nat]))
293        output["batch_index"].append(np.asarray([batch_index] * nat))
294        if "total_charge" not in d:
295            total_charge = np.asarray(0.0, dtype=fprec)
296        else:
297            total_charge = np.asarray(d["total_charge"], dtype=fprec)
298        output["total_charge"].append(total_charge)
299
300        add_cell(d, output)
301        add_other_keys(d, output, atom_shift)
302
303        return atom_shift + nat
304
305    def collate_fn_(batch):
306        output = defaultdict(list)
307        atom_shift = 0
308        batch_index = 0
309
310        for d in batch:
311            atom_shift = add_keys(d, output, atom_shift, batch_index)
312            batch_index += 1
313
314            if coordinates_ref_key is not None:
315                output["system_index"].append(np.asarray([batch_index - 1]))
316                output["system_sign"].append(np.asarray([1]))
317                if coordinates_ref_key in d:
318                    dref = {**d, "coordinates": d[coordinates_ref_key]}
319                    atom_shift = add_keys(dref, output, atom_shift, batch_index)
320                    output["system_index"].append(np.asarray([batch_index - 1]))
321                    output["system_sign"].append(np.asarray([-1]))
322                    batch_index += 1
323
324        nbatch_ = len(output["natoms"])
325        apply_random_rotations(output,nbatch_)
326        add_flow_matching(output,nbatch_)
327
328        # Stack and concatenate the arrays
329        for k, v in output.items():
330            if v[0].ndim == 0:
331                v = np.stack(v)
332            else:
333                v = np.concatenate(v, axis=0)
334            if np.issubdtype(v.dtype, np.floating):
335                v = v.astype(fprec)
336            output[k] = v
337
338        if "cells" in output and pbc_training:
339            output["reciprocal_cells"] = np.linalg.inv(output["cells"])
340
341        # Rename necessary keys
342        # for key in rename_refs:
343        #     if key in output:
344        #         output["true_" + key] = output.pop(key)
345        for kold, knew in rename_refs.items():
346            assert (
347                knew not in output
348            ), f"Cannot rename key {kold} to {knew}. Key {knew} already present."
349            if kold in output:
350                output[knew] = output.pop(kold)
351
352        output["flags"] = flags
353        return output
354
355    collate_layers_train = [collate_fn_]
356    collate_layers_valid = [collate_fn_]
357
358    ### collate preprocessing
359    # add noise to the training data
360    noise_sigma = training_parameters.get("noise_sigma", None)
361    if noise_sigma is not None:
362        assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary"
363
364        for sigma in noise_sigma.values():
365            assert sigma >= 0, "Noise sigma should be a positive number"
366
367        print("Adding noise to the training data:")
368        for key, sigma in noise_sigma.items():
369            print(f"  - {key} with sigma = {sigma}")
370
371        assert np_rng is not None, "np_rng must be provided for adding noise."
372
373        def collate_with_noise(batch):
374            for key, sigma in noise_sigma.items():
375                if key in batch and sigma > 0:
376                    batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype(
377                        batch[key].dtype
378                    )
379            return batch
380
381        collate_layers_train.append(collate_with_noise)
382
383    if atom_padding:
384        padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0))
385        padder_state = padder.init()
386
387        def collate_with_padding(batch):
388            padder_state_up, output = padder(padder_state, batch)
389            padder_state.update(padder_state_up)
390            return output
391
392        collate_layers_train.append(collate_with_padding)
393        collate_layers_valid.append(collate_with_padding)
394
395    if split_data_inputs:
396
397        # input_keys += additional_input_keys
398        # input_keys = set(input_keys)
399        print("Input keys:", all_inputs)
400        print("Ref keys:", ref_keys)
401
402        def collate_split(batch):
403            inputs = {}
404            refs = {}
405            for k, v in batch.items():
406                if k in all_inputs:
407                    inputs[k] = v
408                if k in ref_keys:
409                    refs[k] = v
410                if k.endswith("_mask") and k[:-5] in ref_keys:
411                    refs[k] = v
412            return inputs, refs
413
414        collate_layers_train.append(collate_split)
415        collate_layers_valid.append(collate_split)
416
417    ### apply all collate preprocessing
418    if len(collate_layers_train) == 1:
419        collate_fn_train = collate_layers_train[0]
420    else:
421
422        def collate_fn_train(batch):
423            for layer in collate_layers_train:
424                batch = layer(batch)
425            return batch
426
427    if len(collate_layers_valid) == 1:
428        collate_fn_valid = collate_layers_valid[0]
429    else:
430
431        def collate_fn_valid(batch):
432            for layer in collate_layers_valid:
433                batch = layer(batch)
434            return batch
435
436    if not os.path.exists(dspath):
437        raise ValueError(f"Dataset file '{dspath}' not found.")
438    # dspath = training_parameters.get("dspath", None)
439    print(f"Loading dataset from {dspath}...", end="")
440    # print(f"   the following keys will be renamed if present : {rename_refs}")
441    sharded_training = False
442    if dspath.endswith(".db"):
443        dataset = {}
444        if train_val_split:
445            dataset["training"] = DBDataset(dspath, table="training")
446            dataset["validation"] = DBDataset(dspath, table="validation")
447        else:
448            dataset = DBDataset(dspath)
449    elif dspath.endswith(".h5") or dspath.endswith(".hdf5"):
450        dataset = {}
451        if train_val_split:
452            dataset["training"] = H5Dataset(dspath, table="training")
453            dataset["validation"] = H5Dataset(dspath, table="validation")
454        else:
455            dataset = H5Dataset(dspath)
456    elif dspath.endswith(".pkl") or dspath.endswith(".pickle"):
457        with open(dspath, "rb") as f:
458            dataset = pickle.load(f)
459        if not train_val_split and isinstance(dataset, dict):
460            dataset = dataset["training"]
461    elif os.path.isdir(dspath):
462        if train_val_split:
463            dataset = {}
464            with open(dspath + "/validation.pkl", "rb") as f:
465                dataset["validation"] = pickle.load(f)
466        else:
467            dataset = None
468
469        shard_files = sorted(glob.glob(dspath + "/training_*.pkl"))
470        nshards = len(shard_files)
471        if nshards == 0:
472            raise ValueError("No dataset shards found.")
473        elif nshards == 1:
474            with open(shard_files[0], "rb") as f:
475                if train_val_split:
476                    dataset["training"] = pickle.load(f)
477                else:
478                    dataset = pickle.load(f)
479        else:
480            print(f"Found {nshards} dataset shards.")
481            sharded_training = True
482
483    else:
484        raise ValueError(
485            f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'"
486        )
487    print(" done.")
488
489    ### BUILD DATALOADERS
490    # batch_size = training_parameters.get("batch_size", 16)
491    shuffle = training_parameters.get("shuffle_dataset", True)
492    if train_val_split:
493        if batch_size_val is None:
494            batch_size_val = batch_size
495        dataloader_validation = DataLoader(
496            dataset["validation"],
497            batch_size=batch_size_val,
498            shuffle=shuffle,
499            collate_fn=collate_fn_valid,
500        )
501
502    if sharded_training:
503
504        def iterate_sharded_dataset():
505            indices = np.arange(nshards)
506            if shuffle:
507                assert np_rng is not None, "np_rng must be provided for shuffling."
508                np_rng.shuffle(indices)
509            for i in indices:
510                filename = shard_files[i]
511                print(f"# Loading dataset shard from {filename}...", end="")
512                with open(filename, "rb") as f:
513                    dataset = pickle.load(f)
514                print(" done.")
515                dataloader = DataLoader(
516                    dataset,
517                    batch_size=batch_size,
518                    shuffle=shuffle,
519                    collate_fn=collate_fn_train,
520                )
521                for batch in dataloader:
522                    yield batch
523
524        class DataLoaderSharded:
525            def __iter__(self):
526                return iterate_sharded_dataset()
527
528        dataloader_training = DataLoaderSharded()
529    else:
530        dataloader_training = DataLoader(
531            dataset["training"] if train_val_split else dataset,
532            batch_size=batch_size,
533            shuffle=shuffle,
534            collate_fn=collate_fn_train,
535        )
536
537    if not infinite_iterator:
538        if train_val_split:
539            return dataloader_training, dataloader_validation
540        return dataloader_training
541
542    def next_batch_factory(dataloader):
543        while True:
544            for batch in dataloader:
545                yield batch
546
547    training_iterator = next_batch_factory(dataloader_training)
548    if train_val_split:
549        validation_iterator = next_batch_factory(dataloader_validation)
550        return training_iterator, validation_iterator
551    return training_iterator
552
553
554def load_model(
555    parameters: Dict[str, any],
556    model_file: Optional[str] = None,
557    rng_key: Optional[str] = None,
558) -> FENNIX:
559    """
560    Load a FENNIX model from a file or create a new one.
561
562    Args:
563        parameters (dict): A dictionary of parameters for the model.
564        model_file (str, optional): The path to a saved model file to load.
565
566    Returns:
567        FENNIX: A FENNIX model object.
568    """
569    print_model = parameters["training"].get("print_model", False)
570    if model_file is None:
571        model_file = parameters.get("model_file", None)
572    if model_file is not None and os.path.exists(model_file):
573        model = FENNIX.load(model_file, use_atom_padding=False)
574        if print_model:
575            print(model.summarize())
576        print(f"Restored model from '{model_file}'.")
577    else:
578        assert (
579            rng_key is not None
580        ), "rng_key must be specified if model_file is not provided."
581        model_params = parameters["model"]
582        if isinstance(model_params, str):
583            assert os.path.exists(
584                model_params
585            ), f"Model file '{model_params}' not found."
586            model = FENNIX.load(model_params, use_atom_padding=False)
587            print(f"Restored model from '{model_params}'.")
588        else:
589            model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False)
590        if print_model:
591            print(model.summarize())
592    return model
593
594
595def copy_parameters(variables, variables_ref, params=[".*"]):
596    def merge_params(full_path_, v, v_ref):
597        full_path = "/".join(full_path_[1:]).lower()
598        # status = (False, "")
599        for path in params:
600            if re.match(path.lower(), full_path):
601            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
602                return v_ref
603        return v
604        # return v_ref if status[0] else v
605
606    flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False)
607    flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False)
608    return traverse_util.unflatten_dict(
609        {
610            k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v
611            for k, v in flat.items()
612        }
613    )
614
615
616class TeeLogger(object):
617    def __init__(self, name):
618        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
619        self.stdout = None
620
621    def __del__(self):
622        self.close()
623
624    def write(self, data):
625        self.file.write(data)
626        self.stdout.write(data)
627        self.flush()
628
629    def close(self):
630        self.file.close()
631
632    def flush(self):
633        self.file.flush()
634
635    def bind_stdout(self):
636        if isinstance(sys.stdout, TeeLogger):
637            raise ValueError("stdout already bound to a Tee instance.")
638        if self.stdout is not None:
639            raise ValueError("stdout already bound.")
640        self.stdout = sys.stdout
641        sys.stdout = self
642
643    def unbind_stdout(self):
644        if self.stdout is None:
645            raise ValueError("stdout is not bound.")
646        sys.stdout = self.stdout
def load_configuration(config_file: str) -> Dict[str, <built-in function any>]:
31def load_configuration(config_file: str) -> Dict[str, any]:
32    if config_file.endswith(".json"):
33        parameters = json.load(open(config_file))
34    elif config_file.endswith(".yaml") or config_file.endswith(".yml"):
35        parameters = yaml.load(open(config_file), Loader=yaml.FullLoader)
36    elif tomlkit is not None and config_file.endswith(".toml"):
37        parameters = tomlkit.loads(open(config_file).read())
38    else:
39        supported_formats = [".json", ".yaml", ".yml"]
40        if tomlkit is not None:
41            supported_formats.append(".toml")
42        raise ValueError(
43            f"Unknown config file format. Supported formats: {supported_formats}"
44        )
45    return parameters
def load_dataset( dspath: str, batch_size: int, batch_size_val: Optional[int] = None, rename_refs: dict = {}, infinite_iterator: bool = False, atom_padding: bool = False, ref_keys: Optional[Sequence[str]] = None, split_data_inputs: bool = False, np_rng: Optional[numpy.random._generator.Generator] = None, train_val_split: bool = True, training_parameters: dict = {}, add_flags: Sequence[str] = ['training'], fprec: str = 'float32'):
 48def load_dataset(
 49    dspath: str,
 50    batch_size: int,
 51    batch_size_val: Optional[int] = None,
 52    rename_refs: dict = {},
 53    infinite_iterator: bool = False,
 54    atom_padding: bool = False,
 55    ref_keys: Optional[Sequence[str]] = None,
 56    split_data_inputs: bool = False,
 57    np_rng: Optional[np.random.Generator] = None,
 58    train_val_split: bool = True,
 59    training_parameters: dict = {},
 60    add_flags: Sequence[str] = ["training"],
 61    fprec: str = "float32",
 62):
 63    """
 64    Load a dataset from a pickle file and return two iterators for training and validation batches.
 65
 66    Args:
 67        training_parameters (dict): A dictionary with the following keys:
 68            - 'dspath': str. Path to the pickle file containing the dataset.
 69            - 'batch_size': int. Number of samples per batch.
 70        rename_refs (list, optional): A list of strings with the names of the reference properties to rename.
 71            Default is an empty list.
 72
 73    Returns:
 74        tuple: A tuple of two infinite iterators, one for training batches and one for validation batches.
 75            For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems:
 76            - 'natoms': np.ndarray. Array with the number of atoms in each system.
 77            - 'batch_index': np.ndarray. Array with the index of the system to which each atom
 78            if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.
 79    """
 80
 81    assert isinstance(
 82        training_parameters, dict
 83    ), "training_parameters must be a dictionary."
 84    assert isinstance(
 85        rename_refs, dict
 86    ), "rename_refs must be a dictionary with the keys to rename."
 87
 88    pbc_training = training_parameters.get("pbc_training", False)
 89    minimum_image = training_parameters.get("minimum_image", False)
 90
 91    coordinates_ref_key = training_parameters.get("coordinates_ref_key", None)
 92
 93    input_keys = [
 94        "species",
 95        "coordinates",
 96        "natoms",
 97        "batch_index",
 98        "total_charge",
 99        "flags",
100    ]
101    if pbc_training:
102        input_keys += ["cells", "reciprocal_cells"]
103    if atom_padding:
104        input_keys += ["true_atoms", "true_sys"]
105    if coordinates_ref_key is not None:
106        input_keys += ["system_index", "system_sign"]
107
108    flags = {f: None for f in add_flags}
109    if minimum_image and pbc_training:
110        flags["minimum_image"] = None
111
112    additional_input_keys = set(training_parameters.get("additional_input_keys", []))
113    additional_input_keys_ = set()
114    for key in additional_input_keys:
115        if key not in input_keys:
116            additional_input_keys_.add(key)
117    additional_input_keys = additional_input_keys_
118
119    all_inputs = set(input_keys + list(additional_input_keys))
120
121    extract_all_keys = ref_keys is None
122    if ref_keys is not None:
123        ref_keys = set(ref_keys)
124        ref_keys_ = set()
125        for key in ref_keys:
126            if key not in all_inputs:
127                ref_keys_.add(key)
128
129    random_rotation = training_parameters.get("random_rotation", False)
130    if random_rotation:
131        assert np_rng is not None, "np_rng must be provided for adding noise."
132
133        apply_rotation = {
134            1: lambda x, r: x @ r,
135            -1: lambda x, r: np.einsum("...kn,kj->...jn", x, r),
136            2: lambda x, r: np.einsum("li,...lk,kj->...ij", r, x, r),
137        }
138        def rotate_2f(x,r):
139            assert x.shape[-1]==6
140            # select from 6 components (xx,yy,zz,xy,xz,yz) to form the 3x3 tensor
141            indices = np.array([0,3,4,3,1,5,4,5,2])
142            x=x[...,indices].reshape(*x.shape[:-1],3,3)
143            x=np.einsum("li,...lk,kj->...ij", r, x, r)
144            # select back the 6 components
145            indices = np.array([[0,0],[1,1],[2,2],[0,1],[0,2],[1,2]])
146            x=x[...,indices[:,0],indices[:,1]]
147            return x
148        apply_rotation[-2]=rotate_2f
149
150        valid_rotations = tuple(apply_rotation.keys())
151        rotated_keys = {
152            "coordinates": 1,
153            "forces": 1,
154            "virial_tensor": 2,
155            "stress_tensor": 2,
156            "virial": 2,
157            "stress": 2,
158        }
159        if pbc_training:
160            rotated_keys["cells"] = 1
161        user_rotated_keys = dict(training_parameters.get("rotated_keys", {}))
162        for k, v in user_rotated_keys.items():
163            assert (
164                v in valid_rotations
165            ), f"Invalid rotation type for key {k}. Valid values are {valid_rotations}"
166            rotated_keys[k] = v
167
168        # rotated_vector_keys = set(
169        #     ["coordinates", "forces"]
170        #     + list(training_parameters.get("rotated_vector_keys", []))
171        # )
172        # if pbc_training:
173        #     rotated_vector_keys.add("cells")
174
175        # rotated_tensor_keys = set(
176        #     ["virial_tensor", "stress_tensor", "virial", "stress"]
177        #     + list(training_parameters.get("rotated_tensor_keys", []))
178        # )
179        # assert rotated_vector_keys.isdisjoint(
180        #     rotated_tensor_keys
181        # ), "Rotated vector keys and rotated tensor keys must be disjoint."
182        # rotated_keys = rotated_vector_keys.union(rotated_tensor_keys)
183
184        print(
185            "Applying random rotations to the following keys if present:",
186            list(rotated_keys.keys()),
187        )
188
189        def apply_random_rotations(output, nbatch):
190            euler_angles = np_rng.uniform(0.0, 2 * np.pi, (nbatch, 3))
191            r = [
192                Rotation.from_euler("xyz", euler_angles[i]).as_matrix().T
193                for i in range(nbatch)
194            ]
195            for k, l in rotated_keys.items():
196                if k in output:
197                    for i in range(nbatch):
198                        output[k][i] = apply_rotation[l](output[k][i], r[i])
199
200    else:
201
202        def apply_random_rotations(output, nbatch):
203            pass
204
205    flow_matching = training_parameters.get("flow_matching", False)
206    if flow_matching:
207        if ref_keys is not None:
208            ref_keys.add("flow_matching_target")
209            if "flow_matching_target" in ref_keys_:
210                ref_keys_.remove("flow_matching_target")
211        all_inputs.add("flow_matching_time")
212
213        def add_flow_matching(output, nbatch):
214            ts = np_rng.uniform(0.0, 1.0, (nbatch,))
215            targets = []
216            for i in range(nbatch):
217                x1 = output["coordinates"][i]
218                com = x1.mean(axis=0, keepdims=True)
219                x1 = x1 - com
220                x0 = np_rng.normal(0.0, 1.0, x1.shape)
221                xt = (1 - ts[i]) * x0 + ts[i] * x1
222                output["coordinates"][i] = xt
223                targets.append(x1 - x0)
224            output["flow_matching_target"] = targets
225            output["flow_matching_time"] = [np.array(t) for t in ts]
226
227    else:
228
229        def add_flow_matching(output, nbatch):
230            pass
231
232    if pbc_training:
233        print("Periodic boundary conditions are active.")
234        length_nopbc = training_parameters.get("length_nopbc", 1000.0)
235
236        def add_cell(d, output):
237            if "cell" not in d:
238                cell = np.asarray(
239                    [
240                        [length_nopbc, 0.0, 0.0],
241                        [0.0, length_nopbc, 0.0],
242                        [0.0, 0.0, length_nopbc],
243                    ],
244                    dtype=fprec,
245                )
246            else:
247                cell = np.asarray(d["cell"], dtype=fprec)
248            output["cells"].append(cell.reshape(1, 3, 3))
249
250    else:
251
252        def add_cell(d, output):
253            if "cell" in d:
254                print(
255                    "Warning: 'cell' found in dataset but not training with pbc. Activate pbc_training to use periodic boundary conditions."
256                )
257
258    if extract_all_keys:
259
260        def add_other_keys(d, output, atom_shift):
261            for k, v in d.items():
262                if k in ("cell", "total_charge"):
263                    continue
264                v_array = np.array(v)
265                # Shift atom number if necessary
266                if k.endswith("_atidx"):
267                    v_array = v_array + atom_shift
268                output[k].append(v_array)
269
270    else:
271
272        def add_other_keys(d, output, atom_shift):
273            output["species"].append(np.asarray(d["species"]))
274            output["coordinates"].append(np.asarray(d["coordinates"], dtype=fprec))
275            for k in additional_input_keys:
276                v_array = np.array(d[k])
277                # Shift atom number if necessary
278                if k.endswith("_atidx"):
279                    v_array = v_array + atom_shift
280                output[k].append(v_array)
281            for k in ref_keys_:
282                v_array = np.array(d[k])
283                # Shift atom number if necessary
284                if k.endswith("_atidx"):
285                    v_array = v_array + atom_shift
286                output[k].append(v_array)
287                if k + "_mask" in d:
288                    output[k + "_mask"].append(np.asarray(d[k + "_mask"]))
289
290    def add_keys(d, output, atom_shift, batch_index):
291        nat = d["species"].shape[0]
292
293        output["natoms"].append(np.asarray([nat]))
294        output["batch_index"].append(np.asarray([batch_index] * nat))
295        if "total_charge" not in d:
296            total_charge = np.asarray(0.0, dtype=fprec)
297        else:
298            total_charge = np.asarray(d["total_charge"], dtype=fprec)
299        output["total_charge"].append(total_charge)
300
301        add_cell(d, output)
302        add_other_keys(d, output, atom_shift)
303
304        return atom_shift + nat
305
306    def collate_fn_(batch):
307        output = defaultdict(list)
308        atom_shift = 0
309        batch_index = 0
310
311        for d in batch:
312            atom_shift = add_keys(d, output, atom_shift, batch_index)
313            batch_index += 1
314
315            if coordinates_ref_key is not None:
316                output["system_index"].append(np.asarray([batch_index - 1]))
317                output["system_sign"].append(np.asarray([1]))
318                if coordinates_ref_key in d:
319                    dref = {**d, "coordinates": d[coordinates_ref_key]}
320                    atom_shift = add_keys(dref, output, atom_shift, batch_index)
321                    output["system_index"].append(np.asarray([batch_index - 1]))
322                    output["system_sign"].append(np.asarray([-1]))
323                    batch_index += 1
324
325        nbatch_ = len(output["natoms"])
326        apply_random_rotations(output,nbatch_)
327        add_flow_matching(output,nbatch_)
328
329        # Stack and concatenate the arrays
330        for k, v in output.items():
331            if v[0].ndim == 0:
332                v = np.stack(v)
333            else:
334                v = np.concatenate(v, axis=0)
335            if np.issubdtype(v.dtype, np.floating):
336                v = v.astype(fprec)
337            output[k] = v
338
339        if "cells" in output and pbc_training:
340            output["reciprocal_cells"] = np.linalg.inv(output["cells"])
341
342        # Rename necessary keys
343        # for key in rename_refs:
344        #     if key in output:
345        #         output["true_" + key] = output.pop(key)
346        for kold, knew in rename_refs.items():
347            assert (
348                knew not in output
349            ), f"Cannot rename key {kold} to {knew}. Key {knew} already present."
350            if kold in output:
351                output[knew] = output.pop(kold)
352
353        output["flags"] = flags
354        return output
355
356    collate_layers_train = [collate_fn_]
357    collate_layers_valid = [collate_fn_]
358
359    ### collate preprocessing
360    # add noise to the training data
361    noise_sigma = training_parameters.get("noise_sigma", None)
362    if noise_sigma is not None:
363        assert isinstance(noise_sigma, dict), "noise_sigma should be a dictionary"
364
365        for sigma in noise_sigma.values():
366            assert sigma >= 0, "Noise sigma should be a positive number"
367
368        print("Adding noise to the training data:")
369        for key, sigma in noise_sigma.items():
370            print(f"  - {key} with sigma = {sigma}")
371
372        assert np_rng is not None, "np_rng must be provided for adding noise."
373
374        def collate_with_noise(batch):
375            for key, sigma in noise_sigma.items():
376                if key in batch and sigma > 0:
377                    batch[key] += np_rng.normal(0, sigma, batch[key].shape).astype(
378                        batch[key].dtype
379                    )
380            return batch
381
382        collate_layers_train.append(collate_with_noise)
383
384    if atom_padding:
385        padder = AtomPadding(add_sys=training_parameters.get("padder_add_sys", 0))
386        padder_state = padder.init()
387
388        def collate_with_padding(batch):
389            padder_state_up, output = padder(padder_state, batch)
390            padder_state.update(padder_state_up)
391            return output
392
393        collate_layers_train.append(collate_with_padding)
394        collate_layers_valid.append(collate_with_padding)
395
396    if split_data_inputs:
397
398        # input_keys += additional_input_keys
399        # input_keys = set(input_keys)
400        print("Input keys:", all_inputs)
401        print("Ref keys:", ref_keys)
402
403        def collate_split(batch):
404            inputs = {}
405            refs = {}
406            for k, v in batch.items():
407                if k in all_inputs:
408                    inputs[k] = v
409                if k in ref_keys:
410                    refs[k] = v
411                if k.endswith("_mask") and k[:-5] in ref_keys:
412                    refs[k] = v
413            return inputs, refs
414
415        collate_layers_train.append(collate_split)
416        collate_layers_valid.append(collate_split)
417
418    ### apply all collate preprocessing
419    if len(collate_layers_train) == 1:
420        collate_fn_train = collate_layers_train[0]
421    else:
422
423        def collate_fn_train(batch):
424            for layer in collate_layers_train:
425                batch = layer(batch)
426            return batch
427
428    if len(collate_layers_valid) == 1:
429        collate_fn_valid = collate_layers_valid[0]
430    else:
431
432        def collate_fn_valid(batch):
433            for layer in collate_layers_valid:
434                batch = layer(batch)
435            return batch
436
437    if not os.path.exists(dspath):
438        raise ValueError(f"Dataset file '{dspath}' not found.")
439    # dspath = training_parameters.get("dspath", None)
440    print(f"Loading dataset from {dspath}...", end="")
441    # print(f"   the following keys will be renamed if present : {rename_refs}")
442    sharded_training = False
443    if dspath.endswith(".db"):
444        dataset = {}
445        if train_val_split:
446            dataset["training"] = DBDataset(dspath, table="training")
447            dataset["validation"] = DBDataset(dspath, table="validation")
448        else:
449            dataset = DBDataset(dspath)
450    elif dspath.endswith(".h5") or dspath.endswith(".hdf5"):
451        dataset = {}
452        if train_val_split:
453            dataset["training"] = H5Dataset(dspath, table="training")
454            dataset["validation"] = H5Dataset(dspath, table="validation")
455        else:
456            dataset = H5Dataset(dspath)
457    elif dspath.endswith(".pkl") or dspath.endswith(".pickle"):
458        with open(dspath, "rb") as f:
459            dataset = pickle.load(f)
460        if not train_val_split and isinstance(dataset, dict):
461            dataset = dataset["training"]
462    elif os.path.isdir(dspath):
463        if train_val_split:
464            dataset = {}
465            with open(dspath + "/validation.pkl", "rb") as f:
466                dataset["validation"] = pickle.load(f)
467        else:
468            dataset = None
469
470        shard_files = sorted(glob.glob(dspath + "/training_*.pkl"))
471        nshards = len(shard_files)
472        if nshards == 0:
473            raise ValueError("No dataset shards found.")
474        elif nshards == 1:
475            with open(shard_files[0], "rb") as f:
476                if train_val_split:
477                    dataset["training"] = pickle.load(f)
478                else:
479                    dataset = pickle.load(f)
480        else:
481            print(f"Found {nshards} dataset shards.")
482            sharded_training = True
483
484    else:
485        raise ValueError(
486            f"Unknown dataset format. Supported formats: '.db', '.h5', '.pkl', '.pickle'"
487        )
488    print(" done.")
489
490    ### BUILD DATALOADERS
491    # batch_size = training_parameters.get("batch_size", 16)
492    shuffle = training_parameters.get("shuffle_dataset", True)
493    if train_val_split:
494        if batch_size_val is None:
495            batch_size_val = batch_size
496        dataloader_validation = DataLoader(
497            dataset["validation"],
498            batch_size=batch_size_val,
499            shuffle=shuffle,
500            collate_fn=collate_fn_valid,
501        )
502
503    if sharded_training:
504
505        def iterate_sharded_dataset():
506            indices = np.arange(nshards)
507            if shuffle:
508                assert np_rng is not None, "np_rng must be provided for shuffling."
509                np_rng.shuffle(indices)
510            for i in indices:
511                filename = shard_files[i]
512                print(f"# Loading dataset shard from {filename}...", end="")
513                with open(filename, "rb") as f:
514                    dataset = pickle.load(f)
515                print(" done.")
516                dataloader = DataLoader(
517                    dataset,
518                    batch_size=batch_size,
519                    shuffle=shuffle,
520                    collate_fn=collate_fn_train,
521                )
522                for batch in dataloader:
523                    yield batch
524
525        class DataLoaderSharded:
526            def __iter__(self):
527                return iterate_sharded_dataset()
528
529        dataloader_training = DataLoaderSharded()
530    else:
531        dataloader_training = DataLoader(
532            dataset["training"] if train_val_split else dataset,
533            batch_size=batch_size,
534            shuffle=shuffle,
535            collate_fn=collate_fn_train,
536        )
537
538    if not infinite_iterator:
539        if train_val_split:
540            return dataloader_training, dataloader_validation
541        return dataloader_training
542
543    def next_batch_factory(dataloader):
544        while True:
545            for batch in dataloader:
546                yield batch
547
548    training_iterator = next_batch_factory(dataloader_training)
549    if train_val_split:
550        validation_iterator = next_batch_factory(dataloader_validation)
551        return training_iterator, validation_iterator
552    return training_iterator

Load a dataset from a pickle file and return two iterators for training and validation batches.

Args: training_parameters (dict): A dictionary with the following keys: - 'dspath': str. Path to the pickle file containing the dataset. - 'batch_size': int. Number of samples per batch. rename_refs (list, optional): A list of strings with the names of the reference properties to rename. Default is an empty list.

Returns: tuple: A tuple of two infinite iterators, one for training batches and one for validation batches. For each element in the batch, we expect a "species" key with the atomic numbers of the atoms in the system. Arrays are concatenated along the first axis and the following keys are added to distinguish between the systems: - 'natoms': np.ndarray. Array with the number of atoms in each system. - 'batch_index': np.ndarray. Array with the index of the system to which each atom if the keys "forces", "total_energy", "atomic_energies" or any of the elements in rename_refs are present, the keys are renamed by prepending "true_" to the key name.

def load_model( parameters: Dict[str, <built-in function any>], model_file: Optional[str] = None, rng_key: Optional[str] = None) -> fennol.models.fennix.FENNIX:
555def load_model(
556    parameters: Dict[str, any],
557    model_file: Optional[str] = None,
558    rng_key: Optional[str] = None,
559) -> FENNIX:
560    """
561    Load a FENNIX model from a file or create a new one.
562
563    Args:
564        parameters (dict): A dictionary of parameters for the model.
565        model_file (str, optional): The path to a saved model file to load.
566
567    Returns:
568        FENNIX: A FENNIX model object.
569    """
570    print_model = parameters["training"].get("print_model", False)
571    if model_file is None:
572        model_file = parameters.get("model_file", None)
573    if model_file is not None and os.path.exists(model_file):
574        model = FENNIX.load(model_file, use_atom_padding=False)
575        if print_model:
576            print(model.summarize())
577        print(f"Restored model from '{model_file}'.")
578    else:
579        assert (
580            rng_key is not None
581        ), "rng_key must be specified if model_file is not provided."
582        model_params = parameters["model"]
583        if isinstance(model_params, str):
584            assert os.path.exists(
585                model_params
586            ), f"Model file '{model_params}' not found."
587            model = FENNIX.load(model_params, use_atom_padding=False)
588            print(f"Restored model from '{model_params}'.")
589        else:
590            model = FENNIX(**model_params, rng_key=rng_key, use_atom_padding=False)
591        if print_model:
592            print(model.summarize())
593    return model

Load a FENNIX model from a file or create a new one.

Args: parameters (dict): A dictionary of parameters for the model. model_file (str, optional): The path to a saved model file to load.

Returns: FENNIX: A FENNIX model object.

def copy_parameters(variables, variables_ref, params=['.*']):
596def copy_parameters(variables, variables_ref, params=[".*"]):
597    def merge_params(full_path_, v, v_ref):
598        full_path = "/".join(full_path_[1:]).lower()
599        # status = (False, "")
600        for path in params:
601            if re.match(path.lower(), full_path):
602            # if full_path.startswith(path.lower()) and len(path) > len(status[1]):
603                return v_ref
604        return v
605        # return v_ref if status[0] else v
606
607    flat = traverse_util.flatten_dict(variables, keep_empty_nodes=False)
608    flat_ref = traverse_util.flatten_dict(variables_ref, keep_empty_nodes=False)
609    return traverse_util.unflatten_dict(
610        {
611            k: merge_params(k, v, flat_ref[k]) if k in flat_ref else v
612            for k, v in flat.items()
613        }
614    )
class TeeLogger:
617class TeeLogger(object):
618    def __init__(self, name):
619        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
620        self.stdout = None
621
622    def __del__(self):
623        self.close()
624
625    def write(self, data):
626        self.file.write(data)
627        self.stdout.write(data)
628        self.flush()
629
630    def close(self):
631        self.file.close()
632
633    def flush(self):
634        self.file.flush()
635
636    def bind_stdout(self):
637        if isinstance(sys.stdout, TeeLogger):
638            raise ValueError("stdout already bound to a Tee instance.")
639        if self.stdout is not None:
640            raise ValueError("stdout already bound.")
641        self.stdout = sys.stdout
642        sys.stdout = self
643
644    def unbind_stdout(self):
645        if self.stdout is None:
646            raise ValueError("stdout is not bound.")
647        sys.stdout = self.stdout
TeeLogger(name)
618    def __init__(self, name):
619        self.file = io.TextIOWrapper(open(name, "wb"), write_through=True)
620        self.stdout = None
file
stdout
def write(self, data):
625    def write(self, data):
626        self.file.write(data)
627        self.stdout.write(data)
628        self.flush()
def close(self):
630    def close(self):
631        self.file.close()
def flush(self):
633    def flush(self):
634        self.file.flush()
def bind_stdout(self):
636    def bind_stdout(self):
637        if isinstance(sys.stdout, TeeLogger):
638            raise ValueError("stdout already bound to a Tee instance.")
639        if self.stdout is not None:
640            raise ValueError("stdout already bound.")
641        self.stdout = sys.stdout
642        sys.stdout = self
def unbind_stdout(self):
644    def unbind_stdout(self):
645        if self.stdout is None:
646            raise ValueError("stdout is not bound.")
647        sys.stdout = self.stdout