fennol.md.dynamic

  1import sys, os, io
  2import argparse
  3import time
  4from pathlib import Path
  5
  6import numpy as np
  7from collections import defaultdict
  8import jax
  9import jax.numpy as jnp
 10import pickle
 11import yaml
 12
 13from ..utils.io import (
 14    write_arc_frame,
 15    write_xyz_frame,
 16    write_extxyz_frame,
 17    human_time_duration,
 18)
 19from .utils import wrapbox, save_dynamics_restart
 20from ..utils import minmaxone, AtomicUnits as au,read_tinker_interval
 21from ..utils.input_parser import parse_input,convert_dict_units, InputFile
 22from .integrate import initialize_dynamics
 23
 24
 25def main():
 26    # os.environ["OMP_NUM_THREADS"] = "1"
 27    sys.stdout = io.TextIOWrapper(
 28        open(sys.stdout.fileno(), "wb", 0), write_through=True
 29    )
 30    
 31    ### Read the parameter file
 32    parser = argparse.ArgumentParser(prog="fennol_md")
 33    parser.add_argument("param_file", type=Path, help="Parameter file")
 34    args = parser.parse_args()
 35    param_file = args.param_file
 36
 37    return config_and_run_dynamic(param_file)
 38
 39def config_and_run_dynamic(param_file: Path):
 40    """ Run the dynamic simulation based on the provided parameter file."""
 41
 42    if not param_file.exists() and not param_file.is_file():
 43        raise FileNotFoundError(f"Parameter file {param_file} not found")
 44    
 45    if param_file.suffix in [".yaml", ".yml"]:
 46        with open(param_file, "r") as f:
 47            simulation_parameters = convert_dict_units(yaml.safe_load(f))
 48            simulation_parameters = InputFile(**simulation_parameters)
 49    elif param_file.suffix == ".fnl":
 50        simulation_parameters = parse_input(param_file)
 51    else:
 52        raise ValueError(
 53            f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'"
 54        )
 55
 56    ### Set the device
 57    if "FENNOL_DEVICE" in os.environ:
 58        device = os.environ["FENNOL_DEVICE"].lower()
 59        print(f"# Setting device from env FENNOL_DEVICE={device}")
 60    else:
 61        device = simulation_parameters.get("device", "cpu").lower()
 62    if device == "cpu":
 63        jax.config.update('jax_platforms', 'cpu')
 64        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 65    elif device.startswith("cuda") or device.startswith("gpu"):
 66        if ":" in device:
 67            num = device.split(":")[-1]
 68            os.environ["CUDA_VISIBLE_DEVICES"] = num
 69        else:
 70            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 71        device = "gpu"
 72
 73    _device = jax.devices(device)[0]
 74    jax.config.update("jax_default_device", _device)
 75
 76    ### Set the precision
 77    enable_x64 = simulation_parameters.get("double_precision", False)
 78    jax.config.update("jax_enable_x64", enable_x64)
 79    fprec = "float64" if enable_x64 else "float32"
 80
 81    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
 82    assert matmul_precision in [
 83        "default",
 84        "high",
 85        "highest",
 86    ], "matmul_prec must be one of 'default','high','highest'"
 87    if matmul_precision != "highest":
 88        print(f"# Setting matmul precision to '{matmul_precision}'")
 89    if matmul_precision == "default" and fprec == "float32":
 90        print(
 91            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
 92        )
 93    jax.config.update("jax_default_matmul_precision", matmul_precision)
 94
 95    # with jax.default_device(_device):
 96    return dynamic(simulation_parameters, device, fprec)
 97
 98
 99def dynamic(simulation_parameters, device, fprec):
