fennol.md.initial

  1from pathlib import Path
  2
  3import numpy as np
  4import jax.numpy as jnp
  5
  6from flax.core import freeze, unfreeze
  7
  8from ..utils.io import last_xyz_frame
  9
 10
 11from ..models import FENNIX
 12
 13from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES
 14from .utils import us
 15from ..utils import detect_topology,parse_cell
 16
 17
 18def load_model(simulation_parameters):
 19    model_file = simulation_parameters.get("model_file")
 20    """@keyword[fennol_md] model_file
 21    Path to the machine learning model file (.fnx format). Required parameter.
 22    Type: str, Required
 23    """
 24    model_file = Path(str(model_file).strip())
 25    if not model_file.exists():
 26        raise FileNotFoundError(f"model file {model_file} not found")
 27    else:
 28        graph_config = simulation_parameters.get("graph_config", {})
 29        """@keyword[fennol_md] graph_config
 30        Advanced graph configuration for model initialization.
 31        Default: {}
 32        """
 33        model = FENNIX.load(model_file, graph_config=graph_config)  # \
 34        print(f"# model_file: {model_file}")
 35
 36    if "energy_terms" in simulation_parameters:
 37        energy_terms = simulation_parameters["energy_terms"]
 38        if isinstance(energy_terms, str):
 39            energy_terms = energy_terms.split()
 40        model.set_energy_terms(energy_terms)
 41        print("# energy terms:", model.energy_terms)
 42
 43    return model
 44
 45
 46def load_system_data(simulation_parameters, fprec):
 47    ## LOAD SYSTEM CONFORMATION FROM FILES
 48    system_name = str(simulation_parameters.get("system_name", "system")).strip()
 49    """@keyword[fennol_md] system_name
 50    Name prefix for output files. If not specified, uses the xyz filename stem.
 51    Default: "system"
 52    """
 53    indexed = simulation_parameters.get("xyz_input/indexed", False)
 54    """@keyword[fennol_md] xyz_input/indexed
 55    Whether first column contains atom indices (Tinker format).
 56    Default: False
 57    """
 58    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True)
 59    """@keyword[fennol_md] xyz_input/has_comment_line
 60    Whether file contains comment lines.
 61    Default: True
 62    """
 63    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 64    """@keyword[fennol_md] xyz_input/file
 65    Path to xyz/arc coordinate file. Required parameter.
 66    Type: str, Required
 67    """
 68    if not xyzfile.exists():
 69        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 70    system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip()
 71    symbols, coordinates, _ = last_xyz_frame(
 72        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 73    )
 74    coordinates = coordinates.astype(fprec)
 75    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 76    nat = species.shape[0]
 77
 78    ## GET MASS
 79    mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 80    deuterate = simulation_parameters.get("deuterate", False)
 81    """@keyword[fennol_md] deuterate
 82    Replace hydrogen masses with deuterium masses.
 83    Default: False
 84    """
 85    if deuterate:
 86        print("# Replacing all hydrogens with deuteriums")
 87        mass_Da[species == 1] *= 2.0
 88
 89    totmass_Da = mass_Da.sum()
 90
 91    mass = mass_Da.copy()
 92    hmr = simulation_parameters.get("hmr", 0)
 93    """@keyword[fennol_md] hmr
 94    Hydrogen mass repartitioning factor. 0 = no repartitioning.
 95    Default: 0
 96    """
 97    if hmr > 0:
 98        print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.")
 99        Hmask = species == 1
