fennol.md.thermostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7import os
  8import pickle
  9
 10from ..utils.atomic_units import AtomicUnits as au  # CM1,THZ,BOHR,MPROT
 11from ..utils import Counter
 12from ..utils.deconvolution import (
 13    deconvolute_spectrum,
 14    kernel_lorentz_pot,
 15    kernel_lorentz,
 16)
 17
 18
 19def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 20    state = {}
 21    postprocess = None
 22    
 23
 24    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 25    """@keyword[fennol_md] thermostat
 26    Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'.
 27    Default: "LGV"
 28    """
 29    compute_thermostat_energy = simulation_parameters.get(
 30        "include_thermostat_energy", False
 31    )
 32    """@keyword[fennol_md] include_thermostat_energy
 33    Include thermostat energy in total energy calculations.
 34    Default: False
 35    """
 36
 37    kT = system_data.get("kT", None)
 38    nbeads = system_data.get("nbeads", None)
 39    mass = system_data["mass"]
 40    gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS
 41    """@keyword[fennol_md] gamma
 42    Friction coefficient for Langevin thermostat.
 43    Default: 1.0 ps^-1
 44    """
 45    species = system_data["species"]
 46
 47    if nbeads is not None:
 48        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 49        """@keyword[fennol_md] trpmd_lambda
 50        Lambda parameter for TRPMD (Thermostatted Ring Polymer MD).
 51        Default: 1.0
 52        """
 53        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 54
 55    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 56        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 57        assert kT is not None, "kT must be specified for QTB thermostat"
 58        assert gamma is not None, "gamma must be specified for QTB thermostat"
 59        rng_key, v_key = jax.random.split(rng_key)
 60        if nbeads is None:
 61            a1 = math.exp(-gamma * dt)
 62            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 63            vel = (
 64                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 65                * (kT / mass[:, None]) ** 0.5
 66            )
 67        else:
 68            if isinstance(gamma, float):
 69                gamma = np.array([gamma] * nbeads)
 70            assert isinstance(
 71                gamma, np.ndarray
 72            ), "gamma must be a float or a numpy array"
 73            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 74            a1 = np.exp(-gamma * dt)[:, None, None]
 75            a2 = jnp.asarray(
 76                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 77            )
 78            vel = (
 79                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 80                * (kT / mass[:, None]) ** 0.5
 81            )
 82
 83        state["rng_key"] = rng_key
 84        if compute_thermostat_energy:
 85            state["thermostat_energy"] = 0.0
 86        if thermostat_name == "FFLGV":
 87            def thermostat(vel, state):
 88                rng_key, noise_key = jax.random.split(state["rng_key"])
 89                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 90                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 91                dirvel = vel / norm_vel
 92                if compute_thermostat_energy:
 93                    v2 = (vel**2).sum(axis=-1)
 94                vel = a1 * vel + a2 * noise
 95                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 96                vel = dirvel * new_norm_vel
 97                new_state = {**state, "rng_key": rng_key}
 98                if compute_thermostat_energy:
 99                    v2new = (vel**2).sum(axis=-1)
