fennol.md.spectra
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7import os 8from pathlib import Path 9from flax.core import freeze, unfreeze 10 11from ..models.fennix import FENNIX 12from ..utils.atomic_units import AtomicUnits as au # CM1,THZ,BOHR,MPROT 13from ..utils import Counter 14from ..utils.deconvolution import ( 15 deconvolute_spectrum, 16 kernel_lorentz_pot, 17 kernel_lorentz, 18) 19 20def initialize_ir_spectrum(simulation_parameters,system_data,fprec,dt,apply_kubo_fact=False): 21 state = {} 22 23 parameters = simulation_parameters.get("ir_parameters", {}) 24 dipole_model = parameters["dipole_model"] 25 dipole_model = Path(str(dipole_model).strip()) 26 if not dipole_model.exists(): 27 raise FileNotFoundError(f"Dipole model file {dipole_model} not found") 28 else: 29 print(f"# Using '{dipole_model}' as dipole model.") 30 dipole_model = FENNIX.load(dipole_model) 31 32 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 33 pbc_data = system_data.get("pbc", None) 34 35 ### CONFIGURE PREPROCESSING 36 preproc_state = unfreeze(dipole_model.preproc_state) 37 layer_state = [] 38 for st in preproc_state["layers_state"]: 39 stnew = unfreeze(st) 40 if pbc_data is not None: 41 stnew["minimum_image"] = pbc_data["minimum_image"] 42 if nblist_skin > 0: 43 stnew["nblist_skin"] = nblist_skin 44 if "nblist_mult_size" in simulation_parameters: 45 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 46 if "nblist_add_neigh" in simulation_parameters: 47 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 48 layer_state.append(freeze(stnew)) 49 preproc_state["layers_state"] = layer_state 50 dipole_model.preproc_state = freeze(preproc_state) 51 52 53 Tseg = parameters.get("tseg", 1.0 / au.PS) * au.FS 54 nseg = int(Tseg / dt) 55 Tseg = nseg * dt 56 dom = 2 * np.pi / (3 * Tseg) 57 omegacut = parameters.get("omegacut", 15000.0 / au.CM1) / au.FS 58 nom = int(omegacut / dom) 59 omega = dom * np.arange((3 * nseg) // 2 + 1) 60 61 assert ( 62 omegacut < omega[-1] 63 ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1" 64 65 startsave = parameters.get("startsave", 1) 66 counter = Counter(nseg, startsave=startsave) 67 state["istep"] = 0 68 state["nsample"] = 0 69 state["nadapt"] = 0 70 71 use_qvel = parameters.get("use_qvel", False) 72 if use_qvel: 73 state["qvel"] = jnp.zeros((3,), dtype=fprec) 74 else: 75 nat = system_data["nat"] 76 state["musave-2"] = jnp.zeros((3,), dtype=fprec) 77 state["musave-1"] = jnp.zeros((3,), dtype=fprec) 78 state["qsave-2"] = jnp.zeros((nat,1), dtype=fprec) 79 state["qsave-1"] = jnp.zeros((nat,1), dtype=fprec) 80 state["pos_save-2"] = jnp.zeros((nat,3), dtype=fprec) 81 state["pos_save-1"] = jnp.zeros((nat,3), dtype=fprec) 82 state["mudot"] = jnp.zeros((nseg, 3), dtype=fprec) 83 state["Cmumu"] = jnp.zeros((nom,), dtype=fprec) 84 state["first"] = True 85 86 kT = system_data["kT"] 87 kubo_fact = np.ones_like(omega) 88 if apply_kubo_fact: 89 uu = 0.5*omega[1:]*au.FS/kT 90 kubo_fact[1:] = np.tanh(uu)/uu 91 92 do_deconvolution = parameters.get("deconvolution", False) 93 if do_deconvolution: 94 gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS 95 niter_deconv = parameters.get("niter_deconv", 20) 96 print("# Deconvolution of IR spectra with gamma=", gamma*1000,"ps-1 and niter=",niter_deconv) 97 98 99 kelvin = system_data["temperature"] 100 c=2.99792458e-2 # speed of light in cm/ps 101 # mufact = 1000*4*np.pi**2/(3*kT*c*au.BOHR**2) 102 mufact = 1000*418.40*332.063714*2*np.pi**2/(0.831446215*kelvin*3.*c) 103 pbc_data = system_data.get("pbc", None) 104 if pbc_data is not None: 105 cell = pbc_data["cell"] 106 volume = np.abs(np.linalg.det(cell)) #/au.BOHR**3 107 mufact = mufact/volume 108 109 110 @jax.jit 111 def compute_spectra(state): 112 mudot = state["mudot"] 113 smu = jnp.fft.rfft(mudot, 3 * nseg, axis=0, norm="ortho") 114 Cmumu = dt * jnp.sum(jnp.abs(smu[:nom]) ** 2, axis=-1) 115 116 nsinv = 1.0 / state["nsample"] 117 b1 = 1.0 - nsinv 118 Cmumu = state["Cmumu"] * b1 + Cmumu * nsinv 119 120 return { 121 **state, 122 "Cmumu": Cmumu, 123 } 124 125 def write_spectra_to_file(state): 126 Cmumu_avg = np.array(state["Cmumu"])*mufact*kubo_fact[:nom] 127 128 if do_deconvolution: 129 s_out, s_rec, K_D = deconvolute_spectrum( 130 Cmumu_avg, 131 omega[:nom], 132 gamma, 133 niter_deconv, 134 kernel=kernel_lorentz, 135 trans=False, 136 symmetrize=True, 137 verbose=False, 138 K_D=state.get("K_D", None), 139 ) 140 state = {**state, "K_D": K_D} 141 spectra=(s_out,Cmumu_avg,s_rec) 142 else: 143 spectra=(Cmumu_avg,) 144 145 columns = np.column_stack( 146 ( 147 omega[:nom] * (au.FS * au.CM1), 148 *spectra, 149 ) 150 ) 151 np.savetxt( 152 f"IR_spectrum.out", 153 columns, 154 fmt="%12.6f", 155 header="#omega Cmudot", 156 ) 157 print("# IR spectrum written.") 158 159 return state 160 161 @jax.jit 162 def save_dipole(q,vel,pos,dip,cell_reciprocal, state): 163 164 q=q.reshape(-1,1) 165 if use_qvel: 166 qvel = (q*vel).sum(axis=0) 167 if "first" in state: 168 state = { 169 **state, 170 "musave-2":dip, 171 "musave-1":dip, 172 } 173 if use_qvel: 174 state["qvel"] = qvel 175 else: 176 state["pos_save-2"] = pos 177 state["pos_save-1"] = pos 178 state["qsave-2"] = q 179 state["qsave-1"] = q 180 del state["first"] 181 182 istep = state["istep"] 183 new_state = {**state, "istep": istep + 1} 184 new_state["musave-2"] = state["musave-1"] 185 new_state["musave-1"] = dip 186 187 dipdot = (dip - state["musave-2"])/(2*dt) 188 if use_qvel: 189 qrdot = state["qvel"] 190 new_state["qvel"] = qvel 191 else: 192 new_state["pos_save-2"] = state["pos_save-1"] 193 new_state["pos_save-1"] = pos 194 new_state["qsave-2"] = state["qsave-1"] 195 new_state["qsave-1"] = q 196 197 if cell_reciprocal is not None: 198 cell,reciprocal_cell = cell_reciprocal 199 vec = pos - state["pos_save-2"] 200 shift = -jnp.round(jnp.dot(vec, reciprocal_cell)) 201 pos = pos + jnp.dot(shift, cell) 202 203 qrdot = (q*pos - state["qsave-2"]*state["pos_save-2"]).sum(axis=0)/(2*dt) 204 205 new_state["mudot"] = state["mudot"].at[istep].set(qrdot+dipdot) 206 207 return new_state 208 209 # @jax.jit 210 # def save_dipole(mudot, state): 211 212 # istep = state["istep"] 213 # new_state = {**state, "istep": istep + 1} 214 # new_state["mudot"] = state["mudot"].at[istep].set(mudot) 215 216 # return new_state 217 218 def postprocess(state): 219 counter.increment() 220 if not counter.is_reset_step: 221 return state 222 state["nadapt"] += 1 223 state["nsample"] = max(state["nadapt"] - startsave + 1, 1) 224 state = compute_spectra(state) 225 state["istep"] = 0 226 state = write_spectra_to_file(state) 227 return state 228 229 return dipole_model,state,save_dipole, postprocess 230 231
def
initialize_ir_spectrum(simulation_parameters, system_data, fprec, dt, apply_kubo_fact=False):
21def initialize_ir_spectrum(simulation_parameters,system_data,fprec,dt,apply_kubo_fact=False): 22 state = {} 23 24 parameters = simulation_parameters.get("ir_parameters", {}) 25 dipole_model = parameters["dipole_model"] 26 dipole_model = Path(str(dipole_model).strip()) 27 if not dipole_model.exists(): 28 raise FileNotFoundError(f"Dipole model file {dipole_model} not found") 29 else: 30 print(f"# Using '{dipole_model}' as dipole model.") 31 dipole_model = FENNIX.load(dipole_model) 32 33 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 34 pbc_data = system_data.get("pbc", None) 35 36 ### CONFIGURE PREPROCESSING 37 preproc_state = unfreeze(dipole_model.preproc_state) 38 layer_state = [] 39 for st in preproc_state["layers_state"]: 40 stnew = unfreeze(st) 41 if pbc_data is not None: 42 stnew["minimum_image"] = pbc_data["minimum_image"] 43 if nblist_skin > 0: 44 stnew["nblist_skin"] = nblist_skin 45 if "nblist_mult_size" in simulation_parameters: 46 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 47 if "nblist_add_neigh" in simulation_parameters: 48 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 49 layer_state.append(freeze(stnew)) 50 preproc_state["layers_state"] = layer_state 51 dipole_model.preproc_state = freeze(preproc_state) 52 53 54 Tseg = parameters.get("tseg", 1.0 / au.PS) * au.FS 55 nseg = int(Tseg / dt) 56 Tseg = nseg * dt 57 dom = 2 * np.pi / (3 * Tseg) 58 omegacut = parameters.get("omegacut", 15000.0 / au.CM1) / au.FS 59 nom = int(omegacut / dom) 60 omega = dom * np.arange((3 * nseg) // 2 + 1) 61 62 assert ( 63 omegacut < omega[-1] 64 ), f"omegacut must be smaller than {omega[-1]*au.CM1} CM-1" 65 66 startsave = parameters.get("startsave", 1) 67 counter = Counter(nseg, startsave=startsave) 68 state["istep"] = 0 69 state["nsample"] = 0 70 state["nadapt"] = 0 71 72 use_qvel = parameters.get("use_qvel", False) 73 if use_qvel: 74 state["qvel"] = jnp.zeros((3,), dtype=fprec) 75 else: 76 nat = system_data["nat"] 77 state["musave-2"] = jnp.zeros((3,), dtype=fprec) 78 state["musave-1"] = jnp.zeros((3,), dtype=fprec) 79 state["qsave-2"] = jnp.zeros((nat,1), dtype=fprec) 80 state["qsave-1"] = jnp.zeros((nat,1), dtype=fprec) 81 state["pos_save-2"] = jnp.zeros((nat,3), dtype=fprec) 82 state["pos_save-1"] = jnp.zeros((nat,3), dtype=fprec) 83 state["mudot"] = jnp.zeros((nseg, 3), dtype=fprec) 84 state["Cmumu"] = jnp.zeros((nom,), dtype=fprec) 85 state["first"] = True 86 87 kT = system_data["kT"] 88 kubo_fact = np.ones_like(omega) 89 if apply_kubo_fact: 90 uu = 0.5*omega[1:]*au.FS/kT 91 kubo_fact[1:] = np.tanh(uu)/uu 92 93 do_deconvolution = parameters.get("deconvolution", False) 94 if do_deconvolution: 95 gamma = simulation_parameters.get("gamma", 1.0 / au.THZ) / au.FS 96 niter_deconv = parameters.get("niter_deconv", 20) 97 print("# Deconvolution of IR spectra with gamma=", gamma*1000,"ps-1 and niter=",niter_deconv) 98 99 100 kelvin = system_data["temperature"] 101 c=2.99792458e-2 # speed of light in cm/ps 102 # mufact = 1000*4*np.pi**2/(3*kT*c*au.BOHR**2) 103 mufact = 1000*418.40*332.063714*2*np.pi**2/(0.831446215*kelvin*3.*c) 104 pbc_data = system_data.get("pbc", None) 105 if pbc_data is not None: 106 cell = pbc_data["cell"] 107 volume = np.abs(np.linalg.det(cell)) #/au.BOHR**3 108 mufact = mufact/volume 109 110 111 @jax.jit 112 def compute_spectra(state): 113 mudot = state["mudot"] 114 smu = jnp.fft.rfft(mudot, 3 * nseg, axis=0, norm="ortho") 115 Cmumu = dt * jnp.sum(jnp.abs(smu[:nom]) ** 2, axis=-1) 116 117 nsinv = 1.0 / state["nsample"] 118 b1 = 1.0 - nsinv 119 Cmumu = state["Cmumu"] * b1 + Cmumu * nsinv 120 121 return { 122 **state, 123 "Cmumu": Cmumu, 124 } 125 126 def write_spectra_to_file(state): 127 Cmumu_avg = np.array(state["Cmumu"])*mufact*kubo_fact[:nom] 128 129 if do_deconvolution: 130 s_out, s_rec, K_D = deconvolute_spectrum( 131 Cmumu_avg, 132 omega[:nom], 133 gamma, 134 niter_deconv, 135 kernel=kernel_lorentz, 136 trans=False, 137 symmetrize=True, 138 verbose=False, 139 K_D=state.get("K_D", None), 140 ) 141 state = {**state, "K_D": K_D} 142 spectra=(s_out,Cmumu_avg,s_rec) 143 else: 144 spectra=(Cmumu_avg,) 145 146 columns = np.column_stack( 147 ( 148 omega[:nom] * (au.FS * au.CM1), 149 *spectra, 150 ) 151 ) 152 np.savetxt( 153 f"IR_spectrum.out", 154 columns, 155 fmt="%12.6f", 156 header="#omega Cmudot", 157 ) 158 print("# IR spectrum written.") 159 160 return state 161 162 @jax.jit 163 def save_dipole(q,vel,pos,dip,cell_reciprocal, state): 164 165 q=q.reshape(-1,1) 166 if use_qvel: 167 qvel = (q*vel).sum(axis=0) 168 if "first" in state: 169 state = { 170 **state, 171 "musave-2":dip, 172 "musave-1":dip, 173 } 174 if use_qvel: 175 state["qvel"] = qvel 176 else: 177 state["pos_save-2"] = pos 178 state["pos_save-1"] = pos 179 state["qsave-2"] = q 180 state["qsave-1"] = q 181 del state["first"] 182 183 istep = state["istep"] 184 new_state = {**state, "istep": istep + 1} 185 new_state["musave-2"] = state["musave-1"] 186 new_state["musave-1"] = dip 187 188 dipdot = (dip - state["musave-2"])/(2*dt) 189 if use_qvel: 190 qrdot = state["qvel"] 191 new_state["qvel"] = qvel 192 else: 193 new_state["pos_save-2"] = state["pos_save-1"] 194 new_state["pos_save-1"] = pos 195 new_state["qsave-2"] = state["qsave-1"] 196 new_state["qsave-1"] = q 197 198 if cell_reciprocal is not None: 199 cell,reciprocal_cell = cell_reciprocal 200 vec = pos - state["pos_save-2"] 201 shift = -jnp.round(jnp.dot(vec, reciprocal_cell)) 202 pos = pos + jnp.dot(shift, cell) 203 204 qrdot = (q*pos - state["qsave-2"]*state["pos_save-2"]).sum(axis=0)/(2*dt) 205 206 new_state["mudot"] = state["mudot"].at[istep].set(qrdot+dipdot) 207 208 return new_state 209 210 # @jax.jit 211 # def save_dipole(mudot, state): 212 213 # istep = state["istep"] 214 # new_state = {**state, "istep": istep + 1} 215 # new_state["mudot"] = state["mudot"].at[istep].set(mudot) 216 217 # return new_state 218 219 def postprocess(state): 220 counter.increment() 221 if not counter.is_reset_step: 222 return state 223 state["nadapt"] += 1 224 state["nsample"] = max(state["nadapt"] - startsave + 1, 1) 225 state = compute_spectra(state) 226 state["istep"] = 0 227 state = write_spectra_to_file(state) 228 return state 229 230 return dipole_model,state,save_dipole, postprocess