fennol.md.utils

  1import numpy as np
  2import jax
  3import jax.numpy as jnp
  4import pickle
  5from ..utils import AtomicUnits as au
  6
  7
  8def get_restart_file(system_data):
  9    restart_file = system_data["name"] + ".dyn.restart"
 10    return restart_file
 11
 12
 13def load_dynamics_restart(system_data):
 14    with open(get_restart_file(system_data), "rb") as f:
 15        restart_data = pickle.load(f)
 16
 17    assert (
 18        restart_data["nat"] == system_data["nat"]
 19    ), f"Restart file does not match system data 'nat'"
 20    assert restart_data.get("nbeads", 0) == system_data.get(
 21        "nbeads", 0
 22    ), f"Restart file does not match system data 'nbeads'"
 23    assert restart_data.get("nreplicas", 0) == system_data.get(
 24        "nreplicas", 0
 25    ), f"Restart file does not match system data 'nreplicas'"
 26    assert np.all(
 27        restart_data["species"] == system_data["species"]
 28    ), f"Restart file does not match system data 'species'"
 29
 30    return restart_data
 31
 32
 33def save_dynamics_restart(system_data, conformation, dyn_state, system):
 34    restart_data = {
 35        "nat": system_data["nat"],
 36        "nbeads": system_data.get("nbeads", 0),
 37        "nreplicas": system_data.get("nreplicas", 0),
 38        "species": system_data["species"],
 39        "coordinates": conformation["coordinates"],
 40        "vel": system["vel"],
 41        "preproc_state": dyn_state["preproc_state"],
 42        "simulation_time_ps": dyn_state["start_time_ps"] + (dyn_state["dt"]*1e-3)* dyn_state["istep"],
 43    }
 44    if "cells" in conformation:
 45        restart_data["cells"] = conformation["cells"]
 46
 47    
 48
 49    restart_file = get_restart_file(system_data)
 50    with open(restart_file, "wb") as f:
 51        pickle.dump(restart_data, f)
 52
 53
 54@jax.jit
 55def wrapbox(x, cell, reciprocal_cell):
 56    # q = jnp.einsum("ji,sj->si", reciprocal_cell, x)
 57    q = x @ reciprocal_cell
 58    q = q - jnp.floor(q)
 59    # return jnp.einsum("ji,sj->si", cell, q)
 60    return q @ cell
 61
 62
 63def test_pressure_fd(system_data, conformation, model, verbose=True):
 64    model_energy_unit = au.get_multiplier(model.energy_unit)
 65    nat = system_data["nat"]
 66    volume = system_data["pbc"]["volume"]
 67    coordinates = conformation["coordinates"]
 68    cell = conformation["cells"][0]
 69    # temper = 2 * ek / (3.0 * nat) * au.KELVIN
 70    ek = 1.5 * nat * system_data["kT"]
 71    Pkin = (2 * au.KBAR) * ek / ((3.0 / au.BOHR**3) * volume)
 72    e, f, vir_t, _ = model._energy_and_forces_and_virial(model.variables, conformation)
 73    KBAR = au.KBAR / model_energy_unit
 74    Pvir = -(np.trace(vir_t[0]) * KBAR) / ((3.0 / au.BOHR**3) * volume)
 75    vstep = volume * 0.000001
 76    scalep = ((volume + vstep) / volume) ** (1.0 / 3.0)
 77    cellp = cell * scalep
 78    reciprocal_cell = np.linalg.inv(cellp)
 79    sysp = model.preprocess(
 80        **{
 81            **conformation,
 82            "coordinates": coordinates * scalep,
 83            "cells": cellp[None, :, :],
 84            "reciprocal_cells": reciprocal_cell[None, :, :],
 85        }
 86    )
 87    ep, _ = model._total_energy(model.variables, sysp)
 88    scalem = ((volume - vstep) / volume) ** (1.0 / 3.0)
 89    cellm = cell * scalem
 90    reciprocal_cell = np.linalg.inv(cellm)
 91    sysm = model.preprocess(
 92        **{
 93            **conformation,
 94            "coordinates": coordinates * scalem,
 95            "cells": cellm[None, :, :],
 96            "reciprocal_cells": reciprocal_cell[None, :, :],
 97        }
 98    )
 99    em, _ = model._total_energy(model.variables, sysm)
100    Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep / au.BOHR**3)
101    if verbose:
102        print(
103            f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}"
104        )
105    return Pkin, Pvir, Pvir_fd
def get_restart_file(system_data):
 9def get_restart_file(system_data):
