fennol.md.initial

  1from pathlib import Path
  2
  3import numpy as np
  4import jax.numpy as jnp
  5
  6from flax.core import freeze, unfreeze
  7
  8from ..utils.io import last_xyz_frame
  9
 10
 11from ..models import FENNIX
 12
 13from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES
 14from ..utils.atomic_units import AtomicUnits as au
 15
 16
 17def load_model(simulation_parameters):
 18    model_file = simulation_parameters.get("model_file")
 19    model_file = Path(str(model_file).strip())
 20    if not model_file.exists():
 21        raise FileNotFoundError(f"model file {model_file} not found")
 22    else:
 23        graph_config = simulation_parameters.get("graph_config", {})
 24        model = FENNIX.load(model_file, graph_config=graph_config)  # \
 25        print(f"# model_file: {model_file}")
 26
 27    if "energy_terms" in simulation_parameters:
 28        energy_terms = simulation_parameters["energy_terms"]
 29        if isinstance(energy_terms, str):
 30            energy_terms = energy_terms.split()
 31        model.set_energy_terms(energy_terms)
 32        print("# energy terms:", model.energy_terms)
 33
 34    return model
 35
 36
 37def load_system_data(simulation_parameters, fprec):
 38    ## LOAD SYSTEM CONFORMATION FROM FILES
 39    system_name = str(simulation_parameters.get("system", "system")).strip()
 40    indexed = simulation_parameters.get("xyz_input/indexed", True)
 41    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", False)
 42    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 43    if not xyzfile.exists():
 44        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 45    system_name = str(simulation_parameters.get("system", xyzfile.stem)).strip()
 46    symbols, coordinates, _ = last_xyz_frame(
 47        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 48    )
 49    coordinates = coordinates.astype(fprec)
 50    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 51    nat = species.shape[0]
 52
 53    ## GET MASS
 54    mass_amu = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 55    deuterate = simulation_parameters.get("deuterate", False)
 56    if deuterate:
 57        print("# Replacing all hydrogens with deuteriums")
 58        mass_amu[species == 1] *= 2.0
 59    mass = mass_amu * (au.MPROT * (au.FS / au.BOHR) ** 2)
 60
 61    ### GET TEMPERATURE
 62    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
 63    kT = temperature / au.KELVIN
 64    totmass_amu = mass_amu.sum() / 6.02214129e-1
 65
 66    ### GET TOTAL CHARGE
 67    total_charge = simulation_parameters.get("total_charge", 0.0)
 68
 69    ### ENERGY UNIT
 70    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
 71    energy_unit = au.get_multiplier(energy_unit_str)
 72
 73    ## SYSTEM DATA
 74    system_data = {
 75        "name": system_name,
 76        "nat": nat,
 77        "symbols": symbols,
 78        "species": species,
 79        "mass": mass,
 80        "temperature": temperature,
 81        "kT": kT,
 82        "totmass_amu": totmass_amu,
 83        "total_charge": total_charge,
 84        "energy_unit": energy_unit,
 85        "energy_unit_str": energy_unit_str,
 86    }
 87
 88    ### Set boundary conditions
 89    cell = simulation_parameters.get("cell", None)
 90    if cell is not None:
 91        cell = np.array(cell, dtype=fprec).reshape(3, 3)
 92        reciprocal_cell = np.linalg.inv(cell)
 93        volume = np.abs(np.linalg.det(cell))
 94        print("# cell matrix:")
 95        for l in cell:
 96            print("# ", l)
 97        # print(cell)
 98        dens = totmass_amu / volume
 99        print("# density: ", dens.item(), " g/cm^3")
