fennol.md.dynamic

  1import sys, os, io
  2import argparse
  3import time
  4from pathlib import Path
  5import math
  6
  7import numpy as np
  8from typing import Optional, Callable
  9from collections import defaultdict
 10from functools import partial
 11import jax
 12import jax.numpy as jnp
 13
 14from flax.core import freeze, unfreeze
 15
 16
 17from ..models import FENNIX
 18from ..utils.io import (
 19    write_arc_frame,
 20    last_xyz_frame,
 21    write_xyz_frame,
 22    write_extxyz_frame,
 23    human_time_duration,
 24)
 25from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES
 26from ..utils.atomic_units import AtomicUnits as au
 27from ..utils.input_parser import parse_input
 28from .initial import load_model, load_system_data, initialize_preprocessing
 29from .integrate import initialize_dynamics
 30
 31from copy import deepcopy
 32
 33
 34def minmaxone(x, name=""):
 35    print(name, x.min(), x.max(), (x**2).mean() ** 0.5)
 36
 37
 38@jax.jit
 39def wrapbox(x, cell, reciprocal_cell):
 40    # q = jnp.einsum("ji,sj->si", reciprocal_cell, x)
 41    q = x @ reciprocal_cell
 42    q = q - jnp.floor(q)
 43    # return jnp.einsum("ji,sj->si", cell, q)
 44    return q @ cell
 45
 46
 47def main():
 48    # os.environ["OMP_NUM_THREADS"] = "1"
 49    sys.stdout = io.TextIOWrapper(
 50        open(sys.stdout.fileno(), "wb", 0), write_through=True
 51    )
 52    ### Read the parameter file
 53    parser = argparse.ArgumentParser(prog="fennol_md")
 54    parser.add_argument("param_file", type=Path, help="Parameter file")
 55    args = parser.parse_args()
 56    simulation_parameters = parse_input(args.param_file)
 57
 58    ### Set the device
 59    device: str = simulation_parameters.get("device", "cpu").lower()
 60    if device == "cpu":
 61        device = "cpu"
 62        os.environ["CUDA_VISIBLE_DEVICES"] = ""
 63    elif device.startswith("cuda") or device.startswith("gpu"):
 64        if ":" in device:
 65            num = device.split(":")[-1]
 66            os.environ["CUDA_VISIBLE_DEVICES"] = num
 67        else:
 68            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
 69        device = "gpu"
 70
 71    _device = jax.devices(device)[0]
 72    jax.config.update("jax_default_device", _device)
 73
 74    ### Set the precision
 75    enable_x64 = simulation_parameters.get("double_precision", False)
 76    jax.config.update("jax_enable_x64", enable_x64)
 77    fprec = "float64" if enable_x64 else "float32"
 78
 79    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
 80    assert matmul_precision in [
 81        "default",
 82        "high",
 83        "highest",
 84    ], "matmul_prec must be one of 'default','high','highest'"
 85    if matmul_precision != "highest":
 86        print(f"# Setting matmul precision to '{matmul_precision}'")
 87    if matmul_precision == "default" and fprec == "float32":
 88        print(
 89            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
 90        )
 91    jax.config.update("jax_default_matmul_precision", matmul_precision)
 92
 93    # with jax.default_device(_device):
 94    dynamic(simulation_parameters, device, fprec)
 95
 96
 97def dynamic(simulation_parameters, device, fprec):
 98    tstart_dyn = time.time()
 99
