fennol.md.dynamic

  1import sys, os, io
  2import argparse
  3import time
  4from pathlib import Path
  5
  6import numpy as np
  7from collections import defaultdict
  8import jax
  9import jax.numpy as jnp
 10import pickle
 11import yaml
 12
 13from ..utils.io import (
 14    write_arc_frame,
 15    write_xyz_frame,
 16    write_extxyz_frame,
 17    human_time_duration,
 18)
 19from .utils import wrapbox, save_dynamics_restart
 20from ..utils import minmaxone, AtomicUnits as au,read_tinker_interval
 21from ..utils.input_parser import parse_input,convert_dict_units, InputFile
 22from .integrate import initialize_dynamics
 23
 24
 25def main():
 26    # os.environ["OMP_NUM_THREADS"] = "1"
 27    sys.stdout = io.TextIOWrapper(
 28        open(sys.stdout.fileno(), "wb", 0), write_through=True
 29    )
 30    ### Read the parameter file
 31    parser = argparse.ArgumentParser(prog="fennol_md")
 32    parser.add_argument("param_file", type=Path, help="Parameter file")
 33    args = parser.parse_args()
 34
 35    param_file = args.param_file
 36    if not param_file.exists() and not param_file.is_file():
 37        raise FileNotFoundError(f"Parameter file {param_file} not found")
 38    
 39    if param_file.suffix in [".yaml", ".yml"]:
 40        with open(param_file, "r") as f:
 41            simulation_parameters = convert_dict_units(yaml.safe_load(f))
 42            simulation_parameters = InputFile(**simulation_parameters)
 43    elif param_file.suffix == ".fnl":
 44        simulation_parameters = parse_input(args.param_file)
 45    else:
 46        raise ValueError(
 47            f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'"
 48        )
 49
 50    ### Set the device
 51    if "FENNOL_DEVICE" in os.environ:
 52        device = os.environ["FENNOL_DEVICE"].lower()
 53        print(f"# Setting device from env FENNOL_DEVICE={device}")
 54    else:
 55        device = simulation_parameters.get("device", "cpu").lower()
 56    if device == "cpu":
 57        jax.config.update('jax_platforms', 'cpu')
 58        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 59    elif device.startswith("cuda") or device.startswith("gpu"):
 60        if ":" in device:
 61            num = device.split(":")[-1]
 62            os.environ["CUDA_VISIBLE_DEVICES"] = num
 63        else:
 64            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 65        device = "gpu"
 66
 67    _device = jax.devices(device)[0]
 68    jax.config.update("jax_default_device", _device)
 69
 70    ### Set the precision
 71    enable_x64 = simulation_parameters.get("double_precision", False)
 72    jax.config.update("jax_enable_x64", enable_x64)
 73    fprec = "float64" if enable_x64 else "float32"
 74
 75    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
 76    assert matmul_precision in [
 77        "default",
 78        "high",
 79        "highest",
 80    ], "matmul_prec must be one of 'default','high','highest'"
 81    if matmul_precision != "highest":
 82        print(f"# Setting matmul precision to '{matmul_precision}'")
 83    if matmul_precision == "default" and fprec == "float32":
 84        print(
 85            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
 86        )
 87    jax.config.update("jax_default_matmul_precision", matmul_precision)
 88
 89    # with jax.default_device(_device):
 90    dynamic(simulation_parameters, device, fprec)
 91
 92
 93def dynamic(simulation_parameters, device, fprec):
 94    tstart_dyn = time.time()
 95
 96    random_seed = simulation_parameters.get(
 97        "random_seed", np.random.randint(0, 2**32 - 1)
 98    )
 99    print(f"# random_seed: {random_seed}")
