fennol.md.utils

  1import numpy as np
  2import jax
  3import jax.numpy as jnp
  4import pickle
  5from ..utils import UnitSystem
  6from functools import partial
  7
  8# Define the unit system for fennol_md
  9us = UnitSystem(L="ANGSTROM", T="PS", E="KCALPERMOL")  
 10
 11def get_restart_file(system_data):
 12    restart_file = system_data["name"] + ".dyn.restart"
 13    return restart_file
 14
 15
 16def load_dynamics_restart(system_data):
 17    with open(get_restart_file(system_data), "rb") as f:
 18        restart_data = pickle.load(f)
 19
 20    assert (
 21        restart_data["nat"] == system_data["nat"]
 22    ), f"Restart file does not match system data 'nat'"
 23    assert restart_data.get("nbeads", 0) == system_data.get(
 24        "nbeads", 0
 25    ), f"Restart file does not match system data 'nbeads'"
 26    assert restart_data.get("nreplicas", 0) == system_data.get(
 27        "nreplicas", 0
 28    ), f"Restart file does not match system data 'nreplicas'"
 29    assert np.all(
 30        restart_data["species"] == system_data["species"]
 31    ), f"Restart file does not match system data 'species'"
 32
 33    return restart_data
 34
 35
 36def save_dynamics_restart(system_data, conformation, dyn_state, system):
 37    restart_data = {
 38        "nat": system_data["nat"],
 39        "nbeads": system_data.get("nbeads", 0),
 40        "nreplicas": system_data.get("nreplicas", 0),
 41        "species": system_data["species"],
 42        "coordinates": conformation["coordinates"],
 43        "vel": system["vel"],
 44        "preproc_state": dyn_state["preproc_state"],
 45        "simulation_time_ps": dyn_state["start_time_ps"] + (dyn_state["dt"]*us.PS)* dyn_state["istep"],
 46    }
 47    if "cells" in conformation:
 48        restart_data["cells"] = conformation["cells"]
 49
 50    
 51
 52    restart_file = get_restart_file(system_data)
 53    with open(restart_file, "wb") as f:
 54        pickle.dump(restart_data, f)
 55
 56
 57@partial(jax.jit,static_argnames=["wrap_groups"])
 58def wrapbox(x, cell, reciprocal_cell, wrap_groups = None):
 59    q = x @ reciprocal_cell
 60    if wrap_groups is not None:
 61        for (group, indices) in wrap_groups:
 62            if group == "__other":
 63                q = q.at[indices].add(-jnp.floor(q[indices]))
 64            else:
 65                com = jnp.mean(q[indices], axis=0)
 66                shift = -jnp.floor(com)[None, :]
 67                q = q.at[indices].add(shift)
 68    else:
 69        q = q - jnp.floor(q)
 70    return q @ cell
 71
 72
 73def test_pressure_fd(system_data, conformation, model, verbose=True):
 74    model_energy_unit = us.get_multiplier(model.energy_unit)
 75    nat = system_data["nat"]
 76    volume = system_data["pbc"]["volume"]
 77    coordinates = conformation["coordinates"]
 78    cell = conformation["cells"][0]
 79    ek = 1.5 * nat * system_data["kT"]
 80    Pkin = (2 * us.KBAR / 3.) * ek / volume
 81    e, f, vir_t, _ = model._energy_and_forces_and_virial(model.variables, conformation)
 82    KBAR = us.KBAR / model_energy_unit
 83    Pvir = -(np.trace(vir_t[0]) * KBAR) / (3.0 * volume)
 84    vstep = volume * 0.000001
 85    scalep = ((volume + vstep) / volume) ** (1.0 / 3.0)
 86    cellp = cell * scalep
 87    reciprocal_cell = np.linalg.inv(cellp)
 88    sysp = model.preprocess(
 89        **{
 90            **conformation,
 91            "coordinates": coordinates * scalep,
 92            "cells": cellp[None, :, :],
 93            "reciprocal_cells": reciprocal_cell[None, :, :],
 94        }
 95    )
 96    ep, _ = model._total_energy(model.variables, sysp)
 97    scalem = ((volume - vstep) / volume) ** (1.0 / 3.0)
 98    cellm = cell * scalem
 99    reciprocal_cell = np.linalg.inv(cellm)
