fennol.md.integrate

  1import sys, os, io
  2import argparse
  3import time
  4from pathlib import Path
  5import math
  6
  7import numpy as np
  8from typing import Optional, Callable
  9from collections import defaultdict
 10from functools import partial
 11import jax
 12import jax.numpy as jnp
 13
 14from flax.core import freeze, unfreeze
 15
 16from ..utils.io import last_xyz_frame
 17
 18
 19from ..models import FENNIX
 20
 21from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES
 22from ..utils.atomic_units import AtomicUnits as au
 23from ..utils.input_parser import parse_input
 24from .thermostats import get_thermostat
 25from .barostats import get_barostat
 26
 27from copy import deepcopy
 28from .initial import initialize_system
 29
 30
 31def initialize_dynamics(
 32    simulation_parameters, system_data, conformation, model, fprec, rng_key
 33):
 34    step, update_conformation, dyn_state, thermo_state, vel = initialize_integrator(
 35        simulation_parameters, system_data, model, fprec, rng_key
 36    )
 37    ### initialize system
 38    system = initialize_system(
 39        conformation,
 40        vel,
 41        model,
 42        system_data,
 43        fprec,
 44    )
 45    return step, update_conformation, dyn_state, {**system, **thermo_state}
 46
 47
 48def initialize_integrator(simulation_parameters, system_data, model, fprec, rng_key):
 49    dt = simulation_parameters.get("dt") * au.FS
 50    dt2 = 0.5 * dt
 51    nbeads = system_data.get("nbeads", None)
 52
 53    mass = system_data["mass"]
 54    totmass_amu = system_data["totmass_amu"]
 55    nat = system_data["nat"]
 56    dt2m = jnp.asarray(dt2 / mass[:, None], dtype=fprec)
 57    if nbeads is not None:
 58        dt2m = dt2m[None, :, :]
 59
 60    dyn_state = {
 61        "istep": 0,
 62        "dt": dt,
 63        "pimd": nbeads is not None,
 64    }
 65
 66    model_energy_unit = au.get_multiplier(model.energy_unit)
 67
 68    # initialize thermostat
 69    thermostat_rng, rng_key = jax.random.split(rng_key)
 70    thermostat, thermostat_post, thermostat_state, vel, dyn_state["thermostat_name"] = (
 71        get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng)
 72    )
 73
 74    do_thermostat_post = thermostat_post is not None
 75    if do_thermostat_post:
 76        thermostat_post, post_state = thermostat_post
 77        dyn_state["thermostat_post_state"] = post_state
 78
 79    pbc_data = system_data.get("pbc", None)
 80    if pbc_data is not None:
 81        thermo_update, variable_cell, barostat_state = get_barostat(
 82            thermostat, simulation_parameters, dt, system_data, fprec, rng_key
 83        )
 84        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
 85        thermo_state = {"thermostat": thermostat_state, "barostat": barostat_state}
 86
 87    else:
 88        estimate_pressure = False
 89        variable_cell = False
 90
 91        def thermo_update(x, v, system):
 92            v, thermostat_state = thermostat(v, system["thermostat"])
 93            return x, v, {**system, "thermostat": thermostat_state}
 94
 95        thermo_state = {"thermostat": thermostat_state}
 96
 97    print("# Estimate pressure: ", estimate_pressure)
 98
 99    dyn_state["estimate_pressure"] = estimate_pressure
