fennol.md.integrate
1import time 2import math 3import os 4 5import numpy as np 6import jax 7import jax.numpy as jnp 8 9from ..utils.atomic_units import AtomicUnits as au 10from .thermostats import get_thermostat 11from .barostats import get_barostat 12from .colvars import setup_colvars 13from .spectra import initialize_ir_spectrum 14 15from .utils import load_dynamics_restart, get_restart_file 16from .initial import load_model, load_system_data, initialize_preprocessing 17 18 19def initialize_dynamics(simulation_parameters, fprec, rng_key): 20 ### LOAD MODEL 21 model = load_model(simulation_parameters) 22 model_energy_unit = au.get_multiplier(model.energy_unit) 23 24 ### Get the coordinates and species from the xyz file 25 system_data, conformation = load_system_data(simulation_parameters, fprec) 26 system_data["model_energy_unit"] = model_energy_unit 27 system_data["model_energy_unit_str"] = model.energy_unit 28 29 ### FINISH BUILDING conformation 30 if os.path.exists(get_restart_file(system_data)): 31 ### RESTART FROM PREVIOUS DYNAMICS 32 restart_data = load_dynamics_restart(system_data) 33 print("# RESTARTING FROM PREVIOUS DYNAMICS") 34 model.preproc_state = restart_data["preproc_state"] 35 conformation["coordinates"] = restart_data["coordinates"] 36 else: 37 restart_data = {} 38 39 ### INITIALIZE PREPROCESSING 40 preproc_state, conformation = initialize_preprocessing( 41 simulation_parameters, model, conformation, system_data 42 ) 43 44 ### get dynamics parameters 45 dt = simulation_parameters.get("dt") * au.FS 46 dt2 = 0.5 * dt 47 mass = system_data["mass"] 48 densmass = system_data["totmass_Da"]*(au.MPROT*au.GCM3) 49 nat = system_data["nat"] 50 dtm = jnp.asarray(dt / mass[:, None], dtype=fprec) 51 52 nreplicas = system_data.get("nreplicas", 1) 53 nbeads = system_data.get("nbeads", None) 54 if nbeads is not None: 55 nreplicas = nbeads 56 dtm = dtm[None, :, :] 57 58 ### INITIALIZE DYNAMICS STATE 59 system = {"coordinates": conformation["coordinates"]} 60 dyn_state = { 61 "istep": 0, 62 "dt": dt, 63 "pimd": nbeads is not None, 64 "preproc_state": preproc_state, 65 "start_time_ps": restart_data.get("simulation_time_ps", 0.), 66 } 67 gradient_keys = ["coordinates"] 68 thermo_updates = [] 69 70 ### INITIALIZE THERMOSTAT 71 thermostat_rng, rng_key = jax.random.split(rng_key) 72 ( 73 thermostat, 74 thermostat_post, 75 thermostat_state, 76 initial_vel, 77 dyn_state["thermostat_name"], 78 ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data) 79 do_thermostat_post = thermostat_post is not None 80 if do_thermostat_post: 81 thermostat_post, post_state = thermostat_post 82 dyn_state["thermostat_post_state"] = post_state 83 84 system["thermostat"] = thermostat_state 85 system["vel"] = restart_data.get("vel", initial_vel).astype(fprec) 86 87 ### PBC 88 pbc_data = system_data.get("pbc", None) 89 if pbc_data is not None: 90 ### INITIALIZE BAROSTAT 91 barostat_key, rng_key = jax.random.split(rng_key) 92 thermo_update_ensemble, variable_cell, barostat_state = get_barostat( 93 thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data 94 ) 95 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 96 system["barostat"] = barostat_state 97 system["cell"] = conformation["cells"][0] 98 if estimate_pressure: 99 pressure_o_weight = simulation_parameters.get("pressure_o_weight", 0.0) 100 assert ( 101 0.0 <= pressure_o_weight <= 1.0 102 ), "pressure_o_weight must be between 0 and 1" 103 gradient_keys.append("strain") 104 print("# Estimate pressure: ", estimate_pressure) 105 else: 106 estimate_pressure = False 107 variable_cell = False 108 109 def thermo_update_ensemble(x, v, system): 110 v, thermostat_state = thermostat(v, system["thermostat"]) 111 return x, v, {**system, "thermostat": thermostat_state} 112 113 dyn_state["estimate_pressure"] = estimate_pressure 114 dyn_state["variable_cell"] = variable_cell 115 thermo_updates.append(thermo_update_ensemble) 116 117 ### ENERGY ENSEMBLE 118 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 119 120 ### COLVARS 121 colvars_definitions = simulation_parameters.get("colvars", None) 122 use_colvars = colvars_definitions is not None 123 if use_colvars: 124 colvars_calculators, colvars_names = setup_colvars(colvars_definitions) 125 dyn_state["colvars"] = colvars_names 126 127 ### IR SPECTRUM 128 do_ir_spectrum = simulation_parameters.get("ir_spectrum", False) 129 assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean" 130 if do_ir_spectrum: 131 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 132 model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum( 133 simulation_parameters, system_data, fprec, dt, is_qtb 134 ) 135 dyn_state["ir_spectrum"] = ir_state 136 137 ### BUILD GRADIENT FUNCTION 138 energy_and_gradient = model.get_gradient_function( 139 *gradient_keys, jit=True, variables_as_input=True 140 ) 141 142 ### COLLECT THERMO UPDATES 143 if len(thermo_updates) == 1: 144 thermo_update = thermo_updates[0] 145 else: 146 147 def thermo_update(x, v, system): 148 for update in thermo_updates: 149 x, v, system = update(x, v, system) 150 return x, v, system 151 152 ### RING POLYMER INITIALIZATION 153 if nbeads is not None: 154 cay_correction = simulation_parameters.get("cay_correction", True) 155 omk = system_data["omk"] 156 eigmat = system_data["eigmat"] 157 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 158 if cay_correction: 159 axx = jnp.asarray(2 * cayfact) 160 axv = jnp.asarray(dt * cayfact) 161 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 162 else: 163 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 164 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 165 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 166 167 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 168 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 169 system["coordinates"] = eigx 170 171 ############################################### 172 ### DEFINE UPDATE FUNCTION 173 @jax.jit 174 def update_conformation(conformation, system): 175 x = system["coordinates"] 176 if nbeads is not None: 177 x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * ( 178 nbeads**0.5 179 ) 180 conformation = {**conformation, "coordinates": x} 181 if variable_cell: 182 conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0) 183 184 185 186 return conformation 187 188 ############################################### 189 ### DEFINE INTEGRATION FUNCTIONS 190 def integrate_A_half(x0, v0): 191 if nbeads is None: 192 return x0 + dt2 * v0, v0 193 194 # update coordinates and velocities of a free ring polymer for a half time step 195 eigx_c = x0[0] + dt2 * v0[0] 196 eigv_c = v0[0] 197 eigx = x0[1:] * axx + v0[1:] * axv 198 eigv = x0[1:] * avx + v0[1:] * axx 199 200 return ( 201 jnp.concatenate((eigx_c[None], eigx), axis=0), 202 jnp.concatenate((eigv_c[None], eigv), axis=0), 203 ) 204 205 @jax.jit 206 def integrate(system): 207 x = system["coordinates"] 208 v = system["vel"] + dtm * system["forces"] 209 x, v = integrate_A_half(x, v) 210 x, v, system = thermo_update(x, v, system) 211 x, v = integrate_A_half(x, v) 212 213 return {**system, "coordinates": x, "vel": v} 214 215 ############################################### 216 ### DEFINE OBSERVABLE FUNCTION 217 @jax.jit 218 def update_observables(system, conformation): 219 ### POTENTIAL ENERGY AND FORCES 220 epot, de, out = energy_and_gradient(model.variables, conformation) 221 out["forces"] = -de["coordinates"] 222 epot = epot / model_energy_unit 223 de = {k: v / model_energy_unit for k, v in de.items()} 224 forces = -de["coordinates"] 225 226 if nbeads is not None: 227 ### PROJECT FORCES ONTO POLYMER NORMAL MODES 228 forces = jnp.einsum( 229 "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3) 230 ) * (1.0 / nbeads**0.5) 231 232 system = { 233 **system, 234 "epot": jnp.mean(epot), 235 "forces": forces, 236 "energy_gradients": de, 237 } 238 239 ### KINETIC ENERGY 240 v = system["vel"] 241 if nbeads is None: 242 corr_kin = system["thermostat"].get("corr_kin", 1.0) 243 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 244 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 245 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 246 ) 247 else: 248 ek_c = 0.5 * jnp.sum( 249 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0 250 ) 251 ek = ek_c - 0.5 * jnp.sum( 252 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 253 axis=(0, 1), 254 ) 255 system["ek_c"] = jnp.trace(ek_c) 256 257 system["ek"] = jnp.trace(ek) 258 system["ek_tensor"] = ek 259 260 if estimate_pressure: 261 if pressure_o_weight != 1.0: 262 v = system["vel"] + 0.5 * dtm * system["forces"] 263 if nbeads is None: 264 corr_kin = system["thermostat"].get("corr_kin", 1.0) 265 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 266 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 267 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 268 ) 269 else: 270 ek_c = 0.5 * jnp.sum( 271 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], 272 axis=0, 273 ) 274 ek = ek_c - 0.5 * jnp.sum( 275 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 276 axis=(0, 1), 277 ) 278 b = pressure_o_weight 279 ek = (1.0 - b) * ek + b * system["ek_tensor"] 280 281 vir = jnp.mean(de["strain"], axis=0) 282 system["virial"] = vir 283 out["virial_tensor"] = vir * model_energy_unit 284 285 pV = 2 * ek - vir 286 system["PV_tensor"] = pV 287 volume = jnp.abs(jnp.linalg.det(system["cell"])) 288 Pres = pV / volume 289 system["pressure_tensor"] = Pres 290 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 291 if variable_cell: 292 density = densmass / volume 293 system["density"] = density 294 system["volume"] = volume 295 296 if ensemble_key is not None: 297 kT = system_data["kT"] 298 dE = ( 299 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"] 300 ) 301 system["ensemble_weights"] = -dE / kT 302 303 if "total_dipole" in out: 304 if nbeads is None: 305 system["total_dipole"] = out["total_dipole"][0] 306 else: 307 system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0) 308 309 if use_colvars: 310 coords = system["coordinates"].reshape(-1, nat, 3)[0] 311 colvars = {} 312 for colvar_name, colvar_calc in colvars_calculators.items(): 313 colvars[colvar_name] = colvar_calc(coords) 314 system["colvars"] = colvars 315 316 return system, out 317 318 ############################################### 319 ### IR SPECTRUM 320 if do_ir_spectrum: 321 # @jax.jit 322 # def update_dipole(ir_state,system,conformation): 323 # def mumodel(coords): 324 # out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords}) 325 # if nbeads is None: 326 # return out["total_dipole"][0] 327 # return out["total_dipole"].sum(axis=0) 328 # dmudqmodel = jax.jacobian(mumodel) 329 330 # dmudq = dmudqmodel(conformation["coordinates"]) 331 # # print(dmudq.shape) 332 # if nbeads is None: 333 # vel = system["vel"].reshape(-1,1,nat,3)[0] 334 # mudot = (vel*dmudq).sum(axis=(1,2)) 335 # else: 336 # dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1) 337 # vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) * nbeads**0.5 338 # ) 339 # # vel = system["vel"][0].reshape(1,nat,3) 340 # mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads 341 342 # ir_state = save_dipole(mudot,ir_state) 343 # return ir_state 344 @jax.jit 345 def update_conformation_ir(conformation, system): 346 conformation = { 347 **conformation, 348 "coordinates": system["coordinates"].reshape(-1, nat, 3)[0], 349 "natoms": jnp.asarray([nat]), 350 "batch_index": jnp.asarray([0] * nat), 351 "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]), 352 } 353 if variable_cell: 354 conformation["cells"] = system["cell"][None, :, :] 355 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 356 None, :, : 357 ] 358 return conformation 359 360 @jax.jit 361 def update_dipole(ir_state, system, conformation): 362 if model_ir is not None: 363 out = model_ir._apply(model_ir.variables, conformation) 364 q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 365 dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 366 else: 367 q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 368 dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 369 if nbeads is not None: 370 q = jnp.mean(q, axis=0) 371 dip = jnp.mean(dip, axis=0) 372 vel = system["vel"][0] 373 pos = system["coordinates"][0] 374 else: 375 q = q[0] 376 dip = dip[0] 377 vel = system["vel"].reshape(-1, nat, 3)[0] 378 pos = system["coordinates"].reshape(-1, nat, 3)[0] 379 380 if pbc_data is not None: 381 cell_reciprocal = ( 382 conformation["cells"][0], 383 conformation["reciprocal_cells"][0], 384 ) 385 else: 386 cell_reciprocal = None 387 388 ir_state = save_dipole( 389 q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state 390 ) 391 return ir_state 392 393 ############################################### 394 ### GRAPH UPDATES 395 396 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 397 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 398 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS 399 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 400 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 401 if nblist_skin > 0: 402 if nblist_stride <= 0: 403 ## reference skin parameters at 300K (from Tinker-HP) 404 ## => skin of 2 A gives you 40 fs without complete rebuild 405 t_ref = 40.0 # FS 406 nblist_skin_ref = 2.0 # A 407 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 408 print( 409 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 410 ) 411 412 if nblist_skin <= 0: 413 nblist_stride = 1 414 415 dyn_state["nblist_countdown"] = 0 416 dyn_state["print_skin_activation"] = nblist_warmup > 0 417 418 def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False): 419 nblist_countdown = dyn_state["nblist_countdown"] 420 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 421 ### FULL NBLIST REBUILD 422 dyn_state["nblist_countdown"] = nblist_stride - 1 423 preproc_state = dyn_state["preproc_state"] 424 conformation = model.preprocessing.process( 425 preproc_state, update_conformation(conformation, system) 426 ) 427 preproc_state, state_up, conformation, overflow = ( 428 model.preprocessing.check_reallocate(preproc_state, conformation) 429 ) 430 dyn_state["preproc_state"] = preproc_state 431 if nblist_verbose and overflow: 432 print("step", istep, ", nblist overflow => reallocating nblist") 433 print("size updates:", state_up) 434 435 if do_ir_spectrum and model_ir is not None: 436 conformation_ir = model_ir.preprocessing.process( 437 dyn_state["preproc_state_ir"], 438 update_conformation_ir(dyn_state["conformation_ir"], system), 439 ) 440 ( 441 dyn_state["preproc_state_ir"], 442 _, 443 dyn_state["conformation_ir"], 444 overflow, 445 ) = model_ir.preprocessing.check_reallocate( 446 dyn_state["preproc_state_ir"], conformation_ir 447 ) 448 449 else: 450 ### SKIN UPDATE 451 if dyn_state["print_skin_activation"]: 452 if nblist_verbose: 453 print( 454 "step", 455 istep, 456 ", end of nblist warmup phase => activating skin updates", 457 ) 458 dyn_state["print_skin_activation"] = False 459 460 dyn_state["nblist_countdown"] = nblist_countdown - 1 461 conformation = model.preprocessing.update_skin( 462 update_conformation(conformation, system) 463 ) 464 if do_ir_spectrum and model_ir is not None: 465 dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin( 466 update_conformation_ir(dyn_state["conformation_ir"], system) 467 ) 468 469 return conformation, dyn_state 470 471 ################################################ 472 ### DEFINE STEP FUNCTION 473 def step(istep, dyn_state, system, conformation, force_preprocess=False): 474 475 dyn_state = { 476 **dyn_state, 477 "istep": dyn_state["istep"] + 1, 478 } 479 480 ### INTEGRATE EQUATIONS OF MOTION 481 system = integrate(system) 482 483 ### UPDATE CONFORMATION AND GRAPHS 484 conformation, dyn_state = update_graphs( 485 istep, dyn_state, system, conformation, force_preprocess 486 ) 487 488 ## COMPUTE FORCES AND OBSERVABLES 489 system, out = update_observables(system, conformation) 490 491 ## END OF STEP UPDATES 492 if do_thermostat_post: 493 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 494 system["thermostat"], dyn_state["thermostat_post_state"] 495 ) 496 497 if do_ir_spectrum: 498 ir_state = update_dipole( 499 dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"] 500 ) 501 dyn_state["ir_spectrum"] = ir_post(ir_state) 502 503 return dyn_state, system, conformation, out 504 505 ########################################################### 506 507 print("# Computing initial energy and forces") 508 509 conformation = update_conformation(conformation, system) 510 # initialize IR conformation 511 if do_ir_spectrum and model_ir is not None: 512 dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = ( 513 model_ir.preprocessing( 514 model_ir.preproc_state, 515 update_conformation_ir(conformation, system), 516 ) 517 ) 518 519 system, _ = update_observables(system, conformation) 520 521 return step, update_conformation, system_data, dyn_state, conformation, system
def
initialize_dynamics(simulation_parameters, fprec, rng_key):
20def initialize_dynamics(simulation_parameters, fprec, rng_key): 21 ### LOAD MODEL 22 model = load_model(simulation_parameters) 23 model_energy_unit = au.get_multiplier(model.energy_unit) 24 25 ### Get the coordinates and species from the xyz file 26 system_data, conformation = load_system_data(simulation_parameters, fprec) 27 system_data["model_energy_unit"] = model_energy_unit 28 system_data["model_energy_unit_str"] = model.energy_unit 29 30 ### FINISH BUILDING conformation 31 if os.path.exists(get_restart_file(system_data)): 32 ### RESTART FROM PREVIOUS DYNAMICS 33 restart_data = load_dynamics_restart(system_data) 34 print("# RESTARTING FROM PREVIOUS DYNAMICS") 35 model.preproc_state = restart_data["preproc_state"] 36 conformation["coordinates"] = restart_data["coordinates"] 37 else: 38 restart_data = {} 39 40 ### INITIALIZE PREPROCESSING 41 preproc_state, conformation = initialize_preprocessing( 42 simulation_parameters, model, conformation, system_data 43 ) 44 45 ### get dynamics parameters 46 dt = simulation_parameters.get("dt") * au.FS 47 dt2 = 0.5 * dt 48 mass = system_data["mass"] 49 densmass = system_data["totmass_Da"]*(au.MPROT*au.GCM3) 50 nat = system_data["nat"] 51 dtm = jnp.asarray(dt / mass[:, None], dtype=fprec) 52 53 nreplicas = system_data.get("nreplicas", 1) 54 nbeads = system_data.get("nbeads", None) 55 if nbeads is not None: 56 nreplicas = nbeads 57 dtm = dtm[None, :, :] 58 59 ### INITIALIZE DYNAMICS STATE 60 system = {"coordinates": conformation["coordinates"]} 61 dyn_state = { 62 "istep": 0, 63 "dt": dt, 64 "pimd": nbeads is not None, 65 "preproc_state": preproc_state, 66 "start_time_ps": restart_data.get("simulation_time_ps", 0.), 67 } 68 gradient_keys = ["coordinates"] 69 thermo_updates = [] 70 71 ### INITIALIZE THERMOSTAT 72 thermostat_rng, rng_key = jax.random.split(rng_key) 73 ( 74 thermostat, 75 thermostat_post, 76 thermostat_state, 77 initial_vel, 78 dyn_state["thermostat_name"], 79 ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data) 80 do_thermostat_post = thermostat_post is not None 81 if do_thermostat_post: 82 thermostat_post, post_state = thermostat_post 83 dyn_state["thermostat_post_state"] = post_state 84 85 system["thermostat"] = thermostat_state 86 system["vel"] = restart_data.get("vel", initial_vel).astype(fprec) 87 88 ### PBC 89 pbc_data = system_data.get("pbc", None) 90 if pbc_data is not None: 91 ### INITIALIZE BAROSTAT 92 barostat_key, rng_key = jax.random.split(rng_key) 93 thermo_update_ensemble, variable_cell, barostat_state = get_barostat( 94 thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data 95 ) 96 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 97 system["barostat"] = barostat_state 98 system["cell"] = conformation["cells"][0] 99 if estimate_pressure: 100 pressure_o_weight = simulation_parameters.get("pressure_o_weight", 0.0) 101 assert ( 102 0.0 <= pressure_o_weight <= 1.0 103 ), "pressure_o_weight must be between 0 and 1" 104 gradient_keys.append("strain") 105 print("# Estimate pressure: ", estimate_pressure) 106 else: 107 estimate_pressure = False 108 variable_cell = False 109 110 def thermo_update_ensemble(x, v, system): 111 v, thermostat_state = thermostat(v, system["thermostat"]) 112 return x, v, {**system, "thermostat": thermostat_state} 113 114 dyn_state["estimate_pressure"] = estimate_pressure 115 dyn_state["variable_cell"] = variable_cell 116 thermo_updates.append(thermo_update_ensemble) 117 118 ### ENERGY ENSEMBLE 119 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 120 121 ### COLVARS 122 colvars_definitions = simulation_parameters.get("colvars", None) 123 use_colvars = colvars_definitions is not None 124 if use_colvars: 125 colvars_calculators, colvars_names = setup_colvars(colvars_definitions) 126 dyn_state["colvars"] = colvars_names 127 128 ### IR SPECTRUM 129 do_ir_spectrum = simulation_parameters.get("ir_spectrum", False) 130 assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean" 131 if do_ir_spectrum: 132 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 133 model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum( 134 simulation_parameters, system_data, fprec, dt, is_qtb 135 ) 136 dyn_state["ir_spectrum"] = ir_state 137 138 ### BUILD GRADIENT FUNCTION 139 energy_and_gradient = model.get_gradient_function( 140 *gradient_keys, jit=True, variables_as_input=True 141 ) 142 143 ### COLLECT THERMO UPDATES 144 if len(thermo_updates) == 1: 145 thermo_update = thermo_updates[0] 146 else: 147 148 def thermo_update(x, v, system): 149 for update in thermo_updates: 150 x, v, system = update(x, v, system) 151 return x, v, system 152 153 ### RING POLYMER INITIALIZATION 154 if nbeads is not None: 155 cay_correction = simulation_parameters.get("cay_correction", True) 156 omk = system_data["omk"] 157 eigmat = system_data["eigmat"] 158 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 159 if cay_correction: 160 axx = jnp.asarray(2 * cayfact) 161 axv = jnp.asarray(dt * cayfact) 162 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 163 else: 164 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 165 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 166 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 167 168 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 169 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 170 system["coordinates"] = eigx 171 172 ############################################### 173 ### DEFINE UPDATE FUNCTION 174 @jax.jit 175 def update_conformation(conformation, system): 176 x = system["coordinates"] 177 if nbeads is not None: 178 x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * ( 179 nbeads**0.5 180 ) 181 conformation = {**conformation, "coordinates": x} 182 if variable_cell: 183 conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0) 184 185 186 187 return conformation 188 189 ############################################### 190 ### DEFINE INTEGRATION FUNCTIONS 191 def integrate_A_half(x0, v0): 192 if nbeads is None: 193 return x0 + dt2 * v0, v0 194 195 # update coordinates and velocities of a free ring polymer for a half time step 196 eigx_c = x0[0] + dt2 * v0[0] 197 eigv_c = v0[0] 198 eigx = x0[1:] * axx + v0[1:] * axv 199 eigv = x0[1:] * avx + v0[1:] * axx 200 201 return ( 202 jnp.concatenate((eigx_c[None], eigx), axis=0), 203 jnp.concatenate((eigv_c[None], eigv), axis=0), 204 ) 205 206 @jax.jit 207 def integrate(system): 208 x = system["coordinates"] 209 v = system["vel"] + dtm * system["forces"] 210 x, v = integrate_A_half(x, v) 211 x, v, system = thermo_update(x, v, system) 212 x, v = integrate_A_half(x, v) 213 214 return {**system, "coordinates": x, "vel": v} 215 216 ############################################### 217 ### DEFINE OBSERVABLE FUNCTION 218 @jax.jit 219 def update_observables(system, conformation): 220 ### POTENTIAL ENERGY AND FORCES 221 epot, de, out = energy_and_gradient(model.variables, conformation) 222 out["forces"] = -de["coordinates"] 223 epot = epot / model_energy_unit 224 de = {k: v / model_energy_unit for k, v in de.items()} 225 forces = -de["coordinates"] 226 227 if nbeads is not None: 228 ### PROJECT FORCES ONTO POLYMER NORMAL MODES 229 forces = jnp.einsum( 230 "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3) 231 ) * (1.0 / nbeads**0.5) 232 233 system = { 234 **system, 235 "epot": jnp.mean(epot), 236 "forces": forces, 237 "energy_gradients": de, 238 } 239 240 ### KINETIC ENERGY 241 v = system["vel"] 242 if nbeads is None: 243 corr_kin = system["thermostat"].get("corr_kin", 1.0) 244 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 245 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 246 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 247 ) 248 else: 249 ek_c = 0.5 * jnp.sum( 250 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0 251 ) 252 ek = ek_c - 0.5 * jnp.sum( 253 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 254 axis=(0, 1), 255 ) 256 system["ek_c"] = jnp.trace(ek_c) 257 258 system["ek"] = jnp.trace(ek) 259 system["ek_tensor"] = ek 260 261 if estimate_pressure: 262 if pressure_o_weight != 1.0: 263 v = system["vel"] + 0.5 * dtm * system["forces"] 264 if nbeads is None: 265 corr_kin = system["thermostat"].get("corr_kin", 1.0) 266 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 267 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 268 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 269 ) 270 else: 271 ek_c = 0.5 * jnp.sum( 272 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], 273 axis=0, 274 ) 275 ek = ek_c - 0.5 * jnp.sum( 276 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 277 axis=(0, 1), 278 ) 279 b = pressure_o_weight 280 ek = (1.0 - b) * ek + b * system["ek_tensor"] 281 282 vir = jnp.mean(de["strain"], axis=0) 283 system["virial"] = vir 284 out["virial_tensor"] = vir * model_energy_unit 285 286 pV = 2 * ek - vir 287 system["PV_tensor"] = pV 288 volume = jnp.abs(jnp.linalg.det(system["cell"])) 289 Pres = pV / volume 290 system["pressure_tensor"] = Pres 291 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 292 if variable_cell: 293 density = densmass / volume 294 system["density"] = density 295 system["volume"] = volume 296 297 if ensemble_key is not None: 298 kT = system_data["kT"] 299 dE = ( 300 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"] 301 ) 302 system["ensemble_weights"] = -dE / kT 303 304 if "total_dipole" in out: 305 if nbeads is None: 306 system["total_dipole"] = out["total_dipole"][0] 307 else: 308 system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0) 309 310 if use_colvars: 311 coords = system["coordinates"].reshape(-1, nat, 3)[0] 312 colvars = {} 313 for colvar_name, colvar_calc in colvars_calculators.items(): 314 colvars[colvar_name] = colvar_calc(coords) 315 system["colvars"] = colvars 316 317 return system, out 318 319 ############################################### 320 ### IR SPECTRUM 321 if do_ir_spectrum: 322 # @jax.jit 323 # def update_dipole(ir_state,system,conformation): 324 # def mumodel(coords): 325 # out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords}) 326 # if nbeads is None: 327 # return out["total_dipole"][0] 328 # return out["total_dipole"].sum(axis=0) 329 # dmudqmodel = jax.jacobian(mumodel) 330 331 # dmudq = dmudqmodel(conformation["coordinates"]) 332 # # print(dmudq.shape) 333 # if nbeads is None: 334 # vel = system["vel"].reshape(-1,1,nat,3)[0] 335 # mudot = (vel*dmudq).sum(axis=(1,2)) 336 # else: 337 # dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1) 338 # vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) * nbeads**0.5 339 # ) 340 # # vel = system["vel"][0].reshape(1,nat,3) 341 # mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads 342 343 # ir_state = save_dipole(mudot,ir_state) 344 # return ir_state 345 @jax.jit 346 def update_conformation_ir(conformation, system): 347 conformation = { 348 **conformation, 349 "coordinates": system["coordinates"].reshape(-1, nat, 3)[0], 350 "natoms": jnp.asarray([nat]), 351 "batch_index": jnp.asarray([0] * nat), 352 "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]), 353 } 354 if variable_cell: 355 conformation["cells"] = system["cell"][None, :, :] 356 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 357 None, :, : 358 ] 359 return conformation 360 361 @jax.jit 362 def update_dipole(ir_state, system, conformation): 363 if model_ir is not None: 364 out = model_ir._apply(model_ir.variables, conformation) 365 q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 366 dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 367 else: 368 q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 369 dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 370 if nbeads is not None: 371 q = jnp.mean(q, axis=0) 372 dip = jnp.mean(dip, axis=0) 373 vel = system["vel"][0] 374 pos = system["coordinates"][0] 375 else: 376 q = q[0] 377 dip = dip[0] 378 vel = system["vel"].reshape(-1, nat, 3)[0] 379 pos = system["coordinates"].reshape(-1, nat, 3)[0] 380 381 if pbc_data is not None: 382 cell_reciprocal = ( 383 conformation["cells"][0], 384 conformation["reciprocal_cells"][0], 385 ) 386 else: 387 cell_reciprocal = None 388 389 ir_state = save_dipole( 390 q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state 391 ) 392 return ir_state 393 394 ############################################### 395 ### GRAPH UPDATES 396 397 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 398 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 399 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS 400 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 401 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 402 if nblist_skin > 0: 403 if nblist_stride <= 0: 404 ## reference skin parameters at 300K (from Tinker-HP) 405 ## => skin of 2 A gives you 40 fs without complete rebuild 406 t_ref = 40.0 # FS 407 nblist_skin_ref = 2.0 # A 408 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 409 print( 410 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 411 ) 412 413 if nblist_skin <= 0: 414 nblist_stride = 1 415 416 dyn_state["nblist_countdown"] = 0 417 dyn_state["print_skin_activation"] = nblist_warmup > 0 418 419 def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False): 420 nblist_countdown = dyn_state["nblist_countdown"] 421 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 422 ### FULL NBLIST REBUILD 423 dyn_state["nblist_countdown"] = nblist_stride - 1 424 preproc_state = dyn_state["preproc_state"] 425 conformation = model.preprocessing.process( 426 preproc_state, update_conformation(conformation, system) 427 ) 428 preproc_state, state_up, conformation, overflow = ( 429 model.preprocessing.check_reallocate(preproc_state, conformation) 430 ) 431 dyn_state["preproc_state"] = preproc_state 432 if nblist_verbose and overflow: 433 print("step", istep, ", nblist overflow => reallocating nblist") 434 print("size updates:", state_up) 435 436 if do_ir_spectrum and model_ir is not None: 437 conformation_ir = model_ir.preprocessing.process( 438 dyn_state["preproc_state_ir"], 439 update_conformation_ir(dyn_state["conformation_ir"], system), 440 ) 441 ( 442 dyn_state["preproc_state_ir"], 443 _, 444 dyn_state["conformation_ir"], 445 overflow, 446 ) = model_ir.preprocessing.check_reallocate( 447 dyn_state["preproc_state_ir"], conformation_ir 448 ) 449 450 else: 451 ### SKIN UPDATE 452 if dyn_state["print_skin_activation"]: 453 if nblist_verbose: 454 print( 455 "step", 456 istep, 457 ", end of nblist warmup phase => activating skin updates", 458 ) 459 dyn_state["print_skin_activation"] = False 460 461 dyn_state["nblist_countdown"] = nblist_countdown - 1 462 conformation = model.preprocessing.update_skin( 463 update_conformation(conformation, system) 464 ) 465 if do_ir_spectrum and model_ir is not None: 466 dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin( 467 update_conformation_ir(dyn_state["conformation_ir"], system) 468 ) 469 470 return conformation, dyn_state 471 472 ################################################ 473 ### DEFINE STEP FUNCTION 474 def step(istep, dyn_state, system, conformation, force_preprocess=False): 475 476 dyn_state = { 477 **dyn_state, 478 "istep": dyn_state["istep"] + 1, 479 } 480 481 ### INTEGRATE EQUATIONS OF MOTION 482 system = integrate(system) 483 484 ### UPDATE CONFORMATION AND GRAPHS 485 conformation, dyn_state = update_graphs( 486 istep, dyn_state, system, conformation, force_preprocess 487 ) 488 489 ## COMPUTE FORCES AND OBSERVABLES 490 system, out = update_observables(system, conformation) 491 492 ## END OF STEP UPDATES 493 if do_thermostat_post: 494 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 495 system["thermostat"], dyn_state["thermostat_post_state"] 496 ) 497 498 if do_ir_spectrum: 499 ir_state = update_dipole( 500 dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"] 501 ) 502 dyn_state["ir_spectrum"] = ir_post(ir_state) 503 504 return dyn_state, system, conformation, out 505 506 ########################################################### 507 508 print("# Computing initial energy and forces") 509 510 conformation = update_conformation(conformation, system) 511 # initialize IR conformation 512 if do_ir_spectrum and model_ir is not None: 513 dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = ( 514 model_ir.preprocessing( 515 model_ir.preproc_state, 516 update_conformation_ir(conformation, system), 517 ) 518 ) 519 520 system, _ = update_observables(system, conformation) 521 522 return step, update_conformation, system_data, dyn_state, conformation, system