fennol.md.initial
1from pathlib import Path 2 3import numpy as np 4import jax.numpy as jnp 5 6from flax.core import freeze, unfreeze 7 8from ..utils.io import last_xyz_frame 9 10 11from ..models import FENNIX 12 13from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES 14from ..utils.atomic_units import AtomicUnits as au 15 16 17def load_model(simulation_parameters): 18 model_file = simulation_parameters.get("model_file") 19 model_file = Path(str(model_file).strip()) 20 if not model_file.exists(): 21 raise FileNotFoundError(f"model file {model_file} not found") 22 else: 23 graph_config = simulation_parameters.get("graph_config", {}) 24 model = FENNIX.load(model_file, graph_config=graph_config) # \ 25 print(f"# model_file: {model_file}") 26 27 if "energy_terms" in simulation_parameters: 28 energy_terms = simulation_parameters["energy_terms"] 29 if isinstance(energy_terms, str): 30 energy_terms = energy_terms.split() 31 model.set_energy_terms(energy_terms) 32 print("# energy terms:", model.energy_terms) 33 34 return model 35 36 37def load_system_data(simulation_parameters, fprec): 38 ## LOAD SYSTEM CONFORMATION FROM FILES 39 system_name = str(simulation_parameters.get("system", "system")).strip() 40 indexed = simulation_parameters.get("xyz_input/indexed", True) 41 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", False) 42 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 43 if not xyzfile.exists(): 44 raise FileNotFoundError(f"xyz file {xyzfile} not found") 45 system_name = str(simulation_parameters.get("system", xyzfile.stem)).strip() 46 symbols, coordinates, _ = last_xyz_frame( 47 xyzfile, indexed=indexed, has_comment_line=has_comment_line 48 ) 49 coordinates = coordinates.astype(fprec) 50 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 51 nat = species.shape[0] 52 53 ## GET MASS 54 mass_amu = np.array(ATOMIC_MASSES, dtype=fprec)[species] 55 deuterate = simulation_parameters.get("deuterate", False) 56 if deuterate: 57 print("# Replacing all hydrogens with deuteriums") 58 mass_amu[species == 1] *= 2.0 59 mass = mass_amu * (au.MPROT * (au.FS / au.BOHR) ** 2) 60 61 ### GET TEMPERATURE 62 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 63 kT = temperature / au.KELVIN 64 totmass_amu = mass_amu.sum() / 6.02214129e-1 65 66 ### GET TOTAL CHARGE 67 total_charge = simulation_parameters.get("total_charge", 0.0) 68 69 ### ENERGY UNIT 70 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 71 energy_unit = au.get_multiplier(energy_unit_str) 72 73 ## SYSTEM DATA 74 system_data = { 75 "name": system_name, 76 "nat": nat, 77 "symbols": symbols, 78 "species": species, 79 "mass": mass, 80 "temperature": temperature, 81 "kT": kT, 82 "totmass_amu": totmass_amu, 83 "total_charge": total_charge, 84 "energy_unit": energy_unit, 85 "energy_unit_str": energy_unit_str, 86 } 87 88 ### Set boundary conditions 89 cell = simulation_parameters.get("cell", None) 90 if cell is not None: 91 cell = np.array(cell, dtype=fprec).reshape(3, 3) 92 reciprocal_cell = np.linalg.inv(cell) 93 volume = np.abs(np.linalg.det(cell)) 94 print("# cell matrix:") 95 for l in cell: 96 print("# ", l) 97 # print(cell) 98 dens = totmass_amu / volume 99 print("# density: ", dens.item(), " g/cm^3") 100 minimum_image = simulation_parameters.get("minimum_image", True) 101 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 102 print("# minimum_image: ", minimum_image) 103 104 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 105 if crystal_input: 106 coordinates = coordinates @ cell 107 108 pbc_data = { 109 "cell": cell, 110 "reciprocal_cell": reciprocal_cell, 111 "volume": volume, 112 "minimum_image": minimum_image, 113 "estimate_pressure": estimate_pressure, 114 } 115 else: 116 pbc_data = None 117 system_data["pbc"] = pbc_data 118 119 ### PIMD 120 nbeads = simulation_parameters.get("nbeads", None) 121 nreplicas = simulation_parameters.get("nreplicas", None) 122 if nbeads is not None: 123 nbeads = int(nbeads) 124 print("# nbeads: ", nbeads) 125 system_data["nbeads"] = nbeads 126 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3) 127 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 128 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 129 natoms = np.array([nat] * nbeads, dtype=np.int32) 130 131 eigmat = np.zeros((nbeads, nbeads)) 132 for i in range(nbeads - 1): 133 eigmat[i, i] = 2.0 134 eigmat[i, i + 1] = -1.0 135 eigmat[i + 1, i] = -1.0 136 eigmat[-1, -1] = 2.0 137 eigmat[0, -1] = -1.0 138 eigmat[-1, 0] = -1.0 139 omk, eigmat = np.linalg.eigh(eigmat) 140 omk[0] = 0.0 141 omk = nbeads * kT * omk**0.5 / au.FS 142 for i in range(nbeads): 143 if eigmat[i, 0] < 0: 144 eigmat[i] *= -1.0 145 eigmat = jnp.asarray(eigmat, dtype=fprec) 146 system_data["omk"] = omk 147 system_data["eigmat"] = eigmat 148 nreplicas = None 149 elif nreplicas is not None: 150 nreplicas = int(nreplicas) 151 print("# nreplicas: ", nreplicas) 152 system_data["nreplicas"] = nreplicas 153 system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1) 154 system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape( 155 -1 156 ) 157 coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape( 158 -1, 3 159 ) 160 species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1) 161 bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat) 162 natoms = np.array([nat] * nreplicas, dtype=np.int32) 163 else: 164 system_data["nreplicas"] = 1 165 bead_index = np.array([0] * nat, dtype=np.int32) 166 natoms = np.array([nat], dtype=np.int32) 167 168 conformation = { 169 "species": species, 170 "coordinates": coordinates, 171 "batch_index": bead_index, 172 "natoms": natoms, 173 "total_charge": total_charge, 174 } 175 if cell is not None: 176 cell = cell[None, :, :] 177 reciprocal_cell = reciprocal_cell[None, :, :] 178 if nbeads is not None: 179 cell = np.repeat(cell, nbeads, axis=0) 180 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 181 elif nreplicas is not None: 182 cell = np.repeat(cell, nreplicas, axis=0) 183 reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0) 184 conformation["cells"] = cell 185 conformation["reciprocal_cells"] = reciprocal_cell 186 187 additional_keys = simulation_parameters.get("additional_keys", {}) 188 for key, value in additional_keys.items(): 189 conformation[key] = value 190 191 return system_data, conformation 192 193 194def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 195 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 196 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 197 pbc_data = system_data.get("pbc", None) 198 199 ### CONFIGURE PREPROCESSING 200 preproc_state = unfreeze(model.preproc_state) 201 layer_state = [] 202 for st in preproc_state["layers_state"]: 203 stnew = unfreeze(st) 204 if pbc_data is not None: 205 stnew["minimum_image"] = pbc_data["minimum_image"] 206 if nblist_skin > 0: 207 stnew["nblist_skin"] = nblist_skin 208 if "nblist_mult_size" in simulation_parameters: 209 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 210 if "nblist_add_neigh" in simulation_parameters: 211 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 212 layer_state.append(freeze(stnew)) 213 preproc_state["layers_state"] = layer_state 214 preproc_state = freeze(preproc_state) 215 216 ## initial preprocessing 217 preproc_state = preproc_state.copy({"check_input": True}) 218 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 219 220 preproc_state = preproc_state.copy({"check_input": False}) 221 222 if nblist_verbose: 223 graphs_keys = list(model._graphs_properties.keys()) 224 print("# graphs_keys: ", graphs_keys) 225 print("# nblist state:", preproc_state) 226 227 ### print model 228 if simulation_parameters.get("print_model", False): 229 print(model.summarize(example_data=conformation)) 230 231 return preproc_state, conformation
def
load_model(simulation_parameters):
18def load_model(simulation_parameters): 19 model_file = simulation_parameters.get("model_file") 20 model_file = Path(str(model_file).strip()) 21 if not model_file.exists(): 22 raise FileNotFoundError(f"model file {model_file} not found") 23 else: 24 graph_config = simulation_parameters.get("graph_config", {}) 25 model = FENNIX.load(model_file, graph_config=graph_config) # \ 26 print(f"# model_file: {model_file}") 27 28 if "energy_terms" in simulation_parameters: 29 energy_terms = simulation_parameters["energy_terms"] 30 if isinstance(energy_terms, str): 31 energy_terms = energy_terms.split() 32 model.set_energy_terms(energy_terms) 33 print("# energy terms:", model.energy_terms) 34 35 return model
def
load_system_data(simulation_parameters, fprec):
38def load_system_data(simulation_parameters, fprec): 39 ## LOAD SYSTEM CONFORMATION FROM FILES 40 system_name = str(simulation_parameters.get("system", "system")).strip() 41 indexed = simulation_parameters.get("xyz_input/indexed", True) 42 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", False) 43 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 44 if not xyzfile.exists(): 45 raise FileNotFoundError(f"xyz file {xyzfile} not found") 46 system_name = str(simulation_parameters.get("system", xyzfile.stem)).strip() 47 symbols, coordinates, _ = last_xyz_frame( 48 xyzfile, indexed=indexed, has_comment_line=has_comment_line 49 ) 50 coordinates = coordinates.astype(fprec) 51 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 52 nat = species.shape[0] 53 54 ## GET MASS 55 mass_amu = np.array(ATOMIC_MASSES, dtype=fprec)[species] 56 deuterate = simulation_parameters.get("deuterate", False) 57 if deuterate: 58 print("# Replacing all hydrogens with deuteriums") 59 mass_amu[species == 1] *= 2.0 60 mass = mass_amu * (au.MPROT * (au.FS / au.BOHR) ** 2) 61 62 ### GET TEMPERATURE 63 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 64 kT = temperature / au.KELVIN 65 totmass_amu = mass_amu.sum() / 6.02214129e-1 66 67 ### GET TOTAL CHARGE 68 total_charge = simulation_parameters.get("total_charge", 0.0) 69 70 ### ENERGY UNIT 71 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 72 energy_unit = au.get_multiplier(energy_unit_str) 73 74 ## SYSTEM DATA 75 system_data = { 76 "name": system_name, 77 "nat": nat, 78 "symbols": symbols, 79 "species": species, 80 "mass": mass, 81 "temperature": temperature, 82 "kT": kT, 83 "totmass_amu": totmass_amu, 84 "total_charge": total_charge, 85 "energy_unit": energy_unit, 86 "energy_unit_str": energy_unit_str, 87 } 88 89 ### Set boundary conditions 90 cell = simulation_parameters.get("cell", None) 91 if cell is not None: 92 cell = np.array(cell, dtype=fprec).reshape(3, 3) 93 reciprocal_cell = np.linalg.inv(cell) 94 volume = np.abs(np.linalg.det(cell)) 95 print("# cell matrix:") 96 for l in cell: 97 print("# ", l) 98 # print(cell) 99 dens = totmass_amu / volume 100 print("# density: ", dens.item(), " g/cm^3") 101 minimum_image = simulation_parameters.get("minimum_image", True) 102 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 103 print("# minimum_image: ", minimum_image) 104 105 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 106 if crystal_input: 107 coordinates = coordinates @ cell 108 109 pbc_data = { 110 "cell": cell, 111 "reciprocal_cell": reciprocal_cell, 112 "volume": volume, 113 "minimum_image": minimum_image, 114 "estimate_pressure": estimate_pressure, 115 } 116 else: 117 pbc_data = None 118 system_data["pbc"] = pbc_data 119 120 ### PIMD 121 nbeads = simulation_parameters.get("nbeads", None) 122 nreplicas = simulation_parameters.get("nreplicas", None) 123 if nbeads is not None: 124 nbeads = int(nbeads) 125 print("# nbeads: ", nbeads) 126 system_data["nbeads"] = nbeads 127 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3) 128 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 129 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 130 natoms = np.array([nat] * nbeads, dtype=np.int32) 131 132 eigmat = np.zeros((nbeads, nbeads)) 133 for i in range(nbeads - 1): 134 eigmat[i, i] = 2.0 135 eigmat[i, i + 1] = -1.0 136 eigmat[i + 1, i] = -1.0 137 eigmat[-1, -1] = 2.0 138 eigmat[0, -1] = -1.0 139 eigmat[-1, 0] = -1.0 140 omk, eigmat = np.linalg.eigh(eigmat) 141 omk[0] = 0.0 142 omk = nbeads * kT * omk**0.5 / au.FS 143 for i in range(nbeads): 144 if eigmat[i, 0] < 0: 145 eigmat[i] *= -1.0 146 eigmat = jnp.asarray(eigmat, dtype=fprec) 147 system_data["omk"] = omk 148 system_data["eigmat"] = eigmat 149 nreplicas = None 150 elif nreplicas is not None: 151 nreplicas = int(nreplicas) 152 print("# nreplicas: ", nreplicas) 153 system_data["nreplicas"] = nreplicas 154 system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1) 155 system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape( 156 -1 157 ) 158 coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape( 159 -1, 3 160 ) 161 species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1) 162 bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat) 163 natoms = np.array([nat] * nreplicas, dtype=np.int32) 164 else: 165 system_data["nreplicas"] = 1 166 bead_index = np.array([0] * nat, dtype=np.int32) 167 natoms = np.array([nat], dtype=np.int32) 168 169 conformation = { 170 "species": species, 171 "coordinates": coordinates, 172 "batch_index": bead_index, 173 "natoms": natoms, 174 "total_charge": total_charge, 175 } 176 if cell is not None: 177 cell = cell[None, :, :] 178 reciprocal_cell = reciprocal_cell[None, :, :] 179 if nbeads is not None: 180 cell = np.repeat(cell, nbeads, axis=0) 181 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 182 elif nreplicas is not None: 183 cell = np.repeat(cell, nreplicas, axis=0) 184 reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0) 185 conformation["cells"] = cell 186 conformation["reciprocal_cells"] = reciprocal_cell 187 188 additional_keys = simulation_parameters.get("additional_keys", {}) 189 for key, value in additional_keys.items(): 190 conformation[key] = value 191 192 return system_data, conformation
def
initialize_preprocessing(simulation_parameters, model, conformation, system_data):
195def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 196 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 197 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 198 pbc_data = system_data.get("pbc", None) 199 200 ### CONFIGURE PREPROCESSING 201 preproc_state = unfreeze(model.preproc_state) 202 layer_state = [] 203 for st in preproc_state["layers_state"]: 204 stnew = unfreeze(st) 205 if pbc_data is not None: 206 stnew["minimum_image"] = pbc_data["minimum_image"] 207 if nblist_skin > 0: 208 stnew["nblist_skin"] = nblist_skin 209 if "nblist_mult_size" in simulation_parameters: 210 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 211 if "nblist_add_neigh" in simulation_parameters: 212 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 213 layer_state.append(freeze(stnew)) 214 preproc_state["layers_state"] = layer_state 215 preproc_state = freeze(preproc_state) 216 217 ## initial preprocessing 218 preproc_state = preproc_state.copy({"check_input": True}) 219 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 220 221 preproc_state = preproc_state.copy({"check_input": False}) 222 223 if nblist_verbose: 224 graphs_keys = list(model._graphs_properties.keys()) 225 print("# graphs_keys: ", graphs_keys) 226 print("# nblist state:", preproc_state) 227 228 ### print model 229 if simulation_parameters.get("print_model", False): 230 print(model.summarize(example_data=conformation)) 231 232 return preproc_state, conformation