fennol.md.integrate

  1import time
  2import math
  3import os
  4
  5import numpy as np
  6import jax
  7import jax.numpy as jnp
  8
  9from ..utils.atomic_units import AtomicUnits as au
 10from .thermostats import get_thermostat
 11from .barostats import get_barostat
 12from .colvars import setup_colvars
 13from .spectra import initialize_ir_spectrum
 14
 15from .utils import load_dynamics_restart, get_restart_file
 16from .initial import load_model, load_system_data, initialize_preprocessing
 17
 18
 19def initialize_dynamics(simulation_parameters, fprec, rng_key):
 20    ### LOAD MODEL
 21    model = load_model(simulation_parameters)
 22    model_energy_unit = au.get_multiplier(model.energy_unit)
 23
 24    ### Get the coordinates and species from the xyz file
 25    system_data, conformation = load_system_data(simulation_parameters, fprec)
 26    system_data["model_energy_unit"] = model_energy_unit
 27    system_data["model_energy_unit_str"] = model.energy_unit
 28
 29    ### FINISH BUILDING conformation
 30    if os.path.exists(get_restart_file(system_data)):
 31        ### RESTART FROM PREVIOUS DYNAMICS
 32        restart_data = load_dynamics_restart(system_data)
 33        print("# RESTARTING FROM PREVIOUS DYNAMICS")
 34        model.preproc_state = restart_data["preproc_state"]
 35        conformation["coordinates"] = restart_data["coordinates"]
 36    else:
 37        restart_data = {}
 38
 39    ### INITIALIZE PREPROCESSING
 40    preproc_state, conformation = initialize_preprocessing(
 41        simulation_parameters, model, conformation, system_data
 42    )
 43
 44    ### get dynamics parameters
 45    dt = simulation_parameters.get("dt") * au.FS
 46    dt2 = 0.5 * dt
 47    mass = system_data["mass"]
 48    densmass = system_data["totmass_Da"]*(au.MPROT*au.GCM3)
 49    nat = system_data["nat"]
 50    dtm = jnp.asarray(dt / mass[:, None], dtype=fprec)
 51
 52    nreplicas = system_data.get("nreplicas", 1)
 53    nbeads = system_data.get("nbeads", None)
 54    if nbeads is not None:
 55        nreplicas = nbeads
 56        dtm = dtm[None, :, :]
 57
 58    ### INITIALIZE DYNAMICS STATE
 59    system = {"coordinates": conformation["coordinates"]}
 60    dyn_state = {
 61        "istep": 0,
 62        "dt": dt,
 63        "pimd": nbeads is not None,
 64        "preproc_state": preproc_state,
 65        "start_time_ps": restart_data.get("simulation_time_ps", 0.),
 66    }
 67    gradient_keys = ["coordinates"]
 68    thermo_updates = []
 69
 70    ### INITIALIZE THERMOSTAT
 71    thermostat_rng, rng_key = jax.random.split(rng_key)
 72    (
 73        thermostat,
 74        thermostat_post,
 75        thermostat_state,
 76        initial_vel,
 77        dyn_state["thermostat_name"],
 78    ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data)
 79    do_thermostat_post = thermostat_post is not None
 80    if do_thermostat_post:
 81        thermostat_post, post_state = thermostat_post
 82        dyn_state["thermostat_post_state"] = post_state
 83
 84    system["thermostat"] = thermostat_state
 85    system["vel"] = restart_data.get("vel", initial_vel).astype(fprec)
 86
 87    ### PBC
 88    pbc_data = system_data.get("pbc", None)
 89    if pbc_data is not None:
 90        ### INITIALIZE BAROSTAT
 91        barostat_key, rng_key = jax.random.split(rng_key)
 92        thermo_update_ensemble, variable_cell, barostat_state = get_barostat(
 93            thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data
 94        )
 95        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
 96        system["barostat"] = barostat_state
 97        system["cell"] = conformation["cells"][0]
 98        if estimate_pressure:
 99            pressure_o_weight = simulation_parameters.get("pressure_o_weight", 0.0)