100        added_mass = hmr * Hmask.sum()
101        mass[Hmask] += hmr
102        wmass = mass[~Hmask]
103        mass[~Hmask] -= added_mass * wmass/ wmass.sum()
104
105        assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed"
106
107    # convert to internal units
108    mass = mass / us.DA
109
110    ### GET TEMPERATURE
111    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
112    """@keyword[fennol_md] temperature
113    Target temperature in Kelvin.
114    Default: 300.0
115    """
116    kT = us.K_B * temperature 
117
118    ### GET TOTAL CHARGE
119    total_charge = simulation_parameters.get("total_charge", None)
120    """@keyword[fennol_md] total_charge
121    Total system charge for charged systems.
122    Default: None (interpreted as 0)
123    """
124    if total_charge is None:
125        total_charge = 0
126    else:
127        total_charge = int(total_charge)
128        print("# total charge: ", total_charge,"e")
129
130    ### ENERGY UNIT
131    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
132    """@keyword[fennol_md] energy_unit
133    Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'.
134    Default: "kcal/mol"
135    """
136    energy_unit = us.get_multiplier(energy_unit_str)
137
138    ## SYSTEM DATA
139    system_data = {
140        "name": system_name,
141        "nat": nat,
142        "symbols": symbols,
143        "species": species,
144        "mass": mass,
145        "mass_Da": mass_Da,
146        "totmass_Da": totmass_Da,
147        "temperature": temperature,
148        "kT": kT,
149        "total_charge": total_charge,
150        "energy_unit": energy_unit,
151        "energy_unit_str": energy_unit_str,
152    }
153    input_flags = simulation_parameters.get("model_flags", [])
154    """@keyword[fennol_md] model_flags
155    Additional flags to pass to the model.
156    Default: []
157    """
158    flags = {f:None for f in input_flags}
159
160    ### Set boundary conditions
161    cell = simulation_parameters.get("cell", None)
162    """@keyword[fennol_md] cell
163    Unit cell vectors. Required for PBC. It is a sequence of floats:
164    - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz]
165    - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma]
166    - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic)
167    - 1 float: length of cell vectors (cubic cell)
168    Lengths are in Angstroms, angles in degrees.
169    Default: None
170    """
171    if cell is not None:
172        cell = parse_cell(cell).astype(fprec)
173        # cell = np.array(cell, dtype=fprec).reshape(3, 3)
174        reciprocal_cell = np.linalg.inv(cell)
175        volume = np.abs(np.linalg.det(cell))
176        print("# cell matrix:")
177        for l in cell:
178            print("# ", l)
179        # print(cell)
180        dens = (totmass_Da/volume) * (us.MOL/us.CM**3)
181        print("# density: ", dens.item(), " g/cm^3")
182        minimum_image = simulation_parameters.get("minimum_image", True)
183        """@keyword[fennol_md] minimum_image
184        Use minimum image convention for neighbor lists in periodic systems.
185        Default: True
186        """
187        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
188        """@keyword[fennol_md] estimate_pressure
189        Calculate and print pressure during simulation.
190        Default: False
191        """
192        print("# minimum_image: ", minimum_image)
193
194        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
195        """@keyword[fennol_md] xyz_input/crystal
196        Use crystal coordinates.
197        Default: False
198        """
199        if crystal_input:
200            coordinates = coordinates @ cell
201
202        pbc_data = {
203            "cell": cell,
204            "reciprocal_cell": reciprocal_cell,
205            "volume": volume,
206            "minimum_image": minimum_image,
207            "estimate_pressure": estimate_pressure,
208        }
209        if minimum_image:
210            flags["minimum_image"] = None
211    else:
212        pbc_data = None
213    system_data["pbc"] = pbc_data
214    system_data["initial_coordinates"] = coordinates.copy()
215
216    ### TOPOLOGY
217    topology_key = simulation_parameters.get("topology", None)
218    """@keyword[fennol_md] topology
219    Topology specification for molecular systems. Use "detect" for automatic detection.
220    Default: None
221    """
222    if topology_key is not None:
223        topology_key = str(topology_key).strip()
224        if topology_key.lower() == "detect":
225            topology = detect_topology(species,coordinates,cell=cell)
226            np.savetxt(system_name +".topo", topology+1, fmt="%d")
227            print("# Detected topology saved to", system_name + ".topo")
228        else:
229            assert Path(topology_key).exists(), f"Topology file {topology_key} not found"
230            topology = np.loadtxt(topology_key, dtype=np.int32)-1
231            assert topology.shape[1] == 2, "Topology file must have two columns (source, target)"
232            print("# Topology loaded from", topology_key)
233    else:
234        topology = None
235    
236    system_data["topology"] = topology
237
238    ### PIMD
239    nbeads = simulation_parameters.get("nbeads", None)
240    """@keyword[fennol_md] nbeads
241    Number of beads for Path Integral MD.
242    Default: None
243    """
244    nreplicas = simulation_parameters.get("nreplicas", None)
245    """@keyword[fennol_md] nreplicas
246    Number of replicas for independent replica simulations.
247    Default: None
248    """
249    if nbeads is not None:
250        nbeads = int(nbeads)
251        print("# nbeads: ", nbeads)
252        system_data["nbeads"] = nbeads
253        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3)
254        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
255        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
256        natoms = np.array([nat] * nbeads, dtype=np.int32)
257
258        eigmat = np.zeros((nbeads, nbeads))
259        for i in range(nbeads - 1):
260            eigmat[i, i] = 2.0
261            eigmat[i, i + 1] = -1.0
262            eigmat[i + 1, i] = -1.0
263        eigmat[-1, -1] = 2.0
264        eigmat[0, -1] = -1.0
265        eigmat[-1, 0] = -1.0
266        omk, eigmat = np.linalg.eigh(eigmat)
267        omk[0] = 0.0
268        omk = (nbeads * kT / us.HBAR) * omk**0.5
269        for i in range(nbeads):
270            if eigmat[i, 0] < 0:
271                eigmat[i] *= -1.0
272        eigmat = jnp.asarray(eigmat, dtype=fprec)
273        system_data["omk"] = omk
274        system_data["eigmat"] = eigmat
275        nreplicas = None
276    elif nreplicas is not None:
277        nreplicas = int(nreplicas)
278        print("# nreplicas: ", nreplicas)
279        system_data["nreplicas"] = nreplicas
280        system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1)
281        system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape(
282            -1
283        )
284        coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape(
285            -1, 3
286        )
287        species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1)
288        bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat)
289        natoms = np.array([nat] * nreplicas, dtype=np.int32)
290    else:
291        system_data["nreplicas"] = 1
292        bead_index = np.array([0] * nat, dtype=np.int32)
293        natoms = np.array([nat], dtype=np.int32)
294
295    conformation = {
296        "species": species,
297        "coordinates": coordinates,
298        "batch_index": bead_index,
299        "natoms": natoms,
300        "total_charge": total_charge,
301    }
302    if cell is not None:
303        cell = cell[None, :, :]
304        reciprocal_cell = reciprocal_cell[None, :, :]
305        if nbeads is not None:
306            cell = np.repeat(cell, nbeads, axis=0)
307            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
308        elif nreplicas is not None:
309            cell = np.repeat(cell, nreplicas, axis=0)
310            reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0)
311        conformation["cells"] = cell
312        conformation["reciprocal_cells"] = reciprocal_cell
313
314    additional_keys = simulation_parameters.get("additional_keys", {})
315    """@keyword[fennol_md] additional_keys
316    Additional custom keys for model input.
317    Default: {}
318    """
319    for key, value in additional_keys.items():
320        conformation[key] = value
321    
322    conformation["flags"] = flags
323
324    return system_data, conformation
325
326
327def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
328    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
329    """@keyword[fennol_md] nblist_verbose
330    Print detailed neighbor list information.
331    Default: False
332    """
333    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
334    """@keyword[fennol_md] nblist_skin
335    Neighbor list skin distance in Angstroms.
336    Default: -1.0 (automatic)
337    """
338
339    ### CONFIGURE PREPROCESSING
340    preproc_state = unfreeze(model.preproc_state)
341    layer_state = []
342    for st in preproc_state["layers_state"]:
343        stnew = unfreeze(st)
344        if nblist_skin > 0:
345            stnew["nblist_skin"] = nblist_skin
346        if "nblist_mult_size" in simulation_parameters:
347            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
348            """@keyword[fennol_md] nblist_mult_size
349            Multiplier for neighbor list size.
350            Default: None
351            """
352        if "nblist_add_neigh" in simulation_parameters:
353            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
354            """@keyword[fennol_md] nblist_add_neigh
355            Additional neighbors to include in lists.
356            Default: None
357            """
358        layer_state.append(freeze(stnew))
359    preproc_state["layers_state"] = layer_state
360    preproc_state = freeze(preproc_state)
361
362    ## initial preprocessing
363    preproc_state = preproc_state.copy({"check_input": True})
364    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
365
366    preproc_state = preproc_state.copy({"check_input": False})
367
368    if nblist_verbose:
369        graphs_keys = list(model._graphs_properties.keys())
370        print("# graphs_keys: ", graphs_keys)
371        print("# nblist state:", preproc_state)
372
373    ### print model
374    if simulation_parameters.get("print_model", False):
375        """@keyword[fennol_md] print_model
376        Print detailed model information at startup.
377        Default: False
378        """
379        print(model.summarize(example_data=conformation))
380
381    return preproc_state, conformation
def load_model(simulation_parameters):
19def load_model(simulation_parameters):
20    model_file = simulation_parameters.get("model_file")
21    """@keyword[fennol_md] model_file
22    Path to the machine learning model file (.fnx format). Required parameter.
23    Type: str, Required
24    """
25    model_file = Path(str(model_file).strip())
26    if not model_file.exists():
27        raise FileNotFoundError(f"model file {model_file} not found")
28    else:
29        graph_config = simulation_parameters.get("graph_config", {})
30        """@keyword[fennol_md] graph_config
31        Advanced graph configuration for model initialization.
32        Default: {}
33        """
34        model = FENNIX.load(model_file, graph_config=graph_config)  # \
35        print(f"# model_file: {model_file}")
36
37    if "energy_terms" in simulation_parameters:
38        energy_terms = simulation_parameters["energy_terms"]
39        if isinstance(energy_terms, str):
40            energy_terms = energy_terms.split()
41        model.set_energy_terms(energy_terms)
42        print("# energy terms:", model.energy_terms)
43
44    return model
def load_system_data(simulation_parameters, fprec):
 47def load_system_data(simulation_parameters, fprec):
 48    ## LOAD SYSTEM CONFORMATION FROM FILES
 49    system_name = str(simulation_parameters.get("system_name", "system")).strip()
 50    """@keyword[fennol_md] system_name
 51    Name prefix for output files. If not specified, uses the xyz filename stem.
 52    Default: "system"
 53    """
 54    indexed = simulation_parameters.get("xyz_input/indexed", False)
 55    """@keyword[fennol_md] xyz_input/indexed
 56    Whether first column contains atom indices (Tinker format).
 57    Default: False
 58    """
 59    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", True)
 60    """@keyword[fennol_md] xyz_input/has_comment_line
 61    Whether file contains comment lines.
 62    Default: True
 63    """
 64    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 65    """@keyword[fennol_md] xyz_input/file
 66    Path to xyz/arc coordinate file. Required parameter.
 67    Type: str, Required
 68    """
 69    if not xyzfile.exists():
 70        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 71    system_name = str(simulation_parameters.get("system_name", xyzfile.stem)).strip()
 72    symbols, coordinates, _ = last_xyz_frame(
 73        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 74    )
 75    coordinates = coordinates.astype(fprec)
 76    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 77    nat = species.shape[0]
 78
 79    ## GET MASS
 80    mass_Da = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 81    deuterate = simulation_parameters.get("deuterate", False)
 82    """@keyword[fennol_md] deuterate
 83    Replace hydrogen masses with deuterium masses.
 84    Default: False
 85    """
 86    if deuterate:
 87        print("# Replacing all hydrogens with deuteriums")
 88        mass_Da[species == 1] *= 2.0
 89
 90    totmass_Da = mass_Da.sum()
 91
 92    mass = mass_Da.copy()
 93    hmr = simulation_parameters.get("hmr", 0)
 94    """@keyword[fennol_md] hmr
 95    Hydrogen mass repartitioning factor. 0 = no repartitioning.
 96    Default: 0
 97    """
 98    if hmr > 0:
 99        print(f"# Adding {hmr} Da to H masses and repartitioning on others for total mass conservation.")