100    ### Initialize the model
101    model = load_model(simulation_parameters)
102
103    ### Get the coordinates and species from the xyz file
104    system_data, conformation = load_system_data(simulation_parameters, fprec)
105    nat = system_data["nat"]
106
107    preproc_state, conformation = initialize_preprocessing(
108        simulation_parameters, model, conformation, system_data
109    )
110
111    random_seed = simulation_parameters.get(
112        "random_seed", np.random.randint(0, 2**32 - 1)
113    )
114    print(f"# random_seed: {random_seed}")
115    rng_key = jax.random.PRNGKey(random_seed)
116    rng_key, subkey = jax.random.split(rng_key)
117    ## INITIALIZE INTEGRATOR AND SYSTEM
118    step, update_conformation, dyn_state, system = initialize_dynamics(
119        simulation_parameters, system_data, conformation, model, fprec, subkey
120    )
121
122    dt = dyn_state["dt"]
123    ## get number of steps
124    nsteps = int(simulation_parameters.get("nsteps"))
125    start_time = 0.0
126    start_step = 0
127
128
129    ### Set I/O parameters
130    Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS
131    ndump = int(Tdump / dt)
132    system_name = system_data["name"]
133
134    model_energy_unit = au.get_multiplier(model.energy_unit)
135    ### Print initial pressure
136    estimate_pressure = dyn_state["estimate_pressure"]
137    if estimate_pressure and fprec == "float64":
138        volume = system_data["pbc"]["volume"]
139        coordinates = conformation["coordinates"]
140        cell = conformation["cells"][0]
141        # temper = 2 * ek / (3.0 * nat) * au.KELVIN
142        ek = 1.5*nat * system_data["kT"] 
143        Pkin = (2 * au.KBAR) * ek / ((3.0 / au.BOHR**3) * volume)
144        e, f, vir_t, _ = model._energy_and_forces_and_virial(
145            model.variables, conformation
146        )
147        KBAR = au.KBAR/model_energy_unit
148        Pvir = -(np.trace(vir_t[0]) * KBAR) / ((3.0 / au.BOHR**3) * volume)
149        vstep = volume * 0.000001
150        scalep = ((volume + vstep) / volume) ** (1.0 / 3.0)
151        cellp = cell * scalep
152        reciprocal_cell = np.linalg.inv(cellp)
153        sysp = model.preprocess(
154            **{
155                **conformation,
156                "coordinates": coordinates * scalep,
157                "cells": cellp[None, :, :],
158                "reciprocal_cells": reciprocal_cell[None, :, :],
159            }
160        )
161        ep, _ = model._total_energy(model.variables, sysp)
162        scalem = ((volume - vstep) / volume) ** (1.0 / 3.0)
163        cellm = cell * scalem
164        reciprocal_cell = np.linalg.inv(cellm)
165        sysm = model.preprocess(
166            **{
167                **conformation,
168                "coordinates": coordinates * scalem,
169                "cells": cellm[None, :, :],
170                "reciprocal_cells": reciprocal_cell[None, :, :],
171            }
172        )
173        em, _ = model._total_energy(model.variables, sysm)
174        Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep / au.BOHR**3)
175        print(
176            f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}"
177        )
178
179    @jax.jit
180    def check_nan(system):
181        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
182            jnp.isnan(system["coordinates"])
183        )
184
185    if system_data["pbc"] is not None:
186        cell = system_data["pbc"]["cell"]
187        reciprocal_cell = system_data["pbc"]["reciprocal_cell"]
188        do_wrap_box = simulation_parameters.get("wrap_box", False)
189    else:
190        cell = None
191        reciprocal_cell = None
192        do_wrap_box = False
193
194    ### Energy units and print initial energy
195    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
196    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
197    print("# Energy unit: ", energy_unit_str)
198    energy_unit = au.get_multiplier(energy_unit_str)
199    atom_energy_unit = energy_unit
200    atom_energy_unit_str = energy_unit_str
201    if per_atom_energy:
202        atom_energy_unit /= nat
203        atom_energy_unit_str = f"{energy_unit_str}/atom"
204        print("# Printing Energy per atom")
205    print(
206        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
207    )
208    f = system["forces"]
209    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
210
211    ## printing options
212    print_timings = simulation_parameters.get("print_timings", False)
213    nprint = int(simulation_parameters.get("nprint", 10))
214    assert nprint > 0, "nprint must be > 0"
215    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
216    assert nsummary > nprint, "nsummary must be > nprint"
217
218    ### Print header
219    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
220    thermostat_name = dyn_state["thermostat_name"]
221    pimd = dyn_state["pimd"]
222    variable_cell = dyn_state["variable_cell"]
223    nbeads = system_data.get("nbeads", 1)
224    dyn_name = "PIMD" if pimd else "MD"
225    print("#" * 84)
226    print(
227        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
228    )
229    header = "#     Step   Time[ps]        Etot        Epot        Ekin    Temp[K]"
230    if pimd:
231        header += "  Temp_c[K]"
232    if include_thermostat_energy:
233        header += "      Etherm"
234    if estimate_pressure:
235        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
236        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
237        pressure_unit = au.get_multiplier(pressure_unit_str)*au.BOHR**3
238        header += f"  Press[{pressure_unit_str}]"
239    if variable_cell:
240        header += "   Density"
241    print(header)
242
243    ### Open trajectory file
244    traj_format = simulation_parameters.get("traj_format", "arc").lower()
245    if traj_format == "xyz":
246        traj_ext = ".traj.xyz"
247        write_frame = write_xyz_frame
248    elif traj_format == "extxyz":
249        traj_ext =  ".traj.extxyz"
250        write_frame = write_extxyz_frame
251    elif traj_format == "arc":
252        traj_ext =  ".arc"
253        write_frame = write_arc_frame
254    else:
255        raise ValueError(
256            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
257        )
258    
259    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
260    
261    if write_all_beads:
262        fout = [open(f"{system_name}_bead{i+1:03d}"+traj_ext, "w") for i in range(nbeads)]
263    else:
264        fout = open(system_name+traj_ext, "a+")
265
266    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
267    if ensemble_key is not None:
268        fens = open(f"{system_name}.ensemble_weights.traj", "a+")
269
270
271    ### initialize proprerty trajectories
272    properties_traj = defaultdict(list)
273    if print_timings:
274        timings = defaultdict(lambda: 0.0)
275
276    ### initialize counters and timers
277    t0 = time.time()
278    t0dump = t0
279    istep = 0
280    t0full = time.time()
281    force_preprocess = False
282
283    for istep in range(1, nsteps + 1):
284
285        ### update the system
286        dyn_state, system, conformation, preproc_state = step(
287            istep, dyn_state, system, conformation, preproc_state, force_preprocess
288        )
289
290        ### print properties
291        if istep % nprint == 0:
292            t1 = time.time()
293            tperstep = (t1 - t0) / nprint
294            t0 = t1
295            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
296
297            ek = system["ek"]
298            epot = system["epot"]
299            etot = ek + epot
300            temper = 2 * ek / (3.0 * nat) * au.KELVIN
301
302            th_state = system["thermostat"]
303            if include_thermostat_energy:
304                etherm = th_state["thermostat_energy"]
305                etot = etot + etherm
306
307            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
308                etot * atom_energy_unit
309            )
310            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
311                epot * atom_energy_unit
312            )
313            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
314                ek * atom_energy_unit
315            )
316            properties_traj["Temper[Kelvin]"].append(temper)
317            if pimd:
318                ek_c = system["ek_c"]
319                temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN
320                properties_traj["Temper_c[Kelvin]"].append(temper_c)
321
322            ### construct line of properties
323            line = f"{istep:10.6g} {(start_time+istep*dt)/1000: 10.3f}  {etot*atom_energy_unit: #10.4f}  {epot*atom_energy_unit: #10.4f}  {ek*atom_energy_unit: #10.4f} {temper: 10.2f}"
324            if pimd:
325                line += f" {temper_c: 10.2f}"
326            if include_thermostat_energy:
327                line += f"  {etherm*atom_energy_unit: #10.4f}"
328                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
329                    etherm * atom_energy_unit
330                )
331            if estimate_pressure:
332                pres = system["pressure"]*pressure_unit
333                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
334                if print_aniso_pressure:
335                    pres_tensor = system["pressure_tensor"]*pressure_unit
336                    pres_tensor = 0.5*(pres_tensor + pres_tensor.T)
337                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(pres_tensor[0,0])
338                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(pres_tensor[1,1])
339                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(pres_tensor[2,2])
340                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(pres_tensor[0,1])
341                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(pres_tensor[0,2])
342                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(pres_tensor[1,2])
343                line += f" {pres:10.3f}"
344            if variable_cell:
345                density = system["density"]
346                properties_traj["Density[g/cm^3]"].append(density)
347                if print_aniso_pressure:
348                    cell = system["cell"]
349                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0,0])
350                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0,1])
351                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0,2])
352                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1,0])
353                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1,1])
354                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1,2])
355                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2,0])
356                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2,1])
357                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2,2])
358                line += f" {density:10.4f}"
359                if "piston_temperature" in system["barostat"]:
360                    piston_temperature = system["barostat"]["piston_temperature"]
361                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
362
363            print(line)
364
365        ### save frame
366        if istep % ndump == 0:
367            line = "# Write XYZ frame"
368            if variable_cell:
369                cell = np.array(system["cell"])
370                reciprocal_cell = np.linalg.inv(cell)
371            if do_wrap_box:
372                if pimd:
373                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell)
374                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
375                else:
376                    system["coordinates"] = wrapbox(
377                        system["coordinates"], cell, reciprocal_cell
378                    )
379                conformation = update_conformation(conformation,system)
380                line += " (atoms have been wrapped into the box)"
381                force_preprocess = True
382            print(line)
383            properties = {
384                "energy": float(system["epot"]) * energy_unit,
385                "Time": start_time + istep * dt,
386                "energy_unit": energy_unit_str,
387            }
388            
389            if write_all_beads:
390                coords = np.asarray(conformation["coordinates"].reshape(nbeads, nat, 3))
391                for i,fb in enumerate(fout):
392                    write_frame(
393                        fb,
394                        system_data["symbols"],
395                        coords[i],
396                        cell=cell,
397                        properties=properties,
398                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
399                    )
400            else:
401                write_frame(
402                    fout,
403                    system_data["symbols"],
404                    np.asarray(conformation["coordinates"].reshape(nbeads, nat, 3)[0]),
405                    cell=cell,
406                    properties=properties,
407                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
408                )
409            if ensemble_key is not None:
410                weights = " ".join([f'{w:.6f}' for w in system["ensemble_weights"].tolist()])
411                fens.write(f"{weights}\n")
412                fens.flush()
413
414
415        ### summary over last nsummary steps
416        if istep % (nsummary) == 0:
417            if check_nan(system):
418                raise ValueError(f"dynamics crashed at step {istep}.")
419            tfull = time.time() - t0full
420            t0full = time.time()
421            tperstep = tfull / (nsummary)
422            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
423            elapsed_time = time.time() - tstart_dyn
424            estimated_remaining_time = tperstep * (nsteps - istep)
425            estimated_total_time = elapsed_time + estimated_remaining_time
426
427            print("#" * 50)
428            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
429            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
430            print(
431                f"# Est. total time     : {human_time_duration(estimated_total_time)}"
432            )
433            print(
434                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
435            )
436            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
437
438            if print_timings:
439                print(f"# Detailed per-step timings :")
440                dsteps = nsummary
441                tother = tfull - sum([t for t in timings.values()])
442                timings["Other"] = tother
443                # sort timings
444                timings = {
445                    k: v
446                    for k, v in sorted(
447                        timings.items(), key=lambda item: item[1], reverse=True
448                    )
449                }
450                for k, v in timings.items():
451                    print(
452                        f"#   {k:15} : {human_time_duration(v/dsteps):>12} ({v/tfull*100:5.3g} %)"
453                    )
454                print(f"#   {'Total':15} : {human_time_duration(tfull/dsteps):>12}")
455                ## reset timings
456                timings = defaultdict(lambda: 0.0)
457            
458            corr_kin = system["thermostat"].get("corr_kin",None)
459            if corr_kin is not None:
460                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
461            print(f"# Averages over last {nsummary:_} steps :")
462            for k, v in properties_traj.items():
463                if len(properties_traj[k]) == 0:
464                    continue
465                mu = np.mean(properties_traj[k])
466                sig = np.std(properties_traj[k])
467                ksplit = k.split("[")
468                name = ksplit[0].strip()
469                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
470                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
471
472            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
473            print("#" * 50)
474            if istep < nsteps:
475                print(header)
476            ## reset property trajectories
477            properties_traj = defaultdict(list)
478
479    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
480    ### close trajectory file
481    fout.close()
482    if ensemble_key is not None:
483        fens.close()
484
485
486if __name__ == "__main__":
487    main()
def minmaxone(x, name=''):
35def minmaxone(x, name=""):
36    print(name, x.min(), x.max(), (x**2).mean() ** 0.5)
@jax.jit
def wrapbox(x, cell, reciprocal_cell):
39@jax.jit
40def wrapbox(x, cell, reciprocal_cell):
41    # q = jnp.einsum("ji,sj->si", reciprocal_cell, x)
42    q = x @ reciprocal_cell
43    q = q - jnp.floor(q)
44    # return jnp.einsum("ji,sj->si", cell, q)
45    return q @ cell
def main():
48def main():
49    # os.environ["OMP_NUM_THREADS"] = "1"
50    sys.stdout = io.TextIOWrapper(
51        open(sys.stdout.fileno(), "wb", 0), write_through=True
52    )
53    ### Read the parameter file
54    parser = argparse.ArgumentParser(prog="fennol_md")
55    parser.add_argument("param_file", type=Path, help="Parameter file")
56    args = parser.parse_args()
57    simulation_parameters = parse_input(args.param_file)
58
59    ### Set the device
60    device: str = simulation_parameters.get("device", "cpu").lower()
61    if device == "cpu":
62        device = "cpu"
63        os.environ["CUDA_VISIBLE_DEVICES"] = ""
64    elif device.startswith("cuda") or device.startswith("gpu"):
65        if ":" in device:
66            num = device.split(":")[-1]
67            os.environ["CUDA_VISIBLE_DEVICES"] = num
68        else:
69            os.environ["CUDA_VISIBLE_DEVICES"] = "0"
70        device = "gpu"
71
72    _device = jax.devices(device)[0]
73    jax.config.update("jax_default_device", _device)
74
75    ### Set the precision
76    enable_x64 = simulation_parameters.get("double_precision", False)
77    jax.config.update("jax_enable_x64", enable_x64)
78    fprec = "float64" if enable_x64 else "float32"
79
80    matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower()
81    assert matmul_precision in [
82        "default",
83        "high",
84        "highest",
85    ], "matmul_prec must be one of 'default','high','highest'"
86    if matmul_precision != "highest":
87        print(f"# Setting matmul precision to '{matmul_precision}'")
88    if matmul_precision == "default" and fprec == "float32":
89        print(
90            "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'."
91        )
92    jax.config.update("jax_default_matmul_precision", matmul_precision)
93
94    # with jax.default_device(_device):
95    dynamic(simulation_parameters, device, fprec)
def dynamic(simulation_parameters, device, fprec):
 98def dynamic(simulation_parameters, device, fprec):
 99    tstart_dyn = time.time()
