fennol.md.thermostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7import os
  8import pickle
  9
 10from .utils import us 
 11from ..utils import Counter,read_tinker_interval
 12from ..utils.deconvolution import (
 13    deconvolute_spectrum,
 14    kernel_lorentz_pot,
 15    kernel_lorentz,
 16)
 17from ..utils.periodic_table import PERIODIC_TABLE
 18
 19
 20def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 21    state = {}
 22    postprocess = None
 23    
 24
 25    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 26    """@keyword[fennol_md] thermostat
 27    Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'.
 28    Default: "LGV"
 29    """
 30    compute_thermostat_energy = simulation_parameters.get(
 31        "include_thermostat_energy", False
 32    )
 33    """@keyword[fennol_md] include_thermostat_energy
 34    Include thermostat energy in total energy calculations.
 35    Default: False
 36    """
 37
 38    kT = system_data.get("kT", None)
 39    nbeads = system_data.get("nbeads", None)
 40    mass = system_data["mass"]
 41    gamma = simulation_parameters.get("gamma", 1.0 / us.THZ)
 42    """@keyword[fennol_md] gamma
 43    Friction coefficient for Langevin thermostat.
 44    Default: 1.0 ps^-1
 45    """
 46    species = system_data["species"]
 47
 48    if nbeads is not None:
 49        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 50        """@keyword[fennol_md] trpmd_lambda
 51        Lambda parameter for TRPMD (Thermostatted Ring Polymer MD).
 52        Default: 1.0
 53        """
 54        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 55
 56    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 57        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 58        assert kT is not None, "kT must be specified for QTB thermostat"
 59        assert gamma is not None, "gamma must be specified for QTB thermostat"
 60        rng_key, v_key = jax.random.split(rng_key)
 61        if nbeads is None:
 62            a1 = math.exp(-gamma * dt)
 63            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 64            vel = (
 65                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 66                * (kT / mass[:, None]) ** 0.5
 67            )
 68        else:
 69            if isinstance(gamma, float):
 70                gamma = np.array([gamma] * nbeads)
 71            assert isinstance(
 72                gamma, np.ndarray
 73            ), "gamma must be a float or a numpy array"
 74            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 75            a1 = np.exp(-gamma * dt)[:, None, None]
 76            a2 = jnp.asarray(
 77                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 78            )
 79            vel = (
 80                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 81                * (kT / mass[:, None]) ** 0.5
 82            )
 83
 84        state["rng_key"] = rng_key
 85        if compute_thermostat_energy:
 86            state["thermostat_energy"] = 0.0
 87        if thermostat_name == "FFLGV":
 88            def thermostat(vel, state):
 89                rng_key, noise_key = jax.random.split(state["rng_key"])
 90                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 91                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 92                dirvel = vel / norm_vel
 93                if compute_thermostat_energy:
 94                    v2 = (vel**2).sum(axis=-1)
 95                vel = a1 * vel + a2 * noise
 96                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 97                vel = dirvel * new_norm_vel
 98                new_state = {**state, "rng_key": rng_key}
 99                if compute_thermostat_energy:
