fennol.md.dynamic

  1import sys, os, io
  2import argparse
  3import time
  4from pathlib import Path
  5
  6import numpy as np
  7from collections import defaultdict
  8import jax
  9import jax.numpy as jnp
 10import pickle
 11import yaml
 12
 13from ..utils.io import (
 14    write_arc_frame,
 15    write_xyz_frame,
 16    write_extxyz_frame,
 17    human_time_duration,
 18)
 19from .utils import wrapbox, save_dynamics_restart
 20from ..utils import minmaxone, AtomicUnits as au
 21from ..utils.input_parser import parse_input,convert_dict_units, InputFile
 22from .integrate import initialize_dynamics
 23
 24
 25def main():
 26    # os.environ["OMP_NUM_THREADS"] = "1"
 27    sys.stdout = io.TextIOWrapper(
 28        open(sys.stdout.fileno(), "wb", 0), write_through=True
 29    )
 30    ### Read the parameter file
 31    parser = argparse.ArgumentParser(prog="fennol_md")
 32    parser.add_argument("param_file", type=Path, help="Parameter file")
 33    args = parser.parse_args()
 34
 35    param_file = args.param_file
 36    if not param_file.exists() and not param_file.is_file():
 37        raise FileNotFoundError(f"Parameter file {param_file} not found")
 38    
 39    if param_file.suffix in [".yaml", ".yml"]:
 40        with open(param_file, "r") as f:
 41            simulation_parameters = convert_dict_units(yaml.safe_load(f))
 42            simulation_parameters = InputFile(**simulation_parameters)
 43    elif param_file.suffix == ".fnl":
 44        simulation_parameters = parse_input(args.param_file)
 45    else:
 46        raise ValueError(
 47            f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'"
 48        )
 49
 50    ### Set the device
 51    device: str = simulation_parameters.get("device", "cpu").lower()
 52    if device == "cpu":
 53        device = "cpu"
 54        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 55    elif device.startswith("cuda") or device.startswith("gpu"):
 56        if ":" in device:
 57            num = device.split(":")[-1]
 58            os.environ["CUDA_VISIBLE_DEVICES"] = num
 59        else:
 60            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 61        device = "gpu"
 62
 63    _device = jax.devices(device)[0]
 64    jax.config.update("jax_default_device", _device)
 65
 66    ### Set the precision
 67    enable_x64 = simulation_parameters.get("double_precision", False)
 68    jax.config.update("jax_enable_x64", enable_x64)
 69    fprec = "float64" if enable_x64 else "float32"
 70
 71    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
 72    assert matmul_precision in [
 73        "default",
 74        "high",
 75        "highest",
 76    ], "matmul_prec must be one of 'default','high','highest'"
 77    if matmul_precision != "highest":
 78        print(f"# Setting matmul precision to '{matmul_precision}'")
 79    if matmul_precision == "default" and fprec == "float32":
 80        print(
 81            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
 82        )
 83    jax.config.update("jax_default_matmul_precision", matmul_precision)
 84
 85    # with jax.default_device(_device):
 86    dynamic(simulation_parameters, device, fprec)
 87
 88
 89def dynamic(simulation_parameters, device, fprec):
 90    tstart_dyn = time.time()
 91
 92    random_seed = simulation_parameters.get(
 93        "random_seed", np.random.randint(0, 2**32 - 1)
 94    )
 95    print(f"# random_seed: {random_seed}")
 96    rng_key = jax.random.PRNGKey(random_seed)
 97
 98    ## INITIALIZE INTEGRATOR AND SYSTEM
 99    rng_key, subkey = jax.random.split(rng_key)