10    restart_file = system_data["name"] + ".dyn.restart"
11    return restart_file
def load_dynamics_restart(system_data):
14def load_dynamics_restart(system_data):
15    with open(get_restart_file(system_data), "rb") as f:
16        restart_data = pickle.load(f)
17
18    assert (
19        restart_data["nat"] == system_data["nat"]
20    ), f"Restart file does not match system data 'nat'"
21    assert restart_data.get("nbeads", 0) == system_data.get(
22        "nbeads", 0
23    ), f"Restart file does not match system data 'nbeads'"
24    assert restart_data.get("nreplicas", 0) == system_data.get(
25        "nreplicas", 0
26    ), f"Restart file does not match system data 'nreplicas'"
27    assert np.all(
28        restart_data["species"] == system_data["species"]
29    ), f"Restart file does not match system data 'species'"
30
31    return restart_data
def save_dynamics_restart(system_data, conformation, dyn_state, system):
34def save_dynamics_restart(system_data, conformation, dyn_state, system):
35    restart_data = {
36        "nat": system_data["nat"],
37        "nbeads": system_data.get("nbeads", 0),
38        "nreplicas": system_data.get("nreplicas", 0),
39        "species": system_data["species"],
40        "coordinates": conformation["coordinates"],
41        "vel": system["vel"],
42        "preproc_state": dyn_state["preproc_state"],
43        "simulation_time_ps": dyn_state["start_time_ps"] + (dyn_state["dt"]*1e-3)* dyn_state["istep"],
44    }
45    if "cells" in conformation:
46        restart_data["cells"] = conformation["cells"]
47
48    
49
50    restart_file = get_restart_file(system_data)
51    with open(restart_file, "wb") as f:
52        pickle.dump(restart_data, f)
@jax.jit
def wrapbox(x, cell, reciprocal_cell):
55@jax.jit
56def wrapbox(x, cell, reciprocal_cell):
57    # q = jnp.einsum("ji,sj->si", reciprocal_cell, x)
58    q = x @ reciprocal_cell
59    q = q - jnp.floor(q)
60    # return jnp.einsum("ji,sj->si", cell, q)
61    return q @ cell
def test_pressure_fd(system_data, conformation, model, verbose=True):
 64def test_pressure_fd(system_data, conformation, model, verbose=True):
 65    model_energy_unit = au.get_multiplier(model.energy_unit)
 66    nat = system_data["nat"]
 67    volume = system_data["pbc"]["volume"]
 68    coordinates = conformation["coordinates"]
 69    cell = conformation["cells"][0]
 70    # temper = 2 * ek / (3.0 * nat) * au.KELVIN
 71    ek = 1.5 * nat * system_data["kT"]
 72    Pkin = (2 * au.KBAR) * ek / ((3.0 / au.BOHR**3) * volume)
 73    e, f, vir_t, _ = model._energy_and_forces_and_virial(model.variables, conformation)
 74    KBAR = au.KBAR / model_energy_unit
 75    Pvir = -(np.trace(vir_t[0]) * KBAR) / ((3.0 / au.BOHR**3) * volume)
 76    vstep = volume * 0.000001
 77    scalep = ((volume + vstep) / volume) ** (1.0 / 3.0)
 78    cellp = cell * scalep
 79    reciprocal_cell = np.linalg.inv(cellp)
 80    sysp = model.preprocess(
 81        **{
 82            **conformation,
 83            "coordinates": coordinates * scalep,
 84            "cells": cellp[None, :, :],
 85            "reciprocal_cells": reciprocal_cell[None, :, :],
 86        }
 87    )
 88    ep, _ = model._total_energy(model.variables, sysp)
 89    scalem = ((volume - vstep) / volume) ** (1.0 / 3.0)
 90    cellm = cell * scalem
 91    reciprocal_cell = np.linalg.inv(cellm)
 92    sysm = model.preprocess(
 93        **{
 94            **conformation,
 95            "coordinates": coordinates * scalem,
 96            "cells": cellm[None, :, :],
 97            "reciprocal_cells": reciprocal_cell[None, :, :],
 98        }
 99    )
100    em, _ = model._total_energy(model.variables, sysm)
101    Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep / au.BOHR**3)
102    if verbose:
103        print(
104            f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}"
105        )
106    return Pkin, Pvir, Pvir_fd