100                    v2new = (vel**2).sum(axis=-1)
101                    new_state["thermostat_energy"] = (
102                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
103                    )
104
105                return vel, new_state
106
107        else:
108            def thermostat(vel, state):
109                rng_key, noise_key = jax.random.split(state["rng_key"])
110                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
111                if compute_thermostat_energy:
112                    v2 = (vel**2).sum(axis=-1)
113                vel = a1 * vel + a2 * noise
114                new_state = {**state, "rng_key": rng_key}
115                if compute_thermostat_energy:
116                    v2new = (vel**2).sum(axis=-1)
117                    new_state["thermostat_energy"] = (
118                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
119                    )
120                return vel, new_state
121
122    elif thermostat_name in ["BUSSI"]:
123        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
124        assert kT is not None, "kT must be specified for QTB thermostat"
125        assert gamma is not None, "gamma must be specified for QTB thermostat"
126        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
127        rng_key, v_key = jax.random.split(rng_key)
128
129        a1 = math.exp(-gamma * dt)
130        a2 = (1 - a1) * kT
131        vel = (
132            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
133            * (kT / mass[:, None]) ** 0.5
134        )
135
136        state["rng_key"] = rng_key
137        if compute_thermostat_energy:
138            state["thermostat_energy"] = 0.0
139
140        def thermostat(vel, state):
141            rng_key, noise_key = jax.random.split(state["rng_key"])
142            new_state = {**state, "rng_key": rng_key}
143            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
144            R2 = jnp.sum(noise**2)
145            R1 = noise[0, 0]
146            c = a2 / (mass[:, None] * vel**2).sum()
147            d = (a1 * c) ** 0.5
148            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
149            if compute_thermostat_energy:
150                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
151                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
152            return scale * vel, new_state
153
154    elif thermostat_name in [
155        "GD",
156        "DESCENT",
157        "GRADIENT_DESCENT",
158        "MIN",
159        "MINIMIZE",
160    ]:
161        assert nbeads is None, "Gradient descent is not compatible with PIMD"
162        a1 = math.exp(-gamma * dt)
163
164        if nbeads is None:
165            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
166        else:
167            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
168
169        def thermostat(vel, state):
170            return a1 * vel, state
171
172    elif thermostat_name in ["NVE", "NONE"]:
173        if nbeads is None:
174            vel = (
175                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
176                * (kT / mass[:, None]) ** 0.5
177            )
178            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
179            vel = vel * (kT / kTsys) ** 0.5
180        else:
181            vel = (
182                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
183                * (kT / mass[None, :, None]) ** 0.5
184            )
185            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
186                mass.shape[0] * 3
187            )
188            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
189        thermostat = lambda x, s: (x, s)
190
191    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
192        assert gamma is not None, "gamma must be specified for QTB thermostat"
193        ndof = mass.shape[0] * 3
194        nkT = ndof * kT
195        nose_mass = nkT / gamma**2
196        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
197        state["nose_s"] = 0.0
198        state["nose_v"] = 0.0
199        if compute_thermostat_energy:
200            state["thermostat_energy"] = 0.0
201        print(
202            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
203        )
204        vel = (
205            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
206            * (kT / mass[:, None]) ** 0.5
207        )
208
209        def thermostat(vel, state):
210            nose_s = state["nose_s"]
211            nose_v = state["nose_v"]
212            kTsys = jnp.sum(mass[:, None] * vel**2)
213            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
214            nose_s = nose_s + dt * nose_v
215            vel = jnp.exp(-nose_v * dt) * vel
216            kTsys = jnp.sum(mass[:, None] * vel**2)
217            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
218            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
219
220            if compute_thermostat_energy:
221                new_state["thermostat_energy"] = (
222                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
223                )
224            return vel, new_state
225
226    elif thermostat_name in ["QTB", "ADQTB"]:
227        assert nbeads is None, "QTB is not compatible with PIMD"
228        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
229        assert kT is not None, "kT must be specified for QTB thermostat"
230        assert gamma is not None, "gamma must be specified for QTB thermostat"
231        assert species is not None, "species must be provided for QTB thermostat"
232        rng_key, v_key = jax.random.split(rng_key)
233        vel = (
234            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
235            * (kT / mass[:, None]) ** 0.5
236        )
237
238        thermostat, postprocess, qtb_state = initialize_qtb(
239            simulation_parameters,
240            system_data,
241            fprec=fprec,
242            dt=dt,
243            mass=mass,
244            gamma=gamma,
245            kT=kT,
246            species=species,
247            rng_key=rng_key,
248            adaptive=thermostat_name.startswith("AD"),
249            compute_thermostat_energy=compute_thermostat_energy,
250        )
251        state = {**state, **qtb_state}
252
253    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
254        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
255        assert kT is not None, "kT must be specified for QTB thermostat"
256        assert gamma is not None, "gamma must be specified for QTB thermostat"
257        assert nbeads is None, "ANNEAL is not compatible with PIMD"
258        a1 = math.exp(-gamma * dt)
259        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
260
261        anneal_parameters = simulation_parameters.get("annealing", {})
262        """@keyword[fennol_md] annealing
263        Parameters for simulated annealing schedule configuration.
264        Required for ANNEAL/ANNEALING thermostat
265        """
266        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
267        """@keyword[fennol_md] annealing/init_factor
268        Initial temperature factor for annealing schedule.
269        Default: 0.04 (1/25)
270        """
271        assert init_factor > 0.0, "init_factor must be positive"
272        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
273        """@keyword[fennol_md] annealing/final_factor
274        Final temperature factor for annealing schedule.
275        Default: 0.0001 (1/10000)
276        """
277        assert final_factor > 0.0, "final_factor must be positive"
278        nsteps = simulation_parameters.get("nsteps")
279        """@keyword[fennol_md] nsteps
280        Total number of simulation steps for annealing schedule calculation.
281        Required parameter
282        """
283        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
284        """@keyword[fennol_md] annealing/anneal_steps
285        Fraction of total steps for annealing process.
286        Default: 1.0
287        """
288        assert (
289            anneal_steps < 1.0 and anneal_steps > 0.0
290        ), "warmup_steps must be between 0 and nsteps"
291        pct_start = anneal_parameters.get("warmup_steps", 0.3)
292        """@keyword[fennol_md] annealing/warmup_steps
293        Fraction of annealing steps for warmup phase.
294        Default: 0.3
295        """
296        assert (
297            pct_start < 1.0 and pct_start > 0.0
298        ), "warmup_steps must be between 0 and nsteps"
299
300        anneal_type = anneal_parameters.get("type", "cosine").lower()
301        """@keyword[fennol_md] annealing/type
302        Type of annealing schedule (linear, cosine_onecycle).
303        Default: cosine
304        """
305        if anneal_type == "linear":
306            schedule = optax.linear_onecycle_schedule(
307                peak_value=1.0,
308                div_factor=1.0 / init_factor,
309                final_div_factor=1.0 / final_factor,
310                transition_steps=int(anneal_steps * nsteps),
311                pct_start=pct_start,
312                pct_final=1.0,
313            )
314        elif anneal_type == "cosine_onecycle":
315            schedule = optax.cosine_onecycle_schedule(
316                peak_value=1.0,
317                div_factor=1.0 / init_factor,
318                final_div_factor=1.0 / final_factor,
319                transition_steps=int(anneal_steps * nsteps),
320                pct_start=pct_start,
321            )
322        else:
323            raise ValueError(f"Unknown anneal_type {anneal_type}")
324
325        state["rng_key"] = rng_key
326        state["istep_anneal"] = 0
327
328        rng_key, v_key = jax.random.split(rng_key)
329        Tscale = schedule(0)
330        print(f"# ANNEAL: initial temperature = {Tscale*kT*us.KELVIN:.3e} K")
331        vel = (
332            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
333            * (kT * Tscale / mass[:, None]) ** 0.5
334        )
335
336        def thermostat(vel, state):
337            rng_key, noise_key = jax.random.split(state["rng_key"])
338            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
339
340            Tscale = schedule(state["istep_anneal"]) ** 0.5
341            vel = a1 * vel + a2 * Tscale * noise
342            return vel, {
343                **state,
344                "rng_key": rng_key,
345                "istep_anneal": state["istep_anneal"] + 1,
346            }
347
348    else:
349        raise ValueError(f"Unknown thermostat {thermostat_name}")
350
351    return thermostat, postprocess, state, vel,thermostat_name
352
353
354def initialize_qtb(
355    simulation_parameters,
356    system_data,
357    fprec,
358    dt,
359    mass,
360    gamma,
361    kT,
362    species,
363    rng_key,
364    adaptive,
365    compute_thermostat_energy=False,
366):
367    state = {}
368    post_state = {}
369    qtb_parameters = simulation_parameters.get("qtb", {})
370    verbose = qtb_parameters.get("verbose", False)
371    """@keyword[fennol_md] qtb/verbose
372    Print verbose QTB thermostat information.
373    Default: False
374    """
375    if compute_thermostat_energy:
376        state["thermostat_energy"] = 0.0
377
378    mass = jnp.asarray(mass, dtype=fprec)
379
380    nat = species.shape[0]
381    # define type indices
382    species_set = set(species)
383    nspecies = len(species_set)
384    idx = {sp: i for i, sp in enumerate(species_set)}
385    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
386    type_labels = [PERIODIC_TABLE[sp] for sp in species_set]
387
388    adapt_groups = qtb_parameters.get("adapt_groups", {})
389    """@keyword[fennol_md] qtb/adapt_groups
390    Groups of atoms with separate adQTB parameters (atoms of differents species within groups will be separated in subtypes).
391    Default: {}
392    """
393    ntypes = nspecies
394    allgroups = set()
395    for groupname, interval in adapt_groups.items():
396        indices = read_tinker_interval(interval)
397        assert allgroups.isdisjoint(indices), f"Indices in group {groupname} overlap with other groups"
398        allgroups.update(indices)
399        species_in_group = set(species[indices])
400        idx = {sp: i for i, sp in enumerate(species_in_group)}
401        for i in indices:
402            type_idx[i] = ntypes + idx[species[i]]
403        type_labels += [f"{PERIODIC_TABLE[sp]}_{groupname}" for sp in species_in_group]
404        ntypes = len(type_labels)
405
406    n_of_type = np.zeros(ntypes, dtype=np.int32)
407    for i in range(ntypes):
408        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
409        print(f"# QTB: n({type_labels[i]}) =", n_of_type[i])
410    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
411    mass_idx = jax.ops.segment_sum(mass, type_idx, ntypes) / n_of_type
412
413    niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7)
414    """@keyword[fennol_md] qtb/niter_deconv_kin
415    Number of iterations for kinetic energy deconvolution.
416    Default: 7
417    """
418    niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20)
419    """@keyword[fennol_md] qtb/niter_deconv_pot
420    Number of iterations for potential energy deconvolution.
421    Default: 20
422    """
423    corr_kin = qtb_parameters.get("corr_kin", -1)
424    """@keyword[fennol_md] qtb/corr_kin
425    Kinetic energy correction factor for QTB (-1 for automatic).
426    Default: -1
427    """
428    do_corr_kin = corr_kin <= 0
429    if do_corr_kin:
430        corr_kin = 1.0
431    state["corr_kin"] = corr_kin
432    post_state["corr_kin_prev"] = corr_kin
433    post_state["do_corr_kin"] = do_corr_kin
434    post_state["isame_kin"] = 0
435
436    # spectra parameters
437    omegasmear = np.pi / dt / 100.0
438    Tseg = qtb_parameters.get("tseg", 1.0 / us.PS)
439    """@keyword[fennol_md] qtb/tseg
440    Time segment length for QTB spectrum calculation.
441    Default: 1.0 ps
442    """
443    nseg = int(Tseg / dt)
444    Tseg = nseg * dt
445    dom = 2 * np.pi / (3 * Tseg)
446    omegacut = qtb_parameters.get("omegacut", 15000.0 / us.CM1)
447    """@keyword[fennol_md] qtb/omegacut
448    Cutoff frequency for QTB spectrum.
449    Default: 15000.0 cm⁻¹
450    """
451    nom = int(omegacut / dom)
452    omega = dom * np.arange((3 * nseg) // 2 + 1)
453    cutoff = jnp.asarray(
454        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
455    )
456    assert (
457        omegacut < omega[-1]
458    ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1"
459
460    # initialize gammar
461    assert (
462        gamma < 0.5 * omegacut
463    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
464    gammar_min = qtb_parameters.get("gammar_min", 0.1)
465    """@keyword[fennol_md] qtb/gammar_min
466    Minimum value for QTB gamma ratio coefficients.
467    Default: 0.1
468    """
469    # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
470    gammar = np.ones((ntypes, nom), dtype=float)
471    try:
472        for i, sp in enumerate(type_labels):
473            if not os.path.exists(f"QTB_spectra_{sp}.out"): continue
474            data = np.loadtxt(f"QTB_spectra_{sp}.out")
475            gammar[i] = data[:, 4]/(gamma*us.THZ)
476            print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out")
477    except Exception as e:
478        print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.")
479        gammar[:,:] = 1.0
480    post_state["gammar"] = jnp.asarray(gammar, dtype=fprec)
481
482    # Ornstein-Uhlenbeck correction for colored noise
483    a1 = np.exp(-gamma * dt)
484    OUcorr = jnp.asarray(
485        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
486        dtype=fprec,
487    )
488
489    # hbar schedule
490    classical_kernel = qtb_parameters.get("classical_kernel", False)
491    """@keyword[fennol_md] qtb/classical_kernel
492    Use classical instead of quantum kernel for QTB.
493    Default: False
494    """
495    hbar = qtb_parameters.get("hbar", 1.0) * us.HBAR
496    """@keyword[fennol_md] qtb/hbar
497    Reduced Planck constant scaling factor for quantum effects.
498    Default: 1.0 a.u.
499    """
500    u = 0.5 * hbar * np.abs(omega) / kT
501    theta = kT * np.ones_like(omega)
502    if hbar > 0:
503        theta[1:] *= u[1:] / np.tanh(u[1:])
504    theta = jnp.asarray(theta, dtype=fprec)
505
506    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
507    del rng_key
508    post_state["white_noise"] = jax.random.normal(
509        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
510    )
511
512    startsave = qtb_parameters.get("startsave", 1)
513    """@keyword[fennol_md] qtb/startsave
514    Start saving QTB statistics after this many segments.
515    Default: 1
516    """
517    counter = Counter(nseg, startsave=startsave)
518    state["istep"] = 0
519    post_state["nadapt"] = 0
520    post_state["nsample"] = 0
521
522    write_spectra = qtb_parameters.get("write_spectra", True)
523    """@keyword[fennol_md] qtb/write_spectra
524    Write QTB spectral analysis output files.
525    Default: True
526    """
527    do_compute_spectra = write_spectra or adaptive
528
529    if do_compute_spectra:
530        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
531
532        post_state["dFDT"] = jnp.zeros((ntypes, nom), dtype=fprec)
533        post_state["mCvv"] = jnp.zeros((ntypes, nom), dtype=fprec)
534        post_state["Cvf"] = jnp.zeros((ntypes, nom), dtype=fprec)
535        post_state["Cff"] = jnp.zeros((ntypes, nom), dtype=fprec)
536        post_state["dFDT_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
537        post_state["mCvv_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
538        post_state["Cvfg_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
539        post_state["Cff_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
540
541    if not adaptive:
542        update_gammar = lambda x: x
543    else:
544        # adaptation parameters
545        skipseg = qtb_parameters.get("skipseg", 1)
546        """@keyword[fennol_md] qtb/skipseg
547        Number of segments to skip before starting adaptive QTB.
548        Default: 1
549        """
550
551        adaptation_method = (
552            str(qtb_parameters.get("adaptation_method", "SIMPLE")).upper().strip()
553        )
554        """@keyword[fennol_md] qtb/adaptation_method
555        Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF).
556        Default: SIMPLE
557        """
558        if adaptation_method == "SIMPLE":
559            agamma = qtb_parameters.get("agamma", 0.1)
560            """@keyword[fennol_md] qtb/agamma
561            Learning rate for adaptive QTB gamma update.
562            Default: 0.1
563            """
564            assert agamma > 0, "agamma must be positive"
565            a1_ad = agamma * Tseg  #  * gamma
566            print(f"# ADQTB SIMPLE: agamma = {agamma:.3f}")
567
568            def update_gammar(post_state):
569                g = post_state["dFDT"]
570                gammar = post_state["gammar"] - a1_ad * g
571                gammar = jnp.maximum(gammar_min, gammar)
572                return {**post_state, "gammar": gammar}
573
574        elif adaptation_method == "RATIO":
575            tau_ad = qtb_parameters.get("tau_ad", 5.0 / us.PS) 
576            """@keyword[fennol_md] qtb/tau_ad
577            Adaptation time constant for momentum averaging.
578            Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF)
579            """
580            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad)
581            """@keyword[fennol_md] qtb/tau_s
582            Second moment time constant for variance averaging.
583            Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF)
584            """
585            assert tau_ad > 0, "tau_ad must be positive"
586            print(
587                f"# ADQTB RATIO: tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps"
588            )
589            b1 = np.exp(-Tseg / tau_ad)
590            b2 = np.exp(-Tseg / tau_s)
591            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
592            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
593            post_state["n_adabelief"] = 0
594
595            def update_gammar(post_state):
596                n_adabelief = post_state["n_adabelief"] + 1
597                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
598                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
599                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
600                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
601                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
602                gammar = Cvf / (mCvv + 1.0e-8)
603                gammar = jnp.maximum(gammar_min, gammar)
604                return {
605                    **post_state,
606                    "gammar": gammar,
607                    "mCvv_m": mCvv_m,
608                    "Cvf_m": Cvf_m,
609                    "n_adabelief": n_adabelief,
610                }
611
612        elif adaptation_method == "ADABELIEF":
613            agamma = qtb_parameters.get("agamma", 0.1)
614            tau_ad = qtb_parameters.get("tau_ad", 1.0 / us.PS)
615            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad)
616            assert tau_ad > 0, "tau_ad must be positive"
617            assert tau_s > 0, "tau_s must be positive"
618            assert agamma > 0, "agamma must be positive"
619            print(
620                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps"
621            )
622
623            a1_ad = agamma * gamma  # * Tseg #* gamma
624            b1 = np.exp(-Tseg / tau_ad)
625            b2 = np.exp(-Tseg / tau_s)
626            post_state["dFDT_m"] = jnp.zeros((ntypes, nom), dtype=fprec)
627            post_state["dFDT_s"] = jnp.zeros((ntypes, nom), dtype=fprec)
628            post_state["n_adabelief"] = 0
629
630            def update_gammar(post_state):
631                n_adabelief = post_state["n_adabelief"] + 1
632                dFDT = post_state["dFDT"]
633                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
634                dFDT_s = (
635                    post_state["dFDT_s"] * b2
636                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
637                    + 1.0e-8
638                )
639                # bias correction
640                mt = dFDT_m / (1.0 - b1**n_adabelief)
641                st = dFDT_s / (1.0 - b2**n_adabelief)
642                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
643                gammar = jnp.maximum(gammar_min, gammar)
644                return {
645                    **post_state,
646                    "gammar": gammar,
647                    "dFDT_m": dFDT_m,
648                    "n_adabelief": n_adabelief,
649                    "dFDT_s": dFDT_s,
650                }
651        else: 
652            raise ValueError(
653                f"Unknown adaptation method {adaptation_method}."
654            )
655    
656    #####################
657    # RESTART
658    restart_file = system_data["name"]+".qtb.restart"
659    if os.path.exists(restart_file):
660        with open(restart_file, "rb") as f:
661            data = pickle.load(f)
662            state["corr_kin"] = data["corr_kin"]
663            post_state["corr_kin_prev"] = data["corr_kin"]
664            post_state["isame_kin"] = data["isame_kin"]
665            post_state["do_corr_kin"] = data["do_corr_kin"]
666            print(f"# Restored QTB state from {restart_file}")
667
668    def write_qtb_restart(state, post_state):
669        with open(restart_file, "wb") as f:
670            pickle.dump(
671                {
672                    "corr_kin": state["corr_kin"],
673                    "corr_kin_prev": post_state["corr_kin_prev"],
674                    "isame_kin": post_state["isame_kin"],
675                    "do_corr_kin": post_state["do_corr_kin"],
676                },
677                f,
678            )
679    ######################
680
681    def compute_corr_pot(niter=20, verbose=False):
682        if classical_kernel or hbar == 0:
683            return np.ones(nom)
684
685        s_0 = np.array((theta / kT * cutoff)[:nom])
686        s_out, s_rec, _ = deconvolute_spectrum(
687            s_0,
688            omega[:nom],
689            gamma,
690            niter,
691            kernel=kernel_lorentz_pot,
692            trans=True,
693            symmetrize=True,
694            verbose=verbose,
695        )
696        corr_pot = 1.0 + (s_out - s_0) / s_0
697        columns = np.column_stack(
698            (omega[:nom] * us.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
699        )
700        np.savetxt(
701            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
702        )
703        return corr_pot
704
705    def compute_corr_kin(post_state, niter=7, verbose=False):
706        if not post_state["do_corr_kin"]:
707            return post_state["corr_kin_prev"], post_state
708        if classical_kernel or hbar == 0:
709            return 1.0, post_state
710
711        K_D = post_state.get("K_D", None)
712        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
713        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
714        s_out, s_rec, K_D = deconvolute_spectrum(
715            s_0,
716            omega[:nom],
717            gamma,
718            niter,
719            kernel=kernel_lorentz,
720            trans=False,
721            symmetrize=True,
722            verbose=verbose,
723            K_D=K_D,
724        )
725        s_out = s_out * theta[:nom] / kT
726        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
727        mCvvsum = mCvv.sum()
728        rec_ratio = mCvvsum / s_rec.sum()
729        if rec_ratio < 0.95 or rec_ratio > 1.05:
730            print(
731                f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated"
732            )
733            return post_state["corr_kin_prev"], post_state
734
735        corr_kin = mCvvsum / s_out.sum()
736        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
737            isame_kin = post_state["isame_kin"] + 1
738        else:
739            isame_kin = 0
740
741        # print("# corr_kin: ", corr_kin)
742        do_corr_kin = post_state["do_corr_kin"]
743        if isame_kin > 10:
744            print(
745                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
746            )
747            do_corr_kin = False
748
749        return corr_kin, {
750            **post_state,
751            "corr_kin_prev": corr_kin,
752            "isame_kin": isame_kin,
753            "do_corr_kin": do_corr_kin,
754            "K_D": K_D,
755        }
756
757    @jax.jit
758    def ff_kernel(post_state):
759        if classical_kernel:
760            kernel = cutoff * (2 * gamma * kT / dt)
761        else:
762            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
763        gamma_ratio = jnp.concatenate(
764            (
765                post_state["gammar"].T * post_state["corr_pot"][:, None],
766                jnp.ones(
767                    (kernel.shape[0] - nom, ntypes), dtype=post_state["gammar"].dtype
768                ),
769            ),
770            axis=0,
771        )
772        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
773
774    @jax.jit
775    def refresh_force(post_state):
776        rng_key, noise_key = jax.random.split(post_state["rng_key"])
777        white_noise = jnp.concatenate(
778            (
779                post_state["white_noise"][nseg:],
780                jax.random.normal(
781                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
782                ),
783            ),
784            axis=0,
785        )
786        amplitude = ff_kernel(post_state) ** 0.5
787        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
788        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
789        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
790
791    @jax.jit
792    def compute_spectra(force, vel, post_state):
793        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
794        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
795        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
796        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
797        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
798
799        mCvv = (
800            (dt / 3.0)
801            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
802            * mass_idx[:, None]
803            / n_of_type[:, None]
804        )
805        Cvf = (
806            (dt / 3.0)
807            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
808            / n_of_type[:, None]
809        )
810        Cff = (
811            (dt / 3.0)
812            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
813            / n_of_type[:, None]
814        )
815        dFDT = mCvv * post_state["gammar"] - Cvf
816
817        nsinv = 1.0 / post_state["nsample"]
818        b1 = 1.0 - nsinv
819        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
820        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
821        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
822        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
823
824        return {
825            **post_state,
826            "mCvv": mCvv,
827            "Cvf": Cvf,
828            "Cff": Cff,
829            "dFDT": dFDT,
830            "dFDT_avg": dFDT_avg,
831            "mCvv_avg": mCvv_avg,
832            "Cvfg_avg": Cvfg_avg,
833            "Cff_avg": Cff_avg,
834        }
835
836    def write_spectra_to_file(post_state):
837        mCvv_avg = np.array(post_state["mCvv_avg"])
838        Cvfg_avg = np.array(post_state["Cvfg_avg"])
839        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
840        dFDT_avg = np.array(post_state["dFDT_avg"])
841        gammar = np.array(post_state["gammar"])
842        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
843        for i, sp in enumerate(type_labels):
844            ff_scale = us.KELVIN / ((2 * gamma / dt) * mass_idx[i])
845            columns = np.column_stack(
846                (
847                    omega[:nom] * us.CM1,
848                    mCvv_avg[i],
849                    Cvfg_avg[i],
850                    dFDT_avg[i],
851                    gammar[i] * gamma * us.THZ,
852                    Cff_avg[i] * ff_scale,
853                    Cff_theo[i] * ff_scale,
854                )
855            )
856            np.savetxt(
857                f"QTB_spectra_{sp}.out",
858                columns,
859                fmt="%12.6f",
860                header="#omega mCvv Cvf dFDT gammar Cff",
861            )
862        if verbose:
863            print("# QTB spectra written.")
864
865    if compute_thermostat_energy:
866        state["qtb_energy_flux"] = 0.0
867
868    @jax.jit
869    def thermostat(vel, state):
870        istep = state["istep"]
871        dvel = dt * state["force"][istep] / mass[:, None]
872        new_vel = vel * a1 + dvel
873        new_state = {**state, "istep": istep + 1}
874        if do_compute_spectra:
875            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
876            new_state["vel"] = vel2
877        if compute_thermostat_energy:
878            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
879            ekcorr = (
880                0.5
881                * (mass[:, None] * new_vel**2).sum()
882                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
883            )
884            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
885            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
886        return new_vel, new_state
887
888    @jax.jit
889    def postprocess_work(state, post_state):
890        if do_compute_spectra:
891            post_state = compute_spectra(state["force"], state["vel"], post_state)
892        if adaptive:
893            post_state = jax.lax.cond(
894                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
895            )
896        new_force, post_state = refresh_force(post_state)
897        return {**state, "force": new_force}, post_state
898
899    def postprocess(state, post_state):
900        counter.increment()
901        if not counter.is_reset_step:
902            return state, post_state
903        post_state["nadapt"] += 1
904        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
905        if verbose:
906            print("# Refreshing QTB forces.")
907        state, post_state = postprocess_work(state, post_state)
908        state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin)
909        state["istep"] = 0
910        if write_spectra:
911            write_spectra_to_file(post_state)
912        write_qtb_restart(state, post_state)
913        return state, post_state
914
915    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec)
916
917    state["force"], post_state = refresh_force(post_state)
918    return thermostat, (postprocess, post_state), state
def get_thermostat( simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 21def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 22    state = {}
 23    postprocess = None
 24    
 25
 26    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 27    """@keyword[fennol_md] thermostat
 28    Thermostat type. Options: 'NVE', 'LGV', 'NOSE', 'ADQTB'.
 29    Default: "LGV"
 30    """
 31    compute_thermostat_energy = simulation_parameters.get(
 32        "include_thermostat_energy", False
 33    )
 34    """@keyword[fennol_md] include_thermostat_energy
 35    Include thermostat energy in total energy calculations.
 36    Default: False
 37    """
 38
 39    kT = system_data.get("kT", None)
 40    nbeads = system_data.get("nbeads", None)
 41    mass = system_data["mass"]
 42    gamma = simulation_parameters.get("gamma", 1.0 / us.THZ)
 43    """@keyword[fennol_md] gamma
 44    Friction coefficient for Langevin thermostat.
 45    Default: 1.0 ps^-1
 46    """
 47    species = system_data["species"]
 48
 49    if nbeads is not None:
 50        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 51        """@keyword[fennol_md] trpmd_lambda
 52        Lambda parameter for TRPMD (Thermostatted Ring Polymer MD).
 53        Default: 1.0
 54        """
 55        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 56
 57    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 58        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 59        assert kT is not None, "kT must be specified for QTB thermostat"
 60        assert gamma is not None, "gamma must be specified for QTB thermostat"
 61        rng_key, v_key = jax.random.split(rng_key)
 62        if nbeads is None:
 63            a1 = math.exp(-gamma * dt)
 64            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 65            vel = (
 66                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 67                * (kT / mass[:, None]) ** 0.5
 68            )
 69        else:
 70            if isinstance(gamma, float):
 71                gamma = np.array([gamma] * nbeads)
 72            assert isinstance(
 73                gamma, np.ndarray
 74            ), "gamma must be a float or a numpy array"
 75            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 76            a1 = np.exp(-gamma * dt)[:, None, None]
 77            a2 = jnp.asarray(
 78                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 79            )
 80            vel = (
 81                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 82                * (kT / mass[:, None]) ** 0.5
 83            )
 84
 85        state["rng_key"] = rng_key
 86        if compute_thermostat_energy:
 87            state["thermostat_energy"] = 0.0
 88        if thermostat_name == "FFLGV":
 89            def thermostat(vel, state):
 90                rng_key, noise_key = jax.random.split(state["rng_key"])
 91                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 92                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 93                dirvel = vel / norm_vel
 94                if compute_thermostat_energy:
 95                    v2 = (vel**2).sum(axis=-1)
 96                vel = a1 * vel + a2 * noise
 97                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 98                vel = dirvel * new_norm_vel
 99                new_state = {**state, "rng_key": rng_key}
100                if compute_thermostat_energy:
101                    v2new = (vel**2).sum(axis=-1)
102                    new_state["thermostat_energy"] = (
103                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
104                    )
105
106                return vel, new_state
107
108        else:
109            def thermostat(vel, state):
110                rng_key, noise_key = jax.random.split(state["rng_key"])
111                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
112                if compute_thermostat_energy:
113                    v2 = (vel**2).sum(axis=-1)
114                vel = a1 * vel + a2 * noise
115                new_state = {**state, "rng_key": rng_key}
116                if compute_thermostat_energy:
117                    v2new = (vel**2).sum(axis=-1)
118                    new_state["thermostat_energy"] = (
119                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
120                    )
121                return vel, new_state
122
123    elif thermostat_name in ["BUSSI"]:
124        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
125        assert kT is not None, "kT must be specified for QTB thermostat"
126        assert gamma is not None, "gamma must be specified for QTB thermostat"
127        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
128        rng_key, v_key = jax.random.split(rng_key)
129
130        a1 = math.exp(-gamma * dt)
131        a2 = (1 - a1) * kT
132        vel = (
133            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
134            * (kT / mass[:, None]) ** 0.5
135        )
136
137        state["rng_key"] = rng_key
138        if compute_thermostat_energy:
139            state["thermostat_energy"] = 0.0
140
141        def thermostat(vel, state):
142            rng_key, noise_key = jax.random.split(state["rng_key"])
143            new_state = {**state, "rng_key": rng_key}
144            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
145            R2 = jnp.sum(noise**2)
146            R1 = noise[0, 0]
147            c = a2 / (mass[:, None] * vel**2).sum()
148            d = (a1 * c) ** 0.5
149            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
150            if compute_thermostat_energy:
151                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
152                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
153            return scale * vel, new_state
154
155    elif thermostat_name in [
156        "GD",
157        "DESCENT",
158        "GRADIENT_DESCENT",
159        "MIN",
160        "MINIMIZE",
161    ]:
162        assert nbeads is None, "Gradient descent is not compatible with PIMD"
163        a1 = math.exp(-gamma * dt)
164
165        if nbeads is None:
166            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
167        else:
168            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
169
170        def thermostat(vel, state):
171            return a1 * vel, state
172
173    elif thermostat_name in ["NVE", "NONE"]:
174        if nbeads is None:
175            vel = (
176                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
177                * (kT / mass[:, None]) ** 0.5
178            )
179            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
180            vel = vel * (kT / kTsys) ** 0.5
181        else:
182            vel = (
183                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
184                * (kT / mass[None, :, None]) ** 0.5
185            )
186            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
187                mass.shape[0] * 3
188            )
189            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
190        thermostat = lambda x, s: (x, s)
191
192    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
193        assert gamma is not None, "gamma must be specified for QTB thermostat"
194        ndof = mass.shape[0] * 3
195        nkT = ndof * kT
196        nose_mass = nkT / gamma**2
197        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
198        state["nose_s"] = 0.0
199        state["nose_v"] = 0.0
200        if compute_thermostat_energy:
201            state["thermostat_energy"] = 0.0
202        print(
203            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
204        )
205        vel = (
206            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
207            * (kT / mass[:, None]) ** 0.5
208        )
209
210        def thermostat(vel, state):
211            nose_s = state["nose_s"]
212            nose_v = state["nose_v"]
213            kTsys = jnp.sum(mass[:, None] * vel**2)
214            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
215            nose_s = nose_s + dt * nose_v
216            vel = jnp.exp(-nose_v * dt) * vel
217            kTsys = jnp.sum(mass[:, None] * vel**2)
218            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
219            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
220
221            if compute_thermostat_energy:
222                new_state["thermostat_energy"] = (
223                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
224                )
225            return vel, new_state
226
227    elif thermostat_name in ["QTB", "ADQTB"]:
228        assert nbeads is None, "QTB is not compatible with PIMD"
229        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
230        assert kT is not None, "kT must be specified for QTB thermostat"
231        assert gamma is not None, "gamma must be specified for QTB thermostat"
232        assert species is not None, "species must be provided for QTB thermostat"
233        rng_key, v_key = jax.random.split(rng_key)
234        vel = (
235            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
236            * (kT / mass[:, None]) ** 0.5
237        )
238
239        thermostat, postprocess, qtb_state = initialize_qtb(
240            simulation_parameters,
241            system_data,
242            fprec=fprec,
243            dt=dt,
244            mass=mass,
245            gamma=gamma,
246            kT=kT,
247            species=species,
248            rng_key=rng_key,
249            adaptive=thermostat_name.startswith("AD"),
250            compute_thermostat_energy=compute_thermostat_energy,
251        )
252        state = {**state, **qtb_state}
253
254    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
255        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
256        assert kT is not None, "kT must be specified for QTB thermostat"
257        assert gamma is not None, "gamma must be specified for QTB thermostat"
258        assert nbeads is None, "ANNEAL is not compatible with PIMD"
259        a1 = math.exp(-gamma * dt)
260        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
261
262        anneal_parameters = simulation_parameters.get("annealing", {})
263        """@keyword[fennol_md] annealing
264        Parameters for simulated annealing schedule configuration.
265        Required for ANNEAL/ANNEALING thermostat
266        """
267        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
268        """@keyword[fennol_md] annealing/init_factor
269        Initial temperature factor for annealing schedule.
270        Default: 0.04 (1/25)
271        """
272        assert init_factor > 0.0, "init_factor must be positive"
273        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
274        """@keyword[fennol_md] annealing/final_factor
275        Final temperature factor for annealing schedule.
276        Default: 0.0001 (1/10000)
277        """
278        assert final_factor > 0.0, "final_factor must be positive"
279        nsteps = simulation_parameters.get("nsteps")
280        """@keyword[fennol_md] nsteps
281        Total number of simulation steps for annealing schedule calculation.
282        Required parameter
283        """
284        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
285        """@keyword[fennol_md] annealing/anneal_steps
286        Fraction of total steps for annealing process.
287        Default: 1.0
288        """
289        assert (
290            anneal_steps < 1.0 and anneal_steps > 0.0
291        ), "warmup_steps must be between 0 and nsteps"
292        pct_start = anneal_parameters.get("warmup_steps", 0.3)
293        """@keyword[fennol_md] annealing/warmup_steps
294        Fraction of annealing steps for warmup phase.
295        Default: 0.3
296        """
297        assert (
298            pct_start < 1.0 and pct_start > 0.0
299        ), "warmup_steps must be between 0 and nsteps"
300
301        anneal_type = anneal_parameters.get("type", "cosine").lower()
302        """@keyword[fennol_md] annealing/type
303        Type of annealing schedule (linear, cosine_onecycle).
304        Default: cosine
305        """
306        if anneal_type == "linear":
307            schedule = optax.linear_onecycle_schedule(
308                peak_value=1.0,
309                div_factor=1.0 / init_factor,
310                final_div_factor=1.0 / final_factor,
311                transition_steps=int(anneal_steps * nsteps),
312                pct_start=pct_start,
313                pct_final=1.0,
314            )
315        elif anneal_type == "cosine_onecycle":
316            schedule = optax.cosine_onecycle_schedule(
317                peak_value=1.0,
318                div_factor=1.0 / init_factor,
319                final_div_factor=1.0 / final_factor,
320                transition_steps=int(anneal_steps * nsteps),
321                pct_start=pct_start,
322            )
323        else:
324            raise ValueError(f"Unknown anneal_type {anneal_type}")
325
326        state["rng_key"] = rng_key
327        state["istep_anneal"] = 0
328
329        rng_key, v_key = jax.random.split(rng_key)
330        Tscale = schedule(0)
331        print(f"# ANNEAL: initial temperature = {Tscale*kT*us.KELVIN:.3e} K")
332        vel = (
333            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
334            * (kT * Tscale / mass[:, None]) ** 0.5
335        )
336
337        def thermostat(vel, state):
338            rng_key, noise_key = jax.random.split(state["rng_key"])
339            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
340
341            Tscale = schedule(state["istep_anneal"]) ** 0.5
342            vel = a1 * vel + a2 * Tscale * noise
343            return vel, {
344                **state,
345                "rng_key": rng_key,
346                "istep_anneal": state["istep_anneal"] + 1,
347            }
348
349    else:
350        raise ValueError(f"Unknown thermostat {thermostat_name}")
351
352    return thermostat, postprocess, state, vel,thermostat_name
def initialize_qtb( simulation_parameters, system_data, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
355def initialize_qtb(
356    simulation_parameters,
357    system_data,
358    fprec,
359    dt,
360    mass,
361    gamma,
362    kT,
363    species,
364    rng_key,
365    adaptive,
366    compute_thermostat_energy=False,
367):
368    state = {}
369    post_state = {}
370    qtb_parameters = simulation_parameters.get("qtb", {})
371    verbose = qtb_parameters.get("verbose", False)
372    """@keyword[fennol_md] qtb/verbose
373    Print verbose QTB thermostat information.
374    Default: False
375    """
376    if compute_thermostat_energy:
377        state["thermostat_energy"] = 0.0
378
379    mass = jnp.asarray(mass, dtype=fprec)
380
381    nat = species.shape[0]
382    # define type indices
383    species_set = set(species)
384    nspecies = len(species_set)
385    idx = {sp: i for i, sp in enumerate(species_set)}
386    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
387    type_labels = [PERIODIC_TABLE[sp] for sp in species_set]
388
389    adapt_groups = qtb_parameters.get("adapt_groups", {})
390    """@keyword[fennol_md] qtb/adapt_groups
391    Groups of atoms with separate adQTB parameters (atoms of differents species within groups will be separated in subtypes).
392    Default: {}
393    """
394    ntypes = nspecies
395    allgroups = set()
396    for groupname, interval in adapt_groups.items():
397        indices = read_tinker_interval(interval)
398        assert allgroups.isdisjoint(indices), f"Indices in group {groupname} overlap with other groups"
399        allgroups.update(indices)
400        species_in_group = set(species[indices])
401        idx = {sp: i for i, sp in enumerate(species_in_group)}
402        for i in indices:
403            type_idx[i] = ntypes + idx[species[i]]
404        type_labels += [f"{PERIODIC_TABLE[sp]}_{groupname}" for sp in species_in_group]
405        ntypes = len(type_labels)
406
407    n_of_type = np.zeros(ntypes, dtype=np.int32)
408    for i in range(ntypes):
409        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
410        print(f"# QTB: n({type_labels[i]}) =", n_of_type[i])
411    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
412    mass_idx = jax.ops.segment_sum(mass, type_idx, ntypes) / n_of_type
413
414    niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7)
415    """@keyword[fennol_md] qtb/niter_deconv_kin
416    Number of iterations for kinetic energy deconvolution.
417    Default: 7
418    """
419    niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20)
420    """@keyword[fennol_md] qtb/niter_deconv_pot
421    Number of iterations for potential energy deconvolution.
422    Default: 20
423    """
424    corr_kin = qtb_parameters.get("corr_kin", -1)
425    """@keyword[fennol_md] qtb/corr_kin
426    Kinetic energy correction factor for QTB (-1 for automatic).
427    Default: -1
428    """
429    do_corr_kin = corr_kin <= 0
430    if do_corr_kin:
431        corr_kin = 1.0
432    state["corr_kin"] = corr_kin
433    post_state["corr_kin_prev"] = corr_kin
434    post_state["do_corr_kin"] = do_corr_kin
435    post_state["isame_kin"] = 0
436
437    # spectra parameters
438    omegasmear = np.pi / dt / 100.0
439    Tseg = qtb_parameters.get("tseg", 1.0 / us.PS)
440    """@keyword[fennol_md] qtb/tseg
441    Time segment length for QTB spectrum calculation.
442    Default: 1.0 ps
443    """
444    nseg = int(Tseg / dt)
445    Tseg = nseg * dt
446    dom = 2 * np.pi / (3 * Tseg)
447    omegacut = qtb_parameters.get("omegacut", 15000.0 / us.CM1)
448    """@keyword[fennol_md] qtb/omegacut
449    Cutoff frequency for QTB spectrum.
450    Default: 15000.0 cm⁻¹
451    """
452    nom = int(omegacut / dom)
453    omega = dom * np.arange((3 * nseg) // 2 + 1)
454    cutoff = jnp.asarray(
455        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
456    )
457    assert (
458        omegacut < omega[-1]
459    ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1"
460
461    # initialize gammar
462    assert (
463        gamma < 0.5 * omegacut
464    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
465    gammar_min = qtb_parameters.get("gammar_min", 0.1)
466    """@keyword[fennol_md] qtb/gammar_min
467    Minimum value for QTB gamma ratio coefficients.
468    Default: 0.1
469    """
470    # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
471    gammar = np.ones((ntypes, nom), dtype=float)
472    try:
473        for i, sp in enumerate(type_labels):
474            if not os.path.exists(f"QTB_spectra_{sp}.out"): continue
475            data = np.loadtxt(f"QTB_spectra_{sp}.out")
476            gammar[i] = data[:, 4]/(gamma*us.THZ)
477            print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out")
478    except Exception as e:
479        print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.")
480        gammar[:,:] = 1.0
481    post_state["gammar"] = jnp.asarray(gammar, dtype=fprec)
482
483    # Ornstein-Uhlenbeck correction for colored noise
484    a1 = np.exp(-gamma * dt)
485    OUcorr = jnp.asarray(
486        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
487        dtype=fprec,
488    )
489
490    # hbar schedule
491    classical_kernel = qtb_parameters.get("classical_kernel", False)
492    """@keyword[fennol_md] qtb/classical_kernel
493    Use classical instead of quantum kernel for QTB.
494    Default: False
495    """
496    hbar = qtb_parameters.get("hbar", 1.0) * us.HBAR
497    """@keyword[fennol_md] qtb/hbar
498    Reduced Planck constant scaling factor for quantum effects.
499    Default: 1.0 a.u.
500    """
501    u = 0.5 * hbar * np.abs(omega) / kT
502    theta = kT * np.ones_like(omega)
503    if hbar > 0:
504        theta[1:] *= u[1:] / np.tanh(u[1:])
505    theta = jnp.asarray(theta, dtype=fprec)
506
507    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
508    del rng_key
509    post_state["white_noise"] = jax.random.normal(
510        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
511    )
512
513    startsave = qtb_parameters.get("startsave", 1)
514    """@keyword[fennol_md] qtb/startsave
515    Start saving QTB statistics after this many segments.
516    Default: 1
517    """
518    counter = Counter(nseg, startsave=startsave)
519    state["istep"] = 0
520    post_state["nadapt"] = 0
521    post_state["nsample"] = 0
522
523    write_spectra = qtb_parameters.get("write_spectra", True)
524    """@keyword[fennol_md] qtb/write_spectra
525    Write QTB spectral analysis output files.
526    Default: True
527    """
528    do_compute_spectra = write_spectra or adaptive
529
530    if do_compute_spectra:
531        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
532
533        post_state["dFDT"] = jnp.zeros((ntypes, nom), dtype=fprec)
534        post_state["mCvv"] = jnp.zeros((ntypes, nom), dtype=fprec)
535        post_state["Cvf"] = jnp.zeros((ntypes, nom), dtype=fprec)
536        post_state["Cff"] = jnp.zeros((ntypes, nom), dtype=fprec)
537        post_state["dFDT_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
538        post_state["mCvv_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
539        post_state["Cvfg_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
540        post_state["Cff_avg"] = jnp.zeros((ntypes, nom), dtype=fprec)
541
542    if not adaptive:
543        update_gammar = lambda x: x
544    else:
545        # adaptation parameters
546        skipseg = qtb_parameters.get("skipseg", 1)
547        """@keyword[fennol_md] qtb/skipseg
548        Number of segments to skip before starting adaptive QTB.
549        Default: 1
550        """
551
552        adaptation_method = (
553            str(qtb_parameters.get("adaptation_method", "SIMPLE")).upper().strip()
554        )
555        """@keyword[fennol_md] qtb/adaptation_method
556        Method for adaptive QTB (SIMPLE, RATIO, ADABELIEF).
557        Default: SIMPLE
558        """
559        if adaptation_method == "SIMPLE":
560            agamma = qtb_parameters.get("agamma", 0.1)
561            """@keyword[fennol_md] qtb/agamma
562            Learning rate for adaptive QTB gamma update.
563            Default: 0.1
564            """
565            assert agamma > 0, "agamma must be positive"
566            a1_ad = agamma * Tseg  #  * gamma
567            print(f"# ADQTB SIMPLE: agamma = {agamma:.3f}")
568
569            def update_gammar(post_state):
570                g = post_state["dFDT"]
571                gammar = post_state["gammar"] - a1_ad * g
572                gammar = jnp.maximum(gammar_min, gammar)
573                return {**post_state, "gammar": gammar}
574
575        elif adaptation_method == "RATIO":
576            tau_ad = qtb_parameters.get("tau_ad", 5.0 / us.PS) 
577            """@keyword[fennol_md] qtb/tau_ad
578            Adaptation time constant for momentum averaging.
579            Default: 5.0 ps (RATIO), 1.0 ps (ADABELIEF)
580            """
581            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad)
582            """@keyword[fennol_md] qtb/tau_s
583            Second moment time constant for variance averaging.
584            Default: 10*tau_ad (RATIO), 100*tau_ad (ADABELIEF)
585            """
586            assert tau_ad > 0, "tau_ad must be positive"
587            print(
588                f"# ADQTB RATIO: tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps"
589            )
590            b1 = np.exp(-Tseg / tau_ad)
591            b2 = np.exp(-Tseg / tau_s)
592            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
593            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
594            post_state["n_adabelief"] = 0
595
596            def update_gammar(post_state):
597                n_adabelief = post_state["n_adabelief"] + 1
598                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
599                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
600                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
601                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
602                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
603                gammar = Cvf / (mCvv + 1.0e-8)
604                gammar = jnp.maximum(gammar_min, gammar)
605                return {
606                    **post_state,
607                    "gammar": gammar,
608                    "mCvv_m": mCvv_m,
609                    "Cvf_m": Cvf_m,
610                    "n_adabelief": n_adabelief,
611                }
612
613        elif adaptation_method == "ADABELIEF":
614            agamma = qtb_parameters.get("agamma", 0.1)
615            tau_ad = qtb_parameters.get("tau_ad", 1.0 / us.PS)
616            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad)
617            assert tau_ad > 0, "tau_ad must be positive"
618            assert tau_s > 0, "tau_s must be positive"
619            assert agamma > 0, "agamma must be positive"
620            print(
621                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*us.PS:.2f} ps, tau_s = {tau_s*us.PS:.2f} ps"
622            )
623
624            a1_ad = agamma * gamma  # * Tseg #* gamma
625            b1 = np.exp(-Tseg / tau_ad)
626            b2 = np.exp(-Tseg / tau_s)
627            post_state["dFDT_m"] = jnp.zeros((ntypes, nom), dtype=fprec)
628            post_state["dFDT_s"] = jnp.zeros((ntypes, nom), dtype=fprec)
629            post_state["n_adabelief"] = 0
630
631            def update_gammar(post_state):
632                n_adabelief = post_state["n_adabelief"] + 1
633                dFDT = post_state["dFDT"]
634                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
635                dFDT_s = (
636                    post_state["dFDT_s"] * b2
637                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
638                    + 1.0e-8
639                )
640                # bias correction
641                mt = dFDT_m / (1.0 - b1**n_adabelief)
642                st = dFDT_s / (1.0 - b2**n_adabelief)
643                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
644                gammar = jnp.maximum(gammar_min, gammar)
645                return {
646                    **post_state,
647                    "gammar": gammar,
648                    "dFDT_m": dFDT_m,
649                    "n_adabelief": n_adabelief,
650                    "dFDT_s": dFDT_s,
651                }
652        else: 
653            raise ValueError(
654                f"Unknown adaptation method {adaptation_method}."
655            )
656    
657    #####################
658    # RESTART
659    restart_file = system_data["name"]+".qtb.restart"
660    if os.path.exists(restart_file):
661        with open(restart_file, "rb") as f:
662            data = pickle.load(f)
663            state["corr_kin"] = data["corr_kin"]
664            post_state["corr_kin_prev"] = data["corr_kin"]
665            post_state["isame_kin"] = data["isame_kin"]
666            post_state["do_corr_kin"] = data["do_corr_kin"]
667            print(f"# Restored QTB state from {restart_file}")
668
669    def write_qtb_restart(state, post_state):
670        with open(restart_file, "wb") as f:
671            pickle.dump(
672                {
673                    "corr_kin": state["corr_kin"],
674                    "corr_kin_prev": post_state["corr_kin_prev"],
675                    "isame_kin": post_state["isame_kin"],
676                    "do_corr_kin": post_state["do_corr_kin"],
677                },
678                f,
679            )
680    ######################
681
682    def compute_corr_pot(niter=20, verbose=False):
683        if classical_kernel or hbar == 0:
684            return np.ones(nom)
685
686        s_0 = np.array((theta / kT * cutoff)[:nom])
687        s_out, s_rec, _ = deconvolute_spectrum(
688            s_0,
689            omega[:nom],
690            gamma,
691            niter,
692            kernel=kernel_lorentz_pot,
693            trans=True,
694            symmetrize=True,
695            verbose=verbose,
696        )
697        corr_pot = 1.0 + (s_out - s_0) / s_0
698        columns = np.column_stack(
699            (omega[:nom] * us.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
700        )
701        np.savetxt(
702            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
703        )
704        return corr_pot
705
706    def compute_corr_kin(post_state, niter=7, verbose=False):
707        if not post_state["do_corr_kin"]:
708            return post_state["corr_kin_prev"], post_state
709        if classical_kernel or hbar == 0:
710            return 1.0, post_state
711
712        K_D = post_state.get("K_D", None)
713        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
714        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
715        s_out, s_rec, K_D = deconvolute_spectrum(
716            s_0,
717            omega[:nom],
718            gamma,
719            niter,
720            kernel=kernel_lorentz,
721            trans=False,
722            symmetrize=True,
723            verbose=verbose,
724            K_D=K_D,
725        )
726        s_out = s_out * theta[:nom] / kT
727        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
728        mCvvsum = mCvv.sum()
729        rec_ratio = mCvvsum / s_rec.sum()
730        if rec_ratio < 0.95 or rec_ratio > 1.05:
731            print(
732                f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated"
733            )
734            return post_state["corr_kin_prev"], post_state
735
736        corr_kin = mCvvsum / s_out.sum()
737        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
738            isame_kin = post_state["isame_kin"] + 1
739        else:
740            isame_kin = 0
741
742        # print("# corr_kin: ", corr_kin)
743        do_corr_kin = post_state["do_corr_kin"]
744        if isame_kin > 10:
745            print(
746                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
747            )
748            do_corr_kin = False
749
750        return corr_kin, {
751            **post_state,
752            "corr_kin_prev": corr_kin,
753            "isame_kin": isame_kin,
754            "do_corr_kin": do_corr_kin,
755            "K_D": K_D,
756        }
757
758    @jax.jit
759    def ff_kernel(post_state):
760        if classical_kernel:
761            kernel = cutoff * (2 * gamma * kT / dt)
762        else:
763            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
764        gamma_ratio = jnp.concatenate(
765            (
766                post_state["gammar"].T * post_state["corr_pot"][:, None],
767                jnp.ones(
768                    (kernel.shape[0] - nom, ntypes), dtype=post_state["gammar"].dtype
769                ),
770            ),
771            axis=0,
772        )
773        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
774
775    @jax.jit
776    def refresh_force(post_state):
777        rng_key, noise_key = jax.random.split(post_state["rng_key"])
778        white_noise = jnp.concatenate(
779            (
780                post_state["white_noise"][nseg:],
781                jax.random.normal(
782                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
783                ),
784            ),
785            axis=0,
786        )
787        amplitude = ff_kernel(post_state) ** 0.5
788        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
789        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
790        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
791
792    @jax.jit
793    def compute_spectra(force, vel, post_state):
794        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
795        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
796        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
797        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
798        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
799
800        mCvv = (
801            (dt / 3.0)
802            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
803            * mass_idx[:, None]
804            / n_of_type[:, None]
805        )
806        Cvf = (
807            (dt / 3.0)
808            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
809            / n_of_type[:, None]
810        )
811        Cff = (
812            (dt / 3.0)
813            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
814            / n_of_type[:, None]
815        )
816        dFDT = mCvv * post_state["gammar"] - Cvf
817
818        nsinv = 1.0 / post_state["nsample"]
819        b1 = 1.0 - nsinv
820        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
821        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
822        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
823        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
824
825        return {
826            **post_state,
827            "mCvv": mCvv,
828            "Cvf": Cvf,
829            "Cff": Cff,
830            "dFDT": dFDT,
831            "dFDT_avg": dFDT_avg,
832            "mCvv_avg": mCvv_avg,
833            "Cvfg_avg": Cvfg_avg,
834            "Cff_avg": Cff_avg,
835        }
836
837    def write_spectra_to_file(post_state):
838        mCvv_avg = np.array(post_state["mCvv_avg"])
839        Cvfg_avg = np.array(post_state["Cvfg_avg"])
840        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
841        dFDT_avg = np.array(post_state["dFDT_avg"])
842        gammar = np.array(post_state["gammar"])
843        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
844        for i, sp in enumerate(type_labels):
845            ff_scale = us.KELVIN / ((2 * gamma / dt) * mass_idx[i])
846            columns = np.column_stack(
847                (
848                    omega[:nom] * us.CM1,
849                    mCvv_avg[i],
850                    Cvfg_avg[i],
851                    dFDT_avg[i],
852                    gammar[i] * gamma * us.THZ,
853                    Cff_avg[i] * ff_scale,
854                    Cff_theo[i] * ff_scale,
855                )
856            )
857            np.savetxt(
858                f"QTB_spectra_{sp}.out",
859                columns,
860                fmt="%12.6f",
861                header="#omega mCvv Cvf dFDT gammar Cff",
862            )
863        if verbose:
864            print("# QTB spectra written.")
865
866    if compute_thermostat_energy:
867        state["qtb_energy_flux"] = 0.0
868
869    @jax.jit
870    def thermostat(vel, state):
871        istep = state["istep"]
872        dvel = dt * state["force"][istep] / mass[:, None]
873        new_vel = vel * a1 + dvel
874        new_state = {**state, "istep": istep + 1}
875        if do_compute_spectra:
876            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
877            new_state["vel"] = vel2
878        if compute_thermostat_energy:
879            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
880            ekcorr = (
881                0.5
882                * (mass[:, None] * new_vel**2).sum()
883                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
884            )
885            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
886            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
887        return new_vel, new_state
888
889    @jax.jit
890    def postprocess_work(state, post_state):
891        if do_compute_spectra:
892            post_state = compute_spectra(state["force"], state["vel"], post_state)
893        if adaptive:
894            post_state = jax.lax.cond(
895                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
896            )
897        new_force, post_state = refresh_force(post_state)
898        return {**state, "force": new_force}, post_state
899
900    def postprocess(state, post_state):
901        counter.increment()
902        if not counter.is_reset_step:
903            return state, post_state
904        post_state["nadapt"] += 1
905        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
906        if verbose:
907            print("# Refreshing QTB forces.")
908        state, post_state = postprocess_work(state, post_state)
909        state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin)
910        state["istep"] = 0
911        if write_spectra:
912            write_spectra_to_file(post_state)
913        write_qtb_restart(state, post_state)
914        return state, post_state
915
916    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec)
917
918    state["force"], post_state = refresh_force(post_state)
919    return thermostat, (postprocess, post_state), state