100    tstart_dyn = time.time()
101
102    random_seed = simulation_parameters.get(
103        "random_seed", np.random.randint(0, 2**32 - 1)
104    )
105    print(f"# random_seed: {random_seed}")
106    rng_key = jax.random.PRNGKey(random_seed)
107
108    ## INITIALIZE INTEGRATOR AND SYSTEM
109    rng_key, subkey = jax.random.split(rng_key)
110    step, update_conformation, system_data, dyn_state, conformation, system = (
111        initialize_dynamics(simulation_parameters, fprec, subkey)
112    )
113
114    nat = system_data["nat"]
115    dt = dyn_state["dt"]
116    ## get number of steps
117    nsteps = int(simulation_parameters.get("nsteps"))
118    start_time_ps = dyn_state.get("start_time_ps", 0.0)
119
120    ### Set I/O parameters
121    Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS
122    ndump = int(Tdump / dt)
123    system_name = system_data["name"]
124    estimate_pressure = dyn_state["estimate_pressure"]
125
126    @jax.jit
127    def check_nan(system):
128        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
129            jnp.isnan(system["coordinates"])
130        )
131
132    if system_data["pbc"] is not None:
133        cell = system["cell"]
134        reciprocal_cell = np.linalg.inv(cell)
135        do_wrap_box = simulation_parameters.get("wrap_box", False)
136        if do_wrap_box:
137            wrap_groups_def = simulation_parameters.get("wrap_groups",None)
138            if wrap_groups_def is None:
139                wrap_groups = None
140            else:
141                wrap_groups = {}
142                assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary"
143                for k, v in wrap_groups_def.items():
144                    wrap_groups[k]=read_tinker_interval(v)
145                # check that pairwise intersection of wrap groups is empty
146                wrap_groups_keys = list(wrap_groups.keys())
147                for i in range(len(wrap_groups_keys)):
148                    i_key = wrap_groups_keys[i]
149                    w1 = set(wrap_groups[i_key])
150                    for j in range(i + 1, len(wrap_groups_keys)):
151                        j_key = wrap_groups_keys[j]
152                        w2 = set(wrap_groups[j_key])
153                        if  w1.intersection(w2):
154                            raise ValueError(
155                                f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}"
156                            )
157                group_all = np.concatenate(list(wrap_groups.values()))
158                # get all atoms that are not in any wrap group
159                group_none = np.setdiff1d(np.arange(nat), group_all)
160                print(f"# Wrap groups: {wrap_groups}")
161                wrap_groups["__other"] = group_none
162                wrap_groups = ((k, v) for k, v in wrap_groups.items())
163
164    else:
165        cell = None
166        reciprocal_cell = None
167        do_wrap_box = False
168        wrap_groups = None
169
170    ### Energy units and print initial energy
171    model_energy_unit = system_data["model_energy_unit"]
172    model_energy_unit_str = system_data["model_energy_unit_str"]
173    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
174    energy_unit_str = system_data["energy_unit_str"]
175    energy_unit = system_data["energy_unit"]
176    print("# Energy unit: ", energy_unit_str)
177    atom_energy_unit = energy_unit
178    atom_energy_unit_str = energy_unit_str
179    if per_atom_energy:
180        atom_energy_unit /= nat
181        atom_energy_unit_str = f"{energy_unit_str}/atom"
182        print("# Printing Energy per atom")
183    print(
184        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
185    )
186    f = system["forces"]
187    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
188
189    ## printing options
190    print_timings = simulation_parameters.get("print_timings", False)
191    nprint = int(simulation_parameters.get("nprint", 10))
192    assert nprint > 0, "nprint must be > 0"
193    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
194    assert nsummary > nprint, "nsummary must be > nprint"
195
196    save_keys = simulation_parameters.get("save_keys", [])
197    if save_keys:
198        print(f"# Saving keys: {save_keys}")
199        fkeys = open(f"{system_name}.traj.pkl", "wb+")
200    else:
201        fkeys = None
202    
203    ### initialize colvars
204    use_colvars = "colvars" in dyn_state
205    if use_colvars:
206        print(f"# Colvars: {dyn_state['colvars']}")
207        colvars_names = dyn_state["colvars"]
208        # open colvars file and print header
209        fcolvars = open(f"{system_name}.colvars.traj", "a")
210        fcolvars.write("#time[ps] ")
211        fcolvars.write(" ".join(colvars_names))
212        fcolvars.write("\n")
213        fcolvars.flush()
214
215    ### Print header
216    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
217    thermostat_name = dyn_state["thermostat_name"]
218    pimd = dyn_state["pimd"]
219    variable_cell = dyn_state["variable_cell"]
220    nbeads = system_data.get("nbeads", 1)
221    dyn_name = "PIMD" if pimd else "MD"
222    print("#" * 84)
223    print(
224        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
225    )
226    header = f"#{'Step':>9} {'Time[ps]':>10} {'Etot':>10} {'Epot':>10} {'Ekin':>10} {'Temp[K]':>10}"
227    if pimd:
228        header += f" {'Temp_c[K]':>10}"
229    if include_thermostat_energy:
230        header += f" {'Etherm':>10}"
231    if estimate_pressure:
232        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
233        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
234        pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3
235        p_str = f"  Press[{pressure_unit_str}]"
236        header += f" {p_str:>10}"
237    if variable_cell:
238        header += f" {'Density':>10}"
239    print(header)
240
241    ### Open trajectory file
242    traj_format = simulation_parameters.get("traj_format", "arc").lower()
243    if traj_format == "xyz":
244        traj_ext = ".traj.xyz"
245        write_frame = write_xyz_frame
246    elif traj_format == "extxyz":
247        traj_ext = ".traj.extxyz"
248        write_frame = write_extxyz_frame
249    elif traj_format == "arc":
250        traj_ext = ".arc"
251        write_frame = write_arc_frame
252    else:
253        raise ValueError(
254            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
255        )
256
257    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
258
259    if write_all_beads:
260        fout = [
261            open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads)
262        ]
263    else:
264        fout = open(system_name + traj_ext, "a")
265
266    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
267    if ensemble_key is not None:
268        fens = open(f"{system_name}.ensemble_weights.traj", "a")
269
270    write_centroid = simulation_parameters.get("write_centroid", False) and pimd
271    if write_centroid:
272        fcentroid = open(f"{system_name}_centroid" + traj_ext, "a")
273
274    ### initialize proprerty trajectories
275    properties_traj = defaultdict(list)
276
277    ### initialize counters and timers
278    t0 = time.time()
279    t0dump = t0
280    istep = 0
281    t0full = time.time()
282    force_preprocess = False
283
284    for istep in range(1, nsteps + 1):
285
286        ### update the system
287        dyn_state, system, conformation, model_output = step(
288            istep, dyn_state, system, conformation, force_preprocess
289        )
290
291        ### print properties
292        if istep % nprint == 0:
293            t1 = time.time()
294            tperstep = (t1 - t0) / nprint
295            t0 = t1
296            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
297
298            ek = system["ek"]
299            epot = system["epot"]
300            etot = ek + epot
301            temper = 2 * ek / (3.0 * nat) * au.KELVIN
302
303            th_state = system["thermostat"]
304            if include_thermostat_energy:
305                etherm = th_state["thermostat_energy"]
306                etot = etot + etherm
307
308            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
309                etot * atom_energy_unit
310            )
311            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
312                epot * atom_energy_unit
313            )
314            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
315                ek * atom_energy_unit
316            )
317            properties_traj["Temper[Kelvin]"].append(temper)
318            if pimd:
319                ek_c = system["ek_c"]
320                temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN
321                properties_traj["Temper_c[Kelvin]"].append(temper_c)
322
323            ### construct line of properties
324            simulated_time = start_time_ps + istep * dt / 1000
325            line = f"{istep:10.6g} {simulated_time:10.3f} {etot*atom_energy_unit:#10.4f}  {epot*atom_energy_unit:#10.4f} {ek*atom_energy_unit:#10.4f} {temper:10.2f}"
326            if pimd:
327                line += f" {temper_c:10.2f}"
328            if include_thermostat_energy:
329                line += f" {etherm*atom_energy_unit:#10.4f}"
330                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
331                    etherm * atom_energy_unit
332                )
333            if estimate_pressure:
334                pres = system["pressure"] * pressure_unit
335                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
336                if print_aniso_pressure:
337                    pres_tensor = system["pressure_tensor"] * pressure_unit
338                    pres_tensor = 0.5 * (pres_tensor + pres_tensor.T)
339                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(
340                        pres_tensor[0, 0]
341                    )
342                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(
343                        pres_tensor[1, 1]
344                    )
345                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(
346                        pres_tensor[2, 2]
347                    )
348                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(
349                        pres_tensor[0, 1]
350                    )
351                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(
352                        pres_tensor[0, 2]
353                    )
354                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(
355                        pres_tensor[1, 2]
356                    )
357                line += f" {pres:10.3f}"
358            if variable_cell:
359                density = system["density"]
360                properties_traj["Density[g/cm^3]"].append(density)
361                if print_aniso_pressure:
362                    cell = system["cell"]
363                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0])
364                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1])
365                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2])
366                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0])
367                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1])
368                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2])
369                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0])
370                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1])
371                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2])
372                line += f" {density:10.4f}"
373                if "piston_temperature" in system["barostat"]:
374                    piston_temperature = system["barostat"]["piston_temperature"]
375                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
376
377            print(line)
378
379            ### save colvars
380            if use_colvars:
381                colvars = system["colvars"]
382                fcolvars.write(f"{simulated_time:.3f} ")
383                fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names]))
384                fcolvars.write("\n")
385                fcolvars.flush()
386
387            if save_keys:
388                data = {
389                    k: (
390                        np.asarray(model_output[k])
391                        if isinstance(model_output[k], jnp.ndarray)
392                        else model_output[k]
393                    )
394                    for k in save_keys
395                }
396                data["properties"] = {
397                    k: float(v) for k, v in zip(header.split()[1:], line.split())
398                }
399                data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str)
400                data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str)
401
402                pickle.dump(data, fkeys)
403
404        ### save frame
405        if istep % ndump == 0:
406            line = "# Write XYZ frame"
407            if variable_cell:
408                cell = np.array(system["cell"])
409                reciprocal_cell = np.linalg.inv(cell)
410            if do_wrap_box:
411                if pimd:
412                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups)
413                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
414                else:
415                    system["coordinates"] = wrapbox(
416                        system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups
417                    )
418                conformation = update_conformation(conformation, system)
419                line += " (atoms have been wrapped into the box)"
420                force_preprocess = True
421            print(line)
422
423            save_dynamics_restart(system_data, conformation, dyn_state, system)
424
425            properties = {
426                "energy": float(system["epot"]) * energy_unit,
427                "Time_ps": start_time_ps + istep * dt / 1000,
428                "energy_unit": energy_unit_str,
429            }
430
431            if write_all_beads:
432                coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3))
433                for i, fb in enumerate(fout):
434                    write_frame(
435                        fb,
436                        system_data["symbols"],
437                        coords[i],
438                        cell=cell,
439                        properties=properties,
440                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
441                    )
442            else:
443                write_frame(
444                    fout,
445                    system_data["symbols"],
446                    np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]),
447                    cell=cell,
448                    properties=properties,
449                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
450                )
451            if write_centroid:
452                centroid = np.asarray(system["coordinates"][0])
453                write_frame(
454                    fcentroid,
455                    system_data["symbols"],
456                    centroid,
457                    cell=cell,
458                    properties=properties,
459                    forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0])
460                    * energy_unit,
461                )
462            if ensemble_key is not None:
463                weights = " ".join(
464                    [f"{w:.6f}" for w in system["ensemble_weights"].tolist()]
465                )
466                fens.write(f"{weights}\n")
467                fens.flush()
468
469        ### summary over last nsummary steps
470        if istep % (nsummary) == 0:
471            if check_nan(system):
472                raise ValueError(f"dynamics crashed at step {istep}.")
473            tfull = time.time() - t0full
474            t0full = time.time()
475            tperstep = tfull / (nsummary)
476            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
477            elapsed_time = time.time() - tstart_dyn
478            estimated_remaining_time = tperstep * (nsteps - istep)
479            estimated_total_time = elapsed_time + estimated_remaining_time
480
481            print("#" * 50)
482            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
483            print(f"# Simulated time      : {istep * dt*1.e-3:.3f} ps")
484            print(f"# Tot. Simu. time     : {start_time_ps + istep * dt*1.e-3:.3f} ps")
485            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
486            print(
487                f"# Est. total duration   : {human_time_duration(estimated_total_time)}"
488            )
489            print(
490                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
491            )
492            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
493
494            corr_kin = system["thermostat"].get("corr_kin", None)
495            if corr_kin is not None:
496                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
497            print(f"# Averages over last {nsummary:_} steps :")
498            for k, v in properties_traj.items():
499                if len(properties_traj[k]) == 0:
500                    continue
501                mu = np.mean(properties_traj[k])
502                sig = np.std(properties_traj[k])
503                ksplit = k.split("[")
504                name = ksplit[0].strip()
505                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
506                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
507
508            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
509            print("#" * 50)
510            if istep < nsteps:
511                print(header)
512            ## reset property trajectories
513            properties_traj = defaultdict(list)
514
515    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
516    ### close trajectory file
517    fout.close()
518    if ensemble_key is not None:
519        fens.close()
520    if use_colvars:
521        fcolvars.close()
522    if fkeys is not None:
523        fkeys.close()
524    if write_centroid:
525        fcentroid.close()
526
527    return 0
528
529
530if __name__ == "__main__":
531    main()
def main():
26def main():
27    # os.environ["OMP_NUM_THREADS"] = "1"
28    sys.stdout = io.TextIOWrapper(
29        open(sys.stdout.fileno(), "wb", 0), write_through=True
30    )
31    
32    ### Read the parameter file
33    parser = argparse.ArgumentParser(prog="fennol_md")
34    parser.add_argument("param_file", type=Path, help="Parameter file")
35    args = parser.parse_args()
36    param_file = args.param_file
37
38    return config_and_run_dynamic(param_file)
def config_and_run_dynamic(param_file: pathlib.Path):
40def config_and_run_dynamic(param_file: Path):
41    """ Run the dynamic simulation based on the provided parameter file."""
42
43    if not param_file.exists() and not param_file.is_file():
44        raise FileNotFoundError(f"Parameter file {param_file} not found")
45    
46    if param_file.suffix in [".yaml", ".yml"]:
47        with open(param_file, "r") as f:
48            simulation_parameters = convert_dict_units(yaml.safe_load(f))
49            simulation_parameters = InputFile(**simulation_parameters)
50    elif param_file.suffix == ".fnl":
51        simulation_parameters = parse_input(param_file)
52    else:
53        raise ValueError(
54            f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'"
55        )
56
57    ### Set the device
58    if "FENNOL_DEVICE" in os.environ:
59        device = os.environ["FENNOL_DEVICE"].lower()
60        print(f"# Setting device from env FENNOL_DEVICE={device}")
61    else:
62        device = simulation_parameters.get("device", "cpu").lower()
63    if device == "cpu":
64        jax.config.update('jax_platforms', 'cpu')
65        os.environ["CUDA_VISIBLE_DEVICES"] = ""
66    elif device.startswith("cuda") or device.startswith("gpu"):
67        if ":" in device:
68            num = device.split(":")[-1]
69            os.environ["CUDA_VISIBLE_DEVICES"] = num
70        else:
71            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
72        device = "gpu"
73
74    _device = jax.devices(device)[0]
75    jax.config.update("jax_default_device", _device)
76
77    ### Set the precision
78    enable_x64 = simulation_parameters.get("double_precision", False)
79    jax.config.update("jax_enable_x64", enable_x64)
80    fprec = "float64" if enable_x64 else "float32"
81
82    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
83    assert matmul_precision in [
84        "default",
85        "high",
86        "highest",
87    ], "matmul_prec must be one of 'default','high','highest'"
88    if matmul_precision != "highest":
89        print(f"# Setting matmul precision to '{matmul_precision}'")
90    if matmul_precision == "default" and fprec == "float32":
91        print(
92            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
93        )
94    jax.config.update("jax_default_matmul_precision", matmul_precision)
95
96    # with jax.default_device(_device):
97    return dynamic(simulation_parameters, device, fprec)

