fennol.md.thermostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7import os
  8import pickle
  9
 10from ..utils.atomic_units import AtomicUnits as au  # CM1,THZ,BOHR,MPROT
 11from ..utils import Counter
 12from ..utils.deconvolution import (
 13    deconvolute_spectrum,
 14    kernel_lorentz_pot,
 15    kernel_lorentz,
 16)
 17
 18
 19def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 20    state = {}
 21    postprocess = None
 22    
 23
 24    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 25    compute_thermostat_energy = simulation_parameters.get(
 26        "include_thermostat_energy", False
 27    )
 28
 29    kT = system_data.get("kT", None)
 30    nbeads = system_data.get("nbeads", None)
 31    mass = system_data["mass"]
 32    gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS
 33    species = system_data["species"]
 34
 35    if nbeads is not None:
 36        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 37        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 38
 39    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 40        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 41        assert kT is not None, "kT must be specified for QTB thermostat"
 42        assert gamma is not None, "gamma must be specified for QTB thermostat"
 43        rng_key, v_key = jax.random.split(rng_key)
 44        if nbeads is None:
 45            a1 = math.exp(-gamma * dt)
 46            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 47            vel = (
 48                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 49                * (kT / mass[:, None]) ** 0.5
 50            )
 51        else:
 52            if isinstance(gamma, float):
 53                gamma = np.array([gamma] * nbeads)
 54            assert isinstance(
 55                gamma, np.ndarray
 56            ), "gamma must be a float or a numpy array"
 57            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 58            a1 = np.exp(-gamma * dt)[:, None, None]
 59            a2 = jnp.asarray(
 60                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 61            )
 62            vel = (
 63                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 64                * (kT / mass[:, None]) ** 0.5
 65            )
 66
 67        state["rng_key"] = rng_key
 68        if compute_thermostat_energy:
 69            state["thermostat_energy"] = 0.0
 70        if thermostat_name == "FFLGV":
 71            def thermostat(vel, state):
 72                rng_key, noise_key = jax.random.split(state["rng_key"])
 73                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 74                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 75                dirvel = vel / norm_vel
 76                if compute_thermostat_energy:
 77                    v2 = (vel**2).sum(axis=-1)
 78                vel = a1 * vel + a2 * noise
 79                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 80                vel = dirvel * new_norm_vel
 81                new_state = {**state, "rng_key": rng_key}
 82                if compute_thermostat_energy:
 83                    v2new = (vel**2).sum(axis=-1)
 84                    new_state["thermostat_energy"] = (
 85                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
 86                    )
 87
 88                return vel, new_state
 89
 90        else:
 91            def thermostat(vel, state):
 92                rng_key, noise_key = jax.random.split(state["rng_key"])
 93                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 94                if compute_thermostat_energy:
 95                    v2 = (vel**2).sum(axis=-1)
 96                vel = a1 * vel + a2 * noise
 97                new_state = {**state, "rng_key": rng_key}
 98                if compute_thermostat_energy:
 99                    v2new = (vel**2).sum(axis=-1)
