fennol.md.thermostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7import os 8import pickle 9 10from .utils import us 11from ..utils import Counter,read_tinker_interval 12from ..utils.deconvolution import ( 13 deconvolute_spectrum, 14 kernel_lorentz_pot, 15 kernel_lorentz, 16) 17from ..utils.periodic_table import PERIODIC_TABLE 18 19 20def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}): 21 state = {} 22 postprocess = None 23 24 25 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 26 """@keyword[fennol_md] thermostat 27 Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'. 28 Default: "LGV" 29 """ 30 compute_thermostat_energy = simulation_parameters.get( 31 "include_thermostat_energy", False 32 ) 33 """@keyword[fennol_md] include_thermostat_energy 34 Include thermostat energy in total energy calculations. 35 Default: False 36 """ 37 38 kT = system_data.get("kT", None) 39 nbeads = system_data.get("nbeads", None) 40 mass = system_data["mass"] 41 gamma = simulation_parameters.get("gamma", 1.0 / us.THZ) 42 """@keyword[fennol_md] gamma 43 Friction coefficient for Langevin thermostat. 44 Default: 1.0 ps^-1 45 """ 46 species = system_data["species"] 47 48 if nbeads is not None: 49 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 50 """@keyword[fennol_md] trpmd_lambda 51 Lambda parameter for TRPMD (Thermostatted Ring Polymer MD). 52 Default: 1.0 53 """ 54 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 55 56 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 57 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 58 assert kT is not None, "kT must be specified for QTB thermostat" 59 assert gamma is not None, "gamma must be specified for QTB thermostat" 60 rng_key, v_key = jax.random.split(rng_key) 61 if nbeads is None: 62 a1 = math.exp(-gamma * dt) 63 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 64 vel = ( 65 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 66 * (kT / mass[:, None]) ** 0.5 67 ) 68 else: 69 if isinstance(gamma, float): 70 gamma = np.array([gamma] * nbeads) 71 assert isinstance( 72 gamma, np.ndarray 73 ), "gamma must be a float or a numpy array" 74 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 75 a1 = np.exp(-gamma * dt)[:, None, None] 76 a2 = jnp.asarray( 77 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 78 ) 79 vel = ( 80 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 81 * (kT / mass[:, None]) ** 0.5 82 ) 83 84 state["rng_key"] = rng_key 85 if compute_thermostat_energy: 86 state["thermostat_energy"] = 0.0 87 if thermostat_name == "FFLGV": 88 def thermostat(vel, state): 89 rng_key, noise_key = jax.random.split(state["rng_key"]) 90 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 91 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 92 dirvel = vel / norm_vel 93 if compute_thermostat_energy: 94 v2 = (vel**2).sum(axis=-1) 95 vel = a1 * vel + a2 * noise 96 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 97 vel = dirvel * new_norm_vel 98 new_state = {**state, "rng_key": rng_key} 99 if compute_thermostat_energy: 100 v2new = (vel**2).sum(axis=-1) 101 new_state["thermostat_energy"] = ( 102 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 103 ) 104 105 return vel, new_state 106 107 else: 108 def thermostat(vel, state): 109 rng_key, noise_key = jax.random.split(state["rng_key"]) 110 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 111 if compute_thermostat_energy: 112 v2 = (vel**2).sum(axis=-1) 113 vel = a1 * vel + a2 * noise 114 new_state = {**state, "rng_key": rng_key} 115 if compute_thermostat_energy: 116 v2new = (vel**2).sum(axis=-1) 117 new_state["thermostat_energy"] = ( 118 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 119 ) 120 return vel, new_state 121 122 elif thermostat_name in ["BUSSI"]: 123 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 124 assert kT is not None, "kT must be specified for QTB thermostat" 125 assert gamma is not None, "gamma must be specified for QTB thermostat" 126 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 127 rng_key, v_key = jax.random.split(rng_key) 128 129 a1 = math.exp(-gamma * dt) 130 a2 = (1 - a1) * kT 131 vel = ( 132 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 133 * (kT / mass[:, None]) ** 0.5 134 ) 135 136 state["rng_key"] = rng_key 137 if compute_thermostat_energy: 138 state["thermostat_energy"] = 0.0 139 140 def thermostat(vel, state): 141 rng_key, noise_key = jax.random.split(state["rng_key"]) 142 new_state = {**state, "rng_key": rng_key} 143 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 144 R2 = jnp.sum(noise**2) 145 R1 = noise[0, 0] 146 c = a2 / (mass[:, None] * vel**2).sum() 147 d = (a1 * c) ** 0.5 148 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 149 if compute_thermostat_energy: 150 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 151 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 152 return scale * vel, new_state 153 154 elif thermostat_name in [ 155 "GD", 156 "DESCENT", 157 "GRADIENT_DESCENT", 158 "MIN", 159 "MINIMIZE", 160 ]: 161 assert nbeads is None, "Gradient descent is not compatible with PIMD" 162 a1 = math.exp(-gamma * dt) 163 164 if nbeads is None: 165 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 166 else: 167 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 168 169 def thermostat(vel, state): 170 return a1 * vel, state 171 172 elif thermostat_name in ["NVE", "NONE"]: 173 if nbeads is None: 174 vel = ( 175 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 176 * (kT / mass[:, None]) ** 0.5 177 ) 178 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 179 vel = vel * (kT / kTsys) ** 0.5 180 else: 181 vel = ( 182 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 183 * (kT / mass[None, :, None]) ** 0.5 184 ) 185 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 186 mass.shape[0] * 3 187 ) 188 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 189 thermostat = lambda x, s: (x, s) 190 191 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 192 assert gamma is not None, "gamma must be specified for QTB thermostat" 193 ndof = mass.shape[0] * 3 194 nkT = ndof * kT 195 nose_mass = nkT / gamma**2 196 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 197 state["nose_s"] = 0.0 198 state["nose_v"] = 0.0 199 if compute_thermostat_energy: 200 state["thermostat_energy"] = 0.0 201 print( 202 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 203 ) 204 vel = ( 205 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 206 * (kT / mass[:, None]) ** 0.5 207 ) 208 209 def thermostat(vel, state): 210 nose_s = state["nose_s"] 211 nose_v = state["nose_v"] 212 kTsys = jnp.sum(mass[:, None] * vel**2) 213 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 214 nose_s = nose_s + dt * nose_v 215 vel = jnp.exp(-nose_v * dt) * vel 216 kTsys = jnp.sum(mass[:, None] * vel**2) 217 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 218 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 219 220 if compute_thermostat_energy: 221 new_state["thermostat_energy"] = ( 222 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 223 ) 224 return vel, new_state 225 226 elif thermostat_name in ["QTB", "ADQTB"]: 227 assert nbeads is None, "QTB is not compatible with PIMD" 228 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 229 assert kT is not None, "kT must be specified for QTB thermostat" 230 assert gamma is not None, "gamma must be specified for QTB thermostat" 231 assert species is not None, "species must be provided for QTB thermostat" 232 rng_key, v_key = jax.random.split(rng_key) 233 vel = ( 234 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 235 * (kT / mass[:, None]) ** 0.5 236 ) 237 238 thermostat, postprocess, qtb_state = initialize_qtb( 239 simulation_parameters, 240 system_data, 241 fprec=fprec, 242 dt=dt, 243 mass=mass, 244 gamma=gamma, 245 kT=kT, 246 species=species, 247 rng_key=rng_key, 248 adaptive=thermostat_name.startswith("AD"), 249 compute_thermostat_energy=compute_thermostat_energy, 250 ) 251 state = {**state, **qtb_state} 252 253 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 254 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 255 assert kT is not None, "kT must be specified for QTB thermostat" 256 assert gamma is not None, "gamma must be specified for QTB thermostat" 257 assert nbeads is None, "ANNEAL is not compatible with PIMD" 258 a1 = math.exp(-gamma * dt) 259 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 260 261 anneal_parameters = simulation_parameters.get("annealing", {}) 262 """@keyword[fennol_md] annealing 263 Parameters for simulated annealing schedule configuration. 264 Required for ANNEAL/ANNEALING thermostat 265 """ 266 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 267 """@keyword[fennol_md] annealing/init_factor 268 Initial temperature factor for annealing schedule. 269 Default: 0.04 (1/25) 270 """ 271 assert init_factor > 0.0, "init_factor must be positive" 272 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 273 """@keyword[fennol_md] annealing/final_factor 274 Final temperature factor for annealing schedule. 275 Default: 0.0001 (1/10000) 276 """ 277 assert final_factor > 0.0, "final_factor must be positive" 278 nsteps = simulation_parameters.get("nsteps") 279 """@keyword[fennol_md] nsteps 280 Total number of simulation steps for annealing schedule calculation. 281 Required parameter 282 """ 283 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 284 """@keyword[fennol_md] annealing/anneal_steps 285 Fraction of total steps for annealing process. 286 Default: 1.0 287 """ 288 assert ( 289 anneal_steps < 1.0 and anneal_steps > 0.0 290 ), "warmup_steps must be between 0 and nsteps" 291 pct_start = anneal_parameters.get("warmup_steps", 0.3) 292 """@keyword[fennol_md] annealing/warmup_steps 293 Fraction of annealing steps for warmup phase. 294 Default: 0.3 295 """ 296 assert ( 297 pct_start < 1.0 and pct_start > 0.0 298 ), "warmup_steps must be between 0 and nsteps" 299 300 anneal_type = anneal_parameters.get("type", "cosine").lower() 301 """@keyword[fennol_md] annealing/type 302 Type of annealing schedule (linear, cosine_onecycle). 303 Default: cosine 304 """ 305 if anneal_type == "linear": 306 schedule = optax.linear_onecycle_schedule( 307 peak_value=1.0, 308 div_factor=1.0 / init_factor, 309 final_div_factor=1.0 / final_factor, 310 transition_steps=int(anneal_steps * nsteps), 311 pct_start=pct_start, 312 pct_final=1.0, 313 ) 314 elif anneal_type == "cosine_onecycle": 315 schedule = optax.cosine_onecycle_schedule( 316 peak_value=1.0, 317 div_factor=1.0 / init_factor, 318 final_div_factor=1.0 / final_factor, 319 transition_steps=int(anneal_steps * nsteps), 320 pct_start=pct_start, 321 ) 322 else: 323 raise ValueError(f"Unknown anneal_type {anneal_type}") 324 325 state["rng_key"] = rng_key 326 state["istep_anneal"] = 0 327 328 rng_key, v_key = jax.random.split(rng_key) 329 Tscale = schedule(0) 330 print(f"# ANNEAL: initial temperature = {Tscale*kT*us.KELVIN:.3e} K") 331 vel = ( 332 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 333 * (kT * Tscale / mass[:, None]) ** 0.5 334 ) 335 336 def thermostat(vel, state): 337 rng_key, noise_key = jax.random.split(state["rng_key"]) 338 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 339 340 Tscale = schedule(state["istep_anneal"]) ** 0.5 341 vel = a1 * vel + a2 * Tscale * noise 342 return vel, { 343 **state, 344 "rng_key": rng_key, 345 "istep_anneal": state["istep_anneal"] + 1, 346 } 347 348 else: 349 raise ValueError(f"Unknown thermostat {thermostat_name}") 350 351 return thermostat, postprocess, state, vel,thermostat_name 352 353 354def initialize_qtb( 355 simulation_parameters, 356 system_data, 357 fprec, 358 dt, 359 mass, 360 gamma, 361 kT, 362 species, 363 rng_key, 364 adaptive, 365 compute_thermostat_energy=False, 366): 367 state = {} 368 post_state = {} 369 qtb_parameters = simulation_parameters.get("qtb", {}) 370 verbose = qtb_parameters.get("verbose", False) 371 """@keyword[fennol_md] qtb/verbose 372 Print verbose QTB thermostat information. 373 Default: False 374 """ 375 if compute_thermostat_energy: 376 state["thermostat_energy"] = 0.0 377 378 mass = jnp.asarray(mass, dtype=fprec) 379 380 nat = species.shape[0] 381 # define type indices 382 species_set = set(species) 383 nspecies = len(species_set) 384 idx = {sp: i for i, sp in enumerate(species_set)} 385 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 386 type_labels = [PERIODIC_TABLE[sp] for sp in species_set] 387 388 adapt_groups = qtb_parameters.get("adapt_groups", {}) 389 """@keyword[fennol_md] qtb/adapt_groups 390 Groups of atoms with separate adQTB parameters (atoms of differents species within groups will be separated in subtypes). 391 Default: {} 392 """ 393 ntypes = nspecies 394 allgroups = set() 395 for groupname, interval in adapt_groups.items(): 396 indices = read_tinker_interval(interval) 397 assert allgroups.isdisjoint(indices), f"Indices in group {groupname} overlap with other groups" 398 allgroups.update(indices) 399 species_in_group = set(species[indices]) 400 idx = {sp: i for i, sp in enumerate(species_in_group)} 401 for i in indices: 402 type_idx[i] = ntypes + idx[species[i]] 403 type_labels += [f"{PERIODIC_TABLE[sp]}_{groupname}" for sp in species_in_group] 404 ntypes = len(type_labels) 405 406 n_of_type = np.zeros(ntypes, dtype=np.int32) 407 for i in range(ntypes): 408 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 409 print(f"# QTB: n({type_labels[i]}) =", n_of_type[i]) 410 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 411 mass_idx = jax.ops.segment_sum(mass, type_idx, ntypes) / n_of_type 412 413 niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7) 414 """@keyword[fennol_md] qtb/niter_deconv_kin 415 Number of iterations for kinetic energy deconvolution. 416 Default: 7 417 """ 418 niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20) 419 """@keyword[fennol_md] qtb/niter_deconv_pot 420 Number of iterations for potential energy deconvolution. 421 Default: 20 422 """ 423 corr_kin = qtb_parameters.get("corr_kin", -1) 424 """@keyword[fennol_md] qtb/corr_kin 425 Kinetic energy correction factor for QTB (-1 for automatic). 426 Default: -1 427 """ 428 do_corr_kin = corr_kin <= 0 429 if do_corr_kin: 430 corr_kin = 1.0 431 state["corr_kin"] = corr_kin 432 post_state["corr_kin_prev"] = corr_kin 433 post_state["do_corr_kin"] = do_corr_kin 434 post_state["isame_kin"] = 0 435 436 # spectra parameters 437 omegasmear = np.pi / dt / 100.0 438 Tseg = qtb_parameters.get("tseg", 1.0 / us.PS) 439 """@keyword[fennol_md] qtb/tseg 440 Time segment length for QTB spectrum calculation. 441 Default: 1.0 ps 442 """ 443 nseg = int(Tseg / dt) 444 Tseg = nseg * dt 445 dom = 2 * np.pi / (3 * Tseg) 446 omegacut = qtb_parameters.get("omegacut", 15000.0 / us.CM1) 447 """@keyword[fennol_md] qtb/omegacut 448 Cutoff frequency for QTB spectrum. 449 Default: 15000.0 cm⁻¹ 450 """ 451 nom = int(omegacut / dom) 452 omega = dom * np.arange((3 * nseg) // 2 + 1) 453 cutoff = jnp.asarray( 454 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 455 ) 456 assert ( 457 omegacut < omega[-1] 458 ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1" 459 460 # initialize gammar 461 assert ( 462 gamma < 0.5 * omegacut 463 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 464 gammar_min = qtb_parameters.get("gammar_min", 0.1) 465 """@keyword[fennol_md] qtb/gammar_min 466 Minimum value for QTB gamma ratio coefficients. 467 Default: 0.1 468 """ 469 # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 470 gammar = np.ones((ntypes, nom), dtype=float) 471 try: 472 for i, sp in enumerate(type_labels): 473 if not os.path.exists(f"QTB_spectra_{sp}.out"): continue 474 data = np.loadtxt(f"QTB_spectra_{sp}.out") 475 gammar[i] = data[:, 4]/(gamma*us.THZ) 476 print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out") 477 except Exception as e: 478 print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.") 479 gammar[:,:] = 1.0 480 post_state["gammar"] = jnp.asarray(gammar, dtype=fprec) 481 482 # Ornstein-Uhlenbeck correction for colored noise 483 a1 = np.exp(-gamma * dt) 484 OUcorr = jnp.asarray( 485 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 486 dtype=fprec, 487 ) 488 489 # hbar schedule 490 classical_kernel = qtb_parameters.get("classical_kernel", False) 491 """@keyword[fennol_md] qtb/classical_kernel 492 Use classical instead of quantum kernel for QTB. 493 Default: False 494 """ 495 hbar = qtb_parameters.get("hbar", 1.0) * us.HBAR 496 """@keyword[fennol_md] qtb/hbar 497 Reduced Planck constant scaling factor for quantum effects. 498 Default: 1.0 a.u. 499 """ 500 u = 0.5 * hbar * np.abs(omega) / kT 501 theta = kT * np.ones_like(omega) 502 if hbar > 0: 503 theta[1:] *= u[1:] / np.tanh(u[1:]) 504 theta = jnp.asarray(theta, dtype=fprec) 505 506 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 507 del rng_key 508 post_state["white_noise"] = jax.random.normal( 509 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 510 ) 511 512 startsave = qtb_parameters.get("startsave", 1) 513 """@keyword[fennol_md] qtb/startsave 514 Start saving QTB statistics after this many segments. 515 Default: 1 516 """ 517 counter = Counter(nseg, startsave=startsave) 518 state["istep"] = 0 519 post_state["nadapt"] = 0 520 post_state["nsample"] = 0 521 522 write_spectra = qtb_parameters.get("write_spectra", True) 523 """@keyword[fennol_md] qtb/write_spectra 524 Write QTB spectral analysis output files. 525 Default: True 526 """ 527 do_compute_spectra = write_spectra or adaptive 528 529 if do_compute_spectra: 530 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 531 532 post_state["dFDT"] = jnp.zeros((ntypes, nom), dtype=fprec) 533 post_state["mCvv"] = jnp.zeros((ntypes, nom), dtype=fprec) 534 post_state["Cvf"] = jnp.zeros((ntypes, nom), dtype=fprec) 535 post_state["Cff"] = jnp.zeros((ntypes, nom), dtype=fprec) 536 post_state["dFDT_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 537 post_state["mCvv_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 538 post_state["Cvfg_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 539 post_state["Cff_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 540 541 if not adaptive: 542 update_gammar = lambda x: x 543 else: 544 # adaptation parameters 545 skipseg = qtb_parameters.get("skipseg", 1) 546 """@keyword[fennol_md] qtb/skipseg 547 Number of segments to skip before starting adaptive QTB. 548 Default: 1 549 """ 550 551 adaptation_method = ( 552 str(qtb_parameters.get("adaptation_method", "SIMPLE")).upper().strip() 553 ) 554 """@keyword[fennol_md] qtb/adaptation_method 555 Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF). 556 Default: SIMPLE 557 """ 558 if adaptation_method == "SIMPLE": 559 agamma = qtb_parameters.get("agamma", 0.1) 560 """@keyword[fennol_md] qtb/agamma 561 Learning rate for adaptive QTB gamma update. 562 Default: 0.1 563 """ 564 assert agamma > 0, "agamma must be positive" 565 a1_ad = agamma * Tseg # * gamma 566 print(f"# ADQTB SIMPLE: agamma = {agamma:.3f}") 567 568 def update_gammar(post_state): 569 g = post_state["dFDT"] 570 gammar = post_state["gammar"] - a1_ad * g 571 gammar = jnp.maximum(gammar_min, gammar) 572 return {**post_state, "gammar": gammar} 573 574 elif adaptation_method == "RATIO": 575 tau_ad = qtb_parameters.get("tau_ad", 5.0 / us.PS) 576 """@keyword[fennol_md] qtb/tau_ad 577 Adaptation time constant for momentum averaging. 578 Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF) 579 """ 580 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) 581 """@keyword[fennol_md] qtb/tau_s 582 Second moment time constant for variance averaging. 583 Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF) 584 """ 585 assert tau_ad > 0, "tau_ad must be positive" 586 print( 587 f"# ADQTB RATIO: tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps" 588 ) 589 b1 = np.exp(-Tseg / tau_ad) 590 b2 = np.exp(-Tseg / tau_s) 591 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 592 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 593 post_state["n_adabelief"] = 0 594 595 def update_gammar(post_state): 596 n_adabelief = post_state["n_adabelief"] + 1 597 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 598 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 599 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 600 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 601 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 602 gammar = Cvf / (mCvv + 1.0e-8) 603 gammar = jnp.maximum(gammar_min, gammar) 604 return { 605 **post_state, 606 "gammar": gammar, 607 "mCvv_m": mCvv_m, 608 "Cvf_m": Cvf_m, 609 "n_adabelief": n_adabelief, 610 } 611 612 elif adaptation_method == "ADABELIEF": 613 agamma = qtb_parameters.get("agamma", 0.1) 614 tau_ad = qtb_parameters.get("tau_ad", 1.0 / us.PS) 615 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) 616 assert tau_ad > 0, "tau_ad must be positive" 617 assert tau_s > 0, "tau_s must be positive" 618 assert agamma > 0, "agamma must be positive" 619 print( 620 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps" 621 ) 622 623 a1_ad = agamma * gamma # * Tseg #* gamma 624 b1 = np.exp(-Tseg / tau_ad) 625 b2 = np.exp(-Tseg / tau_s) 626 post_state["dFDT_m"] = jnp.zeros((ntypes, nom), dtype=fprec) 627 post_state["dFDT_s"] = jnp.zeros((ntypes, nom), dtype=fprec) 628 post_state["n_adabelief"] = 0 629 630 def update_gammar(post_state): 631 n_adabelief = post_state["n_adabelief"] + 1 632 dFDT = post_state["dFDT"] 633 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 634 dFDT_s = ( 635 post_state["dFDT_s"] * b2 636 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 637 + 1.0e-8 638 ) 639 # bias correction 640 mt = dFDT_m / (1.0 - b1**n_adabelief) 641 st = dFDT_s / (1.0 - b2**n_adabelief) 642 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 643 gammar = jnp.maximum(gammar_min, gammar) 644 return { 645 **post_state, 646 "gammar": gammar, 647 "dFDT_m": dFDT_m, 648 "n_adabelief": n_adabelief, 649 "dFDT_s": dFDT_s, 650 } 651 else: 652 raise ValueError( 653 f"Unknown adaptation method {adaptation_method}." 654 ) 655 656 ##################### 657 # RESTART 658 restart_file = system_data["name"]+".qtb.restart" 659 if os.path.exists(restart_file): 660 with open(restart_file, "rb") as f: 661 data = pickle.load(f) 662 state["corr_kin"] = data["corr_kin"] 663 post_state["corr_kin_prev"] = data["corr_kin"] 664 post_state["isame_kin"] = data["isame_kin"] 665 post_state["do_corr_kin"] = data["do_corr_kin"] 666 print(f"# Restored QTB state from {restart_file}") 667 668 def write_qtb_restart(state, post_state): 669 with open(restart_file, "wb") as f: 670 pickle.dump( 671 { 672 "corr_kin": state["corr_kin"], 673 "corr_kin_prev": post_state["corr_kin_prev"], 674 "isame_kin": post_state["isame_kin"], 675 "do_corr_kin": post_state["do_corr_kin"], 676 }, 677 f, 678 ) 679 ###################### 680 681 def compute_corr_pot(niter=20, verbose=False): 682 if classical_kernel or hbar == 0: 683 return np.ones(nom) 684 685 s_0 = np.array((theta / kT * cutoff)[:nom]) 686 s_out, s_rec, _ = deconvolute_spectrum( 687 s_0, 688 omega[:nom], 689 gamma, 690 niter, 691 kernel=kernel_lorentz_pot, 692 trans=True, 693 symmetrize=True, 694 verbose=verbose, 695 ) 696 corr_pot = 1.0 + (s_out - s_0) / s_0 697 columns = np.column_stack( 698 (omega[:nom] * us.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 699 ) 700 np.savetxt( 701 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 702 ) 703 return corr_pot 704 705 def compute_corr_kin(post_state, niter=7, verbose=False): 706 if not post_state["do_corr_kin"]: 707 return post_state["corr_kin_prev"], post_state 708 if classical_kernel or hbar == 0: 709 return 1.0, post_state 710 711 K_D = post_state.get("K_D", None) 712 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 713 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 714 s_out, s_rec, K_D = deconvolute_spectrum( 715 s_0, 716 omega[:nom], 717 gamma, 718 niter, 719 kernel=kernel_lorentz, 720 trans=False, 721 symmetrize=True, 722 verbose=verbose, 723 K_D=K_D, 724 ) 725 s_out = s_out * theta[:nom] / kT 726 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 727 mCvvsum = mCvv.sum() 728 rec_ratio = mCvvsum / s_rec.sum() 729 if rec_ratio < 0.95 or rec_ratio > 1.05: 730 print( 731 f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated" 732 ) 733 return post_state["corr_kin_prev"], post_state 734 735 corr_kin = mCvvsum / s_out.sum() 736 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 737 isame_kin = post_state["isame_kin"] + 1 738 else: 739 isame_kin = 0 740 741 # print("# corr_kin: ", corr_kin) 742 do_corr_kin = post_state["do_corr_kin"] 743 if isame_kin > 10: 744 print( 745 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 746 ) 747 do_corr_kin = False 748 749 return corr_kin, { 750 **post_state, 751 "corr_kin_prev": corr_kin, 752 "isame_kin": isame_kin, 753 "do_corr_kin": do_corr_kin, 754 "K_D": K_D, 755 } 756 757 @jax.jit 758 def ff_kernel(post_state): 759 if classical_kernel: 760 kernel = cutoff * (2 * gamma * kT / dt) 761 else: 762 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 763 gamma_ratio = jnp.concatenate( 764 ( 765 post_state["gammar"].T * post_state["corr_pot"][:, None], 766 jnp.ones( 767 (kernel.shape[0] - nom, ntypes), dtype=post_state["gammar"].dtype 768 ), 769 ), 770 axis=0, 771 ) 772 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 773 774 @jax.jit 775 def refresh_force(post_state): 776 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 777 white_noise = jnp.concatenate( 778 ( 779 post_state["white_noise"][nseg:], 780 jax.random.normal( 781 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 782 ), 783 ), 784 axis=0, 785 ) 786 amplitude = ff_kernel(post_state) ** 0.5 787 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 788 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 789 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 790 791 @jax.jit 792 def compute_spectra(force, vel, post_state): 793 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 794 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 795 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 796 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 797 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 798 799 mCvv = ( 800 (dt / 3.0) 801 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 802 * mass_idx[:, None] 803 / n_of_type[:, None] 804 ) 805 Cvf = ( 806 (dt / 3.0) 807 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 808 / n_of_type[:, None] 809 ) 810 Cff = ( 811 (dt / 3.0) 812 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 813 / n_of_type[:, None] 814 ) 815 dFDT = mCvv * post_state["gammar"] - Cvf 816 817 nsinv = 1.0 / post_state["nsample"] 818 b1 = 1.0 - nsinv 819 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 820 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 821 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 822 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 823 824 return { 825 **post_state, 826 "mCvv": mCvv, 827 "Cvf": Cvf, 828 "Cff": Cff, 829 "dFDT": dFDT, 830 "dFDT_avg": dFDT_avg, 831 "mCvv_avg": mCvv_avg, 832 "Cvfg_avg": Cvfg_avg, 833 "Cff_avg": Cff_avg, 834 } 835 836 def write_spectra_to_file(post_state): 837 mCvv_avg = np.array(post_state["mCvv_avg"]) 838 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 839 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 840 dFDT_avg = np.array(post_state["dFDT_avg"]) 841 gammar = np.array(post_state["gammar"]) 842 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 843 for i, sp in enumerate(type_labels): 844 ff_scale = us.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 845 columns = np.column_stack( 846 ( 847 omega[:nom] * us.CM1, 848 mCvv_avg[i], 849 Cvfg_avg[i], 850 dFDT_avg[i], 851 gammar[i] * gamma * us.THZ, 852 Cff_avg[i] * ff_scale, 853 Cff_theo[i] * ff_scale, 854 ) 855 ) 856 np.savetxt( 857 f"QTB_spectra_{sp}.out", 858 columns, 859 fmt="%12.6f", 860 header="#omega mCvv Cvf dFDT gammar Cff", 861 ) 862 if verbose: 863 print("# QTB spectra written.") 864 865 if compute_thermostat_energy: 866 state["qtb_energy_flux"] = 0.0 867 868 @jax.jit 869 def thermostat(vel, state): 870 istep = state["istep"] 871 dvel = dt * state["force"][istep] / mass[:, None] 872 new_vel = vel * a1 + dvel 873 new_state = {**state, "istep": istep + 1} 874 if do_compute_spectra: 875 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 876 new_state["vel"] = vel2 877 if compute_thermostat_energy: 878 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 879 ekcorr = ( 880 0.5 881 * (mass[:, None] * new_vel**2).sum() 882 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 883 ) 884 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 885 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 886 return new_vel, new_state 887 888 @jax.jit 889 def postprocess_work(state, post_state): 890 if do_compute_spectra: 891 post_state = compute_spectra(state["force"], state["vel"], post_state) 892 if adaptive: 893 post_state = jax.lax.cond( 894 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 895 ) 896 new_force, post_state = refresh_force(post_state) 897 return {**state, "force": new_force}, post_state 898 899 def postprocess(state, post_state): 900 counter.increment() 901 if not counter.is_reset_step: 902 return state, post_state 903 post_state["nadapt"] += 1 904 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 905 if verbose: 906 print("# Refreshing QTB forces.") 907 state, post_state = postprocess_work(state, post_state) 908 state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin) 909 state["istep"] = 0 910 if write_spectra: 911 write_spectra_to_file(post_state) 912 write_qtb_restart(state, post_state) 913 return state, post_state 914 915 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec) 916 917 state["force"], post_state = refresh_force(post_state) 918 return thermostat, (postprocess, post_state), state
def
get_thermostat( simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
21def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}): 22 state = {} 23 postprocess = None 24 25 26 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 27 """@keyword[fennol_md] thermostat 28 Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'. 29 Default: "LGV" 30 """ 31 compute_thermostat_energy = simulation_parameters.get( 32 "include_thermostat_energy", False 33 ) 34 """@keyword[fennol_md] include_thermostat_energy 35 Include thermostat energy in total energy calculations. 36 Default: False 37 """ 38 39 kT = system_data.get("kT", None) 40 nbeads = system_data.get("nbeads", None) 41 mass = system_data["mass"] 42 gamma = simulation_parameters.get("gamma", 1.0 / us.THZ) 43 """@keyword[fennol_md] gamma 44 Friction coefficient for Langevin thermostat. 45 Default: 1.0 ps^-1 46 """ 47 species = system_data["species"] 48 49 if nbeads is not None: 50 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 51 """@keyword[fennol_md] trpmd_lambda 52 Lambda parameter for TRPMD (Thermostatted Ring Polymer MD). 53 Default: 1.0 54 """ 55 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 56 57 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 58 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 59 assert kT is not None, "kT must be specified for QTB thermostat" 60 assert gamma is not None, "gamma must be specified for QTB thermostat" 61 rng_key, v_key = jax.random.split(rng_key) 62 if nbeads is None: 63 a1 = math.exp(-gamma * dt) 64 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 65 vel = ( 66 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 67 * (kT / mass[:, None]) ** 0.5 68 ) 69 else: 70 if isinstance(gamma, float): 71 gamma = np.array([gamma] * nbeads) 72 assert isinstance( 73 gamma, np.ndarray 74 ), "gamma must be a float or a numpy array" 75 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 76 a1 = np.exp(-gamma * dt)[:, None, None] 77 a2 = jnp.asarray( 78 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 79 ) 80 vel = ( 81 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 82 * (kT / mass[:, None]) ** 0.5 83 ) 84 85 state["rng_key"] = rng_key 86 if compute_thermostat_energy: 87 state["thermostat_energy"] = 0.0 88 if thermostat_name == "FFLGV": 89 def thermostat(vel, state): 90 rng_key, noise_key = jax.random.split(state["rng_key"]) 91 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 92 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 93 dirvel = vel / norm_vel 94 if compute_thermostat_energy: 95 v2 = (vel**2).sum(axis=-1) 96 vel = a1 * vel + a2 * noise 97 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 98 vel = dirvel * new_norm_vel 99 new_state = {**state, "rng_key": rng_key} 100 if compute_thermostat_energy: 101 v2new = (vel**2).sum(axis=-1) 102 new_state["thermostat_energy"] = ( 103 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 104 ) 105 106 return vel, new_state 107 108 else: 109 def thermostat(vel, state): 110 rng_key, noise_key = jax.random.split(state["rng_key"]) 111 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 112 if compute_thermostat_energy: 113 v2 = (vel**2).sum(axis=-1) 114 vel = a1 * vel + a2 * noise 115 new_state = {**state, "rng_key": rng_key} 116 if compute_thermostat_energy: 117 v2new = (vel**2).sum(axis=-1) 118 new_state["thermostat_energy"] = ( 119 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 120 ) 121 return vel, new_state 122 123 elif thermostat_name in ["BUSSI"]: 124 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 125 assert kT is not None, "kT must be specified for QTB thermostat" 126 assert gamma is not None, "gamma must be specified for QTB thermostat" 127 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 128 rng_key, v_key = jax.random.split(rng_key) 129 130 a1 = math.exp(-gamma * dt) 131 a2 = (1 - a1) * kT 132 vel = ( 133 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 134 * (kT / mass[:, None]) ** 0.5 135 ) 136 137 state["rng_key"] = rng_key 138 if compute_thermostat_energy: 139 state["thermostat_energy"] = 0.0 140 141 def thermostat(vel, state): 142 rng_key, noise_key = jax.random.split(state["rng_key"]) 143 new_state = {**state, "rng_key": rng_key} 144 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 145 R2 = jnp.sum(noise**2) 146 R1 = noise[0, 0] 147 c = a2 / (mass[:, None] * vel**2).sum() 148 d = (a1 * c) ** 0.5 149 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 150 if compute_thermostat_energy: 151 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 152 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 153 return scale * vel, new_state 154 155 elif thermostat_name in [ 156 "GD", 157 "DESCENT", 158 "GRADIENT_DESCENT", 159 "MIN", 160 "MINIMIZE", 161 ]: 162 assert nbeads is None, "Gradient descent is not compatible with PIMD" 163 a1 = math.exp(-gamma * dt) 164 165 if nbeads is None: 166 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 167 else: 168 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 169 170 def thermostat(vel, state): 171 return a1 * vel, state 172 173 elif thermostat_name in ["NVE", "NONE"]: 174 if nbeads is None: 175 vel = ( 176 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 177 * (kT / mass[:, None]) ** 0.5 178 ) 179 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 180 vel = vel * (kT / kTsys) ** 0.5 181 else: 182 vel = ( 183 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 184 * (kT / mass[None, :, None]) ** 0.5 185 ) 186 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 187 mass.shape[0] * 3 188 ) 189 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 190 thermostat = lambda x, s: (x, s) 191 192 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 193 assert gamma is not None, "gamma must be specified for QTB thermostat" 194 ndof = mass.shape[0] * 3 195 nkT = ndof * kT 196 nose_mass = nkT / gamma**2 197 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 198 state["nose_s"] = 0.0 199 state["nose_v"] = 0.0 200 if compute_thermostat_energy: 201 state["thermostat_energy"] = 0.0 202 print( 203 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 204 ) 205 vel = ( 206 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 207 * (kT / mass[:, None]) ** 0.5 208 ) 209 210 def thermostat(vel, state): 211 nose_s = state["nose_s"] 212 nose_v = state["nose_v"] 213 kTsys = jnp.sum(mass[:, None] * vel**2) 214 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 215 nose_s = nose_s + dt * nose_v 216 vel = jnp.exp(-nose_v * dt) * vel 217 kTsys = jnp.sum(mass[:, None] * vel**2) 218 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 219 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 220 221 if compute_thermostat_energy: 222 new_state["thermostat_energy"] = ( 223 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 224 ) 225 return vel, new_state 226 227 elif thermostat_name in ["QTB", "ADQTB"]: 228 assert nbeads is None, "QTB is not compatible with PIMD" 229 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 230 assert kT is not None, "kT must be specified for QTB thermostat" 231 assert gamma is not None, "gamma must be specified for QTB thermostat" 232 assert species is not None, "species must be provided for QTB thermostat" 233 rng_key, v_key = jax.random.split(rng_key) 234 vel = ( 235 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 236 * (kT / mass[:, None]) ** 0.5 237 ) 238 239 thermostat, postprocess, qtb_state = initialize_qtb( 240 simulation_parameters, 241 system_data, 242 fprec=fprec, 243 dt=dt, 244 mass=mass, 245 gamma=gamma, 246 kT=kT, 247 species=species, 248 rng_key=rng_key, 249 adaptive=thermostat_name.startswith("AD"), 250 compute_thermostat_energy=compute_thermostat_energy, 251 ) 252 state = {**state, **qtb_state} 253 254 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 255 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 256 assert kT is not None, "kT must be specified for QTB thermostat" 257 assert gamma is not None, "gamma must be specified for QTB thermostat" 258 assert nbeads is None, "ANNEAL is not compatible with PIMD" 259 a1 = math.exp(-gamma * dt) 260 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 261 262 anneal_parameters = simulation_parameters.get("annealing", {}) 263 """@keyword[fennol_md] annealing 264 Parameters for simulated annealing schedule configuration. 265 Required for ANNEAL/ANNEALING thermostat 266 """ 267 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 268 """@keyword[fennol_md] annealing/init_factor 269 Initial temperature factor for annealing schedule. 270 Default: 0.04 (1/25) 271 """ 272 assert init_factor > 0.0, "init_factor must be positive" 273 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 274 """@keyword[fennol_md] annealing/final_factor 275 Final temperature factor for annealing schedule. 276 Default: 0.0001 (1/10000) 277 """ 278 assert final_factor > 0.0, "final_factor must be positive" 279 nsteps = simulation_parameters.get("nsteps") 280 """@keyword[fennol_md] nsteps 281 Total number of simulation steps for annealing schedule calculation. 282 Required parameter 283 """ 284 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 285 """@keyword[fennol_md] annealing/anneal_steps 286 Fraction of total steps for annealing process. 287 Default: 1.0 288 """ 289 assert ( 290 anneal_steps < 1.0 and anneal_steps > 0.0 291 ), "warmup_steps must be between 0 and nsteps" 292 pct_start = anneal_parameters.get("warmup_steps", 0.3) 293 """@keyword[fennol_md] annealing/warmup_steps 294 Fraction of annealing steps for warmup phase. 295 Default: 0.3 296 """ 297 assert ( 298 pct_start < 1.0 and pct_start > 0.0 299 ), "warmup_steps must be between 0 and nsteps" 300 301 anneal_type = anneal_parameters.get("type", "cosine").lower() 302 """@keyword[fennol_md] annealing/type 303 Type of annealing schedule (linear, cosine_onecycle). 304 Default: cosine 305 """ 306 if anneal_type == "linear": 307 schedule = optax.linear_onecycle_schedule( 308 peak_value=1.0, 309 div_factor=1.0 / init_factor, 310 final_div_factor=1.0 / final_factor, 311 transition_steps=int(anneal_steps * nsteps), 312 pct_start=pct_start, 313 pct_final=1.0, 314 ) 315 elif anneal_type == "cosine_onecycle": 316 schedule = optax.cosine_onecycle_schedule( 317 peak_value=1.0, 318 div_factor=1.0 / init_factor, 319 final_div_factor=1.0 / final_factor, 320 transition_steps=int(anneal_steps * nsteps), 321 pct_start=pct_start, 322 ) 323 else: 324 raise ValueError(f"Unknown anneal_type {anneal_type}") 325 326 state["rng_key"] = rng_key 327 state["istep_anneal"] = 0 328 329 rng_key, v_key = jax.random.split(rng_key) 330 Tscale = schedule(0) 331 print(f"# ANNEAL: initial temperature = {Tscale*kT*us.KELVIN:.3e} K") 332 vel = ( 333 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 334 * (kT * Tscale / mass[:, None]) ** 0.5 335 ) 336 337 def thermostat(vel, state): 338 rng_key, noise_key = jax.random.split(state["rng_key"]) 339 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 340 341 Tscale = schedule(state["istep_anneal"]) ** 0.5 342 vel = a1 * vel + a2 * Tscale * noise 343 return vel, { 344 **state, 345 "rng_key": rng_key, 346 "istep_anneal": state["istep_anneal"] + 1, 347 } 348 349 else: 350 raise ValueError(f"Unknown thermostat {thermostat_name}") 351 352 return thermostat, postprocess, state, vel,thermostat_name
def
initialize_qtb( simulation_parameters, system_data, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
355def initialize_qtb( 356 simulation_parameters, 357 system_data, 358 fprec, 359 dt, 360 mass, 361 gamma, 362 kT, 363 species, 364 rng_key, 365 adaptive, 366 compute_thermostat_energy=False, 367): 368 state = {} 369 post_state = {} 370 qtb_parameters = simulation_parameters.get("qtb", {}) 371 verbose = qtb_parameters.get("verbose", False) 372 """@keyword[fennol_md] qtb/verbose 373 Print verbose QTB thermostat information. 374 Default: False 375 """ 376 if compute_thermostat_energy: 377 state["thermostat_energy"] = 0.0 378 379 mass = jnp.asarray(mass, dtype=fprec) 380 381 nat = species.shape[0] 382 # define type indices 383 species_set = set(species) 384 nspecies = len(species_set) 385 idx = {sp: i for i, sp in enumerate(species_set)} 386 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 387 type_labels = [PERIODIC_TABLE[sp] for sp in species_set] 388 389 adapt_groups = qtb_parameters.get("adapt_groups", {}) 390 """@keyword[fennol_md] qtb/adapt_groups 391 Groups of atoms with separate adQTB parameters (atoms of differents species within groups will be separated in subtypes). 392 Default: {} 393 """ 394 ntypes = nspecies 395 allgroups = set() 396 for groupname, interval in adapt_groups.items(): 397 indices = read_tinker_interval(interval) 398 assert allgroups.isdisjoint(indices), f"Indices in group {groupname} overlap with other groups" 399 allgroups.update(indices) 400 species_in_group = set(species[indices]) 401 idx = {sp: i for i, sp in enumerate(species_in_group)} 402 for i in indices: 403 type_idx[i] = ntypes + idx[species[i]] 404 type_labels += [f"{PERIODIC_TABLE[sp]}_{groupname}" for sp in species_in_group] 405 ntypes = len(type_labels) 406 407 n_of_type = np.zeros(ntypes, dtype=np.int32) 408 for i in range(ntypes): 409 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 410 print(f"# QTB: n({type_labels[i]}) =", n_of_type[i]) 411 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 412 mass_idx = jax.ops.segment_sum(mass, type_idx, ntypes) / n_of_type 413 414 niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7) 415 """@keyword[fennol_md] qtb/niter_deconv_kin 416 Number of iterations for kinetic energy deconvolution. 417 Default: 7 418 """ 419 niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20) 420 """@keyword[fennol_md] qtb/niter_deconv_pot 421 Number of iterations for potential energy deconvolution. 422 Default: 20 423 """ 424 corr_kin = qtb_parameters.get("corr_kin", -1) 425 """@keyword[fennol_md] qtb/corr_kin 426 Kinetic energy correction factor for QTB (-1 for automatic). 427 Default: -1 428 """ 429 do_corr_kin = corr_kin <= 0 430 if do_corr_kin: 431 corr_kin = 1.0 432 state["corr_kin"] = corr_kin 433 post_state["corr_kin_prev"] = corr_kin 434 post_state["do_corr_kin"] = do_corr_kin 435 post_state["isame_kin"] = 0 436 437 # spectra parameters 438 omegasmear = np.pi / dt / 100.0 439 Tseg = qtb_parameters.get("tseg", 1.0 / us.PS) 440 """@keyword[fennol_md] qtb/tseg 441 Time segment length for QTB spectrum calculation. 442 Default: 1.0 ps 443 """ 444 nseg = int(Tseg / dt) 445 Tseg = nseg * dt 446 dom = 2 * np.pi / (3 * Tseg) 447 omegacut = qtb_parameters.get("omegacut", 15000.0 / us.CM1) 448 """@keyword[fennol_md] qtb/omegacut 449 Cutoff frequency for QTB spectrum. 450 Default: 15000.0 cm⁻¹ 451 """ 452 nom = int(omegacut / dom) 453 omega = dom * np.arange((3 * nseg) // 2 + 1) 454 cutoff = jnp.asarray( 455 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 456 ) 457 assert ( 458 omegacut < omega[-1] 459 ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1" 460 461 # initialize gammar 462 assert ( 463 gamma < 0.5 * omegacut 464 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 465 gammar_min = qtb_parameters.get("gammar_min", 0.1) 466 """@keyword[fennol_md] qtb/gammar_min 467 Minimum value for QTB gamma ratio coefficients. 468 Default: 0.1 469 """ 470 # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 471 gammar = np.ones((ntypes, nom), dtype=float) 472 try: 473 for i, sp in enumerate(type_labels): 474 if not os.path.exists(f"QTB_spectra_{sp}.out"): continue 475 data = np.loadtxt(f"QTB_spectra_{sp}.out") 476 gammar[i] = data[:, 4]/(gamma*us.THZ) 477 print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out") 478 except Exception as e: 479 print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.") 480 gammar[:,:] = 1.0 481 post_state["gammar"] = jnp.asarray(gammar, dtype=fprec) 482 483 # Ornstein-Uhlenbeck correction for colored noise 484 a1 = np.exp(-gamma * dt) 485 OUcorr = jnp.asarray( 486 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 487 dtype=fprec, 488 ) 489 490 # hbar schedule 491 classical_kernel = qtb_parameters.get("classical_kernel", False) 492 """@keyword[fennol_md] qtb/classical_kernel 493 Use classical instead of quantum kernel for QTB. 494 Default: False 495 """ 496 hbar = qtb_parameters.get("hbar", 1.0) * us.HBAR 497 """@keyword[fennol_md] qtb/hbar 498 Reduced Planck constant scaling factor for quantum effects. 499 Default: 1.0 a.u. 500 """ 501 u = 0.5 * hbar * np.abs(omega) / kT 502 theta = kT * np.ones_like(omega) 503 if hbar > 0: 504 theta[1:] *= u[1:] / np.tanh(u[1:]) 505 theta = jnp.asarray(theta, dtype=fprec) 506 507 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 508 del rng_key 509 post_state["white_noise"] = jax.random.normal( 510 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 511 ) 512 513 startsave = qtb_parameters.get("startsave", 1) 514 """@keyword[fennol_md] qtb/startsave 515 Start saving QTB statistics after this many segments. 516 Default: 1 517 """ 518 counter = Counter(nseg, startsave=startsave) 519 state["istep"] = 0 520 post_state["nadapt"] = 0 521 post_state["nsample"] = 0 522 523 write_spectra = qtb_parameters.get("write_spectra", True) 524 """@keyword[fennol_md] qtb/write_spectra 525 Write QTB spectral analysis output files. 526 Default: True 527 """ 528 do_compute_spectra = write_spectra or adaptive 529 530 if do_compute_spectra: 531 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 532 533 post_state["dFDT"] = jnp.zeros((ntypes, nom), dtype=fprec) 534 post_state["mCvv"] = jnp.zeros((ntypes, nom), dtype=fprec) 535 post_state["Cvf"] = jnp.zeros((ntypes, nom), dtype=fprec) 536 post_state["Cff"] = jnp.zeros((ntypes, nom), dtype=fprec) 537 post_state["dFDT_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 538 post_state["mCvv_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 539 post_state["Cvfg_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 540 post_state["Cff_avg"] = jnp.zeros((ntypes, nom), dtype=fprec) 541 542 if not adaptive: 543 update_gammar = lambda x: x 544 else: 545 # adaptation parameters 546 skipseg = qtb_parameters.get("skipseg", 1) 547 """@keyword[fennol_md] qtb/skipseg 548 Number of segments to skip before starting adaptive QTB. 549 Default: 1 550 """ 551 552 adaptation_method = ( 553 str(qtb_parameters.get("adaptation_method", "SIMPLE")).upper().strip() 554 ) 555 """@keyword[fennol_md] qtb/adaptation_method 556 Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF). 557 Default: SIMPLE 558 """ 559 if adaptation_method == "SIMPLE": 560 agamma = qtb_parameters.get("agamma", 0.1) 561 """@keyword[fennol_md] qtb/agamma 562 Learning rate for adaptive QTB gamma update. 563 Default: 0.1 564 """ 565 assert agamma > 0, "agamma must be positive" 566 a1_ad = agamma * Tseg # * gamma 567 print(f"# ADQTB SIMPLE: agamma = {agamma:.3f}") 568 569 def update_gammar(post_state): 570 g = post_state["dFDT"] 571 gammar = post_state["gammar"] - a1_ad * g 572 gammar = jnp.maximum(gammar_min, gammar) 573 return {**post_state, "gammar": gammar} 574 575 elif adaptation_method == "RATIO": 576 tau_ad = qtb_parameters.get("tau_ad", 5.0 / us.PS) 577 """@keyword[fennol_md] qtb/tau_ad 578 Adaptation time constant for momentum averaging. 579 Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF) 580 """ 581 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) 582 """@keyword[fennol_md] qtb/tau_s 583 Second moment time constant for variance averaging. 584 Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF) 585 """ 586 assert tau_ad > 0, "tau_ad must be positive" 587 print( 588 f"# ADQTB RATIO: tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps" 589 ) 590 b1 = np.exp(-Tseg / tau_ad) 591 b2 = np.exp(-Tseg / tau_s) 592 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 593 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 594 post_state["n_adabelief"] = 0 595 596 def update_gammar(post_state): 597 n_adabelief = post_state["n_adabelief"] + 1 598 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 599 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 600 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 601 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 602 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 603 gammar = Cvf / (mCvv + 1.0e-8) 604 gammar = jnp.maximum(gammar_min, gammar) 605 return { 606 **post_state, 607 "gammar": gammar, 608 "mCvv_m": mCvv_m, 609 "Cvf_m": Cvf_m, 610 "n_adabelief": n_adabelief, 611 } 612 613 elif adaptation_method == "ADABELIEF": 614 agamma = qtb_parameters.get("agamma", 0.1) 615 tau_ad = qtb_parameters.get("tau_ad", 1.0 / us.PS) 616 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) 617 assert tau_ad > 0, "tau_ad must be positive" 618 assert tau_s > 0, "tau_s must be positive" 619 assert agamma > 0, "agamma must be positive" 620 print( 621 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps" 622 ) 623 624 a1_ad = agamma * gamma # * Tseg #* gamma 625 b1 = np.exp(-Tseg / tau_ad) 626 b2 = np.exp(-Tseg / tau_s) 627 post_state["dFDT_m"] = jnp.zeros((ntypes, nom), dtype=fprec) 628 post_state["dFDT_s"] = jnp.zeros((ntypes, nom), dtype=fprec) 629 post_state["n_adabelief"] = 0 630 631 def update_gammar(post_state): 632 n_adabelief = post_state["n_adabelief"] + 1 633 dFDT = post_state["dFDT"] 634 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 635 dFDT_s = ( 636 post_state["dFDT_s"] * b2 637 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 638 + 1.0e-8 639 ) 640 # bias correction 641 mt = dFDT_m / (1.0 - b1**n_adabelief) 642 st = dFDT_s / (1.0 - b2**n_adabelief) 643 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 644 gammar = jnp.maximum(gammar_min, gammar) 645 return { 646 **post_state, 647 "gammar": gammar, 648 "dFDT_m": dFDT_m, 649 "n_adabelief": n_adabelief, 650 "dFDT_s": dFDT_s, 651 } 652 else: 653 raise ValueError( 654 f"Unknown adaptation method {adaptation_method}." 655 ) 656 657 ##################### 658 # RESTART 659 restart_file = system_data["name"]+".qtb.restart" 660 if os.path.exists(restart_file): 661 with open(restart_file, "rb") as f: 662 data = pickle.load(f) 663 state["corr_kin"] = data["corr_kin"] 664 post_state["corr_kin_prev"] = data["corr_kin"] 665 post_state["isame_kin"] = data["isame_kin"] 666 post_state["do_corr_kin"] = data["do_corr_kin"] 667 print(f"# Restored QTB state from {restart_file}") 668 669 def write_qtb_restart(state, post_state): 670 with open(restart_file, "wb") as f: 671 pickle.dump( 672 { 673 "corr_kin": state["corr_kin"], 674 "corr_kin_prev": post_state["corr_kin_prev"], 675 "isame_kin": post_state["isame_kin"], 676 "do_corr_kin": post_state["do_corr_kin"], 677 }, 678 f, 679 ) 680 ###################### 681 682 def compute_corr_pot(niter=20, verbose=False): 683 if classical_kernel or hbar == 0: 684 return np.ones(nom) 685 686 s_0 = np.array((theta / kT * cutoff)[:nom]) 687 s_out, s_rec, _ = deconvolute_spectrum( 688 s_0, 689 omega[:nom], 690 gamma, 691 niter, 692 kernel=kernel_lorentz_pot, 693 trans=True, 694 symmetrize=True, 695 verbose=verbose, 696 ) 697 corr_pot = 1.0 + (s_out - s_0) / s_0 698 columns = np.column_stack( 699 (omega[:nom] * us.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 700 ) 701 np.savetxt( 702 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 703 ) 704 return corr_pot 705 706 def compute_corr_kin(post_state, niter=7, verbose=False): 707 if not post_state["do_corr_kin"]: 708 return post_state["corr_kin_prev"], post_state 709 if classical_kernel or hbar == 0: 710 return 1.0, post_state 711 712 K_D = post_state.get("K_D", None) 713 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 714 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 715 s_out, s_rec, K_D = deconvolute_spectrum( 716 s_0, 717 omega[:nom], 718 gamma, 719 niter, 720 kernel=kernel_lorentz, 721 trans=False, 722 symmetrize=True, 723 verbose=verbose, 724 K_D=K_D, 725 ) 726 s_out = s_out * theta[:nom] / kT 727 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 728 mCvvsum = mCvv.sum() 729 rec_ratio = mCvvsum / s_rec.sum() 730 if rec_ratio < 0.95 or rec_ratio > 1.05: 731 print( 732 f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated" 733 ) 734 return post_state["corr_kin_prev"], post_state 735 736 corr_kin = mCvvsum / s_out.sum() 737 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 738 isame_kin = post_state["isame_kin"] + 1 739 else: 740 isame_kin = 0 741 742 # print("# corr_kin: ", corr_kin) 743 do_corr_kin = post_state["do_corr_kin"] 744 if isame_kin > 10: 745 print( 746 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 747 ) 748 do_corr_kin = False 749 750 return corr_kin, { 751 **post_state, 752 "corr_kin_prev": corr_kin, 753 "isame_kin": isame_kin, 754 "do_corr_kin": do_corr_kin, 755 "K_D": K_D, 756 } 757 758 @jax.jit 759 def ff_kernel(post_state): 760 if classical_kernel: 761 kernel = cutoff * (2 * gamma * kT / dt) 762 else: 763 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 764 gamma_ratio = jnp.concatenate( 765 ( 766 post_state["gammar"].T * post_state["corr_pot"][:, None], 767 jnp.ones( 768 (kernel.shape[0] - nom, ntypes), dtype=post_state["gammar"].dtype 769 ), 770 ), 771 axis=0, 772 ) 773 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 774 775 @jax.jit 776 def refresh_force(post_state): 777 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 778 white_noise = jnp.concatenate( 779 ( 780 post_state["white_noise"][nseg:], 781 jax.random.normal( 782 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 783 ), 784 ), 785 axis=0, 786 ) 787 amplitude = ff_kernel(post_state) ** 0.5 788 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 789 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 790 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 791 792 @jax.jit 793 def compute_spectra(force, vel, post_state): 794 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 795 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 796 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 797 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 798 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 799 800 mCvv = ( 801 (dt / 3.0) 802 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 803 * mass_idx[:, None] 804 / n_of_type[:, None] 805 ) 806 Cvf = ( 807 (dt / 3.0) 808 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 809 / n_of_type[:, None] 810 ) 811 Cff = ( 812 (dt / 3.0) 813 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 814 / n_of_type[:, None] 815 ) 816 dFDT = mCvv * post_state["gammar"] - Cvf 817 818 nsinv = 1.0 / post_state["nsample"] 819 b1 = 1.0 - nsinv 820 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 821 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 822 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 823 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 824 825 return { 826 **post_state, 827 "mCvv": mCvv, 828 "Cvf": Cvf, 829 "Cff": Cff, 830 "dFDT": dFDT, 831 "dFDT_avg": dFDT_avg, 832 "mCvv_avg": mCvv_avg, 833 "Cvfg_avg": Cvfg_avg, 834 "Cff_avg": Cff_avg, 835 } 836 837 def write_spectra_to_file(post_state): 838 mCvv_avg = np.array(post_state["mCvv_avg"]) 839 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 840 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 841 dFDT_avg = np.array(post_state["dFDT_avg"]) 842 gammar = np.array(post_state["gammar"]) 843 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 844 for i, sp in enumerate(type_labels): 845 ff_scale = us.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 846 columns = np.column_stack( 847 ( 848 omega[:nom] * us.CM1, 849 mCvv_avg[i], 850 Cvfg_avg[i], 851 dFDT_avg[i], 852 gammar[i] * gamma * us.THZ, 853 Cff_avg[i] * ff_scale, 854 Cff_theo[i] * ff_scale, 855 ) 856 ) 857 np.savetxt( 858 f"QTB_spectra_{sp}.out", 859 columns, 860 fmt="%12.6f", 861 header="#omega mCvv Cvf dFDT gammar Cff", 862 ) 863 if verbose: 864 print("# QTB spectra written.") 865 866 if compute_thermostat_energy: 867 state["qtb_energy_flux"] = 0.0 868 869 @jax.jit 870 def thermostat(vel, state): 871 istep = state["istep"] 872 dvel = dt * state["force"][istep] / mass[:, None] 873 new_vel = vel * a1 + dvel 874 new_state = {**state, "istep": istep + 1} 875 if do_compute_spectra: 876 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 877 new_state["vel"] = vel2 878 if compute_thermostat_energy: 879 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 880 ekcorr = ( 881 0.5 882 * (mass[:, None] * new_vel**2).sum() 883 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 884 ) 885 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 886 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 887 return new_vel, new_state 888 889 @jax.jit 890 def postprocess_work(state, post_state): 891 if do_compute_spectra: 892 post_state = compute_spectra(state["force"], state["vel"], post_state) 893 if adaptive: 894 post_state = jax.lax.cond( 895 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 896 ) 897 new_force, post_state = refresh_force(post_state) 898 return {**state, "force": new_force}, post_state 899 900 def postprocess(state, post_state): 901 counter.increment() 902 if not counter.is_reset_step: 903 return state, post_state 904 post_state["nadapt"] += 1 905 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 906 if verbose: 907 print("# Refreshing QTB forces.") 908 state, post_state = postprocess_work(state, post_state) 909 state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin) 910 state["istep"] = 0 911 if write_spectra: 912 write_spectra_to_file(post_state) 913 write_qtb_restart(state, post_state) 914 return state, post_state 915 916 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec) 917 918 state["force"], post_state = refresh_force(post_state) 919 return thermostat, (postprocess, post_state), state