fennol.md.integrate

  1import time
  2import math
  3import os
  4
  5import numpy as np
  6import jax
  7import jax.numpy as jnp
  8
  9from ..utils.atomic_units import AtomicUnits as au
 10from .thermostats import get_thermostat
 11from .barostats import get_barostat
 12from .colvars import setup_colvars
 13from .spectra import initialize_ir_spectrum
 14
 15from .utils import load_dynamics_restart, get_restart_file
 16from .initial import load_model, load_system_data, initialize_preprocessing
 17
 18
 19def initialize_dynamics(simulation_parameters, fprec, rng_key):
 20    ### LOAD MODEL
 21    model = load_model(simulation_parameters)
 22    model_energy_unit = au.get_multiplier(model.energy_unit)
 23
 24    ### Get the coordinates and species from the xyz file
 25    system_data, conformation = load_system_data(simulation_parameters, fprec)
 26
 27    ### FINISH BUILDING conformation
 28    if os.path.exists(get_restart_file(system_data)):
 29        ### RESTART FROM PREVIOUS DYNAMICS
 30        restart_data = load_dynamics_restart(system_data)
 31        print("# RESTARTING FROM PREVIOUS DYNAMICS")
 32        model.preproc_state = restart_data["preproc_state"]
 33        conformation["coordinates"] = restart_data["coordinates"]
 34    else:
 35        restart_data = {}
 36
 37    ### INITIALIZE PREPROCESSING
 38    preproc_state, conformation = initialize_preprocessing(
 39        simulation_parameters, model, conformation, system_data
 40    )
 41
 42    ### get dynamics parameters
 43    dt = simulation_parameters.get("dt") * au.FS
 44    dt2 = 0.5 * dt
 45    mass = system_data["mass"]
 46    totmass_amu = system_data["totmass_amu"]
 47    nat = system_data["nat"]
 48    dtm = jnp.asarray(dt / mass[:, None], dtype=fprec)
 49
 50    nreplicas = system_data.get("nreplicas", 1)
 51    nbeads = system_data.get("nbeads", None)
 52    if nbeads is not None:
 53        nreplicas = nbeads
 54        dtm = dtm[None, :, :]
 55
 56    ### INITIALIZE DYNAMICS STATE
 57    system = {"coordinates": conformation["coordinates"]}
 58    dyn_state = {
 59        "istep": 0,
 60        "dt": dt,
 61        "pimd": nbeads is not None,
 62        "preproc_state": preproc_state,
 63        "start_time_ps": restart_data.get("simulation_time_ps", 0.),
 64    }
 65    gradient_keys = ["coordinates"]
 66    thermo_updates = []
 67
 68    ### INITIALIZE THERMOSTAT
 69    thermostat_rng, rng_key = jax.random.split(rng_key)
 70    (
 71        thermostat,
 72        thermostat_post,
 73        thermostat_state,
 74        initial_vel,
 75        dyn_state["thermostat_name"],
 76    ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data)
 77    do_thermostat_post = thermostat_post is not None
 78    if do_thermostat_post:
 79        thermostat_post, post_state = thermostat_post
 80        dyn_state["thermostat_post_state"] = post_state
 81
 82    system["thermostat"] = thermostat_state
 83    system["vel"] = restart_data.get("vel", initial_vel).astype(fprec)
 84
 85    ### PBC
 86    pbc_data = system_data.get("pbc", None)
 87    if pbc_data is not None:
 88        ### INITIALIZE BAROSTAT
 89        barostat_key, rng_key = jax.random.split(rng_key)
 90        thermo_update_ensemble, variable_cell, barostat_state = get_barostat(
 91            thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data
 92        )
 93        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
 94        system["barostat"] = barostat_state
 95        system["cell"] = conformation["cells"][0]
 96        if estimate_pressure:
 97            pressure_o_weight = simulation_parameters.get("pressure_o_weight", 0.0)
 98            assert (
 99                0.0 <= pressure_o_weight <= 1.0
100            ), "pressure_o_weight must be between 0 and 1"
101            gradient_keys.append("strain")
102        print("# Estimate pressure: ", estimate_pressure)
103    else:
104        estimate_pressure = False
105        variable_cell = False
106
107        def thermo_update_ensemble(x, v, system):
108            v, thermostat_state = thermostat(v, system["thermostat"])
109            return x, v, {**system, "thermostat": thermostat_state}
110
111    dyn_state["estimate_pressure"] = estimate_pressure
112    dyn_state["variable_cell"] = variable_cell
113    thermo_updates.append(thermo_update_ensemble)
114
115    ### ENERGY ENSEMBLE
116    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
117
118    ### COLVARS
119    colvars_definitions = simulation_parameters.get("colvars", None)
120    use_colvars = colvars_definitions is not None
121    if use_colvars:
122        colvars_calculators, colvars_names = setup_colvars(colvars_definitions)
123        dyn_state["colvars"] = colvars_names
124
125    ### IR SPECTRUM
126    do_ir_spectrum = simulation_parameters.get("ir_spectrum", False)
127    assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean"
128    if do_ir_spectrum:
129        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
130        model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum(
131            simulation_parameters, system_data, fprec, dt, is_qtb
132        )
133        dyn_state["ir_spectrum"] = ir_state
134
135    ### BUILD GRADIENT FUNCTION
136    energy_and_gradient = model.get_gradient_function(
137        *gradient_keys, jit=True, variables_as_input=True
138    )
139
140    ### COLLECT THERMO UPDATES
141    if len(thermo_updates) == 1:
142        thermo_update = thermo_updates[0]
143    else:
144
145        def thermo_update(x, v, system):
146            for update in thermo_updates:
147                x, v, system = update(x, v, system)
148            return x, v, system
149
150    ### RING POLYMER INITIALIZATION
151    if nbeads is not None:
152        cay_correction = simulation_parameters.get("cay_correction", True)
153        omk = system_data["omk"]
154        eigmat = system_data["eigmat"]
155        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
156        if cay_correction:
157            axx = jnp.asarray(2 * cayfact)
158            axv = jnp.asarray(dt * cayfact)
159            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
160        else:
161            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
162            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
163            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
164
165        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
166        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
167        system["coordinates"] = eigx
168
169    ###############################################
170    ### DEFINE UPDATE FUNCTION
171    @jax.jit
172    def update_conformation(conformation, system):
173        x = system["coordinates"]
174        if nbeads is not None:
175            x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * (
176                nbeads**0.5
177            )
178        conformation = {**conformation, "coordinates": x}
179        if variable_cell:
180            conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0)
181
182        
183
184        return conformation
185
186    ###############################################
187    ### DEFINE INTEGRATION FUNCTIONS
188    def integrate_A_half(x0, v0):
189        if nbeads is None:
190            return x0 + dt2 * v0, v0
191
192        # update coordinates and velocities of a free ring polymer for a half time step
193        eigx_c = x0[0] + dt2 * v0[0]
194        eigv_c = v0[0]
195        eigx = x0[1:] * axx + v0[1:] * axv
196        eigv = x0[1:] * avx + v0[1:] * axx
197
198        return (
199            jnp.concatenate((eigx_c[None], eigx), axis=0),
200            jnp.concatenate((eigv_c[None], eigv), axis=0),
201        )
202
203    @jax.jit
204    def integrate(system):
205        x = system["coordinates"]
206        v = system["vel"] + dtm * system["forces"]
207        x, v = integrate_A_half(x, v)
208        x, v, system = thermo_update(x, v, system)
209        x, v = integrate_A_half(x, v)
210
211        return {**system, "coordinates": x, "vel": v}
212
213    ###############################################
214    ### DEFINE OBSERVABLE FUNCTION
215    @jax.jit
216    def update_observables(system, conformation):
217        ### POTENTIAL ENERGY AND FORCES
218        epot, de, out = energy_and_gradient(model.variables, conformation)
219        epot = epot / model_energy_unit
220        de = {k: v / model_energy_unit for k, v in de.items()}
221        forces = -de["coordinates"]
222
223        if nbeads is not None:
224            ### PROJECT FORCES ONTO POLYMER NORMAL MODES
225            forces = jnp.einsum(
226                "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3)
227            ) * (1.0 / nbeads**0.5)
228
229        system = {
230            **system,
231            "epot": jnp.mean(epot),
232            "forces": forces,
233            "energy_gradients": de,
234        }
235
236        ### KINETIC ENERGY
237        v = system["vel"]
238        if nbeads is None:
239            corr_kin = system["thermostat"].get("corr_kin", 1.0)
240            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
241            ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
242                mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
243            )
244        else:
245            ek_c = 0.5 * jnp.sum(
246                mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0
247            )
248            ek = ek_c - 0.5 * jnp.sum(
249                system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
250                axis=(0, 1),
251            )
252            system["ek_c"] = jnp.trace(ek_c)
253
254        system["ek"] = jnp.trace(ek)
255        system["ek_tensor"] = ek
256
257        if estimate_pressure:
258            if pressure_o_weight != 1.0:
259                v = system["vel"] + 0.5 * dtm * system["forces"]
260                if nbeads is None:
261                    corr_kin = system["thermostat"].get("corr_kin", 1.0)
262                    # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
263                    ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
264                        mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
265                    )
266                else:
267                    ek_c = 0.5 * jnp.sum(
268                        mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :],
269                        axis=0,
270                    )
271                    ek = ek_c - 0.5 * jnp.sum(
272                        system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
273                        axis=(0, 1),
274                    )
275                b = pressure_o_weight
276                ek = (1.0 - b) * ek + b * system["ek_tensor"]
277
278            vir = jnp.mean(de["strain"], axis=0)
279            system["virial"] = vir
280            
281            pV =  2 * ek  - vir
282            system["PV_tensor"] = pV
283            volume = jnp.abs(jnp.linalg.det(system["cell"]))
284            Pres = pV / volume
285            system["pressure_tensor"] = Pres
286            system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
287            if variable_cell:
288                density = totmass_amu / volume
289                system["density"] = density
290                system["volume"] = volume
291
292        if ensemble_key is not None:
293            kT = system_data["kT"]
294            dE = (
295                jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"]
296            )
297            system["ensemble_weights"] = -dE / kT
298
299        if "total_dipole" in out:
300            if nbeads is None:
301                system["total_dipole"] = out["total_dipole"][0]
302            else:
303                system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0)
304
305        if use_colvars:
306            coords = system["coordinates"].reshape(-1, nat, 3)[0]
307            colvars = {}
308            for colvar_name, colvar_calc in colvars_calculators.items():
309                colvars[colvar_name] = colvar_calc(coords)
310            system["colvars"] = colvars
311
312        return system, out
313
314    ###############################################
315    ### IR SPECTRUM
316    if do_ir_spectrum:
317        # @jax.jit
318        # def update_dipole(ir_state,system,conformation):
319        #     def mumodel(coords):
320        #         out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords})
321        #         if nbeads is None:
322        #             return out["total_dipole"][0]
323        #         return out["total_dipole"].sum(axis=0)
324        #     dmudqmodel = jax.jacobian(mumodel)
325
326        #     dmudq = dmudqmodel(conformation["coordinates"])
327        #     # print(dmudq.shape)
328        #     if nbeads is None:
329        #         vel = system["vel"].reshape(-1,1,nat,3)[0]
330        #         mudot = (vel*dmudq).sum(axis=(1,2))
331        #     else:
332        #         dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1)
333        #         vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) *  nbeads**0.5
334        #         )
335        #         # vel = system["vel"][0].reshape(1,nat,3)
336        #         mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads
337
338        #     ir_state = save_dipole(mudot,ir_state)
339        #     return ir_state
340        @jax.jit
341        def update_conformation_ir(conformation, system):
342            conformation = {
343                **conformation,
344                "coordinates": system["coordinates"].reshape(-1, nat, 3)[0],
345                "natoms": jnp.asarray([nat]),
346                "batch_index": jnp.asarray([0] * nat),
347                "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]),
348            }
349            if variable_cell:
350                conformation["cells"] = system["cell"][None, :, :]
351                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
352                    None, :, :
353                ]
354            return conformation
355
356        @jax.jit
357        def update_dipole(ir_state, system, conformation):
358            if model_ir is not None:
359                out = model_ir._apply(model_ir.variables, conformation)
360                q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat))
361                dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
362            else:
363                q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat))
364                dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
365            if nbeads is not None:
366                q = jnp.mean(q, axis=0)
367                dip = jnp.mean(dip, axis=0)
368                vel = system["vel"][0]
369                pos = system["coordinates"][0]
370            else:
371                q = q[0]
372                dip = dip[0]
373                vel = system["vel"].reshape(-1, nat, 3)[0]
374                pos = system["coordinates"].reshape(-1, nat, 3)[0]
375
376            if pbc_data is not None:
377                cell_reciprocal = (
378                    conformation["cells"][0],
379                    conformation["reciprocal_cells"][0],
380                )
381            else:
382                cell_reciprocal = None
383
384            ir_state = save_dipole(
385                q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state
386            )
387            return ir_state
388
389    ###############################################
390    ### GRAPH UPDATES
391
392    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
393    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
394    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS
395    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
396    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
397    if nblist_skin > 0:
398        if nblist_stride <= 0:
399            ## reference skin parameters at 300K (from Tinker-HP)
400            ##   => skin of 2 A gives you 40 fs without complete rebuild
401            t_ref = 40.0  # FS
402            nblist_skin_ref = 2.0  # A
403            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
404        print(
405            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
406        )
407
408    if nblist_skin <= 0:
409        nblist_stride = 1
410
411    dyn_state["nblist_countdown"] = 0
412    dyn_state["print_skin_activation"] = nblist_warmup > 0
413
414    def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False):
415        nblist_countdown = dyn_state["nblist_countdown"]
416        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
417            ### FULL NBLIST REBUILD
418            dyn_state["nblist_countdown"] = nblist_stride - 1
419            preproc_state = dyn_state["preproc_state"]
420            conformation = model.preprocessing.process(
421                preproc_state, update_conformation(conformation, system)
422            )
423            preproc_state, state_up, conformation, overflow = (
424                model.preprocessing.check_reallocate(preproc_state, conformation)
425            )
426            dyn_state["preproc_state"] = preproc_state
427            if nblist_verbose and overflow:
428                print("step", istep, ", nblist overflow => reallocating nblist")
429                print("size updates:", state_up)
430
431            if do_ir_spectrum and model_ir is not None:
432                conformation_ir = model_ir.preprocessing.process(
433                    dyn_state["preproc_state_ir"],
434                    update_conformation_ir(dyn_state["conformation_ir"], system),
435                )
436                (
437                    dyn_state["preproc_state_ir"],
438                    _,
439                    dyn_state["conformation_ir"],
440                    overflow,
441                ) = model_ir.preprocessing.check_reallocate(
442                    dyn_state["preproc_state_ir"], conformation_ir
443                )
444
445        else:
446            ### SKIN UPDATE
447            if dyn_state["print_skin_activation"]:
448                if nblist_verbose:
449                    print(
450                        "step",
451                        istep,
452                        ", end of nblist warmup phase => activating skin updates",
453                    )
454                dyn_state["print_skin_activation"] = False
455
456            dyn_state["nblist_countdown"] = nblist_countdown - 1
457            conformation = model.preprocessing.update_skin(
458                update_conformation(conformation, system)
459            )
460            if do_ir_spectrum and model_ir is not None:
461                dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin(
462                    update_conformation_ir(dyn_state["conformation_ir"], system)
463                )
464
465        return conformation, dyn_state
466
467    ################################################
468    ### DEFINE STEP FUNCTION
469    def step(istep, dyn_state, system, conformation, force_preprocess=False):
470
471        dyn_state = {
472            **dyn_state,
473            "istep": dyn_state["istep"] + 1,
474        }
475
476        ### INTEGRATE EQUATIONS OF MOTION
477        system = integrate(system)
478
479        ### UPDATE CONFORMATION AND GRAPHS
480        conformation, dyn_state = update_graphs(
481            istep, dyn_state, system, conformation, force_preprocess
482        )
483
484        ## COMPUTE FORCES AND OBSERVABLES
485        system, out = update_observables(system, conformation)
486
487        ## END OF STEP UPDATES
488        if do_thermostat_post:
489            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
490                system["thermostat"], dyn_state["thermostat_post_state"]
491            )
492        
493        if do_ir_spectrum:
494            ir_state = update_dipole(
495                dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"]
496            )
497            dyn_state["ir_spectrum"] = ir_post(ir_state)
498
499        return dyn_state, system, conformation, out
500
501    ###########################################################
502
503    print("# Computing initial energy and forces")
504
505    conformation = update_conformation(conformation, system)
506    # initialize IR conformation
507    if do_ir_spectrum and model_ir is not None:
508        dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = (
509            model_ir.preprocessing(
510                model_ir.preproc_state,
511                update_conformation_ir(conformation, system),
512            )
513        )
514
515    system, _ = update_observables(system, conformation)
516
517    return step, update_conformation, system_data, dyn_state, conformation, system
def initialize_dynamics(simulation_parameters, fprec, rng_key):
 20def initialize_dynamics(simulation_parameters, fprec, rng_key):
 21    ### LOAD MODEL
 22    model = load_model(simulation_parameters)
 23    model_energy_unit = au.get_multiplier(model.energy_unit)
 24
 25    ### Get the coordinates and species from the xyz file
 26    system_data, conformation = load_system_data(simulation_parameters, fprec)
 27
 28    ### FINISH BUILDING conformation
 29    if os.path.exists(get_restart_file(system_data)):
 30        ### RESTART FROM PREVIOUS DYNAMICS
 31        restart_data = load_dynamics_restart(system_data)
 32        print("# RESTARTING FROM PREVIOUS DYNAMICS")
 33        model.preproc_state = restart_data["preproc_state"]
 34        conformation["coordinates"] = restart_data["coordinates"]
 35    else:
 36        restart_data = {}
 37
 38    ### INITIALIZE PREPROCESSING
 39    preproc_state, conformation = initialize_preprocessing(
 40        simulation_parameters, model, conformation, system_data
 41    )
 42
 43    ### get dynamics parameters
 44    dt = simulation_parameters.get("dt") * au.FS
 45    dt2 = 0.5 * dt
 46    mass = system_data["mass"]
 47    totmass_amu = system_data["totmass_amu"]
 48    nat = system_data["nat"]
 49    dtm = jnp.asarray(dt / mass[:, None], dtype=fprec)
 50
 51    nreplicas = system_data.get("nreplicas", 1)
 52    nbeads = system_data.get("nbeads", None)
 53    if nbeads is not None:
 54        nreplicas = nbeads
 55        dtm = dtm[None, :, :]
 56
 57    ### INITIALIZE DYNAMICS STATE
 58    system = {"coordinates": conformation["coordinates"]}
 59    dyn_state = {
 60        "istep": 0,
 61        "dt": dt,
 62        "pimd": nbeads is not None,
 63        "preproc_state": preproc_state,
 64        "start_time_ps": restart_data.get("simulation_time_ps", 0.),
 65    }
 66    gradient_keys = ["coordinates"]
 67    thermo_updates = []
 68
 69    ### INITIALIZE THERMOSTAT
 70    thermostat_rng, rng_key = jax.random.split(rng_key)
 71    (
 72        thermostat,
 73        thermostat_post,
 74        thermostat_state,
 75        initial_vel,
 76        dyn_state["thermostat_name"],
 77    ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data)
 78    do_thermostat_post = thermostat_post is not None
 79    if do_thermostat_post:
 80        thermostat_post, post_state = thermostat_post
 81        dyn_state["thermostat_post_state"] = post_state
 82
 83    system["thermostat"] = thermostat_state
 84    system["vel"] = restart_data.get("vel", initial_vel).astype(fprec)
 85
 86    ### PBC
 87    pbc_data = system_data.get("pbc", None)
 88    if pbc_data is not None:
 89        ### INITIALIZE BAROSTAT
 90        barostat_key, rng_key = jax.random.split(rng_key)
 91        thermo_update_ensemble, variable_cell, barostat_state = get_barostat(
 92            thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data
 93        )
 94        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
 95        system["barostat"] = barostat_state
 96        system["cell"] = conformation["cells"][0]
 97        if estimate_pressure:
 98            pressure_o_weight = simulation_parameters.get("pressure_o_weight", 0.0)
 99            assert (
100                0.0 <= pressure_o_weight <= 1.0
101            ), "pressure_o_weight must be between 0 and 1"
102            gradient_keys.append("strain")
103        print("# Estimate pressure: ", estimate_pressure)
104    else:
105        estimate_pressure = False
106        variable_cell = False
107
108        def thermo_update_ensemble(x, v, system):
109            v, thermostat_state = thermostat(v, system["thermostat"])
110            return x, v, {**system, "thermostat": thermostat_state}
111
112    dyn_state["estimate_pressure"] = estimate_pressure
113    dyn_state["variable_cell"] = variable_cell
114    thermo_updates.append(thermo_update_ensemble)
115
116    ### ENERGY ENSEMBLE
117    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
118
119    ### COLVARS
120    colvars_definitions = simulation_parameters.get("colvars", None)
121    use_colvars = colvars_definitions is not None
122    if use_colvars:
123        colvars_calculators, colvars_names = setup_colvars(colvars_definitions)
124        dyn_state["colvars"] = colvars_names
125
126    ### IR SPECTRUM
127    do_ir_spectrum = simulation_parameters.get("ir_spectrum", False)
128    assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean"
129    if do_ir_spectrum:
130        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
131        model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum(
132            simulation_parameters, system_data, fprec, dt, is_qtb
133        )
134        dyn_state["ir_spectrum"] = ir_state
135
136    ### BUILD GRADIENT FUNCTION
137    energy_and_gradient = model.get_gradient_function(
138        *gradient_keys, jit=True, variables_as_input=True
139    )
140
141    ### COLLECT THERMO UPDATES
142    if len(thermo_updates) == 1:
143        thermo_update = thermo_updates[0]
144    else:
145
146        def thermo_update(x, v, system):
147            for update in thermo_updates:
148                x, v, system = update(x, v, system)
149            return x, v, system
150
151    ### RING POLYMER INITIALIZATION
152    if nbeads is not None:
153        cay_correction = simulation_parameters.get("cay_correction", True)
154        omk = system_data["omk"]
155        eigmat = system_data["eigmat"]
156        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
157        if cay_correction:
158            axx = jnp.asarray(2 * cayfact)
159            axv = jnp.asarray(dt * cayfact)
160            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
161        else:
162            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
163            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
164            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
165
166        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
167        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
168        system["coordinates"] = eigx
169
170    ###############################################
171    ### DEFINE UPDATE FUNCTION
172    @jax.jit
173    def update_conformation(conformation, system):
174        x = system["coordinates"]
175        if nbeads is not None:
176            x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * (
177                nbeads**0.5
178            )
179        conformation = {**conformation, "coordinates": x}
180        if variable_cell:
181            conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0)
182
183        
184
185        return conformation
186
187    ###############################################
188    ### DEFINE INTEGRATION FUNCTIONS
189    def integrate_A_half(x0, v0):
190        if nbeads is None:
191            return x0 + dt2 * v0, v0
192
193        # update coordinates and velocities of a free ring polymer for a half time step
194        eigx_c = x0[0] + dt2 * v0[0]
195        eigv_c = v0[0]
196        eigx = x0[1:] * axx + v0[1:] * axv
197        eigv = x0[1:] * avx + v0[1:] * axx
198
199        return (
200            jnp.concatenate((eigx_c[None], eigx), axis=0),
201            jnp.concatenate((eigv_c[None], eigv), axis=0),
202        )
203
204    @jax.jit
205    def integrate(system):
206        x = system["coordinates"]
207        v = system["vel"] + dtm * system["forces"]
208        x, v = integrate_A_half(x, v)
209        x, v, system = thermo_update(x, v, system)
210        x, v = integrate_A_half(x, v)
211
212        return {**system, "coordinates": x, "vel": v}
213
214    ###############################################
215    ### DEFINE OBSERVABLE FUNCTION
216    @jax.jit
217    def update_observables(system, conformation):
218        ### POTENTIAL ENERGY AND FORCES
219        epot, de, out = energy_and_gradient(model.variables, conformation)
220        epot = epot / model_energy_unit
221        de = {k: v / model_energy_unit for k, v in de.items()}
222        forces = -de["coordinates"]
223
224        if nbeads is not None:
225            ### PROJECT FORCES ONTO POLYMER NORMAL MODES
226            forces = jnp.einsum(
227                "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3)
228            ) * (1.0 / nbeads**0.5)
229
230        system = {
231            **system,
232            "epot": jnp.mean(epot),
233            "forces": forces,
234            "energy_gradients": de,
235        }
236
237        ### KINETIC ENERGY
238        v = system["vel"]
239        if nbeads is None:
240            corr_kin = system["thermostat"].get("corr_kin", 1.0)
241            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
242            ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
243                mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
244            )
245        else:
246            ek_c = 0.5 * jnp.sum(
247                mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0
248            )
249            ek = ek_c - 0.5 * jnp.sum(
250                system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
251                axis=(0, 1),
252            )
253            system["ek_c"] = jnp.trace(ek_c)
254
255        system["ek"] = jnp.trace(ek)
256        system["ek_tensor"] = ek
257
258        if estimate_pressure:
259            if pressure_o_weight != 1.0:
260                v = system["vel"] + 0.5 * dtm * system["forces"]
261                if nbeads is None:
262                    corr_kin = system["thermostat"].get("corr_kin", 1.0)
263                    # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
264                    ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
265                        mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
266                    )
267                else:
268                    ek_c = 0.5 * jnp.sum(
269                        mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :],
270                        axis=0,
271                    )
272                    ek = ek_c - 0.5 * jnp.sum(
273                        system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
274                        axis=(0, 1),
275                    )
276                b = pressure_o_weight
277                ek = (1.0 - b) * ek + b * system["ek_tensor"]
278
279            vir = jnp.mean(de["strain"], axis=0)
280            system["virial"] = vir
281            
282            pV =  2 * ek  - vir
283            system["PV_tensor"] = pV
284            volume = jnp.abs(jnp.linalg.det(system["cell"]))
285            Pres = pV / volume
286            system["pressure_tensor"] = Pres
287            system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
288            if variable_cell:
289                density = totmass_amu / volume
290                system["density"] = density
291                system["volume"] = volume
292
293        if ensemble_key is not None:
294            kT = system_data["kT"]
295            dE = (
296                jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"]
297            )
298            system["ensemble_weights"] = -dE / kT
299
300        if "total_dipole" in out:
301            if nbeads is None:
302                system["total_dipole"] = out["total_dipole"][0]
303            else:
304                system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0)
305
306        if use_colvars:
307            coords = system["coordinates"].reshape(-1, nat, 3)[0]
308            colvars = {}
309            for colvar_name, colvar_calc in colvars_calculators.items():
310                colvars[colvar_name] = colvar_calc(coords)
311            system["colvars"] = colvars
312
313        return system, out
314
315    ###############################################
316    ### IR SPECTRUM
317    if do_ir_spectrum:
318        # @jax.jit
319        # def update_dipole(ir_state,system,conformation):
320        #     def mumodel(coords):
321        #         out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords})
322        #         if nbeads is None:
323        #             return out["total_dipole"][0]
324        #         return out["total_dipole"].sum(axis=0)
325        #     dmudqmodel = jax.jacobian(mumodel)
326
327        #     dmudq = dmudqmodel(conformation["coordinates"])
328        #     # print(dmudq.shape)
329        #     if nbeads is None:
330        #         vel = system["vel"].reshape(-1,1,nat,3)[0]
331        #         mudot = (vel*dmudq).sum(axis=(1,2))
332        #     else:
333        #         dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1)
334        #         vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) *  nbeads**0.5
335        #         )
336        #         # vel = system["vel"][0].reshape(1,nat,3)
337        #         mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads
338
339        #     ir_state = save_dipole(mudot,ir_state)
340        #     return ir_state
341        @jax.jit
342        def update_conformation_ir(conformation, system):
343            conformation = {
344                **conformation,
345                "coordinates": system["coordinates"].reshape(-1, nat, 3)[0],
346                "natoms": jnp.asarray([nat]),
347                "batch_index": jnp.asarray([0] * nat),
348                "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]),
349            }
350            if variable_cell:
351                conformation["cells"] = system["cell"][None, :, :]
352                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
353                    None, :, :
354                ]
355            return conformation
356
357        @jax.jit
358        def update_dipole(ir_state, system, conformation):
359            if model_ir is not None:
360                out = model_ir._apply(model_ir.variables, conformation)
361                q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat))
362                dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
363            else:
364                q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat))
365                dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
366            if nbeads is not None:
367                q = jnp.mean(q, axis=0)
368                dip = jnp.mean(dip, axis=0)
369                vel = system["vel"][0]
370                pos = system["coordinates"][0]
371            else:
372                q = q[0]
373                dip = dip[0]
374                vel = system["vel"].reshape(-1, nat, 3)[0]
375                pos = system["coordinates"].reshape(-1, nat, 3)[0]
376
377            if pbc_data is not None:
378                cell_reciprocal = (
379                    conformation["cells"][0],
380                    conformation["reciprocal_cells"][0],
381                )
382            else:
383                cell_reciprocal = None
384
385            ir_state = save_dipole(
386                q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state
387            )
388            return ir_state
389
390    ###############################################
391    ### GRAPH UPDATES
392
393    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
394    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
395    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS
396    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
397    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
398    if nblist_skin > 0:
399        if nblist_stride <= 0:
400            ## reference skin parameters at 300K (from Tinker-HP)
401            ##   => skin of 2 A gives you 40 fs without complete rebuild
402            t_ref = 40.0  # FS
403            nblist_skin_ref = 2.0  # A
404            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
405        print(
406            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
407        )
408
409    if nblist_skin <= 0:
410        nblist_stride = 1
411
412    dyn_state["nblist_countdown"] = 0
413    dyn_state["print_skin_activation"] = nblist_warmup > 0
414
415    def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False):
416        nblist_countdown = dyn_state["nblist_countdown"]
417        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
418            ### FULL NBLIST REBUILD
419            dyn_state["nblist_countdown"] = nblist_stride - 1
420            preproc_state = dyn_state["preproc_state"]
421            conformation = model.preprocessing.process(
422                preproc_state, update_conformation(conformation, system)
423            )
424            preproc_state, state_up, conformation, overflow = (
425                model.preprocessing.check_reallocate(preproc_state, conformation)
426            )
427            dyn_state["preproc_state"] = preproc_state
428            if nblist_verbose and overflow:
429                print("step", istep, ", nblist overflow => reallocating nblist")
430                print("size updates:", state_up)
431
432            if do_ir_spectrum and model_ir is not None:
433                conformation_ir = model_ir.preprocessing.process(
434                    dyn_state["preproc_state_ir"],
435                    update_conformation_ir(dyn_state["conformation_ir"], system),
436                )
437                (
438                    dyn_state["preproc_state_ir"],
439                    _,
440                    dyn_state["conformation_ir"],
441                    overflow,
442                ) = model_ir.preprocessing.check_reallocate(
443                    dyn_state["preproc_state_ir"], conformation_ir
444                )
445
446        else:
447            ### SKIN UPDATE
448            if dyn_state["print_skin_activation"]:
449                if nblist_verbose:
450                    print(
451                        "step",
452                        istep,
453                        ", end of nblist warmup phase => activating skin updates",
454                    )
455                dyn_state["print_skin_activation"] = False
456
457            dyn_state["nblist_countdown"] = nblist_countdown - 1
458            conformation = model.preprocessing.update_skin(
459                update_conformation(conformation, system)
460            )
461            if do_ir_spectrum and model_ir is not None:
462                dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin(
463                    update_conformation_ir(dyn_state["conformation_ir"], system)
464                )
465
466        return conformation, dyn_state
467
468    ################################################
469    ### DEFINE STEP FUNCTION
470    def step(istep, dyn_state, system, conformation, force_preprocess=False):
471
472        dyn_state = {
473            **dyn_state,
474            "istep": dyn_state["istep"] + 1,
475        }
476
477        ### INTEGRATE EQUATIONS OF MOTION
478        system = integrate(system)
479
480        ### UPDATE CONFORMATION AND GRAPHS
481        conformation, dyn_state = update_graphs(
482            istep, dyn_state, system, conformation, force_preprocess
483        )
484
485        ## COMPUTE FORCES AND OBSERVABLES
486        system, out = update_observables(system, conformation)
487
488        ## END OF STEP UPDATES
489        if do_thermostat_post:
490            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
491                system["thermostat"], dyn_state["thermostat_post_state"]
492            )
493        
494        if do_ir_spectrum:
495            ir_state = update_dipole(
496                dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"]
497            )
498            dyn_state["ir_spectrum"] = ir_post(ir_state)
499
500        return dyn_state, system, conformation, out
501
502    ###########################################################
503
504    print("# Computing initial energy and forces")
505
506    conformation = update_conformation(conformation, system)
507    # initialize IR conformation
508    if do_ir_spectrum and model_ir is not None:
509        dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = (
510            model_ir.preprocessing(
511                model_ir.preproc_state,
512                update_conformation_ir(conformation, system),
513            )
514        )
515
516    system, _ = update_observables(system, conformation)
517
518    return step, update_conformation, system_data, dyn_state, conformation, system