fennol.md.dynamic
1import sys, os, io 2import argparse 3import time 4from pathlib import Path 5 6import numpy as np 7from collections import defaultdict 8import jax 9import jax.numpy as jnp 10import pickle 11import yaml 12 13from ..utils.io import ( 14 write_arc_frame, 15 write_xyz_frame, 16 write_extxyz_frame, 17 human_time_duration, 18) 19from .utils import wrapbox, save_dynamics_restart 20from ..utils import minmaxone, AtomicUnits as au 21from ..utils.input_parser import parse_input,convert_dict_units, InputFile 22from .integrate import initialize_dynamics 23 24 25def main(): 26 # os.environ["OMP_NUM_THREADS"] = "1" 27 sys.stdout = io.TextIOWrapper( 28 open(sys.stdout.fileno(), "wb", 0), write_through=True 29 ) 30 ### Read the parameter file 31 parser = argparse.ArgumentParser(prog="fennol_md") 32 parser.add_argument("param_file", type=Path, help="Parameter file") 33 args = parser.parse_args() 34 35 param_file = args.param_file 36 if not param_file.exists() and not param_file.is_file(): 37 raise FileNotFoundError(f"Parameter file {param_file} not found") 38 39 if param_file.suffix in [".yaml", ".yml"]: 40 with open(param_file, "r") as f: 41 simulation_parameters = convert_dict_units(yaml.safe_load(f)) 42 simulation_parameters = InputFile(**simulation_parameters) 43 elif param_file.suffix == ".fnl": 44 simulation_parameters = parse_input(args.param_file) 45 else: 46 raise ValueError( 47 f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'" 48 ) 49 50 ### Set the device 51 device: str = simulation_parameters.get("device", "cpu").lower() 52 if device == "cpu": 53 device = "cpu" 54 os.environ["CUDA_VISIBLE_DEVICES"] = "" 55 elif device.startswith("cuda") or device.startswith("gpu"): 56 if ":" in device: 57 num = device.split(":")[-1] 58 os.environ["CUDA_VISIBLE_DEVICES"] = num 59 else: 60 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 61 device = "gpu" 62 63 _device = jax.devices(device)[0] 64 jax.config.update("jax_default_device", _device) 65 66 ### Set the precision 67 enable_x64 = simulation_parameters.get("double_precision", False) 68 jax.config.update("jax_enable_x64", enable_x64) 69 fprec = "float64" if enable_x64 else "float32" 70 71 matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower() 72 assert matmul_precision in [ 73 "default", 74 "high", 75 "highest", 76 ], "matmul_prec must be one of 'default','high','highest'" 77 if matmul_precision != "highest": 78 print(f"# Setting matmul precision to '{matmul_precision}'") 79 if matmul_precision == "default" and fprec == "float32": 80 print( 81 "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'." 82 ) 83 jax.config.update("jax_default_matmul_precision", matmul_precision) 84 85 # with jax.default_device(_device): 86 dynamic(simulation_parameters, device, fprec) 87 88 89def dynamic(simulation_parameters, device, fprec): 90 tstart_dyn = time.time() 91 92 random_seed = simulation_parameters.get( 93 "random_seed", np.random.randint(0, 2**32 - 1) 94 ) 95 print(f"# random_seed: {random_seed}") 96 rng_key = jax.random.PRNGKey(random_seed) 97 98 ## INITIALIZE INTEGRATOR AND SYSTEM 99 rng_key, subkey = jax.random.split(rng_key) 100 step, update_conformation, system_data, dyn_state, conformation, system = ( 101 initialize_dynamics(simulation_parameters, fprec, subkey) 102 ) 103 104 nat = system_data["nat"] 105 dt = dyn_state["dt"] 106 ## get number of steps 107 nsteps = int(simulation_parameters.get("nsteps")) 108 start_time_ps = dyn_state.get("start_time_ps", 0.0) 109 110 ### Set I/O parameters 111 Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS 112 ndump = int(Tdump / dt) 113 system_name = system_data["name"] 114 estimate_pressure = dyn_state["estimate_pressure"] 115 116 @jax.jit 117 def check_nan(system): 118 return jnp.any(jnp.isnan(system["vel"])) | jnp.any( 119 jnp.isnan(system["coordinates"]) 120 ) 121 122 if system_data["pbc"] is not None: 123 cell = system["cell"] 124 reciprocal_cell = np.linalg.inv(cell) 125 do_wrap_box = simulation_parameters.get("wrap_box", False) 126 else: 127 cell = None 128 reciprocal_cell = None 129 do_wrap_box = False 130 131 ### Energy units and print initial energy 132 per_atom_energy = simulation_parameters.get("per_atom_energy", True) 133 energy_unit_str = system_data["energy_unit_str"] 134 energy_unit = system_data["energy_unit"] 135 print("# Energy unit: ", energy_unit_str) 136 atom_energy_unit = energy_unit 137 atom_energy_unit_str = energy_unit_str 138 if per_atom_energy: 139 atom_energy_unit /= nat 140 atom_energy_unit_str = f"{energy_unit_str}/atom" 141 print("# Printing Energy per atom") 142 print( 143 f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}" 144 ) 145 f = system["forces"] 146 minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:") 147 148 ## printing options 149 print_timings = simulation_parameters.get("print_timings", False) 150 nprint = int(simulation_parameters.get("nprint", 10)) 151 assert nprint > 0, "nprint must be > 0" 152 nsummary = simulation_parameters.get("nsummary", 100 * nprint) 153 assert nsummary > nprint, "nsummary must be > nprint" 154 155 save_keys = simulation_parameters.get("save_keys", []) 156 if save_keys: 157 print(f"# Saving keys: {save_keys}") 158 fkeys = open(f"{system_name}.traj.pkl", "wb+") 159 else: 160 fkeys = None 161 162 ### initialize colvars 163 use_colvars = "colvars" in dyn_state 164 if use_colvars: 165 print(f"# Colvars: {dyn_state['colvars']}") 166 colvars_names = dyn_state["colvars"] 167 # open colvars file and print header 168 fcolvars = open(f"{system_name}.colvars.traj", "a") 169 fcolvars.write("#time[ps] ") 170 fcolvars.write(" ".join(colvars_names)) 171 fcolvars.write("\n") 172 fcolvars.flush() 173 174 ### Print header 175 include_thermostat_energy = "thermostat_energy" in system["thermostat"] 176 thermostat_name = dyn_state["thermostat_name"] 177 pimd = dyn_state["pimd"] 178 variable_cell = dyn_state["variable_cell"] 179 nbeads = system_data.get("nbeads", 1) 180 dyn_name = "PIMD" if pimd else "MD" 181 print("#" * 84) 182 print( 183 f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}" 184 ) 185 header = "# Step Time[ps] Etot Epot Ekin Temp[K]" 186 if pimd: 187 header += " Temp_c[K]" 188 if include_thermostat_energy: 189 header += " Etherm" 190 if estimate_pressure: 191 print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False) 192 pressure_unit_str = simulation_parameters.get("pressure_unit", "atm") 193 pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3 194 header += f" Press[{pressure_unit_str}]" 195 if variable_cell: 196 header += " Density" 197 print(header) 198 199 ### Open trajectory file 200 traj_format = simulation_parameters.get("traj_format", "arc").lower() 201 if traj_format == "xyz": 202 traj_ext = ".traj.xyz" 203 write_frame = write_xyz_frame 204 elif traj_format == "extxyz": 205 traj_ext = ".traj.extxyz" 206 write_frame = write_extxyz_frame 207 elif traj_format == "arc": 208 traj_ext = ".arc" 209 write_frame = write_arc_frame 210 else: 211 raise ValueError( 212 f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'" 213 ) 214 215 write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd 216 217 if write_all_beads: 218 fout = [ 219 open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads) 220 ] 221 else: 222 fout = open(system_name + traj_ext, "a") 223 224 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 225 if ensemble_key is not None: 226 fens = open(f"{system_name}.ensemble_weights.traj", "a") 227 228 write_centroid = simulation_parameters.get("write_centroid", False) and pimd 229 if write_centroid: 230 fcentroid = open(f"{system_name}_centroid" + traj_ext, "a") 231 232 ### initialize proprerty trajectories 233 properties_traj = defaultdict(list) 234 235 ### initialize counters and timers 236 t0 = time.time() 237 t0dump = t0 238 istep = 0 239 t0full = time.time() 240 force_preprocess = False 241 242 for istep in range(1, nsteps + 1): 243 244 ### update the system 245 dyn_state, system, conformation, model_output = step( 246 istep, dyn_state, system, conformation, force_preprocess 247 ) 248 249 ### print properties 250 if istep % nprint == 0: 251 t1 = time.time() 252 tperstep = (t1 - t0) / nprint 253 t0 = t1 254 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 255 256 ek = system["ek"] 257 epot = system["epot"] 258 etot = ek + epot 259 temper = 2 * ek / (3.0 * nat) * au.KELVIN 260 261 th_state = system["thermostat"] 262 if include_thermostat_energy: 263 etherm = th_state["thermostat_energy"] 264 etot = etot + etherm 265 266 properties_traj[f"Etot[{atom_energy_unit_str}]"].append( 267 etot * atom_energy_unit 268 ) 269 properties_traj[f"Epot[{atom_energy_unit_str}]"].append( 270 epot * atom_energy_unit 271 ) 272 properties_traj[f"Ekin[{atom_energy_unit_str}]"].append( 273 ek * atom_energy_unit 274 ) 275 properties_traj["Temper[Kelvin]"].append(temper) 276 if pimd: 277 ek_c = system["ek_c"] 278 temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN 279 properties_traj["Temper_c[Kelvin]"].append(temper_c) 280 281 ### construct line of properties 282 simulated_time = start_time_ps + istep * dt / 1000 283 line = f"{istep:10.6g} {simulated_time: 10.3f} {etot*atom_energy_unit: #10.4f} {epot*atom_energy_unit: #10.4f} {ek*atom_energy_unit: #10.4f} {temper: 10.2f}" 284 if pimd: 285 line += f" {temper_c: 10.2f}" 286 if include_thermostat_energy: 287 line += f" {etherm*atom_energy_unit: #10.4f}" 288 properties_traj[f"Etherm[{atom_energy_unit_str}]"].append( 289 etherm * atom_energy_unit 290 ) 291 if estimate_pressure: 292 pres = system["pressure"] * pressure_unit 293 properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres) 294 if print_aniso_pressure: 295 pres_tensor = system["pressure_tensor"] * pressure_unit 296 pres_tensor = 0.5 * (pres_tensor + pres_tensor.T) 297 properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append( 298 pres_tensor[0, 0] 299 ) 300 properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append( 301 pres_tensor[1, 1] 302 ) 303 properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append( 304 pres_tensor[2, 2] 305 ) 306 properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append( 307 pres_tensor[0, 1] 308 ) 309 properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append( 310 pres_tensor[0, 2] 311 ) 312 properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append( 313 pres_tensor[1, 2] 314 ) 315 line += f" {pres:10.3f}" 316 if variable_cell: 317 density = system["density"] 318 properties_traj["Density[g/cm^3]"].append(density) 319 if print_aniso_pressure: 320 cell = system["cell"] 321 properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0]) 322 properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1]) 323 properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2]) 324 properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0]) 325 properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1]) 326 properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2]) 327 properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0]) 328 properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1]) 329 properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2]) 330 line += f" {density:10.4f}" 331 if "piston_temperature" in system["barostat"]: 332 piston_temperature = system["barostat"]["piston_temperature"] 333 properties_traj["T_piston[Kelvin]"].append(piston_temperature) 334 335 print(line) 336 337 ### save colvars 338 if use_colvars: 339 colvars = system["colvars"] 340 fcolvars.write(f"{simulated_time:.3f} ") 341 fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names])) 342 fcolvars.write("\n") 343 fcolvars.flush() 344 345 if save_keys: 346 data = { 347 k: ( 348 np.asarray(model_output[k]) 349 if isinstance(model_output[k], jnp.ndarray) 350 else model_output[k] 351 ) 352 for k in save_keys 353 } 354 data["properties"] = { 355 k: float(v) for k, v in zip(header.split()[1:], line.split()) 356 } 357 pickle.dump(data, fkeys) 358 359 ### save frame 360 if istep % ndump == 0: 361 line = "# Write XYZ frame" 362 if variable_cell: 363 cell = np.array(system["cell"]) 364 reciprocal_cell = np.linalg.inv(cell) 365 if do_wrap_box: 366 if pimd: 367 centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell) 368 system["coordinates"] = system["coordinates"].at[0].set(centroid) 369 else: 370 system["coordinates"] = wrapbox( 371 system["coordinates"], cell, reciprocal_cell 372 ) 373 conformation = update_conformation(conformation, system) 374 line += " (atoms have been wrapped into the box)" 375 force_preprocess = True 376 print(line) 377 378 save_dynamics_restart(system_data, conformation, dyn_state, system) 379 380 properties = { 381 "energy": float(system["epot"]) * energy_unit, 382 "Time_ps": start_time_ps + istep * dt / 1000, 383 "energy_unit": energy_unit_str, 384 } 385 386 if write_all_beads: 387 coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3)) 388 for i, fb in enumerate(fout): 389 write_frame( 390 fb, 391 system_data["symbols"], 392 coords[i], 393 cell=cell, 394 properties=properties, 395 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 396 ) 397 else: 398 write_frame( 399 fout, 400 system_data["symbols"], 401 np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]), 402 cell=cell, 403 properties=properties, 404 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 405 ) 406 if write_centroid: 407 centroid = np.asarray(system["coordinates"][0]) 408 write_frame( 409 fcentroid, 410 system_data["symbols"], 411 centroid, 412 cell=cell, 413 properties=properties, 414 forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) 415 * energy_unit, 416 ) 417 if ensemble_key is not None: 418 weights = " ".join( 419 [f"{w:.6f}" for w in system["ensemble_weights"].tolist()] 420 ) 421 fens.write(f"{weights}\n") 422 fens.flush() 423 424 ### summary over last nsummary steps 425 if istep % (nsummary) == 0: 426 if check_nan(system): 427 raise ValueError(f"dynamics crashed at step {istep}.") 428 tfull = time.time() - t0full 429 t0full = time.time() 430 tperstep = tfull / (nsummary) 431 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 432 elapsed_time = time.time() - tstart_dyn 433 estimated_remaining_time = tperstep * (nsteps - istep) 434 estimated_total_time = elapsed_time + estimated_remaining_time 435 436 print("#" * 50) 437 print(f"# Step {istep:_} of {nsteps:_} ({istep/nsteps*100:.5g} %)") 438 print(f"# Simulated time : {istep * dt*1.e-3:.3f} ps") 439 print(f"# Tot. Simu. time : {start_time_ps + istep * dt*1.e-3:.3f} ps") 440 print(f"# Tot. elapsed time : {human_time_duration(elapsed_time)}") 441 print( 442 f"# Est. total duration : {human_time_duration(estimated_total_time)}" 443 ) 444 print( 445 f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}" 446 ) 447 print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}") 448 449 corr_kin = system["thermostat"].get("corr_kin", None) 450 if corr_kin is not None: 451 print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %") 452 print(f"# Averages over last {nsummary:_} steps :") 453 for k, v in properties_traj.items(): 454 if len(properties_traj[k]) == 0: 455 continue 456 mu = np.mean(properties_traj[k]) 457 sig = np.std(properties_traj[k]) 458 ksplit = k.split("[") 459 name = ksplit[0].strip() 460 unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else "" 461 print(f"# {name:10} : {mu: #10.5g} +/- {sig: #9.3g} {unit}") 462 463 print(f"# Perf.: {nsperday:.2f} ns/day ( {1.0 / tperstep:.2f} step/s )") 464 print("#" * 50) 465 if istep < nsteps: 466 print(header) 467 ## reset property trajectories 468 properties_traj = defaultdict(list) 469 470 print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}") 471 ### close trajectory file 472 fout.close() 473 if ensemble_key is not None: 474 fens.close() 475 if use_colvars: 476 fcolvars.close() 477 if fkeys is not None: 478 fkeys.close() 479 if write_centroid: 480 fcentroid.close() 481 482 483if __name__ == "__main__": 484 main()
def
main():
26def main(): 27 # os.environ["OMP_NUM_THREADS"] = "1" 28 sys.stdout = io.TextIOWrapper( 29 open(sys.stdout.fileno(), "wb", 0), write_through=True 30 ) 31 ### Read the parameter file 32 parser = argparse.ArgumentParser(prog="fennol_md") 33 parser.add_argument("param_file", type=Path, help="Parameter file") 34 args = parser.parse_args() 35 36 param_file = args.param_file 37 if not param_file.exists() and not param_file.is_file(): 38 raise FileNotFoundError(f"Parameter file {param_file} not found") 39 40 if param_file.suffix in [".yaml", ".yml"]: 41 with open(param_file, "r") as f: 42 simulation_parameters = convert_dict_units(yaml.safe_load(f)) 43 simulation_parameters = InputFile(**simulation_parameters) 44 elif param_file.suffix == ".fnl": 45 simulation_parameters = parse_input(args.param_file) 46 else: 47 raise ValueError( 48 f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'" 49 ) 50 51 ### Set the device 52 device: str = simulation_parameters.get("device", "cpu").lower() 53 if device == "cpu": 54 device = "cpu" 55 os.environ["CUDA_VISIBLE_DEVICES"] = "" 56 elif device.startswith("cuda") or device.startswith("gpu"): 57 if ":" in device: 58 num = device.split(":")[-1] 59 os.environ["CUDA_VISIBLE_DEVICES"] = num 60 else: 61 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 62 device = "gpu" 63 64 _device = jax.devices(device)[0] 65 jax.config.update("jax_default_device", _device) 66 67 ### Set the precision 68 enable_x64 = simulation_parameters.get("double_precision", False) 69 jax.config.update("jax_enable_x64", enable_x64) 70 fprec = "float64" if enable_x64 else "float32" 71 72 matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower() 73 assert matmul_precision in [ 74 "default", 75 "high", 76 "highest", 77 ], "matmul_prec must be one of 'default','high','highest'" 78 if matmul_precision != "highest": 79 print(f"# Setting matmul precision to '{matmul_precision}'") 80 if matmul_precision == "default" and fprec == "float32": 81 print( 82 "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'." 83 ) 84 jax.config.update("jax_default_matmul_precision", matmul_precision) 85 86 # with jax.default_device(_device): 87 dynamic(simulation_parameters, device, fprec)
def
dynamic(simulation_parameters, device, fprec):
90def dynamic(simulation_parameters, device, fprec): 91 tstart_dyn = time.time() 92 93 random_seed = simulation_parameters.get( 94 "random_seed", np.random.randint(0, 2**32 - 1) 95 ) 96 print(f"# random_seed: {random_seed}") 97 rng_key = jax.random.PRNGKey(random_seed) 98 99 ## INITIALIZE INTEGRATOR AND SYSTEM 100 rng_key, subkey = jax.random.split(rng_key) 101 step, update_conformation, system_data, dyn_state, conformation, system = ( 102 initialize_dynamics(simulation_parameters, fprec, subkey) 103 ) 104 105 nat = system_data["nat"] 106 dt = dyn_state["dt"] 107 ## get number of steps 108 nsteps = int(simulation_parameters.get("nsteps")) 109 start_time_ps = dyn_state.get("start_time_ps", 0.0) 110 111 ### Set I/O parameters 112 Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS 113 ndump = int(Tdump / dt) 114 system_name = system_data["name"] 115 estimate_pressure = dyn_state["estimate_pressure"] 116 117 @jax.jit 118 def check_nan(system): 119 return jnp.any(jnp.isnan(system["vel"])) | jnp.any( 120 jnp.isnan(system["coordinates"]) 121 ) 122 123 if system_data["pbc"] is not None: 124 cell = system["cell"] 125 reciprocal_cell = np.linalg.inv(cell) 126 do_wrap_box = simulation_parameters.get("wrap_box", False) 127 else: 128 cell = None 129 reciprocal_cell = None 130 do_wrap_box = False 131 132 ### Energy units and print initial energy 133 per_atom_energy = simulation_parameters.get("per_atom_energy", True) 134 energy_unit_str = system_data["energy_unit_str"] 135 energy_unit = system_data["energy_unit"] 136 print("# Energy unit: ", energy_unit_str) 137 atom_energy_unit = energy_unit 138 atom_energy_unit_str = energy_unit_str 139 if per_atom_energy: 140 atom_energy_unit /= nat 141 atom_energy_unit_str = f"{energy_unit_str}/atom" 142 print("# Printing Energy per atom") 143 print( 144 f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}" 145 ) 146 f = system["forces"] 147 minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:") 148 149 ## printing options 150 print_timings = simulation_parameters.get("print_timings", False) 151 nprint = int(simulation_parameters.get("nprint", 10)) 152 assert nprint > 0, "nprint must be > 0" 153 nsummary = simulation_parameters.get("nsummary", 100 * nprint) 154 assert nsummary > nprint, "nsummary must be > nprint" 155 156 save_keys = simulation_parameters.get("save_keys", []) 157 if save_keys: 158 print(f"# Saving keys: {save_keys}") 159 fkeys = open(f"{system_name}.traj.pkl", "wb+") 160 else: 161 fkeys = None 162 163 ### initialize colvars 164 use_colvars = "colvars" in dyn_state 165 if use_colvars: 166 print(f"# Colvars: {dyn_state['colvars']}") 167 colvars_names = dyn_state["colvars"] 168 # open colvars file and print header 169 fcolvars = open(f"{system_name}.colvars.traj", "a") 170 fcolvars.write("#time[ps] ") 171 fcolvars.write(" ".join(colvars_names)) 172 fcolvars.write("\n") 173 fcolvars.flush() 174 175 ### Print header 176 include_thermostat_energy = "thermostat_energy" in system["thermostat"] 177 thermostat_name = dyn_state["thermostat_name"] 178 pimd = dyn_state["pimd"] 179 variable_cell = dyn_state["variable_cell"] 180 nbeads = system_data.get("nbeads", 1) 181 dyn_name = "PIMD" if pimd else "MD" 182 print("#" * 84) 183 print( 184 f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}" 185 ) 186 header = "# Step Time[ps] Etot Epot Ekin Temp[K]" 187 if pimd: 188 header += " Temp_c[K]" 189 if include_thermostat_energy: 190 header += " Etherm" 191 if estimate_pressure: 192 print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False) 193 pressure_unit_str = simulation_parameters.get("pressure_unit", "atm") 194 pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3 195 header += f" Press[{pressure_unit_str}]" 196 if variable_cell: 197 header += " Density" 198 print(header) 199 200 ### Open trajectory file 201 traj_format = simulation_parameters.get("traj_format", "arc").lower() 202 if traj_format == "xyz": 203 traj_ext = ".traj.xyz" 204 write_frame = write_xyz_frame 205 elif traj_format == "extxyz": 206 traj_ext = ".traj.extxyz" 207 write_frame = write_extxyz_frame 208 elif traj_format == "arc": 209 traj_ext = ".arc" 210 write_frame = write_arc_frame 211 else: 212 raise ValueError( 213 f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'" 214 ) 215 216 write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd 217 218 if write_all_beads: 219 fout = [ 220 open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads) 221 ] 222 else: 223 fout = open(system_name + traj_ext, "a") 224 225 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 226 if ensemble_key is not None: 227 fens = open(f"{system_name}.ensemble_weights.traj", "a") 228 229 write_centroid = simulation_parameters.get("write_centroid", False) and pimd 230 if write_centroid: 231 fcentroid = open(f"{system_name}_centroid" + traj_ext, "a") 232 233 ### initialize proprerty trajectories 234 properties_traj = defaultdict(list) 235 236 ### initialize counters and timers 237 t0 = time.time() 238 t0dump = t0 239 istep = 0 240 t0full = time.time() 241 force_preprocess = False 242 243 for istep in range(1, nsteps + 1): 244 245 ### update the system 246 dyn_state, system, conformation, model_output = step( 247 istep, dyn_state, system, conformation, force_preprocess 248 ) 249 250 ### print properties 251 if istep % nprint == 0: 252 t1 = time.time() 253 tperstep = (t1 - t0) / nprint 254 t0 = t1 255 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 256 257 ek = system["ek"] 258 epot = system["epot"] 259 etot = ek + epot 260 temper = 2 * ek / (3.0 * nat) * au.KELVIN 261 262 th_state = system["thermostat"] 263 if include_thermostat_energy: 264 etherm = th_state["thermostat_energy"] 265 etot = etot + etherm 266 267 properties_traj[f"Etot[{atom_energy_unit_str}]"].append( 268 etot * atom_energy_unit 269 ) 270 properties_traj[f"Epot[{atom_energy_unit_str}]"].append( 271 epot * atom_energy_unit 272 ) 273 properties_traj[f"Ekin[{atom_energy_unit_str}]"].append( 274 ek * atom_energy_unit 275 ) 276 properties_traj["Temper[Kelvin]"].append(temper) 277 if pimd: 278 ek_c = system["ek_c"] 279 temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN 280 properties_traj["Temper_c[Kelvin]"].append(temper_c) 281 282 ### construct line of properties 283 simulated_time = start_time_ps + istep * dt / 1000 284 line = f"{istep:10.6g} {simulated_time: 10.3f} {etot*atom_energy_unit: #10.4f} {epot*atom_energy_unit: #10.4f} {ek*atom_energy_unit: #10.4f} {temper: 10.2f}" 285 if pimd: 286 line += f" {temper_c: 10.2f}" 287 if include_thermostat_energy: 288 line += f" {etherm*atom_energy_unit: #10.4f}" 289 properties_traj[f"Etherm[{atom_energy_unit_str}]"].append( 290 etherm * atom_energy_unit 291 ) 292 if estimate_pressure: 293 pres = system["pressure"] * pressure_unit 294 properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres) 295 if print_aniso_pressure: 296 pres_tensor = system["pressure_tensor"] * pressure_unit 297 pres_tensor = 0.5 * (pres_tensor + pres_tensor.T) 298 properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append( 299 pres_tensor[0, 0] 300 ) 301 properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append( 302 pres_tensor[1, 1] 303 ) 304 properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append( 305 pres_tensor[2, 2] 306 ) 307 properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append( 308 pres_tensor[0, 1] 309 ) 310 properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append( 311 pres_tensor[0, 2] 312 ) 313 properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append( 314 pres_tensor[1, 2] 315 ) 316 line += f" {pres:10.3f}" 317 if variable_cell: 318 density = system["density"] 319 properties_traj["Density[g/cm^3]"].append(density) 320 if print_aniso_pressure: 321 cell = system["cell"] 322 properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0]) 323 properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1]) 324 properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2]) 325 properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0]) 326 properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1]) 327 properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2]) 328 properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0]) 329 properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1]) 330 properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2]) 331 line += f" {density:10.4f}" 332 if "piston_temperature" in system["barostat"]: 333 piston_temperature = system["barostat"]["piston_temperature"] 334 properties_traj["T_piston[Kelvin]"].append(piston_temperature) 335 336 print(line) 337 338 ### save colvars 339 if use_colvars: 340 colvars = system["colvars"] 341 fcolvars.write(f"{simulated_time:.3f} ") 342 fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names])) 343 fcolvars.write("\n") 344 fcolvars.flush() 345 346 if save_keys: 347 data = { 348 k: ( 349 np.asarray(model_output[k]) 350 if isinstance(model_output[k], jnp.ndarray) 351 else model_output[k] 352 ) 353 for k in save_keys 354 } 355 data["properties"] = { 356 k: float(v) for k, v in zip(header.split()[1:], line.split()) 357 } 358 pickle.dump(data, fkeys) 359 360 ### save frame 361 if istep % ndump == 0: 362 line = "# Write XYZ frame" 363 if variable_cell: 364 cell = np.array(system["cell"]) 365 reciprocal_cell = np.linalg.inv(cell) 366 if do_wrap_box: 367 if pimd: 368 centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell) 369 system["coordinates"] = system["coordinates"].at[0].set(centroid) 370 else: 371 system["coordinates"] = wrapbox( 372 system["coordinates"], cell, reciprocal_cell 373 ) 374 conformation = update_conformation(conformation, system) 375 line += " (atoms have been wrapped into the box)" 376 force_preprocess = True 377 print(line) 378 379 save_dynamics_restart(system_data, conformation, dyn_state, system) 380 381 properties = { 382 "energy": float(system["epot"]) * energy_unit, 383 "Time_ps": start_time_ps + istep * dt / 1000, 384 "energy_unit": energy_unit_str, 385 } 386 387 if write_all_beads: 388 coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3)) 389 for i, fb in enumerate(fout): 390 write_frame( 391 fb, 392 system_data["symbols"], 393 coords[i], 394 cell=cell, 395 properties=properties, 396 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 397 ) 398 else: 399 write_frame( 400 fout, 401 system_data["symbols"], 402 np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]), 403 cell=cell, 404 properties=properties, 405 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 406 ) 407 if write_centroid: 408 centroid = np.asarray(system["coordinates"][0]) 409 write_frame( 410 fcentroid, 411 system_data["symbols"], 412 centroid, 413 cell=cell, 414 properties=properties, 415 forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) 416 * energy_unit, 417 ) 418 if ensemble_key is not None: 419 weights = " ".join( 420 [f"{w:.6f}" for w in system["ensemble_weights"].tolist()] 421 ) 422 fens.write(f"{weights}\n") 423 fens.flush() 424 425 ### summary over last nsummary steps 426 if istep % (nsummary) == 0: 427 if check_nan(system): 428 raise ValueError(f"dynamics crashed at step {istep}.") 429 tfull = time.time() - t0full 430 t0full = time.time() 431 tperstep = tfull / (nsummary) 432 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 433 elapsed_time = time.time() - tstart_dyn 434 estimated_remaining_time = tperstep * (nsteps - istep) 435 estimated_total_time = elapsed_time + estimated_remaining_time 436 437 print("#" * 50) 438 print(f"# Step {istep:_} of {nsteps:_} ({istep/nsteps*100:.5g} %)") 439 print(f"# Simulated time : {istep * dt*1.e-3:.3f} ps") 440 print(f"# Tot. Simu. time : {start_time_ps + istep * dt*1.e-3:.3f} ps") 441 print(f"# Tot. elapsed time : {human_time_duration(elapsed_time)}") 442 print( 443 f"# Est. total duration : {human_time_duration(estimated_total_time)}" 444 ) 445 print( 446 f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}" 447 ) 448 print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}") 449 450 corr_kin = system["thermostat"].get("corr_kin", None) 451 if corr_kin is not None: 452 print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %") 453 print(f"# Averages over last {nsummary:_} steps :") 454 for k, v in properties_traj.items(): 455 if len(properties_traj[k]) == 0: 456 continue 457 mu = np.mean(properties_traj[k]) 458 sig = np.std(properties_traj[k]) 459 ksplit = k.split("[") 460 name = ksplit[0].strip() 461 unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else "" 462 print(f"# {name:10} : {mu: #10.5g} +/- {sig: #9.3g} {unit}") 463 464 print(f"# Perf.: {nsperday:.2f} ns/day ( {1.0 / tperstep:.2f} step/s )") 465 print("#" * 50) 466 if istep < nsteps: 467 print(header) 468 ## reset property trajectories 469 properties_traj = defaultdict(list) 470 471 print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}") 472 ### close trajectory file 473 fout.close() 474 if ensemble_key is not None: 475 fens.close() 476 if use_colvars: 477 fcolvars.close() 478 if fkeys is not None: 479 fkeys.close() 480 if write_centroid: 481 fcentroid.close()