fennol.md.integrate
1import time 2import math 3import os 4 5import numpy as np 6import jax 7import jax.numpy as jnp 8 9from .thermostats import get_thermostat 10from .barostats import get_barostat 11from .colvars import setup_colvars 12from .spectra import initialize_ir_spectrum 13 14from .utils import load_dynamics_restart, get_restart_file,optimize_fire2, us 15from .initial import load_model, load_system_data, initialize_preprocessing 16 17 18def initialize_dynamics(simulation_parameters, fprec, rng_key): 19 ### LOAD MODEL 20 model = load_model(simulation_parameters) 21 model_energy_unit = us.get_multiplier(model.energy_unit) 22 23 ### Get the coordinates and species from the xyz file 24 system_data, conformation = load_system_data(simulation_parameters, fprec) 25 system_data["model_energy_unit"] = model_energy_unit 26 system_data["model_energy_unit_str"] = model.energy_unit 27 28 ### FINISH BUILDING conformation 29 do_restart = os.path.exists(get_restart_file(system_data)) 30 if do_restart: 31 ### RESTART FROM PREVIOUS DYNAMICS 32 restart_data = load_dynamics_restart(system_data) 33 print("# RESTARTING FROM PREVIOUS DYNAMICS") 34 model.preproc_state = restart_data["preproc_state"] 35 conformation["coordinates"] = restart_data["coordinates"] 36 else: 37 restart_data = {} 38 39 ### INITIALIZE PREPROCESSING 40 preproc_state, conformation = initialize_preprocessing( 41 simulation_parameters, model, conformation, system_data 42 ) 43 44 minimize = simulation_parameters.get("xyz_input/minimize", False) 45 """@keyword[fennol_md] xyz_input/minimize 46 Perform energy minimization before dynamics. 47 Default: False 48 """ 49 if minimize and not do_restart: 50 assert system_data["nreplicas"] == 1, "Minimization is only supported for single replica systems" 51 model.preproc_state = preproc_state 52 convert = us.KCALPERMOL / model_energy_unit 53 nat = system_data["nat"] 54 def energy_force_fn(coordinates): 55 inputs = {**conformation, "coordinates": coordinates} 56 e, f, _ = model.energy_and_forces( 57 **inputs, gpu_preprocessing=True 58 ) 59 e = float(e[0]) * convert / nat 60 f = np.array(f) * convert 61 return e, f 62 tol = simulation_parameters.get("xyz_input/minimize_ftol", 1e-1/us.KCALPERMOL)*us.KCALPERMOL 63 """@keyword[fennol_md] xyz_input/minimize_ftol 64 Force tolerance for minimization. 65 Default: 0.1 kcal/mol/Å 66 """ 67 print(f"# Minimizing initial configuration with RMS force tolerance = {tol:.1e} kcal/mol/A") 68 conformation["coordinates"], success = optimize_fire2( 69 conformation["coordinates"], 70 energy_force_fn, 71 atol=tol, 72 max_disp=0.02, 73 ) 74 if success: 75 print("# Minimization successful") 76 else: 77 print("# Warning: Minimization failed, continuing with last configuration") 78 # write the minimized coordinates as an xyz file 79 from ..utils.io import write_xyz_frame 80 with open(system_data["name"]+".opt.xyz", "w") as f: 81 write_xyz_frame(f, system_data["symbols"],np.array(conformation["coordinates"]),cell=conformation.get("cells", None)) 82 print("# Minimized configuration written to", system_data["name"]+".opt.xyz") 83 preproc_state = model.preproc_state 84 conformation = model.preprocessing.process(preproc_state, conformation) 85 system_data["initial_coordinates"] = np.array(conformation["coordinates"]).copy() 86 87 ### get dynamics parameters 88 dt = simulation_parameters.get("dt") 89 """@keyword[fennol_md] dt 90 Integration time step. Required parameter. 91 Type: float, Required 92 """ 93 dt2 = 0.5 * dt 94 mass = system_data["mass"] 95 densmass = system_data["totmass_Da"] * (us.MOL/us.CM**3) 96 nat = system_data["nat"] 97 dtm = jnp.asarray(dt / mass[:, None], dtype=fprec) 98 ek_avg = 0.5 * nat * system_data["kT"] * np.eye(3) 99 100 nreplicas = system_data.get("nreplicas", 1) 101 nbeads = system_data.get("nbeads", None) 102 if nbeads is not None: 103 nreplicas = nbeads 104 dtm = dtm[None, :, :] 105 106 ### INITIALIZE DYNAMICS STATE 107 system = {"coordinates": conformation["coordinates"]} 108 dyn_state = { 109 "istep": 0, 110 "dt": dt, 111 "pimd": nbeads is not None, 112 "preproc_state": preproc_state, 113 "start_time_ps": restart_data.get("simulation_time_ps", 0.), 114 } 115 gradient_keys = ["coordinates"] 116 thermo_updates = [] 117 118 ### INITIALIZE THERMOSTAT 119 thermostat_rng, rng_key = jax.random.split(rng_key) 120 ( 121 thermostat, 122 thermostat_post, 123 thermostat_state, 124 initial_vel, 125 dyn_state["thermostat_name"], 126 ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data) 127 do_thermostat_post = thermostat_post is not None 128 if do_thermostat_post: 129 thermostat_post, post_state = thermostat_post 130 dyn_state["thermostat_post_state"] = post_state 131 132 system["thermostat"] = thermostat_state 133 system["vel"] = restart_data.get("vel", initial_vel).astype(fprec) 134 135 ### PBC 136 pbc_data = system_data.get("pbc", None) 137 if pbc_data is not None: 138 ### INITIALIZE BAROSTAT 139 barostat_key, rng_key = jax.random.split(rng_key) 140 thermo_update_ensemble, variable_cell, barostat_state = get_barostat( 141 thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data 142 ) 143 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 144 system["barostat"] = barostat_state 145 system["cell"] = conformation["cells"][0] 146 if estimate_pressure: 147 pressure_o_weight = simulation_parameters.get("pressure_o_weight", 1.0) 148 """@keyword[fennol_md] pressure_o_weight 149 Weight factor for mixing middle (O) and outer time step kinetic energies in pressure estimator. 150 Default: 1.0 151 """ 152 assert ( 153 0.0 <= pressure_o_weight <= 1.0 154 ), "pressure_o_weight must be between 0 and 1" 155 gradient_keys.append("strain") 156 print("# Estimate pressure: ", estimate_pressure) 157 else: 158 estimate_pressure = False 159 variable_cell = False 160 161 def thermo_update_ensemble(x, v, system): 162 v, thermostat_state = thermostat(v, system["thermostat"]) 163 return x, v, {**system, "thermostat": thermostat_state} 164 165 dyn_state["estimate_pressure"] = estimate_pressure 166 dyn_state["variable_cell"] = variable_cell 167 thermo_updates.append(thermo_update_ensemble) 168 169 if estimate_pressure: 170 use_average_Pkin = simulation_parameters.get("use_average_Pkin", False) 171 """@keyword[fennol_md] use_average_Pkin 172 Use time-averaged kinetic energy for pressure estimation instead of instantaneous values. 173 Default: False 174 """ 175 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 176 if is_qtb and use_average_Pkin: 177 raise ValueError( 178 "use_average_Pkin is not compatible with QTB thermostat, please set use_average_Pkin to False" 179 ) 180 181 182 ### ENERGY ENSEMBLE 183 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 184 """@keyword[fennol_md] etot_ensemble_key 185 Key for energy ensemble calculation. Enables computation of ensemble weights. 186 Default: None 187 """ 188 189 ### COLVARS 190 colvars_definitions = simulation_parameters.get("colvars", None) 191 """@keyword[fennol_md] colvars 192 Collective variables definitions for enhanced sampling or monitoring. 193 Default: None 194 """ 195 use_colvars = colvars_definitions is not None 196 if use_colvars: 197 colvars_calculators, colvars_names = setup_colvars(colvars_definitions) 198 dyn_state["colvars"] = colvars_names 199 200 ### IR SPECTRUM 201 do_ir_spectrum = simulation_parameters.get("ir_spectrum", False) 202 """@keyword[fennol_md] ir_spectrum 203 Calculate infrared spectrum from molecular dipole moment time series. 204 Default: False 205 """ 206 assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean" 207 if do_ir_spectrum: 208 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 209 model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum( 210 simulation_parameters, system_data, fprec, dt, is_qtb 211 ) 212 dyn_state["ir_spectrum"] = ir_state 213 214 ### BUILD GRADIENT FUNCTION 215 energy_and_gradient = model.get_gradient_function( 216 *gradient_keys, jit=True, variables_as_input=True 217 ) 218 219 ### COLLECT THERMO UPDATES 220 if len(thermo_updates) == 1: 221 thermo_update = thermo_updates[0] 222 else: 223 224 def thermo_update(x, v, system): 225 for update in thermo_updates: 226 x, v, system = update(x, v, system) 227 return x, v, system 228 229 ### RING POLYMER INITIALIZATION 230 if nbeads is not None: 231 cay_correction = simulation_parameters.get("cay_correction", True) 232 """@keyword[fennol_md] cay_correction 233 Use Cayley propagator for ring polymer molecular dynamics instead of standard propagation. 234 Default: True 235 """ 236 omk = system_data["omk"] 237 eigmat = system_data["eigmat"] 238 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 239 if cay_correction: 240 axx = jnp.asarray(2 * cayfact) 241 axv = jnp.asarray(dt * cayfact) 242 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 243 else: 244 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 245 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 246 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 247 248 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 249 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 250 system["coordinates"] = eigx 251 252 ############################################### 253 ### DEFINE UPDATE FUNCTION 254 @jax.jit 255 def update_conformation(conformation, system): 256 x = system["coordinates"] 257 if nbeads is not None: 258 x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * ( 259 nbeads**0.5 260 ) 261 conformation = {**conformation, "coordinates": x} 262 if variable_cell: 263 conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0) 264 265 266 267 return conformation 268 269 ############################################### 270 ### DEFINE INTEGRATION FUNCTIONS 271 def integrate_A_half(x0, v0): 272 if nbeads is None: 273 return x0 + dt2 * v0, v0 274 275 # update coordinates and velocities of a free ring polymer for a half time step 276 eigx_c = x0[0] + dt2 * v0[0] 277 eigv_c = v0[0] 278 eigx = x0[1:] * axx + v0[1:] * axv 279 eigv = x0[1:] * avx + v0[1:] * axx 280 281 return ( 282 jnp.concatenate((eigx_c[None], eigx), axis=0), 283 jnp.concatenate((eigv_c[None], eigv), axis=0), 284 ) 285 286 @jax.jit 287 def integrate(system): 288 x = system["coordinates"] 289 v = system["vel"] + dtm * system["forces"] 290 x, v = integrate_A_half(x, v) 291 x, v, system = thermo_update(x, v, system) 292 x, v = integrate_A_half(x, v) 293 294 return {**system, "coordinates": x, "vel": v} 295 296 ############################################### 297 ### DEFINE OBSERVABLE FUNCTION 298 @jax.jit 299 def update_observables(system, conformation): 300 ### POTENTIAL ENERGY AND FORCES 301 epot, de, out = energy_and_gradient(model.variables, conformation) 302 out["forces"] = -de["coordinates"] 303 epot = epot / model_energy_unit 304 de = {k: v / model_energy_unit for k, v in de.items()} 305 forces = -de["coordinates"] 306 307 if nbeads is not None: 308 ### PROJECT FORCES ONTO POLYMER NORMAL MODES 309 forces = jnp.einsum( 310 "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3) 311 ) * (1.0 / nbeads**0.5) 312 313 system = { 314 **system, 315 "epot": jnp.mean(epot), 316 "forces": forces, 317 "energy_gradients": de, 318 } 319 320 ### KINETIC ENERGY 321 v = system["vel"] 322 if nbeads is None: 323 corr_kin = system["thermostat"].get("corr_kin", 1.0) 324 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 325 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 326 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 327 ) 328 else: 329 ek_c = 0.5 * jnp.sum( 330 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0 331 ) 332 ek = ek_c - 0.5 * jnp.sum( 333 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 334 axis=(0, 1), 335 ) 336 system["ek_c"] = jnp.trace(ek_c) 337 338 system["ek"] = jnp.trace(ek) 339 system["ek_tensor"] = ek 340 341 if estimate_pressure: 342 if use_average_Pkin: 343 ek = ek_avg 344 elif pressure_o_weight != 1.0: 345 v = system["vel"] + 0.5 * dtm * system["forces"] 346 if nbeads is None: 347 corr_kin = system["thermostat"].get("corr_kin", 1.0) 348 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 349 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 350 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 351 ) 352 else: 353 ek_c = 0.5 * jnp.sum( 354 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], 355 axis=0, 356 ) 357 ek = ek_c - 0.5 * jnp.sum( 358 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 359 axis=(0, 1), 360 ) 361 b = pressure_o_weight 362 ek = (1.0 - b) * ek + b * system["ek_tensor"] 363 364 vir = jnp.mean(de["strain"], axis=0) 365 system["virial"] = vir 366 out["virial_tensor"] = vir * model_energy_unit 367 368 volume = jnp.abs(jnp.linalg.det(system["cell"])) 369 Pres = ek*(2./volume) - vir/volume 370 system["pressure_tensor"] = Pres 371 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 372 if variable_cell: 373 density = densmass / volume 374 system["density"] = density 375 system["volume"] = volume 376 377 if ensemble_key is not None: 378 kT = system_data["kT"] 379 dE = ( 380 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"] 381 ) 382 system["ensemble_weights"] = -dE / kT 383 384 if "total_dipole" in out: 385 if nbeads is None: 386 system["total_dipole"] = out["total_dipole"][0] 387 else: 388 system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0) 389 390 if use_colvars: 391 coords = system["coordinates"].reshape(-1, nat, 3)[0] 392 colvars = {} 393 for colvar_name, colvar_calc in colvars_calculators.items(): 394 colvars[colvar_name] = colvar_calc(coords) 395 system["colvars"] = colvars 396 397 return system, out 398 399 ############################################### 400 ### IR SPECTRUM 401 if do_ir_spectrum: 402 # @jax.jit 403 # def update_dipole(ir_state,system,conformation): 404 # def mumodel(coords): 405 # out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords}) 406 # if nbeads is None: 407 # return out["total_dipole"][0] 408 # return out["total_dipole"].sum(axis=0) 409 # dmudqmodel = jax.jacobian(mumodel) 410 411 # dmudq = dmudqmodel(conformation["coordinates"]) 412 # # print(dmudq.shape) 413 # if nbeads is None: 414 # vel = system["vel"].reshape(-1,1,nat,3)[0] 415 # mudot = (vel*dmudq).sum(axis=(1,2)) 416 # else: 417 # dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1) 418 # vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) * nbeads**0.5 419 # ) 420 # # vel = system["vel"][0].reshape(1,nat,3) 421 # mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads 422 423 # ir_state = save_dipole(mudot,ir_state) 424 # return ir_state 425 @jax.jit 426 def update_conformation_ir(conformation, system): 427 conformation = { 428 **conformation, 429 "coordinates": system["coordinates"].reshape(-1, nat, 3)[0], 430 "natoms": jnp.asarray([nat]), 431 "batch_index": jnp.asarray([0] * nat), 432 "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]), 433 } 434 if variable_cell: 435 conformation["cells"] = system["cell"][None, :, :] 436 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 437 None, :, : 438 ] 439 return conformation 440 441 @jax.jit 442 def update_dipole(ir_state, system, conformation): 443 if model_ir is not None: 444 out = model_ir._apply(model_ir.variables, conformation) 445 q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 446 dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 447 else: 448 q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 449 dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 450 if nbeads is not None: 451 q = jnp.mean(q, axis=0) 452 dip = jnp.mean(dip, axis=0) 453 vel = system["vel"][0] 454 pos = system["coordinates"][0] 455 else: 456 q = q[0] 457 dip = dip[0] 458 vel = system["vel"].reshape(-1, nat, 3)[0] 459 pos = system["coordinates"].reshape(-1, nat, 3)[0] 460 461 if pbc_data is not None: 462 cell_reciprocal = ( 463 conformation["cells"][0], 464 conformation["reciprocal_cells"][0], 465 ) 466 else: 467 cell_reciprocal = None 468 469 ir_state = save_dipole( 470 q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state 471 ) 472 return ir_state 473 474 ############################################### 475 ### GRAPH UPDATES 476 477 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 478 """@keyword[fennol_md] nblist_verbose 479 Print verbose information about neighbor list updates and reallocations. 480 Default: False 481 """ 482 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 483 """@keyword[fennol_md] nblist_stride 484 Number of steps between full neighbor list rebuilds. Auto-calculated from skin if <= 0. 485 Default: -1 486 """ 487 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) 488 """@keyword[fennol_md] nblist_warmup_time 489 Time period for neighbor list warmup before using skin updates. 490 Default: -1.0 491 """ 492 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 493 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 494 """@keyword[fennol_md] nblist_skin 495 Neighbor list skin distance for efficient updates (in Angstroms). 496 Default: -1.0 497 """ 498 if nblist_skin > 0: 499 if nblist_stride <= 0: 500 ## reference skin parameters at 300K (from Tinker-HP) 501 ## => skin of 2 A gives you 40 fs without complete rebuild 502 t_ref = 40.0 /us.FS # FS 503 nblist_skin_ref = 2.0 # A 504 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 505 print( 506 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 507 ) 508 509 if nblist_skin <= 0: 510 nblist_stride = 1 511 512 dyn_state["nblist_countdown"] = 0 513 dyn_state["print_skin_activation"] = nblist_warmup > 0 514 515 def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False): 516 nblist_countdown = dyn_state["nblist_countdown"] 517 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 518 ### FULL NBLIST REBUILD 519 dyn_state["nblist_countdown"] = nblist_stride - 1 520 preproc_state = dyn_state["preproc_state"] 521 conformation = model.preprocessing.process( 522 preproc_state, update_conformation(conformation, system) 523 ) 524 preproc_state, state_up, conformation, overflow = ( 525 model.preprocessing.check_reallocate(preproc_state, conformation) 526 ) 527 dyn_state["preproc_state"] = preproc_state 528 if nblist_verbose and overflow: 529 print("step", istep, ", nblist overflow => reallocating nblist") 530 print("size updates:", state_up) 531 532 if do_ir_spectrum and model_ir is not None: 533 conformation_ir = model_ir.preprocessing.process( 534 dyn_state["preproc_state_ir"], 535 update_conformation_ir(dyn_state["conformation_ir"], system), 536 ) 537 ( 538 dyn_state["preproc_state_ir"], 539 _, 540 dyn_state["conformation_ir"], 541 overflow, 542 ) = model_ir.preprocessing.check_reallocate( 543 dyn_state["preproc_state_ir"], conformation_ir 544 ) 545 546 else: 547 ### SKIN UPDATE 548 if dyn_state["print_skin_activation"]: 549 if nblist_verbose: 550 print( 551 "step", 552 istep, 553 ", end of nblist warmup phase => activating skin updates", 554 ) 555 dyn_state["print_skin_activation"] = False 556 557 dyn_state["nblist_countdown"] = nblist_countdown - 1 558 conformation = model.preprocessing.update_skin( 559 update_conformation(conformation, system) 560 ) 561 if do_ir_spectrum and model_ir is not None: 562 dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin( 563 update_conformation_ir(dyn_state["conformation_ir"], system) 564 ) 565 566 return conformation, dyn_state 567 568 ################################################ 569 ### DEFINE STEP FUNCTION 570 def step(istep, dyn_state, system, conformation, force_preprocess=False): 571 572 dyn_state = { 573 **dyn_state, 574 "istep": dyn_state["istep"] + 1, 575 } 576 577 ### INTEGRATE EQUATIONS OF MOTION 578 system = integrate(system) 579 580 ### UPDATE CONFORMATION AND GRAPHS 581 conformation, dyn_state = update_graphs( 582 istep, dyn_state, system, conformation, force_preprocess 583 ) 584 585 ## COMPUTE FORCES AND OBSERVABLES 586 system, out = update_observables(system, conformation) 587 588 ## END OF STEP UPDATES 589 if do_thermostat_post: 590 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 591 system["thermostat"], dyn_state["thermostat_post_state"] 592 ) 593 594 if do_ir_spectrum: 595 ir_state = update_dipole( 596 dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"] 597 ) 598 dyn_state["ir_spectrum"] = ir_post(ir_state) 599 600 return dyn_state, system, conformation, out 601 602 ########################################################### 603 604 print("# Computing initial energy and forces") 605 606 conformation = update_conformation(conformation, system) 607 # initialize IR conformation 608 if do_ir_spectrum and model_ir is not None: 609 dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = ( 610 model_ir.preprocessing( 611 model_ir.preproc_state, 612 update_conformation_ir(conformation, system), 613 ) 614 ) 615 616 system, _ = update_observables(system, conformation) 617 618 return step, update_conformation, system_data, dyn_state, conformation, system
def
initialize_dynamics(simulation_parameters, fprec, rng_key):
19def initialize_dynamics(simulation_parameters, fprec, rng_key): 20 ### LOAD MODEL 21 model = load_model(simulation_parameters) 22 model_energy_unit = us.get_multiplier(model.energy_unit) 23 24 ### Get the coordinates and species from the xyz file 25 system_data, conformation = load_system_data(simulation_parameters, fprec) 26 system_data["model_energy_unit"] = model_energy_unit 27 system_data["model_energy_unit_str"] = model.energy_unit 28 29 ### FINISH BUILDING conformation 30 do_restart = os.path.exists(get_restart_file(system_data)) 31 if do_restart: 32 ### RESTART FROM PREVIOUS DYNAMICS 33 restart_data = load_dynamics_restart(system_data) 34 print("# RESTARTING FROM PREVIOUS DYNAMICS") 35 model.preproc_state = restart_data["preproc_state"] 36 conformation["coordinates"] = restart_data["coordinates"] 37 else: 38 restart_data = {} 39 40 ### INITIALIZE PREPROCESSING 41 preproc_state, conformation = initialize_preprocessing( 42 simulation_parameters, model, conformation, system_data 43 ) 44 45 minimize = simulation_parameters.get("xyz_input/minimize", False) 46 """@keyword[fennol_md] xyz_input/minimize 47 Perform energy minimization before dynamics. 48 Default: False 49 """ 50 if minimize and not do_restart: 51 assert system_data["nreplicas"] == 1, "Minimization is only supported for single replica systems" 52 model.preproc_state = preproc_state 53 convert = us.KCALPERMOL / model_energy_unit 54 nat = system_data["nat"] 55 def energy_force_fn(coordinates): 56 inputs = {**conformation, "coordinates": coordinates} 57 e, f, _ = model.energy_and_forces( 58 **inputs, gpu_preprocessing=True 59 ) 60 e = float(e[0]) * convert / nat 61 f = np.array(f) * convert 62 return e, f 63 tol = simulation_parameters.get("xyz_input/minimize_ftol", 1e-1/us.KCALPERMOL)*us.KCALPERMOL 64 """@keyword[fennol_md] xyz_input/minimize_ftol 65 Force tolerance for minimization. 66 Default: 0.1 kcal/mol/Å 67 """ 68 print(f"# Minimizing initial configuration with RMS force tolerance = {tol:.1e} kcal/mol/A") 69 conformation["coordinates"], success = optimize_fire2( 70 conformation["coordinates"], 71 energy_force_fn, 72 atol=tol, 73 max_disp=0.02, 74 ) 75 if success: 76 print("# Minimization successful") 77 else: 78 print("# Warning: Minimization failed, continuing with last configuration") 79 # write the minimized coordinates as an xyz file 80 from ..utils.io import write_xyz_frame 81 with open(system_data["name"]+".opt.xyz", "w") as f: 82 write_xyz_frame(f, system_data["symbols"],np.array(conformation["coordinates"]),cell=conformation.get("cells", None)) 83 print("# Minimized configuration written to", system_data["name"]+".opt.xyz") 84 preproc_state = model.preproc_state 85 conformation = model.preprocessing.process(preproc_state, conformation) 86 system_data["initial_coordinates"] = np.array(conformation["coordinates"]).copy() 87 88 ### get dynamics parameters 89 dt = simulation_parameters.get("dt") 90 """@keyword[fennol_md] dt 91 Integration time step. Required parameter. 92 Type: float, Required 93 """ 94 dt2 = 0.5 * dt 95 mass = system_data["mass"] 96 densmass = system_data["totmass_Da"] * (us.MOL/us.CM**3) 97 nat = system_data["nat"] 98 dtm = jnp.asarray(dt / mass[:, None], dtype=fprec) 99 ek_avg = 0.5 * nat * system_data["kT"] * np.eye(3) 100 101 nreplicas = system_data.get("nreplicas", 1) 102 nbeads = system_data.get("nbeads", None) 103 if nbeads is not None: 104 nreplicas = nbeads 105 dtm = dtm[None, :, :] 106 107 ### INITIALIZE DYNAMICS STATE 108 system = {"coordinates": conformation["coordinates"]} 109 dyn_state = { 110 "istep": 0, 111 "dt": dt, 112 "pimd": nbeads is not None, 113 "preproc_state": preproc_state, 114 "start_time_ps": restart_data.get("simulation_time_ps", 0.), 115 } 116 gradient_keys = ["coordinates"] 117 thermo_updates = [] 118 119 ### INITIALIZE THERMOSTAT 120 thermostat_rng, rng_key = jax.random.split(rng_key) 121 ( 122 thermostat, 123 thermostat_post, 124 thermostat_state, 125 initial_vel, 126 dyn_state["thermostat_name"], 127 ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data) 128 do_thermostat_post = thermostat_post is not None 129 if do_thermostat_post: 130 thermostat_post, post_state = thermostat_post 131 dyn_state["thermostat_post_state"] = post_state 132 133 system["thermostat"] = thermostat_state 134 system["vel"] = restart_data.get("vel", initial_vel).astype(fprec) 135 136 ### PBC 137 pbc_data = system_data.get("pbc", None) 138 if pbc_data is not None: 139 ### INITIALIZE BAROSTAT 140 barostat_key, rng_key = jax.random.split(rng_key) 141 thermo_update_ensemble, variable_cell, barostat_state = get_barostat( 142 thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data 143 ) 144 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 145 system["barostat"] = barostat_state 146 system["cell"] = conformation["cells"][0] 147 if estimate_pressure: 148 pressure_o_weight = simulation_parameters.get("pressure_o_weight", 1.0) 149 """@keyword[fennol_md] pressure_o_weight 150 Weight factor for mixing middle (O) and outer time step kinetic energies in pressure estimator. 151 Default: 1.0 152 """ 153 assert ( 154 0.0 <= pressure_o_weight <= 1.0 155 ), "pressure_o_weight must be between 0 and 1" 156 gradient_keys.append("strain") 157 print("# Estimate pressure: ", estimate_pressure) 158 else: 159 estimate_pressure = False 160 variable_cell = False 161 162 def thermo_update_ensemble(x, v, system): 163 v, thermostat_state = thermostat(v, system["thermostat"]) 164 return x, v, {**system, "thermostat": thermostat_state} 165 166 dyn_state["estimate_pressure"] = estimate_pressure 167 dyn_state["variable_cell"] = variable_cell 168 thermo_updates.append(thermo_update_ensemble) 169 170 if estimate_pressure: 171 use_average_Pkin = simulation_parameters.get("use_average_Pkin", False) 172 """@keyword[fennol_md] use_average_Pkin 173 Use time-averaged kinetic energy for pressure estimation instead of instantaneous values. 174 Default: False 175 """ 176 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 177 if is_qtb and use_average_Pkin: 178 raise ValueError( 179 "use_average_Pkin is not compatible with QTB thermostat, please set use_average_Pkin to False" 180 ) 181 182 183 ### ENERGY ENSEMBLE 184 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 185 """@keyword[fennol_md] etot_ensemble_key 186 Key for energy ensemble calculation. Enables computation of ensemble weights. 187 Default: None 188 """ 189 190 ### COLVARS 191 colvars_definitions = simulation_parameters.get("colvars", None) 192 """@keyword[fennol_md] colvars 193 Collective variables definitions for enhanced sampling or monitoring. 194 Default: None 195 """ 196 use_colvars = colvars_definitions is not None 197 if use_colvars: 198 colvars_calculators, colvars_names = setup_colvars(colvars_definitions) 199 dyn_state["colvars"] = colvars_names 200 201 ### IR SPECTRUM 202 do_ir_spectrum = simulation_parameters.get("ir_spectrum", False) 203 """@keyword[fennol_md] ir_spectrum 204 Calculate infrared spectrum from molecular dipole moment time series. 205 Default: False 206 """ 207 assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean" 208 if do_ir_spectrum: 209 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 210 model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum( 211 simulation_parameters, system_data, fprec, dt, is_qtb 212 ) 213 dyn_state["ir_spectrum"] = ir_state 214 215 ### BUILD GRADIENT FUNCTION 216 energy_and_gradient = model.get_gradient_function( 217 *gradient_keys, jit=True, variables_as_input=True 218 ) 219 220 ### COLLECT THERMO UPDATES 221 if len(thermo_updates) == 1: 222 thermo_update = thermo_updates[0] 223 else: 224 225 def thermo_update(x, v, system): 226 for update in thermo_updates: 227 x, v, system = update(x, v, system) 228 return x, v, system 229 230 ### RING POLYMER INITIALIZATION 231 if nbeads is not None: 232 cay_correction = simulation_parameters.get("cay_correction", True) 233 """@keyword[fennol_md] cay_correction 234 Use Cayley propagator for ring polymer molecular dynamics instead of standard propagation. 235 Default: True 236 """ 237 omk = system_data["omk"] 238 eigmat = system_data["eigmat"] 239 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 240 if cay_correction: 241 axx = jnp.asarray(2 * cayfact) 242 axv = jnp.asarray(dt * cayfact) 243 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 244 else: 245 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 246 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 247 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 248 249 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 250 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 251 system["coordinates"] = eigx 252 253 ############################################### 254 ### DEFINE UPDATE FUNCTION 255 @jax.jit 256 def update_conformation(conformation, system): 257 x = system["coordinates"] 258 if nbeads is not None: 259 x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * ( 260 nbeads**0.5 261 ) 262 conformation = {**conformation, "coordinates": x} 263 if variable_cell: 264 conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0) 265 266 267 268 return conformation 269 270 ############################################### 271 ### DEFINE INTEGRATION FUNCTIONS 272 def integrate_A_half(x0, v0): 273 if nbeads is None: 274 return x0 + dt2 * v0, v0 275 276 # update coordinates and velocities of a free ring polymer for a half time step 277 eigx_c = x0[0] + dt2 * v0[0] 278 eigv_c = v0[0] 279 eigx = x0[1:] * axx + v0[1:] * axv 280 eigv = x0[1:] * avx + v0[1:] * axx 281 282 return ( 283 jnp.concatenate((eigx_c[None], eigx), axis=0), 284 jnp.concatenate((eigv_c[None], eigv), axis=0), 285 ) 286 287 @jax.jit 288 def integrate(system): 289 x = system["coordinates"] 290 v = system["vel"] + dtm * system["forces"] 291 x, v = integrate_A_half(x, v) 292 x, v, system = thermo_update(x, v, system) 293 x, v = integrate_A_half(x, v) 294 295 return {**system, "coordinates": x, "vel": v} 296 297 ############################################### 298 ### DEFINE OBSERVABLE FUNCTION 299 @jax.jit 300 def update_observables(system, conformation): 301 ### POTENTIAL ENERGY AND FORCES 302 epot, de, out = energy_and_gradient(model.variables, conformation) 303 out["forces"] = -de["coordinates"] 304 epot = epot / model_energy_unit 305 de = {k: v / model_energy_unit for k, v in de.items()} 306 forces = -de["coordinates"] 307 308 if nbeads is not None: 309 ### PROJECT FORCES ONTO POLYMER NORMAL MODES 310 forces = jnp.einsum( 311 "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3) 312 ) * (1.0 / nbeads**0.5) 313 314 system = { 315 **system, 316 "epot": jnp.mean(epot), 317 "forces": forces, 318 "energy_gradients": de, 319 } 320 321 ### KINETIC ENERGY 322 v = system["vel"] 323 if nbeads is None: 324 corr_kin = system["thermostat"].get("corr_kin", 1.0) 325 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 326 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 327 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 328 ) 329 else: 330 ek_c = 0.5 * jnp.sum( 331 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0 332 ) 333 ek = ek_c - 0.5 * jnp.sum( 334 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 335 axis=(0, 1), 336 ) 337 system["ek_c"] = jnp.trace(ek_c) 338 339 system["ek"] = jnp.trace(ek) 340 system["ek_tensor"] = ek 341 342 if estimate_pressure: 343 if use_average_Pkin: 344 ek = ek_avg 345 elif pressure_o_weight != 1.0: 346 v = system["vel"] + 0.5 * dtm * system["forces"] 347 if nbeads is None: 348 corr_kin = system["thermostat"].get("corr_kin", 1.0) 349 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 350 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 351 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 352 ) 353 else: 354 ek_c = 0.5 * jnp.sum( 355 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], 356 axis=0, 357 ) 358 ek = ek_c - 0.5 * jnp.sum( 359 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 360 axis=(0, 1), 361 ) 362 b = pressure_o_weight 363 ek = (1.0 - b) * ek + b * system["ek_tensor"] 364 365 vir = jnp.mean(de["strain"], axis=0) 366 system["virial"] = vir 367 out["virial_tensor"] = vir * model_energy_unit 368 369 volume = jnp.abs(jnp.linalg.det(system["cell"])) 370 Pres = ek*(2./volume) - vir/volume 371 system["pressure_tensor"] = Pres 372 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 373 if variable_cell: 374 density = densmass / volume 375 system["density"] = density 376 system["volume"] = volume 377 378 if ensemble_key is not None: 379 kT = system_data["kT"] 380 dE = ( 381 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"] 382 ) 383 system["ensemble_weights"] = -dE / kT 384 385 if "total_dipole" in out: 386 if nbeads is None: 387 system["total_dipole"] = out["total_dipole"][0] 388 else: 389 system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0) 390 391 if use_colvars: 392 coords = system["coordinates"].reshape(-1, nat, 3)[0] 393 colvars = {} 394 for colvar_name, colvar_calc in colvars_calculators.items(): 395 colvars[colvar_name] = colvar_calc(coords) 396 system["colvars"] = colvars 397 398 return system, out 399 400 ############################################### 401 ### IR SPECTRUM 402 if do_ir_spectrum: 403 # @jax.jit 404 # def update_dipole(ir_state,system,conformation): 405 # def mumodel(coords): 406 # out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords}) 407 # if nbeads is None: 408 # return out["total_dipole"][0] 409 # return out["total_dipole"].sum(axis=0) 410 # dmudqmodel = jax.jacobian(mumodel) 411 412 # dmudq = dmudqmodel(conformation["coordinates"]) 413 # # print(dmudq.shape) 414 # if nbeads is None: 415 # vel = system["vel"].reshape(-1,1,nat,3)[0] 416 # mudot = (vel*dmudq).sum(axis=(1,2)) 417 # else: 418 # dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1) 419 # vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) * nbeads**0.5 420 # ) 421 # # vel = system["vel"][0].reshape(1,nat,3) 422 # mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads 423 424 # ir_state = save_dipole(mudot,ir_state) 425 # return ir_state 426 @jax.jit 427 def update_conformation_ir(conformation, system): 428 conformation = { 429 **conformation, 430 "coordinates": system["coordinates"].reshape(-1, nat, 3)[0], 431 "natoms": jnp.asarray([nat]), 432 "batch_index": jnp.asarray([0] * nat), 433 "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]), 434 } 435 if variable_cell: 436 conformation["cells"] = system["cell"][None, :, :] 437 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 438 None, :, : 439 ] 440 return conformation 441 442 @jax.jit 443 def update_dipole(ir_state, system, conformation): 444 if model_ir is not None: 445 out = model_ir._apply(model_ir.variables, conformation) 446 q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 447 dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 448 else: 449 q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 450 dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 451 if nbeads is not None: 452 q = jnp.mean(q, axis=0) 453 dip = jnp.mean(dip, axis=0) 454 vel = system["vel"][0] 455 pos = system["coordinates"][0] 456 else: 457 q = q[0] 458 dip = dip[0] 459 vel = system["vel"].reshape(-1, nat, 3)[0] 460 pos = system["coordinates"].reshape(-1, nat, 3)[0] 461 462 if pbc_data is not None: 463 cell_reciprocal = ( 464 conformation["cells"][0], 465 conformation["reciprocal_cells"][0], 466 ) 467 else: 468 cell_reciprocal = None 469 470 ir_state = save_dipole( 471 q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state 472 ) 473 return ir_state 474 475 ############################################### 476 ### GRAPH UPDATES 477 478 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 479 """@keyword[fennol_md] nblist_verbose 480 Print verbose information about neighbor list updates and reallocations. 481 Default: False 482 """ 483 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 484 """@keyword[fennol_md] nblist_stride 485 Number of steps between full neighbor list rebuilds. Auto-calculated from skin if <= 0. 486 Default: -1 487 """ 488 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) 489 """@keyword[fennol_md] nblist_warmup_time 490 Time period for neighbor list warmup before using skin updates. 491 Default: -1.0 492 """ 493 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 494 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 495 """@keyword[fennol_md] nblist_skin 496 Neighbor list skin distance for efficient updates (in Angstroms). 497 Default: -1.0 498 """ 499 if nblist_skin > 0: 500 if nblist_stride <= 0: 501 ## reference skin parameters at 300K (from Tinker-HP) 502 ## => skin of 2 A gives you 40 fs without complete rebuild 503 t_ref = 40.0 /us.FS # FS 504 nblist_skin_ref = 2.0 # A 505 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 506 print( 507 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 508 ) 509 510 if nblist_skin <= 0: 511 nblist_stride = 1 512 513 dyn_state["nblist_countdown"] = 0 514 dyn_state["print_skin_activation"] = nblist_warmup > 0 515 516 def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False): 517 nblist_countdown = dyn_state["nblist_countdown"] 518 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 519 ### FULL NBLIST REBUILD 520 dyn_state["nblist_countdown"] = nblist_stride - 1 521 preproc_state = dyn_state["preproc_state"] 522 conformation = model.preprocessing.process( 523 preproc_state, update_conformation(conformation, system) 524 ) 525 preproc_state, state_up, conformation, overflow = ( 526 model.preprocessing.check_reallocate(preproc_state, conformation) 527 ) 528 dyn_state["preproc_state"] = preproc_state 529 if nblist_verbose and overflow: 530 print("step", istep, ", nblist overflow => reallocating nblist") 531 print("size updates:", state_up) 532 533 if do_ir_spectrum and model_ir is not None: 534 conformation_ir = model_ir.preprocessing.process( 535 dyn_state["preproc_state_ir"], 536 update_conformation_ir(dyn_state["conformation_ir"], system), 537 ) 538 ( 539 dyn_state["preproc_state_ir"], 540 _, 541 dyn_state["conformation_ir"], 542 overflow, 543 ) = model_ir.preprocessing.check_reallocate( 544 dyn_state["preproc_state_ir"], conformation_ir 545 ) 546 547 else: 548 ### SKIN UPDATE 549 if dyn_state["print_skin_activation"]: 550 if nblist_verbose: 551 print( 552 "step", 553 istep, 554 ", end of nblist warmup phase => activating skin updates", 555 ) 556 dyn_state["print_skin_activation"] = False 557 558 dyn_state["nblist_countdown"] = nblist_countdown - 1 559 conformation = model.preprocessing.update_skin( 560 update_conformation(conformation, system) 561 ) 562 if do_ir_spectrum and model_ir is not None: 563 dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin( 564 update_conformation_ir(dyn_state["conformation_ir"], system) 565 ) 566 567 return conformation, dyn_state 568 569 ################################################ 570 ### DEFINE STEP FUNCTION 571 def step(istep, dyn_state, system, conformation, force_preprocess=False): 572 573 dyn_state = { 574 **dyn_state, 575 "istep": dyn_state["istep"] + 1, 576 } 577 578 ### INTEGRATE EQUATIONS OF MOTION 579 system = integrate(system) 580 581 ### UPDATE CONFORMATION AND GRAPHS 582 conformation, dyn_state = update_graphs( 583 istep, dyn_state, system, conformation, force_preprocess 584 ) 585 586 ## COMPUTE FORCES AND OBSERVABLES 587 system, out = update_observables(system, conformation) 588 589 ## END OF STEP UPDATES 590 if do_thermostat_post: 591 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 592 system["thermostat"], dyn_state["thermostat_post_state"] 593 ) 594 595 if do_ir_spectrum: 596 ir_state = update_dipole( 597 dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"] 598 ) 599 dyn_state["ir_spectrum"] = ir_post(ir_state) 600 601 return dyn_state, system, conformation, out 602 603 ########################################################### 604 605 print("# Computing initial energy and forces") 606 607 conformation = update_conformation(conformation, system) 608 # initialize IR conformation 609 if do_ir_spectrum and model_ir is not None: 610 dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = ( 611 model_ir.preprocessing( 612 model_ir.preproc_state, 613 update_conformation_ir(conformation, system), 614 ) 615 ) 616 617 system, _ = update_observables(system, conformation) 618 619 return step, update_conformation, system_data, dyn_state, conformation, system