fennol.md.initial
1from pathlib import Path 2 3import numpy as np 4import jax.numpy as jnp 5 6from flax.core import freeze, unfreeze 7 8from ..utils.io import last_xyz_frame 9 10 11from ..models import FENNIX 12 13from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES 14from ..utils.atomic_units import AtomicUnits as au 15from ..utils import detect_topology,parse_cell 16 17 18def load_model(simulation_parameters): 19 model_file = simulation_parameters.get("model_file") 20 """@keyword[fennol_md] model_file 21 Path to the machine learning model file (.fnx format). Required parameter. 22 Type: str, Required 23 """ 24 model_file = Path(str(model_file).strip()) 25 if not model_file.exists(): 26 raise FileNotFoundError(f"model file {model_file} not found") 27 else: 28 graph_config = simulation_parameters.get("graph_config", {}) 29 """@keyword[fennol_md] graph_config 30 Advanced graph configuration for model initialization. 31 Default: {} 32 """ 33 model = FENNIX.load(model_file, graph_config=graph_config) # \ 34 print(f"# model_file: {model_file}") 35 36 if "energy_terms" in simulation_parameters: 37 energy_terms = simulation_parameters["energy_terms"] 38 if isinstance(energy_terms, str): 39 energy_terms = energy_terms.split() 40 model.set_energy_terms(energy_terms) 41 print("# energy terms:", model.energy_terms) 42 43 return model 44 45 46def load_system_data(simulation_parameters, fprec): 47 ## LOAD SYSTEM CONFORMATION FROM FILES 48 system_name = str(simulation_parameters.get("system_name", "system")).strip() 49 """@keyword[fennol_md] system_name 50 Name prefix for output files. If not specified, uses the xyz filename stem. 51 Default: "system" 52 """ 53 indexed = simulation_parameters.get("xyz_input/indexed", False) 54 """@keyword[fennol_md] xyz_input/indexed 55 Whether first column contains atom indices (Tinker format). 56 Default: False 57 """ 58 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True) 59 """@keyword[fennol_md] xyz_input/has_comment_line 60 Whether file contains comment lines. 61 Default: True 62 """ 63 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 64 """@keyword[fennol_md] xyz_input/file 65 Path to xyz/arc coordinate file. Required parameter. 66 Type: str, Required 67 """ 68 if not xyzfile.exists(): 69 raise FileNotFoundError(f"xyz file {xyzfile} not found") 70 system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip() 71 symbols, coordinates, _ = last_xyz_frame( 72 xyzfile, indexed=indexed, has_comment_line=has_comment_line 73 ) 74 coordinates = coordinates.astype(fprec) 75 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 76 nat = species.shape[0] 77 78 ## GET MASS 79 mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species] 80 deuterate = simulation_parameters.get("deuterate", False) 81 """@keyword[fennol_md] deuterate 82 Replace hydrogen masses with deuterium masses. 83 Default: False 84 """ 85 if deuterate: 86 print("# Replacing all hydrogens with deuteriums") 87 mass_Da[species == 1] *= 2.0 88 89 totmass_Da = mass_Da.sum() 90 91 mass = mass_Da.copy() 92 hmr = simulation_parameters.get("hmr", 0) 93 """@keyword[fennol_md] hmr 94 Hydrogen mass repartitioning factor. 0 = no repartitioning. 95 Default: 0 96 """ 97 if hmr > 0: 98 print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.") 99 Hmask = species == 1 100 added_mass = hmr * Hmask.sum() 101 mass[Hmask] += hmr 102 wmass = mass[~Hmask] 103 mass[~Hmask] -= added_mass * wmass/ wmass.sum() 104 105 assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed" 106 107 mass = mass * (au.MPROT * (au.FS / au.BOHR) ** 2) 108 109 ### GET TEMPERATURE 110 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 111 """@keyword[fennol_md] temperature 112 Target temperature in Kelvin. 113 Default: 300.0 114 """ 115 kT = temperature / au.KELVIN 116 117 ### GET TOTAL CHARGE 118 total_charge = simulation_parameters.get("total_charge", None) 119 """@keyword[fennol_md] total_charge 120 Total system charge for charged systems. 121 Default: None (interpreted as 0) 122 """ 123 if total_charge is None: 124 total_charge = 0 125 else: 126 total_charge = int(total_charge) 127 print("# total charge: ", total_charge,"e") 128 129 ### ENERGY UNIT 130 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 131 """@keyword[fennol_md] energy_unit 132 Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'. 133 Default: "kcal/mol" 134 """ 135 energy_unit = au.get_multiplier(energy_unit_str) 136 137 ## SYSTEM DATA 138 system_data = { 139 "name": system_name, 140 "nat": nat, 141 "symbols": symbols, 142 "species": species, 143 "mass": mass, 144 "mass_Da": mass_Da, 145 "totmass_Da": totmass_Da, 146 "temperature": temperature, 147 "kT": kT, 148 "total_charge": total_charge, 149 "energy_unit": energy_unit, 150 "energy_unit_str": energy_unit_str, 151 } 152 input_flags = simulation_parameters.get("model_flags", []) 153 """@keyword[fennol_md] model_flags 154 Additional flags to pass to the model. 155 Default: [] 156 """ 157 flags = {f:None for f in input_flags} 158 159 ### Set boundary conditions 160 cell = simulation_parameters.get("cell", None) 161 """@keyword[fennol_md] cell 162 Unit cell vectors. Required for PBC. It is a sequence of floats: 163 - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz] 164 - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma] 165 - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic) 166 - 1 float: length of cell vectors (cubic cell) 167 Lengths are in Angstroms, angles in degrees. 168 Default: None 169 """ 170 if cell is not None: 171 cell = parse_cell(cell).astype(fprec) 172 # cell = np.array(cell, dtype=fprec).reshape(3, 3) 173 reciprocal_cell = np.linalg.inv(cell) 174 volume = np.abs(np.linalg.det(cell)) 175 print("# cell matrix:") 176 for l in cell: 177 print("# ", l) 178 # print(cell) 179 dens = totmass_Da * (au.MPROT*au.GCM3) / volume 180 print("# density: ", dens.item(), " g/cm^3") 181 minimum_image = simulation_parameters.get("minimum_image", True) 182 """@keyword[fennol_md] minimum_image 183 Use minimum image convention for neighbor lists in periodic systems. 184 Default: True 185 """ 186 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 187 """@keyword[fennol_md] estimate_pressure 188 Calculate and print pressure during simulation. 189 Default: False 190 """ 191 print("# minimum_image: ", minimum_image) 192 193 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 194 """@keyword[fennol_md] xyz_input/crystal 195 Use crystal coordinates. 196 Default: False 197 """ 198 if crystal_input: 199 coordinates = coordinates @ cell 200 201 pbc_data = { 202 "cell": cell, 203 "reciprocal_cell": reciprocal_cell, 204 "volume": volume, 205 "minimum_image": minimum_image, 206 "estimate_pressure": estimate_pressure, 207 } 208 if minimum_image: 209 flags["minimum_image"] = None 210 else: 211 pbc_data = None 212 system_data["pbc"] = pbc_data 213 system_data["initial_coordinates"] = coordinates.copy() 214 215 ### TOPOLOGY 216 topology_key = simulation_parameters.get("topology", None) 217 """@keyword[fennol_md] topology 218 Topology specification for molecular systems. Use "detect" for automatic detection. 219 Default: None 220 """ 221 if topology_key is not None: 222 topology_key = str(topology_key).strip() 223 if topology_key.lower() == "detect": 224 topology = detect_topology(species,coordinates,cell=cell) 225 np.savetxt(system_name +".topo", topology+1, fmt="%d") 226 print("# Detected topology saved to", system_name + ".topo") 227 else: 228 assert Path(topology_key).exists(), f"Topology file {topology_key} not found" 229 topology = np.loadtxt(topology_key, dtype=np.int32)-1 230 assert topology.shape[1] == 2, "Topology file must have two columns (source, target)" 231 print("# Topology loaded from", topology_key) 232 else: 233 topology = None 234 235 system_data["topology"] = topology 236 237 ### PIMD 238 nbeads = simulation_parameters.get("nbeads", None) 239 """@keyword[fennol_md] nbeads 240 Number of beads for Path Integral MD. 241 Default: None 242 """ 243 nreplicas = simulation_parameters.get("nreplicas", None) 244 """@keyword[fennol_md] nreplicas 245 Number of replicas for independent replica simulations. 246 Default: None 247 """ 248 if nbeads is not None: 249 nbeads = int(nbeads) 250 print("# nbeads: ", nbeads) 251 system_data["nbeads"] = nbeads 252 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3) 253 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 254 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 255 natoms = np.array([nat] * nbeads, dtype=np.int32) 256 257 eigmat = np.zeros((nbeads, nbeads)) 258 for i in range(nbeads - 1): 259 eigmat[i, i] = 2.0 260 eigmat[i, i + 1] = -1.0 261 eigmat[i + 1, i] = -1.0 262 eigmat[-1, -1] = 2.0 263 eigmat[0, -1] = -1.0 264 eigmat[-1, 0] = -1.0 265 omk, eigmat = np.linalg.eigh(eigmat) 266 omk[0] = 0.0 267 omk = nbeads * kT * omk**0.5 / au.FS 268 for i in range(nbeads): 269 if eigmat[i, 0] < 0: 270 eigmat[i] *= -1.0 271 eigmat = jnp.asarray(eigmat, dtype=fprec) 272 system_data["omk"] = omk 273 system_data["eigmat"] = eigmat 274 nreplicas = None 275 elif nreplicas is not None: 276 nreplicas = int(nreplicas) 277 print("# nreplicas: ", nreplicas) 278 system_data["nreplicas"] = nreplicas 279 system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1) 280 system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape( 281 -1 282 ) 283 coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape( 284 -1, 3 285 ) 286 species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1) 287 bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat) 288 natoms = np.array([nat] * nreplicas, dtype=np.int32) 289 else: 290 system_data["nreplicas"] = 1 291 bead_index = np.array([0] * nat, dtype=np.int32) 292 natoms = np.array([nat], dtype=np.int32) 293 294 conformation = { 295 "species": species, 296 "coordinates": coordinates, 297 "batch_index": bead_index, 298 "natoms": natoms, 299 "total_charge": total_charge, 300 } 301 if cell is not None: 302 cell = cell[None, :, :] 303 reciprocal_cell = reciprocal_cell[None, :, :] 304 if nbeads is not None: 305 cell = np.repeat(cell, nbeads, axis=0) 306 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 307 elif nreplicas is not None: 308 cell = np.repeat(cell, nreplicas, axis=0) 309 reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0) 310 conformation["cells"] = cell 311 conformation["reciprocal_cells"] = reciprocal_cell 312 313 additional_keys = simulation_parameters.get("additional_keys", {}) 314 """@keyword[fennol_md] additional_keys 315 Additional custom keys for model input. 316 Default: {} 317 """ 318 for key, value in additional_keys.items(): 319 conformation[key] = value 320 321 conformation["flags"] = flags 322 323 return system_data, conformation 324 325 326def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 327 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 328 """@keyword[fennol_md] nblist_verbose 329 Print detailed neighbor list information. 330 Default: False 331 """ 332 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 333 """@keyword[fennol_md] nblist_skin 334 Neighbor list skin distance in Angstroms. 335 Default: -1.0 (automatic) 336 """ 337 338 ### CONFIGURE PREPROCESSING 339 preproc_state = unfreeze(model.preproc_state) 340 layer_state = [] 341 for st in preproc_state["layers_state"]: 342 stnew = unfreeze(st) 343 if nblist_skin > 0: 344 stnew["nblist_skin"] = nblist_skin 345 if "nblist_mult_size" in simulation_parameters: 346 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 347 """@keyword[fennol_md] nblist_mult_size 348 Multiplier for neighbor list size. 349 Default: None 350 """ 351 if "nblist_add_neigh" in simulation_parameters: 352 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 353 """@keyword[fennol_md] nblist_add_neigh 354 Additional neighbors to include in lists. 355 Default: None 356 """ 357 layer_state.append(freeze(stnew)) 358 preproc_state["layers_state"] = layer_state 359 preproc_state = freeze(preproc_state) 360 361 ## initial preprocessing 362 preproc_state = preproc_state.copy({"check_input": True}) 363 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 364 365 preproc_state = preproc_state.copy({"check_input": False}) 366 367 if nblist_verbose: 368 graphs_keys = list(model._graphs_properties.keys()) 369 print("# graphs_keys: ", graphs_keys) 370 print("# nblist state:", preproc_state) 371 372 ### print model 373 if simulation_parameters.get("print_model", False): 374 """@keyword[fennol_md] print_model 375 Print detailed model information at startup. 376 Default: False 377 """ 378 print(model.summarize(example_data=conformation)) 379 380 return preproc_state, conformation
def
load_model(simulation_parameters):
19def load_model(simulation_parameters): 20 model_file = simulation_parameters.get("model_file") 21 """@keyword[fennol_md] model_file 22 Path to the machine learning model file (.fnx format). Required parameter. 23 Type: str, Required 24 """ 25 model_file = Path(str(model_file).strip()) 26 if not model_file.exists(): 27 raise FileNotFoundError(f"model file {model_file} not found") 28 else: 29 graph_config = simulation_parameters.get("graph_config", {}) 30 """@keyword[fennol_md] graph_config 31 Advanced graph configuration for model initialization. 32 Default: {} 33 """ 34 model = FENNIX.load(model_file, graph_config=graph_config) # \ 35 print(f"# model_file: {model_file}") 36 37 if "energy_terms" in simulation_parameters: 38 energy_terms = simulation_parameters["energy_terms"] 39 if isinstance(energy_terms, str): 40 energy_terms = energy_terms.split() 41 model.set_energy_terms(energy_terms) 42 print("# energy terms:", model.energy_terms) 43 44 return model
def
load_system_data(simulation_parameters, fprec):
47def load_system_data(simulation_parameters, fprec): 48 ## LOAD SYSTEM CONFORMATION FROM FILES 49 system_name = str(simulation_parameters.get("system_name", "system")).strip() 50 """@keyword[fennol_md] system_name 51 Name prefix for output files. If not specified, uses the xyz filename stem. 52 Default: "system" 53 """ 54 indexed = simulation_parameters.get("xyz_input/indexed", False) 55 """@keyword[fennol_md] xyz_input/indexed 56 Whether first column contains atom indices (Tinker format). 57 Default: False 58 """ 59 has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True) 60 """@keyword[fennol_md] xyz_input/has_comment_line 61 Whether file contains comment lines. 62 Default: True 63 """ 64 xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz")) 65 """@keyword[fennol_md] xyz_input/file 66 Path to xyz/arc coordinate file. Required parameter. 67 Type: str, Required 68 """ 69 if not xyzfile.exists(): 70 raise FileNotFoundError(f"xyz file {xyzfile} not found") 71 system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip() 72 symbols, coordinates, _ = last_xyz_frame( 73 xyzfile, indexed=indexed, has_comment_line=has_comment_line 74 ) 75 coordinates = coordinates.astype(fprec) 76 species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32) 77 nat = species.shape[0] 78 79 ## GET MASS 80 mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species] 81 deuterate = simulation_parameters.get("deuterate", False) 82 """@keyword[fennol_md] deuterate 83 Replace hydrogen masses with deuterium masses. 84 Default: False 85 """ 86 if deuterate: 87 print("# Replacing all hydrogens with deuteriums") 88 mass_Da[species == 1] *= 2.0 89 90 totmass_Da = mass_Da.sum() 91 92 mass = mass_Da.copy() 93 hmr = simulation_parameters.get("hmr", 0) 94 """@keyword[fennol_md] hmr 95 Hydrogen mass repartitioning factor. 0 = no repartitioning. 96 Default: 0 97 """ 98 if hmr > 0: 99 print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.") 100 Hmask = species == 1 101 added_mass = hmr * Hmask.sum() 102 mass[Hmask] += hmr 103 wmass = mass[~Hmask] 104 mass[~Hmask] -= added_mass * wmass/ wmass.sum() 105 106 assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed" 107 108 mass = mass * (au.MPROT * (au.FS / au.BOHR) ** 2) 109 110 ### GET TEMPERATURE 111 temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None) 112 """@keyword[fennol_md] temperature 113 Target temperature in Kelvin. 114 Default: 300.0 115 """ 116 kT = temperature / au.KELVIN 117 118 ### GET TOTAL CHARGE 119 total_charge = simulation_parameters.get("total_charge", None) 120 """@keyword[fennol_md] total_charge 121 Total system charge for charged systems. 122 Default: None (interpreted as 0) 123 """ 124 if total_charge is None: 125 total_charge = 0 126 else: 127 total_charge = int(total_charge) 128 print("# total charge: ", total_charge,"e") 129 130 ### ENERGY UNIT 131 energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol") 132 """@keyword[fennol_md] energy_unit 133 Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'. 134 Default: "kcal/mol" 135 """ 136 energy_unit = au.get_multiplier(energy_unit_str) 137 138 ## SYSTEM DATA 139 system_data = { 140 "name": system_name, 141 "nat": nat, 142 "symbols": symbols, 143 "species": species, 144 "mass": mass, 145 "mass_Da": mass_Da, 146 "totmass_Da": totmass_Da, 147 "temperature": temperature, 148 "kT": kT, 149 "total_charge": total_charge, 150 "energy_unit": energy_unit, 151 "energy_unit_str": energy_unit_str, 152 } 153 input_flags = simulation_parameters.get("model_flags", []) 154 """@keyword[fennol_md] model_flags 155 Additional flags to pass to the model. 156 Default: [] 157 """ 158 flags = {f:None for f in input_flags} 159 160 ### Set boundary conditions 161 cell = simulation_parameters.get("cell", None) 162 """@keyword[fennol_md] cell 163 Unit cell vectors. Required for PBC. It is a sequence of floats: 164 - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz] 165 - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma] 166 - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic) 167 - 1 float: length of cell vectors (cubic cell) 168 Lengths are in Angstroms, angles in degrees. 169 Default: None 170 """ 171 if cell is not None: 172 cell = parse_cell(cell).astype(fprec) 173 # cell = np.array(cell, dtype=fprec).reshape(3, 3) 174 reciprocal_cell = np.linalg.inv(cell) 175 volume = np.abs(np.linalg.det(cell)) 176 print("# cell matrix:") 177 for l in cell: 178 print("# ", l) 179 # print(cell) 180 dens = totmass_Da * (au.MPROT*au.GCM3) / volume 181 print("# density: ", dens.item(), " g/cm^3") 182 minimum_image = simulation_parameters.get("minimum_image", True) 183 """@keyword[fennol_md] minimum_image 184 Use minimum image convention for neighbor lists in periodic systems. 185 Default: True 186 """ 187 estimate_pressure = simulation_parameters.get("estimate_pressure", False) 188 """@keyword[fennol_md] estimate_pressure 189 Calculate and print pressure during simulation. 190 Default: False 191 """ 192 print("# minimum_image: ", minimum_image) 193 194 crystal_input = simulation_parameters.get("xyz_input/crystal", False) 195 """@keyword[fennol_md] xyz_input/crystal 196 Use crystal coordinates. 197 Default: False 198 """ 199 if crystal_input: 200 coordinates = coordinates @ cell 201 202 pbc_data = { 203 "cell": cell, 204 "reciprocal_cell": reciprocal_cell, 205 "volume": volume, 206 "minimum_image": minimum_image, 207 "estimate_pressure": estimate_pressure, 208 } 209 if minimum_image: 210 flags["minimum_image"] = None 211 else: 212 pbc_data = None 213 system_data["pbc"] = pbc_data 214 system_data["initial_coordinates"] = coordinates.copy() 215 216 ### TOPOLOGY 217 topology_key = simulation_parameters.get("topology", None) 218 """@keyword[fennol_md] topology 219 Topology specification for molecular systems. Use "detect" for automatic detection. 220 Default: None 221 """ 222 if topology_key is not None: 223 topology_key = str(topology_key).strip() 224 if topology_key.lower() == "detect": 225 topology = detect_topology(species,coordinates,cell=cell) 226 np.savetxt(system_name +".topo", topology+1, fmt="%d") 227 print("# Detected topology saved to", system_name + ".topo") 228 else: 229 assert Path(topology_key).exists(), f"Topology file {topology_key} not found" 230 topology = np.loadtxt(topology_key, dtype=np.int32)-1 231 assert topology.shape[1] == 2, "Topology file must have two columns (source, target)" 232 print("# Topology loaded from", topology_key) 233 else: 234 topology = None 235 236 system_data["topology"] = topology 237 238 ### PIMD 239 nbeads = simulation_parameters.get("nbeads", None) 240 """@keyword[fennol_md] nbeads 241 Number of beads for Path Integral MD. 242 Default: None 243 """ 244 nreplicas = simulation_parameters.get("nreplicas", None) 245 """@keyword[fennol_md] nreplicas 246 Number of replicas for independent replica simulations. 247 Default: None 248 """ 249 if nbeads is not None: 250 nbeads = int(nbeads) 251 print("# nbeads: ", nbeads) 252 system_data["nbeads"] = nbeads 253 coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3) 254 species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1) 255 bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat) 256 natoms = np.array([nat] * nbeads, dtype=np.int32) 257 258 eigmat = np.zeros((nbeads, nbeads)) 259 for i in range(nbeads - 1): 260 eigmat[i, i] = 2.0 261 eigmat[i, i + 1] = -1.0 262 eigmat[i + 1, i] = -1.0 263 eigmat[-1, -1] = 2.0 264 eigmat[0, -1] = -1.0 265 eigmat[-1, 0] = -1.0 266 omk, eigmat = np.linalg.eigh(eigmat) 267 omk[0] = 0.0 268 omk = nbeads * kT * omk**0.5 / au.FS 269 for i in range(nbeads): 270 if eigmat[i, 0] < 0: 271 eigmat[i] *= -1.0 272 eigmat = jnp.asarray(eigmat, dtype=fprec) 273 system_data["omk"] = omk 274 system_data["eigmat"] = eigmat 275 nreplicas = None 276 elif nreplicas is not None: 277 nreplicas = int(nreplicas) 278 print("# nreplicas: ", nreplicas) 279 system_data["nreplicas"] = nreplicas 280 system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1) 281 system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape( 282 -1 283 ) 284 coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape( 285 -1, 3 286 ) 287 species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1) 288 bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat) 289 natoms = np.array([nat] * nreplicas, dtype=np.int32) 290 else: 291 system_data["nreplicas"] = 1 292 bead_index = np.array([0] * nat, dtype=np.int32) 293 natoms = np.array([nat], dtype=np.int32) 294 295 conformation = { 296 "species": species, 297 "coordinates": coordinates, 298 "batch_index": bead_index, 299 "natoms": natoms, 300 "total_charge": total_charge, 301 } 302 if cell is not None: 303 cell = cell[None, :, :] 304 reciprocal_cell = reciprocal_cell[None, :, :] 305 if nbeads is not None: 306 cell = np.repeat(cell, nbeads, axis=0) 307 reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0) 308 elif nreplicas is not None: 309 cell = np.repeat(cell, nreplicas, axis=0) 310 reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0) 311 conformation["cells"] = cell 312 conformation["reciprocal_cells"] = reciprocal_cell 313 314 additional_keys = simulation_parameters.get("additional_keys", {}) 315 """@keyword[fennol_md] additional_keys 316 Additional custom keys for model input. 317 Default: {} 318 """ 319 for key, value in additional_keys.items(): 320 conformation[key] = value 321 322 conformation["flags"] = flags 323 324 return system_data, conformation
def
initialize_preprocessing(simulation_parameters, model, conformation, system_data):
327def initialize_preprocessing(simulation_parameters, model, conformation, system_data): 328 nblist_verbose = simulation_parameters.get("nblist_verbose", False) 329 """@keyword[fennol_md] nblist_verbose 330 Print detailed neighbor list information. 331 Default: False 332 """ 333 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 334 """@keyword[fennol_md] nblist_skin 335 Neighbor list skin distance in Angstroms. 336 Default: -1.0 (automatic) 337 """ 338 339 ### CONFIGURE PREPROCESSING 340 preproc_state = unfreeze(model.preproc_state) 341 layer_state = [] 342 for st in preproc_state["layers_state"]: 343 stnew = unfreeze(st) 344 if nblist_skin > 0: 345 stnew["nblist_skin"] = nblist_skin 346 if "nblist_mult_size" in simulation_parameters: 347 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 348 """@keyword[fennol_md] nblist_mult_size 349 Multiplier for neighbor list size. 350 Default: None 351 """ 352 if "nblist_add_neigh" in simulation_parameters: 353 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 354 """@keyword[fennol_md] nblist_add_neigh 355 Additional neighbors to include in lists. 356 Default: None 357 """ 358 layer_state.append(freeze(stnew)) 359 preproc_state["layers_state"] = layer_state 360 preproc_state = freeze(preproc_state) 361 362 ## initial preprocessing 363 preproc_state = preproc_state.copy({"check_input": True}) 364 preproc_state, conformation = model.preprocessing(preproc_state, conformation) 365 366 preproc_state = preproc_state.copy({"check_input": False}) 367 368 if nblist_verbose: 369 graphs_keys = list(model._graphs_properties.keys()) 370 print("# graphs_keys: ", graphs_keys) 371 print("# nblist state:", preproc_state) 372 373 ### print model 374 if simulation_parameters.get("print_model", False): 375 """@keyword[fennol_md] print_model 376 Print detailed model information at startup. 377 Default: False 378 """ 379 print(model.summarize(example_data=conformation)) 380 381 return preproc_state, conformation