fennol.md.initial

  1from pathlib import Path
  2
  3import numpy as np
  4import jax.numpy as jnp
  5
  6from flax.core import freeze, unfreeze
  7
  8from ..utils.io import last_xyz_frame
  9
 10
 11from ..models import FENNIX
 12
 13from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES
 14from ..utils.atomic_units import AtomicUnits as au
 15from ..utils import detect_topology,parse_cell
 16
 17
 18def load_model(simulation_parameters):
 19    model_file = simulation_parameters.get("model_file")
 20    """@keyword[fennol_md] model_file
 21    Path to the machine learning model file (.fnx format). Required parameter.
 22    Type: str, Required
 23    """
 24    model_file = Path(str(model_file).strip())
 25    if not model_file.exists():
 26        raise FileNotFoundError(f"model file {model_file} not found")
 27    else:
 28        graph_config = simulation_parameters.get("graph_config", {})
 29        """@keyword[fennol_md] graph_config
 30        Advanced graph configuration for model initialization.
 31        Default: {}
 32        """
 33        model = FENNIX.load(model_file, graph_config=graph_config)  # \
 34        print(f"# model_file: {model_file}")
 35
 36    if "energy_terms" in simulation_parameters:
 37        energy_terms = simulation_parameters["energy_terms"]
 38        if isinstance(energy_terms, str):
 39            energy_terms = energy_terms.split()
 40        model.set_energy_terms(energy_terms)
 41        print("# energy terms:", model.energy_terms)
 42
 43    return model
 44
 45
 46def load_system_data(simulation_parameters, fprec):
 47    ## LOAD SYSTEM CONFORMATION FROM FILES
 48    system_name = str(simulation_parameters.get("system_name", "system")).strip()
 49    """@keyword[fennol_md] system_name
 50    Name prefix for output files. If not specified, uses the xyz filename stem.
 51    Default: "system"
 52    """
 53    indexed = simulation_parameters.get("xyz_input/indexed", False)
 54    """@keyword[fennol_md] xyz_input/indexed
 55    Whether first column contains atom indices (Tinker format).
 56    Default: False
 57    """
 58    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True)
 59    """@keyword[fennol_md] xyz_input/has_comment_line
 60    Whether file contains comment lines.
 61    Default: True
 62    """
 63    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 64    """@keyword[fennol_md] xyz_input/file
 65    Path to xyz/arc coordinate file. Required parameter.
 66    Type: str, Required
 67    """
 68    if not xyzfile.exists():
 69        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 70    system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip()
 71    symbols, coordinates, _ = last_xyz_frame(
 72        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 73    )
 74    coordinates = coordinates.astype(fprec)
 75    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 76    nat = species.shape[0]
 77
 78    ## GET MASS
 79    mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 80    deuterate = simulation_parameters.get("deuterate", False)
 81    """@keyword[fennol_md] deuterate
 82    Replace hydrogen masses with deuterium masses.
 83    Default: False
 84    """
 85    if deuterate:
 86        print("# Replacing all hydrogens with deuteriums")
 87        mass_Da[species == 1] *= 2.0
 88
 89    totmass_Da = mass_Da.sum()
 90
 91    mass = mass_Da.copy()
 92    hmr = simulation_parameters.get("hmr", 0)
 93    """@keyword[fennol_md] hmr
 94    Hydrogen mass repartitioning factor. 0 = no repartitioning.
 95    Default: 0
 96    """
 97    if hmr > 0:
 98        print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.")
 99        Hmask = species == 1
