fennol.md.dynamic
1import sys, os, io 2import argparse 3import time 4from pathlib import Path 5import math 6 7import numpy as np 8from typing import Optional, Callable 9from collections import defaultdict 10from functools import partial 11import jax 12import jax.numpy as jnp 13 14from flax.core import freeze, unfreeze 15 16 17from ..models import FENNIX 18from ..utils.io import ( 19 write_arc_frame, 20 last_xyz_frame, 21 write_xyz_frame, 22 write_extxyz_frame, 23 human_time_duration, 24) 25from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES 26from ..utils.atomic_units import AtomicUnits as au 27from ..utils.input_parser import parse_input 28from .initial import load_model, load_system_data, initialize_preprocessing 29from .integrate import initialize_dynamics 30 31from copy import deepcopy 32 33 34def minmaxone(x, name=""): 35 print(name, x.min(), x.max(), (x**2).mean() ** 0.5) 36 37 38@jax.jit 39def wrapbox(x, cell, reciprocal_cell): 40 # q = jnp.einsum("ji,sj->si", reciprocal_cell, x) 41 q = x @ reciprocal_cell 42 q = q - jnp.floor(q) 43 # return jnp.einsum("ji,sj->si", cell, q) 44 return q @ cell 45 46 47def main(): 48 # os.environ["OMP_NUM_THREADS"] = "1" 49 sys.stdout = io.TextIOWrapper( 50 open(sys.stdout.fileno(), "wb", 0), write_through=True 51 ) 52 ### Read the parameter file 53 parser = argparse.ArgumentParser(prog="fennol_md") 54 parser.add_argument("param_file", type=Path, help="Parameter file") 55 args = parser.parse_args() 56 simulation_parameters = parse_input(args.param_file) 57 58 ### Set the device 59 device: str = simulation_parameters.get("device", "cpu").lower() 60 if device == "cpu": 61 device = "cpu" 62 os.environ["CUDA_VISIBLE_DEVICES"] = "" 63 elif device.startswith("cuda") or device.startswith("gpu"): 64 if ":" in device: 65 num = device.split(":")[-1] 66 os.environ["CUDA_VISIBLE_DEVICES"] = num 67 else: 68 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 69 device = "gpu" 70 71 _device = jax.devices(device)[0] 72 jax.config.update("jax_default_device", _device) 73 74 ### Set the precision 75 enable_x64 = simulation_parameters.get("double_precision", False) 76 jax.config.update("jax_enable_x64", enable_x64) 77 fprec = "float64" if enable_x64 else "float32" 78 79 matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower() 80 assert matmul_precision in [ 81 "default", 82 "high", 83 "highest", 84 ], "matmul_prec must be one of 'default','high','highest'" 85 if matmul_precision != "highest": 86 print(f"# Setting matmul precision to '{matmul_precision}'") 87 if matmul_precision == "default" and fprec == "float32": 88 print( 89 "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'." 90 ) 91 jax.config.update("jax_default_matmul_precision", matmul_precision) 92 93 # with jax.default_device(_device): 94 dynamic(simulation_parameters, device, fprec) 95 96 97def dynamic(simulation_parameters, device, fprec): 98 tstart_dyn = time.time() 99 100 ### Initialize the model 101 model = load_model(simulation_parameters) 102 103 ### Get the coordinates and species from the xyz file 104 system_data, conformation = load_system_data(simulation_parameters, fprec) 105 nat = system_data["nat"] 106 107 preproc_state, conformation = initialize_preprocessing( 108 simulation_parameters, model, conformation, system_data 109 ) 110 111 random_seed = simulation_parameters.get( 112 "random_seed", np.random.randint(0, 2**32 - 1) 113 ) 114 print(f"# random_seed: {random_seed}") 115 rng_key = jax.random.PRNGKey(random_seed) 116 rng_key, subkey = jax.random.split(rng_key) 117 ## INITIALIZE INTEGRATOR AND SYSTEM 118 step, update_conformation, dyn_state, system = initialize_dynamics( 119 simulation_parameters, system_data, conformation, model, fprec, subkey 120 ) 121 122 dt = dyn_state["dt"] 123 ## get number of steps 124 nsteps = int(simulation_parameters.get("nsteps")) 125 start_time = 0.0 126 start_step = 0 127 128 129 ### Set I/O parameters 130 Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS 131 ndump = int(Tdump / dt) 132 system_name = system_data["name"] 133 134 model_energy_unit = au.get_multiplier(model.energy_unit) 135 ### Print initial pressure 136 estimate_pressure = dyn_state["estimate_pressure"] 137 if estimate_pressure and fprec == "float64": 138 volume = system_data["pbc"]["volume"] 139 coordinates = conformation["coordinates"] 140 cell = conformation["cells"][0] 141 # temper = 2 * ek / (3.0 * nat) * au.KELVIN 142 ek = 1.5*nat * system_data["kT"] 143 Pkin = (2 * au.KBAR) * ek / ((3.0 / au.BOHR**3) * volume) 144 e, f, vir_t, _ = model._energy_and_forces_and_virial( 145 model.variables, conformation 146 ) 147 KBAR = au.KBAR/model_energy_unit 148 Pvir = -(np.trace(vir_t[0]) * KBAR) / ((3.0 / au.BOHR**3) * volume) 149 vstep = volume * 0.000001 150 scalep = ((volume + vstep) / volume) ** (1.0 / 3.0) 151 cellp = cell * scalep 152 reciprocal_cell = np.linalg.inv(cellp) 153 sysp = model.preprocess( 154 **{ 155 **conformation, 156 "coordinates": coordinates * scalep, 157 "cells": cellp[None, :, :], 158 "reciprocal_cells": reciprocal_cell[None, :, :], 159 } 160 ) 161 ep, _ = model._total_energy(model.variables, sysp) 162 scalem = ((volume - vstep) / volume) ** (1.0 / 3.0) 163 cellm = cell * scalem 164 reciprocal_cell = np.linalg.inv(cellm) 165 sysm = model.preprocess( 166 **{ 167 **conformation, 168 "coordinates": coordinates * scalem, 169 "cells": cellm[None, :, :], 170 "reciprocal_cells": reciprocal_cell[None, :, :], 171 } 172 ) 173 em, _ = model._total_energy(model.variables, sysm) 174 Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep / au.BOHR**3) 175 print( 176 f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}" 177 ) 178 179 @jax.jit 180 def check_nan(system): 181 return jnp.any(jnp.isnan(system["vel"])) | jnp.any( 182 jnp.isnan(system["coordinates"]) 183 ) 184 185 if system_data["pbc"] is not None: 186 cell = system_data["pbc"]["cell"] 187 reciprocal_cell = system_data["pbc"]["reciprocal_cell"] 188 do_wrap_box = simulation_parameters.get("wrap_box", False) 189 else: 190 cell = None 191 reciprocal_cell = None 192 do_wrap_box = False 193 194 ### Energy units and print initial energy 195 per_atom_energy = simulation_parameters.get("per_atom_energy", True) 196 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 197 print("# Energy unit: ", energy_unit_str) 198 energy_unit = au.get_multiplier(energy_unit_str) 199 atom_energy_unit = energy_unit 200 atom_energy_unit_str = energy_unit_str 201 if per_atom_energy: 202 atom_energy_unit /= nat 203 atom_energy_unit_str = f"{energy_unit_str}/atom" 204 print("# Printing Energy per atom") 205 print( 206 f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}" 207 ) 208 f = system["forces"] 209 minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:") 210 211 ## printing options 212 print_timings = simulation_parameters.get("print_timings", False) 213 nprint = int(simulation_parameters.get("nprint", 10)) 214 assert nprint > 0, "nprint must be > 0" 215 nsummary = simulation_parameters.get("nsummary", 100 * nprint) 216 assert nsummary > nprint, "nsummary must be > nprint" 217 218 ### Print header 219 include_thermostat_energy = "thermostat_energy" in system["thermostat"] 220 thermostat_name = dyn_state["thermostat_name"] 221 pimd = dyn_state["pimd"] 222 variable_cell = dyn_state["variable_cell"] 223 nbeads = system_data.get("nbeads", 1) 224 dyn_name = "PIMD" if pimd else "MD" 225 print("#" * 84) 226 print( 227 f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}" 228 ) 229 header = "# Step Time[ps] Etot Epot Ekin Temp[K]" 230 if pimd: 231 header += " Temp_c[K]" 232 if include_thermostat_energy: 233 header += " Etherm" 234 if estimate_pressure: 235 print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False) 236 pressure_unit_str = simulation_parameters.get("pressure_unit", "atm") 237 pressure_unit = au.get_multiplier(pressure_unit_str)*au.BOHR**3 238 header += f" Press[{pressure_unit_str}]" 239 if variable_cell: 240 header += " Density" 241 print(header) 242 243 ### Open trajectory file 244 traj_format = simulation_parameters.get("traj_format", "arc").lower() 245 if traj_format == "xyz": 246 traj_ext = ".traj.xyz" 247 write_frame = write_xyz_frame 248 elif traj_format == "extxyz": 249 traj_ext = ".traj.extxyz" 250 write_frame = write_extxyz_frame 251 elif traj_format == "arc": 252 traj_ext = ".arc" 253 write_frame = write_arc_frame 254 else: 255 raise ValueError( 256 f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'" 257 ) 258 259 write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd 260 261 if write_all_beads: 262 fout = [open(f"{system_name}_bead{i+1:03d}"+traj_ext, "w") for i in range(nbeads)] 263 else: 264 fout = open(system_name+traj_ext, "a+") 265 266 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 267 if ensemble_key is not None: 268 fens = open(f"{system_name}.ensemble_weights.traj", "a+") 269 270 271 ### initialize proprerty trajectories 272 properties_traj = defaultdict(list) 273 if print_timings: 274 timings = defaultdict(lambda: 0.0) 275 276 ### initialize counters and timers 277 t0 = time.time() 278 t0dump = t0 279 istep = 0 280 t0full = time.time() 281 force_preprocess = False 282 283 for istep in range(1, nsteps + 1): 284 285 ### update the system 286 dyn_state, system, conformation, preproc_state = step( 287 istep, dyn_state, system, conformation, preproc_state, force_preprocess 288 ) 289 290 ### print properties 291 if istep % nprint == 0: 292 t1 = time.time() 293 tperstep = (t1 - t0) / nprint 294 t0 = t1 295 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 296 297 ek = system["ek"] 298 epot = system["epot"] 299 etot = ek + epot 300 temper = 2 * ek / (3.0 * nat) * au.KELVIN 301 302 th_state = system["thermostat"] 303 if include_thermostat_energy: 304 etherm = th_state["thermostat_energy"] 305 etot = etot + etherm 306 307 properties_traj[f"Etot[{atom_energy_unit_str}]"].append( 308 etot * atom_energy_unit 309 ) 310 properties_traj[f"Epot[{atom_energy_unit_str}]"].append( 311 epot * atom_energy_unit 312 ) 313 properties_traj[f"Ekin[{atom_energy_unit_str}]"].append( 314 ek * atom_energy_unit 315 ) 316 properties_traj["Temper[Kelvin]"].append(temper) 317 if pimd: 318 ek_c = system["ek_c"] 319 temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN 320 properties_traj["Temper_c[Kelvin]"].append(temper_c) 321 322 ### construct line of properties 323 line = f"{istep:10.6g} {(start_time+istep*dt)/1000: 10.3f} {etot*atom_energy_unit: #10.4f} {epot*atom_energy_unit: #10.4f} {ek*atom_energy_unit: #10.4f} {temper: 10.2f}" 324 if pimd: 325 line += f" {temper_c: 10.2f}" 326 if include_thermostat_energy: 327 line += f" {etherm*atom_energy_unit: #10.4f}" 328 properties_traj[f"Etherm[{atom_energy_unit_str}]"].append( 329 etherm * atom_energy_unit 330 ) 331 if estimate_pressure: 332 pres = system["pressure"]*pressure_unit 333 properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres) 334 if print_aniso_pressure: 335 pres_tensor = system["pressure_tensor"]*pressure_unit 336 pres_tensor = 0.5*(pres_tensor + pres_tensor.T) 337 properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(pres_tensor[0,0]) 338 properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(pres_tensor[1,1]) 339 properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(pres_tensor[2,2]) 340 properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(pres_tensor[0,1]) 341 properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(pres_tensor[0,2]) 342 properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(pres_tensor[1,2]) 343 line += f" {pres:10.3f}" 344 if variable_cell: 345 density = system["density"] 346 properties_traj["Density[g/cm^3]"].append(density) 347 if print_aniso_pressure: 348 cell = system["cell"] 349 properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0,0]) 350 properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0,1]) 351 properties_traj[f"Cell_Az[Angstrom]"].append(cell[0,2]) 352 properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1,0]) 353 properties_traj[f"Cell_By[Angstrom]"].append(cell[1,1]) 354 properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1,2]) 355 properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2,0]) 356 properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2,1]) 357 properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2,2]) 358 line += f" {density:10.4f}" 359 if "piston_temperature" in system["barostat"]: 360 piston_temperature = system["barostat"]["piston_temperature"] 361 properties_traj["T_piston[Kelvin]"].append(piston_temperature) 362 363 print(line) 364 365 ### save frame 366 if istep % ndump == 0: 367 line = "# Write XYZ frame" 368 if variable_cell: 369 cell = np.array(system["cell"]) 370 reciprocal_cell = np.linalg.inv(cell) 371 if do_wrap_box: 372 if pimd: 373 centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell) 374 system["coordinates"] = system["coordinates"].at[0].set(centroid) 375 else: 376 system["coordinates"] = wrapbox( 377 system["coordinates"], cell, reciprocal_cell 378 ) 379 conformation = update_conformation(conformation,system) 380 line += " (atoms have been wrapped into the box)" 381 force_preprocess = True 382 print(line) 383 properties = { 384 "energy": float(system["epot"]) * energy_unit, 385 "Time": start_time + istep * dt, 386 "energy_unit": energy_unit_str, 387 } 388 389 if write_all_beads: 390 coords = np.asarray(conformation["coordinates"].reshape(nbeads, nat, 3)) 391 for i,fb in enumerate(fout): 392 write_frame( 393 fb, 394 system_data["symbols"], 395 coords[i], 396 cell=cell, 397 properties=properties, 398 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 399 ) 400 else: 401 write_frame( 402 fout, 403 system_data["symbols"], 404 np.asarray(conformation["coordinates"].reshape(nbeads, nat, 3)[0]), 405 cell=cell, 406 properties=properties, 407 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 408 ) 409 if ensemble_key is not None: 410 weights = " ".join([f'{w:.6f}' for w in system["ensemble_weights"].tolist()]) 411 fens.write(f"{weights}\n") 412 fens.flush() 413 414 415 ### summary over last nsummary steps 416 if istep % (nsummary) == 0: 417 if check_nan(system): 418 raise ValueError(f"dynamics crashed at step {istep}.") 419 tfull = time.time() - t0full 420 t0full = time.time() 421 tperstep = tfull / (nsummary) 422 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 423 elapsed_time = time.time() - tstart_dyn 424 estimated_remaining_time = tperstep * (nsteps - istep) 425 estimated_total_time = elapsed_time + estimated_remaining_time 426 427 print("#" * 50) 428 print(f"# Step {istep:_} of {nsteps:_} ({istep/nsteps*100:.5g} %)") 429 print(f"# Tot. elapsed time : {human_time_duration(elapsed_time)}") 430 print( 431 f"# Est. total time : {human_time_duration(estimated_total_time)}" 432 ) 433 print( 434 f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}" 435 ) 436 print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}") 437 438 if print_timings: 439 print(f"# Detailed per-step timings :") 440 dsteps = nsummary 441 tother = tfull - sum([t for t in timings.values()]) 442 timings["Other"] = tother 443 # sort timings 444 timings = { 445 k: v 446 for k, v in sorted( 447 timings.items(), key=lambda item: item[1], reverse=True 448 ) 449 } 450 for k, v in timings.items(): 451 print( 452 f"# {k:15} : {human_time_duration(v/dsteps):>12} ({v/tfull*100:5.3g} %)" 453 ) 454 print(f"# {'Total':15} : {human_time_duration(tfull/dsteps):>12}") 455 ## reset timings 456 timings = defaultdict(lambda: 0.0) 457 458 corr_kin = system["thermostat"].get("corr_kin",None) 459 if corr_kin is not None: 460 print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %") 461 print(f"# Averages over last {nsummary:_} steps :") 462 for k, v in properties_traj.items(): 463 if len(properties_traj[k]) == 0: 464 continue 465 mu = np.mean(properties_traj[k]) 466 sig = np.std(properties_traj[k]) 467 ksplit = k.split("[") 468 name = ksplit[0].strip() 469 unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else "" 470 print(f"# {name:10} : {mu: #10.5g} +/- {sig: #9.3g} {unit}") 471 472 print(f"# Perf.: {nsperday:.2f} ns/day ( {1.0 / tperstep:.2f} step/s )") 473 print("#" * 50) 474 if istep < nsteps: 475 print(header) 476 ## reset property trajectories 477 properties_traj = defaultdict(list) 478 479 print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}") 480 ### close trajectory file 481 fout.close() 482 if ensemble_key is not None: 483 fens.close() 484 485 486if __name__ == "__main__": 487 main()
def
minmaxone(x, name=''):
@jax.jit
def
wrapbox(x, cell, reciprocal_cell):
def
main():
48def main(): 49 # os.environ["OMP_NUM_THREADS"] = "1" 50 sys.stdout = io.TextIOWrapper( 51 open(sys.stdout.fileno(), "wb", 0), write_through=True 52 ) 53 ### Read the parameter file 54 parser = argparse.ArgumentParser(prog="fennol_md") 55 parser.add_argument("param_file", type=Path, help="Parameter file") 56 args = parser.parse_args() 57 simulation_parameters = parse_input(args.param_file) 58 59 ### Set the device 60 device: str = simulation_parameters.get("device", "cpu").lower() 61 if device == "cpu": 62 device = "cpu" 63 os.environ["CUDA_VISIBLE_DEVICES"] = "" 64 elif device.startswith("cuda") or device.startswith("gpu"): 65 if ":" in device: 66 num = device.split(":")[-1] 67 os.environ["CUDA_VISIBLE_DEVICES"] = num 68 else: 69 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 70 device = "gpu" 71 72 _device = jax.devices(device)[0] 73 jax.config.update("jax_default_device", _device) 74 75 ### Set the precision 76 enable_x64 = simulation_parameters.get("double_precision", False) 77 jax.config.update("jax_enable_x64", enable_x64) 78 fprec = "float64" if enable_x64 else "float32" 79 80 matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower() 81 assert matmul_precision in [ 82 "default", 83 "high", 84 "highest", 85 ], "matmul_prec must be one of 'default','high','highest'" 86 if matmul_precision != "highest": 87 print(f"# Setting matmul precision to '{matmul_precision}'") 88 if matmul_precision == "default" and fprec == "float32": 89 print( 90 "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'." 91 ) 92 jax.config.update("jax_default_matmul_precision", matmul_precision) 93 94 # with jax.default_device(_device): 95 dynamic(simulation_parameters, device, fprec)
def
dynamic(simulation_parameters, device, fprec):
98def dynamic(simulation_parameters, device, fprec): 99 tstart_dyn = time.time() 100 101 ### Initialize the model 102 model = load_model(simulation_parameters) 103 104 ### Get the coordinates and species from the xyz file 105 system_data, conformation = load_system_data(simulation_parameters, fprec) 106 nat = system_data["nat"] 107 108 preproc_state, conformation = initialize_preprocessing( 109 simulation_parameters, model, conformation, system_data 110 ) 111 112 random_seed = simulation_parameters.get( 113 "random_seed", np.random.randint(0, 2**32 - 1) 114 ) 115 print(f"# random_seed: {random_seed}") 116 rng_key = jax.random.PRNGKey(random_seed) 117 rng_key, subkey = jax.random.split(rng_key) 118 ## INITIALIZE INTEGRATOR AND SYSTEM 119 step, update_conformation, dyn_state, system = initialize_dynamics( 120 simulation_parameters, system_data, conformation, model, fprec, subkey 121 ) 122 123 dt = dyn_state["dt"] 124 ## get number of steps 125 nsteps = int(simulation_parameters.get("nsteps")) 126 start_time = 0.0 127 start_step = 0 128 129 130 ### Set I/O parameters 131 Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS 132 ndump = int(Tdump / dt) 133 system_name = system_data["name"] 134 135 model_energy_unit = au.get_multiplier(model.energy_unit) 136 ### Print initial pressure 137 estimate_pressure = dyn_state["estimate_pressure"] 138 if estimate_pressure and fprec == "float64": 139 volume = system_data["pbc"]["volume"] 140 coordinates = conformation["coordinates"] 141 cell = conformation["cells"][0] 142 # temper = 2 * ek / (3.0 * nat) * au.KELVIN 143 ek = 1.5*nat * system_data["kT"] 144 Pkin = (2 * au.KBAR) * ek / ((3.0 / au.BOHR**3) * volume) 145 e, f, vir_t, _ = model._energy_and_forces_and_virial( 146 model.variables, conformation 147 ) 148 KBAR = au.KBAR/model_energy_unit 149 Pvir = -(np.trace(vir_t[0]) * KBAR) / ((3.0 / au.BOHR**3) * volume) 150 vstep = volume * 0.000001 151 scalep = ((volume + vstep) / volume) ** (1.0 / 3.0) 152 cellp = cell * scalep 153 reciprocal_cell = np.linalg.inv(cellp) 154 sysp = model.preprocess( 155 **{ 156 **conformation, 157 "coordinates": coordinates * scalep, 158 "cells": cellp[None, :, :], 159 "reciprocal_cells": reciprocal_cell[None, :, :], 160 } 161 ) 162 ep, _ = model._total_energy(model.variables, sysp) 163 scalem = ((volume - vstep) / volume) ** (1.0 / 3.0) 164 cellm = cell * scalem 165 reciprocal_cell = np.linalg.inv(cellm) 166 sysm = model.preprocess( 167 **{ 168 **conformation, 169 "coordinates": coordinates * scalem, 170 "cells": cellm[None, :, :], 171 "reciprocal_cells": reciprocal_cell[None, :, :], 172 } 173 ) 174 em, _ = model._total_energy(model.variables, sysm) 175 Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep / au.BOHR**3) 176 print( 177 f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}" 178 ) 179 180 @jax.jit 181 def check_nan(system): 182 return jnp.any(jnp.isnan(system["vel"])) | jnp.any( 183 jnp.isnan(system["coordinates"]) 184 ) 185 186 if system_data["pbc"] is not None: 187 cell = system_data["pbc"]["cell"] 188 reciprocal_cell = system_data["pbc"]["reciprocal_cell"] 189 do_wrap_box = simulation_parameters.get("wrap_box", False) 190 else: 191 cell = None 192 reciprocal_cell = None 193 do_wrap_box = False 194 195 ### Energy units and print initial energy 196 per_atom_energy = simulation_parameters.get("per_atom_energy", True) 197 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 198 print("# Energy unit: ", energy_unit_str) 199 energy_unit = au.get_multiplier(energy_unit_str) 200 atom_energy_unit = energy_unit 201 atom_energy_unit_str = energy_unit_str 202 if per_atom_energy: 203 atom_energy_unit /= nat 204 atom_energy_unit_str = f"{energy_unit_str}/atom" 205 print("# Printing Energy per atom") 206 print( 207 f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}" 208 ) 209 f = system["forces"] 210 minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:") 211 212 ## printing options 213 print_timings = simulation_parameters.get("print_timings", False) 214 nprint = int(simulation_parameters.get("nprint", 10)) 215 assert nprint > 0, "nprint must be > 0" 216 nsummary = simulation_parameters.get("nsummary", 100 * nprint) 217 assert nsummary > nprint, "nsummary must be > nprint" 218 219 ### Print header 220 include_thermostat_energy = "thermostat_energy" in system["thermostat"] 221 thermostat_name = dyn_state["thermostat_name"] 222 pimd = dyn_state["pimd"] 223 variable_cell = dyn_state["variable_cell"] 224 nbeads = system_data.get("nbeads", 1) 225 dyn_name = "PIMD" if pimd else "MD" 226 print("#" * 84) 227 print( 228 f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}" 229 ) 230 header = "# Step Time[ps] Etot Epot Ekin Temp[K]" 231 if pimd: 232 header += " Temp_c[K]" 233 if include_thermostat_energy: 234 header += " Etherm" 235 if estimate_pressure: 236 print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False) 237 pressure_unit_str = simulation_parameters.get("pressure_unit", "atm") 238 pressure_unit = au.get_multiplier(pressure_unit_str)*au.BOHR**3 239 header += f" Press[{pressure_unit_str}]" 240 if variable_cell: 241 header += " Density" 242 print(header) 243 244 ### Open trajectory file 245 traj_format = simulation_parameters.get("traj_format", "arc").lower() 246 if traj_format == "xyz": 247 traj_ext = ".traj.xyz" 248 write_frame = write_xyz_frame 249 elif traj_format == "extxyz": 250 traj_ext = ".traj.extxyz" 251 write_frame = write_extxyz_frame 252 elif traj_format == "arc": 253 traj_ext = ".arc" 254 write_frame = write_arc_frame 255 else: 256 raise ValueError( 257 f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'" 258 ) 259 260 write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd 261 262 if write_all_beads: 263 fout = [open(f"{system_name}_bead{i+1:03d}"+traj_ext, "w") for i in range(nbeads)] 264 else: 265 fout = open(system_name+traj_ext, "a+") 266 267 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 268 if ensemble_key is not None: 269 fens = open(f"{system_name}.ensemble_weights.traj", "a+") 270 271 272 ### initialize proprerty trajectories 273 properties_traj = defaultdict(list) 274 if print_timings: 275 timings = defaultdict(lambda: 0.0) 276 277 ### initialize counters and timers 278 t0 = time.time() 279 t0dump = t0 280 istep = 0 281 t0full = time.time() 282 force_preprocess = False 283 284 for istep in range(1, nsteps + 1): 285 286 ### update the system 287 dyn_state, system, conformation, preproc_state = step( 288 istep, dyn_state, system, conformation, preproc_state, force_preprocess 289 ) 290 291 ### print properties 292 if istep % nprint == 0: 293 t1 = time.time() 294 tperstep = (t1 - t0) / nprint 295 t0 = t1 296 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 297 298 ek = system["ek"] 299 epot = system["epot"] 300 etot = ek + epot 301 temper = 2 * ek / (3.0 * nat) * au.KELVIN 302 303 th_state = system["thermostat"] 304 if include_thermostat_energy: 305 etherm = th_state["thermostat_energy"] 306 etot = etot + etherm 307 308 properties_traj[f"Etot[{atom_energy_unit_str}]"].append( 309 etot * atom_energy_unit 310 ) 311 properties_traj[f"Epot[{atom_energy_unit_str}]"].append( 312 epot * atom_energy_unit 313 ) 314 properties_traj[f"Ekin[{atom_energy_unit_str}]"].append( 315 ek * atom_energy_unit 316 ) 317 properties_traj["Temper[Kelvin]"].append(temper) 318 if pimd: 319 ek_c = system["ek_c"] 320 temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN 321 properties_traj["Temper_c[Kelvin]"].append(temper_c) 322 323 ### construct line of properties 324 line = f"{istep:10.6g} {(start_time+istep*dt)/1000: 10.3f} {etot*atom_energy_unit: #10.4f} {epot*atom_energy_unit: #10.4f} {ek*atom_energy_unit: #10.4f} {temper: 10.2f}" 325 if pimd: 326 line += f" {temper_c: 10.2f}" 327 if include_thermostat_energy: 328 line += f" {etherm*atom_energy_unit: #10.4f}" 329 properties_traj[f"Etherm[{atom_energy_unit_str}]"].append( 330 etherm * atom_energy_unit 331 ) 332 if estimate_pressure: 333 pres = system["pressure"]*pressure_unit 334 properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres) 335 if print_aniso_pressure: 336 pres_tensor = system["pressure_tensor"]*pressure_unit 337 pres_tensor = 0.5*(pres_tensor + pres_tensor.T) 338 properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append(pres_tensor[0,0]) 339 properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append(pres_tensor[1,1]) 340 properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append(pres_tensor[2,2]) 341 properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append(pres_tensor[0,1]) 342 properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append(pres_tensor[0,2]) 343 properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append(pres_tensor[1,2]) 344 line += f" {pres:10.3f}" 345 if variable_cell: 346 density = system["density"] 347 properties_traj["Density[g/cm^3]"].append(density) 348 if print_aniso_pressure: 349 cell = system["cell"] 350 properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0,0]) 351 properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0,1]) 352 properties_traj[f"Cell_Az[Angstrom]"].append(cell[0,2]) 353 properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1,0]) 354 properties_traj[f"Cell_By[Angstrom]"].append(cell[1,1]) 355 properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1,2]) 356 properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2,0]) 357 properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2,1]) 358 properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2,2]) 359 line += f" {density:10.4f}" 360 if "piston_temperature" in system["barostat"]: 361 piston_temperature = system["barostat"]["piston_temperature"] 362 properties_traj["T_piston[Kelvin]"].append(piston_temperature) 363 364 print(line) 365 366 ### save frame 367 if istep % ndump == 0: 368 line = "# Write XYZ frame" 369 if variable_cell: 370 cell = np.array(system["cell"]) 371 reciprocal_cell = np.linalg.inv(cell) 372 if do_wrap_box: 373 if pimd: 374 centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell) 375 system["coordinates"] = system["coordinates"].at[0].set(centroid) 376 else: 377 system["coordinates"] = wrapbox( 378 system["coordinates"], cell, reciprocal_cell 379 ) 380 conformation = update_conformation(conformation,system) 381 line += " (atoms have been wrapped into the box)" 382 force_preprocess = True 383 print(line) 384 properties = { 385 "energy": float(system["epot"]) * energy_unit, 386 "Time": start_time + istep * dt, 387 "energy_unit": energy_unit_str, 388 } 389 390 if write_all_beads: 391 coords = np.asarray(conformation["coordinates"].reshape(nbeads, nat, 3)) 392 for i,fb in enumerate(fout): 393 write_frame( 394 fb, 395 system_data["symbols"], 396 coords[i], 397 cell=cell, 398 properties=properties, 399 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 400 ) 401 else: 402 write_frame( 403 fout, 404 system_data["symbols"], 405 np.asarray(conformation["coordinates"].reshape(nbeads, nat, 3)[0]), 406 cell=cell, 407 properties=properties, 408 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 409 ) 410 if ensemble_key is not None: 411 weights = " ".join([f'{w:.6f}' for w in system["ensemble_weights"].tolist()]) 412 fens.write(f"{weights}\n") 413 fens.flush() 414 415 416 ### summary over last nsummary steps 417 if istep % (nsummary) == 0: 418 if check_nan(system): 419 raise ValueError(f"dynamics crashed at step {istep}.") 420 tfull = time.time() - t0full 421 t0full = time.time() 422 tperstep = tfull / (nsummary) 423 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 424 elapsed_time = time.time() - tstart_dyn 425 estimated_remaining_time = tperstep * (nsteps - istep) 426 estimated_total_time = elapsed_time + estimated_remaining_time 427 428 print("#" * 50) 429 print(f"# Step {istep:_} of {nsteps:_} ({istep/nsteps*100:.5g} %)") 430 print(f"# Tot. elapsed time : {human_time_duration(elapsed_time)}") 431 print( 432 f"# Est. total time : {human_time_duration(estimated_total_time)}" 433 ) 434 print( 435 f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}" 436 ) 437 print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}") 438 439 if print_timings: 440 print(f"# Detailed per-step timings :") 441 dsteps = nsummary 442 tother = tfull - sum([t for t in timings.values()]) 443 timings["Other"] = tother 444 # sort timings 445 timings = { 446 k: v 447 for k, v in sorted( 448 timings.items(), key=lambda item: item[1], reverse=True 449 ) 450 } 451 for k, v in timings.items(): 452 print( 453 f"# {k:15} : {human_time_duration(v/dsteps):>12} ({v/tfull*100:5.3g} %)" 454 ) 455 print(f"# {'Total':15} : {human_time_duration(tfull/dsteps):>12}") 456 ## reset timings 457 timings = defaultdict(lambda: 0.0) 458 459 corr_kin = system["thermostat"].get("corr_kin",None) 460 if corr_kin is not None: 461 print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %") 462 print(f"# Averages over last {nsummary:_} steps :") 463 for k, v in properties_traj.items(): 464 if len(properties_traj[k]) == 0: 465 continue 466 mu = np.mean(properties_traj[k]) 467 sig = np.std(properties_traj[k]) 468 ksplit = k.split("[") 469 name = ksplit[0].strip() 470 unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else "" 471 print(f"# {name:10} : {mu: #10.5g} +/- {sig: #9.3g} {unit}") 472 473 print(f"# Perf.: {nsperday:.2f} ns/day ( {1.0 / tperstep:.2f} step/s )") 474 print("#" * 50) 475 if istep < nsteps: 476 print(header) 477 ## reset property trajectories 478 properties_traj = defaultdict(list) 479 480 print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}") 481 ### close trajectory file 482 fout.close() 483 if ensemble_key is not None: 484 fens.close()