100    step, update_conformation, system_data, dyn_state, conformation, system = (
101        initialize_dynamics(simulation_parameters, fprec, subkey)
102    )
103
104    nat = system_data["nat"]
105    dt = dyn_state["dt"]
106    ## get number of steps
107    nsteps = int(simulation_parameters.get("nsteps"))
108    start_time_ps = dyn_state.get("start_time_ps", 0.0)
109
110    ### Set I/O parameters
111    Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS
112    ndump = int(Tdump / dt)
113    system_name = system_data["name"]
114    estimate_pressure = dyn_state["estimate_pressure"]
115
116    @jax.jit
117    def check_nan(system):
118        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
119            jnp.isnan(system["coordinates"])
120        )
121
122    if system_data["pbc"] is not None:
123        cell = system["cell"]
124        reciprocal_cell = np.linalg.inv(cell)
125        do_wrap_box = simulation_parameters.get("wrap_box", False)
126    else:
127        cell = None
128        reciprocal_cell = None
129        do_wrap_box = False
130
131    ### Energy units and print initial energy
132    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
133    energy_unit_str = system_data["energy_unit_str"]
134    energy_unit = system_data["energy_unit"]
135    print("# Energy unit: ", energy_unit_str)
136    atom_energy_unit = energy_unit
137    atom_energy_unit_str = energy_unit_str
138    if per_atom_energy:
139        atom_energy_unit /= nat
140        atom_energy_unit_str = f"{energy_unit_str}/atom"
141        print("# Printing Energy per atom")
142    print(
143        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
144    )
145    f = system["forces"]
146    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
147
148    ## printing options
149    print_timings = simulation_parameters.get("print_timings", False)
150    nprint = int(simulation_parameters.get("nprint", 10))
151    assert nprint > 0, "nprint must be > 0"
152    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
153    assert nsummary > nprint, "nsummary must be > nprint"
154
155    save_keys = simulation_parameters.get("save_keys", [])
156    if save_keys:
157        print(f"# Saving keys: {save_keys}")
158        fkeys = open(f"{system_name}.traj.pkl", "wb+")
159    else:
160        fkeys = None
161    
162    ### initialize colvars
163    use_colvars = "colvars" in dyn_state
164    if use_colvars:
165        print(f"# Colvars: {dyn_state['colvars']}")
166        colvars_names = dyn_state["colvars"]
167        # open colvars file and print header
168        fcolvars = open(f"{system_name}.colvars.traj", "a")
169        fcolvars.write("#time[ps] ")
170        fcolvars.write(" ".join(colvars_names))
171        fcolvars.write("\n")
172        fcolvars.flush()
173
174    ### Print header
175    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
176    thermostat_name = dyn_state["thermostat_name"]
177    pimd = dyn_state["pimd"]
178    variable_cell = dyn_state["variable_cell"]
179    nbeads = system_data.get("nbeads", 1)
180    dyn_name = "PIMD" if pimd else "MD"
181    print("#" * 84)
182    print(
183        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
184    )
185    header = "#     Step   Time[ps]        Etot        Epot        Ekin    Temp[K]"
186    if pimd:
187        header += "  Temp_c[K]"
188    if include_thermostat_energy:
189        header += "      Etherm"
190    if estimate_pressure:
191        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
192        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
193        pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3
194        header += f"  Press[{pressure_unit_str}]"
195    if variable_cell:
196        header += "   Density"
197    print(header)
198
199    ### Open trajectory file
200    traj_format = simulation_parameters.get("traj_format", "arc").lower()
201    if traj_format == "xyz":
202        traj_ext = ".traj.xyz"
203        write_frame = write_xyz_frame
204    elif traj_format == "extxyz":
205        traj_ext = ".traj.extxyz"
206        write_frame = write_extxyz_frame
207    elif traj_format == "arc":
208        traj_ext = ".arc"
209        write_frame = write_arc_frame
210    else:
211        raise ValueError(
212            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
213        )
214
215    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
216
217    if write_all_beads:
218        fout = [
219            open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads)
220        ]
221    else:
222        fout = open(system_name + traj_ext, "a")
223
224    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
225    if ensemble_key is not None:
226        fens = open(f"{system_name}.ensemble_weights.traj", "a")
227
228    write_centroid = simulation_parameters.get("write_centroid", False) and pimd
229    if write_centroid:
230        fcentroid = open(f"{system_name}_centroid" + traj_ext, "a")
231
232    ### initialize proprerty trajectories
233    properties_traj = defaultdict(list)
234
235    ### initialize counters and timers
236    t0 = time.time()
237    t0dump = t0
238    istep = 0
239    t0full = time.time()
240    force_preprocess = False
241
242    for istep in range(1, nsteps + 1):
243
244        ### update the system
245        dyn_state, system, conformation, model_output = step(
246            istep, dyn_state, system, conformation, force_preprocess
247        )
248
249        ### print properties
250        if istep % nprint == 0:
251            t1 = time.time()
252            tperstep = (t1 - t0) / nprint
253            t0 = t1
254            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
255
256            ek = system["ek"]
257            epot = system["epot"]
258            etot = ek + epot
259            temper = 2 * ek / (3.0 * nat) * au.KELVIN
260
261            th_state = system["thermostat"]
262            if include_thermostat_energy:
263                etherm = th_state["thermostat_energy"]
264                etot = etot + etherm
265
266            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
267                etot * atom_energy_unit
268            )
269            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
270                epot * atom_energy_unit
271            )
272            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
273                ek * atom_energy_unit
274            )
275            properties_traj["Temper[Kelvin]"].append(temper)
276            if pimd:
277                ek_c = system["ek_c"]
278                temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN
279                properties_traj["Temper_c[Kelvin]"].append(temper_c)
280
281            ### construct line of properties
282            simulated_time = start_time_ps + istep * dt / 1000
283            line = f"{istep:10.6g} {simulated_time: 10.3f}  {etot*atom_energy_unit: #10.4f}  {epot*atom_energy_unit: #10.4f}  {ek*atom_energy_unit: #10.4f} {temper: 10.2f}"
284            if pimd:
285                line += f" {temper_c: 10.2f}"
286            if include_thermostat_energy:
287                line += f"  {etherm*atom_energy_unit: #10.4f}"
288                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
289                    etherm * atom_energy_unit
290                )
291            if estimate_pressure:
292                pres = system["pressure"] * pressure_unit
293                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
294                if print_aniso_pressure:
295                    pres_tensor = system["pressure_tensor"] * pressure_unit
296                    pres_tensor = 0.5 * (pres_tensor + pres_tensor.T)
297                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(
298                        pres_tensor[0, 0]
299                    )
300                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(
301                        pres_tensor[1, 1]
302                    )
303                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(
304                        pres_tensor[2, 2]
305                    )
306                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(
307                        pres_tensor[0, 1]
308                    )
309                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(
310                        pres_tensor[0, 2]
311                    )
312                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(
313                        pres_tensor[1, 2]
314                    )
315                line += f" {pres:10.3f}"
316            if variable_cell:
317                density = system["density"]
318                properties_traj["Density[g/cm^3]"].append(density)
319                if print_aniso_pressure:
320                    cell = system["cell"]
321                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0])
322                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1])
323                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2])
324                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0])
325                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1])
326                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2])
327                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0])
328                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1])
329                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2])
330                line += f" {density:10.4f}"
331                if "piston_temperature" in system["barostat"]:
332                    piston_temperature = system["barostat"]["piston_temperature"]
333                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
334
335            print(line)
336
337            ### save colvars
338            if use_colvars:
339                colvars = system["colvars"]
340                fcolvars.write(f"{simulated_time:.3f} ")
341                fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names]))
342                fcolvars.write("\n")
343                fcolvars.flush()
344
345            if save_keys:
346                data = {
347                    k: (
348                        np.asarray(model_output[k])
349                        if isinstance(model_output[k], jnp.ndarray)
350                        else model_output[k]
351                    )
352                    for k in save_keys
353                }
354                data["properties"] = {
355                    k: float(v) for k, v in zip(header.split()[1:], line.split())
356                }
357                pickle.dump(data, fkeys)
358
359        ### save frame
360        if istep % ndump == 0:
361            line = "# Write XYZ frame"
362            if variable_cell:
363                cell = np.array(system["cell"])
364                reciprocal_cell = np.linalg.inv(cell)
365            if do_wrap_box:
366                if pimd:
367                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell)
368                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
369                else:
370                    system["coordinates"] = wrapbox(
371                        system["coordinates"], cell, reciprocal_cell
372                    )
373                conformation = update_conformation(conformation, system)
374                line += " (atoms have been wrapped into the box)"
375                force_preprocess = True
376            print(line)
377
378            save_dynamics_restart(system_data, conformation, dyn_state, system)
379
380            properties = {
381                "energy": float(system["epot"]) * energy_unit,
382                "Time_ps": start_time_ps + istep * dt / 1000,
383                "energy_unit": energy_unit_str,
384            }
385
386            if write_all_beads:
387                coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3))
388                for i, fb in enumerate(fout):
389                    write_frame(
390                        fb,
391                        system_data["symbols"],
392                        coords[i],
393                        cell=cell,
394                        properties=properties,
395                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
396                    )
397            else:
398                write_frame(
399                    fout,
400                    system_data["symbols"],
401                    np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]),
402                    cell=cell,
403                    properties=properties,
404                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
405                )
406            if write_centroid:
407                centroid = np.asarray(system["coordinates"][0])
408                write_frame(
409                    fcentroid,
410                    system_data["symbols"],
411                    centroid,
412                    cell=cell,
413                    properties=properties,
414                    forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0])
415                    * energy_unit,
416                )
417            if ensemble_key is not None:
418                weights = " ".join(
419                    [f"{w:.6f}" for w in system["ensemble_weights"].tolist()]
420                )
421                fens.write(f"{weights}\n")
422                fens.flush()
423
424        ### summary over last nsummary steps
425        if istep % (nsummary) == 0:
426            if check_nan(system):
427                raise ValueError(f"dynamics crashed at step {istep}.")
428            tfull = time.time() - t0full
429            t0full = time.time()
430            tperstep = tfull / (nsummary)
431            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
432            elapsed_time = time.time() - tstart_dyn
433            estimated_remaining_time = tperstep * (nsteps - istep)
434            estimated_total_time = elapsed_time + estimated_remaining_time
435
436            print("#" * 50)
437            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
438            print(f"# Simulated time      : {istep * dt*1.e-3:.3f} ps")
439            print(f"# Tot. Simu. time     : {start_time_ps + istep * dt*1.e-3:.3f} ps")
440            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
441            print(
442                f"# Est. total duration   : {human_time_duration(estimated_total_time)}"
443            )
444            print(
445                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
446            )
447            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
448
449            corr_kin = system["thermostat"].get("corr_kin", None)
450            if corr_kin is not None:
451                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
452            print(f"# Averages over last {nsummary:_} steps :")
453            for k, v in properties_traj.items():
454                if len(properties_traj[k]) == 0:
455                    continue
456                mu = np.mean(properties_traj[k])
457                sig = np.std(properties_traj[k])
458                ksplit = k.split("[")
459                name = ksplit[0].strip()
460                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
461                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
462
463            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
464            print("#" * 50)
465            if istep < nsteps:
466                print(header)
467            ## reset property trajectories
468            properties_traj = defaultdict(list)
469
470    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
471    ### close trajectory file
472    fout.close()
473    if ensemble_key is not None:
474        fens.close()
475    if use_colvars:
476        fcolvars.close()
477    if fkeys is not None:
478        fkeys.close()
479    if write_centroid:
480        fcentroid.close()
481
482
483if __name__ == "__main__":
484    main()
def main():
26def main():
27    # os.environ["OMP_NUM_THREADS"] = "1"
28    sys.stdout = io.TextIOWrapper(
29        open(sys.stdout.fileno(), "wb", 0), write_through=True
30    )
31    ### Read the parameter file
32    parser = argparse.ArgumentParser(prog="fennol_md")
33    parser.add_argument("param_file", type=Path, help="Parameter file")
34    args = parser.parse_args()
35
36    param_file = args.param_file
37    if not param_file.exists() and not param_file.is_file():
38        raise FileNotFoundError(f"Parameter file {param_file} not found")
39    
40    if param_file.suffix in [".yaml", ".yml"]:
41        with open(param_file, "r") as f:
42            simulation_parameters = convert_dict_units(yaml.safe_load(f))
43            simulation_parameters = InputFile(**simulation_parameters)
44    elif param_file.suffix == ".fnl":
45        simulation_parameters = parse_input(args.param_file)
46    else:
47        raise ValueError(
48            f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'"
49        )
50
51    ### Set the device
52    device: str = simulation_parameters.get("device", "cpu").lower()
53    if device == "cpu":
54        device = "cpu"
55        os.environ["CUDA_VISIBLE_DEVICES"] = ""
56    elif device.startswith("cuda") or device.startswith("gpu"):
57        if ":" in device:
58            num = device.split(":")[-1]
59            os.environ["CUDA_VISIBLE_DEVICES"] = num
60        else:
61            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
62        device = "gpu"
63
64    _device = jax.devices(device)[0]
65    jax.config.update("jax_default_device", _device)
66
67    ### Set the precision
68    enable_x64 = simulation_parameters.get("double_precision", False)
69    jax.config.update("jax_enable_x64", enable_x64)
70    fprec = "float64" if enable_x64 else "float32"
71
72    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
73    assert matmul_precision in [
74        "default",
75        "high",
76        "highest",
77    ], "matmul_prec must be one of 'default','high','highest'"
78    if matmul_precision != "highest":
79        print(f"# Setting matmul precision to '{matmul_precision}'")
80    if matmul_precision == "default" and fprec == "float32":
81        print(
82            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
83        )
84    jax.config.update("jax_default_matmul_precision", matmul_precision)
85
86    # with jax.default_device(_device):
87    dynamic(simulation_parameters, device, fprec)
def dynamic(simulation_parameters, device, fprec):
 90def dynamic(simulation_parameters, device, fprec):
 91    tstart_dyn = time.time()
 92
 93    random_seed = simulation_parameters.get(
 94        "random_seed", np.random.randint(0, 2**32 - 1)
 95    )
 96    print(f"# random_seed: {random_seed}")
 97    rng_key = jax.random.PRNGKey(random_seed)
 98
 99    ## INITIALIZE INTEGRATOR AND SYSTEM
