fennol.md.thermostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7import os 8import pickle 9 10from ..utils.atomic_units import AtomicUnits as au # CM1,THZ,BOHR,MPROT 11from ..utils import Counter 12from ..utils.deconvolution import ( 13 deconvolute_spectrum, 14 kernel_lorentz_pot, 15 kernel_lorentz, 16) 17 18 19def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}): 20 state = {} 21 postprocess = None 22 23 24 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 25 """@keyword[fennol_md] thermostat 26 Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'. 27 Default: "LGV" 28 """ 29 compute_thermostat_energy = simulation_parameters.get( 30 "include_thermostat_energy", False 31 ) 32 """@keyword[fennol_md] include_thermostat_energy 33 Include thermostat energy in total energy calculations. 34 Default: False 35 """ 36 37 kT = system_data.get("kT", None) 38 nbeads = system_data.get("nbeads", None) 39 mass = system_data["mass"] 40 gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS 41 """@keyword[fennol_md] gamma 42 Friction coefficient for Langevin thermostat. 43 Default: 1.0 ps^-1 44 """ 45 species = system_data["species"] 46 47 if nbeads is not None: 48 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 49 """@keyword[fennol_md] trpmd_lambda 50 Lambda parameter for TRPMD (Thermostatted Ring Polymer MD). 51 Default: 1.0 52 """ 53 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 54 55 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 56 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 57 assert kT is not None, "kT must be specified for QTB thermostat" 58 assert gamma is not None, "gamma must be specified for QTB thermostat" 59 rng_key, v_key = jax.random.split(rng_key) 60 if nbeads is None: 61 a1 = math.exp(-gamma * dt) 62 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 63 vel = ( 64 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 65 * (kT / mass[:, None]) ** 0.5 66 ) 67 else: 68 if isinstance(gamma, float): 69 gamma = np.array([gamma] * nbeads) 70 assert isinstance( 71 gamma, np.ndarray 72 ), "gamma must be a float or a numpy array" 73 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 74 a1 = np.exp(-gamma * dt)[:, None, None] 75 a2 = jnp.asarray( 76 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 77 ) 78 vel = ( 79 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 80 * (kT / mass[:, None]) ** 0.5 81 ) 82 83 state["rng_key"] = rng_key 84 if compute_thermostat_energy: 85 state["thermostat_energy"] = 0.0 86 if thermostat_name == "FFLGV": 87 def thermostat(vel, state): 88 rng_key, noise_key = jax.random.split(state["rng_key"]) 89 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 90 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 91 dirvel = vel / norm_vel 92 if compute_thermostat_energy: 93 v2 = (vel**2).sum(axis=-1) 94 vel = a1 * vel + a2 * noise 95 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 96 vel = dirvel * new_norm_vel 97 new_state = {**state, "rng_key": rng_key} 98 if compute_thermostat_energy: 99 v2new = (vel**2).sum(axis=-1) 100 new_state["thermostat_energy"] = ( 101 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 102 ) 103 104 return vel, new_state 105 106 else: 107 def thermostat(vel, state): 108 rng_key, noise_key = jax.random.split(state["rng_key"]) 109 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 110 if compute_thermostat_energy: 111 v2 = (vel**2).sum(axis=-1) 112 vel = a1 * vel + a2 * noise 113 new_state = {**state, "rng_key": rng_key} 114 if compute_thermostat_energy: 115 v2new = (vel**2).sum(axis=-1) 116 new_state["thermostat_energy"] = ( 117 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 118 ) 119 return vel, new_state 120 121 elif thermostat_name in ["BUSSI"]: 122 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 123 assert kT is not None, "kT must be specified for QTB thermostat" 124 assert gamma is not None, "gamma must be specified for QTB thermostat" 125 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 126 rng_key, v_key = jax.random.split(rng_key) 127 128 a1 = math.exp(-gamma * dt) 129 a2 = (1 - a1) * kT 130 vel = ( 131 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 132 * (kT / mass[:, None]) ** 0.5 133 ) 134 135 state["rng_key"] = rng_key 136 if compute_thermostat_energy: 137 state["thermostat_energy"] = 0.0 138 139 def thermostat(vel, state): 140 rng_key, noise_key = jax.random.split(state["rng_key"]) 141 new_state = {**state, "rng_key": rng_key} 142 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 143 R2 = jnp.sum(noise**2) 144 R1 = noise[0, 0] 145 c = a2 / (mass[:, None] * vel**2).sum() 146 d = (a1 * c) ** 0.5 147 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 148 if compute_thermostat_energy: 149 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 150 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 151 return scale * vel, new_state 152 153 elif thermostat_name in [ 154 "GD", 155 "DESCENT", 156 "GRADIENT_DESCENT", 157 "MIN", 158 "MINIMIZE", 159 ]: 160 assert nbeads is None, "Gradient descent is not compatible with PIMD" 161 a1 = math.exp(-gamma * dt) 162 163 if nbeads is None: 164 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 165 else: 166 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 167 168 def thermostat(vel, state): 169 return a1 * vel, state 170 171 elif thermostat_name in ["NVE", "NONE"]: 172 if nbeads is None: 173 vel = ( 174 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 175 * (kT / mass[:, None]) ** 0.5 176 ) 177 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 178 vel = vel * (kT / kTsys) ** 0.5 179 else: 180 vel = ( 181 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 182 * (kT / mass[None, :, None]) ** 0.5 183 ) 184 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 185 mass.shape[0] * 3 186 ) 187 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 188 thermostat = lambda x, s: (x, s) 189 190 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 191 assert gamma is not None, "gamma must be specified for QTB thermostat" 192 ndof = mass.shape[0] * 3 193 nkT = ndof * kT 194 nose_mass = nkT / gamma**2 195 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 196 state["nose_s"] = 0.0 197 state["nose_v"] = 0.0 198 if compute_thermostat_energy: 199 state["thermostat_energy"] = 0.0 200 print( 201 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 202 ) 203 vel = ( 204 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 205 * (kT / mass[:, None]) ** 0.5 206 ) 207 208 def thermostat(vel, state): 209 nose_s = state["nose_s"] 210 nose_v = state["nose_v"] 211 kTsys = jnp.sum(mass[:, None] * vel**2) 212 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 213 nose_s = nose_s + dt * nose_v 214 vel = jnp.exp(-nose_v * dt) * vel 215 kTsys = jnp.sum(mass[:, None] * vel**2) 216 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 217 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 218 219 if compute_thermostat_energy: 220 new_state["thermostat_energy"] = ( 221 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 222 ) 223 return vel, new_state 224 225 elif thermostat_name in ["QTB", "ADQTB"]: 226 assert nbeads is None, "QTB is not compatible with PIMD" 227 qtb_parameters = simulation_parameters.get("qtb", None) 228 """@keyword[fennol_md] qtb 229 Parameters for Quantum Thermal Bath thermostat configuration. 230 Required for QTB/ADQTB thermostats 231 """ 232 assert ( 233 qtb_parameters is not None 234 ), "qtb_parameters must be provided for QTB thermostat" 235 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 236 assert kT is not None, "kT must be specified for QTB thermostat" 237 assert gamma is not None, "gamma must be specified for QTB thermostat" 238 assert species is not None, "species must be provided for QTB thermostat" 239 rng_key, v_key = jax.random.split(rng_key) 240 vel = ( 241 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 242 * (kT / mass[:, None]) ** 0.5 243 ) 244 245 thermostat, postprocess, qtb_state = initialize_qtb( 246 qtb_parameters, 247 system_data, 248 fprec=fprec, 249 dt=dt, 250 mass=mass, 251 gamma=gamma, 252 kT=kT, 253 species=species, 254 rng_key=rng_key, 255 adaptive=thermostat_name.startswith("AD"), 256 compute_thermostat_energy=compute_thermostat_energy, 257 ) 258 state = {**state, **qtb_state} 259 260 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 261 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 262 assert kT is not None, "kT must be specified for QTB thermostat" 263 assert gamma is not None, "gamma must be specified for QTB thermostat" 264 assert nbeads is None, "ANNEAL is not compatible with PIMD" 265 a1 = math.exp(-gamma * dt) 266 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 267 268 anneal_parameters = simulation_parameters.get("annealing", {}) 269 """@keyword[fennol_md] annealing 270 Parameters for simulated annealing schedule configuration. 271 Required for ANNEAL/ANNEALING thermostat 272 """ 273 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 274 """@keyword[fennol_md] annealing/init_factor 275 Initial temperature factor for annealing schedule. 276 Default: 0.04 (1/25) 277 """ 278 assert init_factor > 0.0, "init_factor must be positive" 279 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 280 """@keyword[fennol_md] annealing/final_factor 281 Final temperature factor for annealing schedule. 282 Default: 0.0001 (1/10000) 283 """ 284 assert final_factor > 0.0, "final_factor must be positive" 285 nsteps = simulation_parameters.get("nsteps") 286 """@keyword[fennol_md] nsteps 287 Total number of simulation steps for annealing schedule calculation. 288 Required parameter 289 """ 290 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 291 """@keyword[fennol_md] annealing/anneal_steps 292 Fraction of total steps for annealing process. 293 Default: 1.0 294 """ 295 assert ( 296 anneal_steps < 1.0 and anneal_steps > 0.0 297 ), "warmup_steps must be between 0 and nsteps" 298 pct_start = anneal_parameters.get("warmup_steps", 0.3) 299 """@keyword[fennol_md] annealing/warmup_steps 300 Fraction of annealing steps for warmup phase. 301 Default: 0.3 302 """ 303 assert ( 304 pct_start < 1.0 and pct_start > 0.0 305 ), "warmup_steps must be between 0 and nsteps" 306 307 anneal_type = anneal_parameters.get("type", "cosine").lower() 308 """@keyword[fennol_md] annealing/type 309 Type of annealing schedule (linear, cosine_onecycle). 310 Default: cosine 311 """ 312 if anneal_type == "linear": 313 schedule = optax.linear_onecycle_schedule( 314 peak_value=1.0, 315 div_factor=1.0 / init_factor, 316 final_div_factor=1.0 / final_factor, 317 transition_steps=int(anneal_steps * nsteps), 318 pct_start=pct_start, 319 pct_final=1.0, 320 ) 321 elif anneal_type == "cosine_onecycle": 322 schedule = optax.cosine_onecycle_schedule( 323 peak_value=1.0, 324 div_factor=1.0 / init_factor, 325 final_div_factor=1.0 / final_factor, 326 transition_steps=int(anneal_steps * nsteps), 327 pct_start=pct_start, 328 ) 329 else: 330 raise ValueError(f"Unknown anneal_type {anneal_type}") 331 332 state["rng_key"] = rng_key 333 state["istep_anneal"] = 0 334 335 rng_key, v_key = jax.random.split(rng_key) 336 Tscale = schedule(0) 337 print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K") 338 vel = ( 339 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 340 * (kT * Tscale / mass[:, None]) ** 0.5 341 ) 342 343 def thermostat(vel, state): 344 rng_key, noise_key = jax.random.split(state["rng_key"]) 345 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 346 347 Tscale = schedule(state["istep_anneal"]) ** 0.5 348 vel = a1 * vel + a2 * Tscale * noise 349 return vel, { 350 **state, 351 "rng_key": rng_key, 352 "istep_anneal": state["istep_anneal"] + 1, 353 } 354 355 else: 356 raise ValueError(f"Unknown thermostat {thermostat_name}") 357 358 return thermostat, postprocess, state, vel,thermostat_name 359 360 361def initialize_qtb( 362 qtb_parameters, 363 system_data, 364 fprec, 365 dt, 366 mass, 367 gamma, 368 kT, 369 species, 370 rng_key, 371 adaptive, 372 compute_thermostat_energy=False, 373): 374 state = {} 375 post_state = {} 376 verbose = qtb_parameters.get("verbose", False) 377 """@keyword[fennol_md] qtb/verbose 378 Print verbose QTB thermostat information. 379 Default: False 380 """ 381 if compute_thermostat_energy: 382 state["thermostat_energy"] = 0.0 383 384 mass = jnp.asarray(mass, dtype=fprec) 385 386 nat = species.shape[0] 387 # define type indices 388 species_set = set(species) 389 nspecies = len(species_set) 390 idx = {sp: i for i, sp in enumerate(species_set)} 391 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 392 393 n_of_type = np.zeros(nspecies, dtype=np.int32) 394 for i in range(nspecies): 395 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 396 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 397 mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type 398 399 niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7) 400 """@keyword[fennol_md] qtb/niter_deconv_kin 401 Number of iterations for kinetic energy deconvolution. 402 Default: 7 403 """ 404 niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20) 405 """@keyword[fennol_md] qtb/niter_deconv_pot 406 Number of iterations for potential energy deconvolution. 407 Default: 20 408 """ 409 corr_kin = qtb_parameters.get("corr_kin", -1) 410 """@keyword[fennol_md] qtb/corr_kin 411 Kinetic energy correction factor for QTB (-1 for automatic). 412 Default: -1 413 """ 414 do_corr_kin = corr_kin <= 0 415 if do_corr_kin: 416 corr_kin = 1.0 417 state["corr_kin"] = corr_kin 418 post_state["corr_kin_prev"] = corr_kin 419 post_state["do_corr_kin"] = do_corr_kin 420 post_state["isame_kin"] = 0 421 422 # spectra parameters 423 omegasmear = np.pi / dt / 100.0 424 Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS 425 """@keyword[fennol_md] qtb/tseg 426 Time segment length for QTB spectrum calculation. 427 Default: 1.0 ps 428 """ 429 nseg = int(Tseg / dt) 430 Tseg = nseg * dt 431 dom = 2 * np.pi / (3 * Tseg) 432 omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS 433 """@keyword[fennol_md] qtb/omegacut 434 Cutoff frequency for QTB spectrum. 435 Default: 15000.0 cm⁻¹ 436 """ 437 nom = int(omegacut / dom) 438 omega = dom * np.arange((3 * nseg) // 2 + 1) 439 cutoff = jnp.asarray( 440 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 441 ) 442 assert ( 443 omegacut < omega[-1] 444 ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1" 445 446 # initialize gammar 447 assert ( 448 gamma < 0.5 * omegacut 449 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 450 gammar_min = qtb_parameters.get("gammar_min", 0.1) 451 """@keyword[fennol_md] qtb/gammar_min 452 Minimum value for QTB gamma ratio coefficients. 453 Default: 0.1 454 """ 455 # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 456 gammar = np.ones((nspecies, nom), dtype=float) 457 try: 458 for i, sp in enumerate(species_set): 459 if not os.path.exists(f"QTB_spectra_{sp}.out"): continue 460 data = np.loadtxt(f"QTB_spectra_{sp}.out") 461 gammar[i] = data[:, 4]/(gamma*au.FS*au.THZ) 462 print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out") 463 except Exception as e: 464 print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.") 465 gammar[:,:] = 1.0 466 post_state["gammar"] = jnp.asarray(gammar, dtype=fprec) 467 468 # Ornstein-Uhlenbeck correction for colored noise 469 a1 = np.exp(-gamma * dt) 470 OUcorr = jnp.asarray( 471 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 472 dtype=fprec, 473 ) 474 475 # hbar schedule 476 classical_kernel = qtb_parameters.get("classical_kernel", False) 477 """@keyword[fennol_md] qtb/classical_kernel 478 Use classical instead of quantum kernel for QTB. 479 Default: False 480 """ 481 hbar = qtb_parameters.get("hbar", 1.0) * au.FS 482 """@keyword[fennol_md] qtb/hbar 483 Reduced Planck constant scaling factor for quantum effects. 484 Default: 1.0 a.u. 485 """ 486 u = 0.5 * hbar * np.abs(omega) / kT 487 theta = kT * np.ones_like(omega) 488 if hbar > 0: 489 theta[1:] *= u[1:] / np.tanh(u[1:]) 490 theta = jnp.asarray(theta, dtype=fprec) 491 492 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 493 del rng_key 494 post_state["white_noise"] = jax.random.normal( 495 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 496 ) 497 498 startsave = qtb_parameters.get("startsave", 1) 499 """@keyword[fennol_md] qtb/startsave 500 Start saving QTB statistics after this many segments. 501 Default: 1 502 """ 503 counter = Counter(nseg, startsave=startsave) 504 state["istep"] = 0 505 post_state["nadapt"] = 0 506 post_state["nsample"] = 0 507 508 write_spectra = qtb_parameters.get("write_spectra", True) 509 """@keyword[fennol_md] qtb/write_spectra 510 Write QTB spectral analysis output files. 511 Default: True 512 """ 513 do_compute_spectra = write_spectra or adaptive 514 515 if do_compute_spectra: 516 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 517 518 post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec) 519 post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec) 520 post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec) 521 post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec) 522 post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 523 post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 524 post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 525 post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 526 527 if not adaptive: 528 update_gammar = lambda x: x 529 else: 530 # adaptation parameters 531 skipseg = qtb_parameters.get("skipseg", 1) 532 """@keyword[fennol_md] qtb/skipseg 533 Number of segments to skip before starting adaptive QTB. 534 Default: 1 535 """ 536 537 adaptation_method = ( 538 str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip() 539 ) 540 """@keyword[fennol_md] qtb/adaptation_method 541 Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF). 542 Default: ADABELIEF 543 """ 544 authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"] 545 assert ( 546 adaptation_method in authorized_methods 547 ), f"adaptation_method must be one of {authorized_methods}" 548 if adaptation_method == "SIMPLE": 549 agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS 550 """@keyword[fennol_md] qtb/agamma 551 Learning rate for adaptive QTB gamma update. 552 Default: 1.0e-3 (SIMPLE), 0.1 (ADABELIEF) 553 """ 554 assert agamma > 0, "agamma must be positive" 555 a1_ad = agamma * Tseg # * gamma 556 print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}") 557 558 def update_gammar(post_state): 559 g = post_state["dFDT"] 560 gammar = post_state["gammar"] - a1_ad * g 561 gammar = jnp.maximum(gammar_min, gammar) 562 return {**post_state, "gammar": gammar} 563 564 elif adaptation_method == "RATIO": 565 tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS 566 """@keyword[fennol_md] qtb/tau_ad 567 Adaptation time constant for momentum averaging. 568 Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF) 569 """ 570 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS 571 """@keyword[fennol_md] qtb/tau_s 572 Second moment time constant for variance averaging. 573 Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF) 574 """ 575 assert tau_ad > 0, "tau_ad must be positive" 576 print( 577 f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps" 578 ) 579 b1 = np.exp(-Tseg / tau_ad) 580 b2 = np.exp(-Tseg / tau_s) 581 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 582 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 583 post_state["n_adabelief"] = 0 584 585 def update_gammar(post_state): 586 n_adabelief = post_state["n_adabelief"] + 1 587 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 588 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 589 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 590 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 591 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 592 gammar = Cvf / (mCvv + 1.0e-8) 593 gammar = jnp.maximum(gammar_min, gammar) 594 return { 595 **post_state, 596 "gammar": gammar, 597 "mCvv_m": mCvv_m, 598 "Cvf_m": Cvf_m, 599 "n_adabelief": n_adabelief, 600 } 601 602 elif adaptation_method == "ADABELIEF": 603 agamma = qtb_parameters.get("agamma", 0.1) 604 tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS 605 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS 606 assert tau_ad > 0, "tau_ad must be positive" 607 assert tau_s > 0, "tau_s must be positive" 608 assert agamma > 0, "agamma must be positive" 609 print( 610 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps" 611 ) 612 613 a1_ad = agamma * gamma # * Tseg #* gamma 614 b1 = np.exp(-Tseg / tau_ad) 615 b2 = np.exp(-Tseg / tau_s) 616 post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 617 post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec) 618 post_state["n_adabelief"] = 0 619 620 def update_gammar(post_state): 621 n_adabelief = post_state["n_adabelief"] + 1 622 dFDT = post_state["dFDT"] 623 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 624 dFDT_s = ( 625 post_state["dFDT_s"] * b2 626 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 627 + 1.0e-8 628 ) 629 # bias correction 630 mt = dFDT_m / (1.0 - b1**n_adabelief) 631 st = dFDT_s / (1.0 - b2**n_adabelief) 632 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 633 gammar = jnp.maximum(gammar_min, gammar) 634 return { 635 **post_state, 636 "gammar": gammar, 637 "dFDT_m": dFDT_m, 638 "n_adabelief": n_adabelief, 639 "dFDT_s": dFDT_s, 640 } 641 642 ##################### 643 # RESTART 644 restart_file = system_data["name"]+".qtb.restart" 645 if os.path.exists(restart_file): 646 with open(restart_file, "rb") as f: 647 data = pickle.load(f) 648 state["corr_kin"] = data["corr_kin"] 649 post_state["corr_kin_prev"] = data["corr_kin"] 650 post_state["isame_kin"] = data["isame_kin"] 651 post_state["do_corr_kin"] = data["do_corr_kin"] 652 print(f"# Restored QTB state from {restart_file}") 653 654 def write_qtb_restart(state, post_state): 655 with open(restart_file, "wb") as f: 656 pickle.dump( 657 { 658 "corr_kin": state["corr_kin"], 659 "corr_kin_prev": post_state["corr_kin_prev"], 660 "isame_kin": post_state["isame_kin"], 661 "do_corr_kin": post_state["do_corr_kin"], 662 }, 663 f, 664 ) 665 ###################### 666 667 def compute_corr_pot(niter=20, verbose=False): 668 if classical_kernel or hbar == 0: 669 return np.ones(nom) 670 671 s_0 = np.array((theta / kT * cutoff)[:nom]) 672 s_out, s_rec, _ = deconvolute_spectrum( 673 s_0, 674 omega[:nom], 675 gamma, 676 niter, 677 kernel=kernel_lorentz_pot, 678 trans=True, 679 symmetrize=True, 680 verbose=verbose, 681 ) 682 corr_pot = 1.0 + (s_out - s_0) / s_0 683 columns = np.column_stack( 684 (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 685 ) 686 np.savetxt( 687 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 688 ) 689 return corr_pot 690 691 def compute_corr_kin(post_state, niter=7, verbose=False): 692 if not post_state["do_corr_kin"]: 693 return post_state["corr_kin_prev"], post_state 694 if classical_kernel or hbar == 0: 695 return 1.0, post_state 696 697 K_D = post_state.get("K_D", None) 698 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 699 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 700 s_out, s_rec, K_D = deconvolute_spectrum( 701 s_0, 702 omega[:nom], 703 gamma, 704 niter, 705 kernel=kernel_lorentz, 706 trans=False, 707 symmetrize=True, 708 verbose=verbose, 709 K_D=K_D, 710 ) 711 s_out = s_out * theta[:nom] / kT 712 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 713 mCvvsum = mCvv.sum() 714 rec_ratio = mCvvsum / s_rec.sum() 715 if rec_ratio < 0.95 or rec_ratio > 1.05: 716 print( 717 f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated" 718 ) 719 return post_state["corr_kin_prev"], post_state 720 721 corr_kin = mCvvsum / s_out.sum() 722 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 723 isame_kin = post_state["isame_kin"] + 1 724 else: 725 isame_kin = 0 726 727 # print("# corr_kin: ", corr_kin) 728 do_corr_kin = post_state["do_corr_kin"] 729 if isame_kin > 10: 730 print( 731 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 732 ) 733 do_corr_kin = False 734 735 return corr_kin, { 736 **post_state, 737 "corr_kin_prev": corr_kin, 738 "isame_kin": isame_kin, 739 "do_corr_kin": do_corr_kin, 740 "K_D": K_D, 741 } 742 743 @jax.jit 744 def ff_kernel(post_state): 745 if classical_kernel: 746 kernel = cutoff * (2 * gamma * kT / dt) 747 else: 748 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 749 gamma_ratio = jnp.concatenate( 750 ( 751 post_state["gammar"].T * post_state["corr_pot"][:, None], 752 jnp.ones( 753 (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype 754 ), 755 ), 756 axis=0, 757 ) 758 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 759 760 @jax.jit 761 def refresh_force(post_state): 762 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 763 white_noise = jnp.concatenate( 764 ( 765 post_state["white_noise"][nseg:], 766 jax.random.normal( 767 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 768 ), 769 ), 770 axis=0, 771 ) 772 amplitude = ff_kernel(post_state) ** 0.5 773 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 774 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 775 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 776 777 @jax.jit 778 def compute_spectra(force, vel, post_state): 779 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 780 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 781 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 782 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 783 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 784 785 mCvv = ( 786 (dt / 3.0) 787 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 788 * mass_idx[:, None] 789 / n_of_type[:, None] 790 ) 791 Cvf = ( 792 (dt / 3.0) 793 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 794 / n_of_type[:, None] 795 ) 796 Cff = ( 797 (dt / 3.0) 798 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 799 / n_of_type[:, None] 800 ) 801 dFDT = mCvv * post_state["gammar"] - Cvf 802 803 nsinv = 1.0 / post_state["nsample"] 804 b1 = 1.0 - nsinv 805 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 806 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 807 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 808 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 809 810 return { 811 **post_state, 812 "mCvv": mCvv, 813 "Cvf": Cvf, 814 "Cff": Cff, 815 "dFDT": dFDT, 816 "dFDT_avg": dFDT_avg, 817 "mCvv_avg": mCvv_avg, 818 "Cvfg_avg": Cvfg_avg, 819 "Cff_avg": Cff_avg, 820 } 821 822 def write_spectra_to_file(post_state): 823 mCvv_avg = np.array(post_state["mCvv_avg"]) 824 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 825 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 826 dFDT_avg = np.array(post_state["dFDT_avg"]) 827 gammar = np.array(post_state["gammar"]) 828 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 829 for i, sp in enumerate(species_set): 830 ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 831 columns = np.column_stack( 832 ( 833 omega[:nom] * (au.FS * au.CM1), 834 mCvv_avg[i], 835 Cvfg_avg[i], 836 dFDT_avg[i], 837 gammar[i] * gamma * (au.FS * au.THZ), 838 Cff_avg[i] * ff_scale, 839 Cff_theo[i] * ff_scale, 840 ) 841 ) 842 np.savetxt( 843 f"QTB_spectra_{sp}.out", 844 columns, 845 fmt="%12.6f", 846 header="#omega mCvv Cvf dFDT gammar Cff", 847 ) 848 if verbose: 849 print("# QTB spectra written.") 850 851 if compute_thermostat_energy: 852 state["qtb_energy_flux"] = 0.0 853 854 @jax.jit 855 def thermostat(vel, state): 856 istep = state["istep"] 857 dvel = dt * state["force"][istep] / mass[:, None] 858 new_vel = vel * a1 + dvel 859 new_state = {**state, "istep": istep + 1} 860 if do_compute_spectra: 861 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 862 new_state["vel"] = vel2 863 if compute_thermostat_energy: 864 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 865 ekcorr = ( 866 0.5 867 * (mass[:, None] * new_vel**2).sum() 868 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 869 ) 870 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 871 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 872 return new_vel, new_state 873 874 @jax.jit 875 def postprocess_work(state, post_state): 876 if do_compute_spectra: 877 post_state = compute_spectra(state["force"], state["vel"], post_state) 878 if adaptive: 879 post_state = jax.lax.cond( 880 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 881 ) 882 new_force, post_state = refresh_force(post_state) 883 return {**state, "force": new_force}, post_state 884 885 def postprocess(state, post_state): 886 counter.increment() 887 if not counter.is_reset_step: 888 return state, post_state 889 post_state["nadapt"] += 1 890 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 891 if verbose: 892 print("# Refreshing QTB forces.") 893 state, post_state = postprocess_work(state, post_state) 894 state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin) 895 state["istep"] = 0 896 if write_spectra: 897 write_spectra_to_file(post_state) 898 write_qtb_restart(state, post_state) 899 return state, post_state 900 901 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec) 902 903 state["force"], post_state = refresh_force(post_state) 904 return thermostat, (postprocess, post_state), state
def
get_thermostat( simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
20def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}): 21 state = {} 22 postprocess = None 23 24 25 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 26 """@keyword[fennol_md] thermostat 27 Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'. 28 Default: "LGV" 29 """ 30 compute_thermostat_energy = simulation_parameters.get( 31 "include_thermostat_energy", False 32 ) 33 """@keyword[fennol_md] include_thermostat_energy 34 Include thermostat energy in total energy calculations. 35 Default: False 36 """ 37 38 kT = system_data.get("kT", None) 39 nbeads = system_data.get("nbeads", None) 40 mass = system_data["mass"] 41 gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS 42 """@keyword[fennol_md] gamma 43 Friction coefficient for Langevin thermostat. 44 Default: 1.0 ps^-1 45 """ 46 species = system_data["species"] 47 48 if nbeads is not None: 49 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 50 """@keyword[fennol_md] trpmd_lambda 51 Lambda parameter for TRPMD (Thermostatted Ring Polymer MD). 52 Default: 1.0 53 """ 54 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 55 56 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 57 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 58 assert kT is not None, "kT must be specified for QTB thermostat" 59 assert gamma is not None, "gamma must be specified for QTB thermostat" 60 rng_key, v_key = jax.random.split(rng_key) 61 if nbeads is None: 62 a1 = math.exp(-gamma * dt) 63 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 64 vel = ( 65 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 66 * (kT / mass[:, None]) ** 0.5 67 ) 68 else: 69 if isinstance(gamma, float): 70 gamma = np.array([gamma] * nbeads) 71 assert isinstance( 72 gamma, np.ndarray 73 ), "gamma must be a float or a numpy array" 74 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 75 a1 = np.exp(-gamma * dt)[:, None, None] 76 a2 = jnp.asarray( 77 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 78 ) 79 vel = ( 80 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 81 * (kT / mass[:, None]) ** 0.5 82 ) 83 84 state["rng_key"] = rng_key 85 if compute_thermostat_energy: 86 state["thermostat_energy"] = 0.0 87 if thermostat_name == "FFLGV": 88 def thermostat(vel, state): 89 rng_key, noise_key = jax.random.split(state["rng_key"]) 90 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 91 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 92 dirvel = vel / norm_vel 93 if compute_thermostat_energy: 94 v2 = (vel**2).sum(axis=-1) 95 vel = a1 * vel + a2 * noise 96 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 97 vel = dirvel * new_norm_vel 98 new_state = {**state, "rng_key": rng_key} 99 if compute_thermostat_energy: 100 v2new = (vel**2).sum(axis=-1) 101 new_state["thermostat_energy"] = ( 102 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 103 ) 104 105 return vel, new_state 106 107 else: 108 def thermostat(vel, state): 109 rng_key, noise_key = jax.random.split(state["rng_key"]) 110 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 111 if compute_thermostat_energy: 112 v2 = (vel**2).sum(axis=-1) 113 vel = a1 * vel + a2 * noise 114 new_state = {**state, "rng_key": rng_key} 115 if compute_thermostat_energy: 116 v2new = (vel**2).sum(axis=-1) 117 new_state["thermostat_energy"] = ( 118 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 119 ) 120 return vel, new_state 121 122 elif thermostat_name in ["BUSSI"]: 123 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 124 assert kT is not None, "kT must be specified for QTB thermostat" 125 assert gamma is not None, "gamma must be specified for QTB thermostat" 126 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 127 rng_key, v_key = jax.random.split(rng_key) 128 129 a1 = math.exp(-gamma * dt) 130 a2 = (1 - a1) * kT 131 vel = ( 132 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 133 * (kT / mass[:, None]) ** 0.5 134 ) 135 136 state["rng_key"] = rng_key 137 if compute_thermostat_energy: 138 state["thermostat_energy"] = 0.0 139 140 def thermostat(vel, state): 141 rng_key, noise_key = jax.random.split(state["rng_key"]) 142 new_state = {**state, "rng_key": rng_key} 143 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 144 R2 = jnp.sum(noise**2) 145 R1 = noise[0, 0] 146 c = a2 / (mass[:, None] * vel**2).sum() 147 d = (a1 * c) ** 0.5 148 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 149 if compute_thermostat_energy: 150 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 151 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 152 return scale * vel, new_state 153 154 elif thermostat_name in [ 155 "GD", 156 "DESCENT", 157 "GRADIENT_DESCENT", 158 "MIN", 159 "MINIMIZE", 160 ]: 161 assert nbeads is None, "Gradient descent is not compatible with PIMD" 162 a1 = math.exp(-gamma * dt) 163 164 if nbeads is None: 165 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 166 else: 167 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 168 169 def thermostat(vel, state): 170 return a1 * vel, state 171 172 elif thermostat_name in ["NVE", "NONE"]: 173 if nbeads is None: 174 vel = ( 175 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 176 * (kT / mass[:, None]) ** 0.5 177 ) 178 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 179 vel = vel * (kT / kTsys) ** 0.5 180 else: 181 vel = ( 182 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 183 * (kT / mass[None, :, None]) ** 0.5 184 ) 185 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 186 mass.shape[0] * 3 187 ) 188 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 189 thermostat = lambda x, s: (x, s) 190 191 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 192 assert gamma is not None, "gamma must be specified for QTB thermostat" 193 ndof = mass.shape[0] * 3 194 nkT = ndof * kT 195 nose_mass = nkT / gamma**2 196 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 197 state["nose_s"] = 0.0 198 state["nose_v"] = 0.0 199 if compute_thermostat_energy: 200 state["thermostat_energy"] = 0.0 201 print( 202 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 203 ) 204 vel = ( 205 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 206 * (kT / mass[:, None]) ** 0.5 207 ) 208 209 def thermostat(vel, state): 210 nose_s = state["nose_s"] 211 nose_v = state["nose_v"] 212 kTsys = jnp.sum(mass[:, None] * vel**2) 213 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 214 nose_s = nose_s + dt * nose_v 215 vel = jnp.exp(-nose_v * dt) * vel 216 kTsys = jnp.sum(mass[:, None] * vel**2) 217 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 218 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 219 220 if compute_thermostat_energy: 221 new_state["thermostat_energy"] = ( 222 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 223 ) 224 return vel, new_state 225 226 elif thermostat_name in ["QTB", "ADQTB"]: 227 assert nbeads is None, "QTB is not compatible with PIMD" 228 qtb_parameters = simulation_parameters.get("qtb", None) 229 """@keyword[fennol_md] qtb 230 Parameters for Quantum Thermal Bath thermostat configuration. 231 Required for QTB/ADQTB thermostats 232 """ 233 assert ( 234 qtb_parameters is not None 235 ), "qtb_parameters must be provided for QTB thermostat" 236 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 237 assert kT is not None, "kT must be specified for QTB thermostat" 238 assert gamma is not None, "gamma must be specified for QTB thermostat" 239 assert species is not None, "species must be provided for QTB thermostat" 240 rng_key, v_key = jax.random.split(rng_key) 241 vel = ( 242 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 243 * (kT / mass[:, None]) ** 0.5 244 ) 245 246 thermostat, postprocess, qtb_state = initialize_qtb( 247 qtb_parameters, 248 system_data, 249 fprec=fprec, 250 dt=dt, 251 mass=mass, 252 gamma=gamma, 253 kT=kT, 254 species=species, 255 rng_key=rng_key, 256 adaptive=thermostat_name.startswith("AD"), 257 compute_thermostat_energy=compute_thermostat_energy, 258 ) 259 state = {**state, **qtb_state} 260 261 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 262 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 263 assert kT is not None, "kT must be specified for QTB thermostat" 264 assert gamma is not None, "gamma must be specified for QTB thermostat" 265 assert nbeads is None, "ANNEAL is not compatible with PIMD" 266 a1 = math.exp(-gamma * dt) 267 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 268 269 anneal_parameters = simulation_parameters.get("annealing", {}) 270 """@keyword[fennol_md] annealing 271 Parameters for simulated annealing schedule configuration. 272 Required for ANNEAL/ANNEALING thermostat 273 """ 274 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 275 """@keyword[fennol_md] annealing/init_factor 276 Initial temperature factor for annealing schedule. 277 Default: 0.04 (1/25) 278 """ 279 assert init_factor > 0.0, "init_factor must be positive" 280 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 281 """@keyword[fennol_md] annealing/final_factor 282 Final temperature factor for annealing schedule. 283 Default: 0.0001 (1/10000) 284 """ 285 assert final_factor > 0.0, "final_factor must be positive" 286 nsteps = simulation_parameters.get("nsteps") 287 """@keyword[fennol_md] nsteps 288 Total number of simulation steps for annealing schedule calculation. 289 Required parameter 290 """ 291 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 292 """@keyword[fennol_md] annealing/anneal_steps 293 Fraction of total steps for annealing process. 294 Default: 1.0 295 """ 296 assert ( 297 anneal_steps < 1.0 and anneal_steps > 0.0 298 ), "warmup_steps must be between 0 and nsteps" 299 pct_start = anneal_parameters.get("warmup_steps", 0.3) 300 """@keyword[fennol_md] annealing/warmup_steps 301 Fraction of annealing steps for warmup phase. 302 Default: 0.3 303 """ 304 assert ( 305 pct_start < 1.0 and pct_start > 0.0 306 ), "warmup_steps must be between 0 and nsteps" 307 308 anneal_type = anneal_parameters.get("type", "cosine").lower() 309 """@keyword[fennol_md] annealing/type 310 Type of annealing schedule (linear, cosine_onecycle). 311 Default: cosine 312 """ 313 if anneal_type == "linear": 314 schedule = optax.linear_onecycle_schedule( 315 peak_value=1.0, 316 div_factor=1.0 / init_factor, 317 final_div_factor=1.0 / final_factor, 318 transition_steps=int(anneal_steps * nsteps), 319 pct_start=pct_start, 320 pct_final=1.0, 321 ) 322 elif anneal_type == "cosine_onecycle": 323 schedule = optax.cosine_onecycle_schedule( 324 peak_value=1.0, 325 div_factor=1.0 / init_factor, 326 final_div_factor=1.0 / final_factor, 327 transition_steps=int(anneal_steps * nsteps), 328 pct_start=pct_start, 329 ) 330 else: 331 raise ValueError(f"Unknown anneal_type {anneal_type}") 332 333 state["rng_key"] = rng_key 334 state["istep_anneal"] = 0 335 336 rng_key, v_key = jax.random.split(rng_key) 337 Tscale = schedule(0) 338 print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K") 339 vel = ( 340 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 341 * (kT * Tscale / mass[:, None]) ** 0.5 342 ) 343 344 def thermostat(vel, state): 345 rng_key, noise_key = jax.random.split(state["rng_key"]) 346 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 347 348 Tscale = schedule(state["istep_anneal"]) ** 0.5 349 vel = a1 * vel + a2 * Tscale * noise 350 return vel, { 351 **state, 352 "rng_key": rng_key, 353 "istep_anneal": state["istep_anneal"] + 1, 354 } 355 356 else: 357 raise ValueError(f"Unknown thermostat {thermostat_name}") 358 359 return thermostat, postprocess, state, vel,thermostat_name
def
initialize_qtb( qtb_parameters, system_data, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
362def initialize_qtb( 363 qtb_parameters, 364 system_data, 365 fprec, 366 dt, 367 mass, 368 gamma, 369 kT, 370 species, 371 rng_key, 372 adaptive, 373 compute_thermostat_energy=False, 374): 375 state = {} 376 post_state = {} 377 verbose = qtb_parameters.get("verbose", False) 378 """@keyword[fennol_md] qtb/verbose 379 Print verbose QTB thermostat information. 380 Default: False 381 """ 382 if compute_thermostat_energy: 383 state["thermostat_energy"] = 0.0 384 385 mass = jnp.asarray(mass, dtype=fprec) 386 387 nat = species.shape[0] 388 # define type indices 389 species_set = set(species) 390 nspecies = len(species_set) 391 idx = {sp: i for i, sp in enumerate(species_set)} 392 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 393 394 n_of_type = np.zeros(nspecies, dtype=np.int32) 395 for i in range(nspecies): 396 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 397 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 398 mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type 399 400 niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7) 401 """@keyword[fennol_md] qtb/niter_deconv_kin 402 Number of iterations for kinetic energy deconvolution. 403 Default: 7 404 """ 405 niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20) 406 """@keyword[fennol_md] qtb/niter_deconv_pot 407 Number of iterations for potential energy deconvolution. 408 Default: 20 409 """ 410 corr_kin = qtb_parameters.get("corr_kin", -1) 411 """@keyword[fennol_md] qtb/corr_kin 412 Kinetic energy correction factor for QTB (-1 for automatic). 413 Default: -1 414 """ 415 do_corr_kin = corr_kin <= 0 416 if do_corr_kin: 417 corr_kin = 1.0 418 state["corr_kin"] = corr_kin 419 post_state["corr_kin_prev"] = corr_kin 420 post_state["do_corr_kin"] = do_corr_kin 421 post_state["isame_kin"] = 0 422 423 # spectra parameters 424 omegasmear = np.pi / dt / 100.0 425 Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS 426 """@keyword[fennol_md] qtb/tseg 427 Time segment length for QTB spectrum calculation. 428 Default: 1.0 ps 429 """ 430 nseg = int(Tseg / dt) 431 Tseg = nseg * dt 432 dom = 2 * np.pi / (3 * Tseg) 433 omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS 434 """@keyword[fennol_md] qtb/omegacut 435 Cutoff frequency for QTB spectrum. 436 Default: 15000.0 cm⁻¹ 437 """ 438 nom = int(omegacut / dom) 439 omega = dom * np.arange((3 * nseg) // 2 + 1) 440 cutoff = jnp.asarray( 441 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 442 ) 443 assert ( 444 omegacut < omega[-1] 445 ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1" 446 447 # initialize gammar 448 assert ( 449 gamma < 0.5 * omegacut 450 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 451 gammar_min = qtb_parameters.get("gammar_min", 0.1) 452 """@keyword[fennol_md] qtb/gammar_min 453 Minimum value for QTB gamma ratio coefficients. 454 Default: 0.1 455 """ 456 # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 457 gammar = np.ones((nspecies, nom), dtype=float) 458 try: 459 for i, sp in enumerate(species_set): 460 if not os.path.exists(f"QTB_spectra_{sp}.out"): continue 461 data = np.loadtxt(f"QTB_spectra_{sp}.out") 462 gammar[i] = data[:, 4]/(gamma*au.FS*au.THZ) 463 print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out") 464 except Exception as e: 465 print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.") 466 gammar[:,:] = 1.0 467 post_state["gammar"] = jnp.asarray(gammar, dtype=fprec) 468 469 # Ornstein-Uhlenbeck correction for colored noise 470 a1 = np.exp(-gamma * dt) 471 OUcorr = jnp.asarray( 472 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 473 dtype=fprec, 474 ) 475 476 # hbar schedule 477 classical_kernel = qtb_parameters.get("classical_kernel", False) 478 """@keyword[fennol_md] qtb/classical_kernel 479 Use classical instead of quantum kernel for QTB. 480 Default: False 481 """ 482 hbar = qtb_parameters.get("hbar", 1.0) * au.FS 483 """@keyword[fennol_md] qtb/hbar 484 Reduced Planck constant scaling factor for quantum effects. 485 Default: 1.0 a.u. 486 """ 487 u = 0.5 * hbar * np.abs(omega) / kT 488 theta = kT * np.ones_like(omega) 489 if hbar > 0: 490 theta[1:] *= u[1:] / np.tanh(u[1:]) 491 theta = jnp.asarray(theta, dtype=fprec) 492 493 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 494 del rng_key 495 post_state["white_noise"] = jax.random.normal( 496 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 497 ) 498 499 startsave = qtb_parameters.get("startsave", 1) 500 """@keyword[fennol_md] qtb/startsave 501 Start saving QTB statistics after this many segments. 502 Default: 1 503 """ 504 counter = Counter(nseg, startsave=startsave) 505 state["istep"] = 0 506 post_state["nadapt"] = 0 507 post_state["nsample"] = 0 508 509 write_spectra = qtb_parameters.get("write_spectra", True) 510 """@keyword[fennol_md] qtb/write_spectra 511 Write QTB spectral analysis output files. 512 Default: True 513 """ 514 do_compute_spectra = write_spectra or adaptive 515 516 if do_compute_spectra: 517 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 518 519 post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec) 520 post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec) 521 post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec) 522 post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec) 523 post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 524 post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 525 post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 526 post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 527 528 if not adaptive: 529 update_gammar = lambda x: x 530 else: 531 # adaptation parameters 532 skipseg = qtb_parameters.get("skipseg", 1) 533 """@keyword[fennol_md] qtb/skipseg 534 Number of segments to skip before starting adaptive QTB. 535 Default: 1 536 """ 537 538 adaptation_method = ( 539 str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip() 540 ) 541 """@keyword[fennol_md] qtb/adaptation_method 542 Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF). 543 Default: ADABELIEF 544 """ 545 authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"] 546 assert ( 547 adaptation_method in authorized_methods 548 ), f"adaptation_method must be one of {authorized_methods}" 549 if adaptation_method == "SIMPLE": 550 agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS 551 """@keyword[fennol_md] qtb/agamma 552 Learning rate for adaptive QTB gamma update. 553 Default: 1.0e-3 (SIMPLE), 0.1 (ADABELIEF) 554 """ 555 assert agamma > 0, "agamma must be positive" 556 a1_ad = agamma * Tseg # * gamma 557 print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}") 558 559 def update_gammar(post_state): 560 g = post_state["dFDT"] 561 gammar = post_state["gammar"] - a1_ad * g 562 gammar = jnp.maximum(gammar_min, gammar) 563 return {**post_state, "gammar": gammar} 564 565 elif adaptation_method == "RATIO": 566 tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS 567 """@keyword[fennol_md] qtb/tau_ad 568 Adaptation time constant for momentum averaging. 569 Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF) 570 """ 571 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS 572 """@keyword[fennol_md] qtb/tau_s 573 Second moment time constant for variance averaging. 574 Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF) 575 """ 576 assert tau_ad > 0, "tau_ad must be positive" 577 print( 578 f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps" 579 ) 580 b1 = np.exp(-Tseg / tau_ad) 581 b2 = np.exp(-Tseg / tau_s) 582 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 583 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 584 post_state["n_adabelief"] = 0 585 586 def update_gammar(post_state): 587 n_adabelief = post_state["n_adabelief"] + 1 588 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 589 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 590 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 591 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 592 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 593 gammar = Cvf / (mCvv + 1.0e-8) 594 gammar = jnp.maximum(gammar_min, gammar) 595 return { 596 **post_state, 597 "gammar": gammar, 598 "mCvv_m": mCvv_m, 599 "Cvf_m": Cvf_m, 600 "n_adabelief": n_adabelief, 601 } 602 603 elif adaptation_method == "ADABELIEF": 604 agamma = qtb_parameters.get("agamma", 0.1) 605 tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS 606 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS 607 assert tau_ad > 0, "tau_ad must be positive" 608 assert tau_s > 0, "tau_s must be positive" 609 assert agamma > 0, "agamma must be positive" 610 print( 611 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps" 612 ) 613 614 a1_ad = agamma * gamma # * Tseg #* gamma 615 b1 = np.exp(-Tseg / tau_ad) 616 b2 = np.exp(-Tseg / tau_s) 617 post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 618 post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec) 619 post_state["n_adabelief"] = 0 620 621 def update_gammar(post_state): 622 n_adabelief = post_state["n_adabelief"] + 1 623 dFDT = post_state["dFDT"] 624 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 625 dFDT_s = ( 626 post_state["dFDT_s"] * b2 627 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 628 + 1.0e-8 629 ) 630 # bias correction 631 mt = dFDT_m / (1.0 - b1**n_adabelief) 632 st = dFDT_s / (1.0 - b2**n_adabelief) 633 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 634 gammar = jnp.maximum(gammar_min, gammar) 635 return { 636 **post_state, 637 "gammar": gammar, 638 "dFDT_m": dFDT_m, 639 "n_adabelief": n_adabelief, 640 "dFDT_s": dFDT_s, 641 } 642 643 ##################### 644 # RESTART 645 restart_file = system_data["name"]+".qtb.restart" 646 if os.path.exists(restart_file): 647 with open(restart_file, "rb") as f: 648 data = pickle.load(f) 649 state["corr_kin"] = data["corr_kin"] 650 post_state["corr_kin_prev"] = data["corr_kin"] 651 post_state["isame_kin"] = data["isame_kin"] 652 post_state["do_corr_kin"] = data["do_corr_kin"] 653 print(f"# Restored QTB state from {restart_file}") 654 655 def write_qtb_restart(state, post_state): 656 with open(restart_file, "wb") as f: 657 pickle.dump( 658 { 659 "corr_kin": state["corr_kin"], 660 "corr_kin_prev": post_state["corr_kin_prev"], 661 "isame_kin": post_state["isame_kin"], 662 "do_corr_kin": post_state["do_corr_kin"], 663 }, 664 f, 665 ) 666 ###################### 667 668 def compute_corr_pot(niter=20, verbose=False): 669 if classical_kernel or hbar == 0: 670 return np.ones(nom) 671 672 s_0 = np.array((theta / kT * cutoff)[:nom]) 673 s_out, s_rec, _ = deconvolute_spectrum( 674 s_0, 675 omega[:nom], 676 gamma, 677 niter, 678 kernel=kernel_lorentz_pot, 679 trans=True, 680 symmetrize=True, 681 verbose=verbose, 682 ) 683 corr_pot = 1.0 + (s_out - s_0) / s_0 684 columns = np.column_stack( 685 (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 686 ) 687 np.savetxt( 688 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 689 ) 690 return corr_pot 691 692 def compute_corr_kin(post_state, niter=7, verbose=False): 693 if not post_state["do_corr_kin"]: 694 return post_state["corr_kin_prev"], post_state 695 if classical_kernel or hbar == 0: 696 return 1.0, post_state 697 698 K_D = post_state.get("K_D", None) 699 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 700 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 701 s_out, s_rec, K_D = deconvolute_spectrum( 702 s_0, 703 omega[:nom], 704 gamma, 705 niter, 706 kernel=kernel_lorentz, 707 trans=False, 708 symmetrize=True, 709 verbose=verbose, 710 K_D=K_D, 711 ) 712 s_out = s_out * theta[:nom] / kT 713 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 714 mCvvsum = mCvv.sum() 715 rec_ratio = mCvvsum / s_rec.sum() 716 if rec_ratio < 0.95 or rec_ratio > 1.05: 717 print( 718 f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated" 719 ) 720 return post_state["corr_kin_prev"], post_state 721 722 corr_kin = mCvvsum / s_out.sum() 723 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 724 isame_kin = post_state["isame_kin"] + 1 725 else: 726 isame_kin = 0 727 728 # print("# corr_kin: ", corr_kin) 729 do_corr_kin = post_state["do_corr_kin"] 730 if isame_kin > 10: 731 print( 732 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 733 ) 734 do_corr_kin = False 735 736 return corr_kin, { 737 **post_state, 738 "corr_kin_prev": corr_kin, 739 "isame_kin": isame_kin, 740 "do_corr_kin": do_corr_kin, 741 "K_D": K_D, 742 } 743 744 @jax.jit 745 def ff_kernel(post_state): 746 if classical_kernel: 747 kernel = cutoff * (2 * gamma * kT / dt) 748 else: 749 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 750 gamma_ratio = jnp.concatenate( 751 ( 752 post_state["gammar"].T * post_state["corr_pot"][:, None], 753 jnp.ones( 754 (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype 755 ), 756 ), 757 axis=0, 758 ) 759 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 760 761 @jax.jit 762 def refresh_force(post_state): 763 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 764 white_noise = jnp.concatenate( 765 ( 766 post_state["white_noise"][nseg:], 767 jax.random.normal( 768 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 769 ), 770 ), 771 axis=0, 772 ) 773 amplitude = ff_kernel(post_state) ** 0.5 774 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 775 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 776 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 777 778 @jax.jit 779 def compute_spectra(force, vel, post_state): 780 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 781 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 782 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 783 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 784 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 785 786 mCvv = ( 787 (dt / 3.0) 788 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 789 * mass_idx[:, None] 790 / n_of_type[:, None] 791 ) 792 Cvf = ( 793 (dt / 3.0) 794 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 795 / n_of_type[:, None] 796 ) 797 Cff = ( 798 (dt / 3.0) 799 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 800 / n_of_type[:, None] 801 ) 802 dFDT = mCvv * post_state["gammar"] - Cvf 803 804 nsinv = 1.0 / post_state["nsample"] 805 b1 = 1.0 - nsinv 806 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 807 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 808 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 809 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 810 811 return { 812 **post_state, 813 "mCvv": mCvv, 814 "Cvf": Cvf, 815 "Cff": Cff, 816 "dFDT": dFDT, 817 "dFDT_avg": dFDT_avg, 818 "mCvv_avg": mCvv_avg, 819 "Cvfg_avg": Cvfg_avg, 820 "Cff_avg": Cff_avg, 821 } 822 823 def write_spectra_to_file(post_state): 824 mCvv_avg = np.array(post_state["mCvv_avg"]) 825 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 826 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 827 dFDT_avg = np.array(post_state["dFDT_avg"]) 828 gammar = np.array(post_state["gammar"]) 829 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 830 for i, sp in enumerate(species_set): 831 ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 832 columns = np.column_stack( 833 ( 834 omega[:nom] * (au.FS * au.CM1), 835 mCvv_avg[i], 836 Cvfg_avg[i], 837 dFDT_avg[i], 838 gammar[i] * gamma * (au.FS * au.THZ), 839 Cff_avg[i] * ff_scale, 840 Cff_theo[i] * ff_scale, 841 ) 842 ) 843 np.savetxt( 844 f"QTB_spectra_{sp}.out", 845 columns, 846 fmt="%12.6f", 847 header="#omega mCvv Cvf dFDT gammar Cff", 848 ) 849 if verbose: 850 print("# QTB spectra written.") 851 852 if compute_thermostat_energy: 853 state["qtb_energy_flux"] = 0.0 854 855 @jax.jit 856 def thermostat(vel, state): 857 istep = state["istep"] 858 dvel = dt * state["force"][istep] / mass[:, None] 859 new_vel = vel * a1 + dvel 860 new_state = {**state, "istep": istep + 1} 861 if do_compute_spectra: 862 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 863 new_state["vel"] = vel2 864 if compute_thermostat_energy: 865 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 866 ekcorr = ( 867 0.5 868 * (mass[:, None] * new_vel**2).sum() 869 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 870 ) 871 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 872 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 873 return new_vel, new_state 874 875 @jax.jit 876 def postprocess_work(state, post_state): 877 if do_compute_spectra: 878 post_state = compute_spectra(state["force"], state["vel"], post_state) 879 if adaptive: 880 post_state = jax.lax.cond( 881 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 882 ) 883 new_force, post_state = refresh_force(post_state) 884 return {**state, "force": new_force}, post_state 885 886 def postprocess(state, post_state): 887 counter.increment() 888 if not counter.is_reset_step: 889 return state, post_state 890 post_state["nadapt"] += 1 891 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 892 if verbose: 893 print("# Refreshing QTB forces.") 894 state, post_state = postprocess_work(state, post_state) 895 state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin) 896 state["istep"] = 0 897 if write_spectra: 898 write_spectra_to_file(post_state) 899 write_qtb_restart(state, post_state) 900 return state, post_state 901 902 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec) 903 904 state["force"], post_state = refresh_force(post_state) 905 return thermostat, (postprocess, post_state), state