fennol.md.thermostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7import os 8import pickle 9 10from ..utils.atomic_units import AtomicUnits as au # CM1,THZ,BOHR,MPROT 11from ..utils import Counter 12from ..utils.deconvolution import ( 13 deconvolute_spectrum, 14 kernel_lorentz_pot, 15 kernel_lorentz, 16) 17 18 19def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}): 20 state = {} 21 postprocess = None 22 23 24 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 25 compute_thermostat_energy = simulation_parameters.get( 26 "include_thermostat_energy", False 27 ) 28 29 kT = system_data.get("kT", None) 30 nbeads = system_data.get("nbeads", None) 31 mass = system_data["mass"] 32 gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS 33 species = system_data["species"] 34 35 if nbeads is not None: 36 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 37 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 38 39 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 40 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 41 assert kT is not None, "kT must be specified for QTB thermostat" 42 assert gamma is not None, "gamma must be specified for QTB thermostat" 43 rng_key, v_key = jax.random.split(rng_key) 44 if nbeads is None: 45 a1 = math.exp(-gamma * dt) 46 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 47 vel = ( 48 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 49 * (kT / mass[:, None]) ** 0.5 50 ) 51 else: 52 if isinstance(gamma, float): 53 gamma = np.array([gamma] * nbeads) 54 assert isinstance( 55 gamma, np.ndarray 56 ), "gamma must be a float or a numpy array" 57 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 58 a1 = np.exp(-gamma * dt)[:, None, None] 59 a2 = jnp.asarray( 60 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 61 ) 62 vel = ( 63 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 64 * (kT / mass[:, None]) ** 0.5 65 ) 66 67 state["rng_key"] = rng_key 68 if compute_thermostat_energy: 69 state["thermostat_energy"] = 0.0 70 if thermostat_name == "FFLGV": 71 def thermostat(vel, state): 72 rng_key, noise_key = jax.random.split(state["rng_key"]) 73 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 74 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 75 dirvel = vel / norm_vel 76 if compute_thermostat_energy: 77 v2 = (vel**2).sum(axis=-1) 78 vel = a1 * vel + a2 * noise 79 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 80 vel = dirvel * new_norm_vel 81 new_state = {**state, "rng_key": rng_key} 82 if compute_thermostat_energy: 83 v2new = (vel**2).sum(axis=-1) 84 new_state["thermostat_energy"] = ( 85 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 86 ) 87 88 return vel, new_state 89 90 else: 91 def thermostat(vel, state): 92 rng_key, noise_key = jax.random.split(state["rng_key"]) 93 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 94 if compute_thermostat_energy: 95 v2 = (vel**2).sum(axis=-1) 96 vel = a1 * vel + a2 * noise 97 new_state = {**state, "rng_key": rng_key} 98 if compute_thermostat_energy: 99 v2new = (vel**2).sum(axis=-1) 100 new_state["thermostat_energy"] = ( 101 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 102 ) 103 return vel, new_state 104 105 elif thermostat_name in ["BUSSI"]: 106 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 107 assert kT is not None, "kT must be specified for QTB thermostat" 108 assert gamma is not None, "gamma must be specified for QTB thermostat" 109 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 110 rng_key, v_key = jax.random.split(rng_key) 111 112 a1 = math.exp(-gamma * dt) 113 a2 = (1 - a1) * kT 114 vel = ( 115 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 116 * (kT / mass[:, None]) ** 0.5 117 ) 118 119 state["rng_key"] = rng_key 120 if compute_thermostat_energy: 121 state["thermostat_energy"] = 0.0 122 123 def thermostat(vel, state): 124 rng_key, noise_key = jax.random.split(state["rng_key"]) 125 new_state = {**state, "rng_key": rng_key} 126 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 127 R2 = jnp.sum(noise**2) 128 R1 = noise[0, 0] 129 c = a2 / (mass[:, None] * vel**2).sum() 130 d = (a1 * c) ** 0.5 131 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 132 if compute_thermostat_energy: 133 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 134 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 135 return scale * vel, new_state 136 137 elif thermostat_name in [ 138 "GD", 139 "DESCENT", 140 "GRADIENT_DESCENT", 141 "MIN", 142 "MINIMIZE", 143 ]: 144 assert nbeads is None, "Gradient descent is not compatible with PIMD" 145 a1 = math.exp(-gamma * dt) 146 147 if nbeads is None: 148 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 149 else: 150 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 151 152 def thermostat(vel, state): 153 return a1 * vel, state 154 155 elif thermostat_name in ["NVE", "NONE"]: 156 if nbeads is None: 157 vel = ( 158 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 159 * (kT / mass[:, None]) ** 0.5 160 ) 161 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 162 vel = vel * (kT / kTsys) ** 0.5 163 else: 164 vel = ( 165 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 166 * (kT / mass[None, :, None]) ** 0.5 167 ) 168 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 169 mass.shape[0] * 3 170 ) 171 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 172 thermostat = lambda x, s: (x, s) 173 174 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 175 assert gamma is not None, "gamma must be specified for QTB thermostat" 176 ndof = mass.shape[0] * 3 177 nkT = ndof * kT 178 nose_mass = nkT / gamma**2 179 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 180 state["nose_s"] = 0.0 181 state["nose_v"] = 0.0 182 if compute_thermostat_energy: 183 state["thermostat_energy"] = 0.0 184 print( 185 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 186 ) 187 vel = ( 188 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 189 * (kT / mass[:, None]) ** 0.5 190 ) 191 192 def thermostat(vel, state): 193 nose_s = state["nose_s"] 194 nose_v = state["nose_v"] 195 kTsys = jnp.sum(mass[:, None] * vel**2) 196 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 197 nose_s = nose_s + dt * nose_v 198 vel = jnp.exp(-nose_v * dt) * vel 199 kTsys = jnp.sum(mass[:, None] * vel**2) 200 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 201 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 202 203 if compute_thermostat_energy: 204 new_state["thermostat_energy"] = ( 205 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 206 ) 207 return vel, new_state 208 209 elif thermostat_name in ["QTB", "ADQTB"]: 210 assert nbeads is None, "QTB is not compatible with PIMD" 211 qtb_parameters = simulation_parameters.get("qtb", None) 212 assert ( 213 qtb_parameters is not None 214 ), "qtb_parameters must be provided for QTB thermostat" 215 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 216 assert kT is not None, "kT must be specified for QTB thermostat" 217 assert gamma is not None, "gamma must be specified for QTB thermostat" 218 assert species is not None, "species must be provided for QTB thermostat" 219 rng_key, v_key = jax.random.split(rng_key) 220 vel = ( 221 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 222 * (kT / mass[:, None]) ** 0.5 223 ) 224 225 thermostat, postprocess, qtb_state = initialize_qtb( 226 qtb_parameters, 227 system_data, 228 fprec=fprec, 229 dt=dt, 230 mass=mass, 231 gamma=gamma, 232 kT=kT, 233 species=species, 234 rng_key=rng_key, 235 adaptive=thermostat_name.startswith("AD"), 236 compute_thermostat_energy=compute_thermostat_energy, 237 ) 238 state = {**state, **qtb_state} 239 240 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 241 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 242 assert kT is not None, "kT must be specified for QTB thermostat" 243 assert gamma is not None, "gamma must be specified for QTB thermostat" 244 assert nbeads is None, "ANNEAL is not compatible with PIMD" 245 a1 = math.exp(-gamma * dt) 246 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 247 248 anneal_parameters = simulation_parameters.get("annealing", {}) 249 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 250 assert init_factor > 0.0, "init_factor must be positive" 251 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 252 assert final_factor > 0.0, "final_factor must be positive" 253 nsteps = simulation_parameters.get("nsteps") 254 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 255 assert ( 256 anneal_steps < 1.0 and anneal_steps > 0.0 257 ), "warmup_steps must be between 0 and nsteps" 258 pct_start = anneal_parameters.get("warmup_steps", 0.3) 259 assert ( 260 pct_start < 1.0 and pct_start > 0.0 261 ), "warmup_steps must be between 0 and nsteps" 262 263 anneal_type = anneal_parameters.get("type", "cosine").lower() 264 if anneal_type == "linear": 265 schedule = optax.linear_onecycle_schedule( 266 peak_value=1.0, 267 div_factor=1.0 / init_factor, 268 final_div_factor=1.0 / final_factor, 269 transition_steps=int(anneal_steps * nsteps), 270 pct_start=pct_start, 271 pct_final=1.0, 272 ) 273 elif anneal_type == "cosine_onecycle": 274 schedule = optax.cosine_onecycle_schedule( 275 peak_value=1.0, 276 div_factor=1.0 / init_factor, 277 final_div_factor=1.0 / final_factor, 278 transition_steps=int(anneal_steps * nsteps), 279 pct_start=pct_start, 280 ) 281 else: 282 raise ValueError(f"Unknown anneal_type {anneal_type}") 283 284 state["rng_key"] = rng_key 285 state["istep_anneal"] = 0 286 287 rng_key, v_key = jax.random.split(rng_key) 288 Tscale = schedule(0) 289 print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K") 290 vel = ( 291 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 292 * (kT * Tscale / mass[:, None]) ** 0.5 293 ) 294 295 def thermostat(vel, state): 296 rng_key, noise_key = jax.random.split(state["rng_key"]) 297 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 298 299 Tscale = schedule(state["istep_anneal"]) ** 0.5 300 vel = a1 * vel + a2 * Tscale * noise 301 return vel, { 302 **state, 303 "rng_key": rng_key, 304 "istep_anneal": state["istep_anneal"] + 1, 305 } 306 307 else: 308 raise ValueError(f"Unknown thermostat {thermostat_name}") 309 310 return thermostat, postprocess, state, vel,thermostat_name 311 312 313def initialize_qtb( 314 qtb_parameters, 315 system_data, 316 fprec, 317 dt, 318 mass, 319 gamma, 320 kT, 321 species, 322 rng_key, 323 adaptive, 324 compute_thermostat_energy=False, 325): 326 state = {} 327 post_state = {} 328 verbose = qtb_parameters.get("verbose", False) 329 if compute_thermostat_energy: 330 state["thermostat_energy"] = 0.0 331 332 mass = jnp.asarray(mass, dtype=fprec) 333 334 nat = species.shape[0] 335 # define type indices 336 species_set = set(species) 337 nspecies = len(species_set) 338 idx = {sp: i for i, sp in enumerate(species_set)} 339 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 340 341 n_of_type = np.zeros(nspecies, dtype=np.int32) 342 for i in range(nspecies): 343 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 344 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 345 mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type 346 347 niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7) 348 niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20) 349 corr_kin = qtb_parameters.get("corr_kin", -1) 350 do_corr_kin = corr_kin <= 0 351 if do_corr_kin: 352 corr_kin = 1.0 353 state["corr_kin"] = corr_kin 354 post_state["corr_kin_prev"] = corr_kin 355 post_state["do_corr_kin"] = do_corr_kin 356 post_state["isame_kin"] = 0 357 358 # spectra parameters 359 omegasmear = np.pi / dt / 100.0 360 Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS 361 nseg = int(Tseg / dt) 362 Tseg = nseg * dt 363 dom = 2 * np.pi / (3 * Tseg) 364 omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS 365 nom = int(omegacut / dom) 366 omega = dom * np.arange((3 * nseg) // 2 + 1) 367 cutoff = jnp.asarray( 368 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 369 ) 370 assert ( 371 omegacut < omega[-1] 372 ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1" 373 374 # initialize gammar 375 assert ( 376 gamma < 0.5 * omegacut 377 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 378 gammar_min = qtb_parameters.get("gammar_min", 0.1) 379 # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 380 gammar = np.ones((nspecies, nom), dtype=float) 381 try: 382 for i, sp in enumerate(species_set): 383 if not os.path.exists(f"QTB_spectra_{sp}.out"): continue 384 data = np.loadtxt(f"QTB_spectra_{sp}.out") 385 gammar[i] = data[:, 4]/(gamma*au.FS*au.THZ) 386 print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out") 387 except Exception as e: 388 print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.") 389 gammar[:,:] = 1.0 390 post_state["gammar"] = jnp.asarray(gammar, dtype=fprec) 391 392 # Ornstein-Uhlenbeck correction for colored noise 393 a1 = np.exp(-gamma * dt) 394 OUcorr = jnp.asarray( 395 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 396 dtype=fprec, 397 ) 398 399 # hbar schedule 400 classical_kernel = qtb_parameters.get("classical_kernel", False) 401 hbar = qtb_parameters.get("hbar", 1.0) * au.FS 402 u = 0.5 * hbar * np.abs(omega) / kT 403 theta = kT * np.ones_like(omega) 404 if hbar > 0: 405 theta[1:] *= u[1:] / np.tanh(u[1:]) 406 theta = jnp.asarray(theta, dtype=fprec) 407 408 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 409 del rng_key 410 post_state["white_noise"] = jax.random.normal( 411 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 412 ) 413 414 startsave = qtb_parameters.get("startsave", 1) 415 counter = Counter(nseg, startsave=startsave) 416 state["istep"] = 0 417 post_state["nadapt"] = 0 418 post_state["nsample"] = 0 419 420 write_spectra = qtb_parameters.get("write_spectra", True) 421 do_compute_spectra = write_spectra or adaptive 422 423 if do_compute_spectra: 424 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 425 426 post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec) 427 post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec) 428 post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec) 429 post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec) 430 post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 431 post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 432 post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 433 post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 434 435 if not adaptive: 436 update_gammar = lambda x: x 437 else: 438 # adaptation parameters 439 skipseg = qtb_parameters.get("skipseg", 1) 440 441 adaptation_method = ( 442 str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip() 443 ) 444 authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"] 445 assert ( 446 adaptation_method in authorized_methods 447 ), f"adaptation_method must be one of {authorized_methods}" 448 if adaptation_method == "SIMPLE": 449 agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS 450 assert agamma > 0, "agamma must be positive" 451 a1_ad = agamma * Tseg # * gamma 452 print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}") 453 454 def update_gammar(post_state): 455 g = post_state["dFDT"] 456 gammar = post_state["gammar"] - a1_ad * g 457 gammar = jnp.maximum(gammar_min, gammar) 458 return {**post_state, "gammar": gammar} 459 460 elif adaptation_method == "RATIO": 461 tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS 462 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS 463 assert tau_ad > 0, "tau_ad must be positive" 464 print( 465 f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps" 466 ) 467 b1 = np.exp(-Tseg / tau_ad) 468 b2 = np.exp(-Tseg / tau_s) 469 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 470 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 471 post_state["n_adabelief"] = 0 472 473 def update_gammar(post_state): 474 n_adabelief = post_state["n_adabelief"] + 1 475 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 476 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 477 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 478 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 479 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 480 gammar = Cvf / (mCvv + 1.0e-8) 481 gammar = jnp.maximum(gammar_min, gammar) 482 return { 483 **post_state, 484 "gammar": gammar, 485 "mCvv_m": mCvv_m, 486 "Cvf_m": Cvf_m, 487 "n_adabelief": n_adabelief, 488 } 489 490 elif adaptation_method == "ADABELIEF": 491 agamma = qtb_parameters.get("agamma", 0.1) 492 tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS 493 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS 494 assert tau_ad > 0, "tau_ad must be positive" 495 assert tau_s > 0, "tau_s must be positive" 496 assert agamma > 0, "agamma must be positive" 497 print( 498 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps" 499 ) 500 501 a1_ad = agamma * gamma # * Tseg #* gamma 502 b1 = np.exp(-Tseg / tau_ad) 503 b2 = np.exp(-Tseg / tau_s) 504 post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 505 post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec) 506 post_state["n_adabelief"] = 0 507 508 def update_gammar(post_state): 509 n_adabelief = post_state["n_adabelief"] + 1 510 dFDT = post_state["dFDT"] 511 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 512 dFDT_s = ( 513 post_state["dFDT_s"] * b2 514 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 515 + 1.0e-8 516 ) 517 # bias correction 518 mt = dFDT_m / (1.0 - b1**n_adabelief) 519 st = dFDT_s / (1.0 - b2**n_adabelief) 520 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 521 gammar = jnp.maximum(gammar_min, gammar) 522 return { 523 **post_state, 524 "gammar": gammar, 525 "dFDT_m": dFDT_m, 526 "n_adabelief": n_adabelief, 527 "dFDT_s": dFDT_s, 528 } 529 530 ##################### 531 # RESTART 532 restart_file = system_data["name"]+".qtb.restart" 533 if os.path.exists(restart_file): 534 with open(restart_file, "rb") as f: 535 data = pickle.load(f) 536 state["corr_kin"] = data["corr_kin"] 537 post_state["corr_kin_prev"] = data["corr_kin"] 538 post_state["isame_kin"] = data["isame_kin"] 539 post_state["do_corr_kin"] = data["do_corr_kin"] 540 print(f"# Restored QTB state from {restart_file}") 541 542 def write_qtb_restart(state, post_state): 543 with open(restart_file, "wb") as f: 544 pickle.dump( 545 { 546 "corr_kin": state["corr_kin"], 547 "corr_kin_prev": post_state["corr_kin_prev"], 548 "isame_kin": post_state["isame_kin"], 549 "do_corr_kin": post_state["do_corr_kin"], 550 }, 551 f, 552 ) 553 ###################### 554 555 def compute_corr_pot(niter=20, verbose=False): 556 if classical_kernel or hbar == 0: 557 return np.ones(nom) 558 559 s_0 = np.array((theta / kT * cutoff)[:nom]) 560 s_out, s_rec, _ = deconvolute_spectrum( 561 s_0, 562 omega[:nom], 563 gamma, 564 niter, 565 kernel=kernel_lorentz_pot, 566 trans=True, 567 symmetrize=True, 568 verbose=verbose, 569 ) 570 corr_pot = 1.0 + (s_out - s_0) / s_0 571 columns = np.column_stack( 572 (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 573 ) 574 np.savetxt( 575 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 576 ) 577 return corr_pot 578 579 def compute_corr_kin(post_state, niter=7, verbose=False): 580 if not post_state["do_corr_kin"]: 581 return post_state["corr_kin_prev"], post_state 582 if classical_kernel or hbar == 0: 583 return 1.0, post_state 584 585 K_D = post_state.get("K_D", None) 586 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 587 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 588 s_out, s_rec, K_D = deconvolute_spectrum( 589 s_0, 590 omega[:nom], 591 gamma, 592 niter, 593 kernel=kernel_lorentz, 594 trans=False, 595 symmetrize=True, 596 verbose=verbose, 597 K_D=K_D, 598 ) 599 s_out = s_out * theta[:nom] / kT 600 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 601 mCvvsum = mCvv.sum() 602 rec_ratio = mCvvsum / s_rec.sum() 603 if rec_ratio < 0.95 or rec_ratio > 1.05: 604 print( 605 f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated" 606 ) 607 return post_state["corr_kin_prev"], post_state 608 609 corr_kin = mCvvsum / s_out.sum() 610 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 611 isame_kin = post_state["isame_kin"] + 1 612 else: 613 isame_kin = 0 614 615 # print("# corr_kin: ", corr_kin) 616 do_corr_kin = post_state["do_corr_kin"] 617 if isame_kin > 10: 618 print( 619 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 620 ) 621 do_corr_kin = False 622 623 return corr_kin, { 624 **post_state, 625 "corr_kin_prev": corr_kin, 626 "isame_kin": isame_kin, 627 "do_corr_kin": do_corr_kin, 628 "K_D": K_D, 629 } 630 631 @jax.jit 632 def ff_kernel(post_state): 633 if classical_kernel: 634 kernel = cutoff * (2 * gamma * kT / dt) 635 else: 636 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 637 gamma_ratio = jnp.concatenate( 638 ( 639 post_state["gammar"].T * post_state["corr_pot"][:, None], 640 jnp.ones( 641 (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype 642 ), 643 ), 644 axis=0, 645 ) 646 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 647 648 @jax.jit 649 def refresh_force(post_state): 650 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 651 white_noise = jnp.concatenate( 652 ( 653 post_state["white_noise"][nseg:], 654 jax.random.normal( 655 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 656 ), 657 ), 658 axis=0, 659 ) 660 amplitude = ff_kernel(post_state) ** 0.5 661 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 662 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 663 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 664 665 @jax.jit 666 def compute_spectra(force, vel, post_state): 667 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 668 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 669 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 670 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 671 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 672 673 mCvv = ( 674 (dt / 3.0) 675 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 676 * mass_idx[:, None] 677 / n_of_type[:, None] 678 ) 679 Cvf = ( 680 (dt / 3.0) 681 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 682 / n_of_type[:, None] 683 ) 684 Cff = ( 685 (dt / 3.0) 686 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 687 / n_of_type[:, None] 688 ) 689 dFDT = mCvv * post_state["gammar"] - Cvf 690 691 nsinv = 1.0 / post_state["nsample"] 692 b1 = 1.0 - nsinv 693 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 694 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 695 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 696 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 697 698 return { 699 **post_state, 700 "mCvv": mCvv, 701 "Cvf": Cvf, 702 "Cff": Cff, 703 "dFDT": dFDT, 704 "dFDT_avg": dFDT_avg, 705 "mCvv_avg": mCvv_avg, 706 "Cvfg_avg": Cvfg_avg, 707 "Cff_avg": Cff_avg, 708 } 709 710 def write_spectra_to_file(post_state): 711 mCvv_avg = np.array(post_state["mCvv_avg"]) 712 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 713 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 714 dFDT_avg = np.array(post_state["dFDT_avg"]) 715 gammar = np.array(post_state["gammar"]) 716 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 717 for i, sp in enumerate(species_set): 718 ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 719 columns = np.column_stack( 720 ( 721 omega[:nom] * (au.FS * au.CM1), 722 mCvv_avg[i], 723 Cvfg_avg[i], 724 dFDT_avg[i], 725 gammar[i] * gamma * (au.FS * au.THZ), 726 Cff_avg[i] * ff_scale, 727 Cff_theo[i] * ff_scale, 728 ) 729 ) 730 np.savetxt( 731 f"QTB_spectra_{sp}.out", 732 columns, 733 fmt="%12.6f", 734 header="#omega mCvv Cvf dFDT gammar Cff", 735 ) 736 if verbose: 737 print("# QTB spectra written.") 738 739 if compute_thermostat_energy: 740 state["qtb_energy_flux"] = 0.0 741 742 @jax.jit 743 def thermostat(vel, state): 744 istep = state["istep"] 745 dvel = dt * state["force"][istep] / mass[:, None] 746 new_vel = vel * a1 + dvel 747 new_state = {**state, "istep": istep + 1} 748 if do_compute_spectra: 749 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 750 new_state["vel"] = vel2 751 if compute_thermostat_energy: 752 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 753 ekcorr = ( 754 0.5 755 * (mass[:, None] * new_vel**2).sum() 756 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 757 ) 758 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 759 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 760 return new_vel, new_state 761 762 @jax.jit 763 def postprocess_work(state, post_state): 764 if do_compute_spectra: 765 post_state = compute_spectra(state["force"], state["vel"], post_state) 766 if adaptive: 767 post_state = jax.lax.cond( 768 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 769 ) 770 new_force, post_state = refresh_force(post_state) 771 return {**state, "force": new_force}, post_state 772 773 def postprocess(state, post_state): 774 counter.increment() 775 if not counter.is_reset_step: 776 return state, post_state 777 post_state["nadapt"] += 1 778 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 779 if verbose: 780 print("# Refreshing QTB forces.") 781 state, post_state = postprocess_work(state, post_state) 782 state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin) 783 state["istep"] = 0 784 if write_spectra: 785 write_spectra_to_file(post_state) 786 write_qtb_restart(state, post_state) 787 return state, post_state 788 789 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec) 790 791 state["force"], post_state = refresh_force(post_state) 792 return thermostat, (postprocess, post_state), state
def
get_thermostat( simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
20def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}): 21 state = {} 22 postprocess = None 23 24 25 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 26 compute_thermostat_energy = simulation_parameters.get( 27 "include_thermostat_energy", False 28 ) 29 30 kT = system_data.get("kT", None) 31 nbeads = system_data.get("nbeads", None) 32 mass = system_data["mass"] 33 gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS 34 species = system_data["species"] 35 36 if nbeads is not None: 37 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 38 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 39 40 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 41 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 42 assert kT is not None, "kT must be specified for QTB thermostat" 43 assert gamma is not None, "gamma must be specified for QTB thermostat" 44 rng_key, v_key = jax.random.split(rng_key) 45 if nbeads is None: 46 a1 = math.exp(-gamma * dt) 47 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 48 vel = ( 49 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 50 * (kT / mass[:, None]) ** 0.5 51 ) 52 else: 53 if isinstance(gamma, float): 54 gamma = np.array([gamma] * nbeads) 55 assert isinstance( 56 gamma, np.ndarray 57 ), "gamma must be a float or a numpy array" 58 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 59 a1 = np.exp(-gamma * dt)[:, None, None] 60 a2 = jnp.asarray( 61 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 62 ) 63 vel = ( 64 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 65 * (kT / mass[:, None]) ** 0.5 66 ) 67 68 state["rng_key"] = rng_key 69 if compute_thermostat_energy: 70 state["thermostat_energy"] = 0.0 71 if thermostat_name == "FFLGV": 72 def thermostat(vel, state): 73 rng_key, noise_key = jax.random.split(state["rng_key"]) 74 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 75 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 76 dirvel = vel / norm_vel 77 if compute_thermostat_energy: 78 v2 = (vel**2).sum(axis=-1) 79 vel = a1 * vel + a2 * noise 80 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 81 vel = dirvel * new_norm_vel 82 new_state = {**state, "rng_key": rng_key} 83 if compute_thermostat_energy: 84 v2new = (vel**2).sum(axis=-1) 85 new_state["thermostat_energy"] = ( 86 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 87 ) 88 89 return vel, new_state 90 91 else: 92 def thermostat(vel, state): 93 rng_key, noise_key = jax.random.split(state["rng_key"]) 94 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 95 if compute_thermostat_energy: 96 v2 = (vel**2).sum(axis=-1) 97 vel = a1 * vel + a2 * noise 98 new_state = {**state, "rng_key": rng_key} 99 if compute_thermostat_energy: 100 v2new = (vel**2).sum(axis=-1) 101 new_state["thermostat_energy"] = ( 102 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 103 ) 104 return vel, new_state 105 106 elif thermostat_name in ["BUSSI"]: 107 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 108 assert kT is not None, "kT must be specified for QTB thermostat" 109 assert gamma is not None, "gamma must be specified for QTB thermostat" 110 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 111 rng_key, v_key = jax.random.split(rng_key) 112 113 a1 = math.exp(-gamma * dt) 114 a2 = (1 - a1) * kT 115 vel = ( 116 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 117 * (kT / mass[:, None]) ** 0.5 118 ) 119 120 state["rng_key"] = rng_key 121 if compute_thermostat_energy: 122 state["thermostat_energy"] = 0.0 123 124 def thermostat(vel, state): 125 rng_key, noise_key = jax.random.split(state["rng_key"]) 126 new_state = {**state, "rng_key": rng_key} 127 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 128 R2 = jnp.sum(noise**2) 129 R1 = noise[0, 0] 130 c = a2 / (mass[:, None] * vel**2).sum() 131 d = (a1 * c) ** 0.5 132 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 133 if compute_thermostat_energy: 134 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 135 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 136 return scale * vel, new_state 137 138 elif thermostat_name in [ 139 "GD", 140 "DESCENT", 141 "GRADIENT_DESCENT", 142 "MIN", 143 "MINIMIZE", 144 ]: 145 assert nbeads is None, "Gradient descent is not compatible with PIMD" 146 a1 = math.exp(-gamma * dt) 147 148 if nbeads is None: 149 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 150 else: 151 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 152 153 def thermostat(vel, state): 154 return a1 * vel, state 155 156 elif thermostat_name in ["NVE", "NONE"]: 157 if nbeads is None: 158 vel = ( 159 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 160 * (kT / mass[:, None]) ** 0.5 161 ) 162 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 163 vel = vel * (kT / kTsys) ** 0.5 164 else: 165 vel = ( 166 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 167 * (kT / mass[None, :, None]) ** 0.5 168 ) 169 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 170 mass.shape[0] * 3 171 ) 172 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 173 thermostat = lambda x, s: (x, s) 174 175 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 176 assert gamma is not None, "gamma must be specified for QTB thermostat" 177 ndof = mass.shape[0] * 3 178 nkT = ndof * kT 179 nose_mass = nkT / gamma**2 180 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 181 state["nose_s"] = 0.0 182 state["nose_v"] = 0.0 183 if compute_thermostat_energy: 184 state["thermostat_energy"] = 0.0 185 print( 186 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 187 ) 188 vel = ( 189 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 190 * (kT / mass[:, None]) ** 0.5 191 ) 192 193 def thermostat(vel, state): 194 nose_s = state["nose_s"] 195 nose_v = state["nose_v"] 196 kTsys = jnp.sum(mass[:, None] * vel**2) 197 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 198 nose_s = nose_s + dt * nose_v 199 vel = jnp.exp(-nose_v * dt) * vel 200 kTsys = jnp.sum(mass[:, None] * vel**2) 201 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 202 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 203 204 if compute_thermostat_energy: 205 new_state["thermostat_energy"] = ( 206 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 207 ) 208 return vel, new_state 209 210 elif thermostat_name in ["QTB", "ADQTB"]: 211 assert nbeads is None, "QTB is not compatible with PIMD" 212 qtb_parameters = simulation_parameters.get("qtb", None) 213 assert ( 214 qtb_parameters is not None 215 ), "qtb_parameters must be provided for QTB thermostat" 216 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 217 assert kT is not None, "kT must be specified for QTB thermostat" 218 assert gamma is not None, "gamma must be specified for QTB thermostat" 219 assert species is not None, "species must be provided for QTB thermostat" 220 rng_key, v_key = jax.random.split(rng_key) 221 vel = ( 222 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 223 * (kT / mass[:, None]) ** 0.5 224 ) 225 226 thermostat, postprocess, qtb_state = initialize_qtb( 227 qtb_parameters, 228 system_data, 229 fprec=fprec, 230 dt=dt, 231 mass=mass, 232 gamma=gamma, 233 kT=kT, 234 species=species, 235 rng_key=rng_key, 236 adaptive=thermostat_name.startswith("AD"), 237 compute_thermostat_energy=compute_thermostat_energy, 238 ) 239 state = {**state, **qtb_state} 240 241 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 242 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 243 assert kT is not None, "kT must be specified for QTB thermostat" 244 assert gamma is not None, "gamma must be specified for QTB thermostat" 245 assert nbeads is None, "ANNEAL is not compatible with PIMD" 246 a1 = math.exp(-gamma * dt) 247 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 248 249 anneal_parameters = simulation_parameters.get("annealing", {}) 250 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 251 assert init_factor > 0.0, "init_factor must be positive" 252 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 253 assert final_factor > 0.0, "final_factor must be positive" 254 nsteps = simulation_parameters.get("nsteps") 255 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 256 assert ( 257 anneal_steps < 1.0 and anneal_steps > 0.0 258 ), "warmup_steps must be between 0 and nsteps" 259 pct_start = anneal_parameters.get("warmup_steps", 0.3) 260 assert ( 261 pct_start < 1.0 and pct_start > 0.0 262 ), "warmup_steps must be between 0 and nsteps" 263 264 anneal_type = anneal_parameters.get("type", "cosine").lower() 265 if anneal_type == "linear": 266 schedule = optax.linear_onecycle_schedule( 267 peak_value=1.0, 268 div_factor=1.0 / init_factor, 269 final_div_factor=1.0 / final_factor, 270 transition_steps=int(anneal_steps * nsteps), 271 pct_start=pct_start, 272 pct_final=1.0, 273 ) 274 elif anneal_type == "cosine_onecycle": 275 schedule = optax.cosine_onecycle_schedule( 276 peak_value=1.0, 277 div_factor=1.0 / init_factor, 278 final_div_factor=1.0 / final_factor, 279 transition_steps=int(anneal_steps * nsteps), 280 pct_start=pct_start, 281 ) 282 else: 283 raise ValueError(f"Unknown anneal_type {anneal_type}") 284 285 state["rng_key"] = rng_key 286 state["istep_anneal"] = 0 287 288 rng_key, v_key = jax.random.split(rng_key) 289 Tscale = schedule(0) 290 print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K") 291 vel = ( 292 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 293 * (kT * Tscale / mass[:, None]) ** 0.5 294 ) 295 296 def thermostat(vel, state): 297 rng_key, noise_key = jax.random.split(state["rng_key"]) 298 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 299 300 Tscale = schedule(state["istep_anneal"]) ** 0.5 301 vel = a1 * vel + a2 * Tscale * noise 302 return vel, { 303 **state, 304 "rng_key": rng_key, 305 "istep_anneal": state["istep_anneal"] + 1, 306 } 307 308 else: 309 raise ValueError(f"Unknown thermostat {thermostat_name}") 310 311 return thermostat, postprocess, state, vel,thermostat_name
def
initialize_qtb( qtb_parameters, system_data, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
314def initialize_qtb( 315 qtb_parameters, 316 system_data, 317 fprec, 318 dt, 319 mass, 320 gamma, 321 kT, 322 species, 323 rng_key, 324 adaptive, 325 compute_thermostat_energy=False, 326): 327 state = {} 328 post_state = {} 329 verbose = qtb_parameters.get("verbose", False) 330 if compute_thermostat_energy: 331 state["thermostat_energy"] = 0.0 332 333 mass = jnp.asarray(mass, dtype=fprec) 334 335 nat = species.shape[0] 336 # define type indices 337 species_set = set(species) 338 nspecies = len(species_set) 339 idx = {sp: i for i, sp in enumerate(species_set)} 340 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 341 342 n_of_type = np.zeros(nspecies, dtype=np.int32) 343 for i in range(nspecies): 344 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 345 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 346 mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type 347 348 niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7) 349 niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20) 350 corr_kin = qtb_parameters.get("corr_kin", -1) 351 do_corr_kin = corr_kin <= 0 352 if do_corr_kin: 353 corr_kin = 1.0 354 state["corr_kin"] = corr_kin 355 post_state["corr_kin_prev"] = corr_kin 356 post_state["do_corr_kin"] = do_corr_kin 357 post_state["isame_kin"] = 0 358 359 # spectra parameters 360 omegasmear = np.pi / dt / 100.0 361 Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS 362 nseg = int(Tseg / dt) 363 Tseg = nseg * dt 364 dom = 2 * np.pi / (3 * Tseg) 365 omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS 366 nom = int(omegacut / dom) 367 omega = dom * np.arange((3 * nseg) // 2 + 1) 368 cutoff = jnp.asarray( 369 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 370 ) 371 assert ( 372 omegacut < omega[-1] 373 ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1" 374 375 # initialize gammar 376 assert ( 377 gamma < 0.5 * omegacut 378 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 379 gammar_min = qtb_parameters.get("gammar_min", 0.1) 380 # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 381 gammar = np.ones((nspecies, nom), dtype=float) 382 try: 383 for i, sp in enumerate(species_set): 384 if not os.path.exists(f"QTB_spectra_{sp}.out"): continue 385 data = np.loadtxt(f"QTB_spectra_{sp}.out") 386 gammar[i] = data[:, 4]/(gamma*au.FS*au.THZ) 387 print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out") 388 except Exception as e: 389 print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.") 390 gammar[:,:] = 1.0 391 post_state["gammar"] = jnp.asarray(gammar, dtype=fprec) 392 393 # Ornstein-Uhlenbeck correction for colored noise 394 a1 = np.exp(-gamma * dt) 395 OUcorr = jnp.asarray( 396 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 397 dtype=fprec, 398 ) 399 400 # hbar schedule 401 classical_kernel = qtb_parameters.get("classical_kernel", False) 402 hbar = qtb_parameters.get("hbar", 1.0) * au.FS 403 u = 0.5 * hbar * np.abs(omega) / kT 404 theta = kT * np.ones_like(omega) 405 if hbar > 0: 406 theta[1:] *= u[1:] / np.tanh(u[1:]) 407 theta = jnp.asarray(theta, dtype=fprec) 408 409 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 410 del rng_key 411 post_state["white_noise"] = jax.random.normal( 412 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 413 ) 414 415 startsave = qtb_parameters.get("startsave", 1) 416 counter = Counter(nseg, startsave=startsave) 417 state["istep"] = 0 418 post_state["nadapt"] = 0 419 post_state["nsample"] = 0 420 421 write_spectra = qtb_parameters.get("write_spectra", True) 422 do_compute_spectra = write_spectra or adaptive 423 424 if do_compute_spectra: 425 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 426 427 post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec) 428 post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec) 429 post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec) 430 post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec) 431 post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 432 post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 433 post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 434 post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 435 436 if not adaptive: 437 update_gammar = lambda x: x 438 else: 439 # adaptation parameters 440 skipseg = qtb_parameters.get("skipseg", 1) 441 442 adaptation_method = ( 443 str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip() 444 ) 445 authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"] 446 assert ( 447 adaptation_method in authorized_methods 448 ), f"adaptation_method must be one of {authorized_methods}" 449 if adaptation_method == "SIMPLE": 450 agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS 451 assert agamma > 0, "agamma must be positive" 452 a1_ad = agamma * Tseg # * gamma 453 print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}") 454 455 def update_gammar(post_state): 456 g = post_state["dFDT"] 457 gammar = post_state["gammar"] - a1_ad * g 458 gammar = jnp.maximum(gammar_min, gammar) 459 return {**post_state, "gammar": gammar} 460 461 elif adaptation_method == "RATIO": 462 tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS 463 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS 464 assert tau_ad > 0, "tau_ad must be positive" 465 print( 466 f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps" 467 ) 468 b1 = np.exp(-Tseg / tau_ad) 469 b2 = np.exp(-Tseg / tau_s) 470 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 471 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 472 post_state["n_adabelief"] = 0 473 474 def update_gammar(post_state): 475 n_adabelief = post_state["n_adabelief"] + 1 476 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 477 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 478 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 479 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 480 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 481 gammar = Cvf / (mCvv + 1.0e-8) 482 gammar = jnp.maximum(gammar_min, gammar) 483 return { 484 **post_state, 485 "gammar": gammar, 486 "mCvv_m": mCvv_m, 487 "Cvf_m": Cvf_m, 488 "n_adabelief": n_adabelief, 489 } 490 491 elif adaptation_method == "ADABELIEF": 492 agamma = qtb_parameters.get("agamma", 0.1) 493 tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS 494 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS 495 assert tau_ad > 0, "tau_ad must be positive" 496 assert tau_s > 0, "tau_s must be positive" 497 assert agamma > 0, "agamma must be positive" 498 print( 499 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps" 500 ) 501 502 a1_ad = agamma * gamma # * Tseg #* gamma 503 b1 = np.exp(-Tseg / tau_ad) 504 b2 = np.exp(-Tseg / tau_s) 505 post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 506 post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec) 507 post_state["n_adabelief"] = 0 508 509 def update_gammar(post_state): 510 n_adabelief = post_state["n_adabelief"] + 1 511 dFDT = post_state["dFDT"] 512 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 513 dFDT_s = ( 514 post_state["dFDT_s"] * b2 515 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 516 + 1.0e-8 517 ) 518 # bias correction 519 mt = dFDT_m / (1.0 - b1**n_adabelief) 520 st = dFDT_s / (1.0 - b2**n_adabelief) 521 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 522 gammar = jnp.maximum(gammar_min, gammar) 523 return { 524 **post_state, 525 "gammar": gammar, 526 "dFDT_m": dFDT_m, 527 "n_adabelief": n_adabelief, 528 "dFDT_s": dFDT_s, 529 } 530 531 ##################### 532 # RESTART 533 restart_file = system_data["name"]+".qtb.restart" 534 if os.path.exists(restart_file): 535 with open(restart_file, "rb") as f: 536 data = pickle.load(f) 537 state["corr_kin"] = data["corr_kin"] 538 post_state["corr_kin_prev"] = data["corr_kin"] 539 post_state["isame_kin"] = data["isame_kin"] 540 post_state["do_corr_kin"] = data["do_corr_kin"] 541 print(f"# Restored QTB state from {restart_file}") 542 543 def write_qtb_restart(state, post_state): 544 with open(restart_file, "wb") as f: 545 pickle.dump( 546 { 547 "corr_kin": state["corr_kin"], 548 "corr_kin_prev": post_state["corr_kin_prev"], 549 "isame_kin": post_state["isame_kin"], 550 "do_corr_kin": post_state["do_corr_kin"], 551 }, 552 f, 553 ) 554 ###################### 555 556 def compute_corr_pot(niter=20, verbose=False): 557 if classical_kernel or hbar == 0: 558 return np.ones(nom) 559 560 s_0 = np.array((theta / kT * cutoff)[:nom]) 561 s_out, s_rec, _ = deconvolute_spectrum( 562 s_0, 563 omega[:nom], 564 gamma, 565 niter, 566 kernel=kernel_lorentz_pot, 567 trans=True, 568 symmetrize=True, 569 verbose=verbose, 570 ) 571 corr_pot = 1.0 + (s_out - s_0) / s_0 572 columns = np.column_stack( 573 (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 574 ) 575 np.savetxt( 576 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 577 ) 578 return corr_pot 579 580 def compute_corr_kin(post_state, niter=7, verbose=False): 581 if not post_state["do_corr_kin"]: 582 return post_state["corr_kin_prev"], post_state 583 if classical_kernel or hbar == 0: 584 return 1.0, post_state 585 586 K_D = post_state.get("K_D", None) 587 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 588 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 589 s_out, s_rec, K_D = deconvolute_spectrum( 590 s_0, 591 omega[:nom], 592 gamma, 593 niter, 594 kernel=kernel_lorentz, 595 trans=False, 596 symmetrize=True, 597 verbose=verbose, 598 K_D=K_D, 599 ) 600 s_out = s_out * theta[:nom] / kT 601 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 602 mCvvsum = mCvv.sum() 603 rec_ratio = mCvvsum / s_rec.sum() 604 if rec_ratio < 0.95 or rec_ratio > 1.05: 605 print( 606 f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated" 607 ) 608 return post_state["corr_kin_prev"], post_state 609 610 corr_kin = mCvvsum / s_out.sum() 611 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 612 isame_kin = post_state["isame_kin"] + 1 613 else: 614 isame_kin = 0 615 616 # print("# corr_kin: ", corr_kin) 617 do_corr_kin = post_state["do_corr_kin"] 618 if isame_kin > 10: 619 print( 620 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 621 ) 622 do_corr_kin = False 623 624 return corr_kin, { 625 **post_state, 626 "corr_kin_prev": corr_kin, 627 "isame_kin": isame_kin, 628 "do_corr_kin": do_corr_kin, 629 "K_D": K_D, 630 } 631 632 @jax.jit 633 def ff_kernel(post_state): 634 if classical_kernel: 635 kernel = cutoff * (2 * gamma * kT / dt) 636 else: 637 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 638 gamma_ratio = jnp.concatenate( 639 ( 640 post_state["gammar"].T * post_state["corr_pot"][:, None], 641 jnp.ones( 642 (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype 643 ), 644 ), 645 axis=0, 646 ) 647 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 648 649 @jax.jit 650 def refresh_force(post_state): 651 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 652 white_noise = jnp.concatenate( 653 ( 654 post_state["white_noise"][nseg:], 655 jax.random.normal( 656 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 657 ), 658 ), 659 axis=0, 660 ) 661 amplitude = ff_kernel(post_state) ** 0.5 662 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 663 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 664 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 665 666 @jax.jit 667 def compute_spectra(force, vel, post_state): 668 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 669 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 670 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 671 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 672 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 673 674 mCvv = ( 675 (dt / 3.0) 676 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 677 * mass_idx[:, None] 678 / n_of_type[:, None] 679 ) 680 Cvf = ( 681 (dt / 3.0) 682 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 683 / n_of_type[:, None] 684 ) 685 Cff = ( 686 (dt / 3.0) 687 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 688 / n_of_type[:, None] 689 ) 690 dFDT = mCvv * post_state["gammar"] - Cvf 691 692 nsinv = 1.0 / post_state["nsample"] 693 b1 = 1.0 - nsinv 694 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 695 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 696 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 697 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 698 699 return { 700 **post_state, 701 "mCvv": mCvv, 702 "Cvf": Cvf, 703 "Cff": Cff, 704 "dFDT": dFDT, 705 "dFDT_avg": dFDT_avg, 706 "mCvv_avg": mCvv_avg, 707 "Cvfg_avg": Cvfg_avg, 708 "Cff_avg": Cff_avg, 709 } 710 711 def write_spectra_to_file(post_state): 712 mCvv_avg = np.array(post_state["mCvv_avg"]) 713 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 714 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 715 dFDT_avg = np.array(post_state["dFDT_avg"]) 716 gammar = np.array(post_state["gammar"]) 717 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 718 for i, sp in enumerate(species_set): 719 ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 720 columns = np.column_stack( 721 ( 722 omega[:nom] * (au.FS * au.CM1), 723 mCvv_avg[i], 724 Cvfg_avg[i], 725 dFDT_avg[i], 726 gammar[i] * gamma * (au.FS * au.THZ), 727 Cff_avg[i] * ff_scale, 728 Cff_theo[i] * ff_scale, 729 ) 730 ) 731 np.savetxt( 732 f"QTB_spectra_{sp}.out", 733 columns, 734 fmt="%12.6f", 735 header="#omega mCvv Cvf dFDT gammar Cff", 736 ) 737 if verbose: 738 print("# QTB spectra written.") 739 740 if compute_thermostat_energy: 741 state["qtb_energy_flux"] = 0.0 742 743 @jax.jit 744 def thermostat(vel, state): 745 istep = state["istep"] 746 dvel = dt * state["force"][istep] / mass[:, None] 747 new_vel = vel * a1 + dvel 748 new_state = {**state, "istep": istep + 1} 749 if do_compute_spectra: 750 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 751 new_state["vel"] = vel2 752 if compute_thermostat_energy: 753 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 754 ekcorr = ( 755 0.5 756 * (mass[:, None] * new_vel**2).sum() 757 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 758 ) 759 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 760 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 761 return new_vel, new_state 762 763 @jax.jit 764 def postprocess_work(state, post_state): 765 if do_compute_spectra: 766 post_state = compute_spectra(state["force"], state["vel"], post_state) 767 if adaptive: 768 post_state = jax.lax.cond( 769 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 770 ) 771 new_force, post_state = refresh_force(post_state) 772 return {**state, "force": new_force}, post_state 773 774 def postprocess(state, post_state): 775 counter.increment() 776 if not counter.is_reset_step: 777 return state, post_state 778 post_state["nadapt"] += 1 779 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 780 if verbose: 781 print("# Refreshing QTB forces.") 782 state, post_state = postprocess_work(state, post_state) 783 state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin) 784 state["istep"] = 0 785 if write_spectra: 786 write_spectra_to_file(post_state) 787 write_qtb_restart(state, post_state) 788 return state, post_state 789 790 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec) 791 792 state["force"], post_state = refresh_force(post_state) 793 return thermostat, (postprocess, post_state), state