fennol.md.integrate

  1import time
  2import math
  3import os
  4
  5import numpy as np
  6import jax
  7import jax.numpy as jnp
  8
  9from .thermostats import get_thermostat
 10from .barostats import get_barostat
 11from .colvars import setup_colvars
 12from .spectra import initialize_ir_spectrum
 13
 14from .utils import load_dynamics_restart, get_restart_file,optimize_fire2, us
 15from .initial import load_model, load_system_data, initialize_preprocessing
 16
 17
 18def initialize_dynamics(simulation_parameters, fprec, rng_key):
 19    ### LOAD MODEL
 20    model = load_model(simulation_parameters)
 21    model_energy_unit = us.get_multiplier(model.energy_unit)
 22
 23    ### Get the coordinates and species from the xyz file
 24    system_data, conformation = load_system_data(simulation_parameters, fprec)
 25    system_data["model_energy_unit"] = model_energy_unit
 26    system_data["model_energy_unit_str"] = model.energy_unit
 27
 28    ### FINISH BUILDING conformation
 29    do_restart = os.path.exists(get_restart_file(system_data))
 30    if do_restart:
 31        ### RESTART FROM PREVIOUS DYNAMICS
 32        restart_data = load_dynamics_restart(system_data)
 33        print("# RESTARTING FROM PREVIOUS DYNAMICS")
 34        model.preproc_state = restart_data["preproc_state"]
 35        conformation["coordinates"] = restart_data["coordinates"]
 36    else:
 37        restart_data = {}
 38
 39    ### INITIALIZE PREPROCESSING
 40    preproc_state, conformation = initialize_preprocessing(
 41        simulation_parameters, model, conformation, system_data
 42    )
 43
 44    minimize = simulation_parameters.get("xyz_input/minimize", False)
 45    """@keyword[fennol_md] xyz_input/minimize
 46    Perform energy minimization before dynamics.
 47    Default: False
 48    """
 49    if minimize and not do_restart:
 50        assert system_data["nreplicas"] == 1, "Minimization is only supported for single replica systems"
 51        model.preproc_state = preproc_state
 52        convert = us.KCALPERMOL / model_energy_unit
 53        nat = system_data["nat"]
 54        def energy_force_fn(coordinates):
 55            inputs = {**conformation, "coordinates": coordinates}
 56            e, f, _ = model.energy_and_forces(
 57                **inputs, gpu_preprocessing=True
 58            )
 59            e = float(e[0]) * convert / nat
 60            f = np.array(f) * convert
 61            return e, f
 62        tol = simulation_parameters.get("xyz_input/minimize_ftol", 1e-1/us.KCALPERMOL)*us.KCALPERMOL
 63        """@keyword[fennol_md] xyz_input/minimize_ftol
 64        Force tolerance for minimization.
 65        Default: 0.1 kcal/mol/Å
 66        """
 67        print(f"# Minimizing initial configuration with RMS force tolerance = {tol:.1e} kcal/mol/A")
 68        conformation["coordinates"], success = optimize_fire2(
 69            conformation["coordinates"],
 70            energy_force_fn,
 71            atol=tol,
 72            max_disp=0.02,
 73        )
 74        if success:
 75            print("# Minimization successful")
 76        else:
 77            print("# Warning: Minimization failed, continuing with last configuration")
 78        # write the minimized coordinates as an xyz file
 79        from ..utils.io import write_xyz_frame
 80        with open(system_data["name"]+".opt.xyz", "w") as f:
 81            write_xyz_frame(f, system_data["symbols"],np.array(conformation["coordinates"]),cell=conformation.get("cells", None))
 82        print("# Minimized configuration written to", system_data["name"]+".opt.xyz")
 83        preproc_state = model.preproc_state
 84        conformation = model.preprocessing.process(preproc_state, conformation)
 85        system_data["initial_coordinates"] = np.array(conformation["coordinates"]).copy()
 86
 87    ### get dynamics parameters
 88    dt = simulation_parameters.get("dt")
 89    """@keyword[fennol_md] dt
 90    Integration time step. Required parameter.
 91    Type: float, Required
 92    """
 93    dt2 = 0.5 * dt
 94    mass = system_data["mass"]
 95    densmass = system_data["totmass_Da"] * (us.MOL/us.CM**3)
 96    nat = system_data["nat"]
 97    dtm = jnp.asarray(dt / mass[:, None], dtype=fprec)
 98    ek_avg = 0.5 * nat * system_data["kT"] * np.eye(3)
 99
