fennol.md.spectra

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7import os
  8from pathlib import Path
  9from flax.core import freeze, unfreeze
 10
 11from ..models.fennix import FENNIX
 12from ..utils.atomic_units import AtomicUnits as au  # CM1,THZ,BOHR,MPROT
 13from ..utils import Counter
 14from ..utils.deconvolution import (
 15    deconvolute_spectrum,
 16    kernel_lorentz_pot,
 17    kernel_lorentz,
 18)
 19
 20def initialize_ir_spectrum(simulation_parameters,system_data,fprec,dt,apply_kubo_fact=False):
 21    state = {}
 22
 23    parameters = simulation_parameters.get("ir_parameters", {})
 24    dipole_model = parameters["dipole_model"]
 25    dipole_model = Path(str(dipole_model).strip())
 26    if not dipole_model.exists():
 27        raise FileNotFoundError(f"Dipole model file {dipole_model} not found")
 28    else:
 29        print(f"# Using '{dipole_model}' as dipole model.")
 30        dipole_model = FENNIX.load(dipole_model)
 31
 32        nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
 33        pbc_data = system_data.get("pbc", None)
 34
 35        ### CONFIGURE PREPROCESSING
 36        preproc_state = unfreeze(dipole_model.preproc_state)
 37        layer_state = []
 38        for st in preproc_state["layers_state"]:
 39            stnew = unfreeze(st)
 40            if pbc_data is not None:
 41                stnew["minimum_image"] = pbc_data["minimum_image"]
 42            if nblist_skin > 0:
 43                stnew["nblist_skin"] = nblist_skin
 44            if "nblist_mult_size" in simulation_parameters:
 45                stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
 46            if "nblist_add_neigh" in simulation_parameters:
 47                stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
 48            layer_state.append(freeze(stnew))
 49        preproc_state["layers_state"] = layer_state
 50        dipole_model.preproc_state = freeze(preproc_state)
 51
 52
 53    Tseg = parameters.get("tseg", 1.0 / au.PS) * au.FS
 54    nseg = int(Tseg / dt)
 55    Tseg = nseg * dt
 56    dom = 2 * np.pi / (3 * Tseg)
 57    omegacut = parameters.get("omegacut", 15000.0 / au.CM1) / au.FS
 58    nom = int(omegacut / dom)
 59    omega = dom * np.arange((3 * nseg) // 2 + 1)
 60
 61    assert (
 62        omegacut < omega[-1]
 63    ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1"
 64
 65    startsave = parameters.get("startsave", 1)
 66    counter = Counter(nseg, startsave=startsave)
 67    state["istep"] = 0
 68    state["nsample"] = 0
 69    state["nadapt"] = 0
 70
 71    use_qvel = parameters.get("use_qvel", False)
 72    if use_qvel:
 73        state["qvel"] = jnp.zeros((3,), dtype=fprec)
 74    else:
 75        nat = system_data["nat"]
 76        state["musave-2"] = jnp.zeros((3,), dtype=fprec)
 77        state["musave-1"] = jnp.zeros((3,), dtype=fprec)
 78        state["qsave-2"] = jnp.zeros((nat,1), dtype=fprec)
 79        state["qsave-1"] = jnp.zeros((nat,1), dtype=fprec)
 80        state["pos_save-2"] = jnp.zeros((nat,3), dtype=fprec)
 81        state["pos_save-1"] = jnp.zeros((nat,3), dtype=fprec)
 82    state["mudot"] = jnp.zeros((nseg, 3), dtype=fprec)
 83    state["Cmumu"] = jnp.zeros((nom,), dtype=fprec)
 84    state["first"] = True
 85
 86    kT = system_data["kT"]
 87    kubo_fact = np.ones_like(omega)
 88    if apply_kubo_fact:
 89        uu = 0.5*omega[1:]*au.FS/kT
 90        kubo_fact[1:] = np.tanh(uu)/uu
 91
 92    do_deconvolution = parameters.get("deconvolution", False)
 93    if do_deconvolution:
 94        gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS
 95        niter_deconv = parameters.get("niter_deconv", 20)
 96        print("# Deconvolution of IR spectra with gamma=", gamma*1000,"ps-1 and niter=",niter_deconv)
 97    
 98
 99    kelvin = system_data["temperature"]
100    c=2.99792458e-2 # speed of light in cm/ps
101    # mufact = 1000*4*np.pi**2/(3*kT*c*au.BOHR**2)
102    mufact = 1000*418.40*332.063714*2*np.pi**2/(0.831446215*kelvin*3.*c)
103    pbc_data = system_data.get("pbc", None)
104    if pbc_data is not None:
105        cell = pbc_data["cell"]
106        volume = np.abs(np.linalg.det(cell)) #/au.BOHR**3
107        mufact = mufact/volume
108
109
110    @jax.jit
111    def compute_spectra(state):
112        mudot = state["mudot"]
113        smu = jnp.fft.rfft(mudot, 3 * nseg, axis=0, norm="ortho")
114        Cmumu = dt * jnp.sum(jnp.abs(smu[:nom]) ** 2, axis=-1)
115
116        nsinv = 1.0 / state["nsample"]
117        b1 = 1.0 - nsinv
118        Cmumu = state["Cmumu"] * b1 + Cmumu * nsinv
119
120        return {
121            **state,
122            "Cmumu": Cmumu,
123        }
124
125    def write_spectra_to_file(state):
126        Cmumu_avg = np.array(state["Cmumu"])*mufact*kubo_fact[:nom]
127
128        if do_deconvolution:
129            s_out, s_rec, K_D = deconvolute_spectrum(
130                Cmumu_avg,
131                omega[:nom],
132                gamma,
133                niter_deconv,
134                kernel=kernel_lorentz,
135                trans=False,
136                symmetrize=True,
137                verbose=False,
138                K_D=state.get("K_D", None),
139            )
140            state = {**state, "K_D": K_D}
141            spectra=(s_out,Cmumu_avg,s_rec)
142        else:
143            spectra=(Cmumu_avg,)
144
145        columns = np.column_stack(
146            (
147                omega[:nom] * (au.FS * au.CM1),
148                *spectra,
149            )
150        )
151        np.savetxt(
152            f"IR_spectrum.out",
153            columns,
154            fmt="%12.6f",
155            header="#omega Cmudot",
156        )
157        print("# IR spectrum written.")
158
159        return state
160    
161    @jax.jit
162    def save_dipole(q,vel,pos,dip,cell_reciprocal, state):
163
164        q=q.reshape(-1,1)
165        if use_qvel:
166            qvel = (q*vel).sum(axis=0)
167        if "first" in state:
168            state = {
169                **state,
170                "musave-2":dip,
171                "musave-1":dip,
172            }
173            if use_qvel:
174                state["qvel"] = qvel
175            else:
176                state["pos_save-2"] = pos
177                state["pos_save-1"] = pos
178                state["qsave-2"] = q
179                state["qsave-1"] = q
180            del state["first"]
181
182        istep = state["istep"]
183        new_state = {**state, "istep": istep + 1}
184        new_state["musave-2"] = state["musave-1"]
185        new_state["musave-1"] = dip
186        
187        dipdot = (dip - state["musave-2"])/(2*dt)
188        if use_qvel:
189            qrdot = state["qvel"]
190            new_state["qvel"] = qvel
191        else:
192            new_state["pos_save-2"] = state["pos_save-1"]
193            new_state["pos_save-1"] = pos
194            new_state["qsave-2"] = state["qsave-1"]
195            new_state["qsave-1"] = q
196
197            if cell_reciprocal is not None:
198                cell,reciprocal_cell = cell_reciprocal
199                vec = pos - state["pos_save-2"]
200                shift = -jnp.round(jnp.dot(vec, reciprocal_cell))
201                pos = pos + jnp.dot(shift, cell)
202
203            qrdot = (q*pos - state["qsave-2"]*state["pos_save-2"]).sum(axis=0)/(2*dt)
204
205        new_state["mudot"] = state["mudot"].at[istep].set(qrdot+dipdot)
206
207        return new_state
208
209    # @jax.jit
210    # def save_dipole(mudot, state):
211
212    #     istep = state["istep"]
213    #     new_state = {**state, "istep": istep + 1}
214    #     new_state["mudot"] = state["mudot"].at[istep].set(mudot)
215
216    #     return new_state
217    
218    def postprocess(state):
219        counter.increment()
220        if not counter.is_reset_step:
221            return state
222        state["nadapt"] += 1
223        state["nsample"] = max(state["nadapt"] - startsave + 1, 1)
224        state = compute_spectra(state)
225        state["istep"] = 0
226        state = write_spectra_to_file(state)
227        return state
228    
229    return dipole_model,state,save_dipole, postprocess
230        
231        
def initialize_ir_spectrum(simulation_parameters, system_data, fprec, dt, apply_kubo_fact=False):
 21def initialize_ir_spectrum(simulation_parameters,system_data,fprec,dt,apply_kubo_fact=False):
 22    state = {}
 23
 24    parameters = simulation_parameters.get("ir_parameters", {})
 25    dipole_model = parameters["dipole_model"]
 26    dipole_model = Path(str(dipole_model).strip())
 27    if not dipole_model.exists():
 28        raise FileNotFoundError(f"Dipole model file {dipole_model} not found")
 29    else:
 30        print(f"# Using '{dipole_model}' as dipole model.")
 31        dipole_model = FENNIX.load(dipole_model)
 32
 33        nblist_skin = simulation_parameters.get("nblist_skin", -1.0)
 34        pbc_data = system_data.get("pbc", None)
 35
 36        ### CONFIGURE PREPROCESSING
 37        preproc_state = unfreeze(dipole_model.preproc_state)
 38        layer_state = []
 39        for st in preproc_state["layers_state"]:
 40            stnew = unfreeze(st)
 41            if pbc_data is not None:
 42                stnew["minimum_image"] = pbc_data["minimum_image"]
 43            if nblist_skin > 0:
 44                stnew["nblist_skin"] = nblist_skin
 45            if "nblist_mult_size" in simulation_parameters:
 46                stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"]
 47            if "nblist_add_neigh" in simulation_parameters:
 48                stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"]
 49            layer_state.append(freeze(stnew))
 50        preproc_state["layers_state"] = layer_state
 51        dipole_model.preproc_state = freeze(preproc_state)
 52
 53
 54    Tseg = parameters.get("tseg", 1.0 / au.PS) * au.FS
 55    nseg = int(Tseg / dt)
 56    Tseg = nseg * dt
 57    dom = 2 * np.pi / (3 * Tseg)
 58    omegacut = parameters.get("omegacut", 15000.0 / au.CM1) / au.FS
 59    nom = int(omegacut / dom)
 60    omega = dom * np.arange((3 * nseg) // 2 + 1)
 61
 62    assert (
 63        omegacut < omega[-1]
 64    ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1"
 65
 66    startsave = parameters.get("startsave", 1)
 67    counter = Counter(nseg, startsave=startsave)
 68    state["istep"] = 0
 69    state["nsample"] = 0
 70    state["nadapt"] = 0
 71
 72    use_qvel = parameters.get("use_qvel", False)
 73    if use_qvel:
 74        state["qvel"] = jnp.zeros((3,), dtype=fprec)
 75    else:
 76        nat = system_data["nat"]
 77        state["musave-2"] = jnp.zeros((3,), dtype=fprec)
 78        state["musave-1"] = jnp.zeros((3,), dtype=fprec)
 79        state["qsave-2"] = jnp.zeros((nat,1), dtype=fprec)
 80        state["qsave-1"] = jnp.zeros((nat,1), dtype=fprec)
 81        state["pos_save-2"] = jnp.zeros((nat,3), dtype=fprec)
 82        state["pos_save-1"] = jnp.zeros((nat,3), dtype=fprec)
 83    state["mudot"] = jnp.zeros((nseg, 3), dtype=fprec)
 84    state["Cmumu"] = jnp.zeros((nom,), dtype=fprec)
 85    state["first"] = True
 86
 87    kT = system_data["kT"]
 88    kubo_fact = np.ones_like(omega)
 89    if apply_kubo_fact:
 90        uu = 0.5*omega[1:]*au.FS/kT
 91        kubo_fact[1:] = np.tanh(uu)/uu
 92
 93    do_deconvolution = parameters.get("deconvolution", False)
 94    if do_deconvolution:
 95        gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS
 96        niter_deconv = parameters.get("niter_deconv", 20)
 97        print("# Deconvolution of IR spectra with gamma=", gamma*1000,"ps-1 and niter=",niter_deconv)
 98    
 99
100    kelvin = system_data["temperature"]
101    c=2.99792458e-2 # speed of light in cm/ps
102    # mufact = 1000*4*np.pi**2/(3*kT*c*au.BOHR**2)
103    mufact = 1000*418.40*332.063714*2*np.pi**2/(0.831446215*kelvin*3.*c)
104    pbc_data = system_data.get("pbc", None)
105    if pbc_data is not None:
106        cell = pbc_data["cell"]
107        volume = np.abs(np.linalg.det(cell)) #/au.BOHR**3
108        mufact = mufact/volume
109
110
111    @jax.jit
112    def compute_spectra(state):
113        mudot = state["mudot"]
114        smu = jnp.fft.rfft(mudot, 3 * nseg, axis=0, norm="ortho")
115        Cmumu = dt * jnp.sum(jnp.abs(smu[:nom]) ** 2, axis=-1)
116
117        nsinv = 1.0 / state["nsample"]
118        b1 = 1.0 - nsinv
119        Cmumu = state["Cmumu"] * b1 + Cmumu * nsinv
120
121        return {
122            **state,
123            "Cmumu": Cmumu,
124        }
125
126    def write_spectra_to_file(state):
127        Cmumu_avg = np.array(state["Cmumu"])*mufact*kubo_fact[:nom]
128
129        if do_deconvolution:
130            s_out, s_rec, K_D = deconvolute_spectrum(
131                Cmumu_avg,
132                omega[:nom],
133                gamma,
134                niter_deconv,
135                kernel=kernel_lorentz,
136                trans=False,
137                symmetrize=True,
138                verbose=False,
139                K_D=state.get("K_D", None),
140            )
141            state = {**state, "K_D": K_D}
142            spectra=(s_out,Cmumu_avg,s_rec)
143        else:
144            spectra=(Cmumu_avg,)
145
146        columns = np.column_stack(
147            (
148                omega[:nom] * (au.FS * au.CM1),
149                *spectra,
150            )
151        )
152        np.savetxt(
153            f"IR_spectrum.out",
154            columns,
155            fmt="%12.6f",
156            header="#omega Cmudot",
157        )
158        print("# IR spectrum written.")
159
160        return state
161    
162    @jax.jit
163    def save_dipole(q,vel,pos,dip,cell_reciprocal, state):
164
165        q=q.reshape(-1,1)
166        if use_qvel:
167            qvel = (q*vel).sum(axis=0)
168        if "first" in state:
169            state = {
170                **state,
171                "musave-2":dip,
172                "musave-1":dip,
173            }
174            if use_qvel:
175                state["qvel"] = qvel
176            else:
177                state["pos_save-2"] = pos
178                state["pos_save-1"] = pos
179                state["qsave-2"] = q
180                state["qsave-1"] = q
181            del state["first"]
182
183        istep = state["istep"]
184        new_state = {**state, "istep": istep + 1}
185        new_state["musave-2"] = state["musave-1"]
186        new_state["musave-1"] = dip
187        
188        dipdot = (dip - state["musave-2"])/(2*dt)
189        if use_qvel:
190            qrdot = state["qvel"]
191            new_state["qvel"] = qvel
192        else:
193            new_state["pos_save-2"] = state["pos_save-1"]
194            new_state["pos_save-1"] = pos
195            new_state["qsave-2"] = state["qsave-1"]
196            new_state["qsave-1"] = q
197
198            if cell_reciprocal is not None:
199                cell,reciprocal_cell = cell_reciprocal
200                vec = pos - state["pos_save-2"]
201                shift = -jnp.round(jnp.dot(vec, reciprocal_cell))
202                pos = pos + jnp.dot(shift, cell)
203
204            qrdot = (q*pos - state["qsave-2"]*state["pos_save-2"]).sum(axis=0)/(2*dt)
205
206        new_state["mudot"] = state["mudot"].at[istep].set(qrdot+dipdot)
207
208        return new_state
209
210    # @jax.jit
211    # def save_dipole(mudot, state):
212
213    #     istep = state["istep"]
214    #     new_state = {**state, "istep": istep + 1}
215    #     new_state["mudot"] = state["mudot"].at[istep].set(mudot)
216
217    #     return new_state
218    
219    def postprocess(state):
220        counter.increment()
221        if not counter.is_reset_step:
222            return state
223        state["nadapt"] += 1
224        state["nsample"] = max(state["nadapt"] - startsave + 1, 1)
225        state = compute_spectra(state)
226        state["istep"] = 0
227        state = write_spectra_to_file(state)
228        return state
229    
230    return dipole_model,state,save_dipole, postprocess