100
101    ### Initialize the model
102    model = load_model(simulation_parameters)
103
104    ### Get the coordinates and species from the xyz file
105    system_data, conformation = load_system_data(simulation_parameters, fprec)
106    nat = system_data["nat"]
107
108    preproc_state, conformation = initialize_preprocessing(
109        simulation_parameters, model, conformation, system_data
110    )
111
112    random_seed = simulation_parameters.get(
113        "random_seed", np.random.randint(0, 2**32 - 1)
114    )
115    print(f"# random_seed: {random_seed}")
116    rng_key = jax.random.PRNGKey(random_seed)
117    rng_key, subkey = jax.random.split(rng_key)
118    ## INITIALIZE INTEGRATOR AND SYSTEM
119    step, update_conformation, dyn_state, system = initialize_dynamics(
120        simulation_parameters, system_data, conformation, model, fprec, subkey
121    )
122
123    dt = dyn_state["dt"]
124    ## get number of steps
125    nsteps = int(simulation_parameters.get("nsteps"))
126    start_time = 0.0
127    start_step = 0
128
129
130    ### Set I/O parameters
131    Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS
132    ndump = int(Tdump / dt)
133    system_name = system_data["name"]
134
135    model_energy_unit = au.get_multiplier(model.energy_unit)
136    ### Print initial pressure
137    estimate_pressure = dyn_state["estimate_pressure"]
138    if estimate_pressure and fprec == "float64":
139        volume = system_data["pbc"]["volume"]
140        coordinates = conformation["coordinates"]
141        cell = conformation["cells"][0]
142        # temper = 2 * ek / (3.0 * nat) * au.KELVIN
143        ek = 1.5*nat * system_data["kT"] 
144        Pkin = (2 * au.KBAR) * ek / ((3.0 / au.BOHR**3) * volume)
145        e, f, vir_t, _ = model._energy_and_forces_and_virial(
146            model.variables, conformation
147        )
148        KBAR = au.KBAR/model_energy_unit
149        Pvir = -(np.trace(vir_t[0]) * KBAR) / ((3.0 / au.BOHR**3) * volume)
150        vstep = volume * 0.000001
151        scalep = ((volume + vstep) / volume) ** (1.0 / 3.0)
152        cellp = cell * scalep
153        reciprocal_cell = np.linalg.inv(cellp)
154        sysp = model.preprocess(
155            **{
156                **conformation,
157                "coordinates": coordinates * scalep,
158                "cells": cellp[None, :, :],
159                "reciprocal_cells": reciprocal_cell[None, :, :],
160            }
161        )
162        ep, _ = model._total_energy(model.variables, sysp)
163        scalem = ((volume - vstep) / volume) ** (1.0 / 3.0)
164        cellm = cell * scalem
165        reciprocal_cell = np.linalg.inv(cellm)
166        sysm = model.preprocess(
167            **{
168                **conformation,
169                "coordinates": coordinates * scalem,
170                "cells": cellm[None, :, :],
171                "reciprocal_cells": reciprocal_cell[None, :, :],
172            }
173        )
174        em, _ = model._total_energy(model.variables, sysm)
175        Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep / au.BOHR**3)
176        print(
177            f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}"
178        )
179
180    @jax.jit
181    def check_nan(system):
182        return jnp.any(jnp.isnan(system["vel"])) | jnp.any(
183            jnp.isnan(system["coordinates"])
184        )
185
186    if system_data["pbc"] is not None:
187        cell = system_data["pbc"]["cell"]
188        reciprocal_cell = system_data["pbc"]["reciprocal_cell"]
189        do_wrap_box = simulation_parameters.get("wrap_box", False)
190    else:
191        cell = None
192        reciprocal_cell = None
193        do_wrap_box = False
194
195    ### Energy units and print initial energy
196    per_atom_energy = simulation_parameters.get("per_atom_energy", True)
197    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
198    print("# Energy unit: ", energy_unit_str)
199    energy_unit = au.get_multiplier(energy_unit_str)
200    atom_energy_unit = energy_unit
201    atom_energy_unit_str = energy_unit_str
202    if per_atom_energy:
203        atom_energy_unit /= nat
204        atom_energy_unit_str = f"{energy_unit_str}/atom"
205        print("# Printing Energy per atom")
206    print(
207        f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}"
208    )
209    f = system["forces"]
210    minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:")
211
212    ## printing options
213    print_timings = simulation_parameters.get("print_timings", False)
214    nprint = int(simulation_parameters.get("nprint", 10))
215    assert nprint > 0, "nprint must be > 0"
216    nsummary = simulation_parameters.get("nsummary", 100 * nprint)
217    assert nsummary > nprint, "nsummary must be > nprint"
218
219    ### Print header
220    include_thermostat_energy = "thermostat_energy" in system["thermostat"]
221    thermostat_name = dyn_state["thermostat_name"]
222    pimd = dyn_state["pimd"]
223    variable_cell = dyn_state["variable_cell"]
224    nbeads = system_data.get("nbeads", 1)
225    dyn_name = "PIMD" if pimd else "MD"
226    print("#" * 84)
227    print(
228        f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}"
229    )
230    header = "#     Step   Time[ps]        Etot        Epot        Ekin    Temp[K]"
231    if pimd:
232        header += "  Temp_c[K]"
233    if include_thermostat_energy:
234        header += "      Etherm"
235    if estimate_pressure:
236        print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False)
237        pressure_unit_str = simulation_parameters.get("pressure_unit", "atm")
238        pressure_unit = au.get_multiplier(pressure_unit_str)*au.BOHR**3
239        header += f"  Press[{pressure_unit_str}]"
240    if variable_cell:
241        header += "   Density"
242    print(header)
243
244    ### Open trajectory file
245    traj_format = simulation_parameters.get("traj_format", "arc").lower()
246    if traj_format == "xyz":
247        traj_ext = ".traj.xyz"
248        write_frame = write_xyz_frame
249    elif traj_format == "extxyz":
250        traj_ext =  ".traj.extxyz"
251        write_frame = write_extxyz_frame
252    elif traj_format == "arc":
253        traj_ext =  ".arc"
254        write_frame = write_arc_frame
255    else:
256        raise ValueError(
257            f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'"
258        )
259    
260    write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd
261    
262    if write_all_beads:
263        fout = [open(f"{system_name}_bead{i+1:03d}"+traj_ext, "w") for i in range(nbeads)]
264    else:
265        fout = open(system_name+traj_ext, "a+")
266
267    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
268    if ensemble_key is not None:
269        fens = open(f"{system_name}.ensemble_weights.traj", "a+")
270
271
272    ### initialize proprerty trajectories
273    properties_traj = defaultdict(list)
274    if print_timings:
275        timings = defaultdict(lambda: 0.0)
276
277    ### initialize counters and timers
278    t0 = time.time()
279    t0dump = t0
280    istep = 0
281    t0full = time.time()
282    force_preprocess = False
283
284    for istep in range(1, nsteps + 1):
285
286        ### update the system
287        dyn_state, system, conformation, preproc_state = step(
288            istep, dyn_state, system, conformation, preproc_state, force_preprocess
289        )
290
291        ### print properties
292        if istep % nprint == 0:
293            t1 = time.time()
294            tperstep = (t1 - t0) / nprint
295            t0 = t1
296            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
297
298            ek = system["ek"]
299            epot = system["epot"]
300            etot = ek + epot
301            temper = 2 * ek / (3.0 * nat) * au.KELVIN
302
303            th_state = system["thermostat"]
304            if include_thermostat_energy:
305                etherm = th_state["thermostat_energy"]
306                etot = etot + etherm
307
308            properties_traj[f"Etot[{atom_energy_unit_str}]"].append(
309                etot * atom_energy_unit
310            )
311            properties_traj[f"Epot[{atom_energy_unit_str}]"].append(
312                epot * atom_energy_unit
313            )
314            properties_traj[f"Ekin[{atom_energy_unit_str}]"].append(
315                ek * atom_energy_unit
316            )
317            properties_traj["Temper[Kelvin]"].append(temper)
318            if pimd:
319                ek_c = system["ek_c"]
320                temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN
321                properties_traj["Temper_c[Kelvin]"].append(temper_c)
322
323            ### construct line of properties
324            line = f"{istep:10.6g} {(start_time+istep*dt)/1000: 10.3f}  {etot*atom_energy_unit: #10.4f}  {epot*atom_energy_unit: #10.4f}  {ek*atom_energy_unit: #10.4f} {temper: 10.2f}"
325            if pimd:
326                line += f" {temper_c: 10.2f}"
327            if include_thermostat_energy:
328                line += f"  {etherm*atom_energy_unit: #10.4f}"
329                properties_traj[f"Etherm[{atom_energy_unit_str}]"].append(
330                    etherm * atom_energy_unit
331                )
332            if estimate_pressure:
333                pres = system["pressure"]*pressure_unit
334                properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres)
335                if print_aniso_pressure:
336                    pres_tensor = system["pressure_tensor"]*pressure_unit
337                    pres_tensor = 0.5*(pres_tensor + pres_tensor.T)
338                    properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(pres_tensor[0,0])
339                    properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(pres_tensor[1,1])
340                    properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(pres_tensor[2,2])
341                    properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(pres_tensor[0,1])
342                    properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(pres_tensor[0,2])
343                    properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(pres_tensor[1,2])
344                line += f" {pres:10.3f}"
345            if variable_cell:
346                density = system["density"]
347                properties_traj["Density[g/cm^3]"].append(density)
348                if print_aniso_pressure:
349                    cell = system["cell"]
350                    properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0,0])
351                    properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0,1])
352                    properties_traj[f"Cell_Az[Angstrom]"].append(cell[0,2])
353                    properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1,0])
354                    properties_traj[f"Cell_By[Angstrom]"].append(cell[1,1])
355                    properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1,2])
356                    properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2,0])
357                    properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2,1])
358                    properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2,2])
359                line += f" {density:10.4f}"
360                if "piston_temperature" in system["barostat"]:
361                    piston_temperature = system["barostat"]["piston_temperature"]
362                    properties_traj["T_piston[Kelvin]"].append(piston_temperature)
363
364            print(line)
365
366        ### save frame
367        if istep % ndump == 0:
368            line = "# Write XYZ frame"
369            if variable_cell:
370                cell = np.array(system["cell"])
371                reciprocal_cell = np.linalg.inv(cell)
372            if do_wrap_box:
373                if pimd:
374                    centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell)
375                    system["coordinates"] = system["coordinates"].at[0].set(centroid)
376                else:
377                    system["coordinates"] = wrapbox(
378                        system["coordinates"], cell, reciprocal_cell
379                    )
380                conformation = update_conformation(conformation,system)
381                line += " (atoms have been wrapped into the box)"
382                force_preprocess = True
383            print(line)
384            properties = {
385                "energy": float(system["epot"]) * energy_unit,
386                "Time": start_time + istep * dt,
387                "energy_unit": energy_unit_str,
388            }
389            
390            if write_all_beads:
391                coords = np.asarray(conformation["coordinates"].reshape(nbeads, nat, 3))
392                for i,fb in enumerate(fout):
393                    write_frame(
394                        fb,
395                        system_data["symbols"],
396                        coords[i],
397                        cell=cell,
398                        properties=properties,
399                        forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
400                    )
401            else:
402                write_frame(
403                    fout,
404                    system_data["symbols"],
405                    np.asarray(conformation["coordinates"].reshape(nbeads, nat, 3)[0]),
406                    cell=cell,
407                    properties=properties,
408                    forces=None,  # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit,
409                )
410            if ensemble_key is not None:
411                weights = " ".join([f'{w:.6f}' for w in system["ensemble_weights"].tolist()])
412                fens.write(f"{weights}\n")
413                fens.flush()
414
415
416        ### summary over last nsummary steps
417        if istep % (nsummary) == 0:
418            if check_nan(system):
419                raise ValueError(f"dynamics crashed at step {istep}.")
420            tfull = time.time() - t0full
421            t0full = time.time()
422            tperstep = tfull / (nsummary)
423            nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6
424            elapsed_time = time.time() - tstart_dyn
425            estimated_remaining_time = tperstep * (nsteps - istep)
426            estimated_total_time = elapsed_time + estimated_remaining_time
427
428            print("#" * 50)
429            print(f"# Step {istep:_} of {nsteps:_}  ({istep/nsteps*100:.5g} %)")
430            print(f"# Tot. elapsed time   : {human_time_duration(elapsed_time)}")
431            print(
432                f"# Est. total time     : {human_time_duration(estimated_total_time)}"
433            )
434            print(
435                f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}"
436            )
437            print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}")
438
439            if print_timings:
440                print(f"# Detailed per-step timings :")
441                dsteps = nsummary
442                tother = tfull - sum([t for t in timings.values()])
443                timings["Other"] = tother
444                # sort timings
445                timings = {
446                    k: v
447                    for k, v in sorted(
448                        timings.items(), key=lambda item: item[1], reverse=True
449                    )
450                }
451                for k, v in timings.items():
452                    print(
453                        f"#   {k:15} : {human_time_duration(v/dsteps):>12} ({v/tfull*100:5.3g} %)"
454                    )
455                print(f"#   {'Total':15} : {human_time_duration(tfull/dsteps):>12}")
456                ## reset timings
457                timings = defaultdict(lambda: 0.0)
458            
459            corr_kin = system["thermostat"].get("corr_kin",None)
460            if corr_kin is not None:
461                print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %")
462            print(f"# Averages over last {nsummary:_} steps :")
463            for k, v in properties_traj.items():
464                if len(properties_traj[k]) == 0:
465                    continue
466                mu = np.mean(properties_traj[k])
467                sig = np.std(properties_traj[k])
468                ksplit = k.split("[")
469                name = ksplit[0].strip()
470                unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else ""
471                print(f"#   {name:10} : {mu: #10.5g}   +/- {sig: #9.3g}  {unit}")
472
473            print(f"# Perf.: {nsperday:.2f} ns/day  ( {1.0 / tperstep:.2f} step/s )")
474            print("#" * 50)
475            if istep < nsteps:
476                print(header)
477            ## reset property trajectories
478            properties_traj = defaultdict(list)
479
480    print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}")
481    ### close trajectory file
482    fout.close()
483    if ensemble_key is not None:
484        fens.close()