100        minimum_image = simulation_parameters.get("minimum_image", True)
101        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
102        print("# minimum_image: ", minimum_image)
103
104        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
105        if crystal_input:
106            coordinates = coordinates @ cell
107
108        pbc_data = {
109            "cell": cell,
110            "reciprocal_cell": reciprocal_cell,
111            "volume": volume,
112            "minimum_image": minimum_image,
113            "estimate_pressure": estimate_pressure,
114        }
115    else:
116        pbc_data = None
117    system_data["pbc"] = pbc_data
118
119    ### PIMD
120    nbeads = simulation_parameters.get("nbeads", None)
121    nreplicas = simulation_parameters.get("nreplicas", None)
122    if nbeads is not None:
123        nbeads = int(nbeads)
124        print("# nbeads: ", nbeads)
125        system_data["nbeads"] = nbeads
126        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3)
127        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
128        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
129        natoms = np.array([nat] * nbeads, dtype=np.int32)
130
131        eigmat = np.zeros((nbeads, nbeads))
132        for i in range(nbeads - 1):
133            eigmat[i, i] = 2.0
134            eigmat[i, i + 1] = -1.0
135            eigmat[i + 1, i] = -1.0
136        eigmat[-1, -1] = 2.0
137        eigmat[0, -1] = -1.0
138        eigmat[-1, 0] = -1.0
139        omk, eigmat = np.linalg.eigh(eigmat)
140        omk[0] = 0.0
141        omk = nbeads * kT * omk**0.5 / au.FS
142        for i in range(nbeads):
143            if eigmat[i, 0] < 0:
144                eigmat[i] *= -1.0
145        eigmat = jnp.asarray(eigmat, dtype=fprec)
146        system_data["omk"] = omk
147        system_data["eigmat"] = eigmat
148        nreplicas = None
149    elif nreplicas is not None:
150        nreplicas = int(nreplicas)
151        print("# nreplicas: ", nreplicas)
152        system_data["nreplicas"] = nreplicas
153        system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1)
154        system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape(
155            -1
156        )
157        coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape(
158            -1, 3
159        )
160        species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1)
161        bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat)
162        natoms = np.array([nat] * nreplicas, dtype=np.int32)
163    else:
164        system_data["nreplicas"] = 1
165        bead_index = np.array([0] * nat, dtype=np.int32)
166        natoms = np.array([nat], dtype=np.int32)
167
168    conformation = {
169        "species": species,
170        "coordinates": coordinates,
171        "batch_index": bead_index,
172        "natoms": natoms,
173        "total_charge": total_charge,
174    }
175    if cell is not None:
176        cell = cell[None, :, :]
177        reciprocal_cell = reciprocal_cell[None, :, :]
178        if nbeads is not None:
179            cell = np.repeat(cell, nbeads, axis=0)
180            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
181        elif nreplicas is not None:
182            cell = np.repeat(cell, nreplicas, axis=0)
183            reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0)
184        conformation["cells"] = cell
185        conformation["reciprocal_cells"] = reciprocal_cell
186
187    additional_keys = simulation_parameters.get("additional_keys", {})
188    for key, value in additional_keys.items():
189        conformation[key] = value
190
191    return system_data, conformation
192
193
194def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
195    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
196    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
197    pbc_data = system_data.get("pbc", None)
198
199    ### CONFIGURE PREPROCESSING
200    preproc_state = unfreeze(model.preproc_state)
201    layer_state = []
202    for st in preproc_state["layers_state"]:
203        stnew = unfreeze(st)
204        if pbc_data is not None:
205            stnew["minimum_image"] = pbc_data["minimum_image"]
206        if nblist_skin > 0:
207            stnew["nblist_skin"] = nblist_skin
208        if "nblist_mult_size" in simulation_parameters:
209            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
210        if "nblist_add_neigh" in simulation_parameters:
211            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
212        layer_state.append(freeze(stnew))
213    preproc_state["layers_state"] = layer_state
214    preproc_state = freeze(preproc_state)
215
216    ## initial preprocessing
217    preproc_state = preproc_state.copy({"check_input": True})
218    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
219
220    preproc_state = preproc_state.copy({"check_input": False})
221
222    if nblist_verbose:
223        graphs_keys = list(model._graphs_properties.keys())
224        print("# graphs_keys: ", graphs_keys)
225        print("# nblist state:", preproc_state)
226
227    ### print model
228    if simulation_parameters.get("print_model", False):
229        print(model.summarize(example_data=conformation))
230
231    return preproc_state, conformation
def load_model(simulation_parameters):
18def load_model(simulation_parameters):
19    model_file = simulation_parameters.get("model_file")
20    model_file = Path(str(model_file).strip())
21    if not model_file.exists():
22        raise FileNotFoundError(f"model file {model_file} not found")
23    else:
24        graph_config = simulation_parameters.get("graph_config", {})
25        model = FENNIX.load(model_file, graph_config=graph_config)  # \
26        print(f"# model_file: {model_file}")
27
28    if "energy_terms" in simulation_parameters:
29        energy_terms = simulation_parameters["energy_terms"]
30        if isinstance(energy_terms, str):
31            energy_terms = energy_terms.split()
32        model.set_energy_terms(energy_terms)
33        print("# energy terms:", model.energy_terms)
34
35    return model
def load_system_data(simulation_parameters, fprec):
 38def load_system_data(simulation_parameters, fprec):
 39    ## LOAD SYSTEM CONFORMATION FROM FILES
 40    system_name = str(simulation_parameters.get("system", "system")).strip()
 41    indexed = simulation_parameters.get("xyz_input/indexed", True)
 42    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", False)
 43    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 44    if not xyzfile.exists():
 45        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 46    system_name = str(simulation_parameters.get("system", xyzfile.stem)).strip()
 47    symbols, coordinates, _ = last_xyz_frame(
 48        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 49    )
 50    coordinates = coordinates.astype(fprec)
 51    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 52    nat = species.shape[0]
 53
 54    ## GET MASS
 55    mass_amu = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 56    deuterate = simulation_parameters.get("deuterate", False)
 57    if deuterate:
 58        print("# Replacing all hydrogens with deuteriums")
 59        mass_amu[species == 1] *= 2.0
 60    mass = mass_amu * (au.MPROT * (au.FS / au.BOHR) ** 2)
 61
 62    ### GET TEMPERATURE
 63    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
 64    kT = temperature / au.KELVIN
 65    totmass_amu = mass_amu.sum() / 6.02214129e-1
 66
 67    ### GET TOTAL CHARGE
 68    total_charge = simulation_parameters.get("total_charge", 0.0)
 69
 70    ### ENERGY UNIT
 71    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
 72    energy_unit = au.get_multiplier(energy_unit_str)
 73
 74    ## SYSTEM DATA
 75    system_data = {
 76        "name": system_name,
 77        "nat": nat,
 78        "symbols": symbols,
 79        "species": species,
 80        "mass": mass,
 81        "temperature": temperature,
 82        "kT": kT,
 83        "totmass_amu": totmass_amu,
 84        "total_charge": total_charge,
 85        "energy_unit": energy_unit,
 86        "energy_unit_str": energy_unit_str,
 87    }
 88
 89    ### Set boundary conditions
 90    cell = simulation_parameters.get("cell", None)
 91    if cell is not None:
 92        cell = np.array(cell, dtype=fprec).reshape(3, 3)
 93        reciprocal_cell = np.linalg.inv(cell)
 94        volume = np.abs(np.linalg.det(cell))
 95        print("# cell matrix:")
 96        for l in cell:
 97            print("# ", l)
 98        # print(cell)
 99        dens = totmass_amu / volume