100        added_mass = hmr * Hmask.sum()
101        mass[Hmask] += hmr
102        wmass = mass[~Hmask]
103        mass[~Hmask] -= added_mass * wmass/ wmass.sum()
104
105        assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed"
106
107    mass = mass * (au.MPROT * (au.FS / au.BOHR) ** 2)
108
109    ### GET TEMPERATURE
110    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
111    """@keyword[fennol_md] temperature
112    Target temperature in Kelvin.
113    Default: 300.0
114    """
115    kT = temperature / au.KELVIN
116
117    ### GET TOTAL CHARGE
118    total_charge = simulation_parameters.get("total_charge", None)
119    """@keyword[fennol_md] total_charge
120    Total system charge for charged systems.
121    Default: None (interpreted as 0)
122    """
123    if total_charge is None:
124        total_charge = 0
125    else:
126        total_charge = int(total_charge)
127        print("# total charge: ", total_charge,"e")
128
129    ### ENERGY UNIT
130    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
131    """@keyword[fennol_md] energy_unit
132    Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'.
133    Default: "kcal/mol"
134    """
135    energy_unit = au.get_multiplier(energy_unit_str)
136
137    ## SYSTEM DATA
138    system_data = {
139        "name": system_name,
140        "nat": nat,
141        "symbols": symbols,
142        "species": species,
143        "mass": mass,
144        "mass_Da": mass_Da,
145        "totmass_Da": totmass_Da,
146        "temperature": temperature,
147        "kT": kT,
148        "total_charge": total_charge,
149        "energy_unit": energy_unit,
150        "energy_unit_str": energy_unit_str,
151    }
152    input_flags = simulation_parameters.get("model_flags", [])
153    """@keyword[fennol_md] model_flags
154    Additional flags to pass to the model.
155    Default: []
156    """
157    flags = {f:None for f in input_flags}
158
159    ### Set boundary conditions
160    cell = simulation_parameters.get("cell", None)
161    """@keyword[fennol_md] cell
162    Unit cell vectors. Required for PBC. It is a sequence of floats:
163    - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz]
164    - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma]
165    - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic)
166    - 1 float: length of cell vectors (cubic cell)
167    Lengths are in Angstroms, angles in degrees.
168    Default: None
169    """
170    if cell is not None:
171        cell = parse_cell(cell).astype(fprec)
172        # cell = np.array(cell, dtype=fprec).reshape(3, 3)
173        reciprocal_cell = np.linalg.inv(cell)
174        volume = np.abs(np.linalg.det(cell))
175        print("# cell matrix:")
176        for l in cell:
177            print("# ", l)
178        # print(cell)
179        dens = totmass_Da * (au.MPROT*au.GCM3) / volume
180        print("# density: ", dens.item(), " g/cm^3")
181        minimum_image = simulation_parameters.get("minimum_image", True)
182        """@keyword[fennol_md] minimum_image
183        Use minimum image convention for neighbor lists in periodic systems.
184        Default: True
185        """
186        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
187        """@keyword[fennol_md] estimate_pressure
188        Calculate and print pressure during simulation.
189        Default: False
190        """
191        print("# minimum_image: ", minimum_image)
192
193        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
194        """@keyword[fennol_md] xyz_input/crystal
195        Use crystal coordinates.
196        Default: False
197        """
198        if crystal_input:
199            coordinates = coordinates @ cell
200
201        pbc_data = {
202            "cell": cell,
203            "reciprocal_cell": reciprocal_cell,
204            "volume": volume,
205            "minimum_image": minimum_image,
206            "estimate_pressure": estimate_pressure,
207        }
208        if minimum_image:
209            flags["minimum_image"] = None
210    else:
211        pbc_data = None
212    system_data["pbc"] = pbc_data
213    system_data["initial_coordinates"] = coordinates.copy()
214
215    ### TOPOLOGY
216    topology_key = simulation_parameters.get("topology", None)
217    """@keyword[fennol_md] topology
218    Topology specification for molecular systems. Use "detect" for automatic detection.
219    Default: None
220    """
221    if topology_key is not None:
222        topology_key = str(topology_key).strip()
223        if topology_key.lower() == "detect":
224            topology = detect_topology(species,coordinates,cell=cell)
225            np.savetxt(system_name +".topo", topology+1, fmt="%d")
226            print("# Detected topology saved to", system_name + ".topo")
227        else:
228            assert Path(topology_key).exists(), f"Topology file {topology_key} not found"
229            topology = np.loadtxt(topology_key, dtype=np.int32)-1
230            assert topology.shape[1] == 2, "Topology file must have two columns (source, target)"
231            print("# Topology loaded from", topology_key)
232    else:
233        topology = None
234    
235    system_data["topology"] = topology
236
237    ### PIMD
238    nbeads = simulation_parameters.get("nbeads", None)
239    """@keyword[fennol_md] nbeads
240    Number of beads for Path Integral MD.
241    Default: None
242    """
243    nreplicas = simulation_parameters.get("nreplicas", None)
244    """@keyword[fennol_md] nreplicas
245    Number of replicas for independent replica simulations.
246    Default: None
247    """
248    if nbeads is not None:
249        nbeads = int(nbeads)
250        print("# nbeads: ", nbeads)
251        system_data["nbeads"] = nbeads
252        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3)
253        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
254        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
255        natoms = np.array([nat] * nbeads, dtype=np.int32)
256
257        eigmat = np.zeros((nbeads, nbeads))
258        for i in range(nbeads - 1):
259            eigmat[i, i] = 2.0
260            eigmat[i, i + 1] = -1.0
261            eigmat[i + 1, i] = -1.0
262        eigmat[-1, -1] = 2.0
263        eigmat[0, -1] = -1.0
264        eigmat[-1, 0] = -1.0
265        omk, eigmat = np.linalg.eigh(eigmat)
266        omk[0] = 0.0
267        omk = nbeads * kT * omk**0.5 / au.FS
268        for i in range(nbeads):
269            if eigmat[i, 0] < 0:
270                eigmat[i] *= -1.0
271        eigmat = jnp.asarray(eigmat, dtype=fprec)
272        system_data["omk"] = omk
273        system_data["eigmat"] = eigmat
274        nreplicas = None
275    elif nreplicas is not None:
276        nreplicas = int(nreplicas)
277        print("# nreplicas: ", nreplicas)
278        system_data["nreplicas"] = nreplicas
279        system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1)
280        system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape(
281            -1
282        )
283        coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape(
284            -1, 3
285        )
286        species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1)
287        bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat)
288        natoms = np.array([nat] * nreplicas, dtype=np.int32)
289    else:
290        system_data["nreplicas"] = 1
291        bead_index = np.array([0] * nat, dtype=np.int32)
292        natoms = np.array([nat], dtype=np.int32)
293
294    conformation = {
295        "species": species,
296        "coordinates": coordinates,
297        "batch_index": bead_index,
298        "natoms": natoms,
299        "total_charge": total_charge,
300    }
301    if cell is not None:
302        cell = cell[None, :, :]
303        reciprocal_cell = reciprocal_cell[None, :, :]
304        if nbeads is not None:
305            cell = np.repeat(cell, nbeads, axis=0)
306            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
307        elif nreplicas is not None:
308            cell = np.repeat(cell, nreplicas, axis=0)
309            reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0)
310        conformation["cells"] = cell
311        conformation["reciprocal_cells"] = reciprocal_cell
312
313    additional_keys = simulation_parameters.get("additional_keys", {})
314    """@keyword[fennol_md] additional_keys
315    Additional custom keys for model input.
316    Default: {}
317    """
318    for key, value in additional_keys.items():
319        conformation[key] = value
320    
321    conformation["flags"] = flags
322
323    return system_data, conformation
324
325
326def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
327    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
328    """@keyword[fennol_md] nblist_verbose
329    Print detailed neighbor list information.
330    Default: False
331    """
332    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
333    """@keyword[fennol_md] nblist_skin
334    Neighbor list skin distance in Angstroms.
335    Default: -1.0 (automatic)
336    """
337
338    ### CONFIGURE PREPROCESSING
339    preproc_state = unfreeze(model.preproc_state)
340    layer_state = []
341    for st in preproc_state["layers_state"]:
342        stnew = unfreeze(st)
343        if nblist_skin > 0:
344            stnew["nblist_skin"] = nblist_skin
345        if "nblist_mult_size" in simulation_parameters:
346            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
347            """@keyword[fennol_md] nblist_mult_size
348            Multiplier for neighbor list size.
349            Default: None
350            """
351        if "nblist_add_neigh" in simulation_parameters:
352            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
353            """@keyword[fennol_md] nblist_add_neigh
354            Additional neighbors to include in lists.
355            Default: None
356            """
357        layer_state.append(freeze(stnew))
358    preproc_state["layers_state"] = layer_state
359    preproc_state = freeze(preproc_state)
360
361    ## initial preprocessing
362    preproc_state = preproc_state.copy({"check_input": True})
363    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
364
365    preproc_state = preproc_state.copy({"check_input": False})
366
367    if nblist_verbose:
368        graphs_keys = list(model._graphs_properties.keys())
369        print("# graphs_keys: ", graphs_keys)
370        print("# nblist state:", preproc_state)
371
372    ### print model
373    if simulation_parameters.get("print_model", False):
374        """@keyword[fennol_md] print_model
375        Print detailed model information at startup.
376        Default: False
377        """
378        print(model.summarize(example_data=conformation))
379
380    return preproc_state, conformation
def load_model(simulation_parameters):
19def load_model(simulation_parameters):
20    model_file = simulation_parameters.get("model_file")
21    """@keyword[fennol_md] model_file
22    Path to the machine learning model file (.fnx format). Required parameter.
23    Type: str, Required
24    """
25    model_file = Path(str(model_file).strip())
26    if not model_file.exists():
27        raise FileNotFoundError(f"model file {model_file} not found")
28    else:
29        graph_config = simulation_parameters.get("graph_config", {})
30        """@keyword[fennol_md] graph_config
31        Advanced graph configuration for model initialization.
32        Default: {}
33        """
34        model = FENNIX.load(model_file, graph_config=graph_config)  # \
35        print(f"# model_file: {model_file}")
36
37    if "energy_terms" in simulation_parameters:
38        energy_terms = simulation_parameters["energy_terms"]
39        if isinstance(energy_terms, str):
40            energy_terms = energy_terms.split()
41        model.set_energy_terms(energy_terms)
42        print("# energy terms:", model.energy_terms)
43
44    return model
def load_system_data(simulation_parameters, fprec):
 47def load_system_data(simulation_parameters, fprec):
 48    ## LOAD SYSTEM CONFORMATION FROM FILES
 49    system_name = str(simulation_parameters.get("system_name", "system")).strip()
 50    """@keyword[fennol_md] system_name
 51    Name prefix for output files. If not specified, uses the xyz filename stem.
 52    Default: "system"
 53    """
 54    indexed = simulation_parameters.get("xyz_input/indexed", False)
 55    """@keyword[fennol_md] xyz_input/indexed
 56    Whether first column contains atom indices (Tinker format).
 57    Default: False
 58    """
 59    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True)
 60    """@keyword[fennol_md] xyz_input/has_comment_line
 61    Whether file contains comment lines.
 62    Default: True
 63    """
 64    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 65    """@keyword[fennol_md] xyz_input/file
 66    Path to xyz/arc coordinate file. Required parameter.
 67    Type: str, Required
 68    """
 69    if not xyzfile.exists():
 70        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 71    system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip()
 72    symbols, coordinates, _ = last_xyz_frame(
 73        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 74    )
 75    coordinates = coordinates.astype(fprec)
 76    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 77    nat = species.shape[0]
 78
 79    ## GET MASS
 80    mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 81    deuterate = simulation_parameters.get("deuterate", False)
 82    """@keyword[fennol_md] deuterate
 83    Replace hydrogen masses with deuterium masses.
 84    Default: False
 85    """
 86    if deuterate:
 87        print("# Replacing all hydrogens with deuteriums")
 88        mass_Da[species == 1] *= 2.0
 89
 90    totmass_Da = mass_Da.sum()
 91
 92    mass = mass_Da.copy()
 93    hmr = simulation_parameters.get("hmr", 0)
 94    """@keyword[fennol_md] hmr
 95    Hydrogen mass repartitioning factor. 0 = no repartitioning.
 96    Default: 0
 97    """
 98    if hmr > 0:
 99        print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.")