100        Hmask = species == 1
101        added_mass = hmr * Hmask.sum()
102        mass[Hmask] += hmr
103        wmass = mass[~Hmask]
104        mass[~Hmask] -= added_mass * wmass/ wmass.sum()
105
106        assert np.isclose(mass.sum(), totmass_Da), "Mass conservation failed"
107
108    # convert to internal units
109    mass = mass / us.DA
110
111    ### GET TEMPERATURE
112    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
113    """@keyword[fennol_md] temperature
114    Target temperature in Kelvin.
115    Default: 300.0
116    """
117    kT = us.K_B * temperature 
118
119    ### GET TOTAL CHARGE
120    total_charge = simulation_parameters.get("total_charge", None)
121    """@keyword[fennol_md] total_charge
122    Total system charge for charged systems.
123    Default: None (interpreted as 0)
124    """
125    if total_charge is None:
126        total_charge = 0
127    else:
128        total_charge = int(total_charge)
129        print("# total charge: ", total_charge,"e")
130
131    ### ENERGY UNIT
132    energy_unit_str = simulation_parameters.get("energy_unit", "kcal/mol")
133    """@keyword[fennol_md] energy_unit
134    Energy unit for output. Common options: 'kcal/mol', 'eV', 'Ha', 'kJ/mol'.
135    Default: "kcal/mol"
136    """
137    energy_unit = us.get_multiplier(energy_unit_str)
138
139    ## SYSTEM DATA
140    system_data = {
141        "name": system_name,
142        "nat": nat,
143        "symbols": symbols,
144        "species": species,
145        "mass": mass,
146        "mass_Da": mass_Da,
147        "totmass_Da": totmass_Da,
148        "temperature": temperature,
149        "kT": kT,
150        "total_charge": total_charge,
151        "energy_unit": energy_unit,
152        "energy_unit_str": energy_unit_str,
153    }
154    input_flags = simulation_parameters.get("model_flags", [])
155    """@keyword[fennol_md] model_flags
156    Additional flags to pass to the model.
157    Default: []
158    """
159    flags = {f:None for f in input_flags}
160
161    ### Set boundary conditions
162    cell = simulation_parameters.get("cell", None)
163    """@keyword[fennol_md] cell
164    Unit cell vectors. Required for PBC. It is a sequence of floats:
165    - 9 floats: components of cell vectors [ax, ay, az, bx, by, bz, cx, cy, cz]
166    - 6 floats: lengths and angles [a, b, c, alpha, beta, gamma]
167    - 3 floats: lengths of cell vectors [a, b, c] (orthorhombic)
168    - 1 float: length of cell vectors (cubic cell)
169    Lengths are in Angstroms, angles in degrees.
170    Default: None
171    """
172    if cell is not None:
173        cell = parse_cell(cell).astype(fprec)
174        # cell = np.array(cell, dtype=fprec).reshape(3, 3)
175        reciprocal_cell = np.linalg.inv(cell)
176        volume = np.abs(np.linalg.det(cell))
177        print("# cell matrix:")
178        for l in cell:
179            print("# ", l)
180        # print(cell)
181        dens = (totmass_Da/volume) * (us.MOL/us.CM**3)
182        print("# density: ", dens.item(), " g/cm^3")
183        minimum_image = simulation_parameters.get("minimum_image", True)
184        """@keyword[fennol_md] minimum_image
185        Use minimum image convention for neighbor lists in periodic systems.
186        Default: True
187        """
188        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
189        """@keyword[fennol_md] estimate_pressure
190        Calculate and print pressure during simulation.
191        Default: False
192        """
193        print("# minimum_image: ", minimum_image)
194
195        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
196        """@keyword[fennol_md] xyz_input/crystal
197        Use crystal coordinates.
198        Default: False
199        """
200        if crystal_input:
201            coordinates = coordinates @ cell
202
203        pbc_data = {
204            "cell": cell,
205            "reciprocal_cell": reciprocal_cell,
206            "volume": volume,
207            "minimum_image": minimum_image,
208            "estimate_pressure": estimate_pressure,
209        }
210        if minimum_image:
211            flags["minimum_image"] = None
212    else:
213        pbc_data = None
214    system_data["pbc"] = pbc_data
215    system_data["initial_coordinates"] = coordinates.copy()
216
217    ### TOPOLOGY
218    topology_key = simulation_parameters.get("topology", None)
219    """@keyword[fennol_md] topology
220    Topology specification for molecular systems. Use "detect" for automatic detection.
221    Default: None
222    """
223    if topology_key is not None:
224        topology_key = str(topology_key).strip()
225        if topology_key.lower() == "detect":
226            topology = detect_topology(species,coordinates,cell=cell)
227            np.savetxt(system_name +".topo", topology+1, fmt="%d")
228            print("# Detected topology saved to", system_name + ".topo")
229        else:
230            assert Path(topology_key).exists(), f"Topology file {topology_key} not found"
231            topology = np.loadtxt(topology_key, dtype=np.int32)-1
232            assert topology.shape[1] == 2, "Topology file must have two columns (source, target)"
233            print("# Topology loaded from", topology_key)
234    else:
235        topology = None
236    
237    system_data["topology"] = topology
238
239    ### PIMD
240    nbeads = simulation_parameters.get("nbeads", None)
241    """@keyword[fennol_md] nbeads
242    Number of beads for Path Integral MD.
243    Default: None
244    """
245    nreplicas = simulation_parameters.get("nreplicas", None)
246    """@keyword[fennol_md] nreplicas
247    Number of replicas for independent replica simulations.
248    Default: None
249    """
250    if nbeads is not None:
251        nbeads = int(nbeads)
252        print("# nbeads: ", nbeads)
253        system_data["nbeads"] = nbeads
254        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1, 3)
255        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
256        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
257        natoms = np.array([nat] * nbeads, dtype=np.int32)
258
259        eigmat = np.zeros((nbeads, nbeads))
260        for i in range(nbeads - 1):
261            eigmat[i, i] = 2.0
262            eigmat[i, i + 1] = -1.0
263            eigmat[i + 1, i] = -1.0
264        eigmat[-1, -1] = 2.0
265        eigmat[0, -1] = -1.0
266        eigmat[-1, 0] = -1.0
267        omk, eigmat = np.linalg.eigh(eigmat)
268        omk[0] = 0.0
269        omk = (nbeads * kT / us.HBAR) * omk**0.5
270        for i in range(nbeads):
271            if eigmat[i, 0] < 0:
272                eigmat[i] *= -1.0
273        eigmat = jnp.asarray(eigmat, dtype=fprec)
274        system_data["omk"] = omk
275        system_data["eigmat"] = eigmat
276        nreplicas = None
277    elif nreplicas is not None:
278        nreplicas = int(nreplicas)
279        print("# nreplicas: ", nreplicas)
280        system_data["nreplicas"] = nreplicas
281        system_data["mass"] = np.repeat(mass[None, :], nreplicas, axis=0).reshape(-1)
282        system_data["species"] = np.repeat(species[None, :], nreplicas, axis=0).reshape(
283            -1
284        )
285        coordinates = np.repeat(coordinates[None, :, :], nreplicas, axis=0).reshape(
286            -1, 3
287        )
288        species = np.repeat(species[None, :], nreplicas, axis=0).reshape(-1)
289        bead_index = np.arange(nreplicas, dtype=np.int32).repeat(nat)
290        natoms = np.array([nat] * nreplicas, dtype=np.int32)
291    else:
292        system_data["nreplicas"] = 1
293        bead_index = np.array([0] * nat, dtype=np.int32)
294        natoms = np.array([nat], dtype=np.int32)
295
296    conformation = {
297        "species": species,
298        "coordinates": coordinates,
299        "batch_index": bead_index,
300        "natoms": natoms,
301        "total_charge": total_charge,
302    }
303    if cell is not None:
304        cell = cell[None, :, :]
305        reciprocal_cell = reciprocal_cell[None, :, :]
306        if nbeads is not None:
307            cell = np.repeat(cell, nbeads, axis=0)
308            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
309        elif nreplicas is not None:
310            cell = np.repeat(cell, nreplicas, axis=0)
311            reciprocal_cell = np.repeat(reciprocal_cell, nreplicas, axis=0)
312        conformation["cells"] = cell
313        conformation["reciprocal_cells"] = reciprocal_cell
314
315    additional_keys = simulation_parameters.get("additional_keys", {})
316    """@keyword[fennol_md] additional_keys
317    Additional custom keys for model input.
318    Default: {}
319    """
320    for key, value in additional_keys.items():
321        conformation[key] = value
322    
323    conformation["flags"] = flags
324
325    return system_data, conformation
def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
328def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
329    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
330    """@keyword[fennol_md] nblist_verbose
331    Print detailed neighbor list information.
332    Default: False
333    """
334    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
335    """@keyword[fennol_md] nblist_skin
336    Neighbor list skin distance in Angstroms.
337    Default: -1.0 (automatic)
338    """
339
340    ### CONFIGURE PREPROCESSING
341    preproc_state = unfreeze(model.preproc_state)
342    layer_state = []
343    for st in preproc_state["layers_state"]:
344        stnew = unfreeze(st)
345        if nblist_skin > 0:
346            stnew["nblist_skin"] = nblist_skin
347        if "nblist_mult_size" in simulation_parameters:
348            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
349            """@keyword[fennol_md] nblist_mult_size
350            Multiplier for neighbor list size.
351            Default: None
352            """
353        if "nblist_add_neigh" in simulation_parameters:
354            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
355            """@keyword[fennol_md] nblist_add_neigh
356            Additional neighbors to include in lists.
357            Default: None
358            """
359        layer_state.append(freeze(stnew))
360    preproc_state["layers_state"] = layer_state
361    preproc_state = freeze(preproc_state)
362
363    ## initial preprocessing
364    preproc_state = preproc_state.copy({"check_input": True})
365    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
366
367    preproc_state = preproc_state.copy({"check_input": False})
368
369    if nblist_verbose:
370        graphs_keys = list(model._graphs_properties.keys())
371        print("# graphs_keys: ", graphs_keys)
372        print("# nblist state:", preproc_state)
373
374    ### print model
375    if simulation_parameters.get("print_model", False):
376        """@keyword[fennol_md] print_model
377        Print detailed model information at startup.
378        Default: False
379        """
380        print(model.summarize(example_data=conformation))
381
382    return preproc_state, conformation