100    dyn_state["variable_cell"] = variable_cell
101
102    dyn_state["print_timings"] = simulation_parameters.get("print_timings", False)
103    if dyn_state["print_timings"]:
104        dyn_state["timings"] = defaultdict(lambda: 0.0)
105
106    ### NBLIST
107    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
108    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
109    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS
110    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
111    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
112    if nblist_skin > 0:
113        if nblist_stride <= 0:
114            ## reference skin parameters at 300K (from Tinker-HP)
115            ##   => skin of 2 A gives you 40 fs without complete rebuild
116            t_ref = 40.0  # FS
117            nblist_skin_ref = 2.0  # A
118            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
119        print(
120            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
121        )
122
123    if nblist_skin <= 0:
124        nblist_stride = 1
125
126    dyn_state["nblist_countdown"] = 0
127    dyn_state["print_skin_activation"] = nblist_warmup > 0
128
129    ### ENSEMBLE
130    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
131
132    ### DEFINE INTEGRATION FUNCTIONS
133
134    if nbeads is not None:
135        ### RING POLYMER INTEGRATOR
136        cay_correction = simulation_parameters.get("cay_correction", True)
137        omk = system_data["omk"]
138        eigmat = system_data["eigmat"]
139        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
140        if cay_correction:
141            axx = jnp.asarray(2 * cayfact)
142            axv = jnp.asarray(dt * cayfact)
143            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
144        else:
145            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
146            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
147            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
148
149        @jax.jit
150        def update_conformation(conformation, system):
151            eigx = system["coordinates"]
152            """update bead coordinates from ring polymer normal modes"""
153            x = jnp.einsum("in,n...->i...", eigmat, eigx).reshape(nbeads * nat, 3) * (
154                nbeads**0.5
155            )
156            conformation = {**conformation, "coordinates": x}
157            if variable_cell:
158                conformation["cells"] = system["cell"][None, :, :].repeat(nbeads, axis=0)
159                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
160                    None, :, :
161                ].repeat(nbeads, axis=0)
162            return conformation
163
164        @jax.jit
165        def coords_to_eig(x):
166            """update normal modes from bead coordinates"""
167            return jnp.einsum("in,i...->n...", eigmat, x.reshape(nbeads, nat, 3)) * (
168                1.0 / nbeads**0.5
169            )
170
171        def halfstep_free_polymer(eigx0, eigv0):
172            """update coordinates and velocities of a free ring polymer for a half time step"""
173            eigx_c = eigx0[0] + dt2 * eigv0[0]
174            eigv_c = eigv0[0]
175            eigx = eigx0[1:] * axx + eigv0[1:] * axv
176            eigv = eigx0[1:] * avx + eigv0[1:] * axx
177
178            return (
179                jnp.concatenate((eigx_c[None], eigx), axis=0),
180                jnp.concatenate((eigv_c[None], eigv), axis=0),
181            )
182
183        @jax.jit
184        def stepA(system):
185            eigx = system["coordinates"]
186            eigv = system["vel"] + dt2m * system["forces"]
187            eigx, eigv = halfstep_free_polymer(eigx, eigv)
188            eigx, eigv, system = thermo_update(eigx, eigv, system)
189            eigx, eigv = halfstep_free_polymer(eigx, eigv)
190
191            return {
192                **system,
193                "coordinates": eigx,
194                "vel": eigv,
195            }
196
197        @jax.jit
198        def update_forces(system, conformation):
199            if estimate_pressure:
200                epot, f, vir_t, out = model._energy_and_forces_and_virial(
201                    model.variables, conformation
202                )
203                epot = epot / model_energy_unit
204                f = f / model_energy_unit
205                vir_t = vir_t / model_energy_unit
206
207                new_sys = {
208                    **system,
209                    "forces": coords_to_eig(f),
210                    "epot": jnp.mean(epot),
211                    "virial": jnp.mean(vir_t, axis=0),
212                }
213            else:
214                epot, f, out = model._energy_and_forces(model.variables, conformation)
215                epot = epot / model_energy_unit
216                f = f / model_energy_unit
217                new_sys = {**system, "forces": coords_to_eig(f), "epot": jnp.mean(epot)}
218            if ensemble_key is not None:
219                kT = system_data["kT"]
220                dE = (
221                    jnp.mean(out[ensemble_key], axis=0) / model_energy_unit
222                    - new_sys["epot"]
223                )
224                new_sys["ensemble_weights"] = -dE / kT
225            return new_sys
226
227        @jax.jit
228        def stepB(system):
229            eigv = system["vel"] + dt2m * system["forces"]
230
231            ek_c = 0.5 * jnp.sum(
232                mass[:, None, None] * eigv[0, :, :, None] * eigv[0, :, None, :], axis=0
233            )
234            ek = ek_c - 0.5 * jnp.sum(
235                system["coordinates"][1:, :, :, None]
236                * system["forces"][1:, :, None, :],
237                axis=(0, 1),
238            )
239            system = {
240                **system,
241                "vel": eigv,
242                "ek_tensor": ek,
243                "ek_c": jnp.trace(ek_c),
244                "ek": jnp.trace(ek),
245            }
246
247            if estimate_pressure:
248                vir = system["virial"]
249                volume = jnp.abs(jnp.linalg.det(system["cell"]))
250                Pres = (2 * ek - vir) / volume
251                system["pressure_tensor"] = Pres
252                system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
253                if variable_cell:
254                    density = totmass_amu / volume
255                    system["density"] = density
256                    system["volume"] = volume
257
258            return system
259
260    else:
261        ### CLASSICAL MD INTEGRATOR
262        @jax.jit
263        def update_conformation(conformation, system):
264            conformation = {**conformation, "coordinates": system["coordinates"]}
265            if variable_cell:
266                conformation["cells"] = system["cell"][None, :, :]
267                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
268                    None, :, :
269                ]
270            return conformation
271
272        @jax.jit
273        def stepA(system):
274            v = system["vel"]
275            f = system["forces"]
276            x = system["coordinates"]
277
278            v = v + f * dt2m
279            x = x + dt2 * v
280            x, v, system = thermo_update(x, v, system)
281            x = x + dt2 * v
282
283            return {**system, "coordinates": x, "vel": v}
284
285        @jax.jit
286        def update_forces(system, conformation):
287            if estimate_pressure:
288                epot, f, vir_t, out = model._energy_and_forces_and_virial(
289                    model.variables, conformation
290                )
291                epot = epot / model_energy_unit
292                f = f / model_energy_unit
293                vir_t = vir_t / model_energy_unit
294                new_sys = {
295                    **system,
296                    "forces": f,
297                    "epot": epot[0],
298                    "virial": vir_t[0],
299                }
300            else:
301                epot, f, out = model._energy_and_forces(model.variables, conformation)
302                epot = epot / model_energy_unit
303                f = f / model_energy_unit
304                new_sys = {**system, "forces": f, "epot": epot[0]}
305
306            if ensemble_key is not None:
307                kT = system_data["kT"]
308                dE = out[ensemble_key][0, :] / model_energy_unit - new_sys["epot"]
309                new_sys["ensemble_weights"] = -dE / kT
310            return new_sys
311
312        @jax.jit
313        def stepB(system):
314            v = system["vel"]
315            f = system["forces"]
316            state_th = system["thermostat"]
317
318            v = v + f * dt2m
319            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
320            ek_tensor = (
321                0.5
322                * jnp.sum(mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0)
323                / state_th.get("corr_kin", 1.0)
324            )
325            system = {
326                **system,
327                "vel": v,
328                "ek": jnp.trace(ek_tensor),
329                "ek_tensor": ek_tensor,
330            }
331
332            if estimate_pressure:
333                vir = system["virial"]
334                volume = jnp.abs(jnp.linalg.det(system["cell"]))
335                Pres = (2 * ek_tensor - vir) / volume
336                system["pressure_tensor"] = Pres
337                system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
338                if variable_cell:
339                    density = totmass_amu / volume
340                    system["density"] = density
341                    system["volume"] = volume
342
343            return system
344
345    ### DEFINE STEP FUNCTION COMMON TO CLASSICAL AND PIMD
346    def step(
347        istep, dyn_state, system, conformation, preproc_state, force_preprocess=False
348    ):
349        tstep0 = time.time()
350        print_timings = "timings" in dyn_state
351
352        dyn_state = {
353            **dyn_state,
354            "istep": dyn_state["istep"] + 1,
355        }
356        if print_timings:
357            prev_timings = dyn_state["timings"]
358            timings = defaultdict(lambda: 0.0)
359            timings.update(prev_timings)
360
361        ## take a half step (update positions, nblist and half velocities)
362        system = stepA(system)
363
364        if print_timings:
365            system["coordinates"].block_until_ready()
366            timings["Integrator"] += time.time() - tstep0
367            tstep0 = time.time()
368
369        ### update conformation and graphs
370        nblist_countdown = dyn_state["nblist_countdown"]
371        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
372            ### full nblist update
373            dyn_state["nblist_countdown"] = nblist_stride - 1
374            conformation = model.preprocessing.process(
375                preproc_state, update_conformation(conformation, system)
376            )
377            preproc_state, state_up, conformation, overflow = (
378                model.preprocessing.check_reallocate(preproc_state, conformation)
379            )
380            if nblist_verbose and overflow:
381                print("step", istep, ", nblist overflow => reallocating nblist")
382                print("size updates:", state_up)
383
384            if print_timings:
385                conformation["coordinates"].block_until_ready()
386                timings["Preprocessing"] += time.time() - tstep0
387                tstep0 = time.time()
388
389        else:
390            ### skin update
391            if dyn_state["print_skin_activation"]:
392                if nblist_verbose:
393                    print(
394                        "step",
395                        istep,
396                        ", end of nblist warmup phase => activating skin updates",
397                    )
398                dyn_state["print_skin_activation"] = False
399
400            dyn_state["nblist_countdown"] = nblist_countdown - 1
401            conformation = model.preprocessing.update_skin(
402                update_conformation(conformation, system)
403            )
404
405            if print_timings:
406                conformation["coordinates"].block_until_ready()
407                timings["Skin update"] += time.time() - tstep0
408                tstep0 = time.time()
409
410        ## compute forces
411        system = update_forces(system, conformation)
412        if print_timings:
413            system["coordinates"].block_until_ready()
414            timings["Forces"] += time.time() - tstep0
415            tstep0 = time.time()
416
417        ## finish step
418        system = stepB(system)
419
420        ## end of step update (mostly for adQTB)
421        if do_thermostat_post:
422            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
423                system["thermostat"], dyn_state["thermostat_post_state"]
424            )
425
426        if print_timings:
427            system["coordinates"].block_until_ready()
428            timings["Integrator"] += time.time() - tstep0
429            tstep0 = time.time()
430
431            # store timings
432            dyn_state["timings"] = timings
433
434        return dyn_state, system, conformation, preproc_state
435
436    return step, update_conformation, dyn_state, thermo_state, vel
def initialize_dynamics( simulation_parameters, system_data, conformation, model, fprec, rng_key):
32def initialize_dynamics(
33    simulation_parameters, system_data, conformation, model, fprec, rng_key
34):
35    step, update_conformation, dyn_state, thermo_state, vel = initialize_integrator(
36        simulation_parameters, system_data, model, fprec, rng_key
37    )
38    ### initialize system
39    system = initialize_system(
40        conformation,
41        vel,
42        model,
43        system_data,
44        fprec,
45    )
46    return step, update_conformation, dyn_state, {**system, **thermo_state}
def initialize_integrator(simulation_parameters, system_data, model, fprec, rng_key):
 49def initialize_integrator(simulation_parameters, system_data, model, fprec, rng_key):
 50    dt = simulation_parameters.get("dt") * au.FS
 51    dt2 = 0.5 * dt
 52    nbeads = system_data.get("nbeads", None)
 53
 54    mass = system_data["mass"]
 55    totmass_amu = system_data["totmass_amu"]
 56    nat = system_data["nat"]
 57    dt2m = jnp.asarray(dt2 / mass[:, None], dtype=fprec)
 58    if nbeads is not None:
 59        dt2m = dt2m[None, :, :]
 60
 61    dyn_state = {
 62        "istep": 0,
 63        "dt": dt,
 64        "pimd": nbeads is not None,
 65    }
 66
 67    model_energy_unit = au.get_multiplier(model.energy_unit)
 68
 69    # initialize thermostat
 70    thermostat_rng, rng_key = jax.random.split(rng_key)
 71    thermostat, thermostat_post, thermostat_state, vel, dyn_state["thermostat_name"] = (
 72        get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng)
 73    )
 74
 75    do_thermostat_post = thermostat_post is not None
 76    if do_thermostat_post:
 77        thermostat_post, post_state = thermostat_post
 78        dyn_state["thermostat_post_state"] = post_state
 79
 80    pbc_data = system_data.get("pbc", None)
 81    if pbc_data is not None:
 82        thermo_update, variable_cell, barostat_state = get_barostat(
 83            thermostat, simulation_parameters, dt, system_data, fprec, rng_key
 84        )
 85        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
 86        thermo_state = {"thermostat": thermostat_state, "barostat": barostat_state}
 87
 88    else:
 89        estimate_pressure = False
 90        variable_cell = False
 91
 92        def thermo_update(x, v, system):
 93            v, thermostat_state = thermostat(v, system["thermostat"])
 94            return x, v, {**system, "thermostat": thermostat_state}
 95
 96        thermo_state = {"thermostat": thermostat_state}
 97
 98    print("# Estimate pressure: ", estimate_pressure)
 99
