fennol.md.initial

  1import sys, os, io
  2import argparse
  3import time
  4from pathlib import Path
  5import math
  6
  7import numpy as np
  8from typing import Optional, Callable
  9from collections import defaultdict
 10from functools import partial
 11import jax
 12import jax.numpy as jnp
 13
 14from flax.core import freeze, unfreeze
 15
 16from ..utils.io import last_xyz_frame
 17
 18
 19from ..models import FENNIX
 20
 21from ..utils.periodic_table import PERIODIC_TABLE_REV_IDX, ATOMIC_MASSES
 22from ..utils.atomic_units import AtomicUnits as au
 23from ..utils.input_parser import parse_input
 24from .thermostats import get_thermostat
 25
 26from copy import deepcopy
 27
 28
 29def load_model(simulation_parameters):
 30    model_file = simulation_parameters.get("model_file")
 31    model_file = Path(str(model_file).strip())
 32    if not model_file.exists():
 33        raise FileNotFoundError(f"model file {model_file} not found")
 34    else:
 35        graph_config = simulation_parameters.get("graph_config", {})
 36        model = FENNIX.load(model_file, graph_config=graph_config)  # \
 37        print(f"# model_file: {model_file}")
 38
 39    if "energy_terms" in simulation_parameters:
 40        energy_terms = simulation_parameters["energy_terms"]
 41        if isinstance(energy_terms, str):
 42            energy_terms = energy_terms.split()
 43        model.set_energy_terms(energy_terms)
 44        print("# energy terms:", model.energy_terms)
 45
 46    return model
 47
 48
 49def load_system_data(simulation_parameters, fprec):
 50    ## LOAD SYSTEM CONFORMATION FROM FILES
 51    system_name = str(simulation_parameters.get("system", "system")).strip()
 52    indexed = simulation_parameters.get("xyz_input/indexed", True)
 53    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", False)
 54    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 55    if not xyzfile.exists():
 56        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 57    system_name = str(simulation_parameters.get("system", xyzfile.stem)).strip()
 58    symbols, coordinates, _ = last_xyz_frame(
 59        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 60    )
 61    coordinates = coordinates.astype(fprec)
 62    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 63    nat = species.shape[0]
 64
 65    ## GET MASS
 66    mass_amu = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 67    deuterate = simulation_parameters.get("deuterate", False)
 68    if deuterate:
 69        print("# Replacing all hydrogens with deuteriums")
 70        mass_amu[species == 1] *= 2.0
 71    mass = mass_amu * (au.MPROT * (au.FS / au.BOHR) ** 2)
 72
 73    ### GET TEMPERATURE
 74    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
 75    kT = temperature / au.KELVIN
 76    totmass_amu = mass_amu.sum()/6.02214129e-1
 77
 78    ## SYSTEM DATA
 79    system_data = {
 80        "name": system_name,
 81        "nat": nat,
 82        "symbols": symbols,
 83        "species": species,
 84        "mass": mass,
 85        "temperature": temperature,
 86        "kT": kT,
 87        "totmass_amu": totmass_amu,
 88    }
 89
 90    ### Set boundary conditions
 91    cell = simulation_parameters.get("cell", None)
 92    if cell is not None:
 93        cell = np.array(cell, dtype=fprec).reshape(3, 3)
 94        reciprocal_cell = np.linalg.inv(cell)
 95        volume = np.abs(np.linalg.det(cell))
 96        print("# cell matrix:")
 97        for l in cell:
 98            print("# ", l)
 99        # print(cell)