100    sysm = model.preprocess(
101        **{
102            **conformation,
103            "coordinates": coordinates * scalem,
104            "cells": cellm[None, :, :],
105            "reciprocal_cells": reciprocal_cell[None, :, :],
106        }
107    )
108    em, _ = model._total_energy(model.variables, sysm)
109    Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep)
110    if verbose:
111        print(
112            f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}"
113        )
114    return Pkin, Pvir, Pvir_fd
115
116
117def optimize_fire2(
118    x0,
119    ef_func,
120    atol=1e-4,
121    dt=0.002,
122    logoutput=True,
123    Nmax=10000,
124    keep_every=-1,
125    max_disp=None,
126):
127    """Fast inertial relaxation engine (FIRE)
128    adapted from https://github.com/elvissoares/PyFIRE
129    """
130    # global variables
131    alpha0 = 0.1
132    Ndelay = 5
133    finc = 1.1
134    fdec = 0.5
135    fa = 0.99
136    Nnegmax = Nmax // 5
137
138    error = 10 * atol
139    dtmax = dt * 10.0
140    dtmin = 0.02 * dt
141    alpha = alpha0
142    Npos = 0
143    Nneg = 0
144
145    nat = x0.shape[0]
146    x = x0.copy()
147    e, F = ef_func(x)
148    V = -0.5 * dt * F  # initial velocity
149    P=0.
150    maxF = np.max(np.abs(F))
151    rmsF = np.sqrt(3 * np.mean(F**2))
152    error = rmsF
153    if error < atol:
154        return x
155
156    if logoutput:
157        print(
158            f"{'#Step':>10} {'Energy':>15} {'RMS Force':>15} {'MAX Force':>15} {'dt':>15} {'Power':>15}"
159        )
160    if logoutput:
161        print(f"{0:10d} {e:15.5f} {rmsF:15.5f} {maxF:15.5f} {dt:15.5f} {0:15.5f}")
162
163    if keep_every > 0:
164        x_keep = [x.copy()]
165
166    if max_disp is not None:
167        assert max_disp > 0, "max_disp must be positive"
168
169    success = False
170    for i in range(Nmax):
171
172        V = V + dt * F
173        V = (1 - alpha) * V + alpha * F * np.linalg.norm(V) / np.linalg.norm(F)
174        if max_disp is not None:
175            V = np.clip(V, -max_disp / dt, max_disp / dt)
176        x = x + dt * V
177        e, F = ef_func(x)
178
179        maxF = np.max(np.abs(F))
180        rmsF = np.sqrt(3 * np.mean(F**2))
181        error = rmsF
182
183        if logoutput:
184            print(f"{i+1:10d} {e:15.5f} {rmsF:15.5f} {maxF:15.5f} {dt:15.5f} {P:15.5f}")
185
186        if error <= atol:
187            success = True
188            break
189        
190        P = (F * V).sum()  # dissipated power
191
192        if P > 0:
193            Npos = Npos + 1
194            Nneg = 0
195            if Npos > Ndelay:
196                dt = min(dt * finc, dtmax)
197                alpha = alpha * fa
198        else:
199            Npos = 0
200            Nneg = Nneg + 1
201            if Nneg > Nnegmax:
202                break
203            if i >= Ndelay:
204                dt = max(dt * fdec, dtmin)
205                alpha = alpha0
206            x = x - 0.5 * dt * V
207            V = np.zeros(x.shape)
208        
209        if keep_every > 0 and (i + 1) % keep_every == 0:
210            x_keep.append(x.copy())
211
212    if keep_every > 0:
213        x_keep.append(x.copy())
214        return x, success, x_keep
215    return x, success
def get_restart_file(system_data):
12def get_restart_file(system_data):
13    restart_file = system_data["name"] + ".dyn.restart"
14    return restart_file
def load_dynamics_restart(system_data):
17def load_dynamics_restart(system_data):
18    with open(get_restart_file(system_data), "rb") as f:
19        restart_data = pickle.load(f)
20
21    assert (
22        restart_data["nat"] == system_data["nat"]
23    ), f"Restart file does not match system data 'nat'"
24    assert restart_data.get("nbeads", 0) == system_data.get(
25        "nbeads", 0
26    ), f"Restart file does not match system data 'nbeads'"
27    assert restart_data.get("nreplicas", 0) == system_data.get(
28        "nreplicas", 0
29    ), f"Restart file does not match system data 'nreplicas'"
30    assert np.all(
31        restart_data["species"] == system_data["species"]
32    ), f"Restart file does not match system data 'species'"
33
34    return restart_data
def save_dynamics_restart(system_data, conformation, dyn_state, system):
37def save_dynamics_restart(system_data, conformation, dyn_state, system):
38    restart_data = {
39        "nat": system_data["nat"],
40        "nbeads": system_data.get("nbeads", 0),
41        "nreplicas": system_data.get("nreplicas", 0),
42        "species": system_data["species"],
43        "coordinates": conformation["coordinates"],
44        "vel": system["vel"],
45        "preproc_state": dyn_state["preproc_state"],
46        "simulation_time_ps": dyn_state["start_time_ps"] + (dyn_state["dt"]*us.PS)* dyn_state["istep"],
47    }
48    if "cells" in conformation:
49        restart_data["cells"] = conformation["cells"]
50
51    
52
53    restart_file = get_restart_file(system_data)
54    with open(restart_file, "wb") as f:
55        pickle.dump(restart_data, f)
@partial(jax.jit, static_argnames=['wrap_groups'])
def wrapbox(x, cell, reciprocal_cell, wrap_groups=None):
58@partial(jax.jit,static_argnames=["wrap_groups"])
59def wrapbox(x, cell, reciprocal_cell, wrap_groups = None):
60    q = x @ reciprocal_cell
61    if wrap_groups is not None:
62        for (group, indices) in wrap_groups:
63            if group == "__other":
64                q = q.at[indices].add(-jnp.floor(q[indices]))
65            else:
66                com = jnp.mean(q[indices], axis=0)
67                shift = -jnp.floor(com)[None, :]
68                q = q.at[indices].add(shift)
69    else:
70        q = q - jnp.floor(q)
71    return q @ cell
def test_pressure_fd(system_data, conformation, model, verbose=True):
 74def test_pressure_fd(system_data, conformation, model, verbose=True):
 75    model_energy_unit = us.get_multiplier(model.energy_unit)
 76    nat = system_data["nat"]
 77    volume = system_data["pbc"]["volume"]
 78    coordinates = conformation["coordinates"]
 79    cell = conformation["cells"][0]
 80    ek = 1.5 * nat * system_data["kT"]
 81    Pkin = (2 * us.KBAR / 3.) * ek / volume
 82    e, f, vir_t, _ = model._energy_and_forces_and_virial(model.variables, conformation)
 83    KBAR = us.KBAR / model_energy_unit
 84    Pvir = -(np.trace(vir_t[0]) * KBAR) / (3.0 * volume)
 85    vstep = volume * 0.000001
 86    scalep = ((volume + vstep) / volume) ** (1.0 / 3.0)
 87    cellp = cell * scalep
 88    reciprocal_cell = np.linalg.inv(cellp)
 89    sysp = model.preprocess(
 90        **{
 91            **conformation,
 92            "coordinates": coordinates * scalep,
 93            "cells": cellp[None, :, :],
 94            "reciprocal_cells": reciprocal_cell[None, :, :],
 95        }
 96    )
 97    ep, _ = model._total_energy(model.variables, sysp)
 98    scalem = ((volume - vstep) / volume) ** (1.0 / 3.0)
 99    cellm = cell * scalem
