fennol.md.thermostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7
  8from ..utils.atomic_units import AtomicUnits as au  # CM1,THZ,BOHR,MPROT
  9from ..utils import Counter
 10from ..utils.deconvolution import (
 11    deconvolute_spectrum,
 12    kernel_lorentz_pot,
 13    kernel_lorentz,
 14)
 15
 16
 17def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None):
 18    state = {}
 19    postprocess = None
 20    
 21
 22    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 23    compute_thermostat_energy = simulation_parameters.get(
 24        "include_thermostat_energy", False
 25    )
 26
 27    kT = system_data.get("kT", None)
 28    nbeads = system_data.get("nbeads", None)
 29    mass = system_data["mass"]
 30    gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS
 31    species = system_data["species"]
 32
 33    if nbeads is not None:
 34        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 35        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 36
 37    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 38        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 39        assert kT is not None, "kT must be specified for QTB thermostat"
 40        assert gamma is not None, "gamma must be specified for QTB thermostat"
 41        rng_key, v_key = jax.random.split(rng_key)
 42        if nbeads is None:
 43            a1 = math.exp(-gamma * dt)
 44            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 45            vel = (
 46                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 47                * (kT / mass[:, None]) ** 0.5
 48            )
 49        else:
 50            if isinstance(gamma, float):
 51                gamma = np.array([gamma] * nbeads)
 52            assert isinstance(
 53                gamma, np.ndarray
 54            ), "gamma must be a float or a numpy array"
 55            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 56            a1 = np.exp(-gamma * dt)[:, None, None]
 57            a2 = jnp.asarray(
 58                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 59            )
 60            vel = (
 61                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 62                * (kT / mass[:, None]) ** 0.5
 63            )
 64
 65        state["rng_key"] = rng_key
 66        if compute_thermostat_energy:
 67            state["thermostat_energy"] = 0.0
 68        if thermostat_name == "FFLGV":
 69            def thermostat(vel, state):
 70                rng_key, noise_key = jax.random.split(state["rng_key"])
 71                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 72                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 73                dirvel = vel / norm_vel
 74                if compute_thermostat_energy:
 75                    v2 = (vel**2).sum(axis=-1)
 76                vel = a1 * vel + a2 * noise
 77                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 78                vel = dirvel * new_norm_vel
 79                new_state = {**state, "rng_key": rng_key}
 80                if compute_thermostat_energy:
 81                    v2new = (vel**2).sum(axis=-1)
 82                    new_state["thermostat_energy"] = (
 83                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
 84                    )
 85
 86                return vel, new_state
 87
 88        else:
 89            def thermostat(vel, state):
 90                rng_key, noise_key = jax.random.split(state["rng_key"])
 91                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 92                if compute_thermostat_energy:
 93                    v2 = (vel**2).sum(axis=-1)
 94                vel = a1 * vel + a2 * noise
 95                new_state = {**state, "rng_key": rng_key}
 96                if compute_thermostat_energy:
 97                    v2new = (vel**2).sum(axis=-1)
 98                    new_state["thermostat_energy"] = (
 99                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
100                    )
101                return vel, new_state
102
103    elif thermostat_name in ["BUSSI"]:
104        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
105        assert kT is not None, "kT must be specified for QTB thermostat"
106        assert gamma is not None, "gamma must be specified for QTB thermostat"
107        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
108        rng_key, v_key = jax.random.split(rng_key)
109
110        a1 = math.exp(-gamma * dt)
111        a2 = (1 - a1) * kT
112        vel = (
113            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
114            * (kT / mass[:, None]) ** 0.5
115        )
116
117        state["rng_key"] = rng_key
118        if compute_thermostat_energy:
119            state["thermostat_energy"] = 0.0
120
121        def thermostat(vel, state):
122            rng_key, noise_key = jax.random.split(state["rng_key"])
123            new_state = {**state, "rng_key": rng_key}
124            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
125            R2 = jnp.sum(noise**2)
126            R1 = noise[0, 0]
127            c = a2 / (mass[:, None] * vel**2).sum()
128            d = (a1 * c) ** 0.5
129            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
130            if compute_thermostat_energy:
131                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
132                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
133            return scale * vel, new_state
134
135    elif thermostat_name in [
136        "GD",
137        "DESCENT",
138        "GRADIENT_DESCENT",
139        "MIN",
140        "MINIMIZE",
141    ]:
142        assert nbeads is None, "Gradient descent is not compatible with PIMD"
143        a1 = math.exp(-gamma * dt)
144
145        if nbeads is None:
146            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
147        else:
148            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
149
150        def thermostat(vel, state):
151            return a1 * vel, state
152
153    elif thermostat_name in ["NVE", "NONE"]:
154        if nbeads is None:
155            vel = (
156                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
157                * (kT / mass[:, None]) ** 0.5
158            )
159            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
160            vel = vel * (kT / kTsys) ** 0.5
161        else:
162            vel = (
163                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
164                * (kT / mass[None, :, None]) ** 0.5
165            )
166            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
167                mass.shape[0] * 3
168            )
169            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
170        thermostat = lambda x, s: (x, s)
171
172    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
173        assert gamma is not None, "gamma must be specified for QTB thermostat"
174        ndof = mass.shape[0] * 3
175        nkT = ndof * kT
176        nose_mass = nkT / gamma**2
177        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
178        state["nose_s"] = 0.0
179        state["nose_v"] = 0.0
180        if compute_thermostat_energy:
181            state["thermostat_energy"] = 0.0
182        print(
183            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
184        )
185        vel = (
186            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
187            * (kT / mass[:, None]) ** 0.5
188        )
189
190        def thermostat(vel, state):
191            nose_s = state["nose_s"]
192            nose_v = state["nose_v"]
193            kTsys = jnp.sum(mass[:, None] * vel**2)
194            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
195            nose_s = nose_s + dt * nose_v
196            vel = jnp.exp(-nose_v * dt) * vel
197            kTsys = jnp.sum(mass[:, None] * vel**2)
198            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
199            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
200
201            if compute_thermostat_energy:
202                new_state["thermostat_energy"] = (
203                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
204                )
205            return vel, new_state
206
207    elif thermostat_name in ["QTB", "ADQTB"]:
208        assert nbeads is None, "QTB is not compatible with PIMD"
209        qtb_parameters = simulation_parameters.get("qtb", None)
210        assert (
211            qtb_parameters is not None
212        ), "qtb_parameters must be provided for QTB thermostat"
213        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
214        assert kT is not None, "kT must be specified for QTB thermostat"
215        assert gamma is not None, "gamma must be specified for QTB thermostat"
216        assert species is not None, "species must be provided for QTB thermostat"
217        rng_key, v_key = jax.random.split(rng_key)
218        vel = (
219            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
220            * (kT / mass[:, None]) ** 0.5
221        )
222
223        thermostat, postprocess, qtb_state = initialize_qtb(
224            qtb_parameters,
225            fprec=fprec,
226            dt=dt,
227            mass=mass,
228            gamma=gamma,
229            kT=kT,
230            species=species,
231            rng_key=rng_key,
232            adaptive=thermostat_name.startswith("AD"),
233            compute_thermostat_energy=compute_thermostat_energy,
234        )
235        state = {**state, **qtb_state}
236
237    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
238        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
239        assert kT is not None, "kT must be specified for QTB thermostat"
240        assert gamma is not None, "gamma must be specified for QTB thermostat"
241        assert nbeads is None, "ANNEAL is not compatible with PIMD"
242        a1 = math.exp(-gamma * dt)
243        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
244
245        anneal_parameters = simulation_parameters.get("annealing", {})
246        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
247        assert init_factor > 0.0, "init_factor must be positive"
248        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
249        assert final_factor > 0.0, "final_factor must be positive"
250        nsteps = simulation_parameters.get("nsteps")
251        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
252        assert (
253            anneal_steps < 1.0 and anneal_steps > 0.0
254        ), "warmup_steps must be between 0 and nsteps"
255        pct_start = anneal_parameters.get("warmup_steps", 0.3)
256        assert (
257            pct_start < 1.0 and pct_start > 0.0
258        ), "warmup_steps must be between 0 and nsteps"
259
260        anneal_type = anneal_parameters.get("type", "cosine").lower()
261        if anneal_type == "linear":
262            schedule = optax.linear_onecycle_schedule(
263                peak_value=1.0,
264                div_factor=1.0 / init_factor,
265                final_div_factor=1.0 / final_factor,
266                transition_steps=int(anneal_steps * nsteps),
267                pct_start=pct_start,
268                pct_final=1.0,
269            )
270        elif anneal_type == "cosine_onecycle":
271            schedule = optax.cosine_onecycle_schedule(
272                peak_value=1.0,
273                div_factor=1.0 / init_factor,
274                final_div_factor=1.0 / final_factor,
275                transition_steps=int(anneal_steps * nsteps),
276                pct_start=pct_start,
277            )
278        else:
279            raise ValueError(f"Unknown anneal_type {anneal_type}")
280
281        state["rng_key"] = rng_key
282        state["istep_anneal"] = 0
283
284        rng_key, v_key = jax.random.split(rng_key)
285        Tscale = schedule(0)
286        print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K")
287        vel = (
288            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
289            * (kT * Tscale / mass[:, None]) ** 0.5
290        )
291
292        def thermostat(vel, state):
293            rng_key, noise_key = jax.random.split(state["rng_key"])
294            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
295
296            Tscale = schedule(state["istep_anneal"]) ** 0.5
297            vel = a1 * vel + a2 * Tscale * noise
298            return vel, {
299                **state,
300                "rng_key": rng_key,
301                "istep_anneal": state["istep_anneal"] + 1,
302            }
303
304    else:
305        raise ValueError(f"Unknown thermostat {thermostat_name}")
306
307    return thermostat, postprocess, state, vel,thermostat_name
308
309
310def initialize_qtb(
311    qtb_parameters,
312    fprec,
313    dt,
314    mass,
315    gamma,
316    kT,
317    species,
318    rng_key,
319    adaptive,
320    compute_thermostat_energy=False,
321):
322    state = {}
323    post_state = {}
324    verbose = qtb_parameters.get("verbose", False)
325    if compute_thermostat_energy:
326        state["thermostat_energy"] = 0.0
327
328    mass = jnp.asarray(mass, dtype=fprec)
329
330    nat = species.shape[0]
331    # define type indices
332    species_set = set(species)
333    nspecies = len(species_set)
334    idx = {sp: i for i, sp in enumerate(species_set)}
335    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
336
337    n_of_type = np.zeros(nspecies, dtype=np.int32)
338    for i in range(nspecies):
339        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
340    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
341    mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type
342
343    corr_kin = qtb_parameters.get("corr_kin", -1)
344    do_corr_kin = corr_kin <= 0
345    if do_corr_kin:
346        corr_kin = 1.0
347    state["corr_kin"] = corr_kin
348    post_state["corr_kin_prev"] = corr_kin
349    post_state["do_corr_kin"] = do_corr_kin
350    post_state["isame_kin"] = 0
351
352    # spectra parameters
353    omegasmear = np.pi / dt / 100.0
354    Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS
355    nseg = int(Tseg / dt)
356    Tseg = nseg * dt
357    dom = 2 * np.pi / (3 * Tseg)
358    omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS
359    nom = int(omegacut / dom)
360    omega = dom * np.arange((3 * nseg) // 2 + 1)
361    cutoff = jnp.asarray(
362        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
363    )
364    assert (
365        omegacut < omega[-1]
366    ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1"
367
368    # initialize gammar
369    assert (
370        gamma < 0.5 * omegacut
371    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
372    gammar_min = qtb_parameters.get("gammar_min", 0.1)
373    post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
374
375    # Ornstein-Uhlenbeck correction for colored noise
376    a1 = np.exp(-gamma * dt)
377    OUcorr = jnp.asarray(
378        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
379        dtype=fprec,
380    )
381
382    # hbar schedule
383    classical_kernel = qtb_parameters.get("classical_kernel", False)
384    hbar = qtb_parameters.get("hbar", 1.0) * au.FS
385    u = 0.5 * hbar * np.abs(omega) / kT
386    theta = kT * np.ones_like(omega)
387    if hbar > 0:
388        theta[1:] *= u[1:] / np.tanh(u[1:])
389    theta = jnp.asarray(theta, dtype=fprec)
390
391    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
392    del rng_key
393    post_state["white_noise"] = jax.random.normal(
394        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
395    )
396
397    startsave = qtb_parameters.get("startsave", 1)
398    counter = Counter(nseg, startsave=startsave)
399    state["istep"] = 0
400    post_state["nadapt"] = 0
401    post_state["nsample"] = 0
402
403    write_spectra = qtb_parameters.get("write_spectra", True)
404    do_compute_spectra = write_spectra or adaptive
405
406    if do_compute_spectra:
407        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
408
409        post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec)
410        post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec)
411        post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec)
412        post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec)
413        post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
414        post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
415        post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
416        post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
417
418    if not adaptive:
419        update_gammar = lambda x: x
420    else:
421        # adaptation parameters
422        skipseg = qtb_parameters.get("skipseg", 1)
423
424        adaptation_method = (
425            str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip()
426        )
427        authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"]
428        assert (
429            adaptation_method in authorized_methods
430        ), f"adaptation_method must be one of {authorized_methods}"
431        if adaptation_method == "SIMPLE":
432            agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS
433            assert agamma > 0, "agamma must be positive"
434            a1_ad = agamma * Tseg  #  * gamma
435            print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}")
436
437            def update_gammar(post_state):
438                g = post_state["dFDT"]
439                gammar = post_state["gammar"] - a1_ad * g
440                gammar = jnp.maximum(gammar_min, gammar)
441                return {**post_state, "gammar": gammar}
442
443        elif adaptation_method == "RATIO":
444            tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS
445            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS
446            assert tau_ad > 0, "tau_ad must be positive"
447            print(
448                f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps"
449            )
450            b1 = np.exp(-Tseg / tau_ad)
451            b2 = np.exp(-Tseg / tau_s)
452            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
453            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
454            post_state["n_adabelief"] = 0
455
456            def update_gammar(post_state):
457                n_adabelief = post_state["n_adabelief"] + 1
458                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
459                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
460                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
461                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
462                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
463                gammar = Cvf / (mCvv + 1.0e-8)
464                gammar = jnp.maximum(gammar_min, gammar)
465                return {
466                    **post_state,
467                    "gammar": gammar,
468                    "mCvv_m": mCvv_m,
469                    "Cvf_m": Cvf_m,
470                    "n_adabelief": n_adabelief,
471                }
472
473        elif adaptation_method == "ADABELIEF":
474            agamma = qtb_parameters.get("agamma", 0.1)
475            tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS
476            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS
477            assert tau_ad > 0, "tau_ad must be positive"
478            assert tau_s > 0, "tau_s must be positive"
479            assert agamma > 0, "agamma must be positive"
480            print(
481                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps"
482            )
483
484            a1_ad = agamma * gamma  # * Tseg #* gamma
485            b1 = np.exp(-Tseg / tau_ad)
486            b2 = np.exp(-Tseg / tau_s)
487            post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
488            post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec)
489            post_state["n_adabelief"] = 0
490
491            def update_gammar(post_state):
492                n_adabelief = post_state["n_adabelief"] + 1
493                dFDT = post_state["dFDT"]
494                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
495                dFDT_s = (
496                    post_state["dFDT_s"] * b2
497                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
498                    + 1.0e-8
499                )
500                # bias correction
501                mt = dFDT_m / (1.0 - b1**n_adabelief)
502                st = dFDT_s / (1.0 - b2**n_adabelief)
503                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
504                gammar = jnp.maximum(gammar_min, gammar)
505                return {
506                    **post_state,
507                    "gammar": gammar,
508                    "dFDT_m": dFDT_m,
509                    "n_adabelief": n_adabelief,
510                    "dFDT_s": dFDT_s,
511                }
512
513    def compute_corr_pot(niter=20, verbose=False):
514        if classical_kernel or hbar == 0:
515            return np.ones(nom)
516
517        s_0 = np.array((theta / kT * cutoff)[:nom])
518        s_out, s_rec, _ = deconvolute_spectrum(
519            s_0,
520            omega[:nom],
521            gamma,
522            niter,
523            kernel=kernel_lorentz_pot,
524            trans=True,
525            symmetrize=True,
526            verbose=verbose,
527        )
528        corr_pot = 1.0 + (s_out - s_0) / s_0
529        columns = np.column_stack(
530            (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
531        )
532        np.savetxt(
533            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
534        )
535        return corr_pot
536
537    def compute_corr_kin(post_state, niter=7, verbose=False):
538        if not post_state["do_corr_kin"]:
539            return post_state["corr_kin_prev"], post_state
540        if classical_kernel or hbar == 0:
541            return 1.0, post_state
542
543        K_D = post_state.get("K_D", None)
544        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
545        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
546        s_out, s_rec, K_D = deconvolute_spectrum(
547            s_0,
548            omega[:nom],
549            gamma,
550            niter,
551            kernel=kernel_lorentz,
552            trans=False,
553            symmetrize=True,
554            verbose=verbose,
555            K_D=K_D,
556        )
557        s_out = s_out * theta[:nom] / kT
558        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
559        mCvvsum = mCvv.sum()
560        rec_ratio = mCvvsum / s_rec.sum()
561        if rec_ratio < 0.95 or rec_ratio > 1.05:
562            print(
563                "# WARNING: reconvolution error is too high, corr_kin was not updated"
564            )
565            return
566
567        corr_kin = mCvvsum / s_out.sum()
568        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
569            isame_kin = post_state["isame_kin"] + 1
570        else:
571            isame_kin = 0
572
573        # print("# corr_kin: ", corr_kin)
574        do_corr_kin = post_state["do_corr_kin"]
575        if isame_kin > 10:
576            print(
577                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
578            )
579            do_corr_kin = False
580
581        return corr_kin, {
582            **post_state,
583            "corr_kin_prev": corr_kin,
584            "isame_kin": isame_kin,
585            "do_corr_kin": do_corr_kin,
586            "K_D": K_D,
587        }
588
589    @jax.jit
590    def ff_kernel(post_state):
591        if classical_kernel:
592            kernel = cutoff * (2 * gamma * kT / dt)
593        else:
594            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
595        gamma_ratio = jnp.concatenate(
596            (
597                post_state["gammar"].T * post_state["corr_pot"][:, None],
598                jnp.ones(
599                    (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype
600                ),
601            ),
602            axis=0,
603        )
604        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
605
606    @jax.jit
607    def refresh_force(post_state):
608        rng_key, noise_key = jax.random.split(post_state["rng_key"])
609        white_noise = jnp.concatenate(
610            (
611                post_state["white_noise"][nseg:],
612                jax.random.normal(
613                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
614                ),
615            ),
616            axis=0,
617        )
618        amplitude = ff_kernel(post_state) ** 0.5
619        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
620        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
621        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
622
623    @jax.jit
624    def compute_spectra(force, vel, post_state):
625        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
626        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
627        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
628        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
629        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
630
631        mCvv = (
632            (dt / 3.0)
633            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
634            * mass_idx[:, None]
635            / n_of_type[:, None]
636        )
637        Cvf = (
638            (dt / 3.0)
639            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
640            / n_of_type[:, None]
641        )
642        Cff = (
643            (dt / 3.0)
644            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
645            / n_of_type[:, None]
646        )
647        dFDT = mCvv * post_state["gammar"] - Cvf
648
649        nsinv = 1.0 / post_state["nsample"]
650        b1 = 1.0 - nsinv
651        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
652        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
653        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
654        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
655
656        return {
657            **post_state,
658            "mCvv": mCvv,
659            "Cvf": Cvf,
660            "Cff": Cff,
661            "dFDT": dFDT,
662            "dFDT_avg": dFDT_avg,
663            "mCvv_avg": mCvv_avg,
664            "Cvfg_avg": Cvfg_avg,
665            "Cff_avg": Cff_avg,
666        }
667
668    def write_spectra_to_file(post_state):
669        mCvv_avg = np.array(post_state["mCvv_avg"])
670        Cvfg_avg = np.array(post_state["Cvfg_avg"])
671        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
672        dFDT_avg = np.array(post_state["dFDT_avg"])
673        gammar = np.array(post_state["gammar"])
674        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
675        for i, sp in enumerate(species_set):
676            ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i])
677            columns = np.column_stack(
678                (
679                    omega[:nom] * (au.FS * au.CM1),
680                    mCvv_avg[i],
681                    Cvfg_avg[i],
682                    dFDT_avg[i],
683                    gammar[i] * gamma * (au.FS * au.THZ),
684                    Cff_avg[i] * ff_scale,
685                    Cff_theo[i] * ff_scale,
686                )
687            )
688            np.savetxt(
689                f"QTB_spectra_{sp}.out",
690                columns,
691                fmt="%12.6f",
692                header="#omega mCvv Cvf dFDT gammar Cff",
693            )
694        if verbose:
695            print("# QTB spectra written.")
696
697    if compute_thermostat_energy:
698        state["qtb_energy_flux"] = 0.0
699
700    @jax.jit
701    def thermostat(vel, state):
702        istep = state["istep"]
703        dvel = dt * state["force"][istep] / mass[:, None]
704        new_vel = vel * a1 + dvel
705        new_state = {**state, "istep": istep + 1}
706        if do_compute_spectra:
707            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
708            new_state["vel"] = vel2
709        if compute_thermostat_energy:
710            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
711            ekcorr = (
712                0.5
713                * (mass[:, None] * new_vel**2).sum()
714                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
715            )
716            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
717            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
718        return new_vel, new_state
719
720    @jax.jit
721    def postprocess_work(state, post_state):
722        if do_compute_spectra:
723            post_state = compute_spectra(state["force"], state["vel"], post_state)
724        if adaptive:
725            post_state = jax.lax.cond(
726                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
727            )
728        new_force, post_state = refresh_force(post_state)
729        return {**state, "force": new_force}, post_state
730
731    def postprocess(state, post_state):
732        counter.increment()
733        if not counter.is_reset_step:
734            return state, post_state
735        post_state["nadapt"] += 1
736        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
737        if verbose:
738            print("# Refreshing QTB forces.")
739        state, post_state = postprocess_work(state, post_state)
740        state["corr_kin"], post_state = compute_corr_kin(post_state)
741        state["istep"] = 0
742        if write_spectra:
743            write_spectra_to_file(post_state)
744        return state, post_state
745
746    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(), dtype=fprec)
747
748    state["force"], post_state = refresh_force(post_state)
749    return thermostat, (postprocess, post_state), state
def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None):
 18def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None):
 19    state = {}
 20    postprocess = None
 21    
 22
 23    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 24    compute_thermostat_energy = simulation_parameters.get(
 25        "include_thermostat_energy", False
 26    )
 27
 28    kT = system_data.get("kT", None)
 29    nbeads = system_data.get("nbeads", None)
 30    mass = system_data["mass"]
 31    gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS
 32    species = system_data["species"]
 33
 34    if nbeads is not None:
 35        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 36        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 37
 38    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 39        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 40        assert kT is not None, "kT must be specified for QTB thermostat"
 41        assert gamma is not None, "gamma must be specified for QTB thermostat"
 42        rng_key, v_key = jax.random.split(rng_key)
 43        if nbeads is None:
 44            a1 = math.exp(-gamma * dt)
 45            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 46            vel = (
 47                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 48                * (kT / mass[:, None]) ** 0.5
 49            )
 50        else:
 51            if isinstance(gamma, float):
 52                gamma = np.array([gamma] * nbeads)
 53            assert isinstance(
 54                gamma, np.ndarray
 55            ), "gamma must be a float or a numpy array"
 56            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 57            a1 = np.exp(-gamma * dt)[:, None, None]
 58            a2 = jnp.asarray(
 59                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 60            )
 61            vel = (
 62                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 63                * (kT / mass[:, None]) ** 0.5
 64            )
 65
 66        state["rng_key"] = rng_key
 67        if compute_thermostat_energy:
 68            state["thermostat_energy"] = 0.0
 69        if thermostat_name == "FFLGV":
 70            def thermostat(vel, state):
 71                rng_key, noise_key = jax.random.split(state["rng_key"])
 72                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 73                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 74                dirvel = vel / norm_vel
 75                if compute_thermostat_energy:
 76                    v2 = (vel**2).sum(axis=-1)
 77                vel = a1 * vel + a2 * noise
 78                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 79                vel = dirvel * new_norm_vel
 80                new_state = {**state, "rng_key": rng_key}
 81                if compute_thermostat_energy:
 82                    v2new = (vel**2).sum(axis=-1)
 83                    new_state["thermostat_energy"] = (
 84                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
 85                    )
 86
 87                return vel, new_state
 88
 89        else:
 90            def thermostat(vel, state):
 91                rng_key, noise_key = jax.random.split(state["rng_key"])
 92                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 93                if compute_thermostat_energy:
 94                    v2 = (vel**2).sum(axis=-1)
 95                vel = a1 * vel + a2 * noise
 96                new_state = {**state, "rng_key": rng_key}
 97                if compute_thermostat_energy:
 98                    v2new = (vel**2).sum(axis=-1)
 99                    new_state["thermostat_energy"] = (
100                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
101                    )
102                return vel, new_state
103
104    elif thermostat_name in ["BUSSI"]:
105        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
106        assert kT is not None, "kT must be specified for QTB thermostat"
107        assert gamma is not None, "gamma must be specified for QTB thermostat"
108        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
109        rng_key, v_key = jax.random.split(rng_key)
110
111        a1 = math.exp(-gamma * dt)
112        a2 = (1 - a1) * kT
113        vel = (
114            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
115            * (kT / mass[:, None]) ** 0.5
116        )
117
118        state["rng_key"] = rng_key
119        if compute_thermostat_energy:
120            state["thermostat_energy"] = 0.0
121
122        def thermostat(vel, state):
123            rng_key, noise_key = jax.random.split(state["rng_key"])
124            new_state = {**state, "rng_key": rng_key}
125            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
126            R2 = jnp.sum(noise**2)
127            R1 = noise[0, 0]
128            c = a2 / (mass[:, None] * vel**2).sum()
129            d = (a1 * c) ** 0.5
130            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
131            if compute_thermostat_energy:
132                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
133                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
134            return scale * vel, new_state
135
136    elif thermostat_name in [
137        "GD",
138        "DESCENT",
139        "GRADIENT_DESCENT",
140        "MIN",
141        "MINIMIZE",
142    ]:
143        assert nbeads is None, "Gradient descent is not compatible with PIMD"
144        a1 = math.exp(-gamma * dt)
145
146        if nbeads is None:
147            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
148        else:
149            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
150
151        def thermostat(vel, state):
152            return a1 * vel, state
153
154    elif thermostat_name in ["NVE", "NONE"]:
155        if nbeads is None:
156            vel = (
157                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
158                * (kT / mass[:, None]) ** 0.5
159            )
160            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
161            vel = vel * (kT / kTsys) ** 0.5
162        else:
163            vel = (
164                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
165                * (kT / mass[None, :, None]) ** 0.5
166            )
167            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
168                mass.shape[0] * 3
169            )
170            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
171        thermostat = lambda x, s: (x, s)
172
173    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
174        assert gamma is not None, "gamma must be specified for QTB thermostat"
175        ndof = mass.shape[0] * 3
176        nkT = ndof * kT
177        nose_mass = nkT / gamma**2
178        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
179        state["nose_s"] = 0.0
180        state["nose_v"] = 0.0
181        if compute_thermostat_energy:
182            state["thermostat_energy"] = 0.0
183        print(
184            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
185        )
186        vel = (
187            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
188            * (kT / mass[:, None]) ** 0.5
189        )
190
191        def thermostat(vel, state):
192            nose_s = state["nose_s"]
193            nose_v = state["nose_v"]
194            kTsys = jnp.sum(mass[:, None] * vel**2)
195            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
196            nose_s = nose_s + dt * nose_v
197            vel = jnp.exp(-nose_v * dt) * vel
198            kTsys = jnp.sum(mass[:, None] * vel**2)
199            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
200            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
201
202            if compute_thermostat_energy:
203                new_state["thermostat_energy"] = (
204                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
205                )
206            return vel, new_state
207
208    elif thermostat_name in ["QTB", "ADQTB"]:
209        assert nbeads is None, "QTB is not compatible with PIMD"
210        qtb_parameters = simulation_parameters.get("qtb", None)
211        assert (
212            qtb_parameters is not None
213        ), "qtb_parameters must be provided for QTB thermostat"
214        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
215        assert kT is not None, "kT must be specified for QTB thermostat"
216        assert gamma is not None, "gamma must be specified for QTB thermostat"
217        assert species is not None, "species must be provided for QTB thermostat"
218        rng_key, v_key = jax.random.split(rng_key)
219        vel = (
220            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
221            * (kT / mass[:, None]) ** 0.5
222        )
223
224        thermostat, postprocess, qtb_state = initialize_qtb(
225            qtb_parameters,
226            fprec=fprec,
227            dt=dt,
228            mass=mass,
229            gamma=gamma,
230            kT=kT,
231            species=species,
232            rng_key=rng_key,
233            adaptive=thermostat_name.startswith("AD"),
234            compute_thermostat_energy=compute_thermostat_energy,
235        )
236        state = {**state, **qtb_state}
237
238    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
239        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
240        assert kT is not None, "kT must be specified for QTB thermostat"
241        assert gamma is not None, "gamma must be specified for QTB thermostat"
242        assert nbeads is None, "ANNEAL is not compatible with PIMD"
243        a1 = math.exp(-gamma * dt)
244        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
245
246        anneal_parameters = simulation_parameters.get("annealing", {})
247        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
248        assert init_factor > 0.0, "init_factor must be positive"
249        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
250        assert final_factor > 0.0, "final_factor must be positive"
251        nsteps = simulation_parameters.get("nsteps")
252        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
253        assert (
254            anneal_steps < 1.0 and anneal_steps > 0.0
255        ), "warmup_steps must be between 0 and nsteps"
256        pct_start = anneal_parameters.get("warmup_steps", 0.3)
257        assert (
258            pct_start < 1.0 and pct_start > 0.0
259        ), "warmup_steps must be between 0 and nsteps"
260
261        anneal_type = anneal_parameters.get("type", "cosine").lower()
262        if anneal_type == "linear":
263            schedule = optax.linear_onecycle_schedule(
264                peak_value=1.0,
265                div_factor=1.0 / init_factor,
266                final_div_factor=1.0 / final_factor,
267                transition_steps=int(anneal_steps * nsteps),
268                pct_start=pct_start,
269                pct_final=1.0,
270            )
271        elif anneal_type == "cosine_onecycle":
272            schedule = optax.cosine_onecycle_schedule(
273                peak_value=1.0,
274                div_factor=1.0 / init_factor,
275                final_div_factor=1.0 / final_factor,
276                transition_steps=int(anneal_steps * nsteps),
277                pct_start=pct_start,
278            )
279        else:
280            raise ValueError(f"Unknown anneal_type {anneal_type}")
281
282        state["rng_key"] = rng_key
283        state["istep_anneal"] = 0
284
285        rng_key, v_key = jax.random.split(rng_key)
286        Tscale = schedule(0)
287        print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K")
288        vel = (
289            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
290            * (kT * Tscale / mass[:, None]) ** 0.5
291        )
292
293        def thermostat(vel, state):
294            rng_key, noise_key = jax.random.split(state["rng_key"])
295            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
296
297            Tscale = schedule(state["istep_anneal"]) ** 0.5
298            vel = a1 * vel + a2 * Tscale * noise
299            return vel, {
300                **state,
301                "rng_key": rng_key,
302                "istep_anneal": state["istep_anneal"] + 1,
303            }
304
305    else:
306        raise ValueError(f"Unknown thermostat {thermostat_name}")
307
308    return thermostat, postprocess, state, vel,thermostat_name
def initialize_qtb( qtb_parameters, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
311def initialize_qtb(
312    qtb_parameters,
313    fprec,
314    dt,
315    mass,
316    gamma,
317    kT,
318    species,
319    rng_key,
320    adaptive,
321    compute_thermostat_energy=False,
322):
323    state = {}
324    post_state = {}
325    verbose = qtb_parameters.get("verbose", False)
326    if compute_thermostat_energy:
327        state["thermostat_energy"] = 0.0
328
329    mass = jnp.asarray(mass, dtype=fprec)
330
331    nat = species.shape[0]
332    # define type indices
333    species_set = set(species)
334    nspecies = len(species_set)
335    idx = {sp: i for i, sp in enumerate(species_set)}
336    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
337
338    n_of_type = np.zeros(nspecies, dtype=np.int32)
339    for i in range(nspecies):
340        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
341    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
342    mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type
343
344    corr_kin = qtb_parameters.get("corr_kin", -1)
345    do_corr_kin = corr_kin <= 0
346    if do_corr_kin:
347        corr_kin = 1.0
348    state["corr_kin"] = corr_kin
349    post_state["corr_kin_prev"] = corr_kin
350    post_state["do_corr_kin"] = do_corr_kin
351    post_state["isame_kin"] = 0
352
353    # spectra parameters
354    omegasmear = np.pi / dt / 100.0
355    Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS
356    nseg = int(Tseg / dt)
357    Tseg = nseg * dt
358    dom = 2 * np.pi / (3 * Tseg)
359    omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS
360    nom = int(omegacut / dom)
361    omega = dom * np.arange((3 * nseg) // 2 + 1)
362    cutoff = jnp.asarray(
363        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
364    )
365    assert (
366        omegacut < omega[-1]
367    ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1"
368
369    # initialize gammar
370    assert (
371        gamma < 0.5 * omegacut
372    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
373    gammar_min = qtb_parameters.get("gammar_min", 0.1)
374    post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
375
376    # Ornstein-Uhlenbeck correction for colored noise
377    a1 = np.exp(-gamma * dt)
378    OUcorr = jnp.asarray(
379        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
380        dtype=fprec,
381    )
382
383    # hbar schedule
384    classical_kernel = qtb_parameters.get("classical_kernel", False)
385    hbar = qtb_parameters.get("hbar", 1.0) * au.FS
386    u = 0.5 * hbar * np.abs(omega) / kT
387    theta = kT * np.ones_like(omega)
388    if hbar > 0:
389        theta[1:] *= u[1:] / np.tanh(u[1:])
390    theta = jnp.asarray(theta, dtype=fprec)
391
392    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
393    del rng_key
394    post_state["white_noise"] = jax.random.normal(
395        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
396    )
397
398    startsave = qtb_parameters.get("startsave", 1)
399    counter = Counter(nseg, startsave=startsave)
400    state["istep"] = 0
401    post_state["nadapt"] = 0
402    post_state["nsample"] = 0
403
404    write_spectra = qtb_parameters.get("write_spectra", True)
405    do_compute_spectra = write_spectra or adaptive
406
407    if do_compute_spectra:
408        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
409
410        post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec)
411        post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec)
412        post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec)
413        post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec)
414        post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
415        post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
416        post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
417        post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
418
419    if not adaptive:
420        update_gammar = lambda x: x
421    else:
422        # adaptation parameters
423        skipseg = qtb_parameters.get("skipseg", 1)
424
425        adaptation_method = (
426            str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip()
427        )
428        authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"]
429        assert (
430            adaptation_method in authorized_methods
431        ), f"adaptation_method must be one of {authorized_methods}"
432        if adaptation_method == "SIMPLE":
433            agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS
434            assert agamma > 0, "agamma must be positive"
435            a1_ad = agamma * Tseg  #  * gamma
436            print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}")
437
438            def update_gammar(post_state):
439                g = post_state["dFDT"]
440                gammar = post_state["gammar"] - a1_ad * g
441                gammar = jnp.maximum(gammar_min, gammar)
442                return {**post_state, "gammar": gammar}
443
444        elif adaptation_method == "RATIO":
445            tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS
446            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS
447            assert tau_ad > 0, "tau_ad must be positive"
448            print(
449                f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps"
450            )
451            b1 = np.exp(-Tseg / tau_ad)
452            b2 = np.exp(-Tseg / tau_s)
453            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
454            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
455            post_state["n_adabelief"] = 0
456
457            def update_gammar(post_state):
458                n_adabelief = post_state["n_adabelief"] + 1
459                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
460                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
461                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
462                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
463                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
464                gammar = Cvf / (mCvv + 1.0e-8)
465                gammar = jnp.maximum(gammar_min, gammar)
466                return {
467                    **post_state,
468                    "gammar": gammar,
469                    "mCvv_m": mCvv_m,
470                    "Cvf_m": Cvf_m,
471                    "n_adabelief": n_adabelief,
472                }
473
474        elif adaptation_method == "ADABELIEF":
475            agamma = qtb_parameters.get("agamma", 0.1)
476            tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS
477            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS
478            assert tau_ad > 0, "tau_ad must be positive"
479            assert tau_s > 0, "tau_s must be positive"
480            assert agamma > 0, "agamma must be positive"
481            print(
482                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps"
483            )
484
485            a1_ad = agamma * gamma  # * Tseg #* gamma
486            b1 = np.exp(-Tseg / tau_ad)
487            b2 = np.exp(-Tseg / tau_s)
488            post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
489            post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec)
490            post_state["n_adabelief"] = 0
491
492            def update_gammar(post_state):
493                n_adabelief = post_state["n_adabelief"] + 1
494                dFDT = post_state["dFDT"]
495                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
496                dFDT_s = (
497                    post_state["dFDT_s"] * b2
498                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
499                    + 1.0e-8
500                )
501                # bias correction
502                mt = dFDT_m / (1.0 - b1**n_adabelief)
503                st = dFDT_s / (1.0 - b2**n_adabelief)
504                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
505                gammar = jnp.maximum(gammar_min, gammar)
506                return {
507                    **post_state,
508                    "gammar": gammar,
509                    "dFDT_m": dFDT_m,
510                    "n_adabelief": n_adabelief,
511                    "dFDT_s": dFDT_s,
512                }
513
514    def compute_corr_pot(niter=20, verbose=False):
515        if classical_kernel or hbar == 0:
516            return np.ones(nom)
517
518        s_0 = np.array((theta / kT * cutoff)[:nom])
519        s_out, s_rec, _ = deconvolute_spectrum(
520            s_0,
521            omega[:nom],
522            gamma,
523            niter,
524            kernel=kernel_lorentz_pot,
525            trans=True,
526            symmetrize=True,
527            verbose=verbose,
528        )
529        corr_pot = 1.0 + (s_out - s_0) / s_0
530        columns = np.column_stack(
531            (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
532        )
533        np.savetxt(
534            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
535        )
536        return corr_pot
537
538    def compute_corr_kin(post_state, niter=7, verbose=False):
539        if not post_state["do_corr_kin"]:
540            return post_state["corr_kin_prev"], post_state
541        if classical_kernel or hbar == 0:
542            return 1.0, post_state
543
544        K_D = post_state.get("K_D", None)
545        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
546        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
547        s_out, s_rec, K_D = deconvolute_spectrum(
548            s_0,
549            omega[:nom],
550            gamma,
551            niter,
552            kernel=kernel_lorentz,
553            trans=False,
554            symmetrize=True,
555            verbose=verbose,
556            K_D=K_D,
557        )
558        s_out = s_out * theta[:nom] / kT
559        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
560        mCvvsum = mCvv.sum()
561        rec_ratio = mCvvsum / s_rec.sum()
562        if rec_ratio < 0.95 or rec_ratio > 1.05:
563            print(
564                "# WARNING: reconvolution error is too high, corr_kin was not updated"
565            )
566            return
567
568        corr_kin = mCvvsum / s_out.sum()
569        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
570            isame_kin = post_state["isame_kin"] + 1
571        else:
572            isame_kin = 0
573
574        # print("# corr_kin: ", corr_kin)
575        do_corr_kin = post_state["do_corr_kin"]
576        if isame_kin > 10:
577            print(
578                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
579            )
580            do_corr_kin = False
581
582        return corr_kin, {
583            **post_state,
584            "corr_kin_prev": corr_kin,
585            "isame_kin": isame_kin,
586            "do_corr_kin": do_corr_kin,
587            "K_D": K_D,
588        }
589
590    @jax.jit
591    def ff_kernel(post_state):
592        if classical_kernel:
593            kernel = cutoff * (2 * gamma * kT / dt)
594        else:
595            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
596        gamma_ratio = jnp.concatenate(
597            (
598                post_state["gammar"].T * post_state["corr_pot"][:, None],
599                jnp.ones(
600                    (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype
601                ),
602            ),
603            axis=0,
604        )
605        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
606
607    @jax.jit
608    def refresh_force(post_state):
609        rng_key, noise_key = jax.random.split(post_state["rng_key"])
610        white_noise = jnp.concatenate(
611            (
612                post_state["white_noise"][nseg:],
613                jax.random.normal(
614                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
615                ),
616            ),
617            axis=0,
618        )
619        amplitude = ff_kernel(post_state) ** 0.5
620        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
621        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
622        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
623
624    @jax.jit
625    def compute_spectra(force, vel, post_state):
626        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
627        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
628        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
629        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
630        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
631
632        mCvv = (
633            (dt / 3.0)
634            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
635            * mass_idx[:, None]
636            / n_of_type[:, None]
637        )
638        Cvf = (
639            (dt / 3.0)
640            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
641            / n_of_type[:, None]
642        )
643        Cff = (
644            (dt / 3.0)
645            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
646            / n_of_type[:, None]
647        )
648        dFDT = mCvv * post_state["gammar"] - Cvf
649
650        nsinv = 1.0 / post_state["nsample"]
651        b1 = 1.0 - nsinv
652        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
653        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
654        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
655        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
656
657        return {
658            **post_state,
659            "mCvv": mCvv,
660            "Cvf": Cvf,
661            "Cff": Cff,
662            "dFDT": dFDT,
663            "dFDT_avg": dFDT_avg,
664            "mCvv_avg": mCvv_avg,
665            "Cvfg_avg": Cvfg_avg,
666            "Cff_avg": Cff_avg,
667        }
668
669    def write_spectra_to_file(post_state):
670        mCvv_avg = np.array(post_state["mCvv_avg"])
671        Cvfg_avg = np.array(post_state["Cvfg_avg"])
672        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
673        dFDT_avg = np.array(post_state["dFDT_avg"])
674        gammar = np.array(post_state["gammar"])
675        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
676        for i, sp in enumerate(species_set):
677            ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i])
678            columns = np.column_stack(
679                (
680                    omega[:nom] * (au.FS * au.CM1),
681                    mCvv_avg[i],
682                    Cvfg_avg[i],
683                    dFDT_avg[i],
684                    gammar[i] * gamma * (au.FS * au.THZ),
685                    Cff_avg[i] * ff_scale,
686                    Cff_theo[i] * ff_scale,
687                )
688            )
689            np.savetxt(
690                f"QTB_spectra_{sp}.out",
691                columns,
692                fmt="%12.6f",
693                header="#omega mCvv Cvf dFDT gammar Cff",
694            )
695        if verbose:
696            print("# QTB spectra written.")
697
698    if compute_thermostat_energy:
699        state["qtb_energy_flux"] = 0.0
700
701    @jax.jit
702    def thermostat(vel, state):
703        istep = state["istep"]
704        dvel = dt * state["force"][istep] / mass[:, None]
705        new_vel = vel * a1 + dvel
706        new_state = {**state, "istep": istep + 1}
707        if do_compute_spectra:
708            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
709            new_state["vel"] = vel2
710        if compute_thermostat_energy:
711            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
712            ekcorr = (
713                0.5
714                * (mass[:, None] * new_vel**2).sum()
715                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
716            )
717            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
718            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
719        return new_vel, new_state
720
721    @jax.jit
722    def postprocess_work(state, post_state):
723        if do_compute_spectra:
724            post_state = compute_spectra(state["force"], state["vel"], post_state)
725        if adaptive:
726            post_state = jax.lax.cond(
727                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
728            )
729        new_force, post_state = refresh_force(post_state)
730        return {**state, "force": new_force}, post_state
731
732    def postprocess(state, post_state):
733        counter.increment()
734        if not counter.is_reset_step:
735            return state, post_state
736        post_state["nadapt"] += 1
737        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
738        if verbose:
739            print("# Refreshing QTB forces.")
740        state, post_state = postprocess_work(state, post_state)
741        state["corr_kin"], post_state = compute_corr_kin(post_state)
742        state["istep"] = 0
743        if write_spectra:
744            write_spectra_to_file(post_state)
745        return state, post_state
746
747    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(), dtype=fprec)
748
749    state["force"], post_state = refresh_force(post_state)
750    return thermostat, (postprocess, post_state), state