100        dens = totmass_amu  / volume
101        print("# density: ", dens.item(), " g/cm^3")
102        minimum_image = simulation_parameters.get("minimum_image", True)
103        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
104        print("# minimum_image: ", minimum_image)
105
106        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
107        if crystal_input:
108            coordinates = coordinates @ cell
109
110        pbc_data = {
111            "cell": cell,
112            "reciprocal_cell": reciprocal_cell,
113            "volume": volume,
114            "minimum_image": minimum_image,
115            "estimate_pressure": estimate_pressure,
116        }
117    else:
118        pbc_data = None
119    system_data["pbc"] = pbc_data
120
121    ### PIMD
122    nbeads = simulation_parameters.get("nbeads", None)
123    if nbeads is not None:
124        nbeads = int(nbeads)
125        print("# nbeads: ", nbeads)
126        system_data["nbeads"] = nbeads
127        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1,3)
128        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
129        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
130        natoms = np.array([nat] * nbeads, dtype=np.int32)
131
132        eigmat = np.zeros((nbeads, nbeads))
133        for i in range(nbeads - 1):
134            eigmat[i, i] = 2.0
135            eigmat[i, i + 1] = -1.0
136            eigmat[i + 1, i] = -1.0
137        eigmat[-1, -1] = 2.0
138        eigmat[0, -1] = -1.0
139        eigmat[-1, 0] = -1.0
140        omk, eigmat = np.linalg.eigh(eigmat)
141        omk[0] = 0.0
142        omk = nbeads * kT * omk**0.5 / au.FS
143        for i in range(nbeads):
144            if eigmat[i, 0] < 0:
145                eigmat[i] *= -1.0
146        eigmat = jnp.asarray(eigmat, dtype=fprec)
147        system_data["omk"] = omk
148        system_data["eigmat"] = eigmat
149    else:
150        bead_index = np.array([0] * nat, dtype=np.int32)
151        natoms = np.array([nat], dtype=np.int32)
152
153    conformation = {
154        "species": species,
155        "coordinates": coordinates,
156        "batch_index": bead_index,
157        "natoms": natoms,
158    }
159    if cell is not None:
160        cell = cell[None, :, :]
161        reciprocal_cell = reciprocal_cell[None, :, :]
162        if nbeads is not None:
163            cell = np.repeat(cell, nbeads, axis=0)
164            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
165        conformation["cells"] = cell
166        conformation["reciprocal_cells"] = reciprocal_cell
167    
168    additional_keys = simulation_parameters.get("additional_keys", {})
169    for key, value in additional_keys.items():
170        conformation[key] = value
171
172    return system_data, conformation
173
174
175def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
176    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
177    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
178    pbc_data = system_data.get("pbc", None)
179
180    ### CONFIGURE PREPROCESSING
181    preproc_state = unfreeze(model.preproc_state)
182    layer_state = []
183    for st in preproc_state["layers_state"]:
184        stnew = unfreeze(st)
185        if pbc_data is not None:
186            stnew["minimum_image"] = pbc_data["minimum_image"]
187        if nblist_skin > 0:
188            stnew["nblist_skin"] = nblist_skin
189        if "nblist_mult_size" in simulation_parameters:
190            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
191        if "nblist_add_neigh" in simulation_parameters:
192            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
193        layer_state.append(freeze(stnew))
194    preproc_state["layers_state"] = layer_state
195    preproc_state = freeze(preproc_state)
196
197    ## initial preprocessing
198    preproc_state = preproc_state.copy({"check_input": True})
199    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
200
201    preproc_state = preproc_state.copy({"check_input": False})
202
203    if nblist_verbose:
204        graphs_keys = list(model._graphs_properties.keys())
205        print("# graphs_keys: ", graphs_keys)
206        print("# nblist state:", preproc_state)
207
208    ### print model
209    if simulation_parameters.get("print_model", False):
210        print(model.summarize(example_data=conformation))
211
212    return preproc_state, conformation
213
214
215def initialize_system(conformation, vel, model, system_data, fprec):
216    ## initial energy and forces
217    print("# Computing initial energy and forces")
218    e, f, vir,_ = model._energy_and_forces_and_virial(model.variables, conformation)
219    model_energy_unit = model.Ha_to_model_energy
220    f = np.array(f) / model_energy_unit
221    epot = np.mean(e) / model_energy_unit
222    vir = np.mean(vir, axis=0) / model_energy_unit
223    
224    if "nbeads" in system_data:
225        ek = 0.5 * jnp.sum(system_data["mass"][:, None,None] * vel[0,:,:,None]*vel[0,:,None,:],axis=0)
226    else:
227        ek = 0.5 * jnp.sum(system_data["mass"][:, None,None] * vel[:,:,None]*vel[:,None,:],axis=0)
228
229    ## build system
230    system = {}
231    system["ek_tensor"] = ek
232    system["ek"] = jnp.trace(ek)
233    system["epot"] = epot
234    system["vel"] = vel.astype(fprec)
235    if "cells" in conformation:
236        system["virial"] = vir
237        system["cell"] = conformation["cells"][0]
238    if "nbeads" in system_data:
239        nbeads = system_data["nbeads"]
240        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
241        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
242        system["coordinates"] = eigx
243        system["forces"] = jnp.einsum(
244            "in,i...->n...", system_data["eigmat"], f.reshape(nbeads, -1, 3)
245        ) * (1.0 / nbeads**0.5)
246    else:
247        system["coordinates"] = conformation["coordinates"]
248        system["forces"] = f
249
250    return system
def load_model(simulation_parameters):
30def load_model(simulation_parameters):
31    model_file = simulation_parameters.get("model_file")
32    model_file = Path(str(model_file).strip())
33    if not model_file.exists():
34        raise FileNotFoundError(f"model file {model_file} not found")
35    else:
36        graph_config = simulation_parameters.get("graph_config", {})
37        model = FENNIX.load(model_file, graph_config=graph_config)  # \
38        print(f"# model_file: {model_file}")
39
40    if "energy_terms" in simulation_parameters:
41        energy_terms = simulation_parameters["energy_terms"]
42        if isinstance(energy_terms, str):
43            energy_terms = energy_terms.split()
44        model.set_energy_terms(energy_terms)
45        print("# energy terms:", model.energy_terms)
46
47    return model
def load_system_data(simulation_parameters, fprec):
 50def load_system_data(simulation_parameters, fprec):
 51    ## LOAD SYSTEM CONFORMATION FROM FILES
 52    system_name = str(simulation_parameters.get("system", "system")).strip()
 53    indexed = simulation_parameters.get("xyz_input/indexed", True)
 54    has_comment_line = simulation_parameters.get("xyz_input/has_comment_line", False)
 55    xyzfile = Path(simulation_parameters.get("xyz_input/file", system_name + ".xyz"))
 56    if not xyzfile.exists():
 57        raise FileNotFoundError(f"xyz file {xyzfile} not found")
 58    system_name = str(simulation_parameters.get("system", xyzfile.stem)).strip()
 59    symbols, coordinates, _ = last_xyz_frame(
 60        xyzfile, indexed=indexed, has_comment_line=has_comment_line
 61    )
 62    coordinates = coordinates.astype(fprec)
 63    species = np.array([PERIODIC_TABLE_REV_IDX[s] for s in symbols], dtype=np.int32)
 64    nat = species.shape[0]
 65
 66    ## GET MASS
 67    mass_amu = np.array(ATOMIC_MASSES, dtype=fprec)[species]
 68    deuterate = simulation_parameters.get("deuterate", False)
 69    if deuterate:
 70        print("# Replacing all hydrogens with deuteriums")
 71        mass_amu[species == 1] *= 2.0
 72    mass = mass_amu * (au.MPROT * (au.FS / au.BOHR) ** 2)
 73
 74    ### GET TEMPERATURE
 75    temperature = np.clip(simulation_parameters.get("temperature", 300.0), 1.0e-6, None)
 76    kT = temperature / au.KELVIN
 77    totmass_amu = mass_amu.sum()/6.02214129e-1
 78
 79    ## SYSTEM DATA
 80    system_data = {
 81        "name": system_name,
 82        "nat": nat,
 83        "symbols": symbols,
 84        "species": species,
 85        "mass": mass,
 86        "temperature": temperature,
 87        "kT": kT,
 88        "totmass_amu": totmass_amu,
 89    }
 90
 91    ### Set boundary conditions
 92    cell = simulation_parameters.get("cell", None)
 93    if cell is not None:
 94        cell = np.array(cell, dtype=fprec).reshape(3, 3)
 95        reciprocal_cell = np.linalg.inv(cell)
 96        volume = np.abs(np.linalg.det(cell))
 97        print("# cell matrix:")
 98        for l in cell:
 99            print("# ", l)
