fennol.md.thermostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7 8from ..utils.atomic_units import AtomicUnits as au # CM1,THZ,BOHR,MPROT 9from ..utils import Counter 10from ..utils.deconvolution import ( 11 deconvolute_spectrum, 12 kernel_lorentz_pot, 13 kernel_lorentz, 14) 15 16 17def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None): 18 state = {} 19 postprocess = None 20 21 22 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 23 compute_thermostat_energy = simulation_parameters.get( 24 "include_thermostat_energy", False 25 ) 26 27 kT = system_data.get("kT", None) 28 nbeads = system_data.get("nbeads", None) 29 mass = system_data["mass"] 30 gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS 31 species = system_data["species"] 32 33 if nbeads is not None: 34 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 35 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 36 37 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 38 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 39 assert kT is not None, "kT must be specified for QTB thermostat" 40 assert gamma is not None, "gamma must be specified for QTB thermostat" 41 rng_key, v_key = jax.random.split(rng_key) 42 if nbeads is None: 43 a1 = math.exp(-gamma * dt) 44 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 45 vel = ( 46 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 47 * (kT / mass[:, None]) ** 0.5 48 ) 49 else: 50 if isinstance(gamma, float): 51 gamma = np.array([gamma] * nbeads) 52 assert isinstance( 53 gamma, np.ndarray 54 ), "gamma must be a float or a numpy array" 55 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 56 a1 = np.exp(-gamma * dt)[:, None, None] 57 a2 = jnp.asarray( 58 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 59 ) 60 vel = ( 61 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 62 * (kT / mass[:, None]) ** 0.5 63 ) 64 65 state["rng_key"] = rng_key 66 if compute_thermostat_energy: 67 state["thermostat_energy"] = 0.0 68 if thermostat_name == "FFLGV": 69 def thermostat(vel, state): 70 rng_key, noise_key = jax.random.split(state["rng_key"]) 71 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 72 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 73 dirvel = vel / norm_vel 74 if compute_thermostat_energy: 75 v2 = (vel**2).sum(axis=-1) 76 vel = a1 * vel + a2 * noise 77 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 78 vel = dirvel * new_norm_vel 79 new_state = {**state, "rng_key": rng_key} 80 if compute_thermostat_energy: 81 v2new = (vel**2).sum(axis=-1) 82 new_state["thermostat_energy"] = ( 83 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 84 ) 85 86 return vel, new_state 87 88 else: 89 def thermostat(vel, state): 90 rng_key, noise_key = jax.random.split(state["rng_key"]) 91 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 92 if compute_thermostat_energy: 93 v2 = (vel**2).sum(axis=-1) 94 vel = a1 * vel + a2 * noise 95 new_state = {**state, "rng_key": rng_key} 96 if compute_thermostat_energy: 97 v2new = (vel**2).sum(axis=-1) 98 new_state["thermostat_energy"] = ( 99 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 100 ) 101 return vel, new_state 102 103 elif thermostat_name in ["BUSSI"]: 104 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 105 assert kT is not None, "kT must be specified for QTB thermostat" 106 assert gamma is not None, "gamma must be specified for QTB thermostat" 107 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 108 rng_key, v_key = jax.random.split(rng_key) 109 110 a1 = math.exp(-gamma * dt) 111 a2 = (1 - a1) * kT 112 vel = ( 113 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 114 * (kT / mass[:, None]) ** 0.5 115 ) 116 117 state["rng_key"] = rng_key 118 if compute_thermostat_energy: 119 state["thermostat_energy"] = 0.0 120 121 def thermostat(vel, state): 122 rng_key, noise_key = jax.random.split(state["rng_key"]) 123 new_state = {**state, "rng_key": rng_key} 124 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 125 R2 = jnp.sum(noise**2) 126 R1 = noise[0, 0] 127 c = a2 / (mass[:, None] * vel**2).sum() 128 d = (a1 * c) ** 0.5 129 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 130 if compute_thermostat_energy: 131 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 132 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 133 return scale * vel, new_state 134 135 elif thermostat_name in [ 136 "GD", 137 "DESCENT", 138 "GRADIENT_DESCENT", 139 "MIN", 140 "MINIMIZE", 141 ]: 142 assert nbeads is None, "Gradient descent is not compatible with PIMD" 143 a1 = math.exp(-gamma * dt) 144 145 if nbeads is None: 146 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 147 else: 148 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 149 150 def thermostat(vel, state): 151 return a1 * vel, state 152 153 elif thermostat_name in ["NVE", "NONE"]: 154 if nbeads is None: 155 vel = ( 156 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 157 * (kT / mass[:, None]) ** 0.5 158 ) 159 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 160 vel = vel * (kT / kTsys) ** 0.5 161 else: 162 vel = ( 163 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 164 * (kT / mass[None, :, None]) ** 0.5 165 ) 166 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 167 mass.shape[0] * 3 168 ) 169 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 170 thermostat = lambda x, s: (x, s) 171 172 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 173 assert gamma is not None, "gamma must be specified for QTB thermostat" 174 ndof = mass.shape[0] * 3 175 nkT = ndof * kT 176 nose_mass = nkT / gamma**2 177 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 178 state["nose_s"] = 0.0 179 state["nose_v"] = 0.0 180 if compute_thermostat_energy: 181 state["thermostat_energy"] = 0.0 182 print( 183 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 184 ) 185 vel = ( 186 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 187 * (kT / mass[:, None]) ** 0.5 188 ) 189 190 def thermostat(vel, state): 191 nose_s = state["nose_s"] 192 nose_v = state["nose_v"] 193 kTsys = jnp.sum(mass[:, None] * vel**2) 194 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 195 nose_s = nose_s + dt * nose_v 196 vel = jnp.exp(-nose_v * dt) * vel 197 kTsys = jnp.sum(mass[:, None] * vel**2) 198 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 199 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 200 201 if compute_thermostat_energy: 202 new_state["thermostat_energy"] = ( 203 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 204 ) 205 return vel, new_state 206 207 elif thermostat_name in ["QTB", "ADQTB"]: 208 assert nbeads is None, "QTB is not compatible with PIMD" 209 qtb_parameters = simulation_parameters.get("qtb", None) 210 assert ( 211 qtb_parameters is not None 212 ), "qtb_parameters must be provided for QTB thermostat" 213 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 214 assert kT is not None, "kT must be specified for QTB thermostat" 215 assert gamma is not None, "gamma must be specified for QTB thermostat" 216 assert species is not None, "species must be provided for QTB thermostat" 217 rng_key, v_key = jax.random.split(rng_key) 218 vel = ( 219 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 220 * (kT / mass[:, None]) ** 0.5 221 ) 222 223 thermostat, postprocess, qtb_state = initialize_qtb( 224 qtb_parameters, 225 fprec=fprec, 226 dt=dt, 227 mass=mass, 228 gamma=gamma, 229 kT=kT, 230 species=species, 231 rng_key=rng_key, 232 adaptive=thermostat_name.startswith("AD"), 233 compute_thermostat_energy=compute_thermostat_energy, 234 ) 235 state = {**state, **qtb_state} 236 237 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 238 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 239 assert kT is not None, "kT must be specified for QTB thermostat" 240 assert gamma is not None, "gamma must be specified for QTB thermostat" 241 assert nbeads is None, "ANNEAL is not compatible with PIMD" 242 a1 = math.exp(-gamma * dt) 243 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 244 245 anneal_parameters = simulation_parameters.get("annealing", {}) 246 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 247 assert init_factor > 0.0, "init_factor must be positive" 248 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 249 assert final_factor > 0.0, "final_factor must be positive" 250 nsteps = simulation_parameters.get("nsteps") 251 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 252 assert ( 253 anneal_steps < 1.0 and anneal_steps > 0.0 254 ), "warmup_steps must be between 0 and nsteps" 255 pct_start = anneal_parameters.get("warmup_steps", 0.3) 256 assert ( 257 pct_start < 1.0 and pct_start > 0.0 258 ), "warmup_steps must be between 0 and nsteps" 259 260 anneal_type = anneal_parameters.get("type", "cosine").lower() 261 if anneal_type == "linear": 262 schedule = optax.linear_onecycle_schedule( 263 peak_value=1.0, 264 div_factor=1.0 / init_factor, 265 final_div_factor=1.0 / final_factor, 266 transition_steps=int(anneal_steps * nsteps), 267 pct_start=pct_start, 268 pct_final=1.0, 269 ) 270 elif anneal_type == "cosine_onecycle": 271 schedule = optax.cosine_onecycle_schedule( 272 peak_value=1.0, 273 div_factor=1.0 / init_factor, 274 final_div_factor=1.0 / final_factor, 275 transition_steps=int(anneal_steps * nsteps), 276 pct_start=pct_start, 277 ) 278 else: 279 raise ValueError(f"Unknown anneal_type {anneal_type}") 280 281 state["rng_key"] = rng_key 282 state["istep_anneal"] = 0 283 284 rng_key, v_key = jax.random.split(rng_key) 285 Tscale = schedule(0) 286 print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K") 287 vel = ( 288 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 289 * (kT * Tscale / mass[:, None]) ** 0.5 290 ) 291 292 def thermostat(vel, state): 293 rng_key, noise_key = jax.random.split(state["rng_key"]) 294 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 295 296 Tscale = schedule(state["istep_anneal"]) ** 0.5 297 vel = a1 * vel + a2 * Tscale * noise 298 return vel, { 299 **state, 300 "rng_key": rng_key, 301 "istep_anneal": state["istep_anneal"] + 1, 302 } 303 304 else: 305 raise ValueError(f"Unknown thermostat {thermostat_name}") 306 307 return thermostat, postprocess, state, vel,thermostat_name 308 309 310def initialize_qtb( 311 qtb_parameters, 312 fprec, 313 dt, 314 mass, 315 gamma, 316 kT, 317 species, 318 rng_key, 319 adaptive, 320 compute_thermostat_energy=False, 321): 322 state = {} 323 post_state = {} 324 verbose = qtb_parameters.get("verbose", False) 325 if compute_thermostat_energy: 326 state["thermostat_energy"] = 0.0 327 328 mass = jnp.asarray(mass, dtype=fprec) 329 330 nat = species.shape[0] 331 # define type indices 332 species_set = set(species) 333 nspecies = len(species_set) 334 idx = {sp: i for i, sp in enumerate(species_set)} 335 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 336 337 n_of_type = np.zeros(nspecies, dtype=np.int32) 338 for i in range(nspecies): 339 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 340 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 341 mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type 342 343 corr_kin = qtb_parameters.get("corr_kin", -1) 344 do_corr_kin = corr_kin <= 0 345 if do_corr_kin: 346 corr_kin = 1.0 347 state["corr_kin"] = corr_kin 348 post_state["corr_kin_prev"] = corr_kin 349 post_state["do_corr_kin"] = do_corr_kin 350 post_state["isame_kin"] = 0 351 352 # spectra parameters 353 omegasmear = np.pi / dt / 100.0 354 Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS 355 nseg = int(Tseg / dt) 356 Tseg = nseg * dt 357 dom = 2 * np.pi / (3 * Tseg) 358 omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS 359 nom = int(omegacut / dom) 360 omega = dom * np.arange((3 * nseg) // 2 + 1) 361 cutoff = jnp.asarray( 362 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 363 ) 364 assert ( 365 omegacut < omega[-1] 366 ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1" 367 368 # initialize gammar 369 assert ( 370 gamma < 0.5 * omegacut 371 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 372 gammar_min = qtb_parameters.get("gammar_min", 0.1) 373 post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 374 375 # Ornstein-Uhlenbeck correction for colored noise 376 a1 = np.exp(-gamma * dt) 377 OUcorr = jnp.asarray( 378 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 379 dtype=fprec, 380 ) 381 382 # hbar schedule 383 classical_kernel = qtb_parameters.get("classical_kernel", False) 384 hbar = qtb_parameters.get("hbar", 1.0) * au.FS 385 u = 0.5 * hbar * np.abs(omega) / kT 386 theta = kT * np.ones_like(omega) 387 if hbar > 0: 388 theta[1:] *= u[1:] / np.tanh(u[1:]) 389 theta = jnp.asarray(theta, dtype=fprec) 390 391 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 392 del rng_key 393 post_state["white_noise"] = jax.random.normal( 394 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 395 ) 396 397 startsave = qtb_parameters.get("startsave", 1) 398 counter = Counter(nseg, startsave=startsave) 399 state["istep"] = 0 400 post_state["nadapt"] = 0 401 post_state["nsample"] = 0 402 403 write_spectra = qtb_parameters.get("write_spectra", True) 404 do_compute_spectra = write_spectra or adaptive 405 406 if do_compute_spectra: 407 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 408 409 post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec) 410 post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec) 411 post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec) 412 post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec) 413 post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 414 post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 415 post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 416 post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 417 418 if not adaptive: 419 update_gammar = lambda x: x 420 else: 421 # adaptation parameters 422 skipseg = qtb_parameters.get("skipseg", 1) 423 424 adaptation_method = ( 425 str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip() 426 ) 427 authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"] 428 assert ( 429 adaptation_method in authorized_methods 430 ), f"adaptation_method must be one of {authorized_methods}" 431 if adaptation_method == "SIMPLE": 432 agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS 433 assert agamma > 0, "agamma must be positive" 434 a1_ad = agamma * Tseg # * gamma 435 print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}") 436 437 def update_gammar(post_state): 438 g = post_state["dFDT"] 439 gammar = post_state["gammar"] - a1_ad * g 440 gammar = jnp.maximum(gammar_min, gammar) 441 return {**post_state, "gammar": gammar} 442 443 elif adaptation_method == "RATIO": 444 tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS 445 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS 446 assert tau_ad > 0, "tau_ad must be positive" 447 print( 448 f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps" 449 ) 450 b1 = np.exp(-Tseg / tau_ad) 451 b2 = np.exp(-Tseg / tau_s) 452 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 453 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 454 post_state["n_adabelief"] = 0 455 456 def update_gammar(post_state): 457 n_adabelief = post_state["n_adabelief"] + 1 458 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 459 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 460 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 461 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 462 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 463 gammar = Cvf / (mCvv + 1.0e-8) 464 gammar = jnp.maximum(gammar_min, gammar) 465 return { 466 **post_state, 467 "gammar": gammar, 468 "mCvv_m": mCvv_m, 469 "Cvf_m": Cvf_m, 470 "n_adabelief": n_adabelief, 471 } 472 473 elif adaptation_method == "ADABELIEF": 474 agamma = qtb_parameters.get("agamma", 0.1) 475 tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS 476 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS 477 assert tau_ad > 0, "tau_ad must be positive" 478 assert tau_s > 0, "tau_s must be positive" 479 assert agamma > 0, "agamma must be positive" 480 print( 481 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps" 482 ) 483 484 a1_ad = agamma * gamma # * Tseg #* gamma 485 b1 = np.exp(-Tseg / tau_ad) 486 b2 = np.exp(-Tseg / tau_s) 487 post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 488 post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec) 489 post_state["n_adabelief"] = 0 490 491 def update_gammar(post_state): 492 n_adabelief = post_state["n_adabelief"] + 1 493 dFDT = post_state["dFDT"] 494 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 495 dFDT_s = ( 496 post_state["dFDT_s"] * b2 497 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 498 + 1.0e-8 499 ) 500 # bias correction 501 mt = dFDT_m / (1.0 - b1**n_adabelief) 502 st = dFDT_s / (1.0 - b2**n_adabelief) 503 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 504 gammar = jnp.maximum(gammar_min, gammar) 505 return { 506 **post_state, 507 "gammar": gammar, 508 "dFDT_m": dFDT_m, 509 "n_adabelief": n_adabelief, 510 "dFDT_s": dFDT_s, 511 } 512 513 def compute_corr_pot(niter=20, verbose=False): 514 if classical_kernel or hbar == 0: 515 return np.ones(nom) 516 517 s_0 = np.array((theta / kT * cutoff)[:nom]) 518 s_out, s_rec, _ = deconvolute_spectrum( 519 s_0, 520 omega[:nom], 521 gamma, 522 niter, 523 kernel=kernel_lorentz_pot, 524 trans=True, 525 symmetrize=True, 526 verbose=verbose, 527 ) 528 corr_pot = 1.0 + (s_out - s_0) / s_0 529 columns = np.column_stack( 530 (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 531 ) 532 np.savetxt( 533 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 534 ) 535 return corr_pot 536 537 def compute_corr_kin(post_state, niter=7, verbose=False): 538 if not post_state["do_corr_kin"]: 539 return post_state["corr_kin_prev"], post_state 540 if classical_kernel or hbar == 0: 541 return 1.0, post_state 542 543 K_D = post_state.get("K_D", None) 544 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 545 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 546 s_out, s_rec, K_D = deconvolute_spectrum( 547 s_0, 548 omega[:nom], 549 gamma, 550 niter, 551 kernel=kernel_lorentz, 552 trans=False, 553 symmetrize=True, 554 verbose=verbose, 555 K_D=K_D, 556 ) 557 s_out = s_out * theta[:nom] / kT 558 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 559 mCvvsum = mCvv.sum() 560 rec_ratio = mCvvsum / s_rec.sum() 561 if rec_ratio < 0.95 or rec_ratio > 1.05: 562 print( 563 "# WARNING: reconvolution error is too high, corr_kin was not updated" 564 ) 565 return 566 567 corr_kin = mCvvsum / s_out.sum() 568 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 569 isame_kin = post_state["isame_kin"] + 1 570 else: 571 isame_kin = 0 572 573 # print("# corr_kin: ", corr_kin) 574 do_corr_kin = post_state["do_corr_kin"] 575 if isame_kin > 10: 576 print( 577 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 578 ) 579 do_corr_kin = False 580 581 return corr_kin, { 582 **post_state, 583 "corr_kin_prev": corr_kin, 584 "isame_kin": isame_kin, 585 "do_corr_kin": do_corr_kin, 586 "K_D": K_D, 587 } 588 589 @jax.jit 590 def ff_kernel(post_state): 591 if classical_kernel: 592 kernel = cutoff * (2 * gamma * kT / dt) 593 else: 594 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 595 gamma_ratio = jnp.concatenate( 596 ( 597 post_state["gammar"].T * post_state["corr_pot"][:, None], 598 jnp.ones( 599 (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype 600 ), 601 ), 602 axis=0, 603 ) 604 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 605 606 @jax.jit 607 def refresh_force(post_state): 608 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 609 white_noise = jnp.concatenate( 610 ( 611 post_state["white_noise"][nseg:], 612 jax.random.normal( 613 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 614 ), 615 ), 616 axis=0, 617 ) 618 amplitude = ff_kernel(post_state) ** 0.5 619 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 620 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 621 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 622 623 @jax.jit 624 def compute_spectra(force, vel, post_state): 625 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 626 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 627 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 628 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 629 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 630 631 mCvv = ( 632 (dt / 3.0) 633 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 634 * mass_idx[:, None] 635 / n_of_type[:, None] 636 ) 637 Cvf = ( 638 (dt / 3.0) 639 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 640 / n_of_type[:, None] 641 ) 642 Cff = ( 643 (dt / 3.0) 644 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 645 / n_of_type[:, None] 646 ) 647 dFDT = mCvv * post_state["gammar"] - Cvf 648 649 nsinv = 1.0 / post_state["nsample"] 650 b1 = 1.0 - nsinv 651 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 652 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 653 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 654 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 655 656 return { 657 **post_state, 658 "mCvv": mCvv, 659 "Cvf": Cvf, 660 "Cff": Cff, 661 "dFDT": dFDT, 662 "dFDT_avg": dFDT_avg, 663 "mCvv_avg": mCvv_avg, 664 "Cvfg_avg": Cvfg_avg, 665 "Cff_avg": Cff_avg, 666 } 667 668 def write_spectra_to_file(post_state): 669 mCvv_avg = np.array(post_state["mCvv_avg"]) 670 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 671 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 672 dFDT_avg = np.array(post_state["dFDT_avg"]) 673 gammar = np.array(post_state["gammar"]) 674 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 675 for i, sp in enumerate(species_set): 676 ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 677 columns = np.column_stack( 678 ( 679 omega[:nom] * (au.FS * au.CM1), 680 mCvv_avg[i], 681 Cvfg_avg[i], 682 dFDT_avg[i], 683 gammar[i] * gamma * (au.FS * au.THZ), 684 Cff_avg[i] * ff_scale, 685 Cff_theo[i] * ff_scale, 686 ) 687 ) 688 np.savetxt( 689 f"QTB_spectra_{sp}.out", 690 columns, 691 fmt="%12.6f", 692 header="#omega mCvv Cvf dFDT gammar Cff", 693 ) 694 if verbose: 695 print("# QTB spectra written.") 696 697 if compute_thermostat_energy: 698 state["qtb_energy_flux"] = 0.0 699 700 @jax.jit 701 def thermostat(vel, state): 702 istep = state["istep"] 703 dvel = dt * state["force"][istep] / mass[:, None] 704 new_vel = vel * a1 + dvel 705 new_state = {**state, "istep": istep + 1} 706 if do_compute_spectra: 707 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 708 new_state["vel"] = vel2 709 if compute_thermostat_energy: 710 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 711 ekcorr = ( 712 0.5 713 * (mass[:, None] * new_vel**2).sum() 714 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 715 ) 716 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 717 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 718 return new_vel, new_state 719 720 @jax.jit 721 def postprocess_work(state, post_state): 722 if do_compute_spectra: 723 post_state = compute_spectra(state["force"], state["vel"], post_state) 724 if adaptive: 725 post_state = jax.lax.cond( 726 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 727 ) 728 new_force, post_state = refresh_force(post_state) 729 return {**state, "force": new_force}, post_state 730 731 def postprocess(state, post_state): 732 counter.increment() 733 if not counter.is_reset_step: 734 return state, post_state 735 post_state["nadapt"] += 1 736 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 737 if verbose: 738 print("# Refreshing QTB forces.") 739 state, post_state = postprocess_work(state, post_state) 740 state["corr_kin"], post_state = compute_corr_kin(post_state) 741 state["istep"] = 0 742 if write_spectra: 743 write_spectra_to_file(post_state) 744 return state, post_state 745 746 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(), dtype=fprec) 747 748 state["force"], post_state = refresh_force(post_state) 749 return thermostat, (postprocess, post_state), state
def
get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None):
18def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None): 19 state = {} 20 postprocess = None 21 22 23 thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper() 24 compute_thermostat_energy = simulation_parameters.get( 25 "include_thermostat_energy", False 26 ) 27 28 kT = system_data.get("kT", None) 29 nbeads = system_data.get("nbeads", None) 30 mass = system_data["mass"] 31 gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS 32 species = system_data["species"] 33 34 if nbeads is not None: 35 trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0) 36 gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma) 37 38 if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]: 39 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 40 assert kT is not None, "kT must be specified for QTB thermostat" 41 assert gamma is not None, "gamma must be specified for QTB thermostat" 42 rng_key, v_key = jax.random.split(rng_key) 43 if nbeads is None: 44 a1 = math.exp(-gamma * dt) 45 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 46 vel = ( 47 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 48 * (kT / mass[:, None]) ** 0.5 49 ) 50 else: 51 if isinstance(gamma, float): 52 gamma = np.array([gamma] * nbeads) 53 assert isinstance( 54 gamma, np.ndarray 55 ), "gamma must be a float or a numpy array" 56 assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads" 57 a1 = np.exp(-gamma * dt)[:, None, None] 58 a2 = jnp.asarray( 59 ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec 60 ) 61 vel = ( 62 jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec) 63 * (kT / mass[:, None]) ** 0.5 64 ) 65 66 state["rng_key"] = rng_key 67 if compute_thermostat_energy: 68 state["thermostat_energy"] = 0.0 69 if thermostat_name == "FFLGV": 70 def thermostat(vel, state): 71 rng_key, noise_key = jax.random.split(state["rng_key"]) 72 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 73 norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 74 dirvel = vel / norm_vel 75 if compute_thermostat_energy: 76 v2 = (vel**2).sum(axis=-1) 77 vel = a1 * vel + a2 * noise 78 new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True) 79 vel = dirvel * new_norm_vel 80 new_state = {**state, "rng_key": rng_key} 81 if compute_thermostat_energy: 82 v2new = (vel**2).sum(axis=-1) 83 new_state["thermostat_energy"] = ( 84 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 85 ) 86 87 return vel, new_state 88 89 else: 90 def thermostat(vel, state): 91 rng_key, noise_key = jax.random.split(state["rng_key"]) 92 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 93 if compute_thermostat_energy: 94 v2 = (vel**2).sum(axis=-1) 95 vel = a1 * vel + a2 * noise 96 new_state = {**state, "rng_key": rng_key} 97 if compute_thermostat_energy: 98 v2new = (vel**2).sum(axis=-1) 99 new_state["thermostat_energy"] = ( 100 state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum() 101 ) 102 return vel, new_state 103 104 elif thermostat_name in ["BUSSI"]: 105 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 106 assert kT is not None, "kT must be specified for QTB thermostat" 107 assert gamma is not None, "gamma must be specified for QTB thermostat" 108 assert nbeads is None, "Bussi thermostat is not compatible with PIMD" 109 rng_key, v_key = jax.random.split(rng_key) 110 111 a1 = math.exp(-gamma * dt) 112 a2 = (1 - a1) * kT 113 vel = ( 114 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 115 * (kT / mass[:, None]) ** 0.5 116 ) 117 118 state["rng_key"] = rng_key 119 if compute_thermostat_energy: 120 state["thermostat_energy"] = 0.0 121 122 def thermostat(vel, state): 123 rng_key, noise_key = jax.random.split(state["rng_key"]) 124 new_state = {**state, "rng_key": rng_key} 125 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 126 R2 = jnp.sum(noise**2) 127 R1 = noise[0, 0] 128 c = a2 / (mass[:, None] * vel**2).sum() 129 d = (a1 * c) ** 0.5 130 scale = (a1 + c * R2 + 2 * d * R1) ** 0.5 131 if compute_thermostat_energy: 132 dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1) 133 new_state["thermostat_energy"] = state["thermostat_energy"] + dek 134 return scale * vel, new_state 135 136 elif thermostat_name in [ 137 "GD", 138 "DESCENT", 139 "GRADIENT_DESCENT", 140 "MIN", 141 "MINIMIZE", 142 ]: 143 assert nbeads is None, "Gradient descent is not compatible with PIMD" 144 a1 = math.exp(-gamma * dt) 145 146 if nbeads is None: 147 vel = jnp.zeros((mass.shape[0], 3), dtype=fprec) 148 else: 149 vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec) 150 151 def thermostat(vel, state): 152 return a1 * vel, state 153 154 elif thermostat_name in ["NVE", "NONE"]: 155 if nbeads is None: 156 vel = ( 157 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 158 * (kT / mass[:, None]) ** 0.5 159 ) 160 kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3) 161 vel = vel * (kT / kTsys) ** 0.5 162 else: 163 vel = ( 164 jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec) 165 * (kT / mass[None, :, None]) ** 0.5 166 ) 167 kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / ( 168 mass.shape[0] * 3 169 ) 170 vel = vel * (kT / kTsys[:, None, None]) ** 0.5 171 thermostat = lambda x, s: (x, s) 172 173 elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]: 174 assert gamma is not None, "gamma must be specified for QTB thermostat" 175 ndof = mass.shape[0] * 3 176 nkT = ndof * kT 177 nose_mass = nkT / gamma**2 178 assert nbeads is None, "Nose-Hoover is not compatible with PIMD" 179 state["nose_s"] = 0.0 180 state["nose_v"] = 0.0 181 if compute_thermostat_energy: 182 state["thermostat_energy"] = 0.0 183 print( 184 "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed." 185 ) 186 vel = ( 187 jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec) 188 * (kT / mass[:, None]) ** 0.5 189 ) 190 191 def thermostat(vel, state): 192 nose_s = state["nose_s"] 193 nose_v = state["nose_v"] 194 kTsys = jnp.sum(mass[:, None] * vel**2) 195 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 196 nose_s = nose_s + dt * nose_v 197 vel = jnp.exp(-nose_v * dt) * vel 198 kTsys = jnp.sum(mass[:, None] * vel**2) 199 nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT) 200 new_state = {**state, "nose_s": nose_s, "nose_v": nose_v} 201 202 if compute_thermostat_energy: 203 new_state["thermostat_energy"] = ( 204 nkT * nose_s + (0.5 * nose_mass) * nose_v**2 205 ) 206 return vel, new_state 207 208 elif thermostat_name in ["QTB", "ADQTB"]: 209 assert nbeads is None, "QTB is not compatible with PIMD" 210 qtb_parameters = simulation_parameters.get("qtb", None) 211 assert ( 212 qtb_parameters is not None 213 ), "qtb_parameters must be provided for QTB thermostat" 214 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 215 assert kT is not None, "kT must be specified for QTB thermostat" 216 assert gamma is not None, "gamma must be specified for QTB thermostat" 217 assert species is not None, "species must be provided for QTB thermostat" 218 rng_key, v_key = jax.random.split(rng_key) 219 vel = ( 220 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 221 * (kT / mass[:, None]) ** 0.5 222 ) 223 224 thermostat, postprocess, qtb_state = initialize_qtb( 225 qtb_parameters, 226 fprec=fprec, 227 dt=dt, 228 mass=mass, 229 gamma=gamma, 230 kT=kT, 231 species=species, 232 rng_key=rng_key, 233 adaptive=thermostat_name.startswith("AD"), 234 compute_thermostat_energy=compute_thermostat_energy, 235 ) 236 state = {**state, **qtb_state} 237 238 elif thermostat_name in ["ANNEAL", "ANNEALING"]: 239 assert rng_key is not None, "rng_key must be provided for QTB thermostat" 240 assert kT is not None, "kT must be specified for QTB thermostat" 241 assert gamma is not None, "gamma must be specified for QTB thermostat" 242 assert nbeads is None, "ANNEAL is not compatible with PIMD" 243 a1 = math.exp(-gamma * dt) 244 a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec) 245 246 anneal_parameters = simulation_parameters.get("annealing", {}) 247 init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0) 248 assert init_factor > 0.0, "init_factor must be positive" 249 final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0) 250 assert final_factor > 0.0, "final_factor must be positive" 251 nsteps = simulation_parameters.get("nsteps") 252 anneal_steps = anneal_parameters.get("anneal_steps", 1.0) 253 assert ( 254 anneal_steps < 1.0 and anneal_steps > 0.0 255 ), "warmup_steps must be between 0 and nsteps" 256 pct_start = anneal_parameters.get("warmup_steps", 0.3) 257 assert ( 258 pct_start < 1.0 and pct_start > 0.0 259 ), "warmup_steps must be between 0 and nsteps" 260 261 anneal_type = anneal_parameters.get("type", "cosine").lower() 262 if anneal_type == "linear": 263 schedule = optax.linear_onecycle_schedule( 264 peak_value=1.0, 265 div_factor=1.0 / init_factor, 266 final_div_factor=1.0 / final_factor, 267 transition_steps=int(anneal_steps * nsteps), 268 pct_start=pct_start, 269 pct_final=1.0, 270 ) 271 elif anneal_type == "cosine_onecycle": 272 schedule = optax.cosine_onecycle_schedule( 273 peak_value=1.0, 274 div_factor=1.0 / init_factor, 275 final_div_factor=1.0 / final_factor, 276 transition_steps=int(anneal_steps * nsteps), 277 pct_start=pct_start, 278 ) 279 else: 280 raise ValueError(f"Unknown anneal_type {anneal_type}") 281 282 state["rng_key"] = rng_key 283 state["istep_anneal"] = 0 284 285 rng_key, v_key = jax.random.split(rng_key) 286 Tscale = schedule(0) 287 print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K") 288 vel = ( 289 jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec) 290 * (kT * Tscale / mass[:, None]) ** 0.5 291 ) 292 293 def thermostat(vel, state): 294 rng_key, noise_key = jax.random.split(state["rng_key"]) 295 noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype) 296 297 Tscale = schedule(state["istep_anneal"]) ** 0.5 298 vel = a1 * vel + a2 * Tscale * noise 299 return vel, { 300 **state, 301 "rng_key": rng_key, 302 "istep_anneal": state["istep_anneal"] + 1, 303 } 304 305 else: 306 raise ValueError(f"Unknown thermostat {thermostat_name}") 307 308 return thermostat, postprocess, state, vel,thermostat_name
def
initialize_qtb( qtb_parameters, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
311def initialize_qtb( 312 qtb_parameters, 313 fprec, 314 dt, 315 mass, 316 gamma, 317 kT, 318 species, 319 rng_key, 320 adaptive, 321 compute_thermostat_energy=False, 322): 323 state = {} 324 post_state = {} 325 verbose = qtb_parameters.get("verbose", False) 326 if compute_thermostat_energy: 327 state["thermostat_energy"] = 0.0 328 329 mass = jnp.asarray(mass, dtype=fprec) 330 331 nat = species.shape[0] 332 # define type indices 333 species_set = set(species) 334 nspecies = len(species_set) 335 idx = {sp: i for i, sp in enumerate(species_set)} 336 type_idx = np.array([idx[sp] for sp in species], dtype=np.int32) 337 338 n_of_type = np.zeros(nspecies, dtype=np.int32) 339 for i in range(nspecies): 340 n_of_type[i] = (type_idx == i).nonzero()[0].shape[0] 341 n_of_type = jnp.asarray(n_of_type, dtype=fprec) 342 mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type 343 344 corr_kin = qtb_parameters.get("corr_kin", -1) 345 do_corr_kin = corr_kin <= 0 346 if do_corr_kin: 347 corr_kin = 1.0 348 state["corr_kin"] = corr_kin 349 post_state["corr_kin_prev"] = corr_kin 350 post_state["do_corr_kin"] = do_corr_kin 351 post_state["isame_kin"] = 0 352 353 # spectra parameters 354 omegasmear = np.pi / dt / 100.0 355 Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS 356 nseg = int(Tseg / dt) 357 Tseg = nseg * dt 358 dom = 2 * np.pi / (3 * Tseg) 359 omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS 360 nom = int(omegacut / dom) 361 omega = dom * np.arange((3 * nseg) // 2 + 1) 362 cutoff = jnp.asarray( 363 1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec 364 ) 365 assert ( 366 omegacut < omega[-1] 367 ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1" 368 369 # initialize gammar 370 assert ( 371 gamma < 0.5 * omegacut 372 ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)" 373 gammar_min = qtb_parameters.get("gammar_min", 0.1) 374 post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec) 375 376 # Ornstein-Uhlenbeck correction for colored noise 377 a1 = np.exp(-gamma * dt) 378 OUcorr = jnp.asarray( 379 (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)), 380 dtype=fprec, 381 ) 382 383 # hbar schedule 384 classical_kernel = qtb_parameters.get("classical_kernel", False) 385 hbar = qtb_parameters.get("hbar", 1.0) * au.FS 386 u = 0.5 * hbar * np.abs(omega) / kT 387 theta = kT * np.ones_like(omega) 388 if hbar > 0: 389 theta[1:] *= u[1:] / np.tanh(u[1:]) 390 theta = jnp.asarray(theta, dtype=fprec) 391 392 noise_key, post_state["rng_key"] = jax.random.split(rng_key) 393 del rng_key 394 post_state["white_noise"] = jax.random.normal( 395 noise_key, (3 * nseg, nat, 3), dtype=jnp.float32 396 ) 397 398 startsave = qtb_parameters.get("startsave", 1) 399 counter = Counter(nseg, startsave=startsave) 400 state["istep"] = 0 401 post_state["nadapt"] = 0 402 post_state["nsample"] = 0 403 404 write_spectra = qtb_parameters.get("write_spectra", True) 405 do_compute_spectra = write_spectra or adaptive 406 407 if do_compute_spectra: 408 state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec) 409 410 post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec) 411 post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec) 412 post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec) 413 post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec) 414 post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 415 post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 416 post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 417 post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec) 418 419 if not adaptive: 420 update_gammar = lambda x: x 421 else: 422 # adaptation parameters 423 skipseg = qtb_parameters.get("skipseg", 1) 424 425 adaptation_method = ( 426 str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip() 427 ) 428 authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"] 429 assert ( 430 adaptation_method in authorized_methods 431 ), f"adaptation_method must be one of {authorized_methods}" 432 if adaptation_method == "SIMPLE": 433 agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS 434 assert agamma > 0, "agamma must be positive" 435 a1_ad = agamma * Tseg # * gamma 436 print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}") 437 438 def update_gammar(post_state): 439 g = post_state["dFDT"] 440 gammar = post_state["gammar"] - a1_ad * g 441 gammar = jnp.maximum(gammar_min, gammar) 442 return {**post_state, "gammar": gammar} 443 444 elif adaptation_method == "RATIO": 445 tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS 446 tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS 447 assert tau_ad > 0, "tau_ad must be positive" 448 print( 449 f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps" 450 ) 451 b1 = np.exp(-Tseg / tau_ad) 452 b2 = np.exp(-Tseg / tau_s) 453 post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 454 post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 455 post_state["n_adabelief"] = 0 456 457 def update_gammar(post_state): 458 n_adabelief = post_state["n_adabelief"] + 1 459 mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1) 460 Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2) 461 mCvv = mCvv_m / (1.0 - b1**n_adabelief) 462 Cvf = Cvf_m / (1.0 - b2**n_adabelief) 463 # g = Cvf/(mCvv+1.e-8)-post_state["gammar"] 464 gammar = Cvf / (mCvv + 1.0e-8) 465 gammar = jnp.maximum(gammar_min, gammar) 466 return { 467 **post_state, 468 "gammar": gammar, 469 "mCvv_m": mCvv_m, 470 "Cvf_m": Cvf_m, 471 "n_adabelief": n_adabelief, 472 } 473 474 elif adaptation_method == "ADABELIEF": 475 agamma = qtb_parameters.get("agamma", 0.1) 476 tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS 477 tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS 478 assert tau_ad > 0, "tau_ad must be positive" 479 assert tau_s > 0, "tau_s must be positive" 480 assert agamma > 0, "agamma must be positive" 481 print( 482 f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps" 483 ) 484 485 a1_ad = agamma * gamma # * Tseg #* gamma 486 b1 = np.exp(-Tseg / tau_ad) 487 b2 = np.exp(-Tseg / tau_s) 488 post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec) 489 post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec) 490 post_state["n_adabelief"] = 0 491 492 def update_gammar(post_state): 493 n_adabelief = post_state["n_adabelief"] + 1 494 dFDT = post_state["dFDT"] 495 dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1) 496 dFDT_s = ( 497 post_state["dFDT_s"] * b2 498 + (dFDT - dFDT_m) ** 2 * (1.0 - b2) 499 + 1.0e-8 500 ) 501 # bias correction 502 mt = dFDT_m / (1.0 - b1**n_adabelief) 503 st = dFDT_s / (1.0 - b2**n_adabelief) 504 gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8) 505 gammar = jnp.maximum(gammar_min, gammar) 506 return { 507 **post_state, 508 "gammar": gammar, 509 "dFDT_m": dFDT_m, 510 "n_adabelief": n_adabelief, 511 "dFDT_s": dFDT_s, 512 } 513 514 def compute_corr_pot(niter=20, verbose=False): 515 if classical_kernel or hbar == 0: 516 return np.ones(nom) 517 518 s_0 = np.array((theta / kT * cutoff)[:nom]) 519 s_out, s_rec, _ = deconvolute_spectrum( 520 s_0, 521 omega[:nom], 522 gamma, 523 niter, 524 kernel=kernel_lorentz_pot, 525 trans=True, 526 symmetrize=True, 527 verbose=verbose, 528 ) 529 corr_pot = 1.0 + (s_out - s_0) / s_0 530 columns = np.column_stack( 531 (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec) 532 ) 533 np.savetxt( 534 "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec" 535 ) 536 return corr_pot 537 538 def compute_corr_kin(post_state, niter=7, verbose=False): 539 if not post_state["do_corr_kin"]: 540 return post_state["corr_kin_prev"], post_state 541 if classical_kernel or hbar == 0: 542 return 1.0, post_state 543 544 K_D = post_state.get("K_D", None) 545 mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat 546 s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"]) 547 s_out, s_rec, K_D = deconvolute_spectrum( 548 s_0, 549 omega[:nom], 550 gamma, 551 niter, 552 kernel=kernel_lorentz, 553 trans=False, 554 symmetrize=True, 555 verbose=verbose, 556 K_D=K_D, 557 ) 558 s_out = s_out * theta[:nom] / kT 559 s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"] 560 mCvvsum = mCvv.sum() 561 rec_ratio = mCvvsum / s_rec.sum() 562 if rec_ratio < 0.95 or rec_ratio > 1.05: 563 print( 564 "# WARNING: reconvolution error is too high, corr_kin was not updated" 565 ) 566 return 567 568 corr_kin = mCvvsum / s_out.sum() 569 if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4: 570 isame_kin = post_state["isame_kin"] + 1 571 else: 572 isame_kin = 0 573 574 # print("# corr_kin: ", corr_kin) 575 do_corr_kin = post_state["do_corr_kin"] 576 if isame_kin > 10: 577 print( 578 "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)" 579 ) 580 do_corr_kin = False 581 582 return corr_kin, { 583 **post_state, 584 "corr_kin_prev": corr_kin, 585 "isame_kin": isame_kin, 586 "do_corr_kin": do_corr_kin, 587 "K_D": K_D, 588 } 589 590 @jax.jit 591 def ff_kernel(post_state): 592 if classical_kernel: 593 kernel = cutoff * (2 * gamma * kT / dt) 594 else: 595 kernel = theta * cutoff * OUcorr * (2 * gamma / dt) 596 gamma_ratio = jnp.concatenate( 597 ( 598 post_state["gammar"].T * post_state["corr_pot"][:, None], 599 jnp.ones( 600 (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype 601 ), 602 ), 603 axis=0, 604 ) 605 return kernel[:, None] * gamma_ratio * mass_idx[None, :] 606 607 @jax.jit 608 def refresh_force(post_state): 609 rng_key, noise_key = jax.random.split(post_state["rng_key"]) 610 white_noise = jnp.concatenate( 611 ( 612 post_state["white_noise"][nseg:], 613 jax.random.normal( 614 noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype 615 ), 616 ), 617 axis=0, 618 ) 619 amplitude = ff_kernel(post_state) ** 0.5 620 s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None] 621 force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg] 622 return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise} 623 624 @jax.jit 625 def compute_spectra(force, vel, post_state): 626 sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho") 627 sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho") 628 Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T 629 Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T 630 Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T 631 632 mCvv = ( 633 (dt / 3.0) 634 * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv) 635 * mass_idx[:, None] 636 / n_of_type[:, None] 637 ) 638 Cvf = ( 639 (dt / 3.0) 640 * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf) 641 / n_of_type[:, None] 642 ) 643 Cff = ( 644 (dt / 3.0) 645 * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff) 646 / n_of_type[:, None] 647 ) 648 dFDT = mCvv * post_state["gammar"] - Cvf 649 650 nsinv = 1.0 / post_state["nsample"] 651 b1 = 1.0 - nsinv 652 dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv 653 mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv 654 Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv 655 Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv 656 657 return { 658 **post_state, 659 "mCvv": mCvv, 660 "Cvf": Cvf, 661 "Cff": Cff, 662 "dFDT": dFDT, 663 "dFDT_avg": dFDT_avg, 664 "mCvv_avg": mCvv_avg, 665 "Cvfg_avg": Cvfg_avg, 666 "Cff_avg": Cff_avg, 667 } 668 669 def write_spectra_to_file(post_state): 670 mCvv_avg = np.array(post_state["mCvv_avg"]) 671 Cvfg_avg = np.array(post_state["Cvfg_avg"]) 672 Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2) 673 dFDT_avg = np.array(post_state["dFDT_avg"]) 674 gammar = np.array(post_state["gammar"]) 675 Cff_theo = np.array(ff_kernel(post_state))[:nom].T 676 for i, sp in enumerate(species_set): 677 ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i]) 678 columns = np.column_stack( 679 ( 680 omega[:nom] * (au.FS * au.CM1), 681 mCvv_avg[i], 682 Cvfg_avg[i], 683 dFDT_avg[i], 684 gammar[i] * gamma * (au.FS * au.THZ), 685 Cff_avg[i] * ff_scale, 686 Cff_theo[i] * ff_scale, 687 ) 688 ) 689 np.savetxt( 690 f"QTB_spectra_{sp}.out", 691 columns, 692 fmt="%12.6f", 693 header="#omega mCvv Cvf dFDT gammar Cff", 694 ) 695 if verbose: 696 print("# QTB spectra written.") 697 698 if compute_thermostat_energy: 699 state["qtb_energy_flux"] = 0.0 700 701 @jax.jit 702 def thermostat(vel, state): 703 istep = state["istep"] 704 dvel = dt * state["force"][istep] / mass[:, None] 705 new_vel = vel * a1 + dvel 706 new_state = {**state, "istep": istep + 1} 707 if do_compute_spectra: 708 vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel) 709 new_state["vel"] = vel2 710 if compute_thermostat_energy: 711 dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum() 712 ekcorr = ( 713 0.5 714 * (mass[:, None] * new_vel**2).sum() 715 * (1.0 - 1.0 / state.get("corr_kin", 1.0)) 716 ) 717 new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek 718 new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr 719 return new_vel, new_state 720 721 @jax.jit 722 def postprocess_work(state, post_state): 723 if do_compute_spectra: 724 post_state = compute_spectra(state["force"], state["vel"], post_state) 725 if adaptive: 726 post_state = jax.lax.cond( 727 post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state 728 ) 729 new_force, post_state = refresh_force(post_state) 730 return {**state, "force": new_force}, post_state 731 732 def postprocess(state, post_state): 733 counter.increment() 734 if not counter.is_reset_step: 735 return state, post_state 736 post_state["nadapt"] += 1 737 post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1) 738 if verbose: 739 print("# Refreshing QTB forces.") 740 state, post_state = postprocess_work(state, post_state) 741 state["corr_kin"], post_state = compute_corr_kin(post_state) 742 state["istep"] = 0 743 if write_spectra: 744 write_spectra_to_file(post_state) 745 return state, post_state 746 747 post_state["corr_pot"] = jnp.asarray(compute_corr_pot(), dtype=fprec) 748 749 state["force"], post_state = refresh_force(post_state) 750 return thermostat, (postprocess, post_state), state