100        Hmask = species == 1
101        added_mass = hmr * Hmask.sum()
102        mass[Hmask] += hmr
103        wmass = mass[~Hmask]
104        mass[~Hmask] -= added_mass * wmass/ wmass.sum()
105
106        assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed"
107
108    mass = mass * (au.MPROT * (au.FS / au.BOHR) ** 2)
109
110    ### GET TEMPERATURE
111    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
112    """@keyword[fennol_md] temperature
113    Target temperature in Kelvin.
114    Default: 300.0
115    """
116    kT = temperature / au.KELVIN
117
118    ### GET TOTAL CHARGE
119    total_charge = simulation_parameters.get("total_charge", None)
120    """@keyword[fennol_md] total_charge
121    Total system charge for charged systems.
122    Default: None (interpreted as 0)
123    """
124    if total_charge is None:
125        total_charge = 0
126    else:
127        total_charge = int(total_charge)
128        print("# total charge: ", total_charge,"e")
129
130    ### ENERGY UNIT
131    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
132    """@keyword[fennol_md] energy_unit
133    Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'.
134    Default: "kcal/mol"
135    """
136    energy_unit = au.get_multiplier(energy_unit_str)
137
138    ## SYSTEM DATA
139    system_data = {
140        "name": system_name,
141        "nat": nat,
142        "symbols": symbols,
143        "species": species,
144        "mass": mass,
145        "mass_Da": mass_Da,
146        "totmass_Da": totmass_Da,
147        "temperature": temperature,
148        "kT": kT,
149        "total_charge": total_charge,
150        "energy_unit": energy_unit,
151        "energy_unit_str": energy_unit_str,
152    }
153    input_flags = simulation_parameters.get("model_flags", [])
154    """@keyword[fennol_md] model_flags
155    Additional flags to pass to the model.
156    Default: []
157    """
158    flags = {f:None for f in input_flags}
159
160    ### Set boundary conditions
161    cell = simulation_parameters.get("cell", None)
162    """@keyword[fennol_md] cell
163    Unit cell vectors. Required for PBC. It is a sequence of floats:
164    - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz]
165    - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma]
166    - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic)
167    - 1 float: length of cell vectors (cubic cell)
168    Lengths are in Angstroms, angles in degrees.
169    Default: None
170    """
171    if cell is not None:
172        cell = parse_cell(cell).astype(fprec)
173        # cell = np.array(cell, dtype=fprec).reshape(3, 3)
174        reciprocal_cell = np.linalg.inv(cell)
175        volume = np.abs(np.linalg.det(cell))
176        print("# cell matrix:")
177        for l in cell:
178            print("# ", l)
179        # print(cell)
180        dens = totmass_Da * (au.MPROT*au.GCM3) / volume
181        print("# density: ", dens.item(), " g/cm^3")
182        minimum_image = simulation_parameters.get("minimum_image", True)
183        """@keyword[fennol_md] minimum_image
184        Use minimum image convention for neighbor lists in periodic systems.
185        Default: True
186        """
187        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
188        """@keyword[fennol_md] estimate_pressure
189        Calculate and print pressure during simulation.
190        Default: False
191        """
192        print("# minimum_image: ", minimum_image)
193
194        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
195        """@keyword[fennol_md] xyz_input/crystal
196        Use crystal coordinates.
197        Default: False
198        """
199        if crystal_input:
200            coordinates = coordinates @ cell
201
202        pbc_data = {
203            "cell": cell,
204            "reciprocal_cell": reciprocal_cell,
205            "volume": volume,
206            "minimum_image": minimum_image,
207            "estimate_pressure": estimate_pressure,
208        }
209        if minimum_image:
210            flags["minimum_image"] = None
211    else:
212        pbc_data = None
213    system_data["pbc"] = pbc_data
214    system_data["initial_coordinates"] = coordinates.copy()
215
216    ### TOPOLOGY
217    topology_key = simulation_parameters.get("topology", None)
218    """@keyword[fennol_md] topology
219    Topology specification for molecular systems. Use "detect" for automatic detection.
220    Default: None
221    """
222    if topology_key is not None:
223        topology_key = str(topology_key).strip()
224        if topology_key.lower() == "detect":
225            topology = detect_topology(species,coordinates,cell=cell)
226            np.savetxt(system_name +".topo", topology+1, fmt="%d")
227            print("# Detected topology saved to", system_name + ".topo")
228        else:
229            assert Path(topology_key).exists(), f"Topology file {topology_key} not found"
230            topology = np.loadtxt(topology_key, dtype=np.int32)-1
231            assert topology.shape[1] == 2, "Topology file must have two columns (source, target)"
232            print("# Topology loaded from", topology_key)
233    else:
234        topology = None
235    
236    system_data["topology"] = topology
237
238    ### PIMD
239    nbeads = simulation_parameters.get("nbeads", None)
240    """@keyword[fennol_md] nbeads
241    Number of beads for Path Integral MD.
242    Default: None
243    """
244    nreplicas = simulation_parameters.get("nreplicas", None)
245    """@keyword[fennol_md] nreplicas
246    Number of replicas for independent replica simulations.
247    Default: None
248    """
249    if nbeads is not None:
250        nbeads = int(nbeads)
251        print("# nbeads: ", nbeads)
252        system_data["nbeads"] = nbeads
253        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3)
254        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
255        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
256        natoms = np.array([nat] * nbeads, dtype=np.int32)
257
258        eigmat = np.zeros((nbeads, nbeads))
259        for i in range(nbeads - 1):
260            eigmat[i, i] = 2.0
261            eigmat[i, i + 1] = -1.0
262            eigmat[i + 1, i] = -1.0
263        eigmat[-1, -1] = 2.0
264        eigmat[0, -1] = -1.0
265        eigmat[-1, 0] = -1.0
266        omk, eigmat = np.linalg.eigh(eigmat)
267        omk[0] = 0.0
268        omk = nbeads * kT * omk**0.5 / au.FS
269        for i in range(nbeads):
270            if eigmat[i, 0] < 0:
271                eigmat[i] *= -1.0
272        eigmat = jnp.asarray(eigmat, dtype=fprec)
273        system_data["omk"] = omk
274        system_data["eigmat"] = eigmat
275        nreplicas = None
276    elif nreplicas is not None:
277        nreplicas = int(nreplicas)
278        print("# nreplicas: ", nreplicas)
279        system_data["nreplicas"] = nreplicas
280        system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1)
281        system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape(
282            -1
283        )
284        coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape(
285            -1, 3
286        )
287        species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1)
288        bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat)
289        natoms = np.array([nat] * nreplicas, dtype=np.int32)
290    else:
291        system_data["nreplicas"] = 1
292        bead_index = np.array([0] * nat, dtype=np.int32)
293        natoms = np.array([nat], dtype=np.int32)
294
295    conformation = {
296        "species": species,
297        "coordinates": coordinates,
298        "batch_index": bead_index,
299        "natoms": natoms,
300        "total_charge": total_charge,
301    }
302    if cell is not None:
303        cell = cell[None, :, :]
304        reciprocal_cell = reciprocal_cell[None, :, :]
305        if nbeads is not None:
306            cell = np.repeat(cell, nbeads, axis=0)
307            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
308        elif nreplicas is not None:
309            cell = np.repeat(cell, nreplicas, axis=0)
310            reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0)
311        conformation["cells"] = cell
312        conformation["reciprocal_cells"] = reciprocal_cell
313
314    additional_keys = simulation_parameters.get("additional_keys", {})
315    """@keyword[fennol_md] additional_keys
316    Additional custom keys for model input.
317    Default: {}
318    """
319    for key, value in additional_keys.items():
320        conformation[key] = value
321    
322    conformation["flags"] = flags
323
324    return system_data, conformation
def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
327def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
328    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
329    """@keyword[fennol_md] nblist_verbose
330    Print detailed neighbor list information.
331    Default: False
332    """
333    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
334    """@keyword[fennol_md] nblist_skin
335    Neighbor list skin distance in Angstroms.
336    Default: -1.0 (automatic)
337    """
338
339    ### CONFIGURE PREPROCESSING
340    preproc_state = unfreeze(model.preproc_state)
341    layer_state = []
342    for st in preproc_state["layers_state"]:
343        stnew = unfreeze(st)
344        if nblist_skin > 0:
345            stnew["nblist_skin"] = nblist_skin
346        if "nblist_mult_size" in simulation_parameters:
347            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
348            """@keyword[fennol_md] nblist_mult_size
349            Multiplier for neighbor list size.
350            Default: None
351            """
352        if "nblist_add_neigh" in simulation_parameters:
353            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
354            """@keyword[fennol_md] nblist_add_neigh
355            Additional neighbors to include in lists.
356            Default: None
357            """
358        layer_state.append(freeze(stnew))
359    preproc_state["layers_state"] = layer_state
360    preproc_state = freeze(preproc_state)
361
362    ## initial preprocessing
363    preproc_state = preproc_state.copy({"check_input": True})
364    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
365
366    preproc_state = preproc_state.copy({"check_input": False})
367
368    if nblist_verbose:
369        graphs_keys = list(model._graphs_properties.keys())
370        print("# graphs_keys: ", graphs_keys)
371        print("# nblist state:", preproc_state)
372
373    ### print model
374    if simulation_parameters.get("print_model", False):
375        """@keyword[fennol_md] print_model
376        Print detailed model information at startup.
377        Default: False
378        """
379        print(model.summarize(example_data=conformation))
380
381    return preproc_state, conformation