fennol.md.spectra

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7import os
  8from pathlib import Path
  9from flax.core import freeze, unfreeze
 10
 11from ..models.fennix import FENNIX
 12from .utils import us 
 13from ..utils import Counter
 14from ..utils.deconvolution import (
 15    deconvolute_spectrum,
 16    kernel_lorentz_pot,
 17    kernel_lorentz,
 18)
 19
 20def initialize_ir_spectrum(simulation_parameters,system_data,fprec,dt,apply_kubo_fact=False):
 21    state = {}
 22
 23    parameters = simulation_parameters.get("ir_parameters", {})
 24    """@keyword[fennol_md] ir_parameters
 25    Parameters for infrared spectrum calculation including dipole model settings.
 26    Required for ir_spectrum=True
 27    """
 28    dipole_model = parameters["dipole_model"]
 29    dipole_model = Path(str(dipole_model).strip())
 30    if not dipole_model.exists():
 31        raise FileNotFoundError(f"Dipole model file {dipole_model} not found")
 32    else:
 33        print(f"# Using '{dipole_model}' as dipole model.")
 34        dipole_model = FENNIX.load(dipole_model)
 35
 36        nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
 37        """@keyword[fennol_md] nblist_skin
 38        Neighbor list skin distance for dipole model preprocessing (in Angstroms).
 39        Default: -1.0
 40        """
 41        pbc_data = system_data.get("pbc", None)
 42
 43        ### CONFIGURE PREPROCESSING
 44        preproc_state = unfreeze(dipole_model.preproc_state)
 45        layer_state = []
 46        for st in preproc_state["layers_state"]:
 47            stnew = unfreeze(st)
 48            if nblist_skin > 0:
 49                stnew["nblist_skin"] = nblist_skin
 50            if "nblist_mult_size" in simulation_parameters:
 51                stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
 52            if "nblist_add_neigh" in simulation_parameters:
 53                stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
 54            layer_state.append(freeze(stnew))
 55        preproc_state["layers_state"] = layer_state
 56        dipole_model.preproc_state = freeze(preproc_state)
 57
 58
 59    Tseg = parameters.get("tseg", 1.0 / us.PS)
 60    """@keyword[fennol_md] ir_parameters/tseg
 61    Time segment length for IR spectrum calculation.
 62    Default: 1.0 ps
 63    """
 64    nseg = int(Tseg / dt)
 65    Tseg = nseg * dt
 66    dom = 2 * np.pi / (3 * Tseg)
 67    omegacut = parameters.get("omegacut", 15000.0 / us.CM1)
 68    """@keyword[fennol_md] ir_parameters/omegacut
 69    Cutoff frequency for IR spectrum.
 70    Default: 15000.0 cm⁻¹
 71    """
 72    nom = int(omegacut / dom)
 73    omega = dom * np.arange((3 * nseg) // 2 + 1)
 74
 75    assert (
 76        omegacut < omega[-1]
 77    ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1"
 78
 79    startsave = parameters.get("startsave", 1)
 80    """@keyword[fennol_md] ir_parameters/startsave
 81    Start saving IR statistics after this many segments.
 82    Default: 1
 83    """
 84    counter = Counter(nseg, startsave=startsave)
 85    state["istep"] = 0
 86    state["nsample"] = 0
 87    state["nadapt"] = 0
 88
 89    use_qvel = parameters.get("use_qvel", False)
 90    """@keyword[fennol_md] ir_parameters/use_qvel
 91    Use quantum velocity correction for IR spectrum calculation.
 92    Default: False
 93    """
 94    if use_qvel:
 95        state["qvel"] = jnp.zeros((3,), dtype=fprec)
 96    else:
 97        nat = system_data["nat"]
 98        state["musave-2"] = jnp.zeros((3,), dtype=fprec)
 99        state["musave-1"] = jnp.zeros((3,), dtype=fprec)
100        state["qsave-2"] = jnp.zeros((nat,1), dtype=fprec)
101        state["qsave-1"] = jnp.zeros((nat,1), dtype=fprec)
102        state["pos_save-2"] = jnp.zeros((nat,3), dtype=fprec)
103        state["pos_save-1"] = jnp.zeros((nat,3), dtype=fprec)
104    state["mudot"] = jnp.zeros((nseg, 3), dtype=fprec)
105    state["Cmumu"] = jnp.zeros((nom,), dtype=fprec)
106    state["first"] = True
107
108    kT = system_data["kT"]
109    kubo_fact = np.ones_like(omega)
110    if apply_kubo_fact:
111        uu = 0.5*us.HBAR*omega[1:]/kT
112        kubo_fact[1:] = np.tanh(uu)/uu
113
114    do_deconvolution = parameters.get("deconvolution", False)
115    """@keyword[fennol_md] ir_parameters/deconvolution
116    Apply deconvolution to IR spectrum for better resolution.
117    Default: False
118    """
119    if do_deconvolution:
120        gamma = simulation_parameters.get("gamma", 1.0 / us.THZ)
121        """@keyword[fennol_md] gamma
122        Friction coefficient for deconvolution of IR spectra.
123        Default: 1.0 ps^-1
124        """
125        niter_deconv = parameters.get("niter_deconv", 20)
126        """@keyword[fennol_md] ir_parameters/niter_deconv
127        Number of iterations for IR spectrum deconvolution.
128        Default: 20
129        """
130        print("# Deconvolution of IR spectra with gamma=", gamma*(1./us.PS),"ps^-1 and niter=",niter_deconv)
131    
132
133    temp_K = system_data["temperature"]
134    c=2.99792458e-2 # speed of light in cm/ps
135    mufact = 1000*418.40*332.063714*2*np.pi**2/(0.831446215*temp_K*3.*c)
136    pbc_data = system_data.get("pbc", None)
137    if pbc_data is not None:
138        cell = pbc_data["cell"]
139        volume = np.abs(np.linalg.det(cell))
140        mufact = mufact/volume
141
142
143    @jax.jit
144    def compute_spectra(state):
145        mudot = state["mudot"]
146        smu = jnp.fft.rfft(mudot, 3 * nseg, axis=0, norm="ortho")
147        Cmumu = dt * jnp.sum(jnp.abs(smu[:nom]) ** 2, axis=-1)
148
149        nsinv = 1.0 / state["nsample"]
150        b1 = 1.0 - nsinv
151        Cmumu = state["Cmumu"] * b1 + Cmumu * nsinv
152
153        return {
154            **state,
155            "Cmumu": Cmumu,
156        }
157
158    def write_spectra_to_file(state):
159        Cmumu_avg = np.array(state["Cmumu"])*mufact*kubo_fact[:nom]
160
161        if do_deconvolution:
162            s_out, s_rec, K_D = deconvolute_spectrum(
163                Cmumu_avg,
164                omega[:nom],
165                gamma,
166                niter_deconv,
167                kernel=kernel_lorentz,
168                trans=False,
169                symmetrize=True,
170                verbose=False,
171                K_D=state.get("K_D", None),
172            )
173            state = {**state, "K_D": K_D}
174            spectra=(s_out,Cmumu_avg,s_rec)
175        else:
176            spectra=(Cmumu_avg,)
177
178        columns = np.column_stack(
179            (
180                omega[:nom] * us.CM1,
181                *spectra,
182            )
183        )
184        np.savetxt(
185            f"IR_spectrum.out",
186            columns,
187            fmt="%12.6f",
188            header="#omega Cmudot",
189        )
190        print("# IR spectrum written.")
191
192        return state
193    
194    @jax.jit
195    def save_dipole(q,vel,pos,dip,cell_reciprocal, state):
196
197        q=q.reshape(-1,1)
198        if use_qvel:
199            qvel = (q*vel).sum(axis=0)
200        if "first" in state:
201            state = {
202                **state,
203                "musave-2":dip,
204                "musave-1":dip,
205            }
206            if use_qvel:
207                state["qvel"] = qvel
208            else:
209                state["pos_save-2"] = pos
210                state["pos_save-1"] = pos
211                state["qsave-2"] = q
212                state["qsave-1"] = q
213            del state["first"]
214
215        istep = state["istep"]
216        new_state = {**state, "istep": istep + 1}
217        new_state["musave-2"] = state["musave-1"]
218        new_state["musave-1"] = dip
219        
220        dipdot = (dip - state["musave-2"])/(2*dt)
221        if use_qvel:
222            qrdot = state["qvel"]
223            new_state["qvel"] = qvel
224        else:
225            new_state["pos_save-2"] = state["pos_save-1"]
226            new_state["pos_save-1"] = pos
227            new_state["qsave-2"] = state["qsave-1"]
228            new_state["qsave-1"] = q
229
230            if cell_reciprocal is not None:
231                cell,reciprocal_cell = cell_reciprocal
232                vec = pos - state["pos_save-2"]
233                shift = -jnp.round(jnp.dot(vec, reciprocal_cell))
234                pos = pos + jnp.dot(shift, cell)
235
236            qrdot = (q*pos - state["qsave-2"]*state["pos_save-2"]).sum(axis=0)/(2*dt)
237
238        new_state["mudot"] = state["mudot"].at[istep].set(qrdot+dipdot)
239
240        return new_state
241
242    # @jax.jit
243    # def save_dipole(mudot, state):
244
245    #     istep = state["istep"]
246    #     new_state = {**state, "istep": istep + 1}
247    #     new_state["mudot"] = state["mudot"].at[istep].set(mudot)
248
249    #     return new_state
250    
251    def postprocess(state):
252        counter.increment()
253        if not counter.is_reset_step:
254            return state
255        state["nadapt"] += 1
256        state["nsample"] = max(state["nadapt"] - startsave + 1, 1)
257        state = compute_spectra(state)
258        state["istep"] = 0
259        state = write_spectra_to_file(state)
260        return state
261    
262    return dipole_model,state,save_dipole, postprocess
263        
264        
def initialize_ir_spectrum(simulation_parameters, system_data, fprec, dt, apply_kubo_fact=False):
 21def initialize_ir_spectrum(simulation_parameters,system_data,fprec,dt,apply_kubo_fact=False):
 22    state = {}
 23
 24    parameters = simulation_parameters.get("ir_parameters", {})
 25    """@keyword[fennol_md] ir_parameters
 26    Parameters for infrared spectrum calculation including dipole model settings.
 27    Required for ir_spectrum=True
 28    """
 29    dipole_model = parameters["dipole_model"]
 30    dipole_model = Path(str(dipole_model).strip())
 31    if not dipole_model.exists():
 32        raise FileNotFoundError(f"Dipole model file {dipole_model} not found")
 33    else:
 34        print(f"# Using '{dipole_model}' as dipole model.")
 35        dipole_model = FENNIX.load(dipole_model)
 36
 37        nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
 38        """@keyword[fennol_md] nblist_skin
 39        Neighbor list skin distance for dipole model preprocessing (in Angstroms).
 40        Default: -1.0
 41        """
 42        pbc_data = system_data.get("pbc", None)
 43
 44        ### CONFIGURE PREPROCESSING
 45        preproc_state = unfreeze(dipole_model.preproc_state)
 46        layer_state = []
 47        for st in preproc_state["layers_state"]:
 48            stnew = unfreeze(st)
 49            if nblist_skin > 0:
 50                stnew["nblist_skin"] = nblist_skin
 51            if "nblist_mult_size" in simulation_parameters:
 52                stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
 53            if "nblist_add_neigh" in simulation_parameters:
 54                stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
 55            layer_state.append(freeze(stnew))
 56        preproc_state["layers_state"] = layer_state
 57        dipole_model.preproc_state = freeze(preproc_state)
 58
 59
 60    Tseg = parameters.get("tseg", 1.0 / us.PS)
 61    """@keyword[fennol_md] ir_parameters/tseg
 62    Time segment length for IR spectrum calculation.
 63    Default: 1.0 ps
 64    """
 65    nseg = int(Tseg / dt)
 66    Tseg = nseg * dt
 67    dom = 2 * np.pi / (3 * Tseg)
 68    omegacut = parameters.get("omegacut", 15000.0 / us.CM1)
 69    """@keyword[fennol_md] ir_parameters/omegacut
 70    Cutoff frequency for IR spectrum.
 71    Default: 15000.0 cm⁻¹
 72    """
 73    nom = int(omegacut / dom)
 74    omega = dom * np.arange((3 * nseg) // 2 + 1)
 75
 76    assert (
 77        omegacut < omega[-1]
 78    ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1"
 79
 80    startsave = parameters.get("startsave", 1)
 81    """@keyword[fennol_md] ir_parameters/startsave
 82    Start saving IR statistics after this many segments.
 83    Default: 1
 84    """
 85    counter = Counter(nseg, startsave=startsave)
 86    state["istep"] = 0
 87    state["nsample"] = 0
 88    state["nadapt"] = 0
 89
 90    use_qvel = parameters.get("use_qvel", False)
 91    """@keyword[fennol_md] ir_parameters/use_qvel
 92    Use quantum velocity correction for IR spectrum calculation.
 93    Default: False
 94    """
 95    if use_qvel:
 96        state["qvel"] = jnp.zeros((3,), dtype=fprec)
 97    else:
 98        nat = system_data["nat"]
 99        state["musave-2"] = jnp.zeros((3,), dtype=fprec)
100        state["musave-1"] = jnp.zeros((3,), dtype=fprec)
101        state["qsave-2"] = jnp.zeros((nat,1), dtype=fprec)
102        state["qsave-1"] = jnp.zeros((nat,1), dtype=fprec)
103        state["pos_save-2"] = jnp.zeros((nat,3), dtype=fprec)
104        state["pos_save-1"] = jnp.zeros((nat,3), dtype=fprec)
105    state["mudot"] = jnp.zeros((nseg, 3), dtype=fprec)
106    state["Cmumu"] = jnp.zeros((nom,), dtype=fprec)
107    state["first"] = True
108
109    kT = system_data["kT"]
110    kubo_fact = np.ones_like(omega)
111    if apply_kubo_fact:
112        uu = 0.5*us.HBAR*omega[1:]/kT
113        kubo_fact[1:] = np.tanh(uu)/uu
114
115    do_deconvolution = parameters.get("deconvolution", False)
116    """@keyword[fennol_md] ir_parameters/deconvolution
117    Apply deconvolution to IR spectrum for better resolution.
118    Default: False
119    """
120    if do_deconvolution:
121        gamma = simulation_parameters.get("gamma", 1.0 / us.THZ)
122        """@keyword[fennol_md] gamma
123        Friction coefficient for deconvolution of IR spectra.
124        Default: 1.0 ps^-1
125        """
126        niter_deconv = parameters.get("niter_deconv", 20)
127        """@keyword[fennol_md] ir_parameters/niter_deconv
128        Number of iterations for IR spectrum deconvolution.
129        Default: 20
130        """
131        print("# Deconvolution of IR spectra with gamma=", gamma*(1./us.PS),"ps^-1 and niter=",niter_deconv)
132    
133
134    temp_K = system_data["temperature"]
135    c=2.99792458e-2 # speed of light in cm/ps
136    mufact = 1000*418.40*332.063714*2*np.pi**2/(0.831446215*temp_K*3.*c)
137    pbc_data = system_data.get("pbc", None)
138    if pbc_data is not None:
139        cell = pbc_data["cell"]
140        volume = np.abs(np.linalg.det(cell))
141        mufact = mufact/volume
142
143
144    @jax.jit
145    def compute_spectra(state):
146        mudot = state["mudot"]
147        smu = jnp.fft.rfft(mudot, 3 * nseg, axis=0, norm="ortho")
148        Cmumu = dt * jnp.sum(jnp.abs(smu[:nom]) ** 2, axis=-1)
149
150        nsinv = 1.0 / state["nsample"]
151        b1 = 1.0 - nsinv
152        Cmumu = state["Cmumu"] * b1 + Cmumu * nsinv
153
154        return {
155            **state,
156            "Cmumu": Cmumu,
157        }
158
159    def write_spectra_to_file(state):
160        Cmumu_avg = np.array(state["Cmumu"])*mufact*kubo_fact[:nom]
161
162        if do_deconvolution:
163            s_out, s_rec, K_D = deconvolute_spectrum(
164                Cmumu_avg,
165                omega[:nom],
166                gamma,
167                niter_deconv,
168                kernel=kernel_lorentz,
169                trans=False,
170                symmetrize=True,
171                verbose=False,
172                K_D=state.get("K_D", None),
173            )
174            state = {**state, "K_D": K_D}
175            spectra=(s_out,Cmumu_avg,s_rec)
176        else:
177            spectra=(Cmumu_avg,)
178
179        columns = np.column_stack(
180            (
181                omega[:nom] * us.CM1,
182                *spectra,
183            )
184        )
185        np.savetxt(
186            f"IR_spectrum.out",
187            columns,
188            fmt="%12.6f",
189            header="#omega Cmudot",
190        )
191        print("# IR spectrum written.")
192
193        return state
194    
195    @jax.jit
196    def save_dipole(q,vel,pos,dip,cell_reciprocal, state):
197
198        q=q.reshape(-1,1)
199        if use_qvel:
200            qvel = (q*vel).sum(axis=0)
201        if "first" in state:
202            state = {
203                **state,
204                "musave-2":dip,
205                "musave-1":dip,
206            }
207            if use_qvel:
208                state["qvel"] = qvel
209            else:
210                state["pos_save-2"] = pos
211                state["pos_save-1"] = pos
212                state["qsave-2"] = q
213                state["qsave-1"] = q
214            del state["first"]
215
216        istep = state["istep"]
217        new_state = {**state, "istep": istep + 1}
218        new_state["musave-2"] = state["musave-1"]
219        new_state["musave-1"] = dip
220        
221        dipdot = (dip - state["musave-2"])/(2*dt)
222        if use_qvel:
223            qrdot = state["qvel"]
224            new_state["qvel"] = qvel
225        else:
226            new_state["pos_save-2"] = state["pos_save-1"]
227            new_state["pos_save-1"] = pos
228            new_state["qsave-2"] = state["qsave-1"]
229            new_state["qsave-1"] = q
230
231            if cell_reciprocal is not None:
232                cell,reciprocal_cell = cell_reciprocal
233                vec = pos - state["pos_save-2"]
234                shift = -jnp.round(jnp.dot(vec, reciprocal_cell))
235                pos = pos + jnp.dot(shift, cell)
236
237            qrdot = (q*pos - state["qsave-2"]*state["pos_save-2"]).sum(axis=0)/(2*dt)
238
239        new_state["mudot"] = state["mudot"].at[istep].set(qrdot+dipdot)
240
241        return new_state
242
243    # @jax.jit
244    # def save_dipole(mudot, state):
245
246    #     istep = state["istep"]
247    #     new_state = {**state, "istep": istep + 1}
248    #     new_state["mudot"] = state["mudot"].at[istep].set(mudot)
249
250    #     return new_state
251    
252    def postprocess(state):
253        counter.increment()
254        if not counter.is_reset_step:
255            return state
256        state["nadapt"] += 1
257        state["nsample"] = max(state["nadapt"] - startsave + 1, 1)
258        state = compute_spectra(state)
259        state["istep"] = 0
260        state = write_spectra_to_file(state)
261        return state
262    
263    return dipole_model,state,save_dipole, postprocess