fennol.md.utils
1import numpy as np 2import jax 3import jax.numpy as jnp 4import pickle 5from ..utils import AtomicUnits as au 6 7 8def get_restart_file(system_data): 9 restart_file = system_data["name"] + ".dyn.restart" 10 return restart_file 11 12 13def load_dynamics_restart(system_data): 14 with open(get_restart_file(system_data), "rb") as f: 15 restart_data = pickle.load(f) 16 17 assert ( 18 restart_data["nat"] == system_data["nat"] 19 ), f"Restart file does not match system data 'nat'" 20 assert restart_data.get("nbeads", 0) == system_data.get( 21 "nbeads", 0 22 ), f"Restart file does not match system data 'nbeads'" 23 assert restart_data.get("nreplicas", 0) == system_data.get( 24 "nreplicas", 0 25 ), f"Restart file does not match system data 'nreplicas'" 26 assert np.all( 27 restart_data["species"] == system_data["species"] 28 ), f"Restart file does not match system data 'species'" 29 30 return restart_data 31 32 33def save_dynamics_restart(system_data, conformation, dyn_state, system): 34 restart_data = { 35 "nat": system_data["nat"], 36 "nbeads": system_data.get("nbeads", 0), 37 "nreplicas": system_data.get("nreplicas", 0), 38 "species": system_data["species"], 39 "coordinates": conformation["coordinates"], 40 "vel": system["vel"], 41 "preproc_state": dyn_state["preproc_state"], 42 "simulation_time_ps": dyn_state["start_time_ps"] + (dyn_state["dt"]*1e-3)* dyn_state["istep"], 43 } 44 if "cells" in conformation: 45 restart_data["cells"] = conformation["cells"] 46 47 48 49 restart_file = get_restart_file(system_data) 50 with open(restart_file, "wb") as f: 51 pickle.dump(restart_data, f) 52 53 54@jax.jit 55def wrapbox(x, cell, reciprocal_cell): 56 # q = jnp.einsum("ji,sj->si", reciprocal_cell, x) 57 q = x @ reciprocal_cell 58 q = q - jnp.floor(q) 59 # return jnp.einsum("ji,sj->si", cell, q) 60 return q @ cell 61 62 63def test_pressure_fd(system_data, conformation, model, verbose=True): 64 model_energy_unit = au.get_multiplier(model.energy_unit) 65 nat = system_data["nat"] 66 volume = system_data["pbc"]["volume"] 67 coordinates = conformation["coordinates"] 68 cell = conformation["cells"][0] 69 # temper = 2 * ek / (3.0 * nat) * au.KELVIN 70 ek = 1.5 * nat * system_data["kT"] 71 Pkin = (2 * au.KBAR) * ek / ((3.0 / au.BOHR**3) * volume) 72 e, f, vir_t, _ = model._energy_and_forces_and_virial(model.variables, conformation) 73 KBAR = au.KBAR / model_energy_unit 74 Pvir = -(np.trace(vir_t[0]) * KBAR) / ((3.0 / au.BOHR**3) * volume) 75 vstep = volume * 0.000001 76 scalep = ((volume + vstep) / volume) ** (1.0 / 3.0) 77 cellp = cell * scalep 78 reciprocal_cell = np.linalg.inv(cellp) 79 sysp = model.preprocess( 80 **{ 81 **conformation, 82 "coordinates": coordinates * scalep, 83 "cells": cellp[None, :, :], 84 "reciprocal_cells": reciprocal_cell[None, :, :], 85 } 86 ) 87 ep, _ = model._total_energy(model.variables, sysp) 88 scalem = ((volume - vstep) / volume) ** (1.0 / 3.0) 89 cellm = cell * scalem 90 reciprocal_cell = np.linalg.inv(cellm) 91 sysm = model.preprocess( 92 **{ 93 **conformation, 94 "coordinates": coordinates * scalem, 95 "cells": cellm[None, :, :], 96 "reciprocal_cells": reciprocal_cell[None, :, :], 97 } 98 ) 99 em, _ = model._total_energy(model.variables, sysm) 100 Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep / au.BOHR**3) 101 if verbose: 102 print( 103 f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}" 104 ) 105 return Pkin, Pvir, Pvir_fd
def
get_restart_file(system_data):
def
load_dynamics_restart(system_data):
14def load_dynamics_restart(system_data): 15 with open(get_restart_file(system_data), "rb") as f: 16 restart_data = pickle.load(f) 17 18 assert ( 19 restart_data["nat"] == system_data["nat"] 20 ), f"Restart file does not match system data 'nat'" 21 assert restart_data.get("nbeads", 0) == system_data.get( 22 "nbeads", 0 23 ), f"Restart file does not match system data 'nbeads'" 24 assert restart_data.get("nreplicas", 0) == system_data.get( 25 "nreplicas", 0 26 ), f"Restart file does not match system data 'nreplicas'" 27 assert np.all( 28 restart_data["species"] == system_data["species"] 29 ), f"Restart file does not match system data 'species'" 30 31 return restart_data
def
save_dynamics_restart(system_data, conformation, dyn_state, system):
34def save_dynamics_restart(system_data, conformation, dyn_state, system): 35 restart_data = { 36 "nat": system_data["nat"], 37 "nbeads": system_data.get("nbeads", 0), 38 "nreplicas": system_data.get("nreplicas", 0), 39 "species": system_data["species"], 40 "coordinates": conformation["coordinates"], 41 "vel": system["vel"], 42 "preproc_state": dyn_state["preproc_state"], 43 "simulation_time_ps": dyn_state["start_time_ps"] + (dyn_state["dt"]*1e-3)* dyn_state["istep"], 44 } 45 if "cells" in conformation: 46 restart_data["cells"] = conformation["cells"] 47 48 49 50 restart_file = get_restart_file(system_data) 51 with open(restart_file, "wb") as f: 52 pickle.dump(restart_data, f)
@jax.jit
def
wrapbox(x, cell, reciprocal_cell):
def
test_pressure_fd(system_data, conformation, model, verbose=True):
64def test_pressure_fd(system_data, conformation, model, verbose=True): 65 model_energy_unit = au.get_multiplier(model.energy_unit) 66 nat = system_data["nat"] 67 volume = system_data["pbc"]["volume"] 68 coordinates = conformation["coordinates"] 69 cell = conformation["cells"][0] 70 # temper = 2 * ek / (3.0 * nat) * au.KELVIN 71 ek = 1.5 * nat * system_data["kT"] 72 Pkin = (2 * au.KBAR) * ek / ((3.0 / au.BOHR**3) * volume) 73 e, f, vir_t, _ = model._energy_and_forces_and_virial(model.variables, conformation) 74 KBAR = au.KBAR / model_energy_unit 75 Pvir = -(np.trace(vir_t[0]) * KBAR) / ((3.0 / au.BOHR**3) * volume) 76 vstep = volume * 0.000001 77 scalep = ((volume + vstep) / volume) ** (1.0 / 3.0) 78 cellp = cell * scalep 79 reciprocal_cell = np.linalg.inv(cellp) 80 sysp = model.preprocess( 81 **{ 82 **conformation, 83 "coordinates": coordinates * scalep, 84 "cells": cellp[None, :, :], 85 "reciprocal_cells": reciprocal_cell[None, :, :], 86 } 87 ) 88 ep, _ = model._total_energy(model.variables, sysp) 89 scalem = ((volume - vstep) / volume) ** (1.0 / 3.0) 90 cellm = cell * scalem 91 reciprocal_cell = np.linalg.inv(cellm) 92 sysm = model.preprocess( 93 **{ 94 **conformation, 95 "coordinates": coordinates * scalem, 96 "cells": cellm[None, :, :], 97 "reciprocal_cells": reciprocal_cell[None, :, :], 98 } 99 ) 100 em, _ = model._total_energy(model.variables, sysm) 101 Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep / au.BOHR**3) 102 if verbose: 103 print( 104 f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}" 105 ) 106 return Pkin, Pvir, Pvir_fd