100    nreplicas = system_data.get("nreplicas", 1)
101    nbeads = system_data.get("nbeads", None)
102    if nbeads is not None:
103        nreplicas = nbeads
104        dtm = dtm[None, :, :]
105
106    ### INITIALIZE DYNAMICS STATE
107    system = {"coordinates": conformation["coordinates"]}
108    dyn_state = {
109        "istep": 0,
110        "dt": dt,
111        "pimd": nbeads is not None,
112        "preproc_state": preproc_state,
113        "start_time_ps": restart_data.get("simulation_time_ps", 0.),
114    }
115    gradient_keys = ["coordinates"]
116    thermo_updates = []
117
118    ### INITIALIZE THERMOSTAT
119    thermostat_rng, rng_key = jax.random.split(rng_key)
120    (
121        thermostat,
122        thermostat_post,
123        thermostat_state,
124        initial_vel,
125        dyn_state["thermostat_name"],
126    ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data)
127    do_thermostat_post = thermostat_post is not None
128    if do_thermostat_post:
129        thermostat_post, post_state = thermostat_post
130        dyn_state["thermostat_post_state"] = post_state
131
132    system["thermostat"] = thermostat_state
133    system["vel"] = restart_data.get("vel", initial_vel).astype(fprec)
134
135    ### PBC
136    pbc_data = system_data.get("pbc", None)
137    if pbc_data is not None:
138        ### INITIALIZE BAROSTAT
139        barostat_key, rng_key = jax.random.split(rng_key)
140        thermo_update_ensemble, variable_cell, barostat_state = get_barostat(
141            thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data
142        )
143        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
144        system["barostat"] = barostat_state
145        system["cell"] = conformation["cells"][0]
146        if estimate_pressure:
147            pressure_o_weight = simulation_parameters.get("pressure_o_weight", 1.0)
148            """@keyword[fennol_md] pressure_o_weight
149            Weight factor for mixing middle (O) and outer time step kinetic energies in pressure estimator.
150            Default: 1.0
151            """
152            assert (
153                0.0 <= pressure_o_weight <= 1.0
154            ), "pressure_o_weight must be between 0 and 1"
155            gradient_keys.append("strain")
156        print("# Estimate pressure: ", estimate_pressure)
157    else:
158        estimate_pressure = False
159        variable_cell = False
160
161        def thermo_update_ensemble(x, v, system):
162            v, thermostat_state = thermostat(v, system["thermostat"])
163            return x, v, {**system, "thermostat": thermostat_state}
164
165    dyn_state["estimate_pressure"] = estimate_pressure
166    dyn_state["variable_cell"] = variable_cell
167    thermo_updates.append(thermo_update_ensemble)
168
169    if estimate_pressure:
170        use_average_Pkin = simulation_parameters.get("use_average_Pkin", False)
171        """@keyword[fennol_md] use_average_Pkin
172        Use time-averaged kinetic energy for pressure estimation instead of instantaneous values.
173        Default: False
174        """
175        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
176        if is_qtb and use_average_Pkin:
177            raise ValueError(
178                "use_average_Pkin is not compatible with QTB thermostat, please set use_average_Pkin to False"
179            )
180
181
182    ### ENERGY ENSEMBLE
183    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
184    """@keyword[fennol_md] etot_ensemble_key
185    Key for energy ensemble calculation. Enables computation of ensemble weights.
186    Default: None
187    """
188
189    ### COLVARS
190    colvars_definitions = simulation_parameters.get("colvars", None)
191    """@keyword[fennol_md] colvars
192    Collective variables definitions for enhanced sampling or monitoring.
193    Default: None
194    """
195    use_colvars = colvars_definitions is not None
196    if use_colvars:
197        colvars_calculators, colvars_names = setup_colvars(colvars_definitions)
198        dyn_state["colvars"] = colvars_names
199
200    ### IR SPECTRUM
201    do_ir_spectrum = simulation_parameters.get("ir_spectrum", False)
202    """@keyword[fennol_md] ir_spectrum
203    Calculate infrared spectrum from molecular dipole moment time series.
204    Default: False
205    """
206    assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean"
207    if do_ir_spectrum:
208        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
209        model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum(
210            simulation_parameters, system_data, fprec, dt, is_qtb
211        )
212        dyn_state["ir_spectrum"] = ir_state
213
214    ### BUILD GRADIENT FUNCTION
215    energy_and_gradient = model.get_gradient_function(
216        *gradient_keys, jit=True, variables_as_input=True
217    )
218
219    ### COLLECT THERMO UPDATES
220    if len(thermo_updates) == 1:
221        thermo_update = thermo_updates[0]
222    else:
223
224        def thermo_update(x, v, system):
225            for update in thermo_updates:
226                x, v, system = update(x, v, system)
227            return x, v, system
228
229    ### RING POLYMER INITIALIZATION
230    if nbeads is not None:
231        cay_correction = simulation_parameters.get("cay_correction", True)
232        """@keyword[fennol_md] cay_correction
233        Use Cayley propagator for ring polymer molecular dynamics instead of standard propagation.
234        Default: True
235        """
236        omk = system_data["omk"]
237        eigmat = system_data["eigmat"]
238        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
239        if cay_correction:
240            axx = jnp.asarray(2 * cayfact)
241            axv = jnp.asarray(dt * cayfact)
242            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
243        else:
244            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
245            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
246            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
247
248        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
249        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
250        system["coordinates"] = eigx
251
252    ###############################################
253    ### DEFINE UPDATE FUNCTION
254    @jax.jit
255    def update_conformation(conformation, system):
256        x = system["coordinates"]
257        if nbeads is not None:
258            x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * (
259                nbeads**0.5
260            )
261        conformation = {**conformation, "coordinates": x}
262        if variable_cell:
263            conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0)
264
265        
266
267        return conformation
268
269    ###############################################
270    ### DEFINE INTEGRATION FUNCTIONS
271    def integrate_A_half(x0, v0):
272        if nbeads is None:
273            return x0 + dt2 * v0, v0
274
275        # update coordinates and velocities of a free ring polymer for a half time step
276        eigx_c = x0[0] + dt2 * v0[0]
277        eigv_c = v0[0]
278        eigx = x0[1:] * axx + v0[1:] * axv
279        eigv = x0[1:] * avx + v0[1:] * axx
280
281        return (
282            jnp.concatenate((eigx_c[None], eigx), axis=0),
283            jnp.concatenate((eigv_c[None], eigv), axis=0),
284        )
285
286    @jax.jit
287    def integrate(system):
288        x = system["coordinates"]
289        v = system["vel"] + dtm * system["forces"]
290        x, v = integrate_A_half(x, v)
291        x, v, system = thermo_update(x, v, system)
292        x, v = integrate_A_half(x, v)
293
294        return {**system, "coordinates": x, "vel": v}
295
296    ###############################################
297    ### DEFINE OBSERVABLE FUNCTION
298    @jax.jit
299    def update_observables(system, conformation):
300        ### POTENTIAL ENERGY AND FORCES
301        epot, de, out = energy_and_gradient(model.variables, conformation)
302        out["forces"] = -de["coordinates"]
303        epot = epot / model_energy_unit
304        de = {k: v / model_energy_unit for k, v in de.items()}
305        forces = -de["coordinates"]
306
307        if nbeads is not None:
308            ### PROJECT FORCES ONTO POLYMER NORMAL MODES
309            forces = jnp.einsum(
310                "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3)
311            ) * (1.0 / nbeads**0.5)
312
313        system = {
314            **system,
315            "epot": jnp.mean(epot),
316            "forces": forces,
317            "energy_gradients": de,
318        }
319
320        ### KINETIC ENERGY
321        v = system["vel"]
322        if nbeads is None:
323            corr_kin = system["thermostat"].get("corr_kin", 1.0)
324            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
325            ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
326                mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
327            )
328        else:
329            ek_c = 0.5 * jnp.sum(
330                mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0
331            )
332            ek = ek_c - 0.5 * jnp.sum(
333                system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
334                axis=(0, 1),
335            )
336            system["ek_c"] = jnp.trace(ek_c)
337
338        system["ek"] = jnp.trace(ek)
339        system["ek_tensor"] = ek
340
341        if estimate_pressure:
342            if use_average_Pkin:
343                ek = ek_avg
344            elif pressure_o_weight != 1.0:
345                v = system["vel"] + 0.5 * dtm * system["forces"]
346                if nbeads is None:
347                    corr_kin = system["thermostat"].get("corr_kin", 1.0)
348                    # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
349                    ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
350                        mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
351                    )
352                else:
353                    ek_c = 0.5 * jnp.sum(
354                        mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :],
355                        axis=0,
356                    )
357                    ek = ek_c - 0.5 * jnp.sum(
358                        system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
359                        axis=(0, 1),
360                    )
361                b = pressure_o_weight
362                ek = (1.0 - b) * ek + b * system["ek_tensor"]
363
364            vir = jnp.mean(de["strain"], axis=0)
365            system["virial"] = vir
366            out["virial_tensor"] = vir * model_energy_unit
367            
368            volume = jnp.abs(jnp.linalg.det(system["cell"]))
369            Pres =  ek*(2./volume)  - vir/volume
370            system["pressure_tensor"] = Pres
371            system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
372            if variable_cell:
373                density = densmass / volume
374                system["density"] = density
375                system["volume"] = volume
376
377        if ensemble_key is not None:
378            kT = system_data["kT"]
379            dE = (
380                jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"]
381            )
382            system["ensemble_weights"] = -dE / kT
383
384        if "total_dipole" in out:
385            if nbeads is None:
386                system["total_dipole"] = out["total_dipole"][0]
387            else:
388                system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0)
389
390        if use_colvars:
391            coords = system["coordinates"].reshape(-1, nat, 3)[0]
392            colvars = {}
393            for colvar_name, colvar_calc in colvars_calculators.items():
394                colvars[colvar_name] = colvar_calc(coords)
395            system["colvars"] = colvars
396
397        return system, out
398
399    ###############################################
400    ### IR SPECTRUM
401    if do_ir_spectrum:
402        # @jax.jit
403        # def update_dipole(ir_state,system,conformation):
404        #     def mumodel(coords):
405        #         out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords})
406        #         if nbeads is None:
407        #             return out["total_dipole"][0]
408        #         return out["total_dipole"].sum(axis=0)
409        #     dmudqmodel = jax.jacobian(mumodel)
410
411        #     dmudq = dmudqmodel(conformation["coordinates"])
412        #     # print(dmudq.shape)
413        #     if nbeads is None:
414        #         vel = system["vel"].reshape(-1,1,nat,3)[0]
415        #         mudot = (vel*dmudq).sum(axis=(1,2))
416        #     else:
417        #         dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1)
418        #         vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) *  nbeads**0.5
419        #         )
420        #         # vel = system["vel"][0].reshape(1,nat,3)
421        #         mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads
422
423        #     ir_state = save_dipole(mudot,ir_state)
424        #     return ir_state
425        @jax.jit
426        def update_conformation_ir(conformation, system):
427            conformation = {
428                **conformation,
429                "coordinates": system["coordinates"].reshape(-1, nat, 3)[0],
430                "natoms": jnp.asarray([nat]),
431                "batch_index": jnp.asarray([0] * nat),
432                "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]),
433            }
434            if variable_cell:
435                conformation["cells"] = system["cell"][None, :, :]
436                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
437                    None, :, :
438                ]
439            return conformation
440
441        @jax.jit
442        def update_dipole(ir_state, system, conformation):
443            if model_ir is not None:
444                out = model_ir._apply(model_ir.variables, conformation)
445                q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat))
446                dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
447            else:
448                q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat))
449                dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
450            if nbeads is not None:
451                q = jnp.mean(q, axis=0)
452                dip = jnp.mean(dip, axis=0)
453                vel = system["vel"][0]
454                pos = system["coordinates"][0]
455            else:
456                q = q[0]
457                dip = dip[0]
458                vel = system["vel"].reshape(-1, nat, 3)[0]
459                pos = system["coordinates"].reshape(-1, nat, 3)[0]
460
461            if pbc_data is not None:
462                cell_reciprocal = (
463                    conformation["cells"][0],
464                    conformation["reciprocal_cells"][0],
465                )
466            else:
467                cell_reciprocal = None
468
469            ir_state = save_dipole(
470                q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state
471            )
472            return ir_state
473
474    ###############################################
475    ### GRAPH UPDATES
476
477    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
478    """@keyword[fennol_md] nblist_verbose
479    Print verbose information about neighbor list updates and reallocations.
480    Default: False
481    """
482    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
483    """@keyword[fennol_md] nblist_stride
484    Number of steps between full neighbor list rebuilds. Auto-calculated from skin if <= 0.
485    Default: -1
486    """
487    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0)
488    """@keyword[fennol_md] nblist_warmup_time
489    Time period for neighbor list warmup before using skin updates.
490    Default: -1.0
491    """
492    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
493    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
494    """@keyword[fennol_md] nblist_skin
495    Neighbor list skin distance for efficient updates (in Angstroms).
496    Default: -1.0
497    """
498    if nblist_skin > 0:
499        if nblist_stride <= 0:
500            ## reference skin parameters at 300K (from Tinker-HP)
501            ##   => skin of 2 A gives you 40 fs without complete rebuild
502            t_ref = 40.0 /us.FS # FS
503            nblist_skin_ref = 2.0  # A
504            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
505        print(
506            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
507        )
508
509    if nblist_skin <= 0:
510        nblist_stride = 1
511
512    dyn_state["nblist_countdown"] = 0
513    dyn_state["print_skin_activation"] = nblist_warmup > 0
514
515    def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False):
516        nblist_countdown = dyn_state["nblist_countdown"]
517        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
518            ### FULL NBLIST REBUILD
519            dyn_state["nblist_countdown"] = nblist_stride - 1
520            preproc_state = dyn_state["preproc_state"]
521            conformation = model.preprocessing.process(
522                preproc_state, update_conformation(conformation, system)
523            )
524            preproc_state, state_up, conformation, overflow = (
525                model.preprocessing.check_reallocate(preproc_state, conformation)
526            )
527            dyn_state["preproc_state"] = preproc_state
528            if nblist_verbose and overflow:
529                print("step", istep, ", nblist overflow => reallocating nblist")
530                print("size updates:", state_up)
531
532            if do_ir_spectrum and model_ir is not None:
533                conformation_ir = model_ir.preprocessing.process(
534                    dyn_state["preproc_state_ir"],
535                    update_conformation_ir(dyn_state["conformation_ir"], system),
536                )
537                (
538                    dyn_state["preproc_state_ir"],
539                    _,
540                    dyn_state["conformation_ir"],
541                    overflow,
542                ) = model_ir.preprocessing.check_reallocate(
543                    dyn_state["preproc_state_ir"], conformation_ir
544                )
545
546        else:
547            ### SKIN UPDATE
548            if dyn_state["print_skin_activation"]:
549                if nblist_verbose:
550                    print(
551                        "step",
552                        istep,
553                        ", end of nblist warmup phase => activating skin updates",
554                    )
555                dyn_state["print_skin_activation"] = False
556
557            dyn_state["nblist_countdown"] = nblist_countdown - 1
558            conformation = model.preprocessing.update_skin(
559                update_conformation(conformation, system)
560            )
561            if do_ir_spectrum and model_ir is not None:
562                dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin(
563                    update_conformation_ir(dyn_state["conformation_ir"], system)
564                )
565
566        return conformation, dyn_state
567
568    ################################################
569    ### DEFINE STEP FUNCTION
570    def step(istep, dyn_state, system, conformation, force_preprocess=False):
571
572        dyn_state = {
573            **dyn_state,
574            "istep": dyn_state["istep"] + 1,
575        }
576
577        ### INTEGRATE EQUATIONS OF MOTION
578        system = integrate(system)
579
580        ### UPDATE CONFORMATION AND GRAPHS
581        conformation, dyn_state = update_graphs(
582            istep, dyn_state, system, conformation, force_preprocess
583        )
584
585        ## COMPUTE FORCES AND OBSERVABLES
586        system, out = update_observables(system, conformation)
587
588        ## END OF STEP UPDATES
589        if do_thermostat_post:
590            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
591                system["thermostat"], dyn_state["thermostat_post_state"]
592            )
593        
594        if do_ir_spectrum:
595            ir_state = update_dipole(
596                dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"]
597            )
598            dyn_state["ir_spectrum"] = ir_post(ir_state)
599
600        return dyn_state, system, conformation, out
601
602    ###########################################################
603
604    print("# Computing initial energy and forces")
605
606    conformation = update_conformation(conformation, system)
607    # initialize IR conformation
608    if do_ir_spectrum and model_ir is not None:
609        dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = (
610            model_ir.preprocessing(
611                model_ir.preproc_state,
612                update_conformation_ir(conformation, system),
613            )
614        )
615
616    system, _ = update_observables(system, conformation)
617
618    return step, update_conformation, system_data, dyn_state, conformation, system
def initialize_dynamics(simulation_parameters, fprec, rng_key):
 19def initialize_dynamics(simulation_parameters, fprec, rng_key):
 20    ### LOAD MODEL
 21    model = load_model(simulation_parameters)
 22    model_energy_unit = us.get_multiplier(model.energy_unit)
 23
 24    ### Get the coordinates and species from the xyz file
 25    system_data, conformation = load_system_data(simulation_parameters, fprec)
 26    system_data["model_energy_unit"] = model_energy_unit
 27    system_data["model_energy_unit_str"] = model.energy_unit
 28
 29    ### FINISH BUILDING conformation
 30    do_restart = os.path.exists(get_restart_file(system_data))
 31    if do_restart:
 32        ### RESTART FROM PREVIOUS DYNAMICS
 33        restart_data = load_dynamics_restart(system_data)
 34        print("# RESTARTING FROM PREVIOUS DYNAMICS")
 35        model.preproc_state = restart_data["preproc_state"]
 36        conformation["coordinates"] = restart_data["coordinates"]
 37    else:
 38        restart_data = {}
 39
 40    ### INITIALIZE PREPROCESSING
 41    preproc_state, conformation = initialize_preprocessing(
 42        simulation_parameters, model, conformation, system_data
 43    )
 44
 45    minimize = simulation_parameters.get("xyz_input/minimize", False)
 46    """@keyword[fennol_md] xyz_input/minimize
 47    Perform energy minimization before dynamics.
 48    Default: False
 49    """
 50    if minimize and not do_restart:
 51        assert system_data["nreplicas"] == 1, "Minimization is only supported for single replica systems"
 52        model.preproc_state = preproc_state
 53        convert = us.KCALPERMOL / model_energy_unit
 54        nat = system_data["nat"]
 55        def energy_force_fn(coordinates):
 56            inputs = {**conformation, "coordinates": coordinates}
 57            e, f, _ = model.energy_and_forces(
 58                **inputs, gpu_preprocessing=True
 59            )
 60            e = float(e[0]) * convert / nat
 61            f = np.array(f) * convert
 62            return e, f
 63        tol = simulation_parameters.get("xyz_input/minimize_ftol", 1e-1/us.KCALPERMOL)*us.KCALPERMOL
 64        """@keyword[fennol_md] xyz_input/minimize_ftol
 65        Force tolerance for minimization.
 66        Default: 0.1 kcal/mol/Å
 67        """
 68        print(f"# Minimizing initial configuration with RMS force tolerance = {tol:.1e} kcal/mol/A")
 69        conformation["coordinates"], success = optimize_fire2(
 70            conformation["coordinates"],
 71            energy_force_fn,
 72            atol=tol,
 73            max_disp=0.02,
 74        )
 75        if success:
 76            print("# Minimization successful")
 77        else:
 78            print("# Warning: Minimization failed, continuing with last configuration")
 79        # write the minimized coordinates as an xyz file
 80        from ..utils.io import write_xyz_frame
 81        with open(system_data["name"]+".opt.xyz", "w") as f:
 82            write_xyz_frame(f, system_data["symbols"],np.array(conformation["coordinates"]),cell=conformation.get("cells", None))
 83        print("# Minimized configuration written to", system_data["name"]+".opt.xyz")
 84        preproc_state = model.preproc_state
 85        conformation = model.preprocessing.process(preproc_state, conformation)
 86        system_data["initial_coordinates"] = np.array(conformation["coordinates"]).copy()
 87
 88    ### get dynamics parameters
 89    dt = simulation_parameters.get("dt")
 90    """@keyword[fennol_md] dt
 91    Integration time step. Required parameter.
 92    Type: float, Required
 93    """
 94    dt2 = 0.5 * dt
 95    mass = system_data["mass"]
 96    densmass = system_data["totmass_Da"] * (us.MOL/us.CM**3)
 97    nat = system_data["nat"]
 98    dtm = jnp.asarray(dt / mass[:, None], dtype=fprec)
 99    ek_avg = 0.5 * nat * system_data["kT"] * np.eye(3)
