fennol.md.dynamic
1import sys, os, io 2import argparse 3import time 4from pathlib import Path 5 6import numpy as np 7from collections import defaultdict 8import jax 9import jax.numpy as jnp 10import pickle 11import yaml 12 13from ..utils.io import ( 14 write_arc_frame, 15 write_xyz_frame, 16 write_extxyz_frame, 17 human_time_duration, 18) 19from .utils import wrapbox, save_dynamics_restart 20from ..utils import minmaxone, AtomicUnits as au,read_tinker_interval 21from ..utils.input_parser import parse_input,convert_dict_units, InputFile 22from .integrate import initialize_dynamics 23 24 25def main(): 26 # os.environ["OMP_NUM_THREADS"] = "1" 27 sys.stdout = io.TextIOWrapper( 28 open(sys.stdout.fileno(), "wb", 0), write_through=True 29 ) 30 31 ### Read the parameter file 32 parser = argparse.ArgumentParser(prog="fennol_md") 33 parser.add_argument("param_file", type=Path, help="Parameter file") 34 args = parser.parse_args() 35 param_file = args.param_file 36 37 return config_and_run_dynamic(param_file) 38 39def config_and_run_dynamic(param_file: Path): 40 """ Run the dynamic simulation based on the provided parameter file.""" 41 42 if not param_file.exists() and not param_file.is_file(): 43 raise FileNotFoundError(f"Parameter file {param_file} not found") 44 45 if param_file.suffix in [".yaml", ".yml"]: 46 with open(param_file, "r") as f: 47 simulation_parameters = convert_dict_units(yaml.safe_load(f)) 48 simulation_parameters = InputFile(**simulation_parameters) 49 elif param_file.suffix == ".fnl": 50 simulation_parameters = parse_input(param_file) 51 else: 52 raise ValueError( 53 f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'" 54 ) 55 56 ### Set the device 57 if "FENNOL_DEVICE" in os.environ: 58 device = os.environ["FENNOL_DEVICE"].lower() 59 print(f"# Setting device from env FENNOL_DEVICE={device}") 60 else: 61 device = simulation_parameters.get("device", "cpu").lower() 62 if device == "cpu": 63 jax.config.update('jax_platforms', 'cpu') 64 os.environ["CUDA_VISIBLE_DEVICES"] = "" 65 elif device.startswith("cuda") or device.startswith("gpu"): 66 if ":" in device: 67 num = device.split(":")[-1] 68 os.environ["CUDA_VISIBLE_DEVICES"] = num 69 else: 70 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 71 device = "gpu" 72 73 _device = jax.devices(device)[0] 74 jax.config.update("jax_default_device", _device) 75 76 ### Set the precision 77 enable_x64 = simulation_parameters.get("double_precision", False) 78 jax.config.update("jax_enable_x64", enable_x64) 79 fprec = "float64" if enable_x64 else "float32" 80 81 matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower() 82 assert matmul_precision in [ 83 "default", 84 "high", 85 "highest", 86 ], "matmul_prec must be one of 'default','high','highest'" 87 if matmul_precision != "highest": 88 print(f"# Setting matmul precision to '{matmul_precision}'") 89 if matmul_precision == "default" and fprec == "float32": 90 print( 91 "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'." 92 ) 93 jax.config.update("jax_default_matmul_precision", matmul_precision) 94 95 # with jax.default_device(_device): 96 return dynamic(simulation_parameters, device, fprec) 97 98 99def dynamic(simulation_parameters, device, fprec): 100 tstart_dyn = time.time() 101 102 random_seed = simulation_parameters.get( 103 "random_seed", np.random.randint(0, 2**32 - 1) 104 ) 105 print(f"# random_seed: {random_seed}") 106 rng_key = jax.random.PRNGKey(random_seed) 107 108 ## INITIALIZE INTEGRATOR AND SYSTEM 109 rng_key, subkey = jax.random.split(rng_key) 110 step, update_conformation, system_data, dyn_state, conformation, system = ( 111 initialize_dynamics(simulation_parameters, fprec, subkey) 112 ) 113 114 nat = system_data["nat"] 115 dt = dyn_state["dt"] 116 ## get number of steps 117 nsteps = int(simulation_parameters.get("nsteps")) 118 start_time_ps = dyn_state.get("start_time_ps", 0.0) 119 120 ### Set I/O parameters 121 Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS 122 ndump = int(Tdump / dt) 123 system_name = system_data["name"] 124 estimate_pressure = dyn_state["estimate_pressure"] 125 126 @jax.jit 127 def check_nan(system): 128 return jnp.any(jnp.isnan(system["vel"])) | jnp.any( 129 jnp.isnan(system["coordinates"]) 130 ) 131 132 if system_data["pbc"] is not None: 133 cell = system["cell"] 134 reciprocal_cell = np.linalg.inv(cell) 135 do_wrap_box = simulation_parameters.get("wrap_box", False) 136 if do_wrap_box: 137 wrap_groups_def = simulation_parameters.get("wrap_groups",None) 138 if wrap_groups_def is None: 139 wrap_groups = None 140 else: 141 wrap_groups = {} 142 assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary" 143 for k, v in wrap_groups_def.items(): 144 wrap_groups[k]=read_tinker_interval(v) 145 # check that pairwise intersection of wrap groups is empty 146 wrap_groups_keys = list(wrap_groups.keys()) 147 for i in range(len(wrap_groups_keys)): 148 i_key = wrap_groups_keys[i] 149 w1 = set(wrap_groups[i_key]) 150 for j in range(i + 1, len(wrap_groups_keys)): 151 j_key = wrap_groups_keys[j] 152 w2 = set(wrap_groups[j_key]) 153 if w1.intersection(w2): 154 raise ValueError( 155 f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}" 156 ) 157 group_all = np.concatenate(list(wrap_groups.values())) 158 # get all atoms that are not in any wrap group 159 group_none = np.setdiff1d(np.arange(nat), group_all) 160 print(f"# Wrap groups: {wrap_groups}") 161 wrap_groups["__other"] = group_none 162 wrap_groups = ((k, v) for k, v in wrap_groups.items()) 163 164 else: 165 cell = None 166 reciprocal_cell = None 167 do_wrap_box = False 168 wrap_groups = None 169 170 ### Energy units and print initial energy 171 model_energy_unit = system_data["model_energy_unit"] 172 model_energy_unit_str = system_data["model_energy_unit_str"] 173 per_atom_energy = simulation_parameters.get("per_atom_energy", True) 174 energy_unit_str = system_data["energy_unit_str"] 175 energy_unit = system_data["energy_unit"] 176 print("# Energy unit: ", energy_unit_str) 177 atom_energy_unit = energy_unit 178 atom_energy_unit_str = energy_unit_str 179 if per_atom_energy: 180 atom_energy_unit /= nat 181 atom_energy_unit_str = f"{energy_unit_str}/atom" 182 print("# Printing Energy per atom") 183 print( 184 f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}" 185 ) 186 f = system["forces"] 187 minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:") 188 189 ## printing options 190 print_timings = simulation_parameters.get("print_timings", False) 191 nprint = int(simulation_parameters.get("nprint", 10)) 192 assert nprint > 0, "nprint must be > 0" 193 nsummary = simulation_parameters.get("nsummary", 100 * nprint) 194 assert nsummary > nprint, "nsummary must be > nprint" 195 196 save_keys = simulation_parameters.get("save_keys", []) 197 if save_keys: 198 print(f"# Saving keys: {save_keys}") 199 fkeys = open(f"{system_name}.traj.pkl", "wb+") 200 else: 201 fkeys = None 202 203 ### initialize colvars 204 use_colvars = "colvars" in dyn_state 205 if use_colvars: 206 print(f"# Colvars: {dyn_state['colvars']}") 207 colvars_names = dyn_state["colvars"] 208 # open colvars file and print header 209 fcolvars = open(f"{system_name}.colvars.traj", "a") 210 fcolvars.write("#time[ps] ") 211 fcolvars.write(" ".join(colvars_names)) 212 fcolvars.write("\n") 213 fcolvars.flush() 214 215 ### Print header 216 include_thermostat_energy = "thermostat_energy" in system["thermostat"] 217 thermostat_name = dyn_state["thermostat_name"] 218 pimd = dyn_state["pimd"] 219 variable_cell = dyn_state["variable_cell"] 220 nbeads = system_data.get("nbeads", 1) 221 dyn_name = "PIMD" if pimd else "MD" 222 print("#" * 84) 223 print( 224 f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}" 225 ) 226 header = f"#{'Step':>9} {'Time[ps]':>10} {'Etot':>10} {'Epot':>10} {'Ekin':>10} {'Temp[K]':>10}" 227 if pimd: 228 header += f" {'Temp_c[K]':>10}" 229 if include_thermostat_energy: 230 header += f" {'Etherm':>10}" 231 if estimate_pressure: 232 print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False) 233 pressure_unit_str = simulation_parameters.get("pressure_unit", "atm") 234 pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3 235 p_str = f" Press[{pressure_unit_str}]" 236 header += f" {p_str:>10}" 237 if variable_cell: 238 header += f" {'Density':>10}" 239 print(header) 240 241 ### Open trajectory file 242 traj_format = simulation_parameters.get("traj_format", "arc").lower() 243 if traj_format == "xyz": 244 traj_ext = ".traj.xyz" 245 write_frame = write_xyz_frame 246 elif traj_format == "extxyz": 247 traj_ext = ".traj.extxyz" 248 write_frame = write_extxyz_frame 249 elif traj_format == "arc": 250 traj_ext = ".arc" 251 write_frame = write_arc_frame 252 else: 253 raise ValueError( 254 f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'" 255 ) 256 257 write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd 258 259 if write_all_beads: 260 fout = [ 261 open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads) 262 ] 263 else: 264 fout = open(system_name + traj_ext, "a") 265 266 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 267 if ensemble_key is not None: 268 fens = open(f"{system_name}.ensemble_weights.traj", "a") 269 270 write_centroid = simulation_parameters.get("write_centroid", False) and pimd 271 if write_centroid: 272 fcentroid = open(f"{system_name}_centroid" + traj_ext, "a") 273 274 ### initialize proprerty trajectories 275 properties_traj = defaultdict(list) 276 277 ### initialize counters and timers 278 t0 = time.time() 279 t0dump = t0 280 istep = 0 281 t0full = time.time() 282 force_preprocess = False 283 284 for istep in range(1, nsteps + 1): 285 286 ### update the system 287 dyn_state, system, conformation, model_output = step( 288 istep, dyn_state, system, conformation, force_preprocess 289 ) 290 291 ### print properties 292 if istep % nprint == 0: 293 t1 = time.time() 294 tperstep = (t1 - t0) / nprint 295 t0 = t1 296 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 297 298 ek = system["ek"] 299 epot = system["epot"] 300 etot = ek + epot 301 temper = 2 * ek / (3.0 * nat) * au.KELVIN 302 303 th_state = system["thermostat"] 304 if include_thermostat_energy: 305 etherm = th_state["thermostat_energy"] 306 etot = etot + etherm 307 308 properties_traj[f"Etot[{atom_energy_unit_str}]"].append( 309 etot * atom_energy_unit 310 ) 311 properties_traj[f"Epot[{atom_energy_unit_str}]"].append( 312 epot * atom_energy_unit 313 ) 314 properties_traj[f"Ekin[{atom_energy_unit_str}]"].append( 315 ek * atom_energy_unit 316 ) 317 properties_traj["Temper[Kelvin]"].append(temper) 318 if pimd: 319 ek_c = system["ek_c"] 320 temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN 321 properties_traj["Temper_c[Kelvin]"].append(temper_c) 322 323 ### construct line of properties 324 simulated_time = start_time_ps + istep * dt / 1000 325 line = f"{istep:10.6g} {simulated_time:10.3f} {etot*atom_energy_unit:#10.4f} {epot*atom_energy_unit:#10.4f} {ek*atom_energy_unit:#10.4f} {temper:10.2f}" 326 if pimd: 327 line += f" {temper_c:10.2f}" 328 if include_thermostat_energy: 329 line += f" {etherm*atom_energy_unit:#10.4f}" 330 properties_traj[f"Etherm[{atom_energy_unit_str}]"].append( 331 etherm * atom_energy_unit 332 ) 333 if estimate_pressure: 334 pres = system["pressure"] * pressure_unit 335 properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres) 336 if print_aniso_pressure: 337 pres_tensor = system["pressure_tensor"] * pressure_unit 338 pres_tensor = 0.5 * (pres_tensor + pres_tensor.T) 339 properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append( 340 pres_tensor[0, 0] 341 ) 342 properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append( 343 pres_tensor[1, 1] 344 ) 345 properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append( 346 pres_tensor[2, 2] 347 ) 348 properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append( 349 pres_tensor[0, 1] 350 ) 351 properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append( 352 pres_tensor[0, 2] 353 ) 354 properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append( 355 pres_tensor[1, 2] 356 ) 357 line += f" {pres:10.3f}" 358 if variable_cell: 359 density = system["density"] 360 properties_traj["Density[g/cm^3]"].append(density) 361 if print_aniso_pressure: 362 cell = system["cell"] 363 properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0]) 364 properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1]) 365 properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2]) 366 properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0]) 367 properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1]) 368 properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2]) 369 properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0]) 370 properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1]) 371 properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2]) 372 line += f" {density:10.4f}" 373 if "piston_temperature" in system["barostat"]: 374 piston_temperature = system["barostat"]["piston_temperature"] 375 properties_traj["T_piston[Kelvin]"].append(piston_temperature) 376 377 print(line) 378 379 ### save colvars 380 if use_colvars: 381 colvars = system["colvars"] 382 fcolvars.write(f"{simulated_time:.3f} ") 383 fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names])) 384 fcolvars.write("\n") 385 fcolvars.flush() 386 387 if save_keys: 388 data = { 389 k: ( 390 np.asarray(model_output[k]) 391 if isinstance(model_output[k], jnp.ndarray) 392 else model_output[k] 393 ) 394 for k in save_keys 395 } 396 data["properties"] = { 397 k: float(v) for k, v in zip(header.split()[1:], line.split()) 398 } 399 data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str) 400 data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str) 401 402 pickle.dump(data, fkeys) 403 404 ### save frame 405 if istep % ndump == 0: 406 line = "# Write XYZ frame" 407 if variable_cell: 408 cell = np.array(system["cell"]) 409 reciprocal_cell = np.linalg.inv(cell) 410 if do_wrap_box: 411 if pimd: 412 centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups) 413 system["coordinates"] = system["coordinates"].at[0].set(centroid) 414 else: 415 system["coordinates"] = wrapbox( 416 system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups 417 ) 418 conformation = update_conformation(conformation, system) 419 line += " (atoms have been wrapped into the box)" 420 force_preprocess = True 421 print(line) 422 423 save_dynamics_restart(system_data, conformation, dyn_state, system) 424 425 properties = { 426 "energy": float(system["epot"]) * energy_unit, 427 "Time_ps": start_time_ps + istep * dt / 1000, 428 "energy_unit": energy_unit_str, 429 } 430 431 if write_all_beads: 432 coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3)) 433 for i, fb in enumerate(fout): 434 write_frame( 435 fb, 436 system_data["symbols"], 437 coords[i], 438 cell=cell, 439 properties=properties, 440 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 441 ) 442 else: 443 write_frame( 444 fout, 445 system_data["symbols"], 446 np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]), 447 cell=cell, 448 properties=properties, 449 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 450 ) 451 if write_centroid: 452 centroid = np.asarray(system["coordinates"][0]) 453 write_frame( 454 fcentroid, 455 system_data["symbols"], 456 centroid, 457 cell=cell, 458 properties=properties, 459 forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) 460 * energy_unit, 461 ) 462 if ensemble_key is not None: 463 weights = " ".join( 464 [f"{w:.6f}" for w in system["ensemble_weights"].tolist()] 465 ) 466 fens.write(f"{weights}\n") 467 fens.flush() 468 469 ### summary over last nsummary steps 470 if istep % (nsummary) == 0: 471 if check_nan(system): 472 raise ValueError(f"dynamics crashed at step {istep}.") 473 tfull = time.time() - t0full 474 t0full = time.time() 475 tperstep = tfull / (nsummary) 476 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 477 elapsed_time = time.time() - tstart_dyn 478 estimated_remaining_time = tperstep * (nsteps - istep) 479 estimated_total_time = elapsed_time + estimated_remaining_time 480 481 print("#" * 50) 482 print(f"# Step {istep:_} of {nsteps:_} ({istep/nsteps*100:.5g} %)") 483 print(f"# Simulated time : {istep * dt*1.e-3:.3f} ps") 484 print(f"# Tot. Simu. time : {start_time_ps + istep * dt*1.e-3:.3f} ps") 485 print(f"# Tot. elapsed time : {human_time_duration(elapsed_time)}") 486 print( 487 f"# Est. total duration : {human_time_duration(estimated_total_time)}" 488 ) 489 print( 490 f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}" 491 ) 492 print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}") 493 494 corr_kin = system["thermostat"].get("corr_kin", None) 495 if corr_kin is not None: 496 print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %") 497 print(f"# Averages over last {nsummary:_} steps :") 498 for k, v in properties_traj.items(): 499 if len(properties_traj[k]) == 0: 500 continue 501 mu = np.mean(properties_traj[k]) 502 sig = np.std(properties_traj[k]) 503 ksplit = k.split("[") 504 name = ksplit[0].strip() 505 unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else "" 506 print(f"# {name:10} : {mu: #10.5g} +/- {sig: #9.3g} {unit}") 507 508 print(f"# Perf.: {nsperday:.2f} ns/day ( {1.0 / tperstep:.2f} step/s )") 509 print("#" * 50) 510 if istep < nsteps: 511 print(header) 512 ## reset property trajectories 513 properties_traj = defaultdict(list) 514 515 print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}") 516 ### close trajectory file 517 fout.close() 518 if ensemble_key is not None: 519 fens.close() 520 if use_colvars: 521 fcolvars.close() 522 if fkeys is not None: 523 fkeys.close() 524 if write_centroid: 525 fcentroid.close() 526 527 return 0 528 529 530if __name__ == "__main__": 531 main()
def
main():
26def main(): 27 # os.environ["OMP_NUM_THREADS"] = "1" 28 sys.stdout = io.TextIOWrapper( 29 open(sys.stdout.fileno(), "wb", 0), write_through=True 30 ) 31 32 ### Read the parameter file 33 parser = argparse.ArgumentParser(prog="fennol_md") 34 parser.add_argument("param_file", type=Path, help="Parameter file") 35 args = parser.parse_args() 36 param_file = args.param_file 37 38 return config_and_run_dynamic(param_file)
def
config_and_run_dynamic(param_file: pathlib.Path):
40def config_and_run_dynamic(param_file: Path): 41 """ Run the dynamic simulation based on the provided parameter file.""" 42 43 if not param_file.exists() and not param_file.is_file(): 44 raise FileNotFoundError(f"Parameter file {param_file} not found") 45 46 if param_file.suffix in [".yaml", ".yml"]: 47 with open(param_file, "r") as f: 48 simulation_parameters = convert_dict_units(yaml.safe_load(f)) 49 simulation_parameters = InputFile(**simulation_parameters) 50 elif param_file.suffix == ".fnl": 51 simulation_parameters = parse_input(param_file) 52 else: 53 raise ValueError( 54 f"Unknown parameter file format '{param_file.suffix}'. Supported formats are '.yaml', '.yml' and '.fnl'" 55 ) 56 57 ### Set the device 58 if "FENNOL_DEVICE" in os.environ: 59 device = os.environ["FENNOL_DEVICE"].lower() 60 print(f"# Setting device from env FENNOL_DEVICE={device}") 61 else: 62 device = simulation_parameters.get("device", "cpu").lower() 63 if device == "cpu": 64 jax.config.update('jax_platforms', 'cpu') 65 os.environ["CUDA_VISIBLE_DEVICES"] = "" 66 elif device.startswith("cuda") or device.startswith("gpu"): 67 if ":" in device: 68 num = device.split(":")[-1] 69 os.environ["CUDA_VISIBLE_DEVICES"] = num 70 else: 71 os.environ["CUDA_VISIBLE_DEVICES"] = "0" 72 device = "gpu" 73 74 _device = jax.devices(device)[0] 75 jax.config.update("jax_default_device", _device) 76 77 ### Set the precision 78 enable_x64 = simulation_parameters.get("double_precision", False) 79 jax.config.update("jax_enable_x64", enable_x64) 80 fprec = "float64" if enable_x64 else "float32" 81 82 matmul_precision = simulation_parameters.get("matmul_prec", "highest").lower() 83 assert matmul_precision in [ 84 "default", 85 "high", 86 "highest", 87 ], "matmul_prec must be one of 'default','high','highest'" 88 if matmul_precision != "highest": 89 print(f"# Setting matmul precision to '{matmul_precision}'") 90 if matmul_precision == "default" and fprec == "float32": 91 print( 92 "# Warning: default matmul precision involves float16 operations which may lead to large numerical errors on energy and pressure estimations ! It is recommended to set matmul_prec to 'high' or 'highest'." 93 ) 94 jax.config.update("jax_default_matmul_precision", matmul_precision) 95 96 # with jax.default_device(_device): 97 return dynamic(simulation_parameters, device, fprec)
Run the dynamic simulation based on the provided parameter file.
def
dynamic(simulation_parameters, device, fprec):
100def dynamic(simulation_parameters, device, fprec): 101 tstart_dyn = time.time() 102 103 random_seed = simulation_parameters.get( 104 "random_seed", np.random.randint(0, 2**32 - 1) 105 ) 106 print(f"# random_seed: {random_seed}") 107 rng_key = jax.random.PRNGKey(random_seed) 108 109 ## INITIALIZE INTEGRATOR AND SYSTEM 110 rng_key, subkey = jax.random.split(rng_key) 111 step, update_conformation, system_data, dyn_state, conformation, system = ( 112 initialize_dynamics(simulation_parameters, fprec, subkey) 113 ) 114 115 nat = system_data["nat"] 116 dt = dyn_state["dt"] 117 ## get number of steps 118 nsteps = int(simulation_parameters.get("nsteps")) 119 start_time_ps = dyn_state.get("start_time_ps", 0.0) 120 121 ### Set I/O parameters 122 Tdump = simulation_parameters.get("tdump", 1.0 / au.PS) * au.FS 123 ndump = int(Tdump / dt) 124 system_name = system_data["name"] 125 estimate_pressure = dyn_state["estimate_pressure"] 126 127 @jax.jit 128 def check_nan(system): 129 return jnp.any(jnp.isnan(system["vel"])) | jnp.any( 130 jnp.isnan(system["coordinates"]) 131 ) 132 133 if system_data["pbc"] is not None: 134 cell = system["cell"] 135 reciprocal_cell = np.linalg.inv(cell) 136 do_wrap_box = simulation_parameters.get("wrap_box", False) 137 if do_wrap_box: 138 wrap_groups_def = simulation_parameters.get("wrap_groups",None) 139 if wrap_groups_def is None: 140 wrap_groups = None 141 else: 142 wrap_groups = {} 143 assert isinstance(wrap_groups_def, dict), "wrap_groups must be a dictionary" 144 for k, v in wrap_groups_def.items(): 145 wrap_groups[k]=read_tinker_interval(v) 146 # check that pairwise intersection of wrap groups is empty 147 wrap_groups_keys = list(wrap_groups.keys()) 148 for i in range(len(wrap_groups_keys)): 149 i_key = wrap_groups_keys[i] 150 w1 = set(wrap_groups[i_key]) 151 for j in range(i + 1, len(wrap_groups_keys)): 152 j_key = wrap_groups_keys[j] 153 w2 = set(wrap_groups[j_key]) 154 if w1.intersection(w2): 155 raise ValueError( 156 f"Wrap groups {i_key} and {j_key} have common atoms: {w1.intersection(w2)}" 157 ) 158 group_all = np.concatenate(list(wrap_groups.values())) 159 # get all atoms that are not in any wrap group 160 group_none = np.setdiff1d(np.arange(nat), group_all) 161 print(f"# Wrap groups: {wrap_groups}") 162 wrap_groups["__other"] = group_none 163 wrap_groups = ((k, v) for k, v in wrap_groups.items()) 164 165 else: 166 cell = None 167 reciprocal_cell = None 168 do_wrap_box = False 169 wrap_groups = None 170 171 ### Energy units and print initial energy 172 model_energy_unit = system_data["model_energy_unit"] 173 model_energy_unit_str = system_data["model_energy_unit_str"] 174 per_atom_energy = simulation_parameters.get("per_atom_energy", True) 175 energy_unit_str = system_data["energy_unit_str"] 176 energy_unit = system_data["energy_unit"] 177 print("# Energy unit: ", energy_unit_str) 178 atom_energy_unit = energy_unit 179 atom_energy_unit_str = energy_unit_str 180 if per_atom_energy: 181 atom_energy_unit /= nat 182 atom_energy_unit_str = f"{energy_unit_str}/atom" 183 print("# Printing Energy per atom") 184 print( 185 f"# Initial potential energy: {system['epot']*atom_energy_unit}; kinetic energy: {system['ek']*atom_energy_unit}" 186 ) 187 f = system["forces"] 188 minmaxone(jnp.abs(f * energy_unit), "# forces min/max/rms:") 189 190 ## printing options 191 print_timings = simulation_parameters.get("print_timings", False) 192 nprint = int(simulation_parameters.get("nprint", 10)) 193 assert nprint > 0, "nprint must be > 0" 194 nsummary = simulation_parameters.get("nsummary", 100 * nprint) 195 assert nsummary > nprint, "nsummary must be > nprint" 196 197 save_keys = simulation_parameters.get("save_keys", []) 198 if save_keys: 199 print(f"# Saving keys: {save_keys}") 200 fkeys = open(f"{system_name}.traj.pkl", "wb+") 201 else: 202 fkeys = None 203 204 ### initialize colvars 205 use_colvars = "colvars" in dyn_state 206 if use_colvars: 207 print(f"# Colvars: {dyn_state['colvars']}") 208 colvars_names = dyn_state["colvars"] 209 # open colvars file and print header 210 fcolvars = open(f"{system_name}.colvars.traj", "a") 211 fcolvars.write("#time[ps] ") 212 fcolvars.write(" ".join(colvars_names)) 213 fcolvars.write("\n") 214 fcolvars.flush() 215 216 ### Print header 217 include_thermostat_energy = "thermostat_energy" in system["thermostat"] 218 thermostat_name = dyn_state["thermostat_name"] 219 pimd = dyn_state["pimd"] 220 variable_cell = dyn_state["variable_cell"] 221 nbeads = system_data.get("nbeads", 1) 222 dyn_name = "PIMD" if pimd else "MD" 223 print("#" * 84) 224 print( 225 f"# Running {nsteps:_} steps of {thermostat_name} {dyn_name} simulation on {device}" 226 ) 227 header = f"#{'Step':>9} {'Time[ps]':>10} {'Etot':>10} {'Epot':>10} {'Ekin':>10} {'Temp[K]':>10}" 228 if pimd: 229 header += f" {'Temp_c[K]':>10}" 230 if include_thermostat_energy: 231 header += f" {'Etherm':>10}" 232 if estimate_pressure: 233 print_aniso_pressure = simulation_parameters.get("print_aniso_pressure", False) 234 pressure_unit_str = simulation_parameters.get("pressure_unit", "atm") 235 pressure_unit = au.get_multiplier(pressure_unit_str) * au.BOHR**3 236 p_str = f" Press[{pressure_unit_str}]" 237 header += f" {p_str:>10}" 238 if variable_cell: 239 header += f" {'Density':>10}" 240 print(header) 241 242 ### Open trajectory file 243 traj_format = simulation_parameters.get("traj_format", "arc").lower() 244 if traj_format == "xyz": 245 traj_ext = ".traj.xyz" 246 write_frame = write_xyz_frame 247 elif traj_format == "extxyz": 248 traj_ext = ".traj.extxyz" 249 write_frame = write_extxyz_frame 250 elif traj_format == "arc": 251 traj_ext = ".arc" 252 write_frame = write_arc_frame 253 else: 254 raise ValueError( 255 f"Unknown trajectory format '{traj_format}'. Supported formats are 'arc' and 'xyz'" 256 ) 257 258 write_all_beads = simulation_parameters.get("write_all_beads", False) and pimd 259 260 if write_all_beads: 261 fout = [ 262 open(f"{system_name}_bead{i+1:03d}" + traj_ext, "a") for i in range(nbeads) 263 ] 264 else: 265 fout = open(system_name + traj_ext, "a") 266 267 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 268 if ensemble_key is not None: 269 fens = open(f"{system_name}.ensemble_weights.traj", "a") 270 271 write_centroid = simulation_parameters.get("write_centroid", False) and pimd 272 if write_centroid: 273 fcentroid = open(f"{system_name}_centroid" + traj_ext, "a") 274 275 ### initialize proprerty trajectories 276 properties_traj = defaultdict(list) 277 278 ### initialize counters and timers 279 t0 = time.time() 280 t0dump = t0 281 istep = 0 282 t0full = time.time() 283 force_preprocess = False 284 285 for istep in range(1, nsteps + 1): 286 287 ### update the system 288 dyn_state, system, conformation, model_output = step( 289 istep, dyn_state, system, conformation, force_preprocess 290 ) 291 292 ### print properties 293 if istep % nprint == 0: 294 t1 = time.time() 295 tperstep = (t1 - t0) / nprint 296 t0 = t1 297 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 298 299 ek = system["ek"] 300 epot = system["epot"] 301 etot = ek + epot 302 temper = 2 * ek / (3.0 * nat) * au.KELVIN 303 304 th_state = system["thermostat"] 305 if include_thermostat_energy: 306 etherm = th_state["thermostat_energy"] 307 etot = etot + etherm 308 309 properties_traj[f"Etot[{atom_energy_unit_str}]"].append( 310 etot * atom_energy_unit 311 ) 312 properties_traj[f"Epot[{atom_energy_unit_str}]"].append( 313 epot * atom_energy_unit 314 ) 315 properties_traj[f"Ekin[{atom_energy_unit_str}]"].append( 316 ek * atom_energy_unit 317 ) 318 properties_traj["Temper[Kelvin]"].append(temper) 319 if pimd: 320 ek_c = system["ek_c"] 321 temper_c = 2 * ek_c / (3.0 * nat) * au.KELVIN 322 properties_traj["Temper_c[Kelvin]"].append(temper_c) 323 324 ### construct line of properties 325 simulated_time = start_time_ps + istep * dt / 1000 326 line = f"{istep:10.6g} {simulated_time:10.3f} {etot*atom_energy_unit:#10.4f} {epot*atom_energy_unit:#10.4f} {ek*atom_energy_unit:#10.4f} {temper:10.2f}" 327 if pimd: 328 line += f" {temper_c:10.2f}" 329 if include_thermostat_energy: 330 line += f" {etherm*atom_energy_unit:#10.4f}" 331 properties_traj[f"Etherm[{atom_energy_unit_str}]"].append( 332 etherm * atom_energy_unit 333 ) 334 if estimate_pressure: 335 pres = system["pressure"] * pressure_unit 336 properties_traj[f"Pressure[{pressure_unit_str}]"].append(pres) 337 if print_aniso_pressure: 338 pres_tensor = system["pressure_tensor"] * pressure_unit 339 pres_tensor = 0.5 * (pres_tensor + pres_tensor.T) 340 properties_traj[f"Pressure_xx[{pressure_unit_str}]"].append( 341 pres_tensor[0, 0] 342 ) 343 properties_traj[f"Pressure_yy[{pressure_unit_str}]"].append( 344 pres_tensor[1, 1] 345 ) 346 properties_traj[f"Pressure_zz[{pressure_unit_str}]"].append( 347 pres_tensor[2, 2] 348 ) 349 properties_traj[f"Pressure_xy[{pressure_unit_str}]"].append( 350 pres_tensor[0, 1] 351 ) 352 properties_traj[f"Pressure_xz[{pressure_unit_str}]"].append( 353 pres_tensor[0, 2] 354 ) 355 properties_traj[f"Pressure_yz[{pressure_unit_str}]"].append( 356 pres_tensor[1, 2] 357 ) 358 line += f" {pres:10.3f}" 359 if variable_cell: 360 density = system["density"] 361 properties_traj["Density[g/cm^3]"].append(density) 362 if print_aniso_pressure: 363 cell = system["cell"] 364 properties_traj[f"Cell_Ax[Angstrom]"].append(cell[0, 0]) 365 properties_traj[f"Cell_Ay[Angstrom]"].append(cell[0, 1]) 366 properties_traj[f"Cell_Az[Angstrom]"].append(cell[0, 2]) 367 properties_traj[f"Cell_Bx[Angstrom]"].append(cell[1, 0]) 368 properties_traj[f"Cell_By[Angstrom]"].append(cell[1, 1]) 369 properties_traj[f"Cell_Bz[Angstrom]"].append(cell[1, 2]) 370 properties_traj[f"Cell_Cx[Angstrom]"].append(cell[2, 0]) 371 properties_traj[f"Cell_Cy[Angstrom]"].append(cell[2, 1]) 372 properties_traj[f"Cell_Cz[Angstrom]"].append(cell[2, 2]) 373 line += f" {density:10.4f}" 374 if "piston_temperature" in system["barostat"]: 375 piston_temperature = system["barostat"]["piston_temperature"] 376 properties_traj["T_piston[Kelvin]"].append(piston_temperature) 377 378 print(line) 379 380 ### save colvars 381 if use_colvars: 382 colvars = system["colvars"] 383 fcolvars.write(f"{simulated_time:.3f} ") 384 fcolvars.write(" ".join([f"{colvars[k]:.6f}" for k in colvars_names])) 385 fcolvars.write("\n") 386 fcolvars.flush() 387 388 if save_keys: 389 data = { 390 k: ( 391 np.asarray(model_output[k]) 392 if isinstance(model_output[k], jnp.ndarray) 393 else model_output[k] 394 ) 395 for k in save_keys 396 } 397 data["properties"] = { 398 k: float(v) for k, v in zip(header.split()[1:], line.split()) 399 } 400 data["properties"]["properties_energy_unit"] = (atom_energy_unit,atom_energy_unit_str) 401 data["properties"]["model_energy_unit"] = (model_energy_unit,model_energy_unit_str) 402 403 pickle.dump(data, fkeys) 404 405 ### save frame 406 if istep % ndump == 0: 407 line = "# Write XYZ frame" 408 if variable_cell: 409 cell = np.array(system["cell"]) 410 reciprocal_cell = np.linalg.inv(cell) 411 if do_wrap_box: 412 if pimd: 413 centroid = wrapbox(system["coordinates"][0], cell, reciprocal_cell,wrap_groups=wrap_groups) 414 system["coordinates"] = system["coordinates"].at[0].set(centroid) 415 else: 416 system["coordinates"] = wrapbox( 417 system["coordinates"], cell, reciprocal_cell,wrap_groups=wrap_groups 418 ) 419 conformation = update_conformation(conformation, system) 420 line += " (atoms have been wrapped into the box)" 421 force_preprocess = True 422 print(line) 423 424 save_dynamics_restart(system_data, conformation, dyn_state, system) 425 426 properties = { 427 "energy": float(system["epot"]) * energy_unit, 428 "Time_ps": start_time_ps + istep * dt / 1000, 429 "energy_unit": energy_unit_str, 430 } 431 432 if write_all_beads: 433 coords = np.asarray(conformation["coordinates"].reshape(-1, nat, 3)) 434 for i, fb in enumerate(fout): 435 write_frame( 436 fb, 437 system_data["symbols"], 438 coords[i], 439 cell=cell, 440 properties=properties, 441 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 442 ) 443 else: 444 write_frame( 445 fout, 446 system_data["symbols"], 447 np.asarray(conformation["coordinates"].reshape(-1, nat, 3)[0]), 448 cell=cell, 449 properties=properties, 450 forces=None, # np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) * energy_unit, 451 ) 452 if write_centroid: 453 centroid = np.asarray(system["coordinates"][0]) 454 write_frame( 455 fcentroid, 456 system_data["symbols"], 457 centroid, 458 cell=cell, 459 properties=properties, 460 forces=np.asarray(system["forces"].reshape(nbeads, nat, 3)[0]) 461 * energy_unit, 462 ) 463 if ensemble_key is not None: 464 weights = " ".join( 465 [f"{w:.6f}" for w in system["ensemble_weights"].tolist()] 466 ) 467 fens.write(f"{weights}\n") 468 fens.flush() 469 470 ### summary over last nsummary steps 471 if istep % (nsummary) == 0: 472 if check_nan(system): 473 raise ValueError(f"dynamics crashed at step {istep}.") 474 tfull = time.time() - t0full 475 t0full = time.time() 476 tperstep = tfull / (nsummary) 477 nsperday = (24 * 60 * 60 / tperstep) * dt / 1e6 478 elapsed_time = time.time() - tstart_dyn 479 estimated_remaining_time = tperstep * (nsteps - istep) 480 estimated_total_time = elapsed_time + estimated_remaining_time 481 482 print("#" * 50) 483 print(f"# Step {istep:_} of {nsteps:_} ({istep/nsteps*100:.5g} %)") 484 print(f"# Simulated time : {istep * dt*1.e-3:.3f} ps") 485 print(f"# Tot. Simu. time : {start_time_ps + istep * dt*1.e-3:.3f} ps") 486 print(f"# Tot. elapsed time : {human_time_duration(elapsed_time)}") 487 print( 488 f"# Est. total duration : {human_time_duration(estimated_total_time)}" 489 ) 490 print( 491 f"# Est. remaining time : {human_time_duration(estimated_remaining_time)}" 492 ) 493 print(f"# Time for {nsummary:_} steps : {human_time_duration(tfull)}") 494 495 corr_kin = system["thermostat"].get("corr_kin", None) 496 if corr_kin is not None: 497 print(f"# QTB kin. correction : {100*(corr_kin-1.):.2f} %") 498 print(f"# Averages over last {nsummary:_} steps :") 499 for k, v in properties_traj.items(): 500 if len(properties_traj[k]) == 0: 501 continue 502 mu = np.mean(properties_traj[k]) 503 sig = np.std(properties_traj[k]) 504 ksplit = k.split("[") 505 name = ksplit[0].strip() 506 unit = ksplit[1].replace("]", "").strip() if len(ksplit) > 1 else "" 507 print(f"# {name:10} : {mu: #10.5g} +/- {sig: #9.3g} {unit}") 508 509 print(f"# Perf.: {nsperday:.2f} ns/day ( {1.0 / tperstep:.2f} step/s )") 510 print("#" * 50) 511 if istep < nsteps: 512 print(header) 513 ## reset property trajectories 514 properties_traj = defaultdict(list) 515 516 print(f"# Run done in {human_time_duration(time.time()-tstart_dyn)}") 517 ### close trajectory file 518 fout.close() 519 if ensemble_key is not None: 520 fens.close() 521 if use_colvars: 522 fcolvars.close() 523 if fkeys is not None: 524 fkeys.close() 525 if write_centroid: 526 fcentroid.close() 527 528 return 0