100    reciprocal_cell = np.linalg.inv(cellm)
101    sysm = model.preprocess(
102        **{
103            **conformation,
104            "coordinates": coordinates * scalem,
105            "cells": cellm[None, :, :],
106            "reciprocal_cells": reciprocal_cell[None, :, :],
107        }
108    )
109    em, _ = model._total_energy(model.variables, sysm)
110    Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep)
111    if verbose:
112        print(
113            f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}"
114        )
115    return Pkin, Pvir, Pvir_fd
def optimize_fire2( x0, ef_func, atol=0.0001, dt=0.002, logoutput=True, Nmax=10000, keep_every=-1, max_disp=None):
118def optimize_fire2(
119    x0,
120    ef_func,
121    atol=1e-4,
122    dt=0.002,
123    logoutput=True,
124    Nmax=10000,
125    keep_every=-1,
126    max_disp=None,
127):
128    """Fast inertial relaxation engine (FIRE)
129    adapted from https://github.com/elvissoares/PyFIRE
130    """
131    # global variables
132    alpha0 = 0.1
133    Ndelay = 5
134    finc = 1.1
135    fdec = 0.5
136    fa = 0.99
137    Nnegmax = Nmax // 5
138
139    error = 10 * atol
140    dtmax = dt * 10.0
141    dtmin = 0.02 * dt
142    alpha = alpha0
143    Npos = 0
144    Nneg = 0
145
146    nat = x0.shape[0]
147    x = x0.copy()
148    e, F = ef_func(x)
149    V = -0.5 * dt * F  # initial velocity
150    P=0.
151    maxF = np.max(np.abs(F))
152    rmsF = np.sqrt(3 * np.mean(F**2))
153    error = rmsF
154    if error < atol:
155        return x
156
157    if logoutput:
158        print(
159            f"{'#Step':>10} {'Energy':>15} {'RMS Force':>15} {'MAX Force':>15} {'dt':>15} {'Power':>15}"
160        )
161    if logoutput:
162        print(f"{0:10d} {e:15.5f} {rmsF:15.5f} {maxF:15.5f} {dt:15.5f} {0:15.5f}")
163
164    if keep_every > 0:
165        x_keep = [x.copy()]
166
167    if max_disp is not None:
168        assert max_disp > 0, "max_disp must be positive"
169
170    success = False
171    for i in range(Nmax):
172
173        V = V + dt * F
174        V = (1 - alpha) * V + alpha * F * np.linalg.norm(V) / np.linalg.norm(F)
175        if max_disp is not None:
176            V = np.clip(V, -max_disp / dt, max_disp / dt)
177        x = x + dt * V
178        e, F = ef_func(x)
179
180        maxF = np.max(np.abs(F))
181        rmsF = np.sqrt(3 * np.mean(F**2))
182        error = rmsF
183
184        if logoutput:
185            print(f"{i+1:10d} {e:15.5f} {rmsF:15.5f} {maxF:15.5f} {dt:15.5f} {P:15.5f}")
186
187        if error <= atol:
188            success = True
189            break
190        
191        P = (F * V).sum()  # dissipated power
192
193        if P > 0:
194            Npos = Npos + 1
195            Nneg = 0
196            if Npos > Ndelay:
197                dt = min(dt * finc, dtmax)
198                alpha = alpha * fa
199        else:
200            Npos = 0
201            Nneg = Nneg + 1
202            if Nneg > Nnegmax:
203                break
204            if i >= Ndelay:
205                dt = max(dt * fdec, dtmin)
206                alpha = alpha0
207            x = x - 0.5 * dt * V
208            V = np.zeros(x.shape)
209        
210        if keep_every > 0 and (i + 1) % keep_every == 0:
211            x_keep.append(x.copy())
212
213    if keep_every > 0:
214        x_keep.append(x.copy())
215        return x, success, x_keep
216    return x, success

Fast inertial relaxation engine (FIRE) adapted from https://github.com/elvissoares/PyFIRE