100        # print(cell)
101        dens = totmass_amu  / volume
102        print("# density: ", dens.item(), " g/cm^3")
103        minimum_image = simulation_parameters.get("minimum_image", True)
104        estimate_pressure = simulation_parameters.get("estimate_pressure", False)
105        print("# minimum_image: ", minimum_image)
106
107        crystal_input = simulation_parameters.get("xyz_input/crystal", False)
108        if crystal_input:
109            coordinates = coordinates @ cell
110
111        pbc_data = {
112            "cell": cell,
113            "reciprocal_cell": reciprocal_cell,
114            "volume": volume,
115            "minimum_image": minimum_image,
116            "estimate_pressure": estimate_pressure,
117        }
118    else:
119        pbc_data = None
120    system_data["pbc"] = pbc_data
121
122    ### PIMD
123    nbeads = simulation_parameters.get("nbeads", None)
124    if nbeads is not None:
125        nbeads = int(nbeads)
126        print("# nbeads: ", nbeads)
127        system_data["nbeads"] = nbeads
128        coordinates = np.repeat(coordinates[None, :, :], nbeads, axis=0).reshape(-1,3)
129        species = np.repeat(species[None, :], nbeads, axis=0).reshape(-1)
130        bead_index = np.arange(nbeads, dtype=np.int32).repeat(nat)
131        natoms = np.array([nat] * nbeads, dtype=np.int32)
132
133        eigmat = np.zeros((nbeads, nbeads))
134        for i in range(nbeads - 1):
135            eigmat[i, i] = 2.0
136            eigmat[i, i + 1] = -1.0
137            eigmat[i + 1, i] = -1.0
138        eigmat[-1, -1] = 2.0
139        eigmat[0, -1] = -1.0
140        eigmat[-1, 0] = -1.0
141        omk, eigmat = np.linalg.eigh(eigmat)
142        omk[0] = 0.0
143        omk = nbeads * kT * omk**0.5 / au.FS
144        for i in range(nbeads):
145            if eigmat[i, 0] < 0:
146                eigmat[i] *= -1.0
147        eigmat = jnp.asarray(eigmat, dtype=fprec)
148        system_data["omk"] = omk
149        system_data["eigmat"] = eigmat
150    else:
151        bead_index = np.array([0] * nat, dtype=np.int32)
152        natoms = np.array([nat], dtype=np.int32)
153
154    conformation = {
155        "species": species,
156        "coordinates": coordinates,
157        "batch_index": bead_index,
158        "natoms": natoms,
159    }
160    if cell is not None:
161        cell = cell[None, :, :]
162        reciprocal_cell = reciprocal_cell[None, :, :]
163        if nbeads is not None:
164            cell = np.repeat(cell, nbeads, axis=0)
165            reciprocal_cell = np.repeat(reciprocal_cell, nbeads, axis=0)
166        conformation["cells"] = cell
167        conformation["reciprocal_cells"] = reciprocal_cell
168    
169    additional_keys = simulation_parameters.get("additional_keys", {})
170    for key, value in additional_keys.items():
171        conformation[key] = value
172
173    return system_data, conformation
def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
176def initialize_preprocessing(simulation_parameters, model, conformation, system_data):
177    nblist_verbose = simulation_parameters.get("nblist_verbose", False)
178    nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
179    pbc_data = system_data.get("pbc", None)
180
181    ### CONFIGURE PREPROCESSING
182    preproc_state = unfreeze(model.preproc_state)
183    layer_state = []
184    for st in preproc_state["layers_state"]:
185        stnew = unfreeze(st)
186        if pbc_data is not None:
187            stnew["minimum_image"] = pbc_data["minimum_image"]
188        if nblist_skin > 0:
189            stnew["nblist_skin"] = nblist_skin
190        if "nblist_mult_size" in simulation_parameters:
191            stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
192        if "nblist_add_neigh" in simulation_parameters:
193            stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
194        layer_state.append(freeze(stnew))
195    preproc_state["layers_state"] = layer_state
196    preproc_state = freeze(preproc_state)
197
198    ## initial preprocessing
199    preproc_state = preproc_state.copy({"check_input": True})
200    preproc_state, conformation = model.preprocessing(preproc_state, conformation)
201
202    preproc_state = preproc_state.copy({"check_input": False})
203
204    if nblist_verbose:
205        graphs_keys = list(model._graphs_properties.keys())
206        print("# graphs_keys: ", graphs_keys)
207        print("# nblist state:", preproc_state)
208
209    ### print model
210    if simulation_parameters.get("print_model", False):
211        print(model.summarize(example_data=conformation))
212
213    return preproc_state, conformation
def initialize_system(conformation, vel, model, system_data, fprec):
216def initialize_system(conformation, vel, model, system_data, fprec):
217    ## initial energy and forces
218    print("# Computing initial energy and forces")
219    e, f, vir,_ = model._energy_and_forces_and_virial(model.variables, conformation)
220    model_energy_unit = model.Ha_to_model_energy
221    f = np.array(f) / model_energy_unit
222    epot = np.mean(e) / model_energy_unit
223    vir = np.mean(vir, axis=0) / model_energy_unit
224    
225    if "nbeads" in system_data:
226        ek = 0.5 * jnp.sum(system_data["mass"][:, None,None] * vel[0,:,:,None]*vel[0,:,None,:],axis=0)
227    else:
228        ek = 0.5 * jnp.sum(system_data["mass"][:, None,None] * vel[:,:,None]*vel[:,None,:],axis=0)
229
230    ## build system
231    system = {}
232    system["ek_tensor"] = ek
233    system["ek"] = jnp.trace(ek)
234    system["epot"] = epot
235    system["vel"] = vel.astype(fprec)
236    if "cells" in conformation:
237        system["virial"] = vir
238        system["cell"] = conformation["cells"][0]
239    if "nbeads" in system_data:
240        nbeads = system_data["nbeads"]
241        coordinates = conformation["coordinates"].reshape(nbeads, -1, 3)
242        eigx = jnp.zeros_like(coordinates).at[0].set(coordinates[0])
243        system["coordinates"] = eigx
244        system["forces"] = jnp.einsum(
245            "in,i...->n...", system_data["eigmat"], f.reshape(nbeads, -1, 3)
246        ) * (1.0 / nbeads**0.5)
247    else:
248        system["coordinates"] = conformation["coordinates"]
249        system["forces"] = f
250
251    return system