100    dyn_state["estimate_pressure"] = estimate_pressure
101    dyn_state["variable_cell"] = variable_cell
102
103    dyn_state["print_timings"] = simulation_parameters.get("print_timings", False)
104    if dyn_state["print_timings"]:
105        dyn_state["timings"] = defaultdict(lambda: 0.0)
106
107    ### NBLIST
108    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
109    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
110    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS
111    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
112    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
113    if nblist_skin > 0:
114        if nblist_stride <= 0:
115            ## reference skin parameters at 300K (from Tinker-HP)
116            ##   => skin of 2 A gives you 40 fs without complete rebuild
117            t_ref = 40.0  # FS
118            nblist_skin_ref = 2.0  # A
119            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
120        print(
121            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
122        )
123
124    if nblist_skin <= 0:
125        nblist_stride = 1
126
127    dyn_state["nblist_countdown"] = 0
128    dyn_state["print_skin_activation"] = nblist_warmup > 0
129
130    ### ENSEMBLE
131    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
132
133    ### DEFINE INTEGRATION FUNCTIONS
134
135    if nbeads is not None:
136        ### RING POLYMER INTEGRATOR
137        cay_correction = simulation_parameters.get("cay_correction", True)
138        omk = system_data["omk"]
139        eigmat = system_data["eigmat"]
140        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
141        if cay_correction:
142            axx = jnp.asarray(2 * cayfact)
143            axv = jnp.asarray(dt * cayfact)
144            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
145        else:
146            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
147            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
148            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
149
150        @jax.jit
151        def update_conformation(conformation, system):
152            eigx = system["coordinates"]
153            """update bead coordinates from ring polymer normal modes"""
154            x = jnp.einsum("in,n...->i...", eigmat, eigx).reshape(nbeads * nat, 3) * (
155                nbeads**0.5
156            )
157            conformation = {**conformation, "coordinates": x}
158            if variable_cell:
159                conformation["cells"] = system["cell"][None, :, :].repeat(nbeads, axis=0)
160                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
161                    None, :, :
162                ].repeat(nbeads, axis=0)
163            return conformation
164
165        @jax.jit
166        def coords_to_eig(x):
167            """update normal modes from bead coordinates"""
168            return jnp.einsum("in,i...->n...", eigmat, x.reshape(nbeads, nat, 3)) * (
169                1.0 / nbeads**0.5
170            )
171
172        def halfstep_free_polymer(eigx0, eigv0):
173            """update coordinates and velocities of a free ring polymer for a half time step"""
174            eigx_c = eigx0[0] + dt2 * eigv0[0]
175            eigv_c = eigv0[0]
176            eigx = eigx0[1:] * axx + eigv0[1:] * axv
177            eigv = eigx0[1:] * avx + eigv0[1:] * axx
178
179            return (
180                jnp.concatenate((eigx_c[None], eigx), axis=0),
181                jnp.concatenate((eigv_c[None], eigv), axis=0),
182            )
183
184        @jax.jit
185        def stepA(system):
186            eigx = system["coordinates"]
187            eigv = system["vel"] + dt2m * system["forces"]
188            eigx, eigv = halfstep_free_polymer(eigx, eigv)
189            eigx, eigv, system = thermo_update(eigx, eigv, system)
190            eigx, eigv = halfstep_free_polymer(eigx, eigv)
191
192            return {
193                **system,
194                "coordinates": eigx,
195                "vel": eigv,
196            }
197
198        @jax.jit
199        def update_forces(system, conformation):
200            if estimate_pressure:
201                epot, f, vir_t, out = model._energy_and_forces_and_virial(
202                    model.variables, conformation
203                )
204                epot = epot / model_energy_unit
205                f = f / model_energy_unit
206                vir_t = vir_t / model_energy_unit
207
208                new_sys = {
209                    **system,
210                    "forces": coords_to_eig(f),
211                    "epot": jnp.mean(epot),
212                    "virial": jnp.mean(vir_t, axis=0),
213                }
214            else:
215                epot, f, out = model._energy_and_forces(model.variables, conformation)
216                epot = epot / model_energy_unit
217                f = f / model_energy_unit
218                new_sys = {**system, "forces": coords_to_eig(f), "epot": jnp.mean(epot)}
219            if ensemble_key is not None:
220                kT = system_data["kT"]
221                dE = (
222                    jnp.mean(out[ensemble_key], axis=0) / model_energy_unit
223                    - new_sys["epot"]
224                )
225                new_sys["ensemble_weights"] = -dE / kT
226            return new_sys
227
228        @jax.jit
229        def stepB(system):
230            eigv = system["vel"] + dt2m * system["forces"]
231
232            ek_c = 0.5 * jnp.sum(
233                mass[:, None, None] * eigv[0, :, :, None] * eigv[0, :, None, :], axis=0
234            )
235            ek = ek_c - 0.5 * jnp.sum(
236                system["coordinates"][1:, :, :, None]
237                * system["forces"][1:, :, None, :],
238                axis=(0, 1),
239            )
240            system = {
241                **system,
242                "vel": eigv,
243                "ek_tensor": ek,
244                "ek_c": jnp.trace(ek_c),
245                "ek": jnp.trace(ek),
246            }
247
248            if estimate_pressure:
249                vir = system["virial"]
250                volume = jnp.abs(jnp.linalg.det(system["cell"]))
251                Pres = (2 * ek - vir) / volume
252                system["pressure_tensor"] = Pres
253                system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
254                if variable_cell:
255                    density = totmass_amu / volume
256                    system["density"] = density
257                    system["volume"] = volume
258
259            return system
260
261    else:
262        ### CLASSICAL MD INTEGRATOR
263        @jax.jit
264        def update_conformation(conformation, system):
265            conformation = {**conformation, "coordinates": system["coordinates"]}
266            if variable_cell:
267                conformation["cells"] = system["cell"][None, :, :]
268                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
269                    None, :, :
270                ]
271            return conformation
272
273        @jax.jit
274        def stepA(system):
275            v = system["vel"]
276            f = system["forces"]
277            x = system["coordinates"]
278
279            v = v + f * dt2m
280            x = x + dt2 * v
281            x, v, system = thermo_update(x, v, system)
282            x = x + dt2 * v
283
284            return {**system, "coordinates": x, "vel": v}
285
286        @jax.jit
287        def update_forces(system, conformation):
288            if estimate_pressure:
289                epot, f, vir_t, out = model._energy_and_forces_and_virial(
290                    model.variables, conformation
291                )
292                epot = epot / model_energy_unit
293                f = f / model_energy_unit
294                vir_t = vir_t / model_energy_unit
295                new_sys = {
296                    **system,
297                    "forces": f,
298                    "epot": epot[0],
299                    "virial": vir_t[0],
300                }
301            else:
302                epot, f, out = model._energy_and_forces(model.variables, conformation)
303                epot = epot / model_energy_unit
304                f = f / model_energy_unit
305                new_sys = {**system, "forces": f, "epot": epot[0]}
306
307            if ensemble_key is not None:
308                kT = system_data["kT"]
309                dE = out[ensemble_key][0, :] / model_energy_unit - new_sys["epot"]
310                new_sys["ensemble_weights"] = -dE / kT
311            return new_sys
312
313        @jax.jit
314        def stepB(system):
315            v = system["vel"]
316            f = system["forces"]
317            state_th = system["thermostat"]
318
319            v = v + f * dt2m
320            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
321            ek_tensor = (
322                0.5
323                * jnp.sum(mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0)
324                / state_th.get("corr_kin", 1.0)
325            )
326            system = {
327                **system,
328                "vel": v,
329                "ek": jnp.trace(ek_tensor),
330                "ek_tensor": ek_tensor,
331            }
332
333            if estimate_pressure:
334                vir = system["virial"]
335                volume = jnp.abs(jnp.linalg.det(system["cell"]))
336                Pres = (2 * ek_tensor - vir) / volume
337                system["pressure_tensor"] = Pres
338                system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
339                if variable_cell:
340                    density = totmass_amu / volume
341                    system["density"] = density
342                    system["volume"] = volume
343
344            return system
345
346    ### DEFINE STEP FUNCTION COMMON TO CLASSICAL AND PIMD
347    def step(
348        istep, dyn_state, system, conformation, preproc_state, force_preprocess=False
349    ):
350        tstep0 = time.time()
351        print_timings = "timings" in dyn_state
352
353        dyn_state = {
354            **dyn_state,
355            "istep": dyn_state["istep"] + 1,
356        }
357        if print_timings:
358            prev_timings = dyn_state["timings"]
359            timings = defaultdict(lambda: 0.0)
360            timings.update(prev_timings)
361
362        ## take a half step (update positions, nblist and half velocities)
363        system = stepA(system)
364
365        if print_timings:
366            system["coordinates"].block_until_ready()
367            timings["Integrator"] += time.time() - tstep0
368            tstep0 = time.time()
369
370        ### update conformation and graphs
371        nblist_countdown = dyn_state["nblist_countdown"]
372        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
373            ### full nblist update
374            dyn_state["nblist_countdown"] = nblist_stride - 1
375            conformation = model.preprocessing.process(
376                preproc_state, update_conformation(conformation, system)
377            )
378            preproc_state, state_up, conformation, overflow = (
379                model.preprocessing.check_reallocate(preproc_state, conformation)
380            )
381            if nblist_verbose and overflow:
382                print("step", istep, ", nblist overflow => reallocating nblist")
383                print("size updates:", state_up)
384
385            if print_timings:
386                conformation["coordinates"].block_until_ready()
387                timings["Preprocessing"] += time.time() - tstep0
388                tstep0 = time.time()
389
390        else:
391            ### skin update
392            if dyn_state["print_skin_activation"]:
393                if nblist_verbose:
394                    print(
395                        "step",
396                        istep,
397                        ", end of nblist warmup phase => activating skin updates",
398                    )
399                dyn_state["print_skin_activation"] = False
400
401            dyn_state["nblist_countdown"] = nblist_countdown - 1
402            conformation = model.preprocessing.update_skin(
403                update_conformation(conformation, system)
404            )
405
406            if print_timings:
407                conformation["coordinates"].block_until_ready()
408                timings["Skin update"] += time.time() - tstep0
409                tstep0 = time.time()
410
411        ## compute forces
412        system = update_forces(system, conformation)
413        if print_timings:
414            system["coordinates"].block_until_ready()
415            timings["Forces"] += time.time() - tstep0
416            tstep0 = time.time()
417
418        ## finish step
419        system = stepB(system)
420
421        ## end of step update (mostly for adQTB)
422        if do_thermostat_post:
423            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
424                system["thermostat"], dyn_state["thermostat_post_state"]
425            )
426
427        if print_timings:
428            system["coordinates"].block_until_ready()
429            timings["Integrator"] += time.time() - tstep0
430            tstep0 = time.time()
431
432            # store timings
433            dyn_state["timings"] = timings
434
435        return dyn_state, system, conformation, preproc_state
436
437    return step, update_conformation, dyn_state, thermo_state, vel