100
101    nreplicas = system_data.get("nreplicas", 1)
102    nbeads = system_data.get("nbeads", None)
103    if nbeads is not None:
104        nreplicas = nbeads
105        dtm = dtm[None, :, :]
106
107    ### INITIALIZE DYNAMICS STATE
108    system = {"coordinates": conformation["coordinates"]}
109    dyn_state = {
110        "istep": 0,
111        "dt": dt,
112        "pimd": nbeads is not None,
113        "preproc_state": preproc_state,
114        "start_time_ps": restart_data.get("simulation_time_ps", 0.),
115    }
116    gradient_keys = ["coordinates"]
117    thermo_updates = []
118
119    ### INITIALIZE THERMOSTAT
120    thermostat_rng, rng_key = jax.random.split(rng_key)
121    (
122        thermostat,
123        thermostat_post,
124        thermostat_state,
125        initial_vel,
126        dyn_state["thermostat_name"],
127    ) = get_thermostat(simulation_parameters, dt, system_data, fprec, thermostat_rng,restart_data)
128    do_thermostat_post = thermostat_post is not None
129    if do_thermostat_post:
130        thermostat_post, post_state = thermostat_post
131        dyn_state["thermostat_post_state"] = post_state
132
133    system["thermostat"] = thermostat_state
134    system["vel"] = restart_data.get("vel", initial_vel).astype(fprec)
135
136    ### PBC
137    pbc_data = system_data.get("pbc", None)
138    if pbc_data is not None:
139        ### INITIALIZE BAROSTAT
140        barostat_key, rng_key = jax.random.split(rng_key)
141        thermo_update_ensemble, variable_cell, barostat_state = get_barostat(
142            thermostat, simulation_parameters, dt, system_data, fprec, barostat_key,restart_data
143        )
144        estimate_pressure = variable_cell or pbc_data["estimate_pressure"]
145        system["barostat"] = barostat_state
146        system["cell"] = conformation["cells"][0]
147        if estimate_pressure:
148            pressure_o_weight = simulation_parameters.get("pressure_o_weight", 1.0)
149            """@keyword[fennol_md] pressure_o_weight
150            Weight factor for mixing middle (O) and outer time step kinetic energies in pressure estimator.
151            Default: 1.0
152            """
153            assert (
154                0.0 <= pressure_o_weight <= 1.0
155            ), "pressure_o_weight must be between 0 and 1"
156            gradient_keys.append("strain")
157        print("# Estimate pressure: ", estimate_pressure)
158    else:
159        estimate_pressure = False
160        variable_cell = False
161
162        def thermo_update_ensemble(x, v, system):
163            v, thermostat_state = thermostat(v, system["thermostat"])
164            return x, v, {**system, "thermostat": thermostat_state}
165
166    dyn_state["estimate_pressure"] = estimate_pressure
167    dyn_state["variable_cell"] = variable_cell
168    thermo_updates.append(thermo_update_ensemble)
169
170    if estimate_pressure:
171        use_average_Pkin = simulation_parameters.get("use_average_Pkin", False)
172        """@keyword[fennol_md] use_average_Pkin
173        Use time-averaged kinetic energy for pressure estimation instead of instantaneous values.
174        Default: False
175        """
176        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
177        if is_qtb and use_average_Pkin:
178            raise ValueError(
179                "use_average_Pkin is not compatible with QTB thermostat, please set use_average_Pkin to False"
180            )
181
182
183    ### ENERGY ENSEMBLE
184    ensemble_key = simulation_parameters.get("etot_ensemble_key", None)
185    """@keyword[fennol_md] etot_ensemble_key
186    Key for energy ensemble calculation. Enables computation of ensemble weights.
187    Default: None
188    """
189
190    ### COLVARS
191    colvars_definitions = simulation_parameters.get("colvars", None)
192    """@keyword[fennol_md] colvars
193    Collective variables definitions for enhanced sampling or monitoring.
194    Default: None
195    """
196    use_colvars = colvars_definitions is not None
197    if use_colvars:
198        colvars_calculators, colvars_names = setup_colvars(colvars_definitions)
199        dyn_state["colvars"] = colvars_names
200
201    ### IR SPECTRUM
202    do_ir_spectrum = simulation_parameters.get("ir_spectrum", False)
203    """@keyword[fennol_md] ir_spectrum
204    Calculate infrared spectrum from molecular dipole moment time series.
205    Default: False
206    """
207    assert isinstance(do_ir_spectrum, bool), "ir_spectrum must be a boolean"
208    if do_ir_spectrum:
209        is_qtb = dyn_state["thermostat_name"].endswith("QTB")
210        model_ir, ir_state, save_dipole, ir_post = initialize_ir_spectrum(
211            simulation_parameters, system_data, fprec, dt, is_qtb
212        )
213        dyn_state["ir_spectrum"] = ir_state
214
215    ### BUILD GRADIENT FUNCTION
216    energy_and_gradient = model.get_gradient_function(
217        *gradient_keys, jit=True, variables_as_input=True
218    )
219
220    ### COLLECT THERMO UPDATES
221    if len(thermo_updates) == 1:
222        thermo_update = thermo_updates[0]
223    else:
224
225        def thermo_update(x, v, system):
226            for update in thermo_updates:
227                x, v, system = update(x, v, system)
228            return x, v, system
229
230    ### RING POLYMER INITIALIZATION
231    if nbeads is not None:
232        cay_correction = simulation_parameters.get("cay_correction", True)
233        """@keyword[fennol_md] cay_correction
234        Use Cayley propagator for ring polymer molecular dynamics instead of standard propagation.
235        Default: True
236        """
237        omk = system_data["omk"]
238        eigmat = system_data["eigmat"]
239        cayfact = 1.0 / (4.0 + (dt * omk[1:, None, None]) ** 2) ** 0.5
240        if cay_correction:
241            axx = jnp.asarray(2 * cayfact)
242            axv = jnp.asarray(dt * cayfact)
243            avx = jnp.asarray(-dt * cayfact * omk[1:, None, None] ** 2)
244        else:
245            axx = jnp.asarray(np.cos(omk[1:, None, None] * dt2))
246            axv = jnp.asarray(np.sin(omk[1:, None, None] * dt2) / omk[1:, None, None])
247            avx = jnp.asarray(-omk[1:, None, None] * np.sin(omk[1:, None, None] * dt2))
248
249        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
250        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
251        system["coordinates"] = eigx
252
253    ###############################################
254    ### DEFINE UPDATE FUNCTION
255    @jax.jit
256    def update_conformation(conformation, system):
257        x = system["coordinates"]
258        if nbeads is not None:
259            x = jnp.einsum("in,n...->i...", eigmat, x).reshape(nbeads * nat, 3) * (
260                nbeads**0.5
261            )
262        conformation = {**conformation, "coordinates": x}
263        if variable_cell:
264            conformation["cells"] = system["cell"][None, :, :].repeat(nreplicas, axis=0)
265
266        
267
268        return conformation
269
270    ###############################################
271    ### DEFINE INTEGRATION FUNCTIONS
272    def integrate_A_half(x0, v0):
273        if nbeads is None:
274            return x0 + dt2 * v0, v0
275
276        # update coordinates and velocities of a free ring polymer for a half time step
277        eigx_c = x0[0] + dt2 * v0[0]
278        eigv_c = v0[0]
279        eigx = x0[1:] * axx + v0[1:] * axv
280        eigv = x0[1:] * avx + v0[1:] * axx
281
282        return (
283            jnp.concatenate((eigx_c[None], eigx), axis=0),
284            jnp.concatenate((eigv_c[None], eigv), axis=0),
285        )
286
287    @jax.jit
288    def integrate(system):
289        x = system["coordinates"]
290        v = system["vel"] + dtm * system["forces"]
291        x, v = integrate_A_half(x, v)
292        x, v, system = thermo_update(x, v, system)
293        x, v = integrate_A_half(x, v)
294
295        return {**system, "coordinates": x, "vel": v}
296
297    ###############################################
298    ### DEFINE OBSERVABLE FUNCTION
299    @jax.jit
300    def update_observables(system, conformation):
301        ### POTENTIAL ENERGY AND FORCES
302        epot, de, out = energy_and_gradient(model.variables, conformation)
303        out["forces"] = -de["coordinates"]
304        epot = epot / model_energy_unit
305        de = {k: v / model_energy_unit for k, v in de.items()}
306        forces = -de["coordinates"]
307
308        if nbeads is not None:
309            ### PROJECT FORCES ONTO POLYMER NORMAL MODES
310            forces = jnp.einsum(
311                "in,i...->n...", eigmat, forces.reshape(nbeads, nat, 3)
312            ) * (1.0 / nbeads**0.5)
313
314        system = {
315            **system,
316            "epot": jnp.mean(epot),
317            "forces": forces,
318            "energy_gradients": de,
319        }
320
321        ### KINETIC ENERGY
322        v = system["vel"]
323        if nbeads is None:
324            corr_kin = system["thermostat"].get("corr_kin", 1.0)
325            # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
326            ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
327                mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
328            )
329        else:
330            ek_c = 0.5 * jnp.sum(
331                mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :], axis=0
332            )
333            ek = ek_c - 0.5 * jnp.sum(
334                system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
335                axis=(0, 1),
336            )
337            system["ek_c"] = jnp.trace(ek_c)
338
339        system["ek"] = jnp.trace(ek)
340        system["ek_tensor"] = ek
341
342        if estimate_pressure:
343            if use_average_Pkin:
344                ek = ek_avg
345            elif pressure_o_weight != 1.0:
346                v = system["vel"] + 0.5 * dtm * system["forces"]
347                if nbeads is None:
348                    corr_kin = system["thermostat"].get("corr_kin", 1.0)
349                    # ek = 0.5 * jnp.sum(mass[:, None] * v**2) / state_th.get("corr_kin", 1.0)
350                    ek = (0.5 / nreplicas / corr_kin) * jnp.sum(
351                        mass[:, None, None] * v[:, :, None] * v[:, None, :], axis=0
352                    )
353                else:
354                    ek_c = 0.5 * jnp.sum(
355                        mass[:, None, None] * v[0, :, :, None] * v[0, :, None, :],
356                        axis=0,
357                    )
358                    ek = ek_c - 0.5 * jnp.sum(
359                        system["coordinates"][1:, :, :, None] * forces[1:, :, None, :],
360                        axis=(0, 1),
361                    )
362                b = pressure_o_weight
363                ek = (1.0 - b) * ek + b * system["ek_tensor"]
364
365            vir = jnp.mean(de["strain"], axis=0)
366            system["virial"] = vir
367            out["virial_tensor"] = vir * model_energy_unit
368            
369            volume = jnp.abs(jnp.linalg.det(system["cell"]))
370            Pres =  ek*(2./volume)  - vir/volume
371            system["pressure_tensor"] = Pres
372            system["pressure"] = jnp.trace(Pres) * (1.0 / 3.0)
373            if variable_cell:
374                density = densmass / volume
375                system["density"] = density
376                system["volume"] = volume
377
378        if ensemble_key is not None:
379            kT = system_data["kT"]
380            dE = (
381                jnp.mean(out[ensemble_key], axis=0) / model_energy_unit - system["epot"]
382            )
383            system["ensemble_weights"] = -dE / kT
384
385        if "total_dipole" in out:
386            if nbeads is None:
387                system["total_dipole"] = out["total_dipole"][0]
388            else:
389                system["total_dipole"] = jnp.mean(out["total_dipole"], axis=0)
390
391        if use_colvars:
392            coords = system["coordinates"].reshape(-1, nat, 3)[0]
393            colvars = {}
394            for colvar_name, colvar_calc in colvars_calculators.items():
395                colvars[colvar_name] = colvar_calc(coords)
396            system["colvars"] = colvars
397
398        return system, out
399
400    ###############################################
401    ### IR SPECTRUM
402    if do_ir_spectrum:
403        # @jax.jit
404        # def update_dipole(ir_state,system,conformation):
405        #     def mumodel(coords):
406        #         out = model_ir._apply(model_ir.variables,{**conformation,"coordinates":coords})
407        #         if nbeads is None:
408        #             return out["total_dipole"][0]
409        #         return out["total_dipole"].sum(axis=0)
410        #     dmudqmodel = jax.jacobian(mumodel)
411
412        #     dmudq = dmudqmodel(conformation["coordinates"])
413        #     # print(dmudq.shape)
414        #     if nbeads is None:
415        #         vel = system["vel"].reshape(-1,1,nat,3)[0]
416        #         mudot = (vel*dmudq).sum(axis=(1,2))
417        #     else:
418        #         dmudq = dmudq.reshape(3,nbeads,nat,3)#.mean(axis=1)
419        #         vel = (jnp.einsum("in,n...->i...", eigmat, system["vel"]) *  nbeads**0.5
420        #         )
421        #         # vel = system["vel"][0].reshape(1,nat,3)
422        #         mudot = (vel[None,...]*dmudq).sum(axis=(1,2,3))/nbeads
423
424        #     ir_state = save_dipole(mudot,ir_state)
425        #     return ir_state
426        @jax.jit
427        def update_conformation_ir(conformation, system):
428            conformation = {
429                **conformation,
430                "coordinates": system["coordinates"].reshape(-1, nat, 3)[0],
431                "natoms": jnp.asarray([nat]),
432                "batch_index": jnp.asarray([0] * nat),
433                "species": jnp.asarray(system_data["species"].reshape(-1, nat)[0]),
434            }
435            if variable_cell:
436                conformation["cells"] = system["cell"][None, :, :]
437                conformation["reciprocal_cells"] = jnp.linalg.inv(system["cell"])[
438                    None, :, :
439                ]
440            return conformation
441
442        @jax.jit
443        def update_dipole(ir_state, system, conformation):
444            if model_ir is not None:
445                out = model_ir._apply(model_ir.variables, conformation)
446                q = out.get("charges", jnp.zeros(nat)).reshape((-1, nat))
447                dip = out.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
448            else:
449                q = system.get("charges", jnp.zeros(nat)).reshape((-1, nat))
450                dip = system.get("dipoles", jnp.zeros((nat, 3))).reshape((-1, nat, 3))
451            if nbeads is not None:
452                q = jnp.mean(q, axis=0)
453                dip = jnp.mean(dip, axis=0)
454                vel = system["vel"][0]
455                pos = system["coordinates"][0]
456            else:
457                q = q[0]
458                dip = dip[0]
459                vel = system["vel"].reshape(-1, nat, 3)[0]
460                pos = system["coordinates"].reshape(-1, nat, 3)[0]
461
462            if pbc_data is not None:
463                cell_reciprocal = (
464                    conformation["cells"][0],
465                    conformation["reciprocal_cells"][0],
466                )
467            else:
468                cell_reciprocal = None
469
470            ir_state = save_dipole(
471                q, vel, pos, dip.sum(axis=0), cell_reciprocal, ir_state
472            )
473            return ir_state
474
475    ###############################################
476    ### GRAPH UPDATES
477
478    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
479    """@keyword[fennol_md] nblist_verbose
480    Print verbose information about neighbor list updates and reallocations.
481    Default: False
482    """
483    nblist_stride = int(simulation_parameters.get("nblist_stride", -1))
484    """@keyword[fennol_md] nblist_stride
485    Number of steps between full neighbor list rebuilds. Auto-calculated from skin if <= 0.
486    Default: -1
487    """
488    nblist_warmup_time = simulation_parameters.get("nblist_warmup_time", -1.0)
489    """@keyword[fennol_md] nblist_warmup_time
490    Time period for neighbor list warmup before using skin updates.
491    Default: -1.0
492    """
493    nblist_warmup = int(nblist_warmup_time / dt) if nblist_warmup_time > 0 else 0
494    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
495    """@keyword[fennol_md] nblist_skin
496    Neighbor list skin distance for efficient updates (in Angstroms).
497    Default: -1.0
498    """
499    if nblist_skin > 0:
500        if nblist_stride <= 0:
501            ## reference skin parameters at 300K (from Tinker-HP)
502            ##   => skin of 2 A gives you 40 fs without complete rebuild
503            t_ref = 40.0 /us.FS # FS
504            nblist_skin_ref = 2.0  # A
505            nblist_stride = int(math.floor(nblist_skin / nblist_skin_ref * t_ref / dt))
506        print(
507            f"# nblist_skin: {nblist_skin:.2f} A, nblist_stride: {nblist_stride} steps, nblist_warmup: {nblist_warmup} steps"
508        )
509
510    if nblist_skin <= 0:
511        nblist_stride = 1
512
513    dyn_state["nblist_countdown"] = 0
514    dyn_state["print_skin_activation"] = nblist_warmup > 0
515
516    def update_graphs(istep, dyn_state, system, conformation, force_preprocess=False):
517        nblist_countdown = dyn_state["nblist_countdown"]
518        if nblist_countdown <= 0 or force_preprocess or (istep < nblist_warmup):
519            ### FULL NBLIST REBUILD
520            dyn_state["nblist_countdown"] = nblist_stride - 1
521            preproc_state = dyn_state["preproc_state"]
522            conformation = model.preprocessing.process(
523                preproc_state, update_conformation(conformation, system)
524            )
525            preproc_state, state_up, conformation, overflow = (
526                model.preprocessing.check_reallocate(preproc_state, conformation)
527            )
528            dyn_state["preproc_state"] = preproc_state
529            if nblist_verbose and overflow:
530                print("step", istep, ", nblist overflow => reallocating nblist")
531                print("size updates:", state_up)
532
533            if do_ir_spectrum and model_ir is not None:
534                conformation_ir = model_ir.preprocessing.process(
535                    dyn_state["preproc_state_ir"],
536                    update_conformation_ir(dyn_state["conformation_ir"], system),
537                )
538                (
539                    dyn_state["preproc_state_ir"],
540                    _,
541                    dyn_state["conformation_ir"],
542                    overflow,
543                ) = model_ir.preprocessing.check_reallocate(
544                    dyn_state["preproc_state_ir"], conformation_ir
545                )
546
547        else:
548            ### SKIN UPDATE
549            if dyn_state["print_skin_activation"]:
550                if nblist_verbose:
551                    print(
552                        "step",
553                        istep,
554                        ", end of nblist warmup phase => activating skin updates",
555                    )
556                dyn_state["print_skin_activation"] = False
557
558            dyn_state["nblist_countdown"] = nblist_countdown - 1
559            conformation = model.preprocessing.update_skin(
560                update_conformation(conformation, system)
561            )
562            if do_ir_spectrum and model_ir is not None:
563                dyn_state["conformation_ir"] = model_ir.preprocessing.update_skin(
564                    update_conformation_ir(dyn_state["conformation_ir"], system)
565                )
566
567        return conformation, dyn_state
568
569    ################################################
570    ### DEFINE STEP FUNCTION
571    def step(istep, dyn_state, system, conformation, force_preprocess=False):
572
573        dyn_state = {
574            **dyn_state,
575            "istep": dyn_state["istep"] + 1,
576        }
577
578        ### INTEGRATE EQUATIONS OF MOTION
579        system = integrate(system)
580
581        ### UPDATE CONFORMATION AND GRAPHS
582        conformation, dyn_state = update_graphs(
583            istep, dyn_state, system, conformation, force_preprocess
584        )
585
586        ## COMPUTE FORCES AND OBSERVABLES
587        system, out = update_observables(system, conformation)
588
589        ## END OF STEP UPDATES
590        if do_thermostat_post:
591            system["thermostat"], dyn_state["thermostat_post_state"] = thermostat_post(
592                system["thermostat"], dyn_state["thermostat_post_state"]
593            )
594        
595        if do_ir_spectrum:
596            ir_state = update_dipole(
597                dyn_state["ir_spectrum"], system, dyn_state["conformation_ir"]
598            )
599            dyn_state["ir_spectrum"] = ir_post(ir_state)
600
601        return dyn_state, system, conformation, out
602
603    ###########################################################
604
605    print("# Computing initial energy and forces")
606
607    conformation = update_conformation(conformation, system)
608    # initialize IR conformation
609    if do_ir_spectrum and model_ir is not None:
610        dyn_state["preproc_state_ir"], dyn_state["conformation_ir"] = (
611            model_ir.preprocessing(
612                model_ir.preproc_state,
613                update_conformation_ir(conformation, system),
614            )
615        )
616
617    system, _ = update_observables(system, conformation)
618
619    return step, update_conformation, system_data, dyn_state, conformation, system