fennol.md.dynamic
1import sys, os, io 2import argparse 3import time 4from pathlib import Path 5 6import numpy as np 7from collections import defaultdict 8import jax 9import jax.numpy as jnp 10import pickle 11import yaml 12 13from ..utils.io import ( 14 write_arc_frame, 15 write_xyz_frame, 16 write_extxyz_frame, 17 human_time_duration, 18) 19from .utils import wrapbox, save_dynamics_restart 20from ..utils import minmaxone, AtomicUnits as au,read_tinker_interval 21from ..utils.input_parser import parse_input,convert_dict_units, InputFile 22from .integrate import initialize_dynamics 23 24 25def main(): 26 # os.environ["OMP_NUM_THREADS"] = "1" 27 sys.stdout = io.TextIOWrapper( 28 open(sys.stdout.fileno(), "wb", 0), write_through=True 29 ) 30 ### Read the parameter file 31 parser = argparse.ArgumentParser(prog="fennol_md") 32 parser.add_argument("param_file", type=Path, help="Parameter file") 33 args = parser.parse_args() 34 35 param_file = args.param_file 36 if not param_file.exists() and not param_file.is_file(): 37 raise FileNotFoundError(f"Parameter file {param_file} not found") 38 39 if param_file.suffix in [".yaml", ".yml"]: 40 with open(param_file, "r") as f: 41 simulation_parameters = convert_dict_units(yaml.safe_load(f)) 42 simulation_parameters = InputFile(**simulation_parameters) 43 elif param_file.suffix == ".fnl": 44 simulation_parameters = parse_input(args.param_file) 45 else: 46 raise ValueError( 47 f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'" 48 ) 49 50 ### Set the device 51 if "FENNOL_DEVICE" in os.environ: 52 device = os.environ["FENNOL_DEVICE"].lower() 53 print(f"# Setting device from env FENNOL_DEVICE={device}") 54 else: 55 device = simulation_parameters.get("device", "cpu").lower() 56 if device == "cpu": 57 jax.config.update('jax_platforms', 'cpu') 58 os.environ["CUDA_VISIBLE_DEVICES"] = "" 59 elif device.startswith("cuda") or device.startswith("gpu"): 60 if ":" in device: 61 num = device.split(":")[-1] 62 os.environ["CUDA_VISIBLE_DEVICES"] = num 63 else: 64 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 65 device = "gpu" 66 67 _device = jax.devices(device)[0] 68 jax.config.update("jax_default_device", _device) 69 70 ### Set the precision 71 enable_x64 = simulation_parameters.get("double_precision", False) 72 jax.config.update("jax_enable_x64", enable_x64) 73 fprec = "float64" if enable_x64 else "float32" 74 75 matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower() 76 assert matmul_precision in [ 77 "default", 78 "high", 79 "highest", 80 ], "matmul_prec must be one of 'default','high','highest'" 81 if matmul_precision != "highest": 82 print(f"# Setting matmul precision to '{matmul_precision}'") 83 if matmul_precision == "default" and fprec == "float32": 84 print( 85 "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'." 86 ) 87 jax.config.update("jax_default_matmul_precision", matmul_precision) 88 89 # with jax.default_device(_device): 90 dynamic(simulation_parameters, device, fprec) 91 92 93def dynamic(simulation_parameters, device, fprec): 94 tstart_dyn = time.time() 95 96 random_seed = simulation_parameters.get( 97 "random_seed", np.random.randint(0, 2**32 - 1) 98 ) 99 print(f"# random_seed: {random_seed}") 100 rng_key = jax.random.PRNGKey(random_seed) 101 102 ## INITIALIZE INTEGRATOR AND SYSTEM 103 rng_key, subkey = jax.random.split(rng_key) 104 step, update_conformation, system_data, dyn_state, conformation, system = ( 105 initialize_dynamics(simulation_parameters, fprec, subkey) 106 ) 107 108 nat = system_data["nat"] 109 dt = dyn_state["dt"] 110 ## get number of steps 111 nsteps = int(simulation_parameters.get("nsteps")) 112 start_time_ps = dyn_state.get("start_time_ps", 0.0) 113 114 ### Set I/O parameters 115 Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS 116 ndump = int(Tdump / dt) 117 system_name = system_data["name"] 118 estimate_pressure = dyn_state["estimate_pressure"] 119 120 @jax.jit 121 def check_nan(system): 122 return jnp.any(jnp.isnan(system["vel"])) | jnp.any( 123 jnp.isnan(system["coordinates"]) 124 ) 125 126 if system_data["pbc"] is not None: 127 cell = system["cell"] 128 reciprocal_cell = np.linalg.inv(cell) 129 do_wrap_box = simulation_parameters.get("wrap_box", False) 130 if do_wrap_box: 131 wrap_groups_def = simulation_parameters.get("wrap_groups",None) 132 if wrap_groups_def is None: 133 wrap_groups = None 134 else: 135 wrap_groups = {} 136 assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary" 137 for k, v in wrap_groups_def.items(): 138 wrap_groups[k]=read_tinker_interval(v) 139 # check that pairwise intersection of wrap groups is empty 140 wrap_groups_keys = list(wrap_groups.keys()) 141 for i in range(len(wrap_groups_keys)): 142 i_key = wrap_groups_keys[i] 143 w1 = set(wrap_groups[i_key]) 144 for j in range(i + 1, len(wrap_groups_keys)): 145 j_key = wrap_groups_keys[j] 146 w2 = set(wrap_groups[j_key]) 147 if w1.intersection(w2): 148 raise ValueError( 149 f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}" 150 ) 151 group_all = np.concatenate(list(wrap_groups.values())) 152 # get all atoms that are not in any wrap group 153 group_none = np.setdiff1d(np.arange(nat), group_all) 154 print(f"# Wrap groups: {wrap_groups}") 155 wrap_groups["__other"] = group_none 156 wrap_groups = ((k, v) for k, v in wrap_groups.items()) 157 158 else: 159 cell = None 160 reciprocal_cell = None 161 do_wrap_box = False 162 wrap_groups = None 163 164 ### Energy units and print initial energy 165 model_energy_unit = system_data["model_energy_unit"] 166 model_energy_unit_str = system_data["model_energy_unit_str"] 167 per_atom_energy = simulation_parameters.get("per_atom_energy", True) 168 energy_unit_str = system_data["energy_unit_str"] 169 energy_unit = system_data["energy_unit"] 170 print("# Energy unit: ", energy_unit_str) 171 atom_energy_unit = energy_unit 172 atom_energy_unit_str = energy_unit_str 173 if per_atom_energy: 174 atom_energy_unit /= nat 175 atom_energy_unit_str = f"{energy_unit_str}/atom" 176 print("# Printing Energy per atom") 177 print( 178 f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}" 179 ) 180 f = system["forces"] 181 minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:") 182 183 ## printing options 184 print_timings = simulation_parameters.get("print_timings", False) 185 nprint = int(simulation_parameters.get("nprint", 10)) 186 assert nprint > 0, "nprint must be > 0" 187 nsummary = simulation_parameters.get("nsummary", 100 * nprint) 188 assert nsummary > nprint, "nsummary must be > nprint" 189 190 save_keys = simulation_parameters.get("save_keys", []) 191 if save_keys: 192 print(f"# Saving keys: {save_keys}") 193 fkeys = open(f"{system_name}.traj.pkl", "wb+") 194 else: 195 fkeys = None 196 197 ### initialize colvars 198 use_colvars = "colvars" in dyn_state 199 if use_colvars: 200 print(f"# Colvars: {dyn_state['colvars']}") 201 colvars_names = dyn_state["colvars"] 202 # open colvars file and print header 203 fcolvars = open(f"{system_name}.colvars.traj", "a") 204 fcolvars.write("#time[ps] ") 205 fcolvars.write(" ".join(colvars_names)) 206 fcolvars.write("\n") 207 fcolvars.flush() 208 209 ### Print header 210 include_thermostat_energy = "thermostat_energy" in system["thermostat"] 211 thermostat_name = dyn_state["thermostat_name"] 212 pimd = dyn_state["pimd"] 213 variable_cell = dyn_state["variable_cell"] 214 nbeads = system_data.get("nbeads", 1) 215 dyn_name = "PIMD" if pimd else "MD" 216 print("#" * 84) 217 print( 218 f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}" 219 ) 220 header = "# Step Time[ps] Etot Epot Ekin Temp[K]" 221 if pimd: 222 header += " Temp_c[K]" 223 if include_thermostat_energy: 224 header += " Etherm" 225 if estimate_pressure: 226 print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False) 227 pressure_unit_str = simulation_parameters.get("pressure_unit", "atm") 228 pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3 229 header += f" Press[{pressure_unit_str}]" 230 if variable_cell: 231 header += " Density" 232 print(header) 233 234 ### Open trajectory file 235 traj_format = simulation_parameters.get("traj_format", "arc").lower() 236 if traj_format == "xyz": 237 traj_ext = ".traj.xyz" 238 write_frame = write_xyz_frame 239 elif traj_format == "extxyz": 240 traj_ext = ".traj.extxyz" 241 write_frame = write_extxyz_frame 242 elif traj_format == "arc": 243 traj_ext = ".arc" 244 write_frame = write_arc_frame 245 else: 246 raise ValueError( 247 f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'" 248 ) 249 250 write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd 251 252 if write_all_beads: 253 fout = [ 254 open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads) 255 ] 256 else: 257 fout = open(system_name + traj_ext, "a") 258 259 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 260 if ensemble_key is not None: 261 fens = open(f"{system_name}.ensemble_weights.traj", "a") 262 263 write_centroid = simulation_parameters.get("write_centroid", False) and pimd 264 if write_centroid: 265 fcentroid = open(f"{system_name}_centroid" + traj_ext, "a") 266 267 ### initialize proprerty trajectories 268 properties_traj = defaultdict(list) 269 270 ### initialize counters and timers 271 t0 = time.time() 272 t0dump = t0 273 istep = 0 274 t0full = time.time() 275 force_preprocess = False 276 277 for istep in range(1, nsteps + 1): 278 279 ### update the system 280 dyn_state, system, conformation, model_output = step( 281 istep, dyn_state, system, conformation, force_preprocess 282 ) 283 284 ### print properties 285 if istep % nprint == 0: 286 t1 = time.time() 287 tperstep = (t1 - t0) / nprint 288 t0 = t1 289 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 290 291 ek = system["ek"] 292 epot = system["epot"] 293 etot = ek + epot 294 temper = 2 * ek / (3.0 * nat) * au.KELVIN 295 296 th_state = system["thermostat"] 297 if include_thermostat_energy: 298 etherm = th_state["thermostat_energy"] 299 etot = etot + etherm 300 301 properties_traj[f"Etot[{atom_energy_unit_str}]"].append( 302 etot * atom_energy_unit 303 ) 304 properties_traj[f"Epot[{atom_energy_unit_str}]"].append( 305 epot * atom_energy_unit 306 ) 307 properties_traj[f"Ekin[{atom_energy_unit_str}]"].append( 308 ek * atom_energy_unit 309 ) 310 properties_traj["Temper[Kelvin]"].append(temper) 311 if pimd: 312 ek_c = system["ek_c"] 313 temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN 314 properties_traj["Temper_c[Kelvin]"].append(temper_c) 315 316 ### construct line of properties 317 simulated_time = start_time_ps + istep * dt / 1000 318 line = f"{istep:10.6g} {simulated_time: 10.3f} {etot*atom_energy_unit: #10.4f} {epot*atom_energy_unit: #10.4f} {ek*atom_energy_unit: #10.4f} {temper: 10.2f}" 319 if pimd: 320 line += f" {temper_c: 10.2f}" 321 if include_thermostat_energy: 322 line += f" {etherm*atom_energy_unit: #10.4f}" 323 properties_traj[f"Etherm[{atom_energy_unit_str}]"].append( 324 etherm * atom_energy_unit 325 ) 326 if estimate_pressure: 327 pres = system["pressure"] * pressure_unit 328 properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres) 329 if print_aniso_pressure: 330 pres_tensor = system["pressure_tensor"] * pressure_unit 331 pres_tensor = 0.5 * (pres_tensor + pres_tensor.T) 332 properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append( 333 pres_tensor[0, 0] 334 ) 335 properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append( 336 pres_tensor[1, 1] 337 ) 338 properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append( 339 pres_tensor[2, 2] 340 ) 341 properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append( 342 pres_tensor[0, 1] 343 ) 344 properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append( 345 pres_tensor[0, 2] 346 ) 347 properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append( 348 pres_tensor[1, 2] 349 ) 350 line += f" {pres:10.3f}" 351 if variable_cell: 352 density = system["density"] 353 properties_traj["Density[g/cm^3]"].append(density) 354 if print_aniso_pressure: 355 cell = system["cell"] 356 properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0]) 357 properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1]) 358 properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2]) 359 properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0]) 360 properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1]) 361 properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2]) 362 properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0]) 363 properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1]) 364 properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2]) 365 line += f" {density:10.4f}" 366 if "piston_temperature" in system["barostat"]: 367 piston_temperature = system["barostat"]["piston_temperature"] 368 properties_traj["T_piston[Kelvin]"].append(piston_temperature) 369 370 print(line) 371 372 ### save colvars 373 if use_colvars: 374 colvars = system["colvars"] 375 fcolvars.write(f"{simulated_time:.3f} ") 376 fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names])) 377 fcolvars.write("\n") 378 fcolvars.flush() 379 380 if save_keys: 381 data = { 382 k: ( 383 np.asarray(model_output[k]) 384 if isinstance(model_output[k], jnp.ndarray) 385 else model_output[k] 386 ) 387 for k in save_keys 388 } 389 data["properties"] = { 390 k: float(v) for k, v in zip(header.split()[1:], line.split()) 391 } 392 data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str) 393 data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str) 394 395 pickle.dump(data, fkeys) 396 397 ### save frame 398 if istep % ndump == 0: 399 line = "# Write XYZ frame" 400 if variable_cell: 401 cell = np.array(system["cell"]) 402 reciprocal_cell = np.linalg.inv(cell) 403 if do_wrap_box: 404 if pimd: 405 centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups) 406 system["coordinates"] = system["coordinates"].at[0].set(centroid) 407 else: 408 system["coordinates"] = wrapbox( 409 system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups 410 ) 411 conformation = update_conformation(conformation, system) 412 line += " (atoms have been wrapped into the box)" 413 force_preprocess = True 414 print(line) 415 416 save_dynamics_restart(system_data, conformation, dyn_state, system) 417 418 properties = { 419 "energy": float(system["epot"]) * energy_unit, 420 "Time_ps": start_time_ps + istep * dt / 1000, 421 "energy_unit": energy_unit_str, 422 } 423 424 if write_all_beads: 425 coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3)) 426 for i, fb in enumerate(fout): 427 write_frame( 428 fb, 429 system_data["symbols"], 430 coords[i], 431 cell=cell, 432 properties=properties, 433 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 434 ) 435 else: 436 write_frame( 437 fout, 438 system_data["symbols"], 439 np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]), 440 cell=cell, 441 properties=properties, 442 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 443 ) 444 if write_centroid: 445 centroid = np.asarray(system["coordinates"][0]) 446 write_frame( 447 fcentroid, 448 system_data["symbols"], 449 centroid, 450 cell=cell, 451 properties=properties, 452 forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) 453 * energy_unit, 454 ) 455 if ensemble_key is not None: 456 weights = " ".join( 457 [f"{w:.6f}" for w in system["ensemble_weights"].tolist()] 458 ) 459 fens.write(f"{weights}\n") 460 fens.flush() 461 462 ### summary over last nsummary steps 463 if istep % (nsummary) == 0: 464 if check_nan(system): 465 raise ValueError(f"dynamics crashed at step {istep}.") 466 tfull = time.time() - t0full 467 t0full = time.time() 468 tperstep = tfull / (nsummary) 469 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 470 elapsed_time = time.time() - tstart_dyn 471 estimated_remaining_time = tperstep * (nsteps - istep) 472 estimated_total_time = elapsed_time + estimated_remaining_time 473 474 print("#" * 50) 475 print(f"# Step {istep:_} of {nsteps:_} ({istep/nsteps*100:.5g} %)") 476 print(f"# Simulated time : {istep * dt*1.e-3:.3f} ps") 477 print(f"# Tot. Simu. time : {start_time_ps + istep * dt*1.e-3:.3f} ps") 478 print(f"# Tot. elapsed time : {human_time_duration(elapsed_time)}") 479 print( 480 f"# Est. total duration : {human_time_duration(estimated_total_time)}" 481 ) 482 print( 483 f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}" 484 ) 485 print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}") 486 487 corr_kin = system["thermostat"].get("corr_kin", None) 488 if corr_kin is not None: 489 print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %") 490 print(f"# Averages over last {nsummary:_} steps :") 491 for k, v in properties_traj.items(): 492 if len(properties_traj[k]) == 0: 493 continue 494 mu = np.mean(properties_traj[k]) 495 sig = np.std(properties_traj[k]) 496 ksplit = k.split("[") 497 name = ksplit[0].strip() 498 unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else "" 499 print(f"# {name:10} : {mu: #10.5g} +/- {sig: #9.3g} {unit}") 500 501 print(f"# Perf.: {nsperday:.2f} ns/day ( {1.0 / tperstep:.2f} step/s )") 502 print("#" * 50) 503 if istep < nsteps: 504 print(header) 505 ## reset property trajectories 506 properties_traj = defaultdict(list) 507 508 print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}") 509 ### close trajectory file 510 fout.close() 511 if ensemble_key is not None: 512 fens.close() 513 if use_colvars: 514 fcolvars.close() 515 if fkeys is not None: 516 fkeys.close() 517 if write_centroid: 518 fcentroid.close() 519 520 521if __name__ == "__main__": 522 main()
def
main():
26def main(): 27 # os.environ["OMP_NUM_THREADS"] = "1" 28 sys.stdout = io.TextIOWrapper( 29 open(sys.stdout.fileno(), "wb", 0), write_through=True 30 ) 31 ### Read the parameter file 32 parser = argparse.ArgumentParser(prog="fennol_md") 33 parser.add_argument("param_file", type=Path, help="Parameter file") 34 args = parser.parse_args() 35 36 param_file = args.param_file 37 if not param_file.exists() and not param_file.is_file(): 38 raise FileNotFoundError(f"Parameter file {param_file} not found") 39 40 if param_file.suffix in [".yaml", ".yml"]: 41 with open(param_file, "r") as f: 42 simulation_parameters = convert_dict_units(yaml.safe_load(f)) 43 simulation_parameters = InputFile(**simulation_parameters) 44 elif param_file.suffix == ".fnl": 45 simulation_parameters = parse_input(args.param_file) 46 else: 47 raise ValueError( 48 f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'" 49 ) 50 51 ### Set the device 52 if "FENNOL_DEVICE" in os.environ: 53 device = os.environ["FENNOL_DEVICE"].lower() 54 print(f"# Setting device from env FENNOL_DEVICE={device}") 55 else: 56 device = simulation_parameters.get("device", "cpu").lower() 57 if device == "cpu": 58 jax.config.update('jax_platforms', 'cpu') 59 os.environ["CUDA_VISIBLE_DEVICES"] = "" 60 elif device.startswith("cuda") or device.startswith("gpu"): 61 if ":" in device: 62 num = device.split(":")[-1] 63 os.environ["CUDA_VISIBLE_DEVICES"] = num 64 else: 65 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 66 device = "gpu" 67 68 _device = jax.devices(device)[0] 69 jax.config.update("jax_default_device", _device) 70 71 ### Set the precision 72 enable_x64 = simulation_parameters.get("double_precision", False) 73 jax.config.update("jax_enable_x64", enable_x64) 74 fprec = "float64" if enable_x64 else "float32" 75 76 matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower() 77 assert matmul_precision in [ 78 "default", 79 "high", 80 "highest", 81 ], "matmul_prec must be one of 'default','high','highest'" 82 if matmul_precision != "highest": 83 print(f"# Setting matmul precision to '{matmul_precision}'") 84 if matmul_precision == "default" and fprec == "float32": 85 print( 86 "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'." 87 ) 88 jax.config.update("jax_default_matmul_precision", matmul_precision) 89 90 # with jax.default_device(_device): 91 dynamic(simulation_parameters, device, fprec)
def
dynamic(simulation_parameters, device, fprec):
94def dynamic(simulation_parameters, device, fprec): 95 tstart_dyn = time.time() 96 97 random_seed = simulation_parameters.get( 98 "random_seed", np.random.randint(0, 2**32 - 1) 99 ) 100 print(f"# random_seed: {random_seed}") 101 rng_key = jax.random.PRNGKey(random_seed) 102 103 ## INITIALIZE INTEGRATOR AND SYSTEM 104 rng_key, subkey = jax.random.split(rng_key) 105 step, update_conformation, system_data, dyn_state, conformation, system = ( 106 initialize_dynamics(simulation_parameters, fprec, subkey) 107 ) 108 109 nat = system_data["nat"] 110 dt = dyn_state["dt"] 111 ## get number of steps 112 nsteps = int(simulation_parameters.get("nsteps")) 113 start_time_ps = dyn_state.get("start_time_ps", 0.0) 114 115 ### Set I/O parameters 116 Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS 117 ndump = int(Tdump / dt) 118 system_name = system_data["name"] 119 estimate_pressure = dyn_state["estimate_pressure"] 120 121 @jax.jit 122 def check_nan(system): 123 return jnp.any(jnp.isnan(system["vel"])) | jnp.any( 124 jnp.isnan(system["coordinates"]) 125 ) 126 127 if system_data["pbc"] is not None: 128 cell = system["cell"] 129 reciprocal_cell = np.linalg.inv(cell) 130 do_wrap_box = simulation_parameters.get("wrap_box", False) 131 if do_wrap_box: 132 wrap_groups_def = simulation_parameters.get("wrap_groups",None) 133 if wrap_groups_def is None: 134 wrap_groups = None 135 else: 136 wrap_groups = {} 137 assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary" 138 for k, v in wrap_groups_def.items(): 139 wrap_groups[k]=read_tinker_interval(v) 140 # check that pairwise intersection of wrap groups is empty 141 wrap_groups_keys = list(wrap_groups.keys()) 142 for i in range(len(wrap_groups_keys)): 143 i_key = wrap_groups_keys[i] 144 w1 = set(wrap_groups[i_key]) 145 for j in range(i + 1, len(wrap_groups_keys)): 146 j_key = wrap_groups_keys[j] 147 w2 = set(wrap_groups[j_key]) 148 if w1.intersection(w2): 149 raise ValueError( 150 f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}" 151 ) 152 group_all = np.concatenate(list(wrap_groups.values())) 153 # get all atoms that are not in any wrap group 154 group_none = np.setdiff1d(np.arange(nat), group_all) 155 print(f"# Wrap groups: {wrap_groups}") 156 wrap_groups["__other"] = group_none 157 wrap_groups = ((k, v) for k, v in wrap_groups.items()) 158 159 else: 160 cell = None 161 reciprocal_cell = None 162 do_wrap_box = False 163 wrap_groups = None 164 165 ### Energy units and print initial energy 166 model_energy_unit = system_data["model_energy_unit"] 167 model_energy_unit_str = system_data["model_energy_unit_str"] 168 per_atom_energy = simulation_parameters.get("per_atom_energy", True) 169 energy_unit_str = system_data["energy_unit_str"] 170 energy_unit = system_data["energy_unit"] 171 print("# Energy unit: ", energy_unit_str) 172 atom_energy_unit = energy_unit 173 atom_energy_unit_str = energy_unit_str 174 if per_atom_energy: 175 atom_energy_unit /= nat 176 atom_energy_unit_str = f"{energy_unit_str}/atom" 177 print("# Printing Energy per atom") 178 print( 179 f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}" 180 ) 181 f = system["forces"] 182 minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:") 183 184 ## printing options 185 print_timings = simulation_parameters.get("print_timings", False) 186 nprint = int(simulation_parameters.get("nprint", 10)) 187 assert nprint > 0, "nprint must be > 0" 188 nsummary = simulation_parameters.get("nsummary", 100 * nprint) 189 assert nsummary > nprint, "nsummary must be > nprint" 190 191 save_keys = simulation_parameters.get("save_keys", []) 192 if save_keys: 193 print(f"# Saving keys: {save_keys}") 194 fkeys = open(f"{system_name}.traj.pkl", "wb+") 195 else: 196 fkeys = None 197 198 ### initialize colvars 199 use_colvars = "colvars" in dyn_state 200 if use_colvars: 201 print(f"# Colvars: {dyn_state['colvars']}") 202 colvars_names = dyn_state["colvars"] 203 # open colvars file and print header 204 fcolvars = open(f"{system_name}.colvars.traj", "a") 205 fcolvars.write("#time[ps] ") 206 fcolvars.write(" ".join(colvars_names)) 207 fcolvars.write("\n") 208 fcolvars.flush() 209 210 ### Print header 211 include_thermostat_energy = "thermostat_energy" in system["thermostat"] 212 thermostat_name = dyn_state["thermostat_name"] 213 pimd = dyn_state["pimd"] 214 variable_cell = dyn_state["variable_cell"] 215 nbeads = system_data.get("nbeads", 1) 216 dyn_name = "PIMD" if pimd else "MD" 217 print("#" * 84) 218 print( 219 f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}" 220 ) 221 header = "# Step Time[ps] Etot Epot Ekin Temp[K]" 222 if pimd: 223 header += " Temp_c[K]" 224 if include_thermostat_energy: 225 header += " Etherm" 226 if estimate_pressure: 227 print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False) 228 pressure_unit_str = simulation_parameters.get("pressure_unit", "atm") 229 pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3 230 header += f" Press[{pressure_unit_str}]" 231 if variable_cell: 232 header += " Density" 233 print(header) 234 235 ### Open trajectory file 236 traj_format = simulation_parameters.get("traj_format", "arc").lower() 237 if traj_format == "xyz": 238 traj_ext = ".traj.xyz" 239 write_frame = write_xyz_frame 240 elif traj_format == "extxyz": 241 traj_ext = ".traj.extxyz" 242 write_frame = write_extxyz_frame 243 elif traj_format == "arc": 244 traj_ext = ".arc" 245 write_frame = write_arc_frame 246 else: 247 raise ValueError( 248 f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'" 249 ) 250 251 write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd 252 253 if write_all_beads: 254 fout = [ 255 open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads) 256 ] 257 else: 258 fout = open(system_name + traj_ext, "a") 259 260 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 261 if ensemble_key is not None: 262 fens = open(f"{system_name}.ensemble_weights.traj", "a") 263 264 write_centroid = simulation_parameters.get("write_centroid", False) and pimd 265 if write_centroid: 266 fcentroid = open(f"{system_name}_centroid" + traj_ext, "a") 267 268 ### initialize proprerty trajectories 269 properties_traj = defaultdict(list) 270 271 ### initialize counters and timers 272 t0 = time.time() 273 t0dump = t0 274 istep = 0 275 t0full = time.time() 276 force_preprocess = False 277 278 for istep in range(1, nsteps + 1): 279 280 ### update the system 281 dyn_state, system, conformation, model_output = step( 282 istep, dyn_state, system, conformation, force_preprocess 283 ) 284 285 ### print properties 286 if istep % nprint == 0: 287 t1 = time.time() 288 tperstep = (t1 - t0) / nprint 289 t0 = t1 290 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 291 292 ek = system["ek"] 293 epot = system["epot"] 294 etot = ek + epot 295 temper = 2 * ek / (3.0 * nat) * au.KELVIN 296 297 th_state = system["thermostat"] 298 if include_thermostat_energy: 299 etherm = th_state["thermostat_energy"] 300 etot = etot + etherm 301 302 properties_traj[f"Etot[{atom_energy_unit_str}]"].append( 303 etot * atom_energy_unit 304 ) 305 properties_traj[f"Epot[{atom_energy_unit_str}]"].append( 306 epot * atom_energy_unit 307 ) 308 properties_traj[f"Ekin[{atom_energy_unit_str}]"].append( 309 ek * atom_energy_unit 310 ) 311 properties_traj["Temper[Kelvin]"].append(temper) 312 if pimd: 313 ek_c = system["ek_c"] 314 temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN 315 properties_traj["Temper_c[Kelvin]"].append(temper_c) 316 317 ### construct line of properties 318 simulated_time = start_time_ps + istep * dt / 1000 319 line = f"{istep:10.6g} {simulated_time: 10.3f} {etot*atom_energy_unit: #10.4f} {epot*atom_energy_unit: #10.4f} {ek*atom_energy_unit: #10.4f} {temper: 10.2f}" 320 if pimd: 321 line += f" {temper_c: 10.2f}" 322 if include_thermostat_energy: 323 line += f" {etherm*atom_energy_unit: #10.4f}" 324 properties_traj[f"Etherm[{atom_energy_unit_str}]"].append( 325 etherm * atom_energy_unit 326 ) 327 if estimate_pressure: 328 pres = system["pressure"] * pressure_unit 329 properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres) 330 if print_aniso_pressure: 331 pres_tensor = system["pressure_tensor"] * pressure_unit 332 pres_tensor = 0.5 * (pres_tensor + pres_tensor.T) 333 properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append( 334 pres_tensor[0, 0] 335 ) 336 properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append( 337 pres_tensor[1, 1] 338 ) 339 properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append( 340 pres_tensor[2, 2] 341 ) 342 properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append( 343 pres_tensor[0, 1] 344 ) 345 properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append( 346 pres_tensor[0, 2] 347 ) 348 properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append( 349 pres_tensor[1, 2] 350 ) 351 line += f" {pres:10.3f}" 352 if variable_cell: 353 density = system["density"] 354 properties_traj["Density[g/cm^3]"].append(density) 355 if print_aniso_pressure: 356 cell = system["cell"] 357 properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0]) 358 properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1]) 359 properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2]) 360 properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0]) 361 properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1]) 362 properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2]) 363 properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0]) 364 properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1]) 365 properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2]) 366 line += f" {density:10.4f}" 367 if "piston_temperature" in system["barostat"]: 368 piston_temperature = system["barostat"]["piston_temperature"] 369 properties_traj["T_piston[Kelvin]"].append(piston_temperature) 370 371 print(line) 372 373 ### save colvars 374 if use_colvars: 375 colvars = system["colvars"] 376 fcolvars.write(f"{simulated_time:.3f} ") 377 fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names])) 378 fcolvars.write("\n") 379 fcolvars.flush() 380 381 if save_keys: 382 data = { 383 k: ( 384 np.asarray(model_output[k]) 385 if isinstance(model_output[k], jnp.ndarray) 386 else model_output[k] 387 ) 388 for k in save_keys 389 } 390 data["properties"] = { 391 k: float(v) for k, v in zip(header.split()[1:], line.split()) 392 } 393 data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str) 394 data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str) 395 396 pickle.dump(data, fkeys) 397 398 ### save frame 399 if istep % ndump == 0: 400 line = "# Write XYZ frame" 401 if variable_cell: 402 cell = np.array(system["cell"]) 403 reciprocal_cell = np.linalg.inv(cell) 404 if do_wrap_box: 405 if pimd: 406 centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups) 407 system["coordinates"] = system["coordinates"].at[0].set(centroid) 408 else: 409 system["coordinates"] = wrapbox( 410 system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups 411 ) 412 conformation = update_conformation(conformation, system) 413 line += " (atoms have been wrapped into the box)" 414 force_preprocess = True 415 print(line) 416 417 save_dynamics_restart(system_data, conformation, dyn_state, system) 418 419 properties = { 420 "energy": float(system["epot"]) * energy_unit, 421 "Time_ps": start_time_ps + istep * dt / 1000, 422 "energy_unit": energy_unit_str, 423 } 424 425 if write_all_beads: 426 coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3)) 427 for i, fb in enumerate(fout): 428 write_frame( 429 fb, 430 system_data["symbols"], 431 coords[i], 432 cell=cell, 433 properties=properties, 434 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 435 ) 436 else: 437 write_frame( 438 fout, 439 system_data["symbols"], 440 np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]), 441 cell=cell, 442 properties=properties, 443 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 444 ) 445 if write_centroid: 446 centroid = np.asarray(system["coordinates"][0]) 447 write_frame( 448 fcentroid, 449 system_data["symbols"], 450 centroid, 451 cell=cell, 452 properties=properties, 453 forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) 454 * energy_unit, 455 ) 456 if ensemble_key is not None: 457 weights = " ".join( 458 [f"{w:.6f}" for w in system["ensemble_weights"].tolist()] 459 ) 460 fens.write(f"{weights}\n") 461 fens.flush() 462 463 ### summary over last nsummary steps 464 if istep % (nsummary) == 0: 465 if check_nan(system): 466 raise ValueError(f"dynamics crashed at step {istep}.") 467 tfull = time.time() - t0full 468 t0full = time.time() 469 tperstep = tfull / (nsummary) 470 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 471 elapsed_time = time.time() - tstart_dyn 472 estimated_remaining_time = tperstep * (nsteps - istep) 473 estimated_total_time = elapsed_time + estimated_remaining_time 474 475 print("#" * 50) 476 print(f"# Step {istep:_} of {nsteps:_} ({istep/nsteps*100:.5g} %)") 477 print(f"# Simulated time : {istep * dt*1.e-3:.3f} ps") 478 print(f"# Tot. Simu. time : {start_time_ps + istep * dt*1.e-3:.3f} ps") 479 print(f"# Tot. elapsed time : {human_time_duration(elapsed_time)}") 480 print( 481 f"# Est. total duration : {human_time_duration(estimated_total_time)}" 482 ) 483 print( 484 f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}" 485 ) 486 print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}") 487 488 corr_kin = system["thermostat"].get("corr_kin", None) 489 if corr_kin is not None: 490 print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %") 491 print(f"# Averages over last {nsummary:_} steps :") 492 for k, v in properties_traj.items(): 493 if len(properties_traj[k]) == 0: 494 continue 495 mu = np.mean(properties_traj[k]) 496 sig = np.std(properties_traj[k]) 497 ksplit = k.split("[") 498 name = ksplit[0].strip() 499 unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else "" 500 print(f"# {name:10} : {mu: #10.5g} +/- {sig: #9.3g} {unit}") 501 502 print(f"# Perf.: {nsperday:.2f} ns/day ( {1.0 / tperstep:.2f} step/s )") 503 print("#" * 50) 504 if istep < nsteps: 505 print(header) 506 ## reset property trajectories 507 properties_traj = defaultdict(list) 508 509 print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}") 510 ### close trajectory file 511 fout.close() 512 if ensemble_key is not None: 513 fens.close() 514 if use_colvars: 515 fcolvars.close() 516 if fkeys is not None: 517 fkeys.close() 518 if write_centroid: 519 fcentroid.close()