fennol.md.initial
1import sys, os, io 2import argparse 3import time 4from pathlib import Path 5import math 6 7import numpy as np 8from typing import Optional, Callable 9from collections import defaultdict 10from functools import partial 11import jax 12import jax.numpy as jnp 13 14from flax.core import freeze, unfreeze 15 16from ..utils.io import last_xyz_frame 17 18 19from ..models import FENNIX 20 21from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES 22from ..utils.atomic_units import AtomicUnits as au 23from ..utils.input_parser import parse_input 24from .thermostats import get_thermostat 25 26from copy import deepcopy 27 28 29def load_model(simulation_parameters): 30 model_file = simulation_parameters.get("model_file") 31 model_file = Path(str(model_file).strip()) 32 if not model_file.exists(): 33 raise FileNotFoundError(f"model file {model_file} not found") 34 else: 35 graph_config = simulation_parameters.get("graph_config", {}) 36 model = FENNIX.load(model_file, graph_config=graph_config) # \ 37 print(f"# model_file: {model_file}") 38 39 if "energy_terms" in simulation_parameters: 40 energy_terms = simulation_parameters["energy_terms"] 41 if isinstance(energy_terms, str): 42 energy_terms = energy_terms.split() 43 model.set_energy_terms(energy_terms) 44 print("# energy terms:", model.energy_terms) 45 46 return model 47 48 49def load_system_data(simulation_parameters, fprec): 50 ## LOAD SYSTEM CONFORMATION FROM FILES 51 system_name = str(simulation_parameters.get("system", "system")).strip() 52 indexed = simulation_parameters.get("xyz_input/indexed", True) 53 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", False) 54 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 55 if not xyzfile.exists(): 56 raise FileNotFoundError(f"xyz file {xyzfile} not found") 57 system_name = str(simulation_parameters.get("system", xyzfile.stem)).strip() 58 symbols, coordinates, _ = last_xyz_frame( 59 xyzfile, indexed=indexed, has_comment_line=has_comment_line 60 ) 61 coordinates = coordinates.astype(fprec) 62 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 63 nat = species.shape[0] 64 65 ## GET MASS 66 mass_amu = np.array(ATOMIC_MASSES, dtype=fprec)[species] 67 deuterate = simulation_parameters.get("deuterate", False) 68 if deuterate: 69 print("# Replacing all hydrogens with deuteriums") 70 mass_amu[species == 1] *= 2.0 71 mass = mass_amu * (au.MPROT * (au.FS / au.BOHR) ** 2) 72 73 ### GET TEMPERATURE 74 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 75 kT = temperature / au.KELVIN 76 totmass_amu = mass_amu.sum()/6.02214129e-1 77 78 ## SYSTEM DATA 79 system_data = { 80 "name": system_name, 81 "nat": nat, 82 "symbols": symbols, 83 "species": species, 84 "mass": mass, 85 "temperature": temperature, 86 "kT": kT, 87 "totmass_amu": totmass_amu, 88 } 89 90 ### Set boundary conditions 91 cell = simulation_parameters.get("cell", None) 92 if cell is not None: 93 cell = np.array(cell, dtype=fprec).reshape(3, 3) 94 reciprocal_cell = np.linalg.inv(cell) 95 volume = np.abs(np.linalg.det(cell)) 96 print("# cell matrix:") 97 for l in cell: 98 print("# ", l) 99 # print(cell) 100 dens = totmass_amu / volume 101 print("# density: ", dens.item(), " g/cm^3") 102 minimum_image = simulation_parameters.get("minimum_image", True) 103 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 104 print("# minimum_image: ", minimum_image) 105 106 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 107 if crystal_input: 108 coordinates = coordinates @ cell 109 110 pbc_data = { 111 "cell": cell, 112 "reciprocal_cell": reciprocal_cell, 113 "volume": volume, 114 "minimum_image": minimum_image, 115 "estimate_pressure": estimate_pressure, 116 } 117 else: 118 pbc_data = None 119 system_data["pbc"] = pbc_data 120 121 ### PIMD 122 nbeads = simulation_parameters.get("nbeads", None) 123 if nbeads is not None: 124 nbeads = int(nbeads) 125 print("# nbeads: ", nbeads) 126 system_data["nbeads"] = nbeads 127 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1,3) 128 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 129 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 130 natoms = np.array([nat] * nbeads, dtype=np.int32) 131 132 eigmat = np.zeros((nbeads, nbeads)) 133 for i in range(nbeads - 1): 134 eigmat[i, i] = 2.0 135 eigmat[i, i + 1] = -1.0 136 eigmat[i + 1, i] = -1.0 137 eigmat[-1, -1] = 2.0 138 eigmat[0, -1] = -1.0 139 eigmat[-1, 0] = -1.0 140 omk, eigmat = np.linalg.eigh(eigmat) 141 omk[0] = 0.0 142 omk = nbeads * kT * omk**0.5 / au.FS 143 for i in range(nbeads): 144 if eigmat[i, 0] < 0: 145 eigmat[i] *= -1.0 146 eigmat = jnp.asarray(eigmat, dtype=fprec) 147 system_data["omk"] = omk 148 system_data["eigmat"] = eigmat 149 else: 150 bead_index = np.array([0] * nat, dtype=np.int32) 151 natoms = np.array([nat], dtype=np.int32) 152 153 conformation = { 154 "species": species, 155 "coordinates": coordinates, 156 "batch_index": bead_index, 157 "natoms": natoms, 158 } 159 if cell is not None: 160 cell = cell[None, :, :] 161 reciprocal_cell = reciprocal_cell[None, :, :] 162 if nbeads is not None: 163 cell = np.repeat(cell, nbeads, axis=0) 164 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 165 conformation["cells"] = cell 166 conformation["reciprocal_cells"] = reciprocal_cell 167 168 additional_keys = simulation_parameters.get("additional_keys", {}) 169 for key, value in additional_keys.items(): 170 conformation[key] = value 171 172 return system_data, conformation 173 174 175def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 176 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 177 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 178 pbc_data = system_data.get("pbc", None) 179 180 ### CONFIGURE PREPROCESSING 181 preproc_state = unfreeze(model.preproc_state) 182 layer_state = [] 183 for st in preproc_state["layers_state"]: 184 stnew = unfreeze(st) 185 if pbc_data is not None: 186 stnew["minimum_image"] = pbc_data["minimum_image"] 187 if nblist_skin > 0: 188 stnew["nblist_skin"] = nblist_skin 189 if "nblist_mult_size" in simulation_parameters: 190 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 191 if "nblist_add_neigh" in simulation_parameters: 192 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 193 layer_state.append(freeze(stnew)) 194 preproc_state["layers_state"] = layer_state 195 preproc_state = freeze(preproc_state) 196 197 ## initial preprocessing 198 preproc_state = preproc_state.copy({"check_input": True}) 199 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 200 201 preproc_state = preproc_state.copy({"check_input": False}) 202 203 if nblist_verbose: 204 graphs_keys = list(model._graphs_properties.keys()) 205 print("# graphs_keys: ", graphs_keys) 206 print("# nblist state:", preproc_state) 207 208 ### print model 209 if simulation_parameters.get("print_model", False): 210 print(model.summarize(example_data=conformation)) 211 212 return preproc_state, conformation 213 214 215def initialize_system(conformation, vel, model, system_data, fprec): 216 ## initial energy and forces 217 print("# Computing initial energy and forces") 218 e, f, vir,_ = model._energy_and_forces_and_virial(model.variables, conformation) 219 model_energy_unit = model.Ha_to_model_energy 220 f = np.array(f) / model_energy_unit 221 epot = np.mean(e) / model_energy_unit 222 vir = np.mean(vir, axis=0) / model_energy_unit 223 224 if "nbeads" in system_data: 225 ek = 0.5 * jnp.sum(system_data["mass"][:, None,None] * vel[0,:,:,None]*vel[0,:,None,:],axis=0) 226 else: 227 ek = 0.5 * jnp.sum(system_data["mass"][:, None,None] * vel[:,:,None]*vel[:,None,:],axis=0) 228 229 ## build system 230 system = {} 231 system["ek_tensor"] = ek 232 system["ek"] = jnp.trace(ek) 233 system["epot"] = epot 234 system["vel"] = vel.astype(fprec) 235 if "cells" in conformation: 236 system["virial"] = vir 237 system["cell"] = conformation["cells"][0] 238 if "nbeads" in system_data: 239 nbeads = system_data["nbeads"] 240 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 241 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 242 system["coordinates"] = eigx 243 system["forces"] = jnp.einsum( 244 "in,i...->n...", system_data["eigmat"], f.reshape(nbeads, -1, 3) 245 ) * (1.0 / nbeads**0.5) 246 else: 247 system["coordinates"] = conformation["coordinates"] 248 system["forces"] = f 249 250 return system
def
load_model(simulation_parameters):
30def load_model(simulation_parameters): 31 model_file = simulation_parameters.get("model_file") 32 model_file = Path(str(model_file).strip()) 33 if not model_file.exists(): 34 raise FileNotFoundError(f"model file {model_file} not found") 35 else: 36 graph_config = simulation_parameters.get("graph_config", {}) 37 model = FENNIX.load(model_file, graph_config=graph_config) # \ 38 print(f"# model_file: {model_file}") 39 40 if "energy_terms" in simulation_parameters: 41 energy_terms = simulation_parameters["energy_terms"] 42 if isinstance(energy_terms, str): 43 energy_terms = energy_terms.split() 44 model.set_energy_terms(energy_terms) 45 print("# energy terms:", model.energy_terms) 46 47 return model
def
load_system_data(simulation_parameters, fprec):
50def load_system_data(simulation_parameters, fprec): 51 ## LOAD SYSTEM CONFORMATION FROM FILES 52 system_name = str(simulation_parameters.get("system", "system")).strip() 53 indexed = simulation_parameters.get("xyz_input/indexed", True) 54 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", False) 55 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 56 if not xyzfile.exists(): 57 raise FileNotFoundError(f"xyz file {xyzfile} not found") 58 system_name = str(simulation_parameters.get("system", xyzfile.stem)).strip() 59 symbols, coordinates, _ = last_xyz_frame( 60 xyzfile, indexed=indexed, has_comment_line=has_comment_line 61 ) 62 coordinates = coordinates.astype(fprec) 63 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 64 nat = species.shape[0] 65 66 ## GET MASS 67 mass_amu = np.array(ATOMIC_MASSES, dtype=fprec)[species] 68 deuterate = simulation_parameters.get("deuterate", False) 69 if deuterate: 70 print("# Replacing all hydrogens with deuteriums") 71 mass_amu[species == 1] *= 2.0 72 mass = mass_amu * (au.MPROT * (au.FS / au.BOHR) ** 2) 73 74 ### GET TEMPERATURE 75 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 76 kT = temperature / au.KELVIN 77 totmass_amu = mass_amu.sum()/6.02214129e-1 78 79 ## SYSTEM DATA 80 system_data = { 81 "name": system_name, 82 "nat": nat, 83 "symbols": symbols, 84 "species": species, 85 "mass": mass, 86 "temperature": temperature, 87 "kT": kT, 88 "totmass_amu": totmass_amu, 89 } 90 91 ### Set boundary conditions 92 cell = simulation_parameters.get("cell", None) 93 if cell is not None: 94 cell = np.array(cell, dtype=fprec).reshape(3, 3) 95 reciprocal_cell = np.linalg.inv(cell) 96 volume = np.abs(np.linalg.det(cell)) 97 print("# cell matrix:") 98 for l in cell: 99 print("# ", l) 100 # print(cell) 101 dens = totmass_amu / volume 102 print("# density: ", dens.item(), " g/cm^3") 103 minimum_image = simulation_parameters.get("minimum_image", True) 104 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 105 print("# minimum_image: ", minimum_image) 106 107 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 108 if crystal_input: 109 coordinates = coordinates @ cell 110 111 pbc_data = { 112 "cell": cell, 113 "reciprocal_cell": reciprocal_cell, 114 "volume": volume, 115 "minimum_image": minimum_image, 116 "estimate_pressure": estimate_pressure, 117 } 118 else: 119 pbc_data = None 120 system_data["pbc"] = pbc_data 121 122 ### PIMD 123 nbeads = simulation_parameters.get("nbeads", None) 124 if nbeads is not None: 125 nbeads = int(nbeads) 126 print("# nbeads: ", nbeads) 127 system_data["nbeads"] = nbeads 128 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1,3) 129 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 130 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 131 natoms = np.array([nat] * nbeads, dtype=np.int32) 132 133 eigmat = np.zeros((nbeads, nbeads)) 134 for i in range(nbeads - 1): 135 eigmat[i, i] = 2.0 136 eigmat[i, i + 1] = -1.0 137 eigmat[i + 1, i] = -1.0 138 eigmat[-1, -1] = 2.0 139 eigmat[0, -1] = -1.0 140 eigmat[-1, 0] = -1.0 141 omk, eigmat = np.linalg.eigh(eigmat) 142 omk[0] = 0.0 143 omk = nbeads * kT * omk**0.5 / au.FS 144 for i in range(nbeads): 145 if eigmat[i, 0] < 0: 146 eigmat[i] *= -1.0 147 eigmat = jnp.asarray(eigmat, dtype=fprec) 148 system_data["omk"] = omk 149 system_data["eigmat"] = eigmat 150 else: 151 bead_index = np.array([0] * nat, dtype=np.int32) 152 natoms = np.array([nat], dtype=np.int32) 153 154 conformation = { 155 "species": species, 156 "coordinates": coordinates, 157 "batch_index": bead_index, 158 "natoms": natoms, 159 } 160 if cell is not None: 161 cell = cell[None, :, :] 162 reciprocal_cell = reciprocal_cell[None, :, :] 163 if nbeads is not None: 164 cell = np.repeat(cell, nbeads, axis=0) 165 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 166 conformation["cells"] = cell 167 conformation["reciprocal_cells"] = reciprocal_cell 168 169 additional_keys = simulation_parameters.get("additional_keys", {}) 170 for key, value in additional_keys.items(): 171 conformation[key] = value 172 173 return system_data, conformation
def
initialize_preprocessing(simulation_parameters, model, conformation, system_data):
176def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 177 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 178 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 179 pbc_data = system_data.get("pbc", None) 180 181 ### CONFIGURE PREPROCESSING 182 preproc_state = unfreeze(model.preproc_state) 183 layer_state = [] 184 for st in preproc_state["layers_state"]: 185 stnew = unfreeze(st) 186 if pbc_data is not None: 187 stnew["minimum_image"] = pbc_data["minimum_image"] 188 if nblist_skin > 0: 189 stnew["nblist_skin"] = nblist_skin 190 if "nblist_mult_size" in simulation_parameters: 191 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 192 if "nblist_add_neigh" in simulation_parameters: 193 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 194 layer_state.append(freeze(stnew)) 195 preproc_state["layers_state"] = layer_state 196 preproc_state = freeze(preproc_state) 197 198 ## initial preprocessing 199 preproc_state = preproc_state.copy({"check_input": True}) 200 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 201 202 preproc_state = preproc_state.copy({"check_input": False}) 203 204 if nblist_verbose: 205 graphs_keys = list(model._graphs_properties.keys()) 206 print("# graphs_keys: ", graphs_keys) 207 print("# nblist state:", preproc_state) 208 209 ### print model 210 if simulation_parameters.get("print_model", False): 211 print(model.summarize(example_data=conformation)) 212 213 return preproc_state, conformation
def
initialize_system(conformation, vel, model, system_data, fprec):
216def initialize_system(conformation, vel, model, system_data, fprec): 217 ## initial energy and forces 218 print("# Computing initial energy and forces") 219 e, f, vir,_ = model._energy_and_forces_and_virial(model.variables, conformation) 220 model_energy_unit = model.Ha_to_model_energy 221 f = np.array(f) / model_energy_unit 222 epot = np.mean(e) / model_energy_unit 223 vir = np.mean(vir, axis=0) / model_energy_unit 224 225 if "nbeads" in system_data: 226 ek = 0.5 * jnp.sum(system_data["mass"][:, None,None] * vel[0,:,:,None]*vel[0,:,None,:],axis=0) 227 else: 228 ek = 0.5 * jnp.sum(system_data["mass"][:, None,None] * vel[:,:,None]*vel[:,None,:],axis=0) 229 230 ## build system 231 system = {} 232 system["ek_tensor"] = ek 233 system["ek"] = jnp.trace(ek) 234 system["epot"] = epot 235 system["vel"] = vel.astype(fprec) 236 if "cells" in conformation: 237 system["virial"] = vir 238 system["cell"] = conformation["cells"][0] 239 if "nbeads" in system_data: 240 nbeads = system_data["nbeads"] 241 coordinates = conformation["coordinates"].reshape(nbeads, -1, 3) 242 eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0]) 243 system["coordinates"] = eigx 244 system["forces"] = jnp.einsum( 245 "in,i...->n...", system_data["eigmat"], f.reshape(nbeads, -1, 3) 246 ) * (1.0 / nbeads**0.5) 247 else: 248 system["coordinates"] = conformation["coordinates"] 249 system["forces"] = f 250 251 return system