fennol.md.integrate
1import time 2import math 3import os 4 5import numpy as np 6import jax 7import jax.numpy as jnp 8 9from ..utils.atomic_units import AtomicUnits as au 10from .thermostats import get_thermostat 11from .barostats import get_barostat 12from .colvars import setup_colvars 13from .spectra import initialize_ir_spectrum 14 15from .utils import load_dynamics_restart, get_restart_file 16from .initial import load_model, load_system_data, initialize_preprocessing 17 18 19def initialize_dynamics(simulation_parameters, fprec, rng_key): 20 ### LOAD MODEL 21 model = load_model(simulation_parameters) 22 model_energy_unit = au.get_multiplier(model.energy_unit) 23 24 ### Get the coordinates and species from the xyz file 25 system_data, conformation = load_system_data(simulation_parameters, fprec) 26 27 ### FINISH BUILDING conformation 28 if os.path.exists(get_restart_file(system_data)): 29 ### RESTART FROM PREVIOUS DYNAMICS 30 restart_data = load_dynamics_restart(system_data) 31 print("# RESTARTING FROM PREVIOUS DYNAMICS") 32 model.preproc_state = restart_data["preproc_state"] 33 conformation["coordinates"] = restart_data["coordinates"] 34 else: 35 restart_data = {} 36 37 ### INITIALIZE PREPROCESSING 38 preproc_state, conformation = initialize_preprocessing( 39 simulation_parameters, model, conformation, system_data 40 ) 41 42 ### get dynamics parameters 43 dt = simulation_parameters.get("dt") * au.FS 44 dt2 = 0.5 * dt 45 mass = system_data["mass"] 46 totmass_amu = system_data["totmass_amu"] 47 nat = system_data["nat"] 48 dtm = jnp.asarray(dt / mass[:, None], dtype=fprec) 49 50 nreplicas = system_data.get("nreplicas", 1) 51 nbeads = system_data.get("nbeads", None) 52 if nbeads is not None: 53 nreplicas = nbeads 54 dtm = dtm[None, :, :] 55 56 ### INITIALIZE DYNAMICS STATE 57 system = {"coordinates": conformation["coordinates"]} 58 dyn_state = { 59 "istep": 0, 60 "dt": dt, 61 "pimd": nbeads is not None, 62 "preproc_state": preproc_state, 63 "start_time_ps": restart_data.get("simulation_time_ps", 0.), 64 } 65 gradient_keys = ["coordinates"] 66 thermo_updates = [] 67 68 ### INITIALIZE THERMOSTAT 69 thermostat_rng, rng_key = jax.random.split(rng_key) 70 ( 71 thermostat, 72 thermostat_post, 73 thermostat_state, 74 initial_vel, 75 dyn_state["thermostat_name"], 76 ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data) 77 do_thermostat_post = thermostat_post is not None 78 if do_thermostat_post: 79 thermostat_post, post_state = thermostat_post 80 dyn_state["thermostat_post_state"] = post_state 81 82 system["thermostat"] = thermostat_state 83 system["vel"] = restart_data.get("vel", initial_vel).astype(fprec) 84 85 ### PBC 86 pbc_data = system_data.get("pbc", None) 87 if pbc_data is not None: 88 ### INITIALIZE BAROSTAT 89 barostat_key, rng_key = jax.random.split(rng_key) 90 thermo_update_ensemble, variable_cell, barostat_state = get_barostat( 91 thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data 92 ) 93 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 94 system["barostat"] = barostat_state 95 system["cell"] = conformation["cells"][0] 96 if estimate_pressure: 97 pressure_o_weight = simulation_parameters.get("pressure_o_weight", 0.0) 98 assert ( 99 0.0 <= pressure_o_weight <= 1.0 100 ), "pressure_o_weight must be between 0 and 1" 101 gradient_keys.append("strain") 102 print("# Estimate pressure: ", estimate_pressure) 103 else: 104 estimate_pressure = False 105 variable_cell = False 106 107 def thermo_update_ensemble(x, v, system): 108 v, thermostat_state = thermostat(v, system["thermostat"]) 109 return x, v, {**system, "thermostat": thermostat_state} 110 111 dyn_state["estimate_pressure"] = estimate_pressure 112 dyn_state["variable_cell"] = variable_cell 113 thermo_updates.append(thermo_update_ensemble) 114 115 ### ENERGY ENSEMBLE 116 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 117 118 ### COLVARS 119 colvars_definitions = simulation_parameters.get("colvars", None) 120 use_colvars = colvars_definitions is not None 121 if use_colvars: 122 colvars_calculators, colvars_names = setup_colvars(colvars_definitions) 123 dyn_state["colvars"] = colvars_names 124 125 ### IR SPECTRUM 126 do_ir_spectrum = simulation_parameters.get("ir_spectrum", False) 127 assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean" 128 if do_ir_spectrum: 129 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 130 model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum( 131 simulation_parameters, system_data, fprec, dt, is_qtb 132 ) 133 dyn_state["ir_spectrum"] = ir_state 134 135 ### BUILD GRADIENT FUNCTION 136 energy_and_gradient = model.get_gradient_function( 137 *gradient_keys, jit=True, variables_as_input=True 138 ) 139 140 ### COLLECT THERMO UPDATES 141 if len(thermo_updates) == 1: 142 thermo_update = thermo_updates[0] 143 else: 144 145 def thermo_update(x, v, system): 146 for update in thermo_updates: 147 x, v, system = update(x, v, system) 148 return x, v, system 149 150 ### RING POLYMER INITIALIZATION 151 if nbeads is not None: 152 cay_correction = simulation_parameters.get("cay_correction", True) 153 omk = system_data["omk"] 154 eigmat = system_data["eigmat"] 155 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 156 if cay_correction: 157 axx = jnp.asarray(2 * cayfact) 158 axv = jnp.asarray(dt * cayfact) 159 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 160 else: 161 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 162 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 163 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 164 165 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 166 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 167 system["coordinates"] = eigx 168 169 ############################################### 170 ### DEFINE UPDATE FUNCTION 171 @jax.jit 172 def update_conformation(conformation, system): 173 x = system["coordinates"] 174 if nbeads is not None: 175 x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * ( 176 nbeads**0.5 177 ) 178 conformation = {**conformation, "coordinates": x} 179 if variable_cell: 180 conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0) 181 182 183 184 return conformation 185 186 ############################################### 187 ### DEFINE INTEGRATION FUNCTIONS 188 def integrate_A_half(x0, v0): 189 if nbeads is None: 190 return x0 + dt2 * v0, v0 191 192 # update coordinates and velocities of a free ring polymer for a half time step 193 eigx_c = x0[0] + dt2 * v0[0] 194 eigv_c = v0[0] 195 eigx = x0[1:] * axx + v0[1:] * axv 196 eigv = x0[1:] * avx + v0[1:] * axx 197 198 return ( 199 jnp.concatenate((eigx_c[None], eigx), axis=0), 200 jnp.concatenate((eigv_c[None], eigv), axis=0), 201 ) 202 203 @jax.jit 204 def integrate(system): 205 x = system["coordinates"] 206 v = system["vel"] + dtm * system["forces"] 207 x, v = integrate_A_half(x, v) 208 x, v, system = thermo_update(x, v, system) 209 x, v = integrate_A_half(x, v) 210 211 return {**system, "coordinates": x, "vel": v} 212 213 ############################################### 214 ### DEFINE OBSERVABLE FUNCTION 215 @jax.jit 216 def update_observables(system, conformation): 217 ### POTENTIAL ENERGY AND FORCES 218 epot, de, out = energy_and_gradient(model.variables, conformation) 219 epot = epot / model_energy_unit 220 de = {k: v / model_energy_unit for k, v in de.items()} 221 forces = -de["coordinates"] 222 223 if nbeads is not None: 224 ### PROJECT FORCES ONTO POLYMER NORMAL MODES 225 forces = jnp.einsum( 226 "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3) 227 ) * (1.0 / nbeads**0.5) 228 229 system = { 230 **system, 231 "epot": jnp.mean(epot), 232 "forces": forces, 233 "energy_gradients": de, 234 } 235 236 ### KINETIC ENERGY 237 v = system["vel"] 238 if nbeads is None: 239 corr_kin = system["thermostat"].get("corr_kin", 1.0) 240 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 241 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 242 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 243 ) 244 else: 245 ek_c = 0.5 * jnp.sum( 246 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0 247 ) 248 ek = ek_c - 0.5 * jnp.sum( 249 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 250 axis=(0, 1), 251 ) 252 system["ek_c"] = jnp.trace(ek_c) 253 254 system["ek"] = jnp.trace(ek) 255 system["ek_tensor"] = ek 256 257 if estimate_pressure: 258 if pressure_o_weight != 1.0: 259 v = system["vel"] + 0.5 * dtm * system["forces"] 260 if nbeads is None: 261 corr_kin = system["thermostat"].get("corr_kin", 1.0) 262 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 263 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 264 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 265 ) 266 else: 267 ek_c = 0.5 * jnp.sum( 268 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], 269 axis=0, 270 ) 271 ek = ek_c - 0.5 * jnp.sum( 272 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 273 axis=(0, 1), 274 ) 275 b = pressure_o_weight 276 ek = (1.0 - b) * ek + b * system["ek_tensor"] 277 278 vir = jnp.mean(de["strain"], axis=0) 279 system["virial"] = vir 280 281 pV = 2 * ek - vir 282 system["PV_tensor"] = pV 283 volume = jnp.abs(jnp.linalg.det(system["cell"])) 284 Pres = pV / volume 285 system["pressure_tensor"] = Pres 286 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 287 if variable_cell: 288 density = totmass_amu / volume 289 system["density"] = density 290 system["volume"] = volume 291 292 if ensemble_key is not None: 293 kT = system_data["kT"] 294 dE = ( 295 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"] 296 ) 297 system["ensemble_weights"] = -dE / kT 298 299 if "total_dipole" in out: 300 if nbeads is None: 301 system["total_dipole"] = out["total_dipole"][0] 302 else: 303 system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0) 304 305 if use_colvars: 306 coords = system["coordinates"].reshape(-1, nat, 3)[0] 307 colvars = {} 308 for colvar_name, colvar_calc in colvars_calculators.items(): 309 colvars[colvar_name] = colvar_calc(coords) 310 system["colvars"] = colvars 311 312 return system, out 313 314 ############################################### 315 ### IR SPECTRUM 316 if do_ir_spectrum: 317 # @jax.jit 318 # def update_dipole(ir_state,system,conformation): 319 # def mumodel(coords): 320 # out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords}) 321 # if nbeads is None: 322 # return out["total_dipole"][0] 323 # return out["total_dipole"].sum(axis=0) 324 # dmudqmodel = jax.jacobian(mumodel) 325 326 # dmudq = dmudqmodel(conformation["coordinates"]) 327 # # print(dmudq.shape) 328 # if nbeads is None: 329 # vel = system["vel"].reshape(-1,1,nat,3)[0] 330 # mudot = (vel*dmudq).sum(axis=(1,2)) 331 # else: 332 # dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1) 333 # vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) * nbeads**0.5 334 # ) 335 # # vel = system["vel"][0].reshape(1,nat,3) 336 # mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads 337 338 # ir_state = save_dipole(mudot,ir_state) 339 # return ir_state 340 @jax.jit 341 def update_conformation_ir(conformation, system): 342 conformation = { 343 **conformation, 344 "coordinates": system["coordinates"].reshape(-1, nat, 3)[0], 345 "natoms": jnp.asarray([nat]), 346 "batch_index": jnp.asarray([0] * nat), 347 "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]), 348 } 349 if variable_cell: 350 conformation["cells"] = system["cell"][None, :, :] 351 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 352 None, :, : 353 ] 354 return conformation 355 356 @jax.jit 357 def update_dipole(ir_state, system, conformation): 358 if model_ir is not None: 359 out = model_ir._apply(model_ir.variables, conformation) 360 q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 361 dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 362 else: 363 q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 364 dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 365 if nbeads is not None: 366 q = jnp.mean(q, axis=0) 367 dip = jnp.mean(dip, axis=0) 368 vel = system["vel"][0] 369 pos = system["coordinates"][0] 370 else: 371 q = q[0] 372 dip = dip[0] 373 vel = system["vel"].reshape(-1, nat, 3)[0] 374 pos = system["coordinates"].reshape(-1, nat, 3)[0] 375 376 if pbc_data is not None: 377 cell_reciprocal = ( 378 conformation["cells"][0], 379 conformation["reciprocal_cells"][0], 380 ) 381 else: 382 cell_reciprocal = None 383 384 ir_state = save_dipole( 385 q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state 386 ) 387 return ir_state 388 389 ############################################### 390 ### GRAPH UPDATES 391 392 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 393 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 394 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS 395 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 396 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 397 if nblist_skin > 0: 398 if nblist_stride <= 0: 399 ## reference skin parameters at 300K (from Tinker-HP) 400 ## => skin of 2 A gives you 40 fs without complete rebuild 401 t_ref = 40.0 # FS 402 nblist_skin_ref = 2.0 # A 403 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 404 print( 405 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 406 ) 407 408 if nblist_skin <= 0: 409 nblist_stride = 1 410 411 dyn_state["nblist_countdown"] = 0 412 dyn_state["print_skin_activation"] = nblist_warmup > 0 413 414 def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False): 415 nblist_countdown = dyn_state["nblist_countdown"] 416 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 417 ### FULL NBLIST REBUILD 418 dyn_state["nblist_countdown"] = nblist_stride - 1 419 preproc_state = dyn_state["preproc_state"] 420 conformation = model.preprocessing.process( 421 preproc_state, update_conformation(conformation, system) 422 ) 423 preproc_state, state_up, conformation, overflow = ( 424 model.preprocessing.check_reallocate(preproc_state, conformation) 425 ) 426 dyn_state["preproc_state"] = preproc_state 427 if nblist_verbose and overflow: 428 print("step", istep, ", nblist overflow => reallocating nblist") 429 print("size updates:", state_up) 430 431 if do_ir_spectrum and model_ir is not None: 432 conformation_ir = model_ir.preprocessing.process( 433 dyn_state["preproc_state_ir"], 434 update_conformation_ir(dyn_state["conformation_ir"], system), 435 ) 436 ( 437 dyn_state["preproc_state_ir"], 438 _, 439 dyn_state["conformation_ir"], 440 overflow, 441 ) = model_ir.preprocessing.check_reallocate( 442 dyn_state["preproc_state_ir"], conformation_ir 443 ) 444 445 else: 446 ### SKIN UPDATE 447 if dyn_state["print_skin_activation"]: 448 if nblist_verbose: 449 print( 450 "step", 451 istep, 452 ", end of nblist warmup phase => activating skin updates", 453 ) 454 dyn_state["print_skin_activation"] = False 455 456 dyn_state["nblist_countdown"] = nblist_countdown - 1 457 conformation = model.preprocessing.update_skin( 458 update_conformation(conformation, system) 459 ) 460 if do_ir_spectrum and model_ir is not None: 461 dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin( 462 update_conformation_ir(dyn_state["conformation_ir"], system) 463 ) 464 465 return conformation, dyn_state 466 467 ################################################ 468 ### DEFINE STEP FUNCTION 469 def step(istep, dyn_state, system, conformation, force_preprocess=False): 470 471 dyn_state = { 472 **dyn_state, 473 "istep": dyn_state["istep"] + 1, 474 } 475 476 ### INTEGRATE EQUATIONS OF MOTION 477 system = integrate(system) 478 479 ### UPDATE CONFORMATION AND GRAPHS 480 conformation, dyn_state = update_graphs( 481 istep, dyn_state, system, conformation, force_preprocess 482 ) 483 484 ## COMPUTE FORCES AND OBSERVABLES 485 system, out = update_observables(system, conformation) 486 487 ## END OF STEP UPDATES 488 if do_thermostat_post: 489 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 490 system["thermostat"], dyn_state["thermostat_post_state"] 491 ) 492 493 if do_ir_spectrum: 494 ir_state = update_dipole( 495 dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"] 496 ) 497 dyn_state["ir_spectrum"] = ir_post(ir_state) 498 499 return dyn_state, system, conformation, out 500 501 ########################################################### 502 503 print("# Computing initial energy and forces") 504 505 conformation = update_conformation(conformation, system) 506 # initialize IR conformation 507 if do_ir_spectrum and model_ir is not None: 508 dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = ( 509 model_ir.preprocessing( 510 model_ir.preproc_state, 511 update_conformation_ir(conformation, system), 512 ) 513 ) 514 515 system, _ = update_observables(system, conformation) 516 517 return step, update_conformation, system_data, dyn_state, conformation, system
def
initialize_dynamics(simulation_parameters, fprec, rng_key):
20def initialize_dynamics(simulation_parameters, fprec, rng_key): 21 ### LOAD MODEL 22 model = load_model(simulation_parameters) 23 model_energy_unit = au.get_multiplier(model.energy_unit) 24 25 ### Get the coordinates and species from the xyz file 26 system_data, conformation = load_system_data(simulation_parameters, fprec) 27 28 ### FINISH BUILDING conformation 29 if os.path.exists(get_restart_file(system_data)): 30 ### RESTART FROM PREVIOUS DYNAMICS 31 restart_data = load_dynamics_restart(system_data) 32 print("# RESTARTING FROM PREVIOUS DYNAMICS") 33 model.preproc_state = restart_data["preproc_state"] 34 conformation["coordinates"] = restart_data["coordinates"] 35 else: 36 restart_data = {} 37 38 ### INITIALIZE PREPROCESSING 39 preproc_state, conformation = initialize_preprocessing( 40 simulation_parameters, model, conformation, system_data 41 ) 42 43 ### get dynamics parameters 44 dt = simulation_parameters.get("dt") * au.FS 45 dt2 = 0.5 * dt 46 mass = system_data["mass"] 47 totmass_amu = system_data["totmass_amu"] 48 nat = system_data["nat"] 49 dtm = jnp.asarray(dt / mass[:, None], dtype=fprec) 50 51 nreplicas = system_data.get("nreplicas", 1) 52 nbeads = system_data.get("nbeads", None) 53 if nbeads is not None: 54 nreplicas = nbeads 55 dtm = dtm[None, :, :] 56 57 ### INITIALIZE DYNAMICS STATE 58 system = {"coordinates": conformation["coordinates"]} 59 dyn_state = { 60 "istep": 0, 61 "dt": dt, 62 "pimd": nbeads is not None, 63 "preproc_state": preproc_state, 64 "start_time_ps": restart_data.get("simulation_time_ps", 0.), 65 } 66 gradient_keys = ["coordinates"] 67 thermo_updates = [] 68 69 ### INITIALIZE THERMOSTAT 70 thermostat_rng, rng_key = jax.random.split(rng_key) 71 ( 72 thermostat, 73 thermostat_post, 74 thermostat_state, 75 initial_vel, 76 dyn_state["thermostat_name"], 77 ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data) 78 do_thermostat_post = thermostat_post is not None 79 if do_thermostat_post: 80 thermostat_post, post_state = thermostat_post 81 dyn_state["thermostat_post_state"] = post_state 82 83 system["thermostat"] = thermostat_state 84 system["vel"] = restart_data.get("vel", initial_vel).astype(fprec) 85 86 ### PBC 87 pbc_data = system_data.get("pbc", None) 88 if pbc_data is not None: 89 ### INITIALIZE BAROSTAT 90 barostat_key, rng_key = jax.random.split(rng_key) 91 thermo_update_ensemble, variable_cell, barostat_state = get_barostat( 92 thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data 93 ) 94 estimate_pressure = variable_cell or pbc_data["estimate_pressure"] 95 system["barostat"] = barostat_state 96 system["cell"] = conformation["cells"][0] 97 if estimate_pressure: 98 pressure_o_weight = simulation_parameters.get("pressure_o_weight", 0.0) 99 assert ( 100 0.0 <= pressure_o_weight <= 1.0 101 ), "pressure_o_weight must be between 0 and 1" 102 gradient_keys.append("strain") 103 print("# Estimate pressure: ", estimate_pressure) 104 else: 105 estimate_pressure = False 106 variable_cell = False 107 108 def thermo_update_ensemble(x, v, system): 109 v, thermostat_state = thermostat(v, system["thermostat"]) 110 return x, v, {**system, "thermostat": thermostat_state} 111 112 dyn_state["estimate_pressure"] = estimate_pressure 113 dyn_state["variable_cell"] = variable_cell 114 thermo_updates.append(thermo_update_ensemble) 115 116 ### ENERGY ENSEMBLE 117 ensemble_key = simulation_parameters.get("etot_ensemble_key", None) 118 119 ### COLVARS 120 colvars_definitions = simulation_parameters.get("colvars", None) 121 use_colvars = colvars_definitions is not None 122 if use_colvars: 123 colvars_calculators, colvars_names = setup_colvars(colvars_definitions) 124 dyn_state["colvars"] = colvars_names 125 126 ### IR SPECTRUM 127 do_ir_spectrum = simulation_parameters.get("ir_spectrum", False) 128 assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean" 129 if do_ir_spectrum: 130 is_qtb = dyn_state["thermostat_name"].endswith("QTB") 131 model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum( 132 simulation_parameters, system_data, fprec, dt, is_qtb 133 ) 134 dyn_state["ir_spectrum"] = ir_state 135 136 ### BUILD GRADIENT FUNCTION 137 energy_and_gradient = model.get_gradient_function( 138 *gradient_keys, jit=True, variables_as_input=True 139 ) 140 141 ### COLLECT THERMO UPDATES 142 if len(thermo_updates) == 1: 143 thermo_update = thermo_updates[0] 144 else: 145 146 def thermo_update(x, v, system): 147 for update in thermo_updates: 148 x, v, system = update(x, v, system) 149 return x, v, system 150 151 ### RING POLYMER INITIALIZATION 152 if nbeads is not None: 153 cay_correction = simulation_parameters.get("cay_correction", True) 154 omk = system_data["omk"] 155 eigmat = system_data["eigmat"] 156 cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5 157 if cay_correction: 158 axx = jnp.asarray(2 * cayfact) 159 axv = jnp.asarray(dt * cayfact) 160 avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2) 161 else: 162 axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2)) 163 axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None]) 164 avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2)) 165 166 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 167 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 168 system["coordinates"] = eigx 169 170 ############################################### 171 ### DEFINE UPDATE FUNCTION 172 @jax.jit 173 def update_conformation(conformation, system): 174 x = system["coordinates"] 175 if nbeads is not None: 176 x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * ( 177 nbeads**0.5 178 ) 179 conformation = {**conformation, "coordinates": x} 180 if variable_cell: 181 conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0) 182 183 184 185 return conformation 186 187 ############################################### 188 ### DEFINE INTEGRATION FUNCTIONS 189 def integrate_A_half(x0, v0): 190 if nbeads is None: 191 return x0 + dt2 * v0, v0 192 193 # update coordinates and velocities of a free ring polymer for a half time step 194 eigx_c = x0[0] + dt2 * v0[0] 195 eigv_c = v0[0] 196 eigx = x0[1:] * axx + v0[1:] * axv 197 eigv = x0[1:] * avx + v0[1:] * axx 198 199 return ( 200 jnp.concatenate((eigx_c[None], eigx), axis=0), 201 jnp.concatenate((eigv_c[None], eigv), axis=0), 202 ) 203 204 @jax.jit 205 def integrate(system): 206 x = system["coordinates"] 207 v = system["vel"] + dtm * system["forces"] 208 x, v = integrate_A_half(x, v) 209 x, v, system = thermo_update(x, v, system) 210 x, v = integrate_A_half(x, v) 211 212 return {**system, "coordinates": x, "vel": v} 213 214 ############################################### 215 ### DEFINE OBSERVABLE FUNCTION 216 @jax.jit 217 def update_observables(system, conformation): 218 ### POTENTIAL ENERGY AND FORCES 219 epot, de, out = energy_and_gradient(model.variables, conformation) 220 epot = epot / model_energy_unit 221 de = {k: v / model_energy_unit for k, v in de.items()} 222 forces = -de["coordinates"] 223 224 if nbeads is not None: 225 ### PROJECT FORCES ONTO POLYMER NORMAL MODES 226 forces = jnp.einsum( 227 "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3) 228 ) * (1.0 / nbeads**0.5) 229 230 system = { 231 **system, 232 "epot": jnp.mean(epot), 233 "forces": forces, 234 "energy_gradients": de, 235 } 236 237 ### KINETIC ENERGY 238 v = system["vel"] 239 if nbeads is None: 240 corr_kin = system["thermostat"].get("corr_kin", 1.0) 241 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 242 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 243 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 244 ) 245 else: 246 ek_c = 0.5 * jnp.sum( 247 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0 248 ) 249 ek = ek_c - 0.5 * jnp.sum( 250 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 251 axis=(0, 1), 252 ) 253 system["ek_c"] = jnp.trace(ek_c) 254 255 system["ek"] = jnp.trace(ek) 256 system["ek_tensor"] = ek 257 258 if estimate_pressure: 259 if pressure_o_weight != 1.0: 260 v = system["vel"] + 0.5 * dtm * system["forces"] 261 if nbeads is None: 262 corr_kin = system["thermostat"].get("corr_kin", 1.0) 263 # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0) 264 ek = (0.5 / nreplicas / corr_kin) * jnp.sum( 265 mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0 266 ) 267 else: 268 ek_c = 0.5 * jnp.sum( 269 mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], 270 axis=0, 271 ) 272 ek = ek_c - 0.5 * jnp.sum( 273 system["coordinates"][1:, :, :, None] * forces[1:, :, None, :], 274 axis=(0, 1), 275 ) 276 b = pressure_o_weight 277 ek = (1.0 - b) * ek + b * system["ek_tensor"] 278 279 vir = jnp.mean(de["strain"], axis=0) 280 system["virial"] = vir 281 282 pV = 2 * ek - vir 283 system["PV_tensor"] = pV 284 volume = jnp.abs(jnp.linalg.det(system["cell"])) 285 Pres = pV / volume 286 system["pressure_tensor"] = Pres 287 system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0) 288 if variable_cell: 289 density = totmass_amu / volume 290 system["density"] = density 291 system["volume"] = volume 292 293 if ensemble_key is not None: 294 kT = system_data["kT"] 295 dE = ( 296 jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"] 297 ) 298 system["ensemble_weights"] = -dE / kT 299 300 if "total_dipole" in out: 301 if nbeads is None: 302 system["total_dipole"] = out["total_dipole"][0] 303 else: 304 system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0) 305 306 if use_colvars: 307 coords = system["coordinates"].reshape(-1, nat, 3)[0] 308 colvars = {} 309 for colvar_name, colvar_calc in colvars_calculators.items(): 310 colvars[colvar_name] = colvar_calc(coords) 311 system["colvars"] = colvars 312 313 return system, out 314 315 ############################################### 316 ### IR SPECTRUM 317 if do_ir_spectrum: 318 # @jax.jit 319 # def update_dipole(ir_state,system,conformation): 320 # def mumodel(coords): 321 # out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords}) 322 # if nbeads is None: 323 # return out["total_dipole"][0] 324 # return out["total_dipole"].sum(axis=0) 325 # dmudqmodel = jax.jacobian(mumodel) 326 327 # dmudq = dmudqmodel(conformation["coordinates"]) 328 # # print(dmudq.shape) 329 # if nbeads is None: 330 # vel = system["vel"].reshape(-1,1,nat,3)[0] 331 # mudot = (vel*dmudq).sum(axis=(1,2)) 332 # else: 333 # dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1) 334 # vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) * nbeads**0.5 335 # ) 336 # # vel = system["vel"][0].reshape(1,nat,3) 337 # mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads 338 339 # ir_state = save_dipole(mudot,ir_state) 340 # return ir_state 341 @jax.jit 342 def update_conformation_ir(conformation, system): 343 conformation = { 344 **conformation, 345 "coordinates": system["coordinates"].reshape(-1, nat, 3)[0], 346 "natoms": jnp.asarray([nat]), 347 "batch_index": jnp.asarray([0] * nat), 348 "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]), 349 } 350 if variable_cell: 351 conformation["cells"] = system["cell"][None, :, :] 352 conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[ 353 None, :, : 354 ] 355 return conformation 356 357 @jax.jit 358 def update_dipole(ir_state, system, conformation): 359 if model_ir is not None: 360 out = model_ir._apply(model_ir.variables, conformation) 361 q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 362 dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 363 else: 364 q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat)) 365 dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3)) 366 if nbeads is not None: 367 q = jnp.mean(q, axis=0) 368 dip = jnp.mean(dip, axis=0) 369 vel = system["vel"][0] 370 pos = system["coordinates"][0] 371 else: 372 q = q[0] 373 dip = dip[0] 374 vel = system["vel"].reshape(-1, nat, 3)[0] 375 pos = system["coordinates"].reshape(-1, nat, 3)[0] 376 377 if pbc_data is not None: 378 cell_reciprocal = ( 379 conformation["cells"][0], 380 conformation["reciprocal_cells"][0], 381 ) 382 else: 383 cell_reciprocal = None 384 385 ir_state = save_dipole( 386 q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state 387 ) 388 return ir_state 389 390 ############################################### 391 ### GRAPH UPDATES 392 393 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 394 nblist_stride = int(simulation_parameters.get("nblist_stride", -1)) 395 nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS 396 nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0 397 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 398 if nblist_skin > 0: 399 if nblist_stride <= 0: 400 ## reference skin parameters at 300K (from Tinker-HP) 401 ## => skin of 2 A gives you 40 fs without complete rebuild 402 t_ref = 40.0 # FS 403 nblist_skin_ref = 2.0 # A 404 nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt)) 405 print( 406 f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps" 407 ) 408 409 if nblist_skin <= 0: 410 nblist_stride = 1 411 412 dyn_state["nblist_countdown"] = 0 413 dyn_state["print_skin_activation"] = nblist_warmup > 0 414 415 def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False): 416 nblist_countdown = dyn_state["nblist_countdown"] 417 if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup): 418 ### FULL NBLIST REBUILD 419 dyn_state["nblist_countdown"] = nblist_stride - 1 420 preproc_state = dyn_state["preproc_state"] 421 conformation = model.preprocessing.process( 422 preproc_state, update_conformation(conformation, system) 423 ) 424 preproc_state, state_up, conformation, overflow = ( 425 model.preprocessing.check_reallocate(preproc_state, conformation) 426 ) 427 dyn_state["preproc_state"] = preproc_state 428 if nblist_verbose and overflow: 429 print("step", istep, ", nblist overflow => reallocating nblist") 430 print("size updates:", state_up) 431 432 if do_ir_spectrum and model_ir is not None: 433 conformation_ir = model_ir.preprocessing.process( 434 dyn_state["preproc_state_ir"], 435 update_conformation_ir(dyn_state["conformation_ir"], system), 436 ) 437 ( 438 dyn_state["preproc_state_ir"], 439 _, 440 dyn_state["conformation_ir"], 441 overflow, 442 ) = model_ir.preprocessing.check_reallocate( 443 dyn_state["preproc_state_ir"], conformation_ir 444 ) 445 446 else: 447 ### SKIN UPDATE 448 if dyn_state["print_skin_activation"]: 449 if nblist_verbose: 450 print( 451 "step", 452 istep, 453 ", end of nblist warmup phase => activating skin updates", 454 ) 455 dyn_state["print_skin_activation"] = False 456 457 dyn_state["nblist_countdown"] = nblist_countdown - 1 458 conformation = model.preprocessing.update_skin( 459 update_conformation(conformation, system) 460 ) 461 if do_ir_spectrum and model_ir is not None: 462 dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin( 463 update_conformation_ir(dyn_state["conformation_ir"], system) 464 ) 465 466 return conformation, dyn_state 467 468 ################################################ 469 ### DEFINE STEP FUNCTION 470 def step(istep, dyn_state, system, conformation, force_preprocess=False): 471 472 dyn_state = { 473 **dyn_state, 474 "istep": dyn_state["istep"] + 1, 475 } 476 477 ### INTEGRATE EQUATIONS OF MOTION 478 system = integrate(system) 479 480 ### UPDATE CONFORMATION AND GRAPHS 481 conformation, dyn_state = update_graphs( 482 istep, dyn_state, system, conformation, force_preprocess 483 ) 484 485 ## COMPUTE FORCES AND OBSERVABLES 486 system, out = update_observables(system, conformation) 487 488 ## END OF STEP UPDATES 489 if do_thermostat_post: 490 system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post( 491 system["thermostat"], dyn_state["thermostat_post_state"] 492 ) 493 494 if do_ir_spectrum: 495 ir_state = update_dipole( 496 dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"] 497 ) 498 dyn_state["ir_spectrum"] = ir_post(ir_state) 499 500 return dyn_state, system, conformation, out 501 502 ########################################################### 503 504 print("# Computing initial energy and forces") 505 506 conformation = update_conformation(conformation, system) 507 # initialize IR conformation 508 if do_ir_spectrum and model_ir is not None: 509 dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = ( 510 model_ir.preprocessing( 511 model_ir.preproc_state, 512 update_conformation_ir(conformation, system), 513 ) 514 ) 515 516 system, _ = update_observables(system, conformation) 517 518 return step, update_conformation, system_data, dyn_state, conformation, system