Run the dynamic simulation based on the provided parameter file.

def dynamic(simulation_parameters, device, fprec):
100def dynamic(simulation_parameters, device, fprec):
101    tstart_dyn = time.time()
102
103    random_seed = simulation_parameters.get(
104        "random_seed", np.random.randint(0, 2**32 - 1)
105    )
106    print(f"# random_seed: {random_seed}")
107    rng_key = jax.random.PRNGKey(random_seed)
108
109    ## INITIALIZE INTEGRATOR AND SYSTEM
110    rng_key, subkey = jax.random.split(rng_key)
111    step, update_conformation, system_data, dyn_state, conformation, system = (
112        initialize_dynamics(simulation_parameters, fprec, subkey)
113    )
114
115    nat = system_data["nat"]
116    dt = dyn_state["dt"]
117    ## get number of steps
118    nsteps = int(simulation_parameters.get("nsteps"))
119    start_time_ps = dyn_state.get("start_time_ps", 0.0)
120
121    ### Set I/O parameters
122    Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS
123    ndump = int(Tdump / dt)
124    system_name = system_data["name"]
125    estimate_pressure = dyn_state["estimate_pressure"]
126
127    @jax.jit
128    def check_nan(system):
129        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
130            jnp.isnan(system["coordinates"])
131        )
132
133    if system_data["pbc"] is not None:
134        cell = system["cell"]
135        reciprocal_cell = np.linalg.inv(cell)
136        do_wrap_box = simulation_parameters.get("wrap_box", False)
137        if do_wrap_box:
138            wrap_groups_def = simulation_parameters.get("wrap_groups",None)
139            if wrap_groups_def is None:
140                wrap_groups = None
141            else:
142                wrap_groups = {}
143                assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary"
144                for k, v in wrap_groups_def.items():
145                    wrap_groups[k]=read_tinker_interval(v)
146                # check that pairwise intersection of wrap groups is empty
147                wrap_groups_keys = list(wrap_groups.keys())
148                for i in range(len(wrap_groups_keys)):
149                    i_key = wrap_groups_keys[i]
150                    w1 = set(wrap_groups[i_key])
151                    for j in range(i + 1, len(wrap_groups_keys)):
152                        j_key = wrap_groups_keys[j]
153                        w2 = set(wrap_groups[j_key])
154                        if  w1.intersection(w2):
155                            raise ValueError(
156                                f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}"
157                            )
158                group_all = np.concatenate(list(wrap_groups.values()))
159                # get all atoms that are not in any wrap group
160                group_none = np.setdiff1d(np.arange(nat), group_all)
161                print(f"# Wrap groups: {wrap_groups}")
162                wrap_groups["__other"] = group_none
163                wrap_groups = ((k, v) for k, v in wrap_groups.items())
164
165    else:
166        cell = None
167        reciprocal_cell = None
168        do_wrap_box = False
169        wrap_groups = None
170
171    ### Energy units and print initial energy
172    model_energy_unit = system_data["model_energy_unit"]
173    model_energy_unit_str = system_data["model_energy_unit_str"]
174    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
175    energy_unit_str = system_data["energy_unit_str"]
176    energy_unit = system_data["energy_unit"]
177    print("# Energy unit: ", energy_unit_str)
178    atom_energy_unit = energy_unit
179    atom_energy_unit_str = energy_unit_str
180    if per_atom_energy:
181        atom_energy_unit /= nat
182        atom_energy_unit_str = f"{energy_unit_str}/atom"
183        print("# Printing Energy per atom")
184    print(
185        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
186    )
187    f = system["forces"]
188    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
189
190    ## printing options
191    print_timings = simulation_parameters.get("print_timings", False)
192    nprint = int(simulation_parameters.get("nprint", 10))
193    assert nprint > 0, "nprint must be > 0"
194    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
195    assert nsummary > nprint, "nsummary must be > nprint"
196
197    save_keys = simulation_parameters.get("save_keys", [])
198    if save_keys:
199        print(f"# Saving keys: {save_keys}")
200        fkeys = open(f"{system_name}.traj.pkl", "wb+")
201    else:
202        fkeys = None
203    
204    ### initialize colvars
205    use_colvars = "colvars" in dyn_state
206    if use_colvars:
207        print(f"# Colvars: {dyn_state['colvars']}")
208        colvars_names = dyn_state["colvars"]
209        # open colvars file and print header
210        fcolvars = open(f"{system_name}.colvars.traj", "a")
211        fcolvars.write("#time[ps] ")
212        fcolvars.write(" ".join(colvars_names))
213        fcolvars.write("\n")
214        fcolvars.flush()
215
216    ### Print header
217    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
218    thermostat_name = dyn_state["thermostat_name"]
219    pimd = dyn_state["pimd"]
220    variable_cell = dyn_state["variable_cell"]
221    nbeads = system_data.get("nbeads", 1)
222    dyn_name = "PIMD" if pimd else "MD"
223    print("#" * 84)
224    print(
225        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
226    )
227    header = f"#{'Step':>9} {'Time[ps]':>10} {'Etot':>10} {'Epot':>10} {'Ekin':>10} {'Temp[K]':>10}"
228    if pimd:
229        header += f" {'Temp_c[K]':>10}"
230    if include_thermostat_energy:
231        header += f" {'Etherm':>10}"
232    if estimate_pressure:
233        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
234        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
235        pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3
236        p_str = f"  Press[{pressure_unit_str}]"
237        header += f" {p_str:>10}"
238    if variable_cell:
239        header += f" {'Density':>10}"
240    print(header)
241
242    ### Open trajectory file
243    traj_format = simulation_parameters.get("traj_format", "arc").lower()
244    if traj_format == "xyz":
245        traj_ext = ".traj.xyz"
246        write_frame = write_xyz_frame
247    elif traj_format == "extxyz":
248        traj_ext = ".traj.extxyz"
249        write_frame = write_extxyz_frame
250    elif traj_format == "arc":
251        traj_ext = ".arc"
252        write_frame = write_arc_frame
253    else:
254        raise ValueError(
255            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
256        )
257
258    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
259
260    if write_all_beads:
261        fout = [
262            open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads)
263        ]
264    else:
265        fout = open(system_name + traj_ext, "a")
266
267    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
268    if ensemble_key is not None:
269        fens = open(f"{system_name}.ensemble_weights.traj", "a")
270
271    write_centroid = simulation_parameters.get("write_centroid", False) and pimd
272    if write_centroid:
273        fcentroid = open(f"{system_name}_centroid" + traj_ext, "a")
274
275    ### initialize proprerty trajectories
276    properties_traj = defaultdict(list)
277
278    ### initialize counters and timers
279    t0 = time.time()
280    t0dump = t0
281    istep = 0
282    t0full = time.time()
283    force_preprocess = False
284
285    for istep in range(1, nsteps + 1):
286
287        ### update the system
288        dyn_state, system, conformation, model_output = step(
289            istep, dyn_state, system, conformation, force_preprocess
290        )
291
292        ### print properties
293        if istep % nprint == 0:
294            t1 = time.time()
295            tperstep = (t1 - t0) / nprint
296            t0 = t1
297            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
298
299            ek = system["ek"]
300            epot = system["epot"]
301            etot = ek + epot
302            temper = 2 * ek / (3.0 * nat) * au.KELVIN
303
304            th_state = system["thermostat"]
305            if include_thermostat_energy:
306                etherm = th_state["thermostat_energy"]
307                etot = etot + etherm
308
309            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
310                etot * atom_energy_unit
311            )
312            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
313                epot * atom_energy_unit
314            )
315            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
316                ek * atom_energy_unit
317            )
318            properties_traj["Temper[Kelvin]"].append(temper)
319            if pimd:
320                ek_c = system["ek_c"]
321                temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN
322                properties_traj["Temper_c[Kelvin]"].append(temper_c)
323
324            ### construct line of properties
325            simulated_time = start_time_ps + istep * dt / 1000
326            line = f"{istep:10.6g} {simulated_time:10.3f} {etot*atom_energy_unit:#10.4f}  {epot*atom_energy_unit:#10.4f} {ek*atom_energy_unit:#10.4f} {temper:10.2f}"
327            if pimd:
328                line += f" {temper_c:10.2f}"
329            if include_thermostat_energy:
330                line += f" {etherm*atom_energy_unit:#10.4f}"
331                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
332                    etherm * atom_energy_unit
333                )
334            if estimate_pressure:
335                pres = system["pressure"] * pressure_unit
336                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
337                if print_aniso_pressure:
338                    pres_tensor = system["pressure_tensor"] * pressure_unit
339                    pres_tensor = 0.5 * (pres_tensor + pres_tensor.T)
340                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(
341                        pres_tensor[0, 0]
342                    )
343                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(
344                        pres_tensor[1, 1]
345                    )
346                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(
347                        pres_tensor[2, 2]
348                    )
349                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(
350                        pres_tensor[0, 1]
351                    )
352                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(
353                        pres_tensor[0, 2]
354                    )
355                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(
356                        pres_tensor[1, 2]
357                    )
358                line += f" {pres:10.3f}"
359            if variable_cell:
360                density = system["density"]
361                properties_traj["Density[g/cm^3]"].append(density)
362                if print_aniso_pressure:
363                    cell = system["cell"]
364                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0])
365                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1])
366                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2])
367                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0])
368                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1])
369                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2])
370                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0])
371                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1])
372                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2])
373                line += f" {density:10.4f}"
374                if "piston_temperature" in system["barostat"]:
375                    piston_temperature = system["barostat"]["piston_temperature"]
376                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
377
378            print(line)
379
380            ### save colvars
381            if use_colvars:
382                colvars = system["colvars"]
383                fcolvars.write(f"{simulated_time:.3f} ")
384                fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names]))
385                fcolvars.write("\n")
386                fcolvars.flush()
387
388            if save_keys:
389                data = {
390                    k: (
391                        np.asarray(model_output[k])
392                        if isinstance(model_output[k], jnp.ndarray)
393                        else model_output[k]
394                    )
395                    for k in save_keys
396                }
397                data["properties"] = {
398                    k: float(v) for k, v in zip(header.split()[1:], line.split())
399                }
400                data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str)
401                data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str)
402
403                pickle.dump(data, fkeys)
404
405        ### save frame
406        if istep % ndump == 0:
407            line = "# Write XYZ frame"
408            if variable_cell:
409                cell = np.array(system["cell"])
410                reciprocal_cell = np.linalg.inv(cell)
411            if do_wrap_box:
412                if pimd:
413                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups)
414                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
415                else:
416                    system["coordinates"] = wrapbox(
417                        system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups
418                    )
419                conformation = update_conformation(conformation, system)
420                line += " (atoms have been wrapped into the box)"
421                force_preprocess = True
422            print(line)
423
424            save_dynamics_restart(system_data, conformation, dyn_state, system)
425
426            properties = {
427                "energy": float(system["epot"]) * energy_unit,
428                "Time_ps": start_time_ps + istep * dt / 1000,
429                "energy_unit": energy_unit_str,
430            }
431
432            if write_all_beads:
433                coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3))
434                for i, fb in enumerate(fout):
435                    write_frame(
436                        fb,
437                        system_data["symbols"],
438                        coords[i],
439                        cell=cell,
440                        properties=properties,
441                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
442                    )
443            else:
444                write_frame(
445                    fout,
446                    system_data["symbols"],
447                    np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]),
448                    cell=cell,
449                    properties=properties,
450                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
451                )
452            if write_centroid:
453                centroid = np.asarray(system["coordinates"][0])
454                write_frame(
455                    fcentroid,
456                    system_data["symbols"],
457                    centroid,
458                    cell=cell,
459                    properties=properties,
460                    forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0])
461                    * energy_unit,
462                )
463            if ensemble_key is not None:
464                weights = " ".join(
465                    [f"{w:.6f}" for w in system["ensemble_weights"].tolist()]
466                )
467                fens.write(f"{weights}\n")
468                fens.flush()
469
470        ### summary over last nsummary steps
471        if istep % (nsummary) == 0:
472            if check_nan(system):
473                raise ValueError(f"dynamics crashed at step {istep}.")
474            tfull = time.time() - t0full
475            t0full = time.time()
476            tperstep = tfull / (nsummary)
477            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
478            elapsed_time = time.time() - tstart_dyn
479            estimated_remaining_time = tperstep * (nsteps - istep)
480            estimated_total_time = elapsed_time + estimated_remaining_time
481
482            print("#" * 50)
483            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
484            print(f"# Simulated time      : {istep * dt*1.e-3:.3f} ps")
485            print(f"# Tot. Simu. time     : {start_time_ps + istep * dt*1.e-3:.3f} ps")
486            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
487            print(
488                f"# Est. total duration   : {human_time_duration(estimated_total_time)}"
489            )
490            print(
491                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
492            )
493            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
494
495            corr_kin = system["thermostat"].get("corr_kin", None)
496            if corr_kin is not None:
497                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
498            print(f"# Averages over last {nsummary:_} steps :")
499            for k, v in properties_traj.items():
500                if len(properties_traj[k]) == 0:
501                    continue
502                mu = np.mean(properties_traj[k])
503                sig = np.std(properties_traj[k])
504                ksplit = k.split("[")
505                name = ksplit[0].strip()
506                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
507                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
508
509            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
510            print("#" * 50)
511            if istep < nsteps:
512                print(header)
513            ## reset property trajectories
514            properties_traj = defaultdict(list)
515
516    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
517    ### close trajectory file
518    fout.close()
519    if ensemble_key is not None:
520        fens.close()
521    if use_colvars:
522        fcolvars.close()
523    if fkeys is not None:
524        fkeys.close()
525    if write_centroid:
526        fcentroid.close()
527
528    return 0