fennol.md.initial
1from pathlib import Path 2 3import numpy as np 4import jax.numpy as jnp 5 6from flax.core import freeze, unfreeze 7 8from ..utils.io import last_xyz_frame 9 10 11from ..models import FENNIX 12 13from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES 14from .utils import us 15from ..utils import detect_topology,parse_cell 16 17 18def load_model(simulation_parameters): 19 model_file = simulation_parameters.get("model_file") 20 """@keyword[fennol_md] model_file 21 Path to the machine learning model file (.fnx format). Required parameter. 22 Type: str, Required 23 """ 24 model_file = Path(str(model_file).strip()) 25 if not model_file.exists(): 26 raise FileNotFoundError(f"model file {model_file} not found") 27 else: 28 graph_config = simulation_parameters.get("graph_config", {}) 29 """@keyword[fennol_md] graph_config 30 Advanced graph configuration for model initialization. 31 Default: {} 32 """ 33 model = FENNIX.load(model_file, graph_config=graph_config) # \ 34 print(f"# model_file: {model_file}") 35 36 if "energy_terms" in simulation_parameters: 37 energy_terms = simulation_parameters["energy_terms"] 38 if isinstance(energy_terms, str): 39 energy_terms = energy_terms.split() 40 model.set_energy_terms(energy_terms) 41 print("# energy terms:", model.energy_terms) 42 43 return model 44 45 46def load_system_data(simulation_parameters, fprec): 47 ## LOAD SYSTEM CONFORMATION FROM FILES 48 system_name = str(simulation_parameters.get("system_name", "system")).strip() 49 """@keyword[fennol_md] system_name 50 Name prefix for output files. If not specified, uses the xyz filename stem. 51 Default: "system" 52 """ 53 indexed = simulation_parameters.get("xyz_input/indexed", False) 54 """@keyword[fennol_md] xyz_input/indexed 55 Whether first column contains atom indices (Tinker format). 56 Default: False 57 """ 58 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True) 59 """@keyword[fennol_md] xyz_input/has_comment_line 60 Whether file contains comment lines. 61 Default: True 62 """ 63 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 64 """@keyword[fennol_md] xyz_input/file 65 Path to xyz/arc coordinate file. Required parameter. 66 Type: str, Required 67 """ 68 if not xyzfile.exists(): 69 raise FileNotFoundError(f"xyz file {xyzfile} not found") 70 system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip() 71 symbols, coordinates, _ = last_xyz_frame( 72 xyzfile, indexed=indexed, has_comment_line=has_comment_line 73 ) 74 coordinates = coordinates.astype(fprec) 75 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 76 nat = species.shape[0] 77 78 ## GET MASS 79 mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species] 80 deuterate = simulation_parameters.get("deuterate", False) 81 """@keyword[fennol_md] deuterate 82 Replace hydrogen masses with deuterium masses. 83 Default: False 84 """ 85 if deuterate: 86 print("# Replacing all hydrogens with deuteriums") 87 mass_Da[species == 1] *= 2.0 88 89 totmass_Da = mass_Da.sum() 90 91 mass = mass_Da.copy() 92 hmr = simulation_parameters.get("hmr", 0) 93 """@keyword[fennol_md] hmr 94 Hydrogen mass repartitioning factor. 0 = no repartitioning. 95 Default: 0 96 """ 97 if hmr > 0: 98 print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.") 99 Hmask = species == 1 100 added_mass = hmr * Hmask.sum() 101 mass[Hmask] += hmr 102 wmass = mass[~Hmask] 103 mass[~Hmask] -= added_mass * wmass/ wmass.sum() 104 105 assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed" 106 107 # convert to internal units 108 mass = mass / us.DA 109 110 ### GET TEMPERATURE 111 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 112 """@keyword[fennol_md] temperature 113 Target temperature in Kelvin. 114 Default: 300.0 115 """ 116 kT = us.K_B * temperature 117 118 ### GET TOTAL CHARGE 119 total_charge = simulation_parameters.get("total_charge", None) 120 """@keyword[fennol_md] total_charge 121 Total system charge for charged systems. 122 Default: None (interpreted as 0) 123 """ 124 if total_charge is None: 125 total_charge = 0 126 else: 127 total_charge = int(total_charge) 128 print("# total charge: ", total_charge,"e") 129 130 ### ENERGY UNIT 131 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 132 """@keyword[fennol_md] energy_unit 133 Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'. 134 Default: "kcal/mol" 135 """ 136 energy_unit = us.get_multiplier(energy_unit_str) 137 138 ## SYSTEM DATA 139 system_data = { 140 "name": system_name, 141 "nat": nat, 142 "symbols": symbols, 143 "species": species, 144 "mass": mass, 145 "mass_Da": mass_Da, 146 "totmass_Da": totmass_Da, 147 "temperature": temperature, 148 "kT": kT, 149 "total_charge": total_charge, 150 "energy_unit": energy_unit, 151 "energy_unit_str": energy_unit_str, 152 } 153 input_flags = simulation_parameters.get("model_flags", []) 154 """@keyword[fennol_md] model_flags 155 Additional flags to pass to the model. 156 Default: [] 157 """ 158 flags = {f:None for f in input_flags} 159 160 ### Set boundary conditions 161 cell = simulation_parameters.get("cell", None) 162 """@keyword[fennol_md] cell 163 Unit cell vectors. Required for PBC. It is a sequence of floats: 164 - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz] 165 - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma] 166 - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic) 167 - 1 float: length of cell vectors (cubic cell) 168 Lengths are in Angstroms, angles in degrees. 169 Default: None 170 """ 171 if cell is not None: 172 cell = parse_cell(cell).astype(fprec) 173 # cell = np.array(cell, dtype=fprec).reshape(3, 3) 174 reciprocal_cell = np.linalg.inv(cell) 175 volume = np.abs(np.linalg.det(cell)) 176 print("# cell matrix:") 177 for l in cell: 178 print("# ", l) 179 # print(cell) 180 dens = (totmass_Da/volume) * (us.MOL/us.CM**3) 181 print("# density: ", dens.item(), " g/cm^3") 182 minimum_image = simulation_parameters.get("minimum_image", True) 183 """@keyword[fennol_md] minimum_image 184 Use minimum image convention for neighbor lists in periodic systems. 185 Default: True 186 """ 187 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 188 """@keyword[fennol_md] estimate_pressure 189 Calculate and print pressure during simulation. 190 Default: False 191 """ 192 print("# minimum_image: ", minimum_image) 193 194 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 195 """@keyword[fennol_md] xyz_input/crystal 196 Use crystal coordinates. 197 Default: False 198 """ 199 if crystal_input: 200 coordinates = coordinates @ cell 201 202 pbc_data = { 203 "cell": cell, 204 "reciprocal_cell": reciprocal_cell, 205 "volume": volume, 206 "minimum_image": minimum_image, 207 "estimate_pressure": estimate_pressure, 208 } 209 if minimum_image: 210 flags["minimum_image"] = None 211 else: 212 pbc_data = None 213 system_data["pbc"] = pbc_data 214 system_data["initial_coordinates"] = coordinates.copy() 215 216 ### TOPOLOGY 217 topology_key = simulation_parameters.get("topology", None) 218 """@keyword[fennol_md] topology 219 Topology specification for molecular systems. Use "detect" for automatic detection. 220 Default: None 221 """ 222 if topology_key is not None: 223 topology_key = str(topology_key).strip() 224 if topology_key.lower() == "detect": 225 topology = detect_topology(species,coordinates,cell=cell) 226 np.savetxt(system_name +".topo", topology+1, fmt="%d") 227 print("# Detected topology saved to", system_name + ".topo") 228 else: 229 assert Path(topology_key).exists(), f"Topology file {topology_key} not found" 230 topology = np.loadtxt(topology_key, dtype=np.int32)-1 231 assert topology.shape[1] == 2, "Topology file must have two columns (source, target)" 232 print("# Topology loaded from", topology_key) 233 else: 234 topology = None 235 236 system_data["topology"] = topology 237 238 ### PIMD 239 nbeads = simulation_parameters.get("nbeads", None) 240 """@keyword[fennol_md] nbeads 241 Number of beads for Path Integral MD. 242 Default: None 243 """ 244 nreplicas = simulation_parameters.get("nreplicas", None) 245 """@keyword[fennol_md] nreplicas 246 Number of replicas for independent replica simulations. 247 Default: None 248 """ 249 if nbeads is not None: 250 nbeads = int(nbeads) 251 print("# nbeads: ", nbeads) 252 system_data["nbeads"] = nbeads 253 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3) 254 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 255 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 256 natoms = np.array([nat] * nbeads, dtype=np.int32) 257 258 eigmat = np.zeros((nbeads, nbeads)) 259 for i in range(nbeads - 1): 260 eigmat[i, i] = 2.0 261 eigmat[i, i + 1] = -1.0 262 eigmat[i + 1, i] = -1.0 263 eigmat[-1, -1] = 2.0 264 eigmat[0, -1] = -1.0 265 eigmat[-1, 0] = -1.0 266 omk, eigmat = np.linalg.eigh(eigmat) 267 omk[0] = 0.0 268 omk = (nbeads * kT / us.HBAR) * omk**0.5 269 for i in range(nbeads): 270 if eigmat[i, 0] < 0: 271 eigmat[i] *= -1.0 272 eigmat = jnp.asarray(eigmat, dtype=fprec) 273 system_data["omk"] = omk 274 system_data["eigmat"] = eigmat 275 nreplicas = None 276 elif nreplicas is not None: 277 nreplicas = int(nreplicas) 278 print("# nreplicas: ", nreplicas) 279 system_data["nreplicas"] = nreplicas 280 system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1) 281 system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape( 282 -1 283 ) 284 coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape( 285 -1, 3 286 ) 287 species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1) 288 bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat) 289 natoms = np.array([nat] * nreplicas, dtype=np.int32) 290 else: 291 system_data["nreplicas"] = 1 292 bead_index = np.array([0] * nat, dtype=np.int32) 293 natoms = np.array([nat], dtype=np.int32) 294 295 conformation = { 296 "species": species, 297 "coordinates": coordinates, 298 "batch_index": bead_index, 299 "natoms": natoms, 300 "total_charge": total_charge, 301 } 302 if cell is not None: 303 cell = cell[None, :, :] 304 reciprocal_cell = reciprocal_cell[None, :, :] 305 if nbeads is not None: 306 cell = np.repeat(cell, nbeads, axis=0) 307 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 308 elif nreplicas is not None: 309 cell = np.repeat(cell, nreplicas, axis=0) 310 reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0) 311 conformation["cells"] = cell 312 conformation["reciprocal_cells"] = reciprocal_cell 313 314 additional_keys = simulation_parameters.get("additional_keys", {}) 315 """@keyword[fennol_md] additional_keys 316 Additional custom keys for model input. 317 Default: {} 318 """ 319 for key, value in additional_keys.items(): 320 conformation[key] = value 321 322 conformation["flags"] = flags 323 324 return system_data, conformation 325 326 327def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 328 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 329 """@keyword[fennol_md] nblist_verbose 330 Print detailed neighbor list information. 331 Default: False 332 """ 333 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 334 """@keyword[fennol_md] nblist_skin 335 Neighbor list skin distance in Angstroms. 336 Default: -1.0 (automatic) 337 """ 338 339 ### CONFIGURE PREPROCESSING 340 preproc_state = unfreeze(model.preproc_state) 341 layer_state = [] 342 for st in preproc_state["layers_state"]: 343 stnew = unfreeze(st) 344 if nblist_skin > 0: 345 stnew["nblist_skin"] = nblist_skin 346 if "nblist_mult_size" in simulation_parameters: 347 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 348 """@keyword[fennol_md] nblist_mult_size 349 Multiplier for neighbor list size. 350 Default: None 351 """ 352 if "nblist_add_neigh" in simulation_parameters: 353 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 354 """@keyword[fennol_md] nblist_add_neigh 355 Additional neighbors to include in lists. 356 Default: None 357 """ 358 layer_state.append(freeze(stnew)) 359 preproc_state["layers_state"] = layer_state 360 preproc_state = freeze(preproc_state) 361 362 ## initial preprocessing 363 preproc_state = preproc_state.copy({"check_input": True}) 364 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 365 366 preproc_state = preproc_state.copy({"check_input": False}) 367 368 if nblist_verbose: 369 graphs_keys = list(model._graphs_properties.keys()) 370 print("# graphs_keys: ", graphs_keys) 371 print("# nblist state:", preproc_state) 372 373 ### print model 374 if simulation_parameters.get("print_model", False): 375 """@keyword[fennol_md] print_model 376 Print detailed model information at startup. 377 Default: False 378 """ 379 print(model.summarize(example_data=conformation)) 380 381 return preproc_state, conformation
def
load_model(simulation_parameters):
19def load_model(simulation_parameters): 20 model_file = simulation_parameters.get("model_file") 21 """@keyword[fennol_md] model_file 22 Path to the machine learning model file (.fnx format). Required parameter. 23 Type: str, Required 24 """ 25 model_file = Path(str(model_file).strip()) 26 if not model_file.exists(): 27 raise FileNotFoundError(f"model file {model_file} not found") 28 else: 29 graph_config = simulation_parameters.get("graph_config", {}) 30 """@keyword[fennol_md] graph_config 31 Advanced graph configuration for model initialization. 32 Default: {} 33 """ 34 model = FENNIX.load(model_file, graph_config=graph_config) # \ 35 print(f"# model_file: {model_file}") 36 37 if "energy_terms" in simulation_parameters: 38 energy_terms = simulation_parameters["energy_terms"] 39 if isinstance(energy_terms, str): 40 energy_terms = energy_terms.split() 41 model.set_energy_terms(energy_terms) 42 print("# energy terms:", model.energy_terms) 43 44 return model
def
load_system_data(simulation_parameters, fprec):
47def load_system_data(simulation_parameters, fprec): 48 ## LOAD SYSTEM CONFORMATION FROM FILES 49 system_name = str(simulation_parameters.get("system_name", "system")).strip() 50 """@keyword[fennol_md] system_name 51 Name prefix for output files. If not specified, uses the xyz filename stem. 52 Default: "system" 53 """ 54 indexed = simulation_parameters.get("xyz_input/indexed", False) 55 """@keyword[fennol_md] xyz_input/indexed 56 Whether first column contains atom indices (Tinker format). 57 Default: False 58 """ 59 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True) 60 """@keyword[fennol_md] xyz_input/has_comment_line 61 Whether file contains comment lines. 62 Default: True 63 """ 64 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 65 """@keyword[fennol_md] xyz_input/file 66 Path to xyz/arc coordinate file. Required parameter. 67 Type: str, Required 68 """ 69 if not xyzfile.exists(): 70 raise FileNotFoundError(f"xyz file {xyzfile} not found") 71 system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip() 72 symbols, coordinates, _ = last_xyz_frame( 73 xyzfile, indexed=indexed, has_comment_line=has_comment_line 74 ) 75 coordinates = coordinates.astype(fprec) 76 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 77 nat = species.shape[0] 78 79 ## GET MASS 80 mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species] 81 deuterate = simulation_parameters.get("deuterate", False) 82 """@keyword[fennol_md] deuterate 83 Replace hydrogen masses with deuterium masses. 84 Default: False 85 """ 86 if deuterate: 87 print("# Replacing all hydrogens with deuteriums") 88 mass_Da[species == 1] *= 2.0 89 90 totmass_Da = mass_Da.sum() 91 92 mass = mass_Da.copy() 93 hmr = simulation_parameters.get("hmr", 0) 94 """@keyword[fennol_md] hmr 95 Hydrogen mass repartitioning factor. 0 = no repartitioning. 96 Default: 0 97 """ 98 if hmr > 0: 99 print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.") 100 Hmask = species == 1 101 added_mass = hmr * Hmask.sum() 102 mass[Hmask] += hmr 103 wmass = mass[~Hmask] 104 mass[~Hmask] -= added_mass * wmass/ wmass.sum() 105 106 assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed" 107 108 # convert to internal units 109 mass = mass / us.DA 110 111 ### GET TEMPERATURE 112 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 113 """@keyword[fennol_md] temperature 114 Target temperature in Kelvin. 115 Default: 300.0 116 """ 117 kT = us.K_B * temperature 118 119 ### GET TOTAL CHARGE 120 total_charge = simulation_parameters.get("total_charge", None) 121 """@keyword[fennol_md] total_charge 122 Total system charge for charged systems. 123 Default: None (interpreted as 0) 124 """ 125 if total_charge is None: 126 total_charge = 0 127 else: 128 total_charge = int(total_charge) 129 print("# total charge: ", total_charge,"e") 130 131 ### ENERGY UNIT 132 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 133 """@keyword[fennol_md] energy_unit 134 Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'. 135 Default: "kcal/mol" 136 """ 137 energy_unit = us.get_multiplier(energy_unit_str) 138 139 ## SYSTEM DATA 140 system_data = { 141 "name": system_name, 142 "nat": nat, 143 "symbols": symbols, 144 "species": species, 145 "mass": mass, 146 "mass_Da": mass_Da, 147 "totmass_Da": totmass_Da, 148 "temperature": temperature, 149 "kT": kT, 150 "total_charge": total_charge, 151 "energy_unit": energy_unit, 152 "energy_unit_str": energy_unit_str, 153 } 154 input_flags = simulation_parameters.get("model_flags", []) 155 """@keyword[fennol_md] model_flags 156 Additional flags to pass to the model. 157 Default: [] 158 """ 159 flags = {f:None for f in input_flags} 160 161 ### Set boundary conditions 162 cell = simulation_parameters.get("cell", None) 163 """@keyword[fennol_md] cell 164 Unit cell vectors. Required for PBC. It is a sequence of floats: 165 - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz] 166 - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma] 167 - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic) 168 - 1 float: length of cell vectors (cubic cell) 169 Lengths are in Angstroms, angles in degrees. 170 Default: None 171 """ 172 if cell is not None: 173 cell = parse_cell(cell).astype(fprec) 174 # cell = np.array(cell, dtype=fprec).reshape(3, 3) 175 reciprocal_cell = np.linalg.inv(cell) 176 volume = np.abs(np.linalg.det(cell)) 177 print("# cell matrix:") 178 for l in cell: 179 print("# ", l) 180 # print(cell) 181 dens = (totmass_Da/volume) * (us.MOL/us.CM**3) 182 print("# density: ", dens.item(), " g/cm^3") 183 minimum_image = simulation_parameters.get("minimum_image", True) 184 """@keyword[fennol_md] minimum_image 185 Use minimum image convention for neighbor lists in periodic systems. 186 Default: True 187 """ 188 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 189 """@keyword[fennol_md] estimate_pressure 190 Calculate and print pressure during simulation. 191 Default: False 192 """ 193 print("# minimum_image: ", minimum_image) 194 195 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 196 """@keyword[fennol_md] xyz_input/crystal 197 Use crystal coordinates. 198 Default: False 199 """ 200 if crystal_input: 201 coordinates = coordinates @ cell 202 203 pbc_data = { 204 "cell": cell, 205 "reciprocal_cell": reciprocal_cell, 206 "volume": volume, 207 "minimum_image": minimum_image, 208 "estimate_pressure": estimate_pressure, 209 } 210 if minimum_image: 211 flags["minimum_image"] = None 212 else: 213 pbc_data = None 214 system_data["pbc"] = pbc_data 215 system_data["initial_coordinates"] = coordinates.copy() 216 217 ### TOPOLOGY 218 topology_key = simulation_parameters.get("topology", None) 219 """@keyword[fennol_md] topology 220 Topology specification for molecular systems. Use "detect" for automatic detection. 221 Default: None 222 """ 223 if topology_key is not None: 224 topology_key = str(topology_key).strip() 225 if topology_key.lower() == "detect": 226 topology = detect_topology(species,coordinates,cell=cell) 227 np.savetxt(system_name +".topo", topology+1, fmt="%d") 228 print("# Detected topology saved to", system_name + ".topo") 229 else: 230 assert Path(topology_key).exists(), f"Topology file {topology_key} not found" 231 topology = np.loadtxt(topology_key, dtype=np.int32)-1 232 assert topology.shape[1] == 2, "Topology file must have two columns (source, target)" 233 print("# Topology loaded from", topology_key) 234 else: 235 topology = None 236 237 system_data["topology"] = topology 238 239 ### PIMD 240 nbeads = simulation_parameters.get("nbeads", None) 241 """@keyword[fennol_md] nbeads 242 Number of beads for Path Integral MD. 243 Default: None 244 """ 245 nreplicas = simulation_parameters.get("nreplicas", None) 246 """@keyword[fennol_md] nreplicas 247 Number of replicas for independent replica simulations. 248 Default: None 249 """ 250 if nbeads is not None: 251 nbeads = int(nbeads) 252 print("# nbeads: ", nbeads) 253 system_data["nbeads"] = nbeads 254 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3) 255 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 256 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 257 natoms = np.array([nat] * nbeads, dtype=np.int32) 258 259 eigmat = np.zeros((nbeads, nbeads)) 260 for i in range(nbeads - 1): 261 eigmat[i, i] = 2.0 262 eigmat[i, i + 1] = -1.0 263 eigmat[i + 1, i] = -1.0 264 eigmat[-1, -1] = 2.0 265 eigmat[0, -1] = -1.0 266 eigmat[-1, 0] = -1.0 267 omk, eigmat = np.linalg.eigh(eigmat) 268 omk[0] = 0.0 269 omk = (nbeads * kT / us.HBAR) * omk**0.5 270 for i in range(nbeads): 271 if eigmat[i, 0] < 0: 272 eigmat[i] *= -1.0 273 eigmat = jnp.asarray(eigmat, dtype=fprec) 274 system_data["omk"] = omk 275 system_data["eigmat"] = eigmat 276 nreplicas = None 277 elif nreplicas is not None: 278 nreplicas = int(nreplicas) 279 print("# nreplicas: ", nreplicas) 280 system_data["nreplicas"] = nreplicas 281 system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1) 282 system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape( 283 -1 284 ) 285 coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape( 286 -1, 3 287 ) 288 species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1) 289 bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat) 290 natoms = np.array([nat] * nreplicas, dtype=np.int32) 291 else: 292 system_data["nreplicas"] = 1 293 bead_index = np.array([0] * nat, dtype=np.int32) 294 natoms = np.array([nat], dtype=np.int32) 295 296 conformation = { 297 "species": species, 298 "coordinates": coordinates, 299 "batch_index": bead_index, 300 "natoms": natoms, 301 "total_charge": total_charge, 302 } 303 if cell is not None: 304 cell = cell[None, :, :] 305 reciprocal_cell = reciprocal_cell[None, :, :] 306 if nbeads is not None: 307 cell = np.repeat(cell, nbeads, axis=0) 308 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 309 elif nreplicas is not None: 310 cell = np.repeat(cell, nreplicas, axis=0) 311 reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0) 312 conformation["cells"] = cell 313 conformation["reciprocal_cells"] = reciprocal_cell 314 315 additional_keys = simulation_parameters.get("additional_keys", {}) 316 """@keyword[fennol_md] additional_keys 317 Additional custom keys for model input. 318 Default: {} 319 """ 320 for key, value in additional_keys.items(): 321 conformation[key] = value 322 323 conformation["flags"] = flags 324 325 return system_data, conformation
def
initialize_preprocessing(simulation_parameters, model, conformation, system_data):
328def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 329 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 330 """@keyword[fennol_md] nblist_verbose 331 Print detailed neighbor list information. 332 Default: False 333 """ 334 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 335 """@keyword[fennol_md] nblist_skin 336 Neighbor list skin distance in Angstroms. 337 Default: -1.0 (automatic) 338 """ 339 340 ### CONFIGURE PREPROCESSING 341 preproc_state = unfreeze(model.preproc_state) 342 layer_state = [] 343 for st in preproc_state["layers_state"]: 344 stnew = unfreeze(st) 345 if nblist_skin > 0: 346 stnew["nblist_skin"] = nblist_skin 347 if "nblist_mult_size" in simulation_parameters: 348 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 349 """@keyword[fennol_md] nblist_mult_size 350 Multiplier for neighbor list size. 351 Default: None 352 """ 353 if "nblist_add_neigh" in simulation_parameters: 354 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 355 """@keyword[fennol_md] nblist_add_neigh 356 Additional neighbors to include in lists. 357 Default: None 358 """ 359 layer_state.append(freeze(stnew)) 360 preproc_state["layers_state"] = layer_state 361 preproc_state = freeze(preproc_state) 362 363 ## initial preprocessing 364 preproc_state = preproc_state.copy({"check_input": True}) 365 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 366 367 preproc_state = preproc_state.copy({"check_input": False}) 368 369 if nblist_verbose: 370 graphs_keys = list(model._graphs_properties.keys()) 371 print("# graphs_keys: ", graphs_keys) 372 print("# nblist state:", preproc_state) 373 374 ### print model 375 if simulation_parameters.get("print_model", False): 376 """@keyword[fennol_md] print_model 377 Print detailed model information at startup. 378 Default: False 379 """ 380 print(model.summarize(example_data=conformation)) 381 382 return preproc_state, conformation