100    rng_key = jax.random.PRNGKey(random_seed)
101
102    ## INITIALIZE INTEGRATOR AND SYSTEM
103    rng_key, subkey = jax.random.split(rng_key)
104    step, update_conformation, system_data, dyn_state, conformation, system = (
105        initialize_dynamics(simulation_parameters, fprec, subkey)
106    )
107
108    nat = system_data["nat"]
109    dt = dyn_state["dt"]
110    ## get number of steps
111    nsteps = int(simulation_parameters.get("nsteps"))
112    start_time_ps = dyn_state.get("start_time_ps", 0.0)
113
114    ### Set I/O parameters
115    Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS
116    ndump = int(Tdump / dt)
117    system_name = system_data["name"]
118    estimate_pressure = dyn_state["estimate_pressure"]
119
120    @jax.jit
121    def check_nan(system):
122        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
123            jnp.isnan(system["coordinates"])
124        )
125
126    if system_data["pbc"] is not None:
127        cell = system["cell"]
128        reciprocal_cell = np.linalg.inv(cell)
129        do_wrap_box = simulation_parameters.get("wrap_box", False)
130        if do_wrap_box:
131            wrap_groups_def = simulation_parameters.get("wrap_groups",None)
132            if wrap_groups_def is None:
133                wrap_groups = None
134            else:
135                wrap_groups = {}
136                assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary"
137                for k, v in wrap_groups_def.items():
138                    wrap_groups[k]=read_tinker_interval(v)
139                # check that pairwise intersection of wrap groups is empty
140                wrap_groups_keys = list(wrap_groups.keys())
141                for i in range(len(wrap_groups_keys)):
142                    i_key = wrap_groups_keys[i]
143                    w1 = set(wrap_groups[i_key])
144                    for j in range(i + 1, len(wrap_groups_keys)):
145                        j_key = wrap_groups_keys[j]
146                        w2 = set(wrap_groups[j_key])
147                        if  w1.intersection(w2):
148                            raise ValueError(
149                                f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}"
150                            )
151                group_all = np.concatenate(list(wrap_groups.values()))
152                # get all atoms that are not in any wrap group
153                group_none = np.setdiff1d(np.arange(nat), group_all)
154                print(f"# Wrap groups: {wrap_groups}")
155                wrap_groups["__other"] = group_none
156                wrap_groups = ((k, v) for k, v in wrap_groups.items())
157
158    else:
159        cell = None
160        reciprocal_cell = None
161        do_wrap_box = False
162        wrap_groups = None
163
164    ### Energy units and print initial energy
165    model_energy_unit = system_data["model_energy_unit"]
166    model_energy_unit_str = system_data["model_energy_unit_str"]
167    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
168    energy_unit_str = system_data["energy_unit_str"]
169    energy_unit = system_data["energy_unit"]
170    print("# Energy unit: ", energy_unit_str)
171    atom_energy_unit = energy_unit
172    atom_energy_unit_str = energy_unit_str
173    if per_atom_energy:
174        atom_energy_unit /= nat
175        atom_energy_unit_str = f"{energy_unit_str}/atom"
176        print("# Printing Energy per atom")
177    print(
178        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
179    )
180    f = system["forces"]
181    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
182
183    ## printing options
184    print_timings = simulation_parameters.get("print_timings", False)
185    nprint = int(simulation_parameters.get("nprint", 10))
186    assert nprint > 0, "nprint must be > 0"
187    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
188    assert nsummary > nprint, "nsummary must be > nprint"
189
190    save_keys = simulation_parameters.get("save_keys", [])
191    if save_keys:
192        print(f"# Saving keys: {save_keys}")
193        fkeys = open(f"{system_name}.traj.pkl", "wb+")
194    else:
195        fkeys = None
196    
197    ### initialize colvars
198    use_colvars = "colvars" in dyn_state
199    if use_colvars:
200        print(f"# Colvars: {dyn_state['colvars']}")
201        colvars_names = dyn_state["colvars"]
202        # open colvars file and print header
203        fcolvars = open(f"{system_name}.colvars.traj", "a")
204        fcolvars.write("#time[ps] ")
205        fcolvars.write(" ".join(colvars_names))
206        fcolvars.write("\n")
207        fcolvars.flush()
208
209    ### Print header
210    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
211    thermostat_name = dyn_state["thermostat_name"]
212    pimd = dyn_state["pimd"]
213    variable_cell = dyn_state["variable_cell"]
214    nbeads = system_data.get("nbeads", 1)
215    dyn_name = "PIMD" if pimd else "MD"
216    print("#" * 84)
217    print(
218        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
219    )
220    header = "#     Step   Time[ps]        Etot        Epot        Ekin    Temp[K]"
221    if pimd:
222        header += "  Temp_c[K]"
223    if include_thermostat_energy:
224        header += "      Etherm"
225    if estimate_pressure:
226        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
227        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
228        pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3
229        header += f"  Press[{pressure_unit_str}]"
230    if variable_cell:
231        header += "   Density"
232    print(header)
233
234    ### Open trajectory file
235    traj_format = simulation_parameters.get("traj_format", "arc").lower()
236    if traj_format == "xyz":
237        traj_ext = ".traj.xyz"
238        write_frame = write_xyz_frame
239    elif traj_format == "extxyz":
240        traj_ext = ".traj.extxyz"
241        write_frame = write_extxyz_frame
242    elif traj_format == "arc":
243        traj_ext = ".arc"
244        write_frame = write_arc_frame
245    else:
246        raise ValueError(
247            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
248        )
249
250    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
251
252    if write_all_beads:
253        fout = [
254            open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads)
255        ]
256    else:
257        fout = open(system_name + traj_ext, "a")
258
259    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
260    if ensemble_key is not None:
261        fens = open(f"{system_name}.ensemble_weights.traj", "a")
262
263    write_centroid = simulation_parameters.get("write_centroid", False) and pimd
264    if write_centroid:
265        fcentroid = open(f"{system_name}_centroid" + traj_ext, "a")
266
267    ### initialize proprerty trajectories
268    properties_traj = defaultdict(list)
269
270    ### initialize counters and timers
271    t0 = time.time()
272    t0dump = t0
273    istep = 0
274    t0full = time.time()
275    force_preprocess = False
276
277    for istep in range(1, nsteps + 1):
278
279        ### update the system
280        dyn_state, system, conformation, model_output = step(
281            istep, dyn_state, system, conformation, force_preprocess
282        )
283
284        ### print properties
285        if istep % nprint == 0:
286            t1 = time.time()
287            tperstep = (t1 - t0) / nprint
288            t0 = t1
289            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
290
291            ek = system["ek"]
292            epot = system["epot"]
293            etot = ek + epot
294            temper = 2 * ek / (3.0 * nat) * au.KELVIN
295
296            th_state = system["thermostat"]
297            if include_thermostat_energy:
298                etherm = th_state["thermostat_energy"]
299                etot = etot + etherm
300
301            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
302                etot * atom_energy_unit
303            )
304            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
305                epot * atom_energy_unit
306            )
307            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
308                ek * atom_energy_unit
309            )
310            properties_traj["Temper[Kelvin]"].append(temper)
311            if pimd:
312                ek_c = system["ek_c"]
313                temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN
314                properties_traj["Temper_c[Kelvin]"].append(temper_c)
315
316            ### construct line of properties
317            simulated_time = start_time_ps + istep * dt / 1000
318            line = f"{istep:10.6g} {simulated_time: 10.3f}  {etot*atom_energy_unit: #10.4f}  {epot*atom_energy_unit: #10.4f}  {ek*atom_energy_unit: #10.4f} {temper: 10.2f}"
319            if pimd:
320                line += f" {temper_c: 10.2f}"
321            if include_thermostat_energy:
322                line += f"  {etherm*atom_energy_unit: #10.4f}"
323                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
324                    etherm * atom_energy_unit
325                )
326            if estimate_pressure:
327                pres = system["pressure"] * pressure_unit
328                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
329                if print_aniso_pressure:
330                    pres_tensor = system["pressure_tensor"] * pressure_unit
331                    pres_tensor = 0.5 * (pres_tensor + pres_tensor.T)
332                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(
333                        pres_tensor[0, 0]
334                    )
335                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(
336                        pres_tensor[1, 1]
337                    )
338                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(
339                        pres_tensor[2, 2]
340                    )
341                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(
342                        pres_tensor[0, 1]
343                    )
344                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(
345                        pres_tensor[0, 2]
346                    )
347                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(
348                        pres_tensor[1, 2]
349                    )
350                line += f" {pres:10.3f}"
351            if variable_cell:
352                density = system["density"]
353                properties_traj["Density[g/cm^3]"].append(density)
354                if print_aniso_pressure:
355                    cell = system["cell"]
356                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0])
357                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1])
358                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2])
359                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0])
360                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1])
361                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2])
362                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0])
363                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1])
364                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2])
365                line += f" {density:10.4f}"
366                if "piston_temperature" in system["barostat"]:
367                    piston_temperature = system["barostat"]["piston_temperature"]
368                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
369
370            print(line)
371
372            ### save colvars
373            if use_colvars:
374                colvars = system["colvars"]
375                fcolvars.write(f"{simulated_time:.3f} ")
376                fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names]))
377                fcolvars.write("\n")
378                fcolvars.flush()
379
380            if save_keys:
381                data = {
382                    k: (
383                        np.asarray(model_output[k])
384                        if isinstance(model_output[k], jnp.ndarray)
385                        else model_output[k]
386                    )
387                    for k in save_keys
388                }
389                data["properties"] = {
390                    k: float(v) for k, v in zip(header.split()[1:], line.split())
391                }
392                data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str)
393                data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str)
394
395                pickle.dump(data, fkeys)
396
397        ### save frame
398        if istep % ndump == 0:
399            line = "# Write XYZ frame"
400            if variable_cell:
401                cell = np.array(system["cell"])
402                reciprocal_cell = np.linalg.inv(cell)
403            if do_wrap_box:
404                if pimd:
405                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups)
406                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
407                else:
408                    system["coordinates"] = wrapbox(
409                        system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups
410                    )
411                conformation = update_conformation(conformation, system)
412                line += " (atoms have been wrapped into the box)"
413                force_preprocess = True
414            print(line)
415
416            save_dynamics_restart(system_data, conformation, dyn_state, system)
417
418            properties = {
419                "energy": float(system["epot"]) * energy_unit,
420                "Time_ps": start_time_ps + istep * dt / 1000,
421                "energy_unit": energy_unit_str,
422            }
423
424            if write_all_beads:
425                coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3))
426                for i, fb in enumerate(fout):
427                    write_frame(
428                        fb,
429                        system_data["symbols"],
430                        coords[i],
431                        cell=cell,
432                        properties=properties,
433                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
434                    )
435            else:
436                write_frame(
437                    fout,
438                    system_data["symbols"],
439                    np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]),
440                    cell=cell,
441                    properties=properties,
442                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
443                )
444            if write_centroid:
445                centroid = np.asarray(system["coordinates"][0])
446                write_frame(
447                    fcentroid,
448                    system_data["symbols"],
449                    centroid,
450                    cell=cell,
451                    properties=properties,
452                    forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0])
453                    * energy_unit,
454                )
455            if ensemble_key is not None:
456                weights = " ".join(
457                    [f"{w:.6f}" for w in system["ensemble_weights"].tolist()]
458                )
459                fens.write(f"{weights}\n")
460                fens.flush()
461
462        ### summary over last nsummary steps
463        if istep % (nsummary) == 0:
464            if check_nan(system):
465                raise ValueError(f"dynamics crashed at step {istep}.")
466            tfull = time.time() - t0full
467            t0full = time.time()
468            tperstep = tfull / (nsummary)
469            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
470            elapsed_time = time.time() - tstart_dyn
471            estimated_remaining_time = tperstep * (nsteps - istep)
472            estimated_total_time = elapsed_time + estimated_remaining_time
473
474            print("#" * 50)
475            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
476            print(f"# Simulated time      : {istep * dt*1.e-3:.3f} ps")
477            print(f"# Tot. Simu. time     : {start_time_ps + istep * dt*1.e-3:.3f} ps")
478            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
479            print(
480                f"# Est. total duration   : {human_time_duration(estimated_total_time)}"
481            )
482            print(
483                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
484            )
485            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
486
487            corr_kin = system["thermostat"].get("corr_kin", None)
488            if corr_kin is not None:
489                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
490            print(f"# Averages over last {nsummary:_} steps :")
491            for k, v in properties_traj.items():
492                if len(properties_traj[k]) == 0:
493                    continue
494                mu = np.mean(properties_traj[k])
495                sig = np.std(properties_traj[k])
496                ksplit = k.split("[")
497                name = ksplit[0].strip()
498                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
499                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
500
501            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
502            print("#" * 50)
503            if istep < nsteps:
504                print(header)
505            ## reset property trajectories
506            properties_traj = defaultdict(list)
507
508    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
509    ### close trajectory file
510    fout.close()
511    if ensemble_key is not None:
512        fens.close()
513    if use_colvars:
514        fcolvars.close()
515    if fkeys is not None:
516        fkeys.close()
517    if write_centroid:
518        fcentroid.close()
519
520
521if __name__ == "__main__":
522    main()
def main():
26def main():
27    # os.environ["OMP_NUM_THREADS"] = "1"
28    sys.stdout = io.TextIOWrapper(
29        open(sys.stdout.fileno(), "wb", 0), write_through=True
30    )
31    ### Read the parameter file
32    parser = argparse.ArgumentParser(prog="fennol_md")
33    parser.add_argument("param_file", type=Path, help="Parameter file")
34    args = parser.parse_args()
35
36    param_file = args.param_file
37    if not param_file.exists() and not param_file.is_file():
38        raise FileNotFoundError(f"Parameter file {param_file} not found")
39    
40    if param_file.suffix in [".yaml", ".yml"]:
41        with open(param_file, "r") as f:
42            simulation_parameters = convert_dict_units(yaml.safe_load(f))
43            simulation_parameters = InputFile(**simulation_parameters)
44    elif param_file.suffix == ".fnl":
45        simulation_parameters = parse_input(args.param_file)
46    else:
47        raise ValueError(
48            f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'"
49        )
50
51    ### Set the device
52    if "FENNOL_DEVICE" in os.environ:
53        device = os.environ["FENNOL_DEVICE"].lower()
54        print(f"# Setting device from env FENNOL_DEVICE={device}")
55    else:
56        device = simulation_parameters.get("device", "cpu").lower()
57    if device == "cpu":
58        jax.config.update('jax_platforms', 'cpu')
59        os.environ["CUDA_VISIBLE_DEVICES"] = ""
60    elif device.startswith("cuda") or device.startswith("gpu"):
61        if ":" in device:
62            num = device.split(":")[-1]
63            os.environ["CUDA_VISIBLE_DEVICES"] = num
64        else:
65            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
66        device = "gpu"
67
68    _device = jax.devices(device)[0]
69    jax.config.update("jax_default_device", _device)
70
71    ### Set the precision
72    enable_x64 = simulation_parameters.get("double_precision", False)
73    jax.config.update("jax_enable_x64", enable_x64)
74    fprec = "float64" if enable_x64 else "float32"
75
76    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
77    assert matmul_precision in [
78        "default",
79        "high",
80        "highest",
81    ], "matmul_prec must be one of 'default','high','highest'"
82    if matmul_precision != "highest":
83        print(f"# Setting matmul precision to '{matmul_precision}'")
84    if matmul_precision == "default" and fprec == "float32":
85        print(
86            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
87        )
88    jax.config.update("jax_default_matmul_precision", matmul_precision)
89
90    # with jax.default_device(_device):
91    dynamic(simulation_parameters, device, fprec)
def dynamic(simulation_parameters, device, fprec):
 94def dynamic(simulation_parameters, device, fprec):
 95    tstart_dyn = time.time()
 96
 97    random_seed = simulation_parameters.get(
 98        "random_seed", np.random.randint(0, 2**32 - 1)
 99    )
