fennol.md.utils
1import numpy as np 2import jax 3import jax.numpy as jnp 4import pickle 5from ..utils import UnitSystem 6from functools import partial 7 8# Define the unit system for fennol_md 9us = UnitSystem(L="ANGSTROM", T="PS", E="KCALPERMOL") 10 11def get_restart_file(system_data): 12 restart_file = system_data["name"] + ".dyn.restart" 13 return restart_file 14 15 16def load_dynamics_restart(system_data): 17 with open(get_restart_file(system_data), "rb") as f: 18 restart_data = pickle.load(f) 19 20 assert ( 21 restart_data["nat"] == system_data["nat"] 22 ), f"Restart file does not match system data 'nat'" 23 assert restart_data.get("nbeads", 0) == system_data.get( 24 "nbeads", 0 25 ), f"Restart file does not match system data 'nbeads'" 26 assert restart_data.get("nreplicas", 0) == system_data.get( 27 "nreplicas", 0 28 ), f"Restart file does not match system data 'nreplicas'" 29 assert np.all( 30 restart_data["species"] == system_data["species"] 31 ), f"Restart file does not match system data 'species'" 32 33 return restart_data 34 35 36def save_dynamics_restart(system_data, conformation, dyn_state, system): 37 restart_data = { 38 "nat": system_data["nat"], 39 "nbeads": system_data.get("nbeads", 0), 40 "nreplicas": system_data.get("nreplicas", 0), 41 "species": system_data["species"], 42 "coordinates": conformation["coordinates"], 43 "vel": system["vel"], 44 "preproc_state": dyn_state["preproc_state"], 45 "simulation_time_ps": dyn_state["start_time_ps"] + (dyn_state["dt"]*us.PS)* dyn_state["istep"], 46 } 47 if "cells" in conformation: 48 restart_data["cells"] = conformation["cells"] 49 50 51 52 restart_file = get_restart_file(system_data) 53 with open(restart_file, "wb") as f: 54 pickle.dump(restart_data, f) 55 56 57@partial(jax.jit,static_argnames=["wrap_groups"]) 58def wrapbox(x, cell, reciprocal_cell, wrap_groups = None): 59 q = x @ reciprocal_cell 60 if wrap_groups is not None: 61 for (group, indices) in wrap_groups: 62 if group == "__other": 63 q = q.at[indices].add(-jnp.floor(q[indices])) 64 else: 65 com = jnp.mean(q[indices], axis=0) 66 shift = -jnp.floor(com)[None, :] 67 q = q.at[indices].add(shift) 68 else: 69 q = q - jnp.floor(q) 70 return q @ cell 71 72 73def test_pressure_fd(system_data, conformation, model, verbose=True): 74 model_energy_unit = us.get_multiplier(model.energy_unit) 75 nat = system_data["nat"] 76 volume = system_data["pbc"]["volume"] 77 coordinates = conformation["coordinates"] 78 cell = conformation["cells"][0] 79 ek = 1.5 * nat * system_data["kT"] 80 Pkin = (2 * us.KBAR / 3.) * ek / volume 81 e, f, vir_t, _ = model._energy_and_forces_and_virial(model.variables, conformation) 82 KBAR = us.KBAR / model_energy_unit 83 Pvir = -(np.trace(vir_t[0]) * KBAR) / (3.0 * volume) 84 vstep = volume * 0.000001 85 scalep = ((volume + vstep) / volume) ** (1.0 / 3.0) 86 cellp = cell * scalep 87 reciprocal_cell = np.linalg.inv(cellp) 88 sysp = model.preprocess( 89 **{ 90 **conformation, 91 "coordinates": coordinates * scalep, 92 "cells": cellp[None, :, :], 93 "reciprocal_cells": reciprocal_cell[None, :, :], 94 } 95 ) 96 ep, _ = model._total_energy(model.variables, sysp) 97 scalem = ((volume - vstep) / volume) ** (1.0 / 3.0) 98 cellm = cell * scalem 99 reciprocal_cell = np.linalg.inv(cellm) 100 sysm = model.preprocess( 101 **{ 102 **conformation, 103 "coordinates": coordinates * scalem, 104 "cells": cellm[None, :, :], 105 "reciprocal_cells": reciprocal_cell[None, :, :], 106 } 107 ) 108 em, _ = model._total_energy(model.variables, sysm) 109 Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep) 110 if verbose: 111 print( 112 f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}" 113 ) 114 return Pkin, Pvir, Pvir_fd 115 116 117def optimize_fire2( 118 x0, 119 ef_func, 120 atol=1e-4, 121 dt=0.002, 122 logoutput=True, 123 Nmax=10000, 124 keep_every=-1, 125 max_disp=None, 126): 127 """Fast inertial relaxation engine (FIRE) 128 adapted from https://github.com/elvissoares/PyFIRE 129 """ 130 # global variables 131 alpha0 = 0.1 132 Ndelay = 5 133 finc = 1.1 134 fdec = 0.5 135 fa = 0.99 136 Nnegmax = Nmax // 5 137 138 error = 10 * atol 139 dtmax = dt * 10.0 140 dtmin = 0.02 * dt 141 alpha = alpha0 142 Npos = 0 143 Nneg = 0 144 145 nat = x0.shape[0] 146 x = x0.copy() 147 e, F = ef_func(x) 148 V = -0.5 * dt * F # initial velocity 149 P=0. 150 maxF = np.max(np.abs(F)) 151 rmsF = np.sqrt(3 * np.mean(F**2)) 152 error = rmsF 153 if error < atol: 154 return x 155 156 if logoutput: 157 print( 158 f"{'#Step':>10} {'Energy':>15} {'RMS Force':>15} {'MAX Force':>15} {'dt':>15} {'Power':>15}" 159 ) 160 if logoutput: 161 print(f"{0:10d} {e:15.5f} {rmsF:15.5f} {maxF:15.5f} {dt:15.5f} {0:15.5f}") 162 163 if keep_every > 0: 164 x_keep = [x.copy()] 165 166 if max_disp is not None: 167 assert max_disp > 0, "max_disp must be positive" 168 169 success = False 170 for i in range(Nmax): 171 172 V = V + dt * F 173 V = (1 - alpha) * V + alpha * F * np.linalg.norm(V) / np.linalg.norm(F) 174 if max_disp is not None: 175 V = np.clip(V, -max_disp / dt, max_disp / dt) 176 x = x + dt * V 177 e, F = ef_func(x) 178 179 maxF = np.max(np.abs(F)) 180 rmsF = np.sqrt(3 * np.mean(F**2)) 181 error = rmsF 182 183 if logoutput: 184 print(f"{i+1:10d} {e:15.5f} {rmsF:15.5f} {maxF:15.5f} {dt:15.5f} {P:15.5f}") 185 186 if error <= atol: 187 success = True 188 break 189 190 P = (F * V).sum() # dissipated power 191 192 if P > 0: 193 Npos = Npos + 1 194 Nneg = 0 195 if Npos > Ndelay: 196 dt = min(dt * finc, dtmax) 197 alpha = alpha * fa 198 else: 199 Npos = 0 200 Nneg = Nneg + 1 201 if Nneg > Nnegmax: 202 break 203 if i >= Ndelay: 204 dt = max(dt * fdec, dtmin) 205 alpha = alpha0 206 x = x - 0.5 * dt * V 207 V = np.zeros(x.shape) 208 209 if keep_every > 0 and (i + 1) % keep_every == 0: 210 x_keep.append(x.copy()) 211 212 if keep_every > 0: 213 x_keep.append(x.copy()) 214 return x, success, x_keep 215 return x, success
us =
<fennol.utils.atomic_units.UnitSystem object>
def
get_restart_file(system_data):
def
load_dynamics_restart(system_data):
17def load_dynamics_restart(system_data): 18 with open(get_restart_file(system_data), "rb") as f: 19 restart_data = pickle.load(f) 20 21 assert ( 22 restart_data["nat"] == system_data["nat"] 23 ), f"Restart file does not match system data 'nat'" 24 assert restart_data.get("nbeads", 0) == system_data.get( 25 "nbeads", 0 26 ), f"Restart file does not match system data 'nbeads'" 27 assert restart_data.get("nreplicas", 0) == system_data.get( 28 "nreplicas", 0 29 ), f"Restart file does not match system data 'nreplicas'" 30 assert np.all( 31 restart_data["species"] == system_data["species"] 32 ), f"Restart file does not match system data 'species'" 33 34 return restart_data
def
save_dynamics_restart(system_data, conformation, dyn_state, system):
37def save_dynamics_restart(system_data, conformation, dyn_state, system): 38 restart_data = { 39 "nat": system_data["nat"], 40 "nbeads": system_data.get("nbeads", 0), 41 "nreplicas": system_data.get("nreplicas", 0), 42 "species": system_data["species"], 43 "coordinates": conformation["coordinates"], 44 "vel": system["vel"], 45 "preproc_state": dyn_state["preproc_state"], 46 "simulation_time_ps": dyn_state["start_time_ps"] + (dyn_state["dt"]*us.PS)* dyn_state["istep"], 47 } 48 if "cells" in conformation: 49 restart_data["cells"] = conformation["cells"] 50 51 52 53 restart_file = get_restart_file(system_data) 54 with open(restart_file, "wb") as f: 55 pickle.dump(restart_data, f)
@partial(jax.jit, static_argnames=['wrap_groups'])
def
wrapbox(x, cell, reciprocal_cell, wrap_groups=None):
58@partial(jax.jit,static_argnames=["wrap_groups"]) 59def wrapbox(x, cell, reciprocal_cell, wrap_groups = None): 60 q = x @ reciprocal_cell 61 if wrap_groups is not None: 62 for (group, indices) in wrap_groups: 63 if group == "__other": 64 q = q.at[indices].add(-jnp.floor(q[indices])) 65 else: 66 com = jnp.mean(q[indices], axis=0) 67 shift = -jnp.floor(com)[None, :] 68 q = q.at[indices].add(shift) 69 else: 70 q = q - jnp.floor(q) 71 return q @ cell
def
test_pressure_fd(system_data, conformation, model, verbose=True):
74def test_pressure_fd(system_data, conformation, model, verbose=True): 75 model_energy_unit = us.get_multiplier(model.energy_unit) 76 nat = system_data["nat"] 77 volume = system_data["pbc"]["volume"] 78 coordinates = conformation["coordinates"] 79 cell = conformation["cells"][0] 80 ek = 1.5 * nat * system_data["kT"] 81 Pkin = (2 * us.KBAR / 3.) * ek / volume 82 e, f, vir_t, _ = model._energy_and_forces_and_virial(model.variables, conformation) 83 KBAR = us.KBAR / model_energy_unit 84 Pvir = -(np.trace(vir_t[0]) * KBAR) / (3.0 * volume) 85 vstep = volume * 0.000001 86 scalep = ((volume + vstep) / volume) ** (1.0 / 3.0) 87 cellp = cell * scalep 88 reciprocal_cell = np.linalg.inv(cellp) 89 sysp = model.preprocess( 90 **{ 91 **conformation, 92 "coordinates": coordinates * scalep, 93 "cells": cellp[None, :, :], 94 "reciprocal_cells": reciprocal_cell[None, :, :], 95 } 96 ) 97 ep, _ = model._total_energy(model.variables, sysp) 98 scalem = ((volume - vstep) / volume) ** (1.0 / 3.0) 99 cellm = cell * scalem 100 reciprocal_cell = np.linalg.inv(cellm) 101 sysm = model.preprocess( 102 **{ 103 **conformation, 104 "coordinates": coordinates * scalem, 105 "cells": cellm[None, :, :], 106 "reciprocal_cells": reciprocal_cell[None, :, :], 107 } 108 ) 109 em, _ = model._total_energy(model.variables, sysm) 110 Pvir_fd = -(ep[0] * KBAR - em[0] * KBAR) / (2.0 * vstep) 111 if verbose: 112 print( 113 f"# Initial pressure: {Pkin+Pvir:.3f} (virial); {Pkin+Pvir_fd:.3f} (finite difference) ; Pkin: {Pkin:.3f} ; Pvir: {Pvir:.3f} ; Pvir_fd: {Pvir_fd:.3f}" 114 ) 115 return Pkin, Pvir, Pvir_fd
def
optimize_fire2( x0, ef_func, atol=0.0001, dt=0.002, logoutput=True, Nmax=10000, keep_every=-1, max_disp=None):
118def optimize_fire2( 119 x0, 120 ef_func, 121 atol=1e-4, 122 dt=0.002, 123 logoutput=True, 124 Nmax=10000, 125 keep_every=-1, 126 max_disp=None, 127): 128 """Fast inertial relaxation engine (FIRE) 129 adapted from https://github.com/elvissoares/PyFIRE 130 """ 131 # global variables 132 alpha0 = 0.1 133 Ndelay = 5 134 finc = 1.1 135 fdec = 0.5 136 fa = 0.99 137 Nnegmax = Nmax // 5 138 139 error = 10 * atol 140 dtmax = dt * 10.0 141 dtmin = 0.02 * dt 142 alpha = alpha0 143 Npos = 0 144 Nneg = 0 145 146 nat = x0.shape[0] 147 x = x0.copy() 148 e, F = ef_func(x) 149 V = -0.5 * dt * F # initial velocity 150 P=0. 151 maxF = np.max(np.abs(F)) 152 rmsF = np.sqrt(3 * np.mean(F**2)) 153 error = rmsF 154 if error < atol: 155 return x 156 157 if logoutput: 158 print( 159 f"{'#Step':>10} {'Energy':>15} {'RMS Force':>15} {'MAX Force':>15} {'dt':>15} {'Power':>15}" 160 ) 161 if logoutput: 162 print(f"{0:10d} {e:15.5f} {rmsF:15.5f} {maxF:15.5f} {dt:15.5f} {0:15.5f}") 163 164 if keep_every > 0: 165 x_keep = [x.copy()] 166 167 if max_disp is not None: 168 assert max_disp > 0, "max_disp must be positive" 169 170 success = False 171 for i in range(Nmax): 172 173 V = V + dt * F 174 V = (1 - alpha) * V + alpha * F * np.linalg.norm(V) / np.linalg.norm(F) 175 if max_disp is not None: 176 V = np.clip(V, -max_disp / dt, max_disp / dt) 177 x = x + dt * V 178 e, F = ef_func(x) 179 180 maxF = np.max(np.abs(F)) 181 rmsF = np.sqrt(3 * np.mean(F**2)) 182 error = rmsF 183 184 if logoutput: 185 print(f"{i+1:10d} {e:15.5f} {rmsF:15.5f} {maxF:15.5f} {dt:15.5f} {P:15.5f}") 186 187 if error <= atol: 188 success = True 189 break 190 191 P = (F * V).sum() # dissipated power 192 193 if P > 0: 194 Npos = Npos + 1 195 Nneg = 0 196 if Npos > Ndelay: 197 dt = min(dt * finc, dtmax) 198 alpha = alpha * fa 199 else: 200 Npos = 0 201 Nneg = Nneg + 1 202 if Nneg > Nnegmax: 203 break 204 if i >= Ndelay: 205 dt = max(dt * fdec, dtmin) 206 alpha = alpha0 207 x = x - 0.5 * dt * V 208 V = np.zeros(x.shape) 209 210 if keep_every > 0 and (i + 1) % keep_every == 0: 211 x_keep.append(x.copy()) 212 213 if keep_every > 0: 214 x_keep.append(x.copy()) 215 return x, success, x_keep 216 return x, success
Fast inertial relaxation engine (FIRE) adapted from https://github.com/elvissoares/PyFIRE