100                    new_state["thermostat_energy"] = (
101                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
102                    )
103                return vel, new_state
104
105    elif thermostat_name in ["BUSSI"]:
106        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
107        assert kT is not None, "kT must be specified for QTB thermostat"
108        assert gamma is not None, "gamma must be specified for QTB thermostat"
109        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
110        rng_key, v_key = jax.random.split(rng_key)
111
112        a1 = math.exp(-gamma * dt)
113        a2 = (1 - a1) * kT
114        vel = (
115            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
116            * (kT / mass[:, None]) ** 0.5
117        )
118
119        state["rng_key"] = rng_key
120        if compute_thermostat_energy:
121            state["thermostat_energy"] = 0.0
122
123        def thermostat(vel, state):
124            rng_key, noise_key = jax.random.split(state["rng_key"])
125            new_state = {**state, "rng_key": rng_key}
126            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
127            R2 = jnp.sum(noise**2)
128            R1 = noise[0, 0]
129            c = a2 / (mass[:, None] * vel**2).sum()
130            d = (a1 * c) ** 0.5
131            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
132            if compute_thermostat_energy:
133                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
134                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
135            return scale * vel, new_state
136
137    elif thermostat_name in [
138        "GD",
139        "DESCENT",
140        "GRADIENT_DESCENT",
141        "MIN",
142        "MINIMIZE",
143    ]:
144        assert nbeads is None, "Gradient descent is not compatible with PIMD"
145        a1 = math.exp(-gamma * dt)
146
147        if nbeads is None:
148            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
149        else:
150            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
151
152        def thermostat(vel, state):
153            return a1 * vel, state
154
155    elif thermostat_name in ["NVE", "NONE"]:
156        if nbeads is None:
157            vel = (
158                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
159                * (kT / mass[:, None]) ** 0.5
160            )
161            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
162            vel = vel * (kT / kTsys) ** 0.5
163        else:
164            vel = (
165                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
166                * (kT / mass[None, :, None]) ** 0.5
167            )
168            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
169                mass.shape[0] * 3
170            )
171            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
172        thermostat = lambda x, s: (x, s)
173
174    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
175        assert gamma is not None, "gamma must be specified for QTB thermostat"
176        ndof = mass.shape[0] * 3
177        nkT = ndof * kT
178        nose_mass = nkT / gamma**2
179        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
180        state["nose_s"] = 0.0
181        state["nose_v"] = 0.0
182        if compute_thermostat_energy:
183            state["thermostat_energy"] = 0.0
184        print(
185            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
186        )
187        vel = (
188            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
189            * (kT / mass[:, None]) ** 0.5
190        )
191
192        def thermostat(vel, state):
193            nose_s = state["nose_s"]
194            nose_v = state["nose_v"]
195            kTsys = jnp.sum(mass[:, None] * vel**2)
196            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
197            nose_s = nose_s + dt * nose_v
198            vel = jnp.exp(-nose_v * dt) * vel
199            kTsys = jnp.sum(mass[:, None] * vel**2)
200            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
201            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
202
203            if compute_thermostat_energy:
204                new_state["thermostat_energy"] = (
205                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
206                )
207            return vel, new_state
208
209    elif thermostat_name in ["QTB", "ADQTB"]:
210        assert nbeads is None, "QTB is not compatible with PIMD"
211        qtb_parameters = simulation_parameters.get("qtb", None)
212        assert (
213            qtb_parameters is not None
214        ), "qtb_parameters must be provided for QTB thermostat"
215        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
216        assert kT is not None, "kT must be specified for QTB thermostat"
217        assert gamma is not None, "gamma must be specified for QTB thermostat"
218        assert species is not None, "species must be provided for QTB thermostat"
219        rng_key, v_key = jax.random.split(rng_key)
220        vel = (
221            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
222            * (kT / mass[:, None]) ** 0.5
223        )
224
225        thermostat, postprocess, qtb_state = initialize_qtb(
226            qtb_parameters,
227            system_data,
228            fprec=fprec,
229            dt=dt,
230            mass=mass,
231            gamma=gamma,
232            kT=kT,
233            species=species,
234            rng_key=rng_key,
235            adaptive=thermostat_name.startswith("AD"),
236            compute_thermostat_energy=compute_thermostat_energy,
237        )
238        state = {**state, **qtb_state}
239
240    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
241        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
242        assert kT is not None, "kT must be specified for QTB thermostat"
243        assert gamma is not None, "gamma must be specified for QTB thermostat"
244        assert nbeads is None, "ANNEAL is not compatible with PIMD"
245        a1 = math.exp(-gamma * dt)
246        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
247
248        anneal_parameters = simulation_parameters.get("annealing", {})
249        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
250        assert init_factor > 0.0, "init_factor must be positive"
251        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
252        assert final_factor > 0.0, "final_factor must be positive"
253        nsteps = simulation_parameters.get("nsteps")
254        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
255        assert (
256            anneal_steps < 1.0 and anneal_steps > 0.0
257        ), "warmup_steps must be between 0 and nsteps"
258        pct_start = anneal_parameters.get("warmup_steps", 0.3)
259        assert (
260            pct_start < 1.0 and pct_start > 0.0
261        ), "warmup_steps must be between 0 and nsteps"
262
263        anneal_type = anneal_parameters.get("type", "cosine").lower()
264        if anneal_type == "linear":
265            schedule = optax.linear_onecycle_schedule(
266                peak_value=1.0,
267                div_factor=1.0 / init_factor,
268                final_div_factor=1.0 / final_factor,
269                transition_steps=int(anneal_steps * nsteps),
270                pct_start=pct_start,
271                pct_final=1.0,
272            )
273        elif anneal_type == "cosine_onecycle":
274            schedule = optax.cosine_onecycle_schedule(
275                peak_value=1.0,
276                div_factor=1.0 / init_factor,
277                final_div_factor=1.0 / final_factor,
278                transition_steps=int(anneal_steps * nsteps),
279                pct_start=pct_start,
280            )
281        else:
282            raise ValueError(f"Unknown anneal_type {anneal_type}")
283
284        state["rng_key"] = rng_key
285        state["istep_anneal"] = 0
286
287        rng_key, v_key = jax.random.split(rng_key)
288        Tscale = schedule(0)
289        print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K")
290        vel = (
291            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
292            * (kT * Tscale / mass[:, None]) ** 0.5
293        )
294
295        def thermostat(vel, state):
296            rng_key, noise_key = jax.random.split(state["rng_key"])
297            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
298
299            Tscale = schedule(state["istep_anneal"]) ** 0.5
300            vel = a1 * vel + a2 * Tscale * noise
301            return vel, {
302                **state,
303                "rng_key": rng_key,
304                "istep_anneal": state["istep_anneal"] + 1,
305            }
306
307    else:
308        raise ValueError(f"Unknown thermostat {thermostat_name}")
309
310    return thermostat, postprocess, state, vel,thermostat_name
311
312
313def initialize_qtb(
314    qtb_parameters,
315    system_data,
316    fprec,
317    dt,
318    mass,
319    gamma,
320    kT,
321    species,
322    rng_key,
323    adaptive,
324    compute_thermostat_energy=False,
325):
326    state = {}
327    post_state = {}
328    verbose = qtb_parameters.get("verbose", False)
329    if compute_thermostat_energy:
330        state["thermostat_energy"] = 0.0
331
332    mass = jnp.asarray(mass, dtype=fprec)
333
334    nat = species.shape[0]
335    # define type indices
336    species_set = set(species)
337    nspecies = len(species_set)
338    idx = {sp: i for i, sp in enumerate(species_set)}
339    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
340
341    n_of_type = np.zeros(nspecies, dtype=np.int32)
342    for i in range(nspecies):
343        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
344    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
345    mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type
346
347    niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7)
348    niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20)
349    corr_kin = qtb_parameters.get("corr_kin", -1)
350    do_corr_kin = corr_kin <= 0
351    if do_corr_kin:
352        corr_kin = 1.0
353    state["corr_kin"] = corr_kin
354    post_state["corr_kin_prev"] = corr_kin
355    post_state["do_corr_kin"] = do_corr_kin
356    post_state["isame_kin"] = 0
357
358    # spectra parameters
359    omegasmear = np.pi / dt / 100.0
360    Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS
361    nseg = int(Tseg / dt)
362    Tseg = nseg * dt
363    dom = 2 * np.pi / (3 * Tseg)
364    omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS
365    nom = int(omegacut / dom)
366    omega = dom * np.arange((3 * nseg) // 2 + 1)
367    cutoff = jnp.asarray(
368        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
369    )
370    assert (
371        omegacut < omega[-1]
372    ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1"
373
374    # initialize gammar
375    assert (
376        gamma < 0.5 * omegacut
377    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
378    gammar_min = qtb_parameters.get("gammar_min", 0.1)
379    # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
380    gammar = np.ones((nspecies, nom), dtype=float)
381    try:
382        for i, sp in enumerate(species_set):
383            if not os.path.exists(f"QTB_spectra_{sp}.out"): continue
384            data = np.loadtxt(f"QTB_spectra_{sp}.out")
385            gammar[i] = data[:, 4]/(gamma*au.FS*au.THZ)
386            print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out")
387    except Exception as e:
388        print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.")
389        gammar[:,:] = 1.0
390    post_state["gammar"] = jnp.asarray(gammar, dtype=fprec)
391
392    # Ornstein-Uhlenbeck correction for colored noise
393    a1 = np.exp(-gamma * dt)
394    OUcorr = jnp.asarray(
395        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
396        dtype=fprec,
397    )
398
399    # hbar schedule
400    classical_kernel = qtb_parameters.get("classical_kernel", False)
401    hbar = qtb_parameters.get("hbar", 1.0) * au.FS
402    u = 0.5 * hbar * np.abs(omega) / kT
403    theta = kT * np.ones_like(omega)
404    if hbar > 0:
405        theta[1:] *= u[1:] / np.tanh(u[1:])
406    theta = jnp.asarray(theta, dtype=fprec)
407
408    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
409    del rng_key
410    post_state["white_noise"] = jax.random.normal(
411        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
412    )
413
414    startsave = qtb_parameters.get("startsave", 1)
415    counter = Counter(nseg, startsave=startsave)
416    state["istep"] = 0
417    post_state["nadapt"] = 0
418    post_state["nsample"] = 0
419
420    write_spectra = qtb_parameters.get("write_spectra", True)
421    do_compute_spectra = write_spectra or adaptive
422
423    if do_compute_spectra:
424        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
425
426        post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec)
427        post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec)
428        post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec)
429        post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec)
430        post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
431        post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
432        post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
433        post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
434
435    if not adaptive:
436        update_gammar = lambda x: x
437    else:
438        # adaptation parameters
439        skipseg = qtb_parameters.get("skipseg", 1)
440
441        adaptation_method = (
442            str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip()
443        )
444        authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"]
445        assert (
446            adaptation_method in authorized_methods
447        ), f"adaptation_method must be one of {authorized_methods}"
448        if adaptation_method == "SIMPLE":
449            agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS
450            assert agamma > 0, "agamma must be positive"
451            a1_ad = agamma * Tseg  #  * gamma
452            print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}")
453
454            def update_gammar(post_state):
455                g = post_state["dFDT"]
456                gammar = post_state["gammar"] - a1_ad * g
457                gammar = jnp.maximum(gammar_min, gammar)
458                return {**post_state, "gammar": gammar}
459
460        elif adaptation_method == "RATIO":
461            tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS
462            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS
463            assert tau_ad > 0, "tau_ad must be positive"
464            print(
465                f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps"
466            )
467            b1 = np.exp(-Tseg / tau_ad)
468            b2 = np.exp(-Tseg / tau_s)
469            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
470            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
471            post_state["n_adabelief"] = 0
472
473            def update_gammar(post_state):
474                n_adabelief = post_state["n_adabelief"] + 1
475                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
476                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
477                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
478                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
479                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
480                gammar = Cvf / (mCvv + 1.0e-8)
481                gammar = jnp.maximum(gammar_min, gammar)
482                return {
483                    **post_state,
484                    "gammar": gammar,
485                    "mCvv_m": mCvv_m,
486                    "Cvf_m": Cvf_m,
487                    "n_adabelief": n_adabelief,
488                }
489
490        elif adaptation_method == "ADABELIEF":
491            agamma = qtb_parameters.get("agamma", 0.1)
492            tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS
493            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS
494            assert tau_ad > 0, "tau_ad must be positive"
495            assert tau_s > 0, "tau_s must be positive"
496            assert agamma > 0, "agamma must be positive"
497            print(
498                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps"
499            )
500
501            a1_ad = agamma * gamma  # * Tseg #* gamma
502            b1 = np.exp(-Tseg / tau_ad)
503            b2 = np.exp(-Tseg / tau_s)
504            post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
505            post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec)
506            post_state["n_adabelief"] = 0
507
508            def update_gammar(post_state):
509                n_adabelief = post_state["n_adabelief"] + 1
510                dFDT = post_state["dFDT"]
511                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
512                dFDT_s = (
513                    post_state["dFDT_s"] * b2
514                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
515                    + 1.0e-8
516                )
517                # bias correction
518                mt = dFDT_m / (1.0 - b1**n_adabelief)
519                st = dFDT_s / (1.0 - b2**n_adabelief)
520                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
521                gammar = jnp.maximum(gammar_min, gammar)
522                return {
523                    **post_state,
524                    "gammar": gammar,
525                    "dFDT_m": dFDT_m,
526                    "n_adabelief": n_adabelief,
527                    "dFDT_s": dFDT_s,
528                }
529    
530    #####################
531    # RESTART
532    restart_file = system_data["name"]+".qtb.restart"
533    if os.path.exists(restart_file):
534        with open(restart_file, "rb") as f:
535            data = pickle.load(f)
536            state["corr_kin"] = data["corr_kin"]
537            post_state["corr_kin_prev"] = data["corr_kin"]
538            post_state["isame_kin"] = data["isame_kin"]
539            post_state["do_corr_kin"] = data["do_corr_kin"]
540            print(f"# Restored QTB state from {restart_file}")
541
542    def write_qtb_restart(state, post_state):
543        with open(restart_file, "wb") as f:
544            pickle.dump(
545                {
546                    "corr_kin": state["corr_kin"],
547                    "corr_kin_prev": post_state["corr_kin_prev"],
548                    "isame_kin": post_state["isame_kin"],
549                    "do_corr_kin": post_state["do_corr_kin"],
550                },
551                f,
552            )
553    ######################
554
555    def compute_corr_pot(niter=20, verbose=False):
556        if classical_kernel or hbar == 0:
557            return np.ones(nom)
558
559        s_0 = np.array((theta / kT * cutoff)[:nom])
560        s_out, s_rec, _ = deconvolute_spectrum(
561            s_0,
562            omega[:nom],
563            gamma,
564            niter,
565            kernel=kernel_lorentz_pot,
566            trans=True,
567            symmetrize=True,
568            verbose=verbose,
569        )
570        corr_pot = 1.0 + (s_out - s_0) / s_0
571        columns = np.column_stack(
572            (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
573        )
574        np.savetxt(
575            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
576        )
577        return corr_pot
578
579    def compute_corr_kin(post_state, niter=7, verbose=False):
580        if not post_state["do_corr_kin"]:
581            return post_state["corr_kin_prev"], post_state
582        if classical_kernel or hbar == 0:
583            return 1.0, post_state
584
585        K_D = post_state.get("K_D", None)
586        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
587        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
588        s_out, s_rec, K_D = deconvolute_spectrum(
589            s_0,
590            omega[:nom],
591            gamma,
592            niter,
593            kernel=kernel_lorentz,
594            trans=False,
595            symmetrize=True,
596            verbose=verbose,
597            K_D=K_D,
598        )
599        s_out = s_out * theta[:nom] / kT
600        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
601        mCvvsum = mCvv.sum()
602        rec_ratio = mCvvsum / s_rec.sum()
603        if rec_ratio < 0.95 or rec_ratio > 1.05:
604            print(
605                f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated"
606            )
607            return post_state["corr_kin_prev"], post_state
608
609        corr_kin = mCvvsum / s_out.sum()
610        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
611            isame_kin = post_state["isame_kin"] + 1
612        else:
613            isame_kin = 0
614
615        # print("# corr_kin: ", corr_kin)
616        do_corr_kin = post_state["do_corr_kin"]
617        if isame_kin > 10:
618            print(
619                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
620            )
621            do_corr_kin = False
622
623        return corr_kin, {
624            **post_state,
625            "corr_kin_prev": corr_kin,
626            "isame_kin": isame_kin,
627            "do_corr_kin": do_corr_kin,
628            "K_D": K_D,
629        }
630
631    @jax.jit
632    def ff_kernel(post_state):
633        if classical_kernel:
634            kernel = cutoff * (2 * gamma * kT / dt)
635        else:
636            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
637        gamma_ratio = jnp.concatenate(
638            (
639                post_state["gammar"].T * post_state["corr_pot"][:, None],
640                jnp.ones(
641                    (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype
642                ),
643            ),
644            axis=0,
645        )
646        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
647
648    @jax.jit
649    def refresh_force(post_state):
650        rng_key, noise_key = jax.random.split(post_state["rng_key"])
651        white_noise = jnp.concatenate(
652            (
653                post_state["white_noise"][nseg:],
654                jax.random.normal(
655                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
656                ),
657            ),
658            axis=0,
659        )
660        amplitude = ff_kernel(post_state) ** 0.5
661        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
662        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
663        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
664
665    @jax.jit
666    def compute_spectra(force, vel, post_state):
667        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
668        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
669        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
670        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
671        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
672
673        mCvv = (
674            (dt / 3.0)
675            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
676            * mass_idx[:, None]
677            / n_of_type[:, None]
678        )
679        Cvf = (
680            (dt / 3.0)
681            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
682            / n_of_type[:, None]
683        )
684        Cff = (
685            (dt / 3.0)
686            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
687            / n_of_type[:, None]
688        )
689        dFDT = mCvv * post_state["gammar"] - Cvf
690
691        nsinv = 1.0 / post_state["nsample"]
692        b1 = 1.0 - nsinv
693        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
694        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
695        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
696        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
697
698        return {
699            **post_state,
700            "mCvv": mCvv,
701            "Cvf": Cvf,
702            "Cff": Cff,
703            "dFDT": dFDT,
704            "dFDT_avg": dFDT_avg,
705            "mCvv_avg": mCvv_avg,
706            "Cvfg_avg": Cvfg_avg,
707            "Cff_avg": Cff_avg,
708        }
709
710    def write_spectra_to_file(post_state):
711        mCvv_avg = np.array(post_state["mCvv_avg"])
712        Cvfg_avg = np.array(post_state["Cvfg_avg"])
713        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
714        dFDT_avg = np.array(post_state["dFDT_avg"])
715        gammar = np.array(post_state["gammar"])
716        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
717        for i, sp in enumerate(species_set):
718            ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i])
719            columns = np.column_stack(
720                (
721                    omega[:nom] * (au.FS * au.CM1),
722                    mCvv_avg[i],
723                    Cvfg_avg[i],
724                    dFDT_avg[i],
725                    gammar[i] * gamma * (au.FS * au.THZ),
726                    Cff_avg[i] * ff_scale,
727                    Cff_theo[i] * ff_scale,
728                )
729            )
730            np.savetxt(
731                f"QTB_spectra_{sp}.out",
732                columns,
733                fmt="%12.6f",
734                header="#omega mCvv Cvf dFDT gammar Cff",
735            )
736        if verbose:
737            print("# QTB spectra written.")
738
739    if compute_thermostat_energy:
740        state["qtb_energy_flux"] = 0.0
741
742    @jax.jit
743    def thermostat(vel, state):
744        istep = state["istep"]
745        dvel = dt * state["force"][istep] / mass[:, None]
746        new_vel = vel * a1 + dvel
747        new_state = {**state, "istep": istep + 1}
748        if do_compute_spectra:
749            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
750            new_state["vel"] = vel2
751        if compute_thermostat_energy:
752            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
753            ekcorr = (
754                0.5
755                * (mass[:, None] * new_vel**2).sum()
756                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
757            )
758            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
759            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
760        return new_vel, new_state
761
762    @jax.jit
763    def postprocess_work(state, post_state):
764        if do_compute_spectra:
765            post_state = compute_spectra(state["force"], state["vel"], post_state)
766        if adaptive:
767            post_state = jax.lax.cond(
768                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
769            )
770        new_force, post_state = refresh_force(post_state)
771        return {**state, "force": new_force}, post_state
772
773    def postprocess(state, post_state):
774        counter.increment()
775        if not counter.is_reset_step:
776            return state, post_state
777        post_state["nadapt"] += 1
778        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
779        if verbose:
780            print("# Refreshing QTB forces.")
781        state, post_state = postprocess_work(state, post_state)
782        state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin)
783        state["istep"] = 0
784        if write_spectra:
785            write_spectra_to_file(post_state)
786        write_qtb_restart(state, post_state)
787        return state, post_state
788
789    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec)
790
791    state["force"], post_state = refresh_force(post_state)
792    return thermostat, (postprocess, post_state), state
def get_thermostat( simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 20def get_thermostat(simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 21    state = {}
 22    postprocess = None
 23    
 24
 25    thermostat_name = str(simulation_parameters.get("thermostat", "LGV")).upper()
 26    compute_thermostat_energy = simulation_parameters.get(
 27        "include_thermostat_energy", False
 28    )
 29
 30    kT = system_data.get("kT", None)
 31    nbeads = system_data.get("nbeads", None)
 32    mass = system_data["mass"]
 33    gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS
 34    species = system_data["species"]
 35
 36    if nbeads is not None:
 37        trpmd_lambda = simulation_parameters.get("trpmd_lambda", 1.0)
 38        gamma = np.maximum(trpmd_lambda * system_data["omk"], gamma)
 39
 40    if thermostat_name in ["LGV", "LANGEVIN", "FFLGV"]:
 41        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
 42        assert kT is not None, "kT must be specified for QTB thermostat"
 43        assert gamma is not None, "gamma must be specified for QTB thermostat"
 44        rng_key, v_key = jax.random.split(rng_key)
 45        if nbeads is None:
 46            a1 = math.exp(-gamma * dt)
 47            a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
 48            vel = (
 49                jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
 50                * (kT / mass[:, None]) ** 0.5
 51            )
 52        else:
 53            if isinstance(gamma, float):
 54                gamma = np.array([gamma] * nbeads)
 55            assert isinstance(
 56                gamma, np.ndarray
 57            ), "gamma must be a float or a numpy array"
 58            assert gamma.shape[0] == nbeads, "gamma must have the same length as nbeads"
 59            a1 = np.exp(-gamma * dt)[:, None, None]
 60            a2 = jnp.asarray(
 61                ((1 - a1 * a1) * kT / mass[None, :, None]) ** 0.5, dtype=fprec
 62            )
 63            vel = (
 64                jax.random.normal(v_key, (nbeads, mass.shape[0], 3), dtype=fprec)
 65                * (kT / mass[:, None]) ** 0.5
 66            )
 67
 68        state["rng_key"] = rng_key
 69        if compute_thermostat_energy:
 70            state["thermostat_energy"] = 0.0
 71        if thermostat_name == "FFLGV":
 72            def thermostat(vel, state):
 73                rng_key, noise_key = jax.random.split(state["rng_key"])
 74                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 75                norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 76                dirvel = vel / norm_vel
 77                if compute_thermostat_energy:
 78                    v2 = (vel**2).sum(axis=-1)
 79                vel = a1 * vel + a2 * noise
 80                new_norm_vel = jnp.linalg.norm(vel, axis=-1, keepdims=True)
 81                vel = dirvel * new_norm_vel
 82                new_state = {**state, "rng_key": rng_key}
 83                if compute_thermostat_energy:
 84                    v2new = (vel**2).sum(axis=-1)
 85                    new_state["thermostat_energy"] = (
 86                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
 87                    )
 88
 89                return vel, new_state
 90
 91        else:
 92            def thermostat(vel, state):
 93                rng_key, noise_key = jax.random.split(state["rng_key"])
 94                noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
 95                if compute_thermostat_energy:
 96                    v2 = (vel**2).sum(axis=-1)
 97                vel = a1 * vel + a2 * noise
 98                new_state = {**state, "rng_key": rng_key}
 99                if compute_thermostat_energy:
100                    v2new = (vel**2).sum(axis=-1)
101                    new_state["thermostat_energy"] = (
102                        state["thermostat_energy"] + 0.5 * (mass * (v2 - v2new)).sum()
103                    )
104                return vel, new_state
105
106    elif thermostat_name in ["BUSSI"]:
107        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
108        assert kT is not None, "kT must be specified for QTB thermostat"
109        assert gamma is not None, "gamma must be specified for QTB thermostat"
110        assert nbeads is None, "Bussi thermostat is not compatible with PIMD"
111        rng_key, v_key = jax.random.split(rng_key)
112
113        a1 = math.exp(-gamma * dt)
114        a2 = (1 - a1) * kT
115        vel = (
116            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
117            * (kT / mass[:, None]) ** 0.5
118        )
119
120        state["rng_key"] = rng_key
121        if compute_thermostat_energy:
122            state["thermostat_energy"] = 0.0
123
124        def thermostat(vel, state):
125            rng_key, noise_key = jax.random.split(state["rng_key"])
126            new_state = {**state, "rng_key": rng_key}
127            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
128            R2 = jnp.sum(noise**2)
129            R1 = noise[0, 0]
130            c = a2 / (mass[:, None] * vel**2).sum()
131            d = (a1 * c) ** 0.5
132            scale = (a1 + c * R2 + 2 * d * R1) ** 0.5
133            if compute_thermostat_energy:
134                dek = 0.5 * (mass[:, None] * vel**2).sum() * (scale**2 - 1)
135                new_state["thermostat_energy"] = state["thermostat_energy"] + dek
136            return scale * vel, new_state
137
138    elif thermostat_name in [
139        "GD",
140        "DESCENT",
141        "GRADIENT_DESCENT",
142        "MIN",
143        "MINIMIZE",
144    ]:
145        assert nbeads is None, "Gradient descent is not compatible with PIMD"
146        a1 = math.exp(-gamma * dt)
147
148        if nbeads is None:
149            vel = jnp.zeros((mass.shape[0], 3), dtype=fprec)
150        else:
151            vel = jnp.zeros((nbeads, mass.shape[0], 3), dtype=fprec)
152
153        def thermostat(vel, state):
154            return a1 * vel, state
155
156    elif thermostat_name in ["NVE", "NONE"]:
157        if nbeads is None:
158            vel = (
159                jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
160                * (kT / mass[:, None]) ** 0.5
161            )
162            kTsys = jnp.sum(mass[:, None] * vel**2) / (mass.shape[0] * 3)
163            vel = vel * (kT / kTsys) ** 0.5
164        else:
165            vel = (
166                jax.random.normal(rng_key, (nbeads, mass.shape[0], 3), dtype=fprec)
167                * (kT / mass[None, :, None]) ** 0.5
168            )
169            kTsys = jnp.sum(mass[None, :, None] * vel**2, axis=(1, 2)) / (
170                mass.shape[0] * 3
171            )
172            vel = vel * (kT / kTsys[:, None, None]) ** 0.5
173        thermostat = lambda x, s: (x, s)
174
175    elif thermostat_name in ["NOSE", "NOSEHOOVER", "NOSE_HOOVER"]:
176        assert gamma is not None, "gamma must be specified for QTB thermostat"
177        ndof = mass.shape[0] * 3
178        nkT = ndof * kT
179        nose_mass = nkT / gamma**2
180        assert nbeads is None, "Nose-Hoover is not compatible with PIMD"
181        state["nose_s"] = 0.0
182        state["nose_v"] = 0.0
183        if compute_thermostat_energy:
184            state["thermostat_energy"] = 0.0
185        print(
186            "# WARNING: Nose-Hoover thermostat is not well tested yet. Energy conservation is not guaranteed."
187        )
188        vel = (
189            jax.random.normal(rng_key, (mass.shape[0], 3), dtype=fprec)
190            * (kT / mass[:, None]) ** 0.5
191        )
192
193        def thermostat(vel, state):
194            nose_s = state["nose_s"]
195            nose_v = state["nose_v"]
196            kTsys = jnp.sum(mass[:, None] * vel**2)
197            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
198            nose_s = nose_s + dt * nose_v
199            vel = jnp.exp(-nose_v * dt) * vel
200            kTsys = jnp.sum(mass[:, None] * vel**2)
201            nose_v = nose_v + (0.5 * dt / nose_mass) * (kTsys - nkT)
202            new_state = {**state, "nose_s": nose_s, "nose_v": nose_v}
203
204            if compute_thermostat_energy:
205                new_state["thermostat_energy"] = (
206                    nkT * nose_s + (0.5 * nose_mass) * nose_v**2
207                )
208            return vel, new_state
209
210    elif thermostat_name in ["QTB", "ADQTB"]:
211        assert nbeads is None, "QTB is not compatible with PIMD"
212        qtb_parameters = simulation_parameters.get("qtb", None)
213        assert (
214            qtb_parameters is not None
215        ), "qtb_parameters must be provided for QTB thermostat"
216        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
217        assert kT is not None, "kT must be specified for QTB thermostat"
218        assert gamma is not None, "gamma must be specified for QTB thermostat"
219        assert species is not None, "species must be provided for QTB thermostat"
220        rng_key, v_key = jax.random.split(rng_key)
221        vel = (
222            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
223            * (kT / mass[:, None]) ** 0.5
224        )
225
226        thermostat, postprocess, qtb_state = initialize_qtb(
227            qtb_parameters,
228            system_data,
229            fprec=fprec,
230            dt=dt,
231            mass=mass,
232            gamma=gamma,
233            kT=kT,
234            species=species,
235            rng_key=rng_key,
236            adaptive=thermostat_name.startswith("AD"),
237            compute_thermostat_energy=compute_thermostat_energy,
238        )
239        state = {**state, **qtb_state}
240
241    elif thermostat_name in ["ANNEAL", "ANNEALING"]:
242        assert rng_key is not None, "rng_key must be provided for QTB thermostat"
243        assert kT is not None, "kT must be specified for QTB thermostat"
244        assert gamma is not None, "gamma must be specified for QTB thermostat"
245        assert nbeads is None, "ANNEAL is not compatible with PIMD"
246        a1 = math.exp(-gamma * dt)
247        a2 = jnp.asarray(((1 - a1 * a1) * kT / mass[:, None]) ** 0.5, dtype=fprec)
248
249        anneal_parameters = simulation_parameters.get("annealing", {})
250        init_factor = anneal_parameters.get("init_factor", 1.0 / 25.0)
251        assert init_factor > 0.0, "init_factor must be positive"
252        final_factor = anneal_parameters.get("final_factor", 1.0 / 10000.0)
253        assert final_factor > 0.0, "final_factor must be positive"
254        nsteps = simulation_parameters.get("nsteps")
255        anneal_steps = anneal_parameters.get("anneal_steps", 1.0)
256        assert (
257            anneal_steps < 1.0 and anneal_steps > 0.0
258        ), "warmup_steps must be between 0 and nsteps"
259        pct_start = anneal_parameters.get("warmup_steps", 0.3)
260        assert (
261            pct_start < 1.0 and pct_start > 0.0
262        ), "warmup_steps must be between 0 and nsteps"
263
264        anneal_type = anneal_parameters.get("type", "cosine").lower()
265        if anneal_type == "linear":
266            schedule = optax.linear_onecycle_schedule(
267                peak_value=1.0,
268                div_factor=1.0 / init_factor,
269                final_div_factor=1.0 / final_factor,
270                transition_steps=int(anneal_steps * nsteps),
271                pct_start=pct_start,
272                pct_final=1.0,
273            )
274        elif anneal_type == "cosine_onecycle":
275            schedule = optax.cosine_onecycle_schedule(
276                peak_value=1.0,
277                div_factor=1.0 / init_factor,
278                final_div_factor=1.0 / final_factor,
279                transition_steps=int(anneal_steps * nsteps),
280                pct_start=pct_start,
281            )
282        else:
283            raise ValueError(f"Unknown anneal_type {anneal_type}")
284
285        state["rng_key"] = rng_key
286        state["istep_anneal"] = 0
287
288        rng_key, v_key = jax.random.split(rng_key)
289        Tscale = schedule(0)
290        print(f"# ANNEAL: initial temperature = {Tscale*kT*au.KELVIN:.3e} K")
291        vel = (
292            jax.random.normal(v_key, (mass.shape[0], 3), dtype=fprec)
293            * (kT * Tscale / mass[:, None]) ** 0.5
294        )
295
296        def thermostat(vel, state):
297            rng_key, noise_key = jax.random.split(state["rng_key"])
298            noise = jax.random.normal(noise_key, vel.shape, dtype=vel.dtype)
299
300            Tscale = schedule(state["istep_anneal"]) ** 0.5
301            vel = a1 * vel + a2 * Tscale * noise
302            return vel, {
303                **state,
304                "rng_key": rng_key,
305                "istep_anneal": state["istep_anneal"] + 1,
306            }
307
308    else:
309        raise ValueError(f"Unknown thermostat {thermostat_name}")
310
311    return thermostat, postprocess, state, vel,thermostat_name
def initialize_qtb( qtb_parameters, system_data, fprec, dt, mass, gamma, kT, species, rng_key, adaptive, compute_thermostat_energy=False):
314def initialize_qtb(
315    qtb_parameters,
316    system_data,
317    fprec,
318    dt,
319    mass,
320    gamma,
321    kT,
322    species,
323    rng_key,
324    adaptive,
325    compute_thermostat_energy=False,
326):
327    state = {}
328    post_state = {}
329    verbose = qtb_parameters.get("verbose", False)
330    if compute_thermostat_energy:
331        state["thermostat_energy"] = 0.0
332
333    mass = jnp.asarray(mass, dtype=fprec)
334
335    nat = species.shape[0]
336    # define type indices
337    species_set = set(species)
338    nspecies = len(species_set)
339    idx = {sp: i for i, sp in enumerate(species_set)}
340    type_idx = np.array([idx[sp] for sp in species], dtype=np.int32)
341
342    n_of_type = np.zeros(nspecies, dtype=np.int32)
343    for i in range(nspecies):
344        n_of_type[i] = (type_idx == i).nonzero()[0].shape[0]
345    n_of_type = jnp.asarray(n_of_type, dtype=fprec)
346    mass_idx = jax.ops.segment_sum(mass, type_idx, nspecies) / n_of_type
347
348    niter_deconv_kin = qtb_parameters.get("niter_deconv_kin", 7)
349    niter_deconv_pot = qtb_parameters.get("niter_deconv_pot", 20)
350    corr_kin = qtb_parameters.get("corr_kin", -1)
351    do_corr_kin = corr_kin <= 0
352    if do_corr_kin:
353        corr_kin = 1.0
354    state["corr_kin"] = corr_kin
355    post_state["corr_kin_prev"] = corr_kin
356    post_state["do_corr_kin"] = do_corr_kin
357    post_state["isame_kin"] = 0
358
359    # spectra parameters
360    omegasmear = np.pi / dt / 100.0
361    Tseg = qtb_parameters.get("tseg", 1.0 / au.PS) * au.FS
362    nseg = int(Tseg / dt)
363    Tseg = nseg * dt
364    dom = 2 * np.pi / (3 * Tseg)
365    omegacut = qtb_parameters.get("omegacut", 15000.0 / au.CM1) / au.FS
366    nom = int(omegacut / dom)
367    omega = dom * np.arange((3 * nseg) // 2 + 1)
368    cutoff = jnp.asarray(
369        1.0 / (1.0 + np.exp((omega - omegacut) / omegasmear)), dtype=fprec
370    )
371    assert (
372        omegacut < omega[-1]
373    ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1"
374
375    # initialize gammar
376    assert (
377        gamma < 0.5 * omegacut
378    ), "gamma must be much smaller than omegacut (at most 0.5*omegacut)"
379    gammar_min = qtb_parameters.get("gammar_min", 0.1)
380    # post_state["gammar"] = jnp.asarray(np.ones((nspecies, nom)), dtype=fprec)
381    gammar = np.ones((nspecies, nom), dtype=float)
382    try:
383        for i, sp in enumerate(species_set):
384            if not os.path.exists(f"QTB_spectra_{sp}.out"): continue
385            data = np.loadtxt(f"QTB_spectra_{sp}.out")
386            gammar[i] = data[:, 4]/(gamma*au.FS*au.THZ)
387            print(f"# Restored gammar for species {sp} from QTB_spectra_{sp}.out")
388    except Exception as e:
389        print(f"# Could not restore gammar for all species with exception {e}. Starting from scratch.")
390        gammar[:,:] = 1.0
391    post_state["gammar"] = jnp.asarray(gammar, dtype=fprec)
392
393    # Ornstein-Uhlenbeck correction for colored noise
394    a1 = np.exp(-gamma * dt)
395    OUcorr = jnp.asarray(
396        (1.0 - 2.0 * a1 * np.cos(omega * dt) + a1**2) / (dt**2 * (gamma**2 + omega**2)),
397        dtype=fprec,
398    )
399
400    # hbar schedule
401    classical_kernel = qtb_parameters.get("classical_kernel", False)
402    hbar = qtb_parameters.get("hbar", 1.0) * au.FS
403    u = 0.5 * hbar * np.abs(omega) / kT
404    theta = kT * np.ones_like(omega)
405    if hbar > 0:
406        theta[1:] *= u[1:] / np.tanh(u[1:])
407    theta = jnp.asarray(theta, dtype=fprec)
408
409    noise_key, post_state["rng_key"] = jax.random.split(rng_key)
410    del rng_key
411    post_state["white_noise"] = jax.random.normal(
412        noise_key, (3 * nseg, nat, 3), dtype=jnp.float32
413    )
414
415    startsave = qtb_parameters.get("startsave", 1)
416    counter = Counter(nseg, startsave=startsave)
417    state["istep"] = 0
418    post_state["nadapt"] = 0
419    post_state["nsample"] = 0
420
421    write_spectra = qtb_parameters.get("write_spectra", True)
422    do_compute_spectra = write_spectra or adaptive
423
424    if do_compute_spectra:
425        state["vel"] = jnp.zeros((nseg, nat, 3), dtype=fprec)
426
427        post_state["dFDT"] = jnp.zeros((nspecies, nom), dtype=fprec)
428        post_state["mCvv"] = jnp.zeros((nspecies, nom), dtype=fprec)
429        post_state["Cvf"] = jnp.zeros((nspecies, nom), dtype=fprec)
430        post_state["Cff"] = jnp.zeros((nspecies, nom), dtype=fprec)
431        post_state["dFDT_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
432        post_state["mCvv_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
433        post_state["Cvfg_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
434        post_state["Cff_avg"] = jnp.zeros((nspecies, nom), dtype=fprec)
435
436    if not adaptive:
437        update_gammar = lambda x: x
438    else:
439        # adaptation parameters
440        skipseg = qtb_parameters.get("skipseg", 1)
441
442        adaptation_method = (
443            str(qtb_parameters.get("adaptation_method", "ADABELIEF")).upper().strip()
444        )
445        authorized_methods = ["SIMPLE", "RATIO", "ADABELIEF"]
446        assert (
447            adaptation_method in authorized_methods
448        ), f"adaptation_method must be one of {authorized_methods}"
449        if adaptation_method == "SIMPLE":
450            agamma = qtb_parameters.get("agamma", 1.0e-3) / au.FS
451            assert agamma > 0, "agamma must be positive"
452            a1_ad = agamma * Tseg  #  * gamma
453            print(f"# ADQTB SIMPLE: agamma = {agamma*au.FS:.3f}")
454
455            def update_gammar(post_state):
456                g = post_state["dFDT"]
457                gammar = post_state["gammar"] - a1_ad * g
458                gammar = jnp.maximum(gammar_min, gammar)
459                return {**post_state, "gammar": gammar}
460
461        elif adaptation_method == "RATIO":
462            tau_ad = qtb_parameters.get("tau_ad", 5.0 / au.PS) * au.FS
463            tau_s = qtb_parameters.get("tau_s", 10 * tau_ad) * au.FS
464            assert tau_ad > 0, "tau_ad must be positive"
465            print(
466                f"# ADQTB RATIO: tau_ad = {tau_ad*1e-3:.2f} ps, tau_s = {tau_s*1e-3:.2f} ps"
467            )
468            b1 = np.exp(-Tseg / tau_ad)
469            b2 = np.exp(-Tseg / tau_s)
470            post_state["mCvv_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
471            post_state["Cvf_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
472            post_state["n_adabelief"] = 0
473
474            def update_gammar(post_state):
475                n_adabelief = post_state["n_adabelief"] + 1
476                mCvv_m = post_state["mCvv_m"] * b1 + post_state["mCvv"] * (1.0 - b1)
477                Cvf_m = post_state["Cvf_m"] * b2 + post_state["Cvf"] * (1.0 - b2)
478                mCvv = mCvv_m / (1.0 - b1**n_adabelief)
479                Cvf = Cvf_m / (1.0 - b2**n_adabelief)
480                # g = Cvf/(mCvv+1.e-8)-post_state["gammar"]
481                gammar = Cvf / (mCvv + 1.0e-8)
482                gammar = jnp.maximum(gammar_min, gammar)
483                return {
484                    **post_state,
485                    "gammar": gammar,
486                    "mCvv_m": mCvv_m,
487                    "Cvf_m": Cvf_m,
488                    "n_adabelief": n_adabelief,
489                }
490
491        elif adaptation_method == "ADABELIEF":
492            agamma = qtb_parameters.get("agamma", 0.1)
493            tau_ad = qtb_parameters.get("tau_ad", 1.0 / au.PS) * au.FS
494            tau_s = qtb_parameters.get("tau_s", 100 * tau_ad) * au.FS
495            assert tau_ad > 0, "tau_ad must be positive"
496            assert tau_s > 0, "tau_s must be positive"
497            assert agamma > 0, "agamma must be positive"
498            print(
499                f"# ADQTB ADABELIEF: agamma = {agamma:.3f}, tau_ad = {tau_ad*1.e-3:.2f} ps, tau_s = {tau_s*1.e-3:.2f} ps"
500            )
501
502            a1_ad = agamma * gamma  # * Tseg #* gamma
503            b1 = np.exp(-Tseg / tau_ad)
504            b2 = np.exp(-Tseg / tau_s)
505            post_state["dFDT_m"] = jnp.zeros((nspecies, nom), dtype=fprec)
506            post_state["dFDT_s"] = jnp.zeros((nspecies, nom), dtype=fprec)
507            post_state["n_adabelief"] = 0
508
509            def update_gammar(post_state):
510                n_adabelief = post_state["n_adabelief"] + 1
511                dFDT = post_state["dFDT"]
512                dFDT_m = post_state["dFDT_m"] * b1 + dFDT * (1.0 - b1)
513                dFDT_s = (
514                    post_state["dFDT_s"] * b2
515                    + (dFDT - dFDT_m) ** 2 * (1.0 - b2)
516                    + 1.0e-8
517                )
518                # bias correction
519                mt = dFDT_m / (1.0 - b1**n_adabelief)
520                st = dFDT_s / (1.0 - b2**n_adabelief)
521                gammar = post_state["gammar"] - a1_ad * mt / (st**0.5 + 1.0e-8)
522                gammar = jnp.maximum(gammar_min, gammar)
523                return {
524                    **post_state,
525                    "gammar": gammar,
526                    "dFDT_m": dFDT_m,
527                    "n_adabelief": n_adabelief,
528                    "dFDT_s": dFDT_s,
529                }
530    
531    #####################
532    # RESTART
533    restart_file = system_data["name"]+".qtb.restart"
534    if os.path.exists(restart_file):
535        with open(restart_file, "rb") as f:
536            data = pickle.load(f)
537            state["corr_kin"] = data["corr_kin"]
538            post_state["corr_kin_prev"] = data["corr_kin"]
539            post_state["isame_kin"] = data["isame_kin"]
540            post_state["do_corr_kin"] = data["do_corr_kin"]
541            print(f"# Restored QTB state from {restart_file}")
542
543    def write_qtb_restart(state, post_state):
544        with open(restart_file, "wb") as f:
545            pickle.dump(
546                {
547                    "corr_kin": state["corr_kin"],
548                    "corr_kin_prev": post_state["corr_kin_prev"],
549                    "isame_kin": post_state["isame_kin"],
550                    "do_corr_kin": post_state["do_corr_kin"],
551                },
552                f,
553            )
554    ######################
555
556    def compute_corr_pot(niter=20, verbose=False):
557        if classical_kernel or hbar == 0:
558            return np.ones(nom)
559
560        s_0 = np.array((theta / kT * cutoff)[:nom])
561        s_out, s_rec, _ = deconvolute_spectrum(
562            s_0,
563            omega[:nom],
564            gamma,
565            niter,
566            kernel=kernel_lorentz_pot,
567            trans=True,
568            symmetrize=True,
569            verbose=verbose,
570        )
571        corr_pot = 1.0 + (s_out - s_0) / s_0
572        columns = np.column_stack(
573            (omega[:nom] * au.CM1, corr_pot - 1.0, s_0, s_out, s_rec)
574        )
575        np.savetxt(
576            "corr_pot.dat", columns, header="omega(cm-1) corr_pot s_0 s_out s_rec"
577        )
578        return corr_pot
579
580    def compute_corr_kin(post_state, niter=7, verbose=False):
581        if not post_state["do_corr_kin"]:
582            return post_state["corr_kin_prev"], post_state
583        if classical_kernel or hbar == 0:
584            return 1.0, post_state
585
586        K_D = post_state.get("K_D", None)
587        mCvv = (post_state["mCvv_avg"][:, :nom] * n_of_type[:, None]).sum(axis=0) / nat
588        s_0 = np.array(mCvv * kT / theta[:nom] / post_state["corr_pot"])
589        s_out, s_rec, K_D = deconvolute_spectrum(
590            s_0,
591            omega[:nom],
592            gamma,
593            niter,
594            kernel=kernel_lorentz,
595            trans=False,
596            symmetrize=True,
597            verbose=verbose,
598            K_D=K_D,
599        )
600        s_out = s_out * theta[:nom] / kT
601        s_rec = s_rec * theta[:nom] / kT * post_state["corr_pot"]
602        mCvvsum = mCvv.sum()
603        rec_ratio = mCvvsum / s_rec.sum()
604        if rec_ratio < 0.95 or rec_ratio > 1.05:
605            print(
606                f"# WARNING: reconvolution error {rec_ratio} is too high, corr_kin was not updated"
607            )
608            return post_state["corr_kin_prev"], post_state
609
610        corr_kin = mCvvsum / s_out.sum()
611        if np.abs(corr_kin - post_state["corr_kin_prev"]) < 1.0e-4:
612            isame_kin = post_state["isame_kin"] + 1
613        else:
614            isame_kin = 0
615
616        # print("# corr_kin: ", corr_kin)
617        do_corr_kin = post_state["do_corr_kin"]
618        if isame_kin > 10:
619            print(
620                "# INFO: corr_kin is converged (it did not change for 10 consecutive segments)"
621            )
622            do_corr_kin = False
623
624        return corr_kin, {
625            **post_state,
626            "corr_kin_prev": corr_kin,
627            "isame_kin": isame_kin,
628            "do_corr_kin": do_corr_kin,
629            "K_D": K_D,
630        }
631
632    @jax.jit
633    def ff_kernel(post_state):
634        if classical_kernel:
635            kernel = cutoff * (2 * gamma * kT / dt)
636        else:
637            kernel = theta * cutoff * OUcorr * (2 * gamma / dt)
638        gamma_ratio = jnp.concatenate(
639            (
640                post_state["gammar"].T * post_state["corr_pot"][:, None],
641                jnp.ones(
642                    (kernel.shape[0] - nom, nspecies), dtype=post_state["gammar"].dtype
643                ),
644            ),
645            axis=0,
646        )
647        return kernel[:, None] * gamma_ratio * mass_idx[None, :]
648
649    @jax.jit
650    def refresh_force(post_state):
651        rng_key, noise_key = jax.random.split(post_state["rng_key"])
652        white_noise = jnp.concatenate(
653            (
654                post_state["white_noise"][nseg:],
655                jax.random.normal(
656                    noise_key, (nseg, nat, 3), dtype=post_state["white_noise"].dtype
657                ),
658            ),
659            axis=0,
660        )
661        amplitude = ff_kernel(post_state) ** 0.5
662        s = jnp.fft.rfft(white_noise, 3 * nseg, axis=0) * amplitude[:, type_idx, None]
663        force = jnp.fft.irfft(s, 3 * nseg, axis=0)[nseg : 2 * nseg]
664        return force, {**post_state, "rng_key": rng_key, "white_noise": white_noise}
665
666    @jax.jit
667    def compute_spectra(force, vel, post_state):
668        sf = jnp.fft.rfft(force / gamma, 3 * nseg, axis=0, norm="ortho")
669        sv = jnp.fft.rfft(vel, 3 * nseg, axis=0, norm="ortho")
670        Cvv = jnp.sum(jnp.abs(sv[:nom]) ** 2, axis=-1).T
671        Cff = jnp.sum(jnp.abs(sf[:nom]) ** 2, axis=-1).T
672        Cvf = jnp.sum(jnp.real(sv[:nom] * jnp.conj(sf[:nom])), axis=-1).T
673
674        mCvv = (
675            (dt / 3.0)
676            * jnp.zeros_like(post_state["mCvv"]).at[type_idx].add(Cvv)
677            * mass_idx[:, None]
678            / n_of_type[:, None]
679        )
680        Cvf = (
681            (dt / 3.0)
682            * jnp.zeros_like(post_state["Cvf"]).at[type_idx].add(Cvf)
683            / n_of_type[:, None]
684        )
685        Cff = (
686            (dt / 3.0)
687            * jnp.zeros_like(post_state["Cff"]).at[type_idx].add(Cff)
688            / n_of_type[:, None]
689        )
690        dFDT = mCvv * post_state["gammar"] - Cvf
691
692        nsinv = 1.0 / post_state["nsample"]
693        b1 = 1.0 - nsinv
694        dFDT_avg = post_state["dFDT_avg"] * b1 + dFDT * nsinv
695        mCvv_avg = post_state["mCvv_avg"] * b1 + mCvv * nsinv
696        Cvfg_avg = post_state["Cvfg_avg"] * b1 + Cvf / post_state["gammar"] * nsinv
697        Cff_avg = post_state["Cff_avg"] * b1 + Cff * nsinv
698
699        return {
700            **post_state,
701            "mCvv": mCvv,
702            "Cvf": Cvf,
703            "Cff": Cff,
704            "dFDT": dFDT,
705            "dFDT_avg": dFDT_avg,
706            "mCvv_avg": mCvv_avg,
707            "Cvfg_avg": Cvfg_avg,
708            "Cff_avg": Cff_avg,
709        }
710
711    def write_spectra_to_file(post_state):
712        mCvv_avg = np.array(post_state["mCvv_avg"])
713        Cvfg_avg = np.array(post_state["Cvfg_avg"])
714        Cff_avg = np.array(post_state["Cff_avg"]) * 3.0 / dt / (gamma**2)
715        dFDT_avg = np.array(post_state["dFDT_avg"])
716        gammar = np.array(post_state["gammar"])
717        Cff_theo = np.array(ff_kernel(post_state))[:nom].T
718        for i, sp in enumerate(species_set):
719            ff_scale = au.KELVIN / ((2 * gamma / dt) * mass_idx[i])
720            columns = np.column_stack(
721                (
722                    omega[:nom] * (au.FS * au.CM1),
723                    mCvv_avg[i],
724                    Cvfg_avg[i],
725                    dFDT_avg[i],
726                    gammar[i] * gamma * (au.FS * au.THZ),
727                    Cff_avg[i] * ff_scale,
728                    Cff_theo[i] * ff_scale,
729                )
730            )
731            np.savetxt(
732                f"QTB_spectra_{sp}.out",
733                columns,
734                fmt="%12.6f",
735                header="#omega mCvv Cvf dFDT gammar Cff",
736            )
737        if verbose:
738            print("# QTB spectra written.")
739
740    if compute_thermostat_energy:
741        state["qtb_energy_flux"] = 0.0
742
743    @jax.jit
744    def thermostat(vel, state):
745        istep = state["istep"]
746        dvel = dt * state["force"][istep] / mass[:, None]
747        new_vel = vel * a1 + dvel
748        new_state = {**state, "istep": istep + 1}
749        if do_compute_spectra:
750            vel2 = state["vel"].at[istep].set(vel * a1**0.5 + 0.5 * dvel)
751            new_state["vel"] = vel2
752        if compute_thermostat_energy:
753            dek = 0.5 * (mass[:, None] * (vel**2 - new_vel**2)).sum()
754            ekcorr = (
755                0.5
756                * (mass[:, None] * new_vel**2).sum()
757                * (1.0 - 1.0 / state.get("corr_kin", 1.0))
758            )
759            new_state["qtb_energy_flux"] = state["qtb_energy_flux"] + dek
760            new_state["thermostat_energy"] = new_state["qtb_energy_flux"] + ekcorr
761        return new_vel, new_state
762
763    @jax.jit
764    def postprocess_work(state, post_state):
765        if do_compute_spectra:
766            post_state = compute_spectra(state["force"], state["vel"], post_state)
767        if adaptive:
768            post_state = jax.lax.cond(
769                post_state["nadapt"] > skipseg, update_gammar, lambda x: x, post_state
770            )
771        new_force, post_state = refresh_force(post_state)
772        return {**state, "force": new_force}, post_state
773
774    def postprocess(state, post_state):
775        counter.increment()
776        if not counter.is_reset_step:
777            return state, post_state
778        post_state["nadapt"] += 1
779        post_state["nsample"] = max(post_state["nadapt"] - startsave + 1, 1)
780        if verbose:
781            print("# Refreshing QTB forces.")
782        state, post_state = postprocess_work(state, post_state)
783        state["corr_kin"], post_state = compute_corr_kin(post_state, niter=niter_deconv_kin)
784        state["istep"] = 0
785        if write_spectra:
786            write_spectra_to_file(post_state)
787        write_qtb_restart(state, post_state)
788        return state, post_state
789
790    post_state["corr_pot"] = jnp.asarray(compute_corr_pot(niter=niter_deconv_pot), dtype=fprec)
791
792    state["force"], post_state = refresh_force(post_state)
793    return thermostat, (postprocess, post_state), state