100        print("# density: ", dens.item(), " g/cm^3")
101        minimum_image = simulation_parameters.get("minimum_image", True)
102        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
103        print("# minimum_image: ", minimum_image)
104
105        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
106        if crystal_input:
107            coordinates = coordinates @ cell
108
109        pbc_data = {
110            "cell": cell,
111            "reciprocal_cell": reciprocal_cell,
112            "volume": volume,
113            "minimum_image": minimum_image,
114            "estimate_pressure": estimate_pressure,
115        }
116    else:
117        pbc_data = None
118    system_data["pbc"] = pbc_data
119
120    ### PIMD
121    nbeads = simulation_parameters.get("nbeads", None)
122    nreplicas = simulation_parameters.get("nreplicas", None)
123    if nbeads is not None:
124        nbeads = int(nbeads)
125        print("# nbeads: ", nbeads)
126        system_data["nbeads"] = nbeads
127        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3)
128        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
129        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
130        natoms = np.array([nat] * nbeads, dtype=np.int32)
131
132        eigmat = np.zeros((nbeads, nbeads))
133        for i in range(nbeads - 1):
134            eigmat[i, i] = 2.0
135            eigmat[i, i + 1] = -1.0
136            eigmat[i + 1, i] = -1.0
137        eigmat[-1, -1] = 2.0
138        eigmat[0, -1] = -1.0
139        eigmat[-1, 0] = -1.0
140        omk, eigmat = np.linalg.eigh(eigmat)
141        omk[0] = 0.0
142        omk = nbeads * kT * omk**0.5 / au.FS
143        for i in range(nbeads):
144            if eigmat[i, 0] < 0:
145                eigmat[i] *= -1.0
146        eigmat = jnp.asarray(eigmat, dtype=fprec)
147        system_data["omk"] = omk
148        system_data["eigmat"] = eigmat
149        nreplicas = None
150    elif nreplicas is not None:
151        nreplicas = int(nreplicas)
152        print("# nreplicas: ", nreplicas)
153        system_data["nreplicas"] = nreplicas
154        system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1)
155        system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape(
156            -1
157        )
158        coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape(
159            -1, 3
160        )
161        species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1)
162        bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat)
163        natoms = np.array([nat] * nreplicas, dtype=np.int32)
164    else:
165        system_data["nreplicas"] = 1
166        bead_index = np.array([0] * nat, dtype=np.int32)
167        natoms = np.array([nat], dtype=np.int32)
168
169    conformation = {
170        "species": species,
171        "coordinates": coordinates,
172        "batch_index": bead_index,
173        "natoms": natoms,
174        "total_charge": total_charge,
175    }
176    if cell is not None:
177        cell = cell[None, :, :]
178        reciprocal_cell = reciprocal_cell[None, :, :]
179        if nbeads is not None:
180            cell = np.repeat(cell, nbeads, axis=0)
181            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
182        elif nreplicas is not None:
183            cell = np.repeat(cell, nreplicas, axis=0)
184            reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0)
185        conformation["cells"] = cell
186        conformation["reciprocal_cells"] = reciprocal_cell
187
188    additional_keys = simulation_parameters.get("additional_keys", {})
189    for key, value in additional_keys.items():
190        conformation[key] = value
191
192    return system_data, conformation
def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
195def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
196    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
197    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
198    pbc_data = system_data.get("pbc", None)
199
200    ### CONFIGURE PREPROCESSING
201    preproc_state = unfreeze(model.preproc_state)
202    layer_state = []
203    for st in preproc_state["layers_state"]:
204        stnew = unfreeze(st)
205        if pbc_data is not None:
206            stnew["minimum_image"] = pbc_data["minimum_image"]
207        if nblist_skin > 0:
208            stnew["nblist_skin"] = nblist_skin
209        if "nblist_mult_size" in simulation_parameters:
210            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
211        if "nblist_add_neigh" in simulation_parameters:
212            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
213        layer_state.append(freeze(stnew))
214    preproc_state["layers_state"] = layer_state
215    preproc_state = freeze(preproc_state)
216
217    ## initial preprocessing
218    preproc_state = preproc_state.copy({"check_input": True})
219    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
220
221    preproc_state = preproc_state.copy({"check_input": False})
222
223    if nblist_verbose:
224        graphs_keys = list(model._graphs_properties.keys())
225        print("# graphs_keys: ", graphs_keys)
226        print("# nblist state:", preproc_state)
227
228    ### print model
229    if simulation_parameters.get("print_model", False):
230        print(model.summarize(example_data=conformation))
231
232    return preproc_state, conformation