100    rng_key, subkey = jax.random.split(rng_key)
101    step, update_conformation, system_data, dyn_state, conformation, system = (
102        initialize_dynamics(simulation_parameters, fprec, subkey)
103    )
104
105    nat = system_data["nat"]
106    dt = dyn_state["dt"]
107    ## get number of steps
108    nsteps = int(simulation_parameters.get("nsteps"))
109    start_time_ps = dyn_state.get("start_time_ps", 0.0)
110
111    ### Set I/O parameters
112    Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS
113    ndump = int(Tdump / dt)
114    system_name = system_data["name"]
115    estimate_pressure = dyn_state["estimate_pressure"]
116
117    @jax.jit
118    def check_nan(system):
119        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
120            jnp.isnan(system["coordinates"])
121        )
122
123    if system_data["pbc"] is not None:
124        cell = system["cell"]
125        reciprocal_cell = np.linalg.inv(cell)
126        do_wrap_box = simulation_parameters.get("wrap_box", False)
127    else:
128        cell = None
129        reciprocal_cell = None
130        do_wrap_box = False
131
132    ### Energy units and print initial energy
133    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
134    energy_unit_str = system_data["energy_unit_str"]
135    energy_unit = system_data["energy_unit"]
136    print("# Energy unit: ", energy_unit_str)
137    atom_energy_unit = energy_unit
138    atom_energy_unit_str = energy_unit_str
139    if per_atom_energy:
140        atom_energy_unit /= nat
141        atom_energy_unit_str = f"{energy_unit_str}/atom"
142        print("# Printing Energy per atom")
143    print(
144        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
145    )
146    f = system["forces"]
147    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
148
149    ## printing options
150    print_timings = simulation_parameters.get("print_timings", False)
151    nprint = int(simulation_parameters.get("nprint", 10))
152    assert nprint > 0, "nprint must be > 0"
153    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
154    assert nsummary > nprint, "nsummary must be > nprint"
155
156    save_keys = simulation_parameters.get("save_keys", [])
157    if save_keys:
158        print(f"# Saving keys: {save_keys}")
159        fkeys = open(f"{system_name}.traj.pkl", "wb+")
160    else:
161        fkeys = None
162    
163    ### initialize colvars
164    use_colvars = "colvars" in dyn_state
165    if use_colvars:
166        print(f"# Colvars: {dyn_state['colvars']}")
167        colvars_names = dyn_state["colvars"]
168        # open colvars file and print header
169        fcolvars = open(f"{system_name}.colvars.traj", "a")
170        fcolvars.write("#time[ps] ")
171        fcolvars.write(" ".join(colvars_names))
172        fcolvars.write("\n")
173        fcolvars.flush()
174
175    ### Print header
176    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
177    thermostat_name = dyn_state["thermostat_name"]
178    pimd = dyn_state["pimd"]
179    variable_cell = dyn_state["variable_cell"]
180    nbeads = system_data.get("nbeads", 1)
181    dyn_name = "PIMD" if pimd else "MD"
182    print("#" * 84)
183    print(
184        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
185    )
186    header = "#     Step   Time[ps]        Etot        Epot        Ekin    Temp[K]"
187    if pimd:
188        header += "  Temp_c[K]"
189    if include_thermostat_energy:
190        header += "      Etherm"
191    if estimate_pressure:
192        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
193        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
194        pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3
195        header += f"  Press[{pressure_unit_str}]"
196    if variable_cell:
197        header += "   Density"
198    print(header)
199
200    ### Open trajectory file
201    traj_format = simulation_parameters.get("traj_format", "arc").lower()
202    if traj_format == "xyz":
203        traj_ext = ".traj.xyz"
204        write_frame = write_xyz_frame
205    elif traj_format == "extxyz":
206        traj_ext = ".traj.extxyz"
207        write_frame = write_extxyz_frame
208    elif traj_format == "arc":
209        traj_ext = ".arc"
210        write_frame = write_arc_frame
211    else:
212        raise ValueError(
213            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
214        )
215
216    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
217
218    if write_all_beads:
219        fout = [
220            open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads)
221        ]
222    else:
223        fout = open(system_name + traj_ext, "a")
224
225    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
226    if ensemble_key is not None:
227        fens = open(f"{system_name}.ensemble_weights.traj", "a")
228
229    write_centroid = simulation_parameters.get("write_centroid", False) and pimd
230    if write_centroid:
231        fcentroid = open(f"{system_name}_centroid" + traj_ext, "a")
232
233    ### initialize proprerty trajectories
234    properties_traj = defaultdict(list)
235
236    ### initialize counters and timers
237    t0 = time.time()
238    t0dump = t0
239    istep = 0
240    t0full = time.time()
241    force_preprocess = False
242
243    for istep in range(1, nsteps + 1):
244
245        ### update the system
246        dyn_state, system, conformation, model_output = step(
247            istep, dyn_state, system, conformation, force_preprocess
248        )
249
250        ### print properties
251        if istep % nprint == 0:
252            t1 = time.time()
253            tperstep = (t1 - t0) / nprint
254            t0 = t1
255            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
256
257            ek = system["ek"]
258            epot = system["epot"]
259            etot = ek + epot
260            temper = 2 * ek / (3.0 * nat) * au.KELVIN
261
262            th_state = system["thermostat"]
263            if include_thermostat_energy:
264                etherm = th_state["thermostat_energy"]
265                etot = etot + etherm
266
267            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
268                etot * atom_energy_unit
269            )
270            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
271                epot * atom_energy_unit
272            )
273            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
274                ek * atom_energy_unit
275            )
276            properties_traj["Temper[Kelvin]"].append(temper)
277            if pimd:
278                ek_c = system["ek_c"]
279                temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN
280                properties_traj["Temper_c[Kelvin]"].append(temper_c)
281
282            ### construct line of properties
283            simulated_time = start_time_ps + istep * dt / 1000
284            line = f"{istep:10.6g} {simulated_time: 10.3f}  {etot*atom_energy_unit: #10.4f}  {epot*atom_energy_unit: #10.4f}  {ek*atom_energy_unit: #10.4f} {temper: 10.2f}"
285            if pimd:
286                line += f" {temper_c: 10.2f}"
287            if include_thermostat_energy:
288                line += f"  {etherm*atom_energy_unit: #10.4f}"
289                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
290                    etherm * atom_energy_unit
291                )
292            if estimate_pressure:
293                pres = system["pressure"] * pressure_unit
294                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
295                if print_aniso_pressure:
296                    pres_tensor = system["pressure_tensor"] * pressure_unit
297                    pres_tensor = 0.5 * (pres_tensor + pres_tensor.T)
298                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(
299                        pres_tensor[0, 0]
300                    )
301                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(
302                        pres_tensor[1, 1]
303                    )
304                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(
305                        pres_tensor[2, 2]
306                    )
307                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(
308                        pres_tensor[0, 1]
309                    )
310                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(
311                        pres_tensor[0, 2]
312                    )
313                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(
314                        pres_tensor[1, 2]
315                    )
316                line += f" {pres:10.3f}"
317            if variable_cell:
318                density = system["density"]
319                properties_traj["Density[g/cm^3]"].append(density)
320                if print_aniso_pressure:
321                    cell = system["cell"]
322                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0])
323                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1])
324                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2])
325                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0])
326                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1])
327                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2])
328                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0])
329                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1])
330                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2])
331                line += f" {density:10.4f}"
332                if "piston_temperature" in system["barostat"]:
333                    piston_temperature = system["barostat"]["piston_temperature"]
334                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
335
336            print(line)
337
338            ### save colvars
339            if use_colvars:
340                colvars = system["colvars"]
341                fcolvars.write(f"{simulated_time:.3f} ")
342                fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names]))
343                fcolvars.write("\n")
344                fcolvars.flush()
345
346            if save_keys:
347                data = {
348                    k: (
349                        np.asarray(model_output[k])
350                        if isinstance(model_output[k], jnp.ndarray)
351                        else model_output[k]
352                    )
353                    for k in save_keys
354                }
355                data["properties"] = {
356                    k: float(v) for k, v in zip(header.split()[1:], line.split())
357                }
358                pickle.dump(data, fkeys)
359
360        ### save frame
361        if istep % ndump == 0:
362            line = "# Write XYZ frame"
363            if variable_cell:
364                cell = np.array(system["cell"])
365                reciprocal_cell = np.linalg.inv(cell)
366            if do_wrap_box:
367                if pimd:
368                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell)
369                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
370                else:
371                    system["coordinates"] = wrapbox(
372                        system["coordinates"], cell, reciprocal_cell
373                    )
374                conformation = update_conformation(conformation, system)
375                line += " (atoms have been wrapped into the box)"
376                force_preprocess = True
377            print(line)
378
379            save_dynamics_restart(system_data, conformation, dyn_state, system)
380
381            properties = {
382                "energy": float(system["epot"]) * energy_unit,
383                "Time_ps": start_time_ps + istep * dt / 1000,
384                "energy_unit": energy_unit_str,
385            }
386
387            if write_all_beads:
388                coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3))
389                for i, fb in enumerate(fout):
390                    write_frame(
391                        fb,
392                        system_data["symbols"],
393                        coords[i],
394                        cell=cell,
395                        properties=properties,
396                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
397                    )
398            else:
399                write_frame(
400                    fout,
401                    system_data["symbols"],
402                    np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]),
403                    cell=cell,
404                    properties=properties,
405                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
406                )
407            if write_centroid:
408                centroid = np.asarray(system["coordinates"][0])
409                write_frame(
410                    fcentroid,
411                    system_data["symbols"],
412                    centroid,
413                    cell=cell,
414                    properties=properties,
415                    forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0])
416                    * energy_unit,
417                )
418            if ensemble_key is not None:
419                weights = " ".join(
420                    [f"{w:.6f}" for w in system["ensemble_weights"].tolist()]
421                )
422                fens.write(f"{weights}\n")
423                fens.flush()
424
425        ### summary over last nsummary steps
426        if istep % (nsummary) == 0:
427            if check_nan(system):
428                raise ValueError(f"dynamics crashed at step {istep}.")
429            tfull = time.time() - t0full
430            t0full = time.time()
431            tperstep = tfull / (nsummary)
432            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
433            elapsed_time = time.time() - tstart_dyn
434            estimated_remaining_time = tperstep * (nsteps - istep)
435            estimated_total_time = elapsed_time + estimated_remaining_time
436
437            print("#" * 50)
438            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
439            print(f"# Simulated time      : {istep * dt*1.e-3:.3f} ps")
440            print(f"# Tot. Simu. time     : {start_time_ps + istep * dt*1.e-3:.3f} ps")
441            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
442            print(
443                f"# Est. total duration   : {human_time_duration(estimated_total_time)}"
444            )
445            print(
446                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
447            )
448            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
449
450            corr_kin = system["thermostat"].get("corr_kin", None)
451            if corr_kin is not None:
452                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
453            print(f"# Averages over last {nsummary:_} steps :")
454            for k, v in properties_traj.items():
455                if len(properties_traj[k]) == 0:
456                    continue
457                mu = np.mean(properties_traj[k])
458                sig = np.std(properties_traj[k])
459                ksplit = k.split("[")
460                name = ksplit[0].strip()
461                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
462                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
463
464            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
465            print("#" * 50)
466            if istep < nsteps:
467                print(header)
468            ## reset property trajectories
469            properties_traj = defaultdict(list)
470
471    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
472    ### close trajectory file
473    fout.close()
474    if ensemble_key is not None:
475        fens.close()
476    if use_colvars:
477        fcolvars.close()
478    if fkeys is not None:
479        fkeys.close()
480    if write_centroid:
481        fcentroid.close()