fennol.md.integrate
1import sys, os, io 2import argparse 3import time 4from pathlib import Path 5import math 6 7import numpy as np 8from typing import Optional, Callable 9from collections import defaultdict 10from functools import partial 11import jax 12import jax.numpy as jnp 13 14from flax.core import freeze, unfreeze 15 16from ..utils.io import last_xyz_frame 17 18 19from ..models import FENNIX 20 21from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES 22from ..utils.atomic_units import AtomicUnits as au 23from ..utils.input_parser import parse_input 24from .thermostats import get_thermostat 25from .barostats import get_barostat 26 27from copy import deepcopy 28from .initial import initialize_system 29 30 31def initialize_dynamics( 32 simulation_parameters, system_data, conformation, model, fprec, rng_key 33): 34 step, update_conformation, dyn_state, thermo_state, vel = initialize_integrator( 35 simulation_parameters, system_data, model, fprec, rng_key 36 ) 37 ### initialize system 38 system = initialize_system( 39 conformation, 40 vel, 41 model, 42 system_data, 43 fprec, 44 ) 45 return step, update_conformation, dyn_state, {**system, **thermo_state} 46 47 48def initialize_integrator(simulation_parameters, system_data, model, fprec, rng_key): 49 dt = simulation_parameters.get("dt") * au.FS 50 dt2 = 0.5 * dt 51 nbeads = system_data.get("nbeads", None) 52 53 mass = system_data["mass"] 54 totmass_amu = system_data["totmass_amu"] 55 nat = system_data["nat"] 56 dt2m = jnp.asarray(dt2 / mass[:, None], dtype=fprec) 57 if nbeads is not None: 58 dt2m = dt2m[None, :, :] 59 60 dyn_state = { 61 "istep": 0, 62 "dt": dt, 63 "pimd": nbeads is not None, 64 } 65 66 model_energy_unit = au.get_multiplier(model.energy_unit) 67 68 # initialize thermostat 69 thermostat_rng, rng_key = jax.random.split(rng_key) 70 thermostat, thermostat_post, thermostat_state, vel, dyn_state["thermostat_name"] = ( 71 get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng) 72 ) 73 74 do_thermostat_post = thermostat_post is not None 75 if do_thermostat_post: 76 thermostat_post, post_state = thermostat_post 77 dyn_state["thermostat_post_state"] = post_state 78 79 pbc_data = system_data.get("pbc", None) 80 if pbc_data is not None: 81 thermo_update, variable_cell, barostat_state = get_barostat( 82 thermostat, simulation_parameters, dt, system_data, fprec, rng_key 83 ) 84 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 85 thermo_state = {"thermostat": thermostat_state, "barostat": barostat_state} 86 87 else: 88 estimate_pressure = False 89 variable_cell = False 90 91 def thermo_update(x, v, system): 92 v, thermostat_state = thermostat(v, system["thermostat"]) 93 return x, v, {**system, "thermostat": thermostat_state} 94 95 thermo_state = {"thermostat": thermostat_state} 96 97 print("# Estimate pressure: ", estimate_pressure) 98 99 dyn_state["estimate_pressure"] = estimate_pressure 100 dyn_state["variable_cell"] = variable_cell 101 102 dyn_state["print_timings"] = simulation_parameters.get("print_timings", False) 103 if dyn_state["print_timings"]: 104 dyn_state["timings"] = defaultdict(lambda: 0.0) 105 106 ### NBLIST 107 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 108 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 109 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS 110 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 111 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 112 if nblist_skin > 0: 113 if nblist_stride <= 0: 114 ## reference skin parameters at 300K (from Tinker-HP) 115 ## => skin of 2 A gives you 40 fs without complete rebuild 116 t_ref = 40.0 # FS 117 nblist_skin_ref = 2.0 # A 118 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 119 print( 120 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 121 ) 122 123 if nblist_skin <= 0: 124 nblist_stride = 1 125 126 dyn_state["nblist_countdown"] = 0 127 dyn_state["print_skin_activation"] = nblist_warmup > 0 128 129 ### ENSEMBLE 130 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 131 132 ### DEFINE INTEGRATION FUNCTIONS 133 134 if nbeads is not None: 135 ### RING POLYMER INTEGRATOR 136 cay_correction = simulation_parameters.get("cay_correction", True) 137 omk = system_data["omk"] 138 eigmat = system_data["eigmat"] 139 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 140 if cay_correction: 141 axx = jnp.asarray(2 * cayfact) 142 axv = jnp.asarray(dt * cayfact) 143 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 144 else: 145 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 146 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 147 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 148 149 @jax.jit 150 def update_conformation(conformation, system): 151 eigx = system["coordinates"] 152 """update bead coordinates from ring polymer normal modes""" 153 x = jnp.einsum("in,n...->i...", eigmat, eigx).reshape(nbeads * nat, 3) * ( 154 nbeads**0.5 155 ) 156 conformation = {**conformation, "coordinates": x} 157 if variable_cell: 158 conformation["cells"] = system["cell"][None, :, :].repeat(nbeads, axis=0) 159 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 160 None, :, : 161 ].repeat(nbeads, axis=0) 162 return conformation 163 164 @jax.jit 165 def coords_to_eig(x): 166 """update normal modes from bead coordinates""" 167 return jnp.einsum("in,i...->n...", eigmat, x.reshape(nbeads, nat, 3)) * ( 168 1.0 / nbeads**0.5 169 ) 170 171 def halfstep_free_polymer(eigx0, eigv0): 172 """update coordinates and velocities of a free ring polymer for a half time step""" 173 eigx_c = eigx0[0] + dt2 * eigv0[0] 174 eigv_c = eigv0[0] 175 eigx = eigx0[1:] * axx + eigv0[1:] * axv 176 eigv = eigx0[1:] * avx + eigv0[1:] * axx 177 178 return ( 179 jnp.concatenate((eigx_c[None], eigx), axis=0), 180 jnp.concatenate((eigv_c[None], eigv), axis=0), 181 ) 182 183 @jax.jit 184 def stepA(system): 185 eigx = system["coordinates"] 186 eigv = system["vel"] + dt2m * system["forces"] 187 eigx, eigv = halfstep_free_polymer(eigx, eigv) 188 eigx, eigv, system = thermo_update(eigx, eigv, system) 189 eigx, eigv = halfstep_free_polymer(eigx, eigv) 190 191 return { 192 **system, 193 "coordinates": eigx, 194 "vel": eigv, 195 } 196 197 @jax.jit 198 def update_forces(system, conformation): 199 if estimate_pressure: 200 epot, f, vir_t, out = model._energy_and_forces_and_virial( 201 model.variables, conformation 202 ) 203 epot = epot / model_energy_unit 204 f = f / model_energy_unit 205 vir_t = vir_t / model_energy_unit 206 207 new_sys = { 208 **system, 209 "forces": coords_to_eig(f), 210 "epot": jnp.mean(epot), 211 "virial": jnp.mean(vir_t, axis=0), 212 } 213 else: 214 epot, f, out = model._energy_and_forces(model.variables, conformation) 215 epot = epot / model_energy_unit 216 f = f / model_energy_unit 217 new_sys = {**system, "forces": coords_to_eig(f), "epot": jnp.mean(epot)} 218 if ensemble_key is not None: 219 kT = system_data["kT"] 220 dE = ( 221 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit 222 - new_sys["epot"] 223 ) 224 new_sys["ensemble_weights"] = -dE / kT 225 return new_sys 226 227 @jax.jit 228 def stepB(system): 229 eigv = system["vel"] + dt2m * system["forces"] 230 231 ek_c = 0.5 * jnp.sum( 232 mass[:, None, None] * eigv[0, :, :, None] * eigv[0, :, None, :], axis=0 233 ) 234 ek = ek_c - 0.5 * jnp.sum( 235 system["coordinates"][1:, :, :, None] 236 * system["forces"][1:, :, None, :], 237 axis=(0, 1), 238 ) 239 system = { 240 **system, 241 "vel": eigv, 242 "ek_tensor": ek, 243 "ek_c": jnp.trace(ek_c), 244 "ek": jnp.trace(ek), 245 } 246 247 if estimate_pressure: 248 vir = system["virial"] 249 volume = jnp.abs(jnp.linalg.det(system["cell"])) 250 Pres = (2 * ek - vir) / volume 251 system["pressure_tensor"] = Pres 252 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 253 if variable_cell: 254 density = totmass_amu / volume 255 system["density"] = density 256 system["volume"] = volume 257 258 return system 259 260 else: 261 ### CLASSICAL MD INTEGRATOR 262 @jax.jit 263 def update_conformation(conformation, system): 264 conformation = {**conformation, "coordinates": system["coordinates"]} 265 if variable_cell: 266 conformation["cells"] = system["cell"][None, :, :] 267 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 268 None, :, : 269 ] 270 return conformation 271 272 @jax.jit 273 def stepA(system): 274 v = system["vel"] 275 f = system["forces"] 276 x = system["coordinates"] 277 278 v = v + f * dt2m 279 x = x + dt2 * v 280 x, v, system = thermo_update(x, v, system) 281 x = x + dt2 * v 282 283 return {**system, "coordinates": x, "vel": v} 284 285 @jax.jit 286 def update_forces(system, conformation): 287 if estimate_pressure: 288 epot, f, vir_t, out = model._energy_and_forces_and_virial( 289 model.variables, conformation 290 ) 291 epot = epot / model_energy_unit 292 f = f / model_energy_unit 293 vir_t = vir_t / model_energy_unit 294 new_sys = { 295 **system, 296 "forces": f, 297 "epot": epot[0], 298 "virial": vir_t[0], 299 } 300 else: 301 epot, f, out = model._energy_and_forces(model.variables, conformation) 302 epot = epot / model_energy_unit 303 f = f / model_energy_unit 304 new_sys = {**system, "forces": f, "epot": epot[0]} 305 306 if ensemble_key is not None: 307 kT = system_data["kT"] 308 dE = out[ensemble_key][0, :] / model_energy_unit - new_sys["epot"] 309 new_sys["ensemble_weights"] = -dE / kT 310 return new_sys 311 312 @jax.jit 313 def stepB(system): 314 v = system["vel"] 315 f = system["forces"] 316 state_th = system["thermostat"] 317 318 v = v + f * dt2m 319 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 320 ek_tensor = ( 321 0.5 322 * jnp.sum(mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0) 323 / state_th.get("corr_kin", 1.0) 324 ) 325 system = { 326 **system, 327 "vel": v, 328 "ek": jnp.trace(ek_tensor), 329 "ek_tensor": ek_tensor, 330 } 331 332 if estimate_pressure: 333 vir = system["virial"] 334 volume = jnp.abs(jnp.linalg.det(system["cell"])) 335 Pres = (2 * ek_tensor - vir) / volume 336 system["pressure_tensor"] = Pres 337 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 338 if variable_cell: 339 density = totmass_amu / volume 340 system["density"] = density 341 system["volume"] = volume 342 343 return system 344 345 ### DEFINE STEP FUNCTION COMMON TO CLASSICAL AND PIMD 346 def step( 347 istep, dyn_state, system, conformation, preproc_state, force_preprocess=False 348 ): 349 tstep0 = time.time() 350 print_timings = "timings" in dyn_state 351 352 dyn_state = { 353 **dyn_state, 354 "istep": dyn_state["istep"] + 1, 355 } 356 if print_timings: 357 prev_timings = dyn_state["timings"] 358 timings = defaultdict(lambda: 0.0) 359 timings.update(prev_timings) 360 361 ## take a half step (update positions, nblist and half velocities) 362 system = stepA(system) 363 364 if print_timings: 365 system["coordinates"].block_until_ready() 366 timings["Integrator"] += time.time() - tstep0 367 tstep0 = time.time() 368 369 ### update conformation and graphs 370 nblist_countdown = dyn_state["nblist_countdown"] 371 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 372 ### full nblist update 373 dyn_state["nblist_countdown"] = nblist_stride - 1 374 conformation = model.preprocessing.process( 375 preproc_state, update_conformation(conformation, system) 376 ) 377 preproc_state, state_up, conformation, overflow = ( 378 model.preprocessing.check_reallocate(preproc_state, conformation) 379 ) 380 if nblist_verbose and overflow: 381 print("step", istep, ", nblist overflow => reallocating nblist") 382 print("size updates:", state_up) 383 384 if print_timings: 385 conformation["coordinates"].block_until_ready() 386 timings["Preprocessing"] += time.time() - tstep0 387 tstep0 = time.time() 388 389 else: 390 ### skin update 391 if dyn_state["print_skin_activation"]: 392 if nblist_verbose: 393 print( 394 "step", 395 istep, 396 ", end of nblist warmup phase => activating skin updates", 397 ) 398 dyn_state["print_skin_activation"] = False 399 400 dyn_state["nblist_countdown"] = nblist_countdown - 1 401 conformation = model.preprocessing.update_skin( 402 update_conformation(conformation, system) 403 ) 404 405 if print_timings: 406 conformation["coordinates"].block_until_ready() 407 timings["Skin update"] += time.time() - tstep0 408 tstep0 = time.time() 409 410 ## compute forces 411 system = update_forces(system, conformation) 412 if print_timings: 413 system["coordinates"].block_until_ready() 414 timings["Forces"] += time.time() - tstep0 415 tstep0 = time.time() 416 417 ## finish step 418 system = stepB(system) 419 420 ## end of step update (mostly for adQTB) 421 if do_thermostat_post: 422 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 423 system["thermostat"], dyn_state["thermostat_post_state"] 424 ) 425 426 if print_timings: 427 system["coordinates"].block_until_ready() 428 timings["Integrator"] += time.time() - tstep0 429 tstep0 = time.time() 430 431 # store timings 432 dyn_state["timings"] = timings 433 434 return dyn_state, system, conformation, preproc_state 435 436 return step, update_conformation, dyn_state, thermo_state, vel
def
initialize_dynamics( simulation_parameters, system_data, conformation, model, fprec, rng_key):
32def initialize_dynamics( 33 simulation_parameters, system_data, conformation, model, fprec, rng_key 34): 35 step, update_conformation, dyn_state, thermo_state, vel = initialize_integrator( 36 simulation_parameters, system_data, model, fprec, rng_key 37 ) 38 ### initialize system 39 system = initialize_system( 40 conformation, 41 vel, 42 model, 43 system_data, 44 fprec, 45 ) 46 return step, update_conformation, dyn_state, {**system, **thermo_state}
def
initialize_integrator(simulation_parameters, system_data, model, fprec, rng_key):
49def initialize_integrator(simulation_parameters, system_data, model, fprec, rng_key): 50 dt = simulation_parameters.get("dt") * au.FS 51 dt2 = 0.5 * dt 52 nbeads = system_data.get("nbeads", None) 53 54 mass = system_data["mass"] 55 totmass_amu = system_data["totmass_amu"] 56 nat = system_data["nat"] 57 dt2m = jnp.asarray(dt2 / mass[:, None], dtype=fprec) 58 if nbeads is not None: 59 dt2m = dt2m[None, :, :] 60 61 dyn_state = { 62 "istep": 0, 63 "dt": dt, 64 "pimd": nbeads is not None, 65 } 66 67 model_energy_unit = au.get_multiplier(model.energy_unit) 68 69 # initialize thermostat 70 thermostat_rng, rng_key = jax.random.split(rng_key) 71 thermostat, thermostat_post, thermostat_state, vel, dyn_state["thermostat_name"] = ( 72 get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng) 73 ) 74 75 do_thermostat_post = thermostat_post is not None 76 if do_thermostat_post: 77 thermostat_post, post_state = thermostat_post 78 dyn_state["thermostat_post_state"] = post_state 79 80 pbc_data = system_data.get("pbc", None) 81 if pbc_data is not None: 82 thermo_update, variable_cell, barostat_state = get_barostat( 83 thermostat, simulation_parameters, dt, system_data, fprec, rng_key 84 ) 85 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 86 thermo_state = {"thermostat": thermostat_state, "barostat": barostat_state} 87 88 else: 89 estimate_pressure = False 90 variable_cell = False 91 92 def thermo_update(x, v, system): 93 v, thermostat_state = thermostat(v, system["thermostat"]) 94 return x, v, {**system, "thermostat": thermostat_state} 95 96 thermo_state = {"thermostat": thermostat_state} 97 98 print("# Estimate pressure: ", estimate_pressure) 99 100 dyn_state["estimate_pressure"] = estimate_pressure 101 dyn_state["variable_cell"] = variable_cell 102 103 dyn_state["print_timings"] = simulation_parameters.get("print_timings", False) 104 if dyn_state["print_timings"]: 105 dyn_state["timings"] = defaultdict(lambda: 0.0) 106 107 ### NBLIST 108 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 109 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 110 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS 111 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 112 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 113 if nblist_skin > 0: 114 if nblist_stride <= 0: 115 ## reference skin parameters at 300K (from Tinker-HP) 116 ## => skin of 2 A gives you 40 fs without complete rebuild 117 t_ref = 40.0 # FS 118 nblist_skin_ref = 2.0 # A 119 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 120 print( 121 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 122 ) 123 124 if nblist_skin <= 0: 125 nblist_stride = 1 126 127 dyn_state["nblist_countdown"] = 0 128 dyn_state["print_skin_activation"] = nblist_warmup > 0 129 130 ### ENSEMBLE 131 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 132 133 ### DEFINE INTEGRATION FUNCTIONS 134 135 if nbeads is not None: 136 ### RING POLYMER INTEGRATOR 137 cay_correction = simulation_parameters.get("cay_correction", True) 138 omk = system_data["omk"] 139 eigmat = system_data["eigmat"] 140 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 141 if cay_correction: 142 axx = jnp.asarray(2 * cayfact) 143 axv = jnp.asarray(dt * cayfact) 144 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 145 else: 146 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 147 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 148 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 149 150 @jax.jit 151 def update_conformation(conformation, system): 152 eigx = system["coordinates"] 153 """update bead coordinates from ring polymer normal modes""" 154 x = jnp.einsum("in,n...->i...", eigmat, eigx).reshape(nbeads * nat, 3) * ( 155 nbeads**0.5 156 ) 157 conformation = {**conformation, "coordinates": x} 158 if variable_cell: 159 conformation["cells"] = system["cell"][None, :, :].repeat(nbeads, axis=0) 160 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 161 None, :, : 162 ].repeat(nbeads, axis=0) 163 return conformation 164 165 @jax.jit 166 def coords_to_eig(x): 167 """update normal modes from bead coordinates""" 168 return jnp.einsum("in,i...->n...", eigmat, x.reshape(nbeads, nat, 3)) * ( 169 1.0 / nbeads**0.5 170 ) 171 172 def halfstep_free_polymer(eigx0, eigv0): 173 """update coordinates and velocities of a free ring polymer for a half time step""" 174 eigx_c = eigx0[0] + dt2 * eigv0[0] 175 eigv_c = eigv0[0] 176 eigx = eigx0[1:] * axx + eigv0[1:] * axv 177 eigv = eigx0[1:] * avx + eigv0[1:] * axx 178 179 return ( 180 jnp.concatenate((eigx_c[None], eigx), axis=0), 181 jnp.concatenate((eigv_c[None], eigv), axis=0), 182 ) 183 184 @jax.jit 185 def stepA(system): 186 eigx = system["coordinates"] 187 eigv = system["vel"] + dt2m * system["forces"] 188 eigx, eigv = halfstep_free_polymer(eigx, eigv) 189 eigx, eigv, system = thermo_update(eigx, eigv, system) 190 eigx, eigv = halfstep_free_polymer(eigx, eigv) 191 192 return { 193 **system, 194 "coordinates": eigx, 195 "vel": eigv, 196 } 197 198 @jax.jit 199 def update_forces(system, conformation): 200 if estimate_pressure: 201 epot, f, vir_t, out = model._energy_and_forces_and_virial( 202 model.variables, conformation 203 ) 204 epot = epot / model_energy_unit 205 f = f / model_energy_unit 206 vir_t = vir_t / model_energy_unit 207 208 new_sys = { 209 **system, 210 "forces": coords_to_eig(f), 211 "epot": jnp.mean(epot), 212 "virial": jnp.mean(vir_t, axis=0), 213 } 214 else: 215 epot, f, out = model._energy_and_forces(model.variables, conformation) 216 epot = epot / model_energy_unit 217 f = f / model_energy_unit 218 new_sys = {**system, "forces": coords_to_eig(f), "epot": jnp.mean(epot)} 219 if ensemble_key is not None: 220 kT = system_data["kT"] 221 dE = ( 222 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit 223 - new_sys["epot"] 224 ) 225 new_sys["ensemble_weights"] = -dE / kT 226 return new_sys 227 228 @jax.jit 229 def stepB(system): 230 eigv = system["vel"] + dt2m * system["forces"] 231 232 ek_c = 0.5 * jnp.sum( 233 mass[:, None, None] * eigv[0, :, :, None] * eigv[0, :, None, :], axis=0 234 ) 235 ek = ek_c - 0.5 * jnp.sum( 236 system["coordinates"][1:, :, :, None] 237 * system["forces"][1:, :, None, :], 238 axis=(0, 1), 239 ) 240 system = { 241 **system, 242 "vel": eigv, 243 "ek_tensor": ek, 244 "ek_c": jnp.trace(ek_c), 245 "ek": jnp.trace(ek), 246 } 247 248 if estimate_pressure: 249 vir = system["virial"] 250 volume = jnp.abs(jnp.linalg.det(system["cell"])) 251 Pres = (2 * ek - vir) / volume 252 system["pressure_tensor"] = Pres 253 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 254 if variable_cell: 255 density = totmass_amu / volume 256 system["density"] = density 257 system["volume"] = volume 258 259 return system 260 261 else: 262 ### CLASSICAL MD INTEGRATOR 263 @jax.jit 264 def update_conformation(conformation, system): 265 conformation = {**conformation, "coordinates": system["coordinates"]} 266 if variable_cell: 267 conformation["cells"] = system["cell"][None, :, :] 268 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 269 None, :, : 270 ] 271 return conformation 272 273 @jax.jit 274 def stepA(system): 275 v = system["vel"] 276 f = system["forces"] 277 x = system["coordinates"] 278 279 v = v + f * dt2m 280 x = x + dt2 * v 281 x, v, system = thermo_update(x, v, system) 282 x = x + dt2 * v 283 284 return {**system, "coordinates": x, "vel": v} 285 286 @jax.jit 287 def update_forces(system, conformation): 288 if estimate_pressure: 289 epot, f, vir_t, out = model._energy_and_forces_and_virial( 290 model.variables, conformation 291 ) 292 epot = epot / model_energy_unit 293 f = f / model_energy_unit 294 vir_t = vir_t / model_energy_unit 295 new_sys = { 296 **system, 297 "forces": f, 298 "epot": epot[0], 299 "virial": vir_t[0], 300 } 301 else: 302 epot, f, out = model._energy_and_forces(model.variables, conformation) 303 epot = epot / model_energy_unit 304 f = f / model_energy_unit 305 new_sys = {**system, "forces": f, "epot": epot[0]} 306 307 if ensemble_key is not None: 308 kT = system_data["kT"] 309 dE = out[ensemble_key][0, :] / model_energy_unit - new_sys["epot"] 310 new_sys["ensemble_weights"] = -dE / kT 311 return new_sys 312 313 @jax.jit 314 def stepB(system): 315 v = system["vel"] 316 f = system["forces"] 317 state_th = system["thermostat"] 318 319 v = v + f * dt2m 320 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 321 ek_tensor = ( 322 0.5 323 * jnp.sum(mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0) 324 / state_th.get("corr_kin", 1.0) 325 ) 326 system = { 327 **system, 328 "vel": v, 329 "ek": jnp.trace(ek_tensor), 330 "ek_tensor": ek_tensor, 331 } 332 333 if estimate_pressure: 334 vir = system["virial"] 335 volume = jnp.abs(jnp.linalg.det(system["cell"])) 336 Pres = (2 * ek_tensor - vir) / volume 337 system["pressure_tensor"] = Pres 338 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 339 if variable_cell: 340 density = totmass_amu / volume 341 system["density"] = density 342 system["volume"] = volume 343 344 return system 345 346 ### DEFINE STEP FUNCTION COMMON TO CLASSICAL AND PIMD 347 def step( 348 istep, dyn_state, system, conformation, preproc_state, force_preprocess=False 349 ): 350 tstep0 = time.time() 351 print_timings = "timings" in dyn_state 352 353 dyn_state = { 354 **dyn_state, 355 "istep": dyn_state["istep"] + 1, 356 } 357 if print_timings: 358 prev_timings = dyn_state["timings"] 359 timings = defaultdict(lambda: 0.0) 360 timings.update(prev_timings) 361 362 ## take a half step (update positions, nblist and half velocities) 363 system = stepA(system) 364 365 if print_timings: 366 system["coordinates"].block_until_ready() 367 timings["Integrator"] += time.time() - tstep0 368 tstep0 = time.time() 369 370 ### update conformation and graphs 371 nblist_countdown = dyn_state["nblist_countdown"] 372 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 373 ### full nblist update 374 dyn_state["nblist_countdown"] = nblist_stride - 1 375 conformation = model.preprocessing.process( 376 preproc_state, update_conformation(conformation, system) 377 ) 378 preproc_state, state_up, conformation, overflow = ( 379 model.preprocessing.check_reallocate(preproc_state, conformation) 380 ) 381 if nblist_verbose and overflow: 382 print("step", istep, ", nblist overflow => reallocating nblist") 383 print("size updates:", state_up) 384 385 if print_timings: 386 conformation["coordinates"].block_until_ready() 387 timings["Preprocessing"] += time.time() - tstep0 388 tstep0 = time.time() 389 390 else: 391 ### skin update 392 if dyn_state["print_skin_activation"]: 393 if nblist_verbose: 394 print( 395 "step", 396 istep, 397 ", end of nblist warmup phase => activating skin updates", 398 ) 399 dyn_state["print_skin_activation"] = False 400 401 dyn_state["nblist_countdown"] = nblist_countdown - 1 402 conformation = model.preprocessing.update_skin( 403 update_conformation(conformation, system) 404 ) 405 406 if print_timings: 407 conformation["coordinates"].block_until_ready() 408 timings["Skin update"] += time.time() - tstep0 409 tstep0 = time.time() 410 411 ## compute forces 412 system = update_forces(system, conformation) 413 if print_timings: 414 system["coordinates"].block_until_ready() 415 timings["Forces"] += time.time() - tstep0 416 tstep0 = time.time() 417 418 ## finish step 419 system = stepB(system) 420 421 ## end of step update (mostly for adQTB) 422 if do_thermostat_post: 423 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 424 system["thermostat"], dyn_state["thermostat_post_state"] 425 ) 426 427 if print_timings: 428 system["coordinates"].block_until_ready() 429 timings["Integrator"] += time.time() - tstep0 430 tstep0 = time.time() 431 432 # store timings 433 dyn_state["timings"] = timings 434 435 return dyn_state, system, conformation, preproc_state 436 437 return step, update_conformation, dyn_state, thermo_state, vel