100            assert (
101                0.0 <= pressure_o_weight <= 1.0
102            ), "pressure_o_weight must be between 0 and 1"
103            gradient_keys.append("strain")
104        print("# Estimate pressure: ", estimate_pressure)
105    else:
106        estimate_pressure = False
107        variable_cell = False
108
109        def thermo_update_ensemble(x, v, system):
110            v, thermostat_state = thermostat(v, system["thermostat"])
111            return x, v, {**system, "thermostat": thermostat_state}
112
113    dyn_state["estimate_pressure"] = estimate_pressure
114    dyn_state["variable_cell"] = variable_cell
115    thermo_updates.append(thermo_update_ensemble)
116
117    ### ENERGY ENSEMBLE
118    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
119
120    ### COLVARS
121    colvars_definitions = simulation_parameters.get("colvars", None)
122    use_colvars = colvars_definitions is not None
123    if use_colvars:
124        colvars_calculators, colvars_names = setup_colvars(colvars_definitions)
125        dyn_state["colvars"] = colvars_names
126
127    ### IR SPECTRUM
128    do_ir_spectrum = simulation_parameters.get("ir_spectrum", False)
129    assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean"
130    if do_ir_spectrum:
131        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
132        model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum(
133            simulation_parameters, system_data, fprec, dt, is_qtb
134        )
135        dyn_state["ir_spectrum"] = ir_state
136
137    ### BUILD GRADIENT FUNCTION
138    energy_and_gradient = model.get_gradient_function(
139        *gradient_keys, jit=True, variables_as_input=True
140    )
141
142    ### COLLECT THERMO UPDATES
143    if len(thermo_updates) == 1:
144        thermo_update = thermo_updates[0]
145    else:
146
147        def thermo_update(x, v, system):
148            for update in thermo_updates:
149                x, v, system = update(x, v, system)
150            return x, v, system
151
152    ### RING POLYMER INITIALIZATION
153    if nbeads is not None:
154        cay_correction = simulation_parameters.get("cay_correction", True)
155        omk = system_data["omk"]
156        eigmat = system_data["eigmat"]
157        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
158        if cay_correction:
159            axx = jnp.asarray(2 * cayfact)
160            axv = jnp.asarray(dt * cayfact)
161            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
162        else:
163            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
164            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
165            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
166
167        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
168        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
169        system["coordinates"] = eigx
170
171    ###############################################
172    ### DEFINE UPDATE FUNCTION
173    @jax.jit
174    def update_conformation(conformation, system):
175        x = system["coordinates"]
176        if nbeads is not None:
177            x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * (
178                nbeads**0.5
179            )
180        conformation = {**conformation, "coordinates": x}
181        if variable_cell:
182            conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0)
183
184        
185
186        return conformation
187
188    ###############################################
189    ### DEFINE INTEGRATION FUNCTIONS
190    def integrate_A_half(x0, v0):
191        if nbeads is None:
192            return x0 + dt2 * v0, v0
193
194        # update coordinates and velocities of a free ring polymer for a half time step
195        eigx_c = x0[0] + dt2 * v0[0]
196        eigv_c = v0[0]
197        eigx = x0[1:] * axx + v0[1:] * axv
198        eigv = x0[1:] * avx + v0[1:] * axx
199
200        return (
201            jnp.concatenate((eigx_c[None], eigx), axis=0),
202            jnp.concatenate((eigv_c[None], eigv), axis=0),
203        )
204
205    @jax.jit
206    def integrate(system):
207        x = system["coordinates"]
208        v = system["vel"] + dtm * system["forces"]
209        x, v = integrate_A_half(x, v)
210        x, v, system = thermo_update(x, v, system)
211        x, v = integrate_A_half(x, v)
212
213        return {**system, "coordinates": x, "vel": v}
214
215    ###############################################
216    ### DEFINE OBSERVABLE FUNCTION
217    @jax.jit
218    def update_observables(system, conformation):
219        ### POTENTIAL ENERGY AND FORCES
220        epot, de, out = energy_and_gradient(model.variables, conformation)
221        out["forces"] = -de["coordinates"]
222        epot = epot / model_energy_unit
223        de = {k: v / model_energy_unit for k, v in de.items()}
224        forces = -de["coordinates"]
225
226        if nbeads is not None:
227            ### PROJECT FORCES ONTO POLYMER NORMAL MODES
228            forces = jnp.einsum(
229                "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3)
230            ) * (1.0 / nbeads**0.5)
231
232        system = {
233            **system,
234            "epot": jnp.mean(epot),
235            "forces": forces,
236            "energy_gradients": de,
237        }
238
239        ### KINETIC ENERGY
240        v = system["vel"]
241        if nbeads is None:
242            corr_kin = system["thermostat"].get("corr_kin", 1.0)
243            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
244            ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
245                mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
246            )
247        else:
248            ek_c = 0.5 * jnp.sum(
249                mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0
250            )
251            ek = ek_c - 0.5 * jnp.sum(
252                system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
253                axis=(0, 1),
254            )
255            system["ek_c"] = jnp.trace(ek_c)
256
257        system["ek"] = jnp.trace(ek)
258        system["ek_tensor"] = ek
259
260        if estimate_pressure:
261            if pressure_o_weight != 1.0:
262                v = system["vel"] + 0.5 * dtm * system["forces"]
263                if nbeads is None:
264                    corr_kin = system["thermostat"].get("corr_kin", 1.0)
265                    # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
266                    ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
267                        mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
268                    )
269                else:
270                    ek_c = 0.5 * jnp.sum(
271                        mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :],
272                        axis=0,
273                    )
274                    ek = ek_c - 0.5 * jnp.sum(
275                        system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
276                        axis=(0, 1),
277                    )
278                b = pressure_o_weight
279                ek = (1.0 - b) * ek + b * system["ek_tensor"]
280
281            vir = jnp.mean(de["strain"], axis=0)
282            system["virial"] = vir
283            out["virial_tensor"] = vir * model_energy_unit
284            
285            pV =  2 * ek  - vir
286            system["PV_tensor"] = pV
287            volume = jnp.abs(jnp.linalg.det(system["cell"]))
288            Pres = pV / volume
289            system["pressure_tensor"] = Pres
290            system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
291            if variable_cell:
292                density = densmass / volume
293                system["density"] = density
294                system["volume"] = volume
295
296        if ensemble_key is not None:
297            kT = system_data["kT"]
298            dE = (
299                jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"]
300            )
301            system["ensemble_weights"] = -dE / kT
302
303        if "total_dipole" in out:
304            if nbeads is None:
305                system["total_dipole"] = out["total_dipole"][0]
306            else:
307                system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0)
308
309        if use_colvars:
310            coords = system["coordinates"].reshape(-1, nat, 3)[0]
311            colvars = {}
312            for colvar_name, colvar_calc in colvars_calculators.items():
313                colvars[colvar_name] = colvar_calc(coords)
314            system["colvars"] = colvars
315
316        return system, out
317
318    ###############################################
319    ### IR SPECTRUM
320    if do_ir_spectrum:
321        # @jax.jit
322        # def update_dipole(ir_state,system,conformation):
323        #     def mumodel(coords):
324        #         out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords})
325        #         if nbeads is None:
326        #             return out["total_dipole"][0]
327        #         return out["total_dipole"].sum(axis=0)
328        #     dmudqmodel = jax.jacobian(mumodel)
329
330        #     dmudq = dmudqmodel(conformation["coordinates"])
331        #     # print(dmudq.shape)
332        #     if nbeads is None:
333        #         vel = system["vel"].reshape(-1,1,nat,3)[0]
334        #         mudot = (vel*dmudq).sum(axis=(1,2))
335        #     else:
336        #         dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1)
337        #         vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) *  nbeads**0.5
338        #         )
339        #         # vel = system["vel"][0].reshape(1,nat,3)
340        #         mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads
341
342        #     ir_state = save_dipole(mudot,ir_state)
343        #     return ir_state
344        @jax.jit
345        def update_conformation_ir(conformation, system):
346            conformation = {
347                **conformation,
348                "coordinates": system["coordinates"].reshape(-1, nat, 3)[0],
349                "natoms": jnp.asarray([nat]),
350                "batch_index": jnp.asarray([0] * nat),
351                "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]),
352            }
353            if variable_cell:
354                conformation["cells"] = system["cell"][None, :, :]
355                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
356                    None, :, :
357                ]
358            return conformation
359
360        @jax.jit
361        def update_dipole(ir_state, system, conformation):
362            if model_ir is not None:
363                out = model_ir._apply(model_ir.variables, conformation)
364                q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat))
365                dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
366            else:
367                q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat))
368                dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
369            if nbeads is not None:
370                q = jnp.mean(q, axis=0)
371                dip = jnp.mean(dip, axis=0)
372                vel = system["vel"][0]
373                pos = system["coordinates"][0]
374            else:
375                q = q[0]
376                dip = dip[0]
377                vel = system["vel"].reshape(-1, nat, 3)[0]
378                pos = system["coordinates"].reshape(-1, nat, 3)[0]
379
380            if pbc_data is not None:
381                cell_reciprocal = (
382                    conformation["cells"][0],
383                    conformation["reciprocal_cells"][0],
384                )
385            else:
386                cell_reciprocal = None
387
388            ir_state = save_dipole(
389                q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state
390            )
391            return ir_state
392
393    ###############################################
394    ### GRAPH UPDATES
395
396    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
397    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
398    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS
399    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
400    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
401    if nblist_skin > 0:
402        if nblist_stride <= 0:
403            ## reference skin parameters at 300K (from Tinker-HP)
404            ##   => skin of 2 A gives you 40 fs without complete rebuild
405            t_ref = 40.0  # FS
406            nblist_skin_ref = 2.0  # A
407            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
408        print(
409            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
410        )
411
412    if nblist_skin <= 0:
413        nblist_stride = 1
414
415    dyn_state["nblist_countdown"] = 0
416    dyn_state["print_skin_activation"] = nblist_warmup > 0
417
418    def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False):
419        nblist_countdown = dyn_state["nblist_countdown"]
420        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
421            ### FULL NBLIST REBUILD
422            dyn_state["nblist_countdown"] = nblist_stride - 1
423            preproc_state = dyn_state["preproc_state"]
424            conformation = model.preprocessing.process(
425                preproc_state, update_conformation(conformation, system)
426            )
427            preproc_state, state_up, conformation, overflow = (
428                model.preprocessing.check_reallocate(preproc_state, conformation)
429            )
430            dyn_state["preproc_state"] = preproc_state
431            if nblist_verbose and overflow:
432                print("step", istep, ", nblist overflow => reallocating nblist")
433                print("size updates:", state_up)
434
435            if do_ir_spectrum and model_ir is not None:
436                conformation_ir = model_ir.preprocessing.process(
437                    dyn_state["preproc_state_ir"],
438                    update_conformation_ir(dyn_state["conformation_ir"], system),
439                )
440                (
441                    dyn_state["preproc_state_ir"],
442                    _,
443                    dyn_state["conformation_ir"],
444                    overflow,
445                ) = model_ir.preprocessing.check_reallocate(
446                    dyn_state["preproc_state_ir"], conformation_ir
447                )
448
449        else:
450            ### SKIN UPDATE
451            if dyn_state["print_skin_activation"]:
452                if nblist_verbose:
453                    print(
454                        "step",
455                        istep,
456                        ", end of nblist warmup phase => activating skin updates",
457                    )
458                dyn_state["print_skin_activation"] = False
459
460            dyn_state["nblist_countdown"] = nblist_countdown - 1
461            conformation = model.preprocessing.update_skin(
462                update_conformation(conformation, system)
463            )
464            if do_ir_spectrum and model_ir is not None:
465                dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin(
466                    update_conformation_ir(dyn_state["conformation_ir"], system)
467                )
468
469        return conformation, dyn_state
470
471    ################################################
472    ### DEFINE STEP FUNCTION
473    def step(istep, dyn_state, system, conformation, force_preprocess=False):
474
475        dyn_state = {
476            **dyn_state,
477            "istep": dyn_state["istep"] + 1,
478        }
479
480        ### INTEGRATE EQUATIONS OF MOTION
481        system = integrate(system)
482
483        ### UPDATE CONFORMATION AND GRAPHS
484        conformation, dyn_state = update_graphs(
485            istep, dyn_state, system, conformation, force_preprocess
486        )
487
488        ## COMPUTE FORCES AND OBSERVABLES
489        system, out = update_observables(system, conformation)
490
491        ## END OF STEP UPDATES
492        if do_thermostat_post:
493            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
494                system["thermostat"], dyn_state["thermostat_post_state"]
495            )
496        
497        if do_ir_spectrum:
498            ir_state = update_dipole(
499                dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"]
500            )
501            dyn_state["ir_spectrum"] = ir_post(ir_state)
502
503        return dyn_state, system, conformation, out
504
505    ###########################################################
506
507    print("# Computing initial energy and forces")
508
509    conformation = update_conformation(conformation, system)
510    # initialize IR conformation
511    if do_ir_spectrum and model_ir is not None:
512        dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = (
513            model_ir.preprocessing(
514                model_ir.preproc_state,
515                update_conformation_ir(conformation, system),
516            )
517        )
518
519    system, _ = update_observables(system, conformation)
520
521    return step, update_conformation, system_data, dyn_state, conformation, system
def initialize_dynamics(simulation_parameters, fprec, rng_key):
 20def initialize_dynamics(simulation_parameters, fprec, rng_key):
 21    ### LOAD MODEL
 22    model = load_model(simulation_parameters)
 23    model_energy_unit = au.get_multiplier(model.energy_unit)
 24
 25    ### Get the coordinates and species from the xyz file
 26    system_data, conformation = load_system_data(simulation_parameters, fprec)
 27    system_data["model_energy_unit"] = model_energy_unit
 28    system_data["model_energy_unit_str"] = model.energy_unit
 29
 30    ### FINISH BUILDING conformation
 31    if os.path.exists(get_restart_file(system_data)):
 32        ### RESTART FROM PREVIOUS DYNAMICS
 33        restart_data = load_dynamics_restart(system_data)
 34        print("# RESTARTING FROM PREVIOUS DYNAMICS")
 35        model.preproc_state = restart_data["preproc_state"]
 36        conformation["coordinates"] = restart_data["coordinates"]
 37    else:
 38        restart_data = {}
 39
 40    ### INITIALIZE PREPROCESSING
 41    preproc_state, conformation = initialize_preprocessing(
 42        simulation_parameters, model, conformation, system_data
 43    )
 44
 45    ### get dynamics parameters
 46    dt = simulation_parameters.get("dt") * au.FS
 47    dt2 = 0.5 * dt
 48    mass = system_data["mass"]
 49    densmass = system_data["totmass_Da"]*(au.MPROT*au.GCM3)
 50    nat = system_data["nat"]
 51    dtm = jnp.asarray(dt / mass[:, None], dtype=fprec)
 52
 53    nreplicas = system_data.get("nreplicas", 1)
 54    nbeads = system_data.get("nbeads", None)
 55    if nbeads is not None:
 56        nreplicas = nbeads
 57        dtm = dtm[None, :, :]
 58
 59    ### INITIALIZE DYNAMICS STATE
 60    system = {"coordinates": conformation["coordinates"]}
 61    dyn_state = {
 62        "istep": 0,
 63        "dt": dt,
 64        "pimd": nbeads is not None,
 65        "preproc_state": preproc_state,
 66        "start_time_ps": restart_data.get("simulation_time_ps", 0.),
 67    }
 68    gradient_keys = ["coordinates"]
 69    thermo_updates = []
 70
 71    ### INITIALIZE THERMOSTAT
 72    thermostat_rng, rng_key = jax.random.split(rng_key)
 73    (
 74        thermostat,
 75        thermostat_post,
 76        thermostat_state,
 77        initial_vel,
 78        dyn_state["thermostat_name"],
 79    ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data)
 80    do_thermostat_post = thermostat_post is not None
 81    if do_thermostat_post:
 82        thermostat_post, post_state = thermostat_post
 83        dyn_state["thermostat_post_state"] = post_state
 84
 85    system["thermostat"] = thermostat_state
 86    system["vel"] = restart_data.get("vel", initial_vel).astype(fprec)
 87
 88    ### PBC
 89    pbc_data = system_data.get("pbc", None)
 90    if pbc_data is not None:
 91        ### INITIALIZE BAROSTAT
 92        barostat_key, rng_key = jax.random.split(rng_key)
 93        thermo_update_ensemble, variable_cell, barostat_state = get_barostat(
 94            thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data
 95        )
 96        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
 97        system["barostat"] = barostat_state
 98        system["cell"] = conformation["cells"][0]
 99        if estimate_pressure:
100            pressure_o_weight = simulation_parameters.get("pressure_o_weight", 0.0)
101            assert (
102                0.0 <= pressure_o_weight <= 1.0
103            ), "pressure_o_weight must be between 0 and 1"
104            gradient_keys.append("strain")
105        print("# Estimate pressure: ", estimate_pressure)
106    else:
107        estimate_pressure = False
108        variable_cell = False
109
110        def thermo_update_ensemble(x, v, system):
111            v, thermostat_state = thermostat(v, system["thermostat"])
112            return x, v, {**system, "thermostat": thermostat_state}
113
114    dyn_state["estimate_pressure"] = estimate_pressure
115    dyn_state["variable_cell"] = variable_cell
116    thermo_updates.append(thermo_update_ensemble)
117
118    ### ENERGY ENSEMBLE
119    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
120
121    ### COLVARS
122    colvars_definitions = simulation_parameters.get("colvars", None)
123    use_colvars = colvars_definitions is not None
124    if use_colvars:
125        colvars_calculators, colvars_names = setup_colvars(colvars_definitions)
126        dyn_state["colvars"] = colvars_names
127
128    ### IR SPECTRUM
129    do_ir_spectrum = simulation_parameters.get("ir_spectrum", False)
130    assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean"
131    if do_ir_spectrum:
132        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
133        model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum(
134            simulation_parameters, system_data, fprec, dt, is_qtb
135        )
136        dyn_state["ir_spectrum"] = ir_state
137
138    ### BUILD GRADIENT FUNCTION
139    energy_and_gradient = model.get_gradient_function(
140        *gradient_keys, jit=True, variables_as_input=True
141    )
142
143    ### COLLECT THERMO UPDATES
144    if len(thermo_updates) == 1:
145        thermo_update = thermo_updates[0]
146    else:
147
148        def thermo_update(x, v, system):
149            for update in thermo_updates:
150                x, v, system = update(x, v, system)
151            return x, v, system
152
153    ### RING POLYMER INITIALIZATION
154    if nbeads is not None:
155        cay_correction = simulation_parameters.get("cay_correction", True)
156        omk = system_data["omk"]
157        eigmat = system_data["eigmat"]
158        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
159        if cay_correction:
160            axx = jnp.asarray(2 * cayfact)
161            axv = jnp.asarray(dt * cayfact)
162            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
163        else:
164            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
165            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
166            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
167
168        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
169        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
170        system["coordinates"] = eigx
171
172    ###############################################
173    ### DEFINE UPDATE FUNCTION
174    @jax.jit
175    def update_conformation(conformation, system):
176        x = system["coordinates"]
177        if nbeads is not None:
178            x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * (
179                nbeads**0.5
180            )
181        conformation = {**conformation, "coordinates": x}
182        if variable_cell:
183            conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0)
184
185        
186
187        return conformation
188
189    ###############################################
190    ### DEFINE INTEGRATION FUNCTIONS
191    def integrate_A_half(x0, v0):
192        if nbeads is None:
193            return x0 + dt2 * v0, v0
194
195        # update coordinates and velocities of a free ring polymer for a half time step
196        eigx_c = x0[0] + dt2 * v0[0]
197        eigv_c = v0[0]
198        eigx = x0[1:] * axx + v0[1:] * axv
199        eigv = x0[1:] * avx + v0[1:] * axx
200
201        return (
202            jnp.concatenate((eigx_c[None], eigx), axis=0),
203            jnp.concatenate((eigv_c[None], eigv), axis=0),
204        )
205
206    @jax.jit
207    def integrate(system):
208        x = system["coordinates"]
209        v = system["vel"] + dtm * system["forces"]
210        x, v = integrate_A_half(x, v)
211        x, v, system = thermo_update(x, v, system)
212        x, v = integrate_A_half(x, v)
213
214        return {**system, "coordinates": x, "vel": v}
215
216    ###############################################
217    ### DEFINE OBSERVABLE FUNCTION
218    @jax.jit
219    def update_observables(system, conformation):
220        ### POTENTIAL ENERGY AND FORCES
221        epot, de, out = energy_and_gradient(model.variables, conformation)
222        out["forces"] = -de["coordinates"]
223        epot = epot / model_energy_unit
224        de = {k: v / model_energy_unit for k, v in de.items()}
225        forces = -de["coordinates"]
226
227        if nbeads is not None:
228            ### PROJECT FORCES ONTO POLYMER NORMAL MODES
229            forces = jnp.einsum(
230                "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3)
231            ) * (1.0 / nbeads**0.5)
232
233        system = {
234            **system,
235            "epot": jnp.mean(epot),
236            "forces": forces,
237            "energy_gradients": de,
238        }
239
240        ### KINETIC ENERGY
241        v = system["vel"]
242        if nbeads is None:
243            corr_kin = system["thermostat"].get("corr_kin", 1.0)
244            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
245            ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
246                mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
247            )
248        else:
249            ek_c = 0.5 * jnp.sum(
250                mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0
251            )
252            ek = ek_c - 0.5 * jnp.sum(
253                system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
254                axis=(0, 1),
255            )
256            system["ek_c"] = jnp.trace(ek_c)
257
258        system["ek"] = jnp.trace(ek)
259        system["ek_tensor"] = ek
260
261        if estimate_pressure:
262            if pressure_o_weight != 1.0:
263                v = system["vel"] + 0.5 * dtm * system["forces"]
264                if nbeads is None:
265                    corr_kin = system["thermostat"].get("corr_kin", 1.0)
266                    # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
267                    ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
268                        mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
269                    )
270                else:
271                    ek_c = 0.5 * jnp.sum(
272                        mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :],
273                        axis=0,
274                    )
275                    ek = ek_c - 0.5 * jnp.sum(
276                        system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
277                        axis=(0, 1),
278                    )
279                b = pressure_o_weight
280                ek = (1.0 - b) * ek + b * system["ek_tensor"]
281
282            vir = jnp.mean(de["strain"], axis=0)
283            system["virial"] = vir
284            out["virial_tensor"] = vir * model_energy_unit
285            
286            pV =  2 * ek  - vir
287            system["PV_tensor"] = pV
288            volume = jnp.abs(jnp.linalg.det(system["cell"]))
289            Pres = pV / volume
290            system["pressure_tensor"] = Pres
291            system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
292            if variable_cell:
293                density = densmass / volume
294                system["density"] = density
295                system["volume"] = volume
296
297        if ensemble_key is not None:
298            kT = system_data["kT"]
299            dE = (
300                jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"]
301            )
302            system["ensemble_weights"] = -dE / kT
303
304        if "total_dipole" in out:
305            if nbeads is None:
306                system["total_dipole"] = out["total_dipole"][0]
307            else:
308                system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0)
309
310        if use_colvars:
311            coords = system["coordinates"].reshape(-1, nat, 3)[0]
312            colvars = {}
313            for colvar_name, colvar_calc in colvars_calculators.items():
314                colvars[colvar_name] = colvar_calc(coords)
315            system["colvars"] = colvars
316
317        return system, out
318
319    ###############################################
320    ### IR SPECTRUM
321    if do_ir_spectrum:
322        # @jax.jit
323        # def update_dipole(ir_state,system,conformation):
324        #     def mumodel(coords):
325        #         out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords})
326        #         if nbeads is None:
327        #             return out["total_dipole"][0]
328        #         return out["total_dipole"].sum(axis=0)
329        #     dmudqmodel = jax.jacobian(mumodel)
330
331        #     dmudq = dmudqmodel(conformation["coordinates"])
332        #     # print(dmudq.shape)
333        #     if nbeads is None:
334        #         vel = system["vel"].reshape(-1,1,nat,3)[0]
335        #         mudot = (vel*dmudq).sum(axis=(1,2))
336        #     else:
337        #         dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1)
338        #         vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) *  nbeads**0.5
339        #         )
340        #         # vel = system["vel"][0].reshape(1,nat,3)
341        #         mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads
342
343        #     ir_state = save_dipole(mudot,ir_state)
344        #     return ir_state
345        @jax.jit
346        def update_conformation_ir(conformation, system):
347            conformation = {
348                **conformation,
349                "coordinates": system["coordinates"].reshape(-1, nat, 3)[0],
350                "natoms": jnp.asarray([nat]),
351                "batch_index": jnp.asarray([0] * nat),
352                "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]),
353            }
354            if variable_cell:
355                conformation["cells"] = system["cell"][None, :, :]
356                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
357                    None, :, :
358                ]
359            return conformation
360
361        @jax.jit
362        def update_dipole(ir_state, system, conformation):
363            if model_ir is not None:
364                out = model_ir._apply(model_ir.variables, conformation)
365                q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat))
366                dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
367            else:
368                q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat))
369                dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
370            if nbeads is not None:
371                q = jnp.mean(q, axis=0)
372                dip = jnp.mean(dip, axis=0)
373                vel = system["vel"][0]
374                pos = system["coordinates"][0]
375            else:
376                q = q[0]
377                dip = dip[0]
378                vel = system["vel"].reshape(-1, nat, 3)[0]
379                pos = system["coordinates"].reshape(-1, nat, 3)[0]
380
381            if pbc_data is not None:
382                cell_reciprocal = (
383                    conformation["cells"][0],
384                    conformation["reciprocal_cells"][0],
385                )
386            else:
387                cell_reciprocal = None
388
389            ir_state = save_dipole(
390                q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state
391            )
392            return ir_state
393
394    ###############################################
395    ### GRAPH UPDATES
396
397    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
398    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
399    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0) * au.FS
400    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
401    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
402    if nblist_skin > 0:
403        if nblist_stride <= 0:
404            ## reference skin parameters at 300K (from Tinker-HP)
405            ##   => skin of 2 A gives you 40 fs without complete rebuild
406            t_ref = 40.0  # FS
407            nblist_skin_ref = 2.0  # A
408            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
409        print(
410            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
411        )
412
413    if nblist_skin <= 0:
414        nblist_stride = 1
415
416    dyn_state["nblist_countdown"] = 0
417    dyn_state["print_skin_activation"] = nblist_warmup > 0
418
419    def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False):
420        nblist_countdown = dyn_state["nblist_countdown"]
421        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
422            ### FULL NBLIST REBUILD
423            dyn_state["nblist_countdown"] = nblist_stride - 1
424            preproc_state = dyn_state["preproc_state"]
425            conformation = model.preprocessing.process(
426                preproc_state, update_conformation(conformation, system)
427            )
428            preproc_state, state_up, conformation, overflow = (
429                model.preprocessing.check_reallocate(preproc_state, conformation)
430            )
431            dyn_state["preproc_state"] = preproc_state
432            if nblist_verbose and overflow:
433                print("step", istep, ", nblist overflow => reallocating nblist")
434                print("size updates:", state_up)
435
436            if do_ir_spectrum and model_ir is not None:
437                conformation_ir = model_ir.preprocessing.process(
438                    dyn_state["preproc_state_ir"],
439                    update_conformation_ir(dyn_state["conformation_ir"], system),
440                )
441                (
442                    dyn_state["preproc_state_ir"],
443                    _,
444                    dyn_state["conformation_ir"],
445                    overflow,
446                ) = model_ir.preprocessing.check_reallocate(
447                    dyn_state["preproc_state_ir"], conformation_ir
448                )
449
450        else:
451            ### SKIN UPDATE
452            if dyn_state["print_skin_activation"]:
453                if nblist_verbose:
454                    print(
455                        "step",
456                        istep,
457                        ", end of nblist warmup phase => activating skin updates",
458                    )
459                dyn_state["print_skin_activation"] = False
460
461            dyn_state["nblist_countdown"] = nblist_countdown - 1
462            conformation = model.preprocessing.update_skin(
463                update_conformation(conformation, system)
464            )
465            if do_ir_spectrum and model_ir is not None:
466                dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin(
467                    update_conformation_ir(dyn_state["conformation_ir"], system)
468                )
469
470        return conformation, dyn_state
471
472    ################################################
473    ### DEFINE STEP FUNCTION
474    def step(istep, dyn_state, system, conformation, force_preprocess=False):
475
476        dyn_state = {
477            **dyn_state,
478            "istep": dyn_state["istep"] + 1,
479        }
480
481        ### INTEGRATE EQUATIONS OF MOTION
482        system = integrate(system)
483
484        ### UPDATE CONFORMATION AND GRAPHS
485        conformation, dyn_state = update_graphs(
486            istep, dyn_state, system, conformation, force_preprocess
487        )
488
489        ## COMPUTE FORCES AND OBSERVABLES
490        system, out = update_observables(system, conformation)
491
492        ## END OF STEP UPDATES
493        if do_thermostat_post:
494            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
495                system["thermostat"], dyn_state["thermostat_post_state"]
496            )
497        
498        if do_ir_spectrum:
499            ir_state = update_dipole(
500                dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"]
501            )
502            dyn_state["ir_spectrum"] = ir_post(ir_state)
503
504        return dyn_state, system, conformation, out
505
506    ###########################################################
507
508    print("# Computing initial energy and forces")
509
510    conformation = update_conformation(conformation, system)
511    # initialize IR conformation
512    if do_ir_spectrum and model_ir is not None:
513        dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = (
514            model_ir.preprocessing(
515                model_ir.preproc_state,
516                update_conformation_ir(conformation, system),
517            )
518        )
519
520    system, _ = update_observables(system, conformation)
521
522    return step, update_conformation, system_data, dyn_state, conformation, system