100    print(f"# random_seed: {random_seed}")
101    rng_key = jax.random.PRNGKey(random_seed)
102
103    ## INITIALIZE INTEGRATOR AND SYSTEM
104    rng_key, subkey = jax.random.split(rng_key)
105    step, update_conformation, system_data, dyn_state, conformation, system = (
106        initialize_dynamics(simulation_parameters, fprec, subkey)
107    )
108
109    nat = system_data["nat"]
110    dt = dyn_state["dt"]
111    ## get number of steps
112    nsteps = int(simulation_parameters.get("nsteps"))
113    start_time_ps = dyn_state.get("start_time_ps", 0.0)
114
115    ### Set I/O parameters
116    Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS
117    ndump = int(Tdump / dt)
118    system_name = system_data["name"]
119    estimate_pressure = dyn_state["estimate_pressure"]
120
121    @jax.jit
122    def check_nan(system):
123        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
124            jnp.isnan(system["coordinates"])
125        )
126
127    if system_data["pbc"] is not None:
128        cell = system["cell"]
129        reciprocal_cell = np.linalg.inv(cell)
130        do_wrap_box = simulation_parameters.get("wrap_box", False)
131        if do_wrap_box:
132            wrap_groups_def = simulation_parameters.get("wrap_groups",None)
133            if wrap_groups_def is None:
134                wrap_groups = None
135            else:
136                wrap_groups = {}
137                assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary"
138                for k, v in wrap_groups_def.items():
139                    wrap_groups[k]=read_tinker_interval(v)
140                # check that pairwise intersection of wrap groups is empty
141                wrap_groups_keys = list(wrap_groups.keys())
142                for i in range(len(wrap_groups_keys)):
143                    i_key = wrap_groups_keys[i]
144                    w1 = set(wrap_groups[i_key])
145                    for j in range(i + 1, len(wrap_groups_keys)):
146                        j_key = wrap_groups_keys[j]
147                        w2 = set(wrap_groups[j_key])
148                        if  w1.intersection(w2):
149                            raise ValueError(
150                                f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}"
151                            )
152                group_all = np.concatenate(list(wrap_groups.values()))
153                # get all atoms that are not in any wrap group
154                group_none = np.setdiff1d(np.arange(nat), group_all)
155                print(f"# Wrap groups: {wrap_groups}")
156                wrap_groups["__other"] = group_none
157                wrap_groups = ((k, v) for k, v in wrap_groups.items())
158
159    else:
160        cell = None
161        reciprocal_cell = None
162        do_wrap_box = False
163        wrap_groups = None
164
165    ### Energy units and print initial energy
166    model_energy_unit = system_data["model_energy_unit"]
167    model_energy_unit_str = system_data["model_energy_unit_str"]
168    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
169    energy_unit_str = system_data["energy_unit_str"]
170    energy_unit = system_data["energy_unit"]
171    print("# Energy unit: ", energy_unit_str)
172    atom_energy_unit = energy_unit
173    atom_energy_unit_str = energy_unit_str
174    if per_atom_energy:
175        atom_energy_unit /= nat
176        atom_energy_unit_str = f"{energy_unit_str}/atom"
177        print("# Printing Energy per atom")
178    print(
179        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
180    )
181    f = system["forces"]
182    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
183
184    ## printing options
185    print_timings = simulation_parameters.get("print_timings", False)
186    nprint = int(simulation_parameters.get("nprint", 10))
187    assert nprint > 0, "nprint must be > 0"
188    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
189    assert nsummary > nprint, "nsummary must be > nprint"
190
191    save_keys = simulation_parameters.get("save_keys", [])
192    if save_keys:
193        print(f"# Saving keys: {save_keys}")
194        fkeys = open(f"{system_name}.traj.pkl", "wb+")
195    else:
196        fkeys = None
197    
198    ### initialize colvars
199    use_colvars = "colvars" in dyn_state
200    if use_colvars:
201        print(f"# Colvars: {dyn_state['colvars']}")
202        colvars_names = dyn_state["colvars"]
203        # open colvars file and print header
204        fcolvars = open(f"{system_name}.colvars.traj", "a")
205        fcolvars.write("#time[ps] ")
206        fcolvars.write(" ".join(colvars_names))
207        fcolvars.write("\n")
208        fcolvars.flush()
209
210    ### Print header
211    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
212    thermostat_name = dyn_state["thermostat_name"]
213    pimd = dyn_state["pimd"]
214    variable_cell = dyn_state["variable_cell"]
215    nbeads = system_data.get("nbeads", 1)
216    dyn_name = "PIMD" if pimd else "MD"
217    print("#" * 84)
218    print(
219        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
220    )
221    header = "#     Step   Time[ps]        Etot        Epot        Ekin    Temp[K]"
222    if pimd:
223        header += "  Temp_c[K]"
224    if include_thermostat_energy:
225        header += "      Etherm"
226    if estimate_pressure:
227        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
228        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
229        pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3
230        header += f"  Press[{pressure_unit_str}]"
231    if variable_cell:
232        header += "   Density"
233    print(header)
234
235    ### Open trajectory file
236    traj_format = simulation_parameters.get("traj_format", "arc").lower()
237    if traj_format == "xyz":
238        traj_ext = ".traj.xyz"
239        write_frame = write_xyz_frame
240    elif traj_format == "extxyz":
241        traj_ext = ".traj.extxyz"
242        write_frame = write_extxyz_frame
243    elif traj_format == "arc":
244        traj_ext = ".arc"
245        write_frame = write_arc_frame
246    else:
247        raise ValueError(
248            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
249        )
250
251    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
252
253    if write_all_beads:
254        fout = [
255            open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads)
256        ]
257    else:
258        fout = open(system_name + traj_ext, "a")
259
260    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
261    if ensemble_key is not None:
262        fens = open(f"{system_name}.ensemble_weights.traj", "a")
263
264    write_centroid = simulation_parameters.get("write_centroid", False) and pimd
265    if write_centroid:
266        fcentroid = open(f"{system_name}_centroid" + traj_ext, "a")
267
268    ### initialize proprerty trajectories
269    properties_traj = defaultdict(list)
270
271    ### initialize counters and timers
272    t0 = time.time()
273    t0dump = t0
274    istep = 0
275    t0full = time.time()
276    force_preprocess = False
277
278    for istep in range(1, nsteps + 1):
279
280        ### update the system
281        dyn_state, system, conformation, model_output = step(
282            istep, dyn_state, system, conformation, force_preprocess
283        )
284
285        ### print properties
286        if istep % nprint == 0:
287            t1 = time.time()
288            tperstep = (t1 - t0) / nprint
289            t0 = t1
290            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
291
292            ek = system["ek"]
293            epot = system["epot"]
294            etot = ek + epot
295            temper = 2 * ek / (3.0 * nat) * au.KELVIN
296
297            th_state = system["thermostat"]
298            if include_thermostat_energy:
299                etherm = th_state["thermostat_energy"]
300                etot = etot + etherm
301
302            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
303                etot * atom_energy_unit
304            )
305            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
306                epot * atom_energy_unit
307            )
308            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
309                ek * atom_energy_unit
310            )
311            properties_traj["Temper[Kelvin]"].append(temper)
312            if pimd:
313                ek_c = system["ek_c"]
314                temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN
315                properties_traj["Temper_c[Kelvin]"].append(temper_c)
316
317            ### construct line of properties
318            simulated_time = start_time_ps + istep * dt / 1000
319            line = f"{istep:10.6g} {simulated_time: 10.3f}  {etot*atom_energy_unit: #10.4f}  {epot*atom_energy_unit: #10.4f}  {ek*atom_energy_unit: #10.4f} {temper: 10.2f}"
320            if pimd:
321                line += f" {temper_c: 10.2f}"
322            if include_thermostat_energy:
323                line += f"  {etherm*atom_energy_unit: #10.4f}"
324                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
325                    etherm * atom_energy_unit
326                )
327            if estimate_pressure:
328                pres = system["pressure"] * pressure_unit
329                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
330                if print_aniso_pressure:
331                    pres_tensor = system["pressure_tensor"] * pressure_unit
332                    pres_tensor = 0.5 * (pres_tensor + pres_tensor.T)
333                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(
334                        pres_tensor[0, 0]
335                    )
336                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(
337                        pres_tensor[1, 1]
338                    )
339                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(
340                        pres_tensor[2, 2]
341                    )
342                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(
343                        pres_tensor[0, 1]
344                    )
345                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(
346                        pres_tensor[0, 2]
347                    )
348                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(
349                        pres_tensor[1, 2]
350                    )
351                line += f" {pres:10.3f}"
352            if variable_cell:
353                density = system["density"]
354                properties_traj["Density[g/cm^3]"].append(density)
355                if print_aniso_pressure:
356                    cell = system["cell"]
357                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0])
358                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1])
359                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2])
360                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0])
361                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1])
362                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2])
363                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0])
364                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1])
365                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2])
366                line += f" {density:10.4f}"
367                if "piston_temperature" in system["barostat"]:
368                    piston_temperature = system["barostat"]["piston_temperature"]
369                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
370
371            print(line)
372
373            ### save colvars
374            if use_colvars:
375                colvars = system["colvars"]
376                fcolvars.write(f"{simulated_time:.3f} ")
377                fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names]))
378                fcolvars.write("\n")
379                fcolvars.flush()
380
381            if save_keys:
382                data = {
383                    k: (
384                        np.asarray(model_output[k])
385                        if isinstance(model_output[k], jnp.ndarray)
386                        else model_output[k]
387                    )
388                    for k in save_keys
389                }
390                data["properties"] = {
391                    k: float(v) for k, v in zip(header.split()[1:], line.split())
392                }
393                data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str)
394                data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str)
395
396                pickle.dump(data, fkeys)
397
398        ### save frame
399        if istep % ndump == 0:
400            line = "# Write XYZ frame"
401            if variable_cell:
402                cell = np.array(system["cell"])
403                reciprocal_cell = np.linalg.inv(cell)
404            if do_wrap_box:
405                if pimd:
406                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups)
407                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
408                else:
409                    system["coordinates"] = wrapbox(
410                        system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups
411                    )
412                conformation = update_conformation(conformation, system)
413                line += " (atoms have been wrapped into the box)"
414                force_preprocess = True
415            print(line)
416
417            save_dynamics_restart(system_data, conformation, dyn_state, system)
418
419            properties = {
420                "energy": float(system["epot"]) * energy_unit,
421                "Time_ps": start_time_ps + istep * dt / 1000,
422                "energy_unit": energy_unit_str,
423            }
424
425            if write_all_beads:
426                coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3))
427                for i, fb in enumerate(fout):
428                    write_frame(
429                        fb,
430                        system_data["symbols"],
431                        coords[i],
432                        cell=cell,
433                        properties=properties,
434                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
435                    )
436            else:
437                write_frame(
438                    fout,
439                    system_data["symbols"],
440                    np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]),
441                    cell=cell,
442                    properties=properties,
443                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
444                )
445            if write_centroid:
446                centroid = np.asarray(system["coordinates"][0])
447                write_frame(
448                    fcentroid,
449                    system_data["symbols"],
450                    centroid,
451                    cell=cell,
452                    properties=properties,
453                    forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0])
454                    * energy_unit,
455                )
456            if ensemble_key is not None:
457                weights = " ".join(
458                    [f"{w:.6f}" for w in system["ensemble_weights"].tolist()]
459                )
460                fens.write(f"{weights}\n")
461                fens.flush()
462
463        ### summary over last nsummary steps
464        if istep % (nsummary) == 0:
465            if check_nan(system):
466                raise ValueError(f"dynamics crashed at step {istep}.")
467            tfull = time.time() - t0full
468            t0full = time.time()
469            tperstep = tfull / (nsummary)
470            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
471            elapsed_time = time.time() - tstart_dyn
472            estimated_remaining_time = tperstep * (nsteps - istep)
473            estimated_total_time = elapsed_time + estimated_remaining_time
474
475            print("#" * 50)
476            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
477            print(f"# Simulated time      : {istep * dt*1.e-3:.3f} ps")
478            print(f"# Tot. Simu. time     : {start_time_ps + istep * dt*1.e-3:.3f} ps")
479            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
480            print(
481                f"# Est. total duration   : {human_time_duration(estimated_total_time)}"
482            )
483            print(
484                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
485            )
486            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
487
488            corr_kin = system["thermostat"].get("corr_kin", None)
489            if corr_kin is not None:
490                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
491            print(f"# Averages over last {nsummary:_} steps :")
492            for k, v in properties_traj.items():
493                if len(properties_traj[k]) == 0:
494                    continue
495                mu = np.mean(properties_traj[k])
496                sig = np.std(properties_traj[k])
497                ksplit = k.split("[")
498                name = ksplit[0].strip()
499                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
500                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
501
502            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
503            print("#" * 50)
504            if istep < nsteps:
505                print(header)
506            ## reset property trajectories
507            properties_traj = defaultdict(list)
508
509    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
510    ### close trajectory file
511    fout.close()
512    if ensemble_key is not None:
513        fens.close()
514    if use_colvars:
515        fcolvars.close()
516    if fkeys is not None:
517        fkeys.close()
518    if write_centroid:
519        fcentroid.close()