100                    new_state["thermostat_energy"] = (
101                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
102                    )
103
104                return vel, new_state
105
106        else:
107            def thermostat(vel, state):
108                rng_key, noise_key = jax.random.split(state["rng_key"])
109                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
110                if compute_thermostat_energy:
111                    v2 = (vel**2).sum(axis=-1)
112                vel = a1 * vel + a2 * noise
113                new_state = {**state, "rng_key": rng_key}
114                if compute_thermostat_energy:
115                    v2new = (vel**2).sum(axis=-1)
116                    new_state["thermostat_energy"] = (
117                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
118                    )
119                return vel, new_state
120
121    elif thermostat_name in ["BUSSI"]:
122        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
123        assert kT is not None, "kT must be specified for QTB thermostat"
124        assert gamma is not None, "gamma must be specified for QTB thermostat"
125        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
126        rng_key, v_key = jax.random.split(rng_key)
127
128        a1 = math.exp(-gamma * dt)
129        a2 = (1 - a1) * kT
130        vel = (
131            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
132            * (kT / mass[:, None]) ** 0.5
133        )
134
135        state["rng_key"] = rng_key
136        if compute_thermostat_energy:
137            state["thermostat_energy"] = 0.0
138
139        def thermostat(vel, state):
140            rng_key, noise_key = jax.random.split(state["rng_key"])
141            new_state = {**state, "rng_key": rng_key}
142            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
143            R2 = jnp.sum(noise**2)
144            R1 = noise[0, 0]
145            c = a2 / (mass[:, None] * vel**2).sum()
146            d = (a1 * c) ** 0.5
147            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
148            if compute_thermostat_energy:
149                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
150                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
151            return scale * vel, new_state
152
153    elif thermostat_name in [
154        "GD",
155        "DESCENT",
156        "GRADIENT_DESCENT",
157        "MIN",
158        "MINIMIZE",
159    ]:
160        assert nbeads is None, "Gradient descent is not compatible with PIMD"
161        a1 = math.exp(-gamma * dt)
162
163        if nbeads is None:
164            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
165        else:
166            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
167
168        def thermostat(vel, state):
169            return a1 * vel, state
170
171    elif thermostat_name in ["NVE", "NONE"]:
172        if nbeads is None:
173            vel = (
174                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
175                * (kT / mass[:, None]) ** 0.5
176            )
177            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
178            vel = vel * (kT / kTsys) ** 0.5
179        else:
180            vel = (
181                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
182                * (kT / mass[None, :, None]) ** 0.5
183            )
184            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
185                mass.shape[0] * 3
186            )
187            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
188        thermostat = lambda x, s: (x, s)
189
190    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
191        assert gamma is not None, "gamma must be specified for QTB thermostat"
192        ndof = mass.shape[0] * 3
193        nkT = ndof * kT
194        nose_mass = nkT / gamma**2
195        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
196        state["nose_s"] = 0.0
197        state["nose_v"] = 0.0
198        if compute_thermostat_energy:
199            state["thermostat_energy"] = 0.0
200        print(
201            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
202        )
203        vel = (
204            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
205            * (kT / mass[:, None]) ** 0.5
206        )
207
208        def thermostat(vel, state):
209            nose_s = state["nose_s"]
210            nose_v = state["nose_v"]
211            kTsys = jnp.sum(mass[:, None] * vel**2)
212            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
213            nose_s = nose_s + dt * nose_v
214            vel = jnp.exp(-nose_v * dt) * vel
215            kTsys = jnp.sum(mass[:, None] * vel**2)
216            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
217            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
218
219            if compute_thermostat_energy:
220                new_state["thermostat_energy"] = (
221                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
222                )
223            return vel, new_state
224
225    elif thermostat_name in ["QTB", "ADQTB"]:
226        assert nbeads is None, "QTB is not compatible with PIMD"
227        qtb_parameters = simulation_parameters.get("qtb", None)
228        """@keyword[fennol_md] qtb
229        Parameters for Quantum Thermal Bath thermostat configuration.
230        Required for QTB/ADQTB thermostats
231        """
232        assert (
233            qtb_parameters is not None
234        ), "qtb_parameters must be provided for QTB thermostat"
235        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
236        assert kT is not None, "kT must be specified for QTB thermostat"
237        assert gamma is not None, "gamma must be specified for QTB thermostat"
238        assert species is not None, "species must be provided for QTB thermostat"
239        rng_key, v_key = jax.random.split(rng_key)
240        vel = (
241            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
242            * (kT / mass[:, None]) ** 0.5
243        )
244
245        thermostat, postprocess, qtb_state = initialize_qtb(
246            qtb_parameters,
247            system_data,
248            fprec=fprec,
249            dt=dt,
250            mass=mass,
251            gamma=gamma,
252            kT=kT,
253            species=species,
254            rng_key=rng_key,
255            adaptive=thermostat_name.startswith("AD"),
256            compute_thermostat_energy=compute_thermostat_energy,
257        )
258        state = {**state, **qtb_state}
259
260    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
261        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
262        assert kT is not None, "kT must be specified for QTB thermostat"
263        assert gamma is not None, "gamma must be specified for QTB thermostat"
264        assert nbeads is None, "ANNEAL is not compatible with PIMD"
265        a1 = math.exp(-gamma * dt)
266        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
267
268        anneal_parameters = simulation_parameters.get("annealing", {})
269        """@keyword[fennol_md] annealing
270        Parameters for simulated annealing schedule configuration.
271        Required for ANNEAL/ANNEALING thermostat
272        """
273        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
274        """@keyword[fennol_md] annealing/init_factor
275        Initial temperature factor for annealing schedule.
276        Default: 0.04 (1/25)
277        """
278        assert init_factor > 0.0, "init_factor must be positive"
279        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
280        """@keyword[fennol_md] annealing/final_factor
281        Final temperature factor for annealing schedule.
282        Default: 0.0001 (1/10000)
283        """
284        assert final_factor > 0.0, "final_factor must be positive"
285        nsteps = simulation_parameters.get("nsteps")
286        """@keyword[fennol_md] nsteps
287        Total number of simulation steps for annealing schedule calculation.
288        Required parameter
289        """
290        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
291        """@keyword[fennol_md] annealing/anneal_steps
292        Fraction of total steps for annealing process.
293        Default: 1.0
294        """
295        assert (
296            anneal_steps < 1.0 and anneal_steps > 0.0
297        ), "warmup_steps must be between 0 and nsteps"
298        pct_start = anneal_parameters.get("warmup_steps", 0.3)
299        """@keyword[fennol_md] annealing/warmup_steps
300        Fraction of annealing steps for warmup phase.
301        Default: 0.3
302        """
303        assert (
304            pct_start < 1.0 and pct_start > 0.0
305        ), "warmup_steps must be between 0 and nsteps"
306
307        anneal_type = anneal_parameters.get("type", "cosine").lower()
308        """@keyword[fennol_md] annealing/type
309        Type of annealing schedule (linear, cosine_onecycle).
310        Default: cosine
311        """
312        if anneal_type == "linear":
313            schedule = optax.linear_onecycle_schedule(
314                peak_value=1.0,
315                div_factor=1.0 / init_factor,
316                final_div_factor=1.0 / final_factor,
317                transition_steps=int(anneal_steps * nsteps),
318                pct_start=pct_start,
319                pct_final=1.0,
320            )
321        elif anneal_type == "cosine_onecycle":
322            schedule = optax.cosine_onecycle_schedule(
323                peak_value=1.0,
324                div_factor=1.0 / init_factor,
325                final_div_factor=1.0 / final_factor,
326                transition_steps=int(anneal_steps * nsteps),
327                pct_start=pct_start,
328            )
329        else:
330            raise ValueError(f"Unknown anneal_type {anneal_type}")
331
332        state["rng_key"] = rng_key
333        state["istep_anneal"] = 0
334
335        rng_key, v_key = jax.random.split(rng_key)
336        Tscale = schedule(0)
337        print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K")
338        vel = (
339            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
340            * (kT * Tscale / mass[:, None]) ** 0.5
341        )
342
343        def thermostat(vel, state):
344            rng_key, noise_key = jax.random.split(state["rng_key"])
345            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
346
347            Tscale = schedule(state["istep_anneal"]) ** 0.5
348            vel = a1 * vel + a2 * Tscale * noise
349            return vel, {
350                **state,
351                "rng_key": rng_key,
352                "istep_anneal": state["istep_anneal"] + 1,
353            }
354
355    else:
356        raise ValueError(f"Unknown thermostat {thermostat_name}")
357
358    return thermostat, postprocess, state, vel,thermostat_name
359
360
361def initialize_qtb(
362    qtb_parameters,
363    system_data,
364    fprec,
365    dt,
366    mass,
367    gamma,
368    kT,
369    species,
370    rng_key,
371    adaptive,
372    compute_thermostat_energy=False,
373):
374    state = {}
375    post_state = {}
376    verbose = qtb_parameters.get("verbose", False)
377    """@keyword[fennol_md] qtb/verbose
378    Print verbose QTB thermostat information.
379    Default: False
380    """
381    if compute_thermostat_energy:
382        state["thermostat_energy"] = 0.0
383
384    mass = jnp.asarray(mass, dtype=fprec)
385
386    nat = species.shape[0]
387    # define type indices
388    species_set = set(species)
389    nspecies = len(species_set)
390    idx = {sp: i for i, sp in enumerate(species_set)}
391    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
392
393    n_of_type = np.zeros(nspecies, dtype=np.int32)
394    for i in range(nspecies):
395        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
396    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
397    mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type
398
399    niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7)
400    """@keyword[fennol_md] qtb/niter_deconv_kin
401    Number of iterations for kinetic energy deconvolution.
402    Default: 7
403    """
404    niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20)
405    """@keyword[fennol_md] qtb/niter_deconv_pot
406    Number of iterations for potential energy deconvolution.
407    Default: 20
408    """
409    corr_kin = qtb_parameters.get("corr_kin", -1)
410    """@keyword[fennol_md] qtb/corr_kin
411    Kinetic energy correction factor for QTB (-1 for automatic).
412    Default: -1
413    """
414    do_corr_kin = corr_kin <= 0
415    if do_corr_kin:
416        corr_kin = 1.0
417    state["corr_kin"] = corr_kin
418    post_state["corr_kin_prev"] = corr_kin
419    post_state["do_corr_kin"] = do_corr_kin
420    post_state["isame_kin"] = 0
421
422    # spectra parameters
423    omegasmear = np.pi / dt / 100.0
424    Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS
425    """@keyword[fennol_md] qtb/tseg
426    Time segment length for QTB spectrum calculation.
427    Default: 1.0 ps
428    """
429    nseg = int(Tseg / dt)
430    Tseg = nseg * dt
431    dom = 2 * np.pi / (3 * Tseg)
432    omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS
433    """@keyword[fennol_md] qtb/omegacut
434    Cutoff frequency for QTB spectrum.
435    Default: 15000.0 cm⁻¹
436    """
437    nom = int(omegacut / dom)
438    omega = dom * np.arange((3 * nseg) // 2 + 1)
439    cutoff = jnp.asarray(
440        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
441    )
442    assert (
443        omegacut < omega[-1]
444    ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1"
445
446    # initialize gammar
447    assert (
448        gamma < 0.5 * omegacut
449    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
450    gammar_min = qtb_parameters.get("gammar_min", 0.1)
451    """@keyword[fennol_md] qtb/gammar_min
452    Minimum value for QTB gamma ratio coefficients.
453    Default: 0.1
454    """
455    # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
456    gammar = np.ones((nspecies, nom), dtype=float)
457    try:
458        for i, sp in enumerate(species_set):
459            if not os.path.exists(f"QTB_spectra_{sp}.out"): continue
460            data = np.loadtxt(f"QTB_spectra_{sp}.out")
461            gammar[i] = data[:, 4]/(gamma*au.FS*au.THZ)
462            print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out")
463    except Exception as e:
464        print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.")
465        gammar[:,:] = 1.0
466    post_state["gammar"] = jnp.asarray(gammar, dtype=fprec)
467
468    # Ornstein-Uhlenbeck correction for colored noise
469    a1 = np.exp(-gamma * dt)
470    OUcorr = jnp.asarray(
471        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
472        dtype=fprec,
473    )
474
475    # hbar schedule
476    classical_kernel = qtb_parameters.get("classical_kernel", False)
477    """@keyword[fennol_md] qtb/classical_kernel
478    Use classical instead of quantum kernel for QTB.
479    Default: False
480    """
481    hbar = qtb_parameters.get("hbar", 1.0) * au.FS
482    """@keyword[fennol_md] qtb/hbar
483    Reduced Planck constant scaling factor for quantum effects.
484    Default: 1.0 a.u.
485    """
486    u = 0.5 * hbar * np.abs(omega) / kT
487    theta = kT * np.ones_like(omega)
488    if hbar > 0:
489        theta[1:] *= u[1:] / np.tanh(u[1:])
490    theta = jnp.asarray(theta, dtype=fprec)
491
492    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
493    del rng_key
494    post_state["white_noise"] = jax.random.normal(
495        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
496    )
497
498    startsave = qtb_parameters.get("startsave", 1)
499    """@keyword[fennol_md] qtb/startsave
500    Start saving QTB statistics after this many segments.
501    Default: 1
502    """
503    counter = Counter(nseg, startsave=startsave)
504    state["istep"] = 0
505    post_state["nadapt"] = 0
506    post_state["nsample"] = 0
507
508    write_spectra = qtb_parameters.get("write_spectra", True)
509    """@keyword[fennol_md] qtb/write_spectra
510    Write QTB spectral analysis output files.
511    Default: True
512    """
513    do_compute_spectra = write_spectra or adaptive
514
515    if do_compute_spectra:
516        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
517
518        post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec)
519        post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec)
520        post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec)
521        post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec)
522        post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
523        post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
524        post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
525        post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
526
527    if not adaptive:
528        update_gammar = lambda x: x
529    else:
530        # adaptation parameters
531        skipseg = qtb_parameters.get("skipseg", 1)
532        """@keyword[fennol_md] qtb/skipseg
533        Number of segments to skip before starting adaptive QTB.
534        Default: 1
535        """
536
537        adaptation_method = (
538            str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip()
539        )
540        """@keyword[fennol_md] qtb/adaptation_method
541        Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF).
542        Default: ADABELIEF
543        """
544        authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"]
545        assert (
546            adaptation_method in authorized_methods
547        ), f"adaptation_method must be one of {authorized_methods}"
548        if adaptation_method == "SIMPLE":
549            agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS
550            """@keyword[fennol_md] qtb/agamma
551            Learning rate for adaptive QTB gamma update.
552            Default: 1.0e-3 (SIMPLE), 0.1 (ADABELIEF)
553            """
554            assert agamma > 0, "agamma must be positive"
555            a1_ad = agamma * Tseg  #  * gamma
556            print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}")
557
558            def update_gammar(post_state):
559                g = post_state["dFDT"]
560                gammar = post_state["gammar"] - a1_ad * g
561                gammar = jnp.maximum(gammar_min, gammar)
562                return {**post_state, "gammar": gammar}
563
564        elif adaptation_method == "RATIO":
565            tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS
566            """@keyword[fennol_md] qtb/tau_ad
567            Adaptation time constant for momentum averaging.
568            Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF)
569            """
570            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS
571            """@keyword[fennol_md] qtb/tau_s
572            Second moment time constant for variance averaging.
573            Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF)
574            """
575            assert tau_ad > 0, "tau_ad must be positive"
576            print(
577                f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps"
578            )
579            b1 = np.exp(-Tseg / tau_ad)
580            b2 = np.exp(-Tseg / tau_s)
581            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
582            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
583            post_state["n_adabelief"] = 0
584
585            def update_gammar(post_state):
586                n_adabelief = post_state["n_adabelief"] + 1
587                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
588                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
589                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
590                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
591                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
592                gammar = Cvf / (mCvv + 1.0e-8)
593                gammar = jnp.maximum(gammar_min, gammar)
594                return {
595                    **post_state,
596                    "gammar": gammar,
597                    "mCvv_m": mCvv_m,
598                    "Cvf_m": Cvf_m,
599                    "n_adabelief": n_adabelief,
600                }
601
602        elif adaptation_method == "ADABELIEF":
603            agamma = qtb_parameters.get("agamma", 0.1)
604            tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS
605            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS
606            assert tau_ad > 0, "tau_ad must be positive"
607            assert tau_s > 0, "tau_s must be positive"
608            assert agamma > 0, "agamma must be positive"
609            print(
610                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps"
611            )
612
613            a1_ad = agamma * gamma  # * Tseg #* gamma
614            b1 = np.exp(-Tseg / tau_ad)
615            b2 = np.exp(-Tseg / tau_s)
616            post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
617            post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec)
618            post_state["n_adabelief"] = 0
619
620            def update_gammar(post_state):
621                n_adabelief = post_state["n_adabelief"] + 1
622                dFDT = post_state["dFDT"]
623                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
624                dFDT_s = (
625                    post_state["dFDT_s"] * b2
626                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
627                    + 1.0e-8
628                )
629                # bias correction
630                mt = dFDT_m / (1.0 - b1**n_adabelief)
631                st = dFDT_s / (1.0 - b2**n_adabelief)
632                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
633                gammar = jnp.maximum(gammar_min, gammar)
634                return {
635                    **post_state,
636                    "gammar": gammar,
637                    "dFDT_m": dFDT_m,
638                    "n_adabelief": n_adabelief,
639                    "dFDT_s": dFDT_s,
640                }
641    
642    #####################
643    # RESTART
644    restart_file = system_data["name"]+".qtb.restart"
645    if os.path.exists(restart_file):
646        with open(restart_file, "rb") as f:
647            data = pickle.load(f)
648            state["corr_kin"] = data["corr_kin"]
649            post_state["corr_kin_prev"] = data["corr_kin"]
650            post_state["isame_kin"] = data["isame_kin"]
651            post_state["do_corr_kin"] = data["do_corr_kin"]
652            print(f"# Restored QTB state from {restart_file}")
653
654    def write_qtb_restart(state, post_state):
655        with open(restart_file, "wb") as f:
656            pickle.dump(
657                {
658                    "corr_kin": state["corr_kin"],
659                    "corr_kin_prev": post_state["corr_kin_prev"],
660                    "isame_kin": post_state["isame_kin"],
661                    "do_corr_kin": post_state["do_corr_kin"],
662                },
663                f,
664            )
665    ######################
666
667    def compute_corr_pot(niter=20, verbose=False):
668        if classical_kernel or hbar == 0:
669            return np.ones(nom)
670
671        s_0 = np.array((theta / kT * cutoff)[:nom])
672        s_out, s_rec, _ = deconvolute_spectrum(
673            s_0,
674            omega[:nom],
675            gamma,
676            niter,
677            kernel=kernel_lorentz_pot,
678            trans=True,
679            symmetrize=True,
680            verbose=verbose,
681        )
682        corr_pot = 1.0 + (s_out - s_0) / s_0
683        columns = np.column_stack(
684            (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
685        )
686        np.savetxt(
687            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
688        )
689        return corr_pot
690
691    def compute_corr_kin(post_state, niter=7, verbose=False):
692        if not post_state["do_corr_kin"]:
693            return post_state["corr_kin_prev"], post_state
694        if classical_kernel or hbar == 0:
695            return 1.0, post_state
696
697        K_D = post_state.get("K_D", None)
698        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
699        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
700        s_out, s_rec, K_D = deconvolute_spectrum(
701            s_0,
702            omega[:nom],
703            gamma,
704            niter,
705            kernel=kernel_lorentz,
706            trans=False,
707            symmetrize=True,
708            verbose=verbose,
709            K_D=K_D,
710        )
711        s_out = s_out * theta[:nom] / kT
712        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
713        mCvvsum = mCvv.sum()
714        rec_ratio = mCvvsum / s_rec.sum()
715        if rec_ratio < 0.95 or rec_ratio > 1.05:
716            print(
717                f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated"
718            )
719            return post_state["corr_kin_prev"], post_state
720
721        corr_kin = mCvvsum / s_out.sum()
722        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
723            isame_kin = post_state["isame_kin"] + 1
724        else:
725            isame_kin = 0
726
727        # print("# corr_kin: ", corr_kin)
728        do_corr_kin = post_state["do_corr_kin"]
729        if isame_kin > 10:
730            print(
731                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
732            )
733            do_corr_kin = False
734
735        return corr_kin, {
736            **post_state,
737            "corr_kin_prev": corr_kin,
738            "isame_kin": isame_kin,
739            "do_corr_kin": do_corr_kin,
740            "K_D": K_D,
741        }
742
743    @jax.jit
744    def ff_kernel(post_state):
745        if classical_kernel:
746            kernel = cutoff * (2 * gamma * kT / dt)
747        else:
748            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
749        gamma_ratio = jnp.concatenate(
750            (
751                post_state["gammar"].T * post_state["corr_pot"][:, None],
752                jnp.ones(
753                    (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype
754                ),
755            ),
756            axis=0,
757        )
758        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
759
760    @jax.jit
761    def refresh_force(post_state):
762        rng_key, noise_key = jax.random.split(post_state["rng_key"])
763        white_noise = jnp.concatenate(
764            (
765                post_state["white_noise"][nseg:],
766                jax.random.normal(
767                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
768                ),
769            ),
770            axis=0,
771        )
772        amplitude = ff_kernel(post_state) ** 0.5
773        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
774        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
775        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
776
777    @jax.jit
778    def compute_spectra(force, vel, post_state):
779        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
780        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
781        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
782        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
783        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
784
785        mCvv = (
786            (dt / 3.0)
787            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
788            * mass_idx[:, None]
789            / n_of_type[:, None]
790        )
791        Cvf = (
792            (dt / 3.0)
793            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
794            / n_of_type[:, None]
795        )
796        Cff = (
797            (dt / 3.0)
798            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
799            / n_of_type[:, None]
800        )
801        dFDT = mCvv * post_state["gammar"] - Cvf
802
803        nsinv = 1.0 / post_state["nsample"]
804        b1 = 1.0 - nsinv
805        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
806        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
807        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
808        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
809
810        return {
811            **post_state,
812            "mCvv": mCvv,
813            "Cvf": Cvf,
814            "Cff": Cff,
815            "dFDT": dFDT,
816            "dFDT_avg": dFDT_avg,
817            "mCvv_avg": mCvv_avg,
818            "Cvfg_avg": Cvfg_avg,
819            "Cff_avg": Cff_avg,
820        }
821
822    def write_spectra_to_file(post_state):
823        mCvv_avg = np.array(post_state["mCvv_avg"])
824        Cvfg_avg = np.array(post_state["Cvfg_avg"])
825        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
826        dFDT_avg = np.array(post_state["dFDT_avg"])
827        gammar = np.array(post_state["gammar"])
828        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
829        for i, sp in enumerate(species_set):
830            ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i])
831            columns = np.column_stack(
832                (
833                    omega[:nom] * (au.FS * au.CM1),
834                    mCvv_avg[i],
835                    Cvfg_avg[i],
836                    dFDT_avg[i],
837                    gammar[i] * gamma * (au.FS * au.THZ),
838                    Cff_avg[i] * ff_scale,
839                    Cff_theo[i] * ff_scale,
840                )
841            )
842            np.savetxt(
843                f"QTB_spectra_{sp}.out",
844                columns,
845                fmt="%12.6f",
846                header="#omega mCvv Cvf dFDT gammar Cff",
847            )
848        if verbose:
849            print("# QTB spectra written.")
850
851    if compute_thermostat_energy:
852        state["qtb_energy_flux"] = 0.0
853
854    @jax.jit
855    def thermostat(vel, state):
856        istep = state["istep"]
857        dvel = dt * state["force"][istep] / mass[:, None]
858        new_vel = vel * a1 + dvel
859        new_state = {**state, "istep": istep + 1}
860        if do_compute_spectra:
861            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
862            new_state["vel"] = vel2
863        if compute_thermostat_energy:
864            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
865            ekcorr = (
866                0.5
867                * (mass[:, None] * new_vel**2).sum()
868                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
869            )
870            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
871            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
872        return new_vel, new_state
873
874    @jax.jit
875    def postprocess_work(state, post_state):
876        if do_compute_spectra:
877            post_state = compute_spectra(state["force"], state["vel"], post_state)
878        if adaptive:
879            post_state = jax.lax.cond(
880                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
881            )
882        new_force, post_state = refresh_force(post_state)
883        return {**state, "force": new_force}, post_state
884
885    def postprocess(state, post_state):
886        counter.increment()
887        if not counter.is_reset_step:
888            return state, post_state
889        post_state["nadapt"] += 1
890        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
891        if verbose:
892            print("# Refreshing QTB forces.")
893        state, post_state = postprocess_work(state, post_state)
894        state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin)
895        state["istep"] = 0
896        if write_spectra:
897            write_spectra_to_file(post_state)
898        write_qtb_restart(state, post_state)
899        return state, post_state
900
901    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec)
902
903    state["force"], post_state = refresh_force(post_state)
904    return thermostat, (postprocess, post_state), state
def get_thermostat( simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 20def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 21    state = {}
 22    postprocess = None
 23    
 24
 25    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 26    """@keyword[fennol_md] thermostat
 27    Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'.
 28    Default: "LGV"
 29    """
 30    compute_thermostat_energy = simulation_parameters.get(
 31        "include_thermostat_energy", False
 32    )
 33    """@keyword[fennol_md] include_thermostat_energy
 34    Include thermostat energy in total energy calculations.
 35    Default: False
 36    """
 37
 38    kT = system_data.get("kT", None)
 39    nbeads = system_data.get("nbeads", None)
 40    mass = system_data["mass"]
 41    gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS
 42    """@keyword[fennol_md] gamma
 43    Friction coefficient for Langevin thermostat.
 44    Default: 1.0 ps^-1
 45    """
 46    species = system_data["species"]
 47
 48    if nbeads is not None:
 49        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 50        """@keyword[fennol_md] trpmd_lambda
 51        Lambda parameter for TRPMD (Thermostatted Ring Polymer MD).
 52        Default: 1.0
 53        """
 54        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 55
 56    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 57        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 58        assert kT is not None, "kT must be specified for QTB thermostat"
 59        assert gamma is not None, "gamma must be specified for QTB thermostat"
 60        rng_key, v_key = jax.random.split(rng_key)
 61        if nbeads is None:
 62            a1 = math.exp(-gamma * dt)
 63            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 64            vel = (
 65                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 66                * (kT / mass[:, None]) ** 0.5
 67            )
 68        else:
 69            if isinstance(gamma, float):
 70                gamma = np.array([gamma] * nbeads)
 71            assert isinstance(
 72                gamma, np.ndarray
 73            ), "gamma must be a float or a numpy array"
 74            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 75            a1 = np.exp(-gamma * dt)[:, None, None]
 76            a2 = jnp.asarray(
 77                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 78            )
 79            vel = (
 80                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 81                * (kT / mass[:, None]) ** 0.5
 82            )
 83
 84        state["rng_key"] = rng_key
 85        if compute_thermostat_energy:
 86            state["thermostat_energy"] = 0.0
 87        if thermostat_name == "FFLGV":
 88            def thermostat(vel, state):
 89                rng_key, noise_key = jax.random.split(state["rng_key"])
 90                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 91                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 92                dirvel = vel / norm_vel
 93                if compute_thermostat_energy:
 94                    v2 = (vel**2).sum(axis=-1)
 95                vel = a1 * vel + a2 * noise
 96                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 97                vel = dirvel * new_norm_vel
 98                new_state = {**state, "rng_key": rng_key}
 99                if compute_thermostat_energy:
100                    v2new = (vel**2).sum(axis=-1)
101                    new_state["thermostat_energy"] = (
102                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
103                    )
104
105                return vel, new_state
106
107        else:
108            def thermostat(vel, state):
109                rng_key, noise_key = jax.random.split(state["rng_key"])
110                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
111                if compute_thermostat_energy:
112                    v2 = (vel**2).sum(axis=-1)
113                vel = a1 * vel + a2 * noise
114                new_state = {**state, "rng_key": rng_key}
115                if compute_thermostat_energy:
116                    v2new = (vel**2).sum(axis=-1)
117                    new_state["thermostat_energy"] = (
118                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
119                    )
120                return vel, new_state
121
122    elif thermostat_name in ["BUSSI"]:
123        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
124        assert kT is not None, "kT must be specified for QTB thermostat"
125        assert gamma is not None, "gamma must be specified for QTB thermostat"
126        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
127        rng_key, v_key = jax.random.split(rng_key)
128
129        a1 = math.exp(-gamma * dt)
130        a2 = (1 - a1) * kT
131        vel = (
132            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
133            * (kT / mass[:, None]) ** 0.5
134        )
135
136        state["rng_key"] = rng_key
137        if compute_thermostat_energy:
138            state["thermostat_energy"] = 0.0
139
140        def thermostat(vel, state):
141            rng_key, noise_key = jax.random.split(state["rng_key"])
142            new_state = {**state, "rng_key": rng_key}
143            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
144            R2 = jnp.sum(noise**2)
145            R1 = noise[0, 0]
146            c = a2 / (mass[:, None] * vel**2).sum()
147            d = (a1 * c) ** 0.5
148            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
149            if compute_thermostat_energy:
150                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
151                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
152            return scale * vel, new_state
153
154    elif thermostat_name in [
155        "GD",
156        "DESCENT",
157        "GRADIENT_DESCENT",
158        "MIN",
159        "MINIMIZE",
160    ]:
161        assert nbeads is None, "Gradient descent is not compatible with PIMD"
162        a1 = math.exp(-gamma * dt)
163
164        if nbeads is None:
165            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
166        else:
167            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
168
169        def thermostat(vel, state):
170            return a1 * vel, state
171
172    elif thermostat_name in ["NVE", "NONE"]:
173        if nbeads is None:
174            vel = (
175                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
176                * (kT / mass[:, None]) ** 0.5
177            )
178            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
179            vel = vel * (kT / kTsys) ** 0.5
180        else:
181            vel = (
182                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
183                * (kT / mass[None, :, None]) ** 0.5
184            )
185            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
186                mass.shape[0] * 3
187            )
188            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
189        thermostat = lambda x, s: (x, s)
190
191    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
192        assert gamma is not None, "gamma must be specified for QTB thermostat"
193        ndof = mass.shape[0] * 3
194        nkT = ndof * kT
195        nose_mass = nkT / gamma**2
196        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
197        state["nose_s"] = 0.0
198        state["nose_v"] = 0.0
199        if compute_thermostat_energy:
200            state["thermostat_energy"] = 0.0
201        print(
202            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
203        )
204        vel = (
205            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
206            * (kT / mass[:, None]) ** 0.5
207        )
208
209        def thermostat(vel, state):
210            nose_s = state["nose_s"]
211            nose_v = state["nose_v"]
212            kTsys = jnp.sum(mass[:, None] * vel**2)
213            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
214            nose_s = nose_s + dt * nose_v
215            vel = jnp.exp(-nose_v * dt) * vel
216            kTsys = jnp.sum(mass[:, None] * vel**2)
217            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
218            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
219
220            if compute_thermostat_energy:
221                new_state["thermostat_energy"] = (
222                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
223                )
224            return vel, new_state
225
226    elif thermostat_name in ["QTB", "ADQTB"]:
227        assert nbeads is None, "QTB is not compatible with PIMD"
228        qtb_parameters = simulation_parameters.get("qtb", None)
229        """@keyword[fennol_md] qtb
230        Parameters for Quantum Thermal Bath thermostat configuration.
231        Required for QTB/ADQTB thermostats
232        """
233        assert (
234            qtb_parameters is not None
235        ), "qtb_parameters must be provided for QTB thermostat"
236        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
237        assert kT is not None, "kT must be specified for QTB thermostat"
238        assert gamma is not None, "gamma must be specified for QTB thermostat"
239        assert species is not None, "species must be provided for QTB thermostat"
240        rng_key, v_key = jax.random.split(rng_key)
241        vel = (
242            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
243            * (kT / mass[:, None]) ** 0.5
244        )
245
246        thermostat, postprocess, qtb_state = initialize_qtb(
247            qtb_parameters,
248            system_data,
249            fprec=fprec,
250            dt=dt,
251            mass=mass,
252            gamma=gamma,
253            kT=kT,
254            species=species,
255            rng_key=rng_key,
256            adaptive=thermostat_name.startswith("AD"),
257            compute_thermostat_energy=compute_thermostat_energy,
258        )
259        state = {**state, **qtb_state}
260
261    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
262        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
263        assert kT is not None, "kT must be specified for QTB thermostat"
264        assert gamma is not None, "gamma must be specified for QTB thermostat"
265        assert nbeads is None, "ANNEAL is not compatible with PIMD"
266        a1 = math.exp(-gamma * dt)
267        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
268
269        anneal_parameters = simulation_parameters.get("annealing", {})
270        """@keyword[fennol_md] annealing
271        Parameters for simulated annealing schedule configuration.
272        Required for ANNEAL/ANNEALING thermostat
273        """
274        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
275        """@keyword[fennol_md] annealing/init_factor
276        Initial temperature factor for annealing schedule.
277        Default: 0.04 (1/25)
278        """
279        assert init_factor > 0.0, "init_factor must be positive"
280        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
281        """@keyword[fennol_md] annealing/final_factor
282        Final temperature factor for annealing schedule.
283        Default: 0.0001 (1/10000)
284        """
285        assert final_factor > 0.0, "final_factor must be positive"
286        nsteps = simulation_parameters.get("nsteps")
287        """@keyword[fennol_md] nsteps
288        Total number of simulation steps for annealing schedule calculation.
289        Required parameter
290        """
291        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
292        """@keyword[fennol_md] annealing/anneal_steps
293        Fraction of total steps for annealing process.
294        Default: 1.0
295        """
296        assert (
297            anneal_steps < 1.0 and anneal_steps > 0.0
298        ), "warmup_steps must be between 0 and nsteps"
299        pct_start = anneal_parameters.get("warmup_steps", 0.3)
300        """@keyword[fennol_md] annealing/warmup_steps
301        Fraction of annealing steps for warmup phase.
302        Default: 0.3
303        """
304        assert (
305            pct_start < 1.0 and pct_start > 0.0
306        ), "warmup_steps must be between 0 and nsteps"
307
308        anneal_type = anneal_parameters.get("type", "cosine").lower()
309        """@keyword[fennol_md] annealing/type
310        Type of annealing schedule (linear, cosine_onecycle).
311        Default: cosine
312        """
313        if anneal_type == "linear":
314            schedule = optax.linear_onecycle_schedule(
315                peak_value=1.0,
316                div_factor=1.0 / init_factor,
317                final_div_factor=1.0 / final_factor,
318                transition_steps=int(anneal_steps * nsteps),
319                pct_start=pct_start,
320                pct_final=1.0,
321            )
322        elif anneal_type == "cosine_onecycle":
323            schedule = optax.cosine_onecycle_schedule(
324                peak_value=1.0,
325                div_factor=1.0 / init_factor,
326                final_div_factor=1.0 / final_factor,
327                transition_steps=int(anneal_steps * nsteps),
328                pct_start=pct_start,
329            )
330        else:
331            raise ValueError(f"Unknown anneal_type {anneal_type}")
332
333        state["rng_key"] = rng_key
334        state["istep_anneal"] = 0
335
336        rng_key, v_key = jax.random.split(rng_key)
337        Tscale = schedule(0)
338        print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K")
339        vel = (
340            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
341            * (kT * Tscale / mass[:, None]) ** 0.5
342        )
343
344        def thermostat(vel, state):
345            rng_key, noise_key = jax.random.split(state["rng_key"])
346            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
347
348            Tscale = schedule(state["istep_anneal"]) ** 0.5
349            vel = a1 * vel + a2 * Tscale * noise
350            return vel, {
351                **state,
352                "rng_key": rng_key,
353                "istep_anneal": state["istep_anneal"] + 1,
354            }
355
356    else:
357        raise ValueError(f"Unknown thermostat {thermostat_name}")
358
359    return thermostat, postprocess, state, vel,thermostat_name
def initialize_qtb( qtb_parameters, system_data, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
362def initialize_qtb(
363    qtb_parameters,
364    system_data,
365    fprec,
366    dt,
367    mass,
368    gamma,
369    kT,
370    species,
371    rng_key,
372    adaptive,
373    compute_thermostat_energy=False,
374):
375    state = {}
376    post_state = {}
377    verbose = qtb_parameters.get("verbose", False)
378    """@keyword[fennol_md] qtb/verbose
379    Print verbose QTB thermostat information.
380    Default: False
381    """
382    if compute_thermostat_energy:
383        state["thermostat_energy"] = 0.0
384
385    mass = jnp.asarray(mass, dtype=fprec)
386
387    nat = species.shape[0]
388    # define type indices
389    species_set = set(species)
390    nspecies = len(species_set)
391    idx = {sp: i for i, sp in enumerate(species_set)}
392    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
393
394    n_of_type = np.zeros(nspecies, dtype=np.int32)
395    for i in range(nspecies):
396        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
397    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
398    mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type
399
400    niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7)
401    """@keyword[fennol_md] qtb/niter_deconv_kin
402    Number of iterations for kinetic energy deconvolution.
403    Default: 7
404    """
405    niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20)
406    """@keyword[fennol_md] qtb/niter_deconv_pot
407    Number of iterations for potential energy deconvolution.
408    Default: 20
409    """
410    corr_kin = qtb_parameters.get("corr_kin", -1)
411    """@keyword[fennol_md] qtb/corr_kin
412    Kinetic energy correction factor for QTB (-1 for automatic).
413    Default: -1
414    """
415    do_corr_kin = corr_kin <= 0
416    if do_corr_kin:
417        corr_kin = 1.0
418    state["corr_kin"] = corr_kin
419    post_state["corr_kin_prev"] = corr_kin
420    post_state["do_corr_kin"] = do_corr_kin
421    post_state["isame_kin"] = 0
422
423    # spectra parameters
424    omegasmear = np.pi / dt / 100.0
425    Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS
426    """@keyword[fennol_md] qtb/tseg
427    Time segment length for QTB spectrum calculation.
428    Default: 1.0 ps
429    """
430    nseg = int(Tseg / dt)
431    Tseg = nseg * dt
432    dom = 2 * np.pi / (3 * Tseg)
433    omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS
434    """@keyword[fennol_md] qtb/omegacut
435    Cutoff frequency for QTB spectrum.
436    Default: 15000.0 cm⁻¹
437    """
438    nom = int(omegacut / dom)
439    omega = dom * np.arange((3 * nseg) // 2 + 1)
440    cutoff = jnp.asarray(
441        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
442    )
443    assert (
444        omegacut < omega[-1]
445    ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1"
446
447    # initialize gammar
448    assert (
449        gamma < 0.5 * omegacut
450    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
451    gammar_min = qtb_parameters.get("gammar_min", 0.1)
452    """@keyword[fennol_md] qtb/gammar_min
453    Minimum value for QTB gamma ratio coefficients.
454    Default: 0.1
455    """
456    # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
457    gammar = np.ones((nspecies, nom), dtype=float)
458    try:
459        for i, sp in enumerate(species_set):
460            if not os.path.exists(f"QTB_spectra_{sp}.out"): continue
461            data = np.loadtxt(f"QTB_spectra_{sp}.out")
462            gammar[i] = data[:, 4]/(gamma*au.FS*au.THZ)
463            print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out")
464    except Exception as e:
465        print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.")
466        gammar[:,:] = 1.0
467    post_state["gammar"] = jnp.asarray(gammar, dtype=fprec)
468
469    # Ornstein-Uhlenbeck correction for colored noise
470    a1 = np.exp(-gamma * dt)
471    OUcorr = jnp.asarray(
472        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
473        dtype=fprec,
474    )
475
476    # hbar schedule
477    classical_kernel = qtb_parameters.get("classical_kernel", False)
478    """@keyword[fennol_md] qtb/classical_kernel
479    Use classical instead of quantum kernel for QTB.
480    Default: False
481    """
482    hbar = qtb_parameters.get("hbar", 1.0) * au.FS
483    """@keyword[fennol_md] qtb/hbar
484    Reduced Planck constant scaling factor for quantum effects.
485    Default: 1.0 a.u.
486    """
487    u = 0.5 * hbar * np.abs(omega) / kT
488    theta = kT * np.ones_like(omega)
489    if hbar > 0:
490        theta[1:] *= u[1:] / np.tanh(u[1:])
491    theta = jnp.asarray(theta, dtype=fprec)
492
493    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
494    del rng_key
495    post_state["white_noise"] = jax.random.normal(
496        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
497    )
498
499    startsave = qtb_parameters.get("startsave", 1)
500    """@keyword[fennol_md] qtb/startsave
501    Start saving QTB statistics after this many segments.
502    Default: 1
503    """
504    counter = Counter(nseg, startsave=startsave)
505    state["istep"] = 0
506    post_state["nadapt"] = 0
507    post_state["nsample"] = 0
508
509    write_spectra = qtb_parameters.get("write_spectra", True)
510    """@keyword[fennol_md] qtb/write_spectra
511    Write QTB spectral analysis output files.
512    Default: True
513    """
514    do_compute_spectra = write_spectra or adaptive
515
516    if do_compute_spectra:
517        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
518
519        post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec)
520        post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec)
521        post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec)
522        post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec)
523        post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
524        post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
525        post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
526        post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
527
528    if not adaptive:
529        update_gammar = lambda x: x
530    else:
531        # adaptation parameters
532        skipseg = qtb_parameters.get("skipseg", 1)
533        """@keyword[fennol_md] qtb/skipseg
534        Number of segments to skip before starting adaptive QTB.
535        Default: 1
536        """
537
538        adaptation_method = (
539            str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip()
540        )
541        """@keyword[fennol_md] qtb/adaptation_method
542        Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF).
543        Default: ADABELIEF
544        """
545        authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"]
546        assert (
547            adaptation_method in authorized_methods
548        ), f"adaptation_method must be one of {authorized_methods}"
549        if adaptation_method == "SIMPLE":
550            agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS
551            """@keyword[fennol_md] qtb/agamma
552            Learning rate for adaptive QTB gamma update.
553            Default: 1.0e-3 (SIMPLE), 0.1 (ADABELIEF)
554            """
555            assert agamma > 0, "agamma must be positive"
556            a1_ad = agamma * Tseg  #  * gamma
557            print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}")
558
559            def update_gammar(post_state):
560                g = post_state["dFDT"]
561                gammar = post_state["gammar"] - a1_ad * g
562                gammar = jnp.maximum(gammar_min, gammar)
563                return {**post_state, "gammar": gammar}
564
565        elif adaptation_method == "RATIO":
566            tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS
567            """@keyword[fennol_md] qtb/tau_ad
568            Adaptation time constant for momentum averaging.
569            Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF)
570            """
571            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS
572            """@keyword[fennol_md] qtb/tau_s
573            Second moment time constant for variance averaging.
574            Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF)
575            """
576            assert tau_ad > 0, "tau_ad must be positive"
577            print(
578                f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps"
579            )
580            b1 = np.exp(-Tseg / tau_ad)
581            b2 = np.exp(-Tseg / tau_s)
582            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
583            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
584            post_state["n_adabelief"] = 0
585
586            def update_gammar(post_state):
587                n_adabelief = post_state["n_adabelief"] + 1
588                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
589                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
590                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
591                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
592                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
593                gammar = Cvf / (mCvv + 1.0e-8)
594                gammar = jnp.maximum(gammar_min, gammar)
595                return {
596                    **post_state,
597                    "gammar": gammar,
598                    "mCvv_m": mCvv_m,
599                    "Cvf_m": Cvf_m,
600                    "n_adabelief": n_adabelief,
601                }
602
603        elif adaptation_method == "ADABELIEF":
604            agamma = qtb_parameters.get("agamma", 0.1)
605            tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS
606            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS
607            assert tau_ad > 0, "tau_ad must be positive"
608            assert tau_s > 0, "tau_s must be positive"
609            assert agamma > 0, "agamma must be positive"
610            print(
611                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps"
612            )
613
614            a1_ad = agamma * gamma  # * Tseg #* gamma
615            b1 = np.exp(-Tseg / tau_ad)
616            b2 = np.exp(-Tseg / tau_s)
617            post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
618            post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec)
619            post_state["n_adabelief"] = 0
620
621            def update_gammar(post_state):
622                n_adabelief = post_state["n_adabelief"] + 1
623                dFDT = post_state["dFDT"]
624                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
625                dFDT_s = (
626                    post_state["dFDT_s"] * b2
627                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
628                    + 1.0e-8
629                )
630                # bias correction
631                mt = dFDT_m / (1.0 - b1**n_adabelief)
632                st = dFDT_s / (1.0 - b2**n_adabelief)
633                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
634                gammar = jnp.maximum(gammar_min, gammar)
635                return {
636                    **post_state,
637                    "gammar": gammar,
638                    "dFDT_m": dFDT_m,
639                    "n_adabelief": n_adabelief,
640                    "dFDT_s": dFDT_s,
641                }
642    
643    #####################
644    # RESTART
645    restart_file = system_data["name"]+".qtb.restart"
646    if os.path.exists(restart_file):
647        with open(restart_file, "rb") as f:
648            data = pickle.load(f)
649            state["corr_kin"] = data["corr_kin"]
650            post_state["corr_kin_prev"] = data["corr_kin"]
651            post_state["isame_kin"] = data["isame_kin"]
652            post_state["do_corr_kin"] = data["do_corr_kin"]
653            print(f"# Restored QTB state from {restart_file}")
654
655    def write_qtb_restart(state, post_state):
656        with open(restart_file, "wb") as f:
657            pickle.dump(
658                {
659                    "corr_kin": state["corr_kin"],
660                    "corr_kin_prev": post_state["corr_kin_prev"],
661                    "isame_kin": post_state["isame_kin"],
662                    "do_corr_kin": post_state["do_corr_kin"],
663                },
664                f,
665            )
666    ######################
667
668    def compute_corr_pot(niter=20, verbose=False):
669        if classical_kernel or hbar == 0:
670            return np.ones(nom)
671
672        s_0 = np.array((theta / kT * cutoff)[:nom])
673        s_out, s_rec, _ = deconvolute_spectrum(
674            s_0,
675            omega[:nom],
676            gamma,
677            niter,
678            kernel=kernel_lorentz_pot,
679            trans=True,
680            symmetrize=True,
681            verbose=verbose,
682        )
683        corr_pot = 1.0 + (s_out - s_0) / s_0
684        columns = np.column_stack(
685            (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
686        )
687        np.savetxt(
688            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
689        )
690        return corr_pot
691
692    def compute_corr_kin(post_state, niter=7, verbose=False):
693        if not post_state["do_corr_kin"]:
694            return post_state["corr_kin_prev"], post_state
695        if classical_kernel or hbar == 0:
696            return 1.0, post_state
697
698        K_D = post_state.get("K_D", None)
699        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
700        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
701        s_out, s_rec, K_D = deconvolute_spectrum(
702            s_0,
703            omega[:nom],
704            gamma,
705            niter,
706            kernel=kernel_lorentz,
707            trans=False,
708            symmetrize=True,
709            verbose=verbose,
710            K_D=K_D,
711        )
712        s_out = s_out * theta[:nom] / kT
713        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
714        mCvvsum = mCvv.sum()
715        rec_ratio = mCvvsum / s_rec.sum()
716        if rec_ratio < 0.95 or rec_ratio > 1.05:
717            print(
718                f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated"
719            )
720            return post_state["corr_kin_prev"], post_state
721
722        corr_kin = mCvvsum / s_out.sum()
723        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
724            isame_kin = post_state["isame_kin"] + 1
725        else:
726            isame_kin = 0
727
728        # print("# corr_kin: ", corr_kin)
729        do_corr_kin = post_state["do_corr_kin"]
730        if isame_kin > 10:
731            print(
732                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
733            )
734            do_corr_kin = False
735
736        return corr_kin, {
737            **post_state,
738            "corr_kin_prev": corr_kin,
739            "isame_kin": isame_kin,
740            "do_corr_kin": do_corr_kin,
741            "K_D": K_D,
742        }
743
744    @jax.jit
745    def ff_kernel(post_state):
746        if classical_kernel:
747            kernel = cutoff * (2 * gamma * kT / dt)
748        else:
749            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
750        gamma_ratio = jnp.concatenate(
751            (
752                post_state["gammar"].T * post_state["corr_pot"][:, None],
753                jnp.ones(
754                    (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype
755                ),
756            ),
757            axis=0,
758        )
759        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
760
761    @jax.jit
762    def refresh_force(post_state):
763        rng_key, noise_key = jax.random.split(post_state["rng_key"])
764        white_noise = jnp.concatenate(
765            (
766                post_state["white_noise"][nseg:],
767                jax.random.normal(
768                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
769                ),
770            ),
771            axis=0,
772        )
773        amplitude = ff_kernel(post_state) ** 0.5
774        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
775        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
776        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
777
778    @jax.jit
779    def compute_spectra(force, vel, post_state):
780        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
781        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
782        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
783        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
784        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
785
786        mCvv = (
787            (dt / 3.0)
788            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
789            * mass_idx[:, None]
790            / n_of_type[:, None]
791        )
792        Cvf = (
793            (dt / 3.0)
794            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
795            / n_of_type[:, None]
796        )
797        Cff = (
798            (dt / 3.0)
799            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
800            / n_of_type[:, None]
801        )
802        dFDT = mCvv * post_state["gammar"] - Cvf
803
804        nsinv = 1.0 / post_state["nsample"]
805        b1 = 1.0 - nsinv
806        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
807        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
808        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
809        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
810
811        return {
812            **post_state,
813            "mCvv": mCvv,
814            "Cvf": Cvf,
815            "Cff": Cff,
816            "dFDT": dFDT,
817            "dFDT_avg": dFDT_avg,
818            "mCvv_avg": mCvv_avg,
819            "Cvfg_avg": Cvfg_avg,
820            "Cff_avg": Cff_avg,
821        }
822
823    def write_spectra_to_file(post_state):
824        mCvv_avg = np.array(post_state["mCvv_avg"])
825        Cvfg_avg = np.array(post_state["Cvfg_avg"])
826        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
827        dFDT_avg = np.array(post_state["dFDT_avg"])
828        gammar = np.array(post_state["gammar"])
829        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
830        for i, sp in enumerate(species_set):
831            ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i])
832            columns = np.column_stack(
833                (
834                    omega[:nom] * (au.FS * au.CM1),
835                    mCvv_avg[i],
836                    Cvfg_avg[i],
837                    dFDT_avg[i],
838                    gammar[i] * gamma * (au.FS * au.THZ),
839                    Cff_avg[i] * ff_scale,
840                    Cff_theo[i] * ff_scale,
841                )
842            )
843            np.savetxt(
844                f"QTB_spectra_{sp}.out",
845                columns,
846                fmt="%12.6f",
847                header="#omega mCvv Cvf dFDT gammar Cff",
848            )
849        if verbose:
850            print("# QTB spectra written.")
851
852    if compute_thermostat_energy:
853        state["qtb_energy_flux"] = 0.0
854
855    @jax.jit
856    def thermostat(vel, state):
857        istep = state["istep"]
858        dvel = dt * state["force"][istep] / mass[:, None]
859        new_vel = vel * a1 + dvel
860        new_state = {**state, "istep": istep + 1}
861        if do_compute_spectra:
862            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
863            new_state["vel"] = vel2
864        if compute_thermostat_energy:
865            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
866            ekcorr = (
867                0.5
868                * (mass[:, None] * new_vel**2).sum()
869                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
870            )
871            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
872            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
873        return new_vel, new_state
874
875    @jax.jit
876    def postprocess_work(state, post_state):
877        if do_compute_spectra:
878            post_state = compute_spectra(state["force"], state["vel"], post_state)
879        if adaptive:
880            post_state = jax.lax.cond(
881                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
882            )
883        new_force, post_state = refresh_force(post_state)
884        return {**state, "force": new_force}, post_state
885
886    def postprocess(state, post_state):
887        counter.increment()
888        if not counter.is_reset_step:
889            return state, post_state
890        post_state["nadapt"] += 1
891        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
892        if verbose:
893            print("# Refreshing QTB forces.")
894        state, post_state = postprocess_work(state, post_state)
895        state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin)
896        state["istep"] = 0
897        if write_spectra:
898            write_spectra_to_file(post_state)
899        write_qtb_restart(state, post_state)
900        return state, post_state
901
902    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec)
903
904    state["force"], post_state = refresh_force(post_state)
905    return thermostat, (postprocess, post_state), state