fennol.md.spectra
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7import os 8from pathlib import Path 9from flax.core import freeze, unfreeze 10 11from ..models.fennix import FENNIX 12from .utils import us 13from ..utils import Counter 14from ..utils.deconvolution import ( 15 deconvolute_spectrum, 16 kernel_lorentz_pot, 17 kernel_lorentz, 18) 19 20def initialize_ir_spectrum(simulation_parameters,system_data,fprec,dt,apply_kubo_fact=False): 21 state = {} 22 23 parameters = simulation_parameters.get("ir_parameters", {}) 24 """@keyword[fennol_md] ir_parameters 25 Parameters for infrared spectrum calculation including dipole model settings. 26 Required for ir_spectrum=True 27 """ 28 dipole_model = parameters["dipole_model"] 29 dipole_model = Path(str(dipole_model).strip()) 30 if not dipole_model.exists(): 31 raise FileNotFoundError(f"Dipole model file {dipole_model} not found") 32 else: 33 print(f"# Using '{dipole_model}' as dipole model.") 34 dipole_model = FENNIX.load(dipole_model) 35 36 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 37 """@keyword[fennol_md] nblist_skin 38 Neighbor list skin distance for dipole model preprocessing (in Angstroms). 39 Default: -1.0 40 """ 41 pbc_data = system_data.get("pbc", None) 42 43 ### CONFIGURE PREPROCESSING 44 preproc_state = unfreeze(dipole_model.preproc_state) 45 layer_state = [] 46 for st in preproc_state["layers_state"]: 47 stnew = unfreeze(st) 48 if nblist_skin > 0: 49 stnew["nblist_skin"] = nblist_skin 50 if "nblist_mult_size" in simulation_parameters: 51 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 52 if "nblist_add_neigh" in simulation_parameters: 53 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 54 layer_state.append(freeze(stnew)) 55 preproc_state["layers_state"] = layer_state 56 dipole_model.preproc_state = freeze(preproc_state) 57 58 59 Tseg = parameters.get("tseg", 1.0 / us.PS) 60 """@keyword[fennol_md] ir_parameters/tseg 61 Time segment length for IR spectrum calculation. 62 Default: 1.0 ps 63 """ 64 nseg = int(Tseg / dt) 65 Tseg = nseg * dt 66 dom = 2 * np.pi / (3 * Tseg) 67 omegacut = parameters.get("omegacut", 15000.0 / us.CM1) 68 """@keyword[fennol_md] ir_parameters/omegacut 69 Cutoff frequency for IR spectrum. 70 Default: 15000.0 cm⁻¹ 71 """ 72 nom = int(omegacut / dom) 73 omega = dom * np.arange((3 * nseg) // 2 + 1) 74 75 assert ( 76 omegacut < omega[-1] 77 ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1" 78 79 startsave = parameters.get("startsave", 1) 80 """@keyword[fennol_md] ir_parameters/startsave 81 Start saving IR statistics after this many segments. 82 Default: 1 83 """ 84 counter = Counter(nseg, startsave=startsave) 85 state["istep"] = 0 86 state["nsample"] = 0 87 state["nadapt"] = 0 88 89 use_qvel = parameters.get("use_qvel", False) 90 """@keyword[fennol_md] ir_parameters/use_qvel 91 Use quantum velocity correction for IR spectrum calculation. 92 Default: False 93 """ 94 if use_qvel: 95 state["qvel"] = jnp.zeros((3,), dtype=fprec) 96 else: 97 nat = system_data["nat"] 98 state["musave-2"] = jnp.zeros((3,), dtype=fprec) 99 state["musave-1"] = jnp.zeros((3,), dtype=fprec) 100 state["qsave-2"] = jnp.zeros((nat,1), dtype=fprec) 101 state["qsave-1"] = jnp.zeros((nat,1), dtype=fprec) 102 state["pos_save-2"] = jnp.zeros((nat,3), dtype=fprec) 103 state["pos_save-1"] = jnp.zeros((nat,3), dtype=fprec) 104 state["mudot"] = jnp.zeros((nseg, 3), dtype=fprec) 105 state["Cmumu"] = jnp.zeros((nom,), dtype=fprec) 106 state["first"] = True 107 108 kT = system_data["kT"] 109 kubo_fact = np.ones_like(omega) 110 if apply_kubo_fact: 111 uu = 0.5*us.HBAR*omega[1:]/kT 112 kubo_fact[1:] = np.tanh(uu)/uu 113 114 do_deconvolution = parameters.get("deconvolution", False) 115 """@keyword[fennol_md] ir_parameters/deconvolution 116 Apply deconvolution to IR spectrum for better resolution. 117 Default: False 118 """ 119 if do_deconvolution: 120 gamma = simulation_parameters.get("gamma", 1.0 / us.THZ) 121 """@keyword[fennol_md] gamma 122 Friction coefficient for deconvolution of IR spectra. 123 Default: 1.0 ps^-1 124 """ 125 niter_deconv = parameters.get("niter_deconv", 20) 126 """@keyword[fennol_md] ir_parameters/niter_deconv 127 Number of iterations for IR spectrum deconvolution. 128 Default: 20 129 """ 130 print("# Deconvolution of IR spectra with gamma=", gamma*(1./us.PS),"ps^-1 and niter=",niter_deconv) 131 132 133 temp_K = system_data["temperature"] 134 c=2.99792458e-2 # speed of light in cm/ps 135 mufact = 1000*418.40*332.063714*2*np.pi**2/(0.831446215*temp_K*3.*c) 136 pbc_data = system_data.get("pbc", None) 137 if pbc_data is not None: 138 cell = pbc_data["cell"] 139 volume = np.abs(np.linalg.det(cell)) 140 mufact = mufact/volume 141 142 143 @jax.jit 144 def compute_spectra(state): 145 mudot = state["mudot"] 146 smu = jnp.fft.rfft(mudot, 3 * nseg, axis=0, norm="ortho") 147 Cmumu = dt * jnp.sum(jnp.abs(smu[:nom]) ** 2, axis=-1) 148 149 nsinv = 1.0 / state["nsample"] 150 b1 = 1.0 - nsinv 151 Cmumu = state["Cmumu"] * b1 + Cmumu * nsinv 152 153 return { 154 **state, 155 "Cmumu": Cmumu, 156 } 157 158 def write_spectra_to_file(state): 159 Cmumu_avg = np.array(state["Cmumu"])*mufact*kubo_fact[:nom] 160 161 if do_deconvolution: 162 s_out, s_rec, K_D = deconvolute_spectrum( 163 Cmumu_avg, 164 omega[:nom], 165 gamma, 166 niter_deconv, 167 kernel=kernel_lorentz, 168 trans=False, 169 symmetrize=True, 170 verbose=False, 171 K_D=state.get("K_D", None), 172 ) 173 state = {**state, "K_D": K_D} 174 spectra=(s_out,Cmumu_avg,s_rec) 175 else: 176 spectra=(Cmumu_avg,) 177 178 columns = np.column_stack( 179 ( 180 omega[:nom] * us.CM1, 181 *spectra, 182 ) 183 ) 184 np.savetxt( 185 f"IR_spectrum.out", 186 columns, 187 fmt="%12.6f", 188 header="#omega Cmudot", 189 ) 190 print("# IR spectrum written.") 191 192 return state 193 194 @jax.jit 195 def save_dipole(q,vel,pos,dip,cell_reciprocal, state): 196 197 q=q.reshape(-1,1) 198 if use_qvel: 199 qvel = (q*vel).sum(axis=0) 200 if "first" in state: 201 state = { 202 **state, 203 "musave-2":dip, 204 "musave-1":dip, 205 } 206 if use_qvel: 207 state["qvel"] = qvel 208 else: 209 state["pos_save-2"] = pos 210 state["pos_save-1"] = pos 211 state["qsave-2"] = q 212 state["qsave-1"] = q 213 del state["first"] 214 215 istep = state["istep"] 216 new_state = {**state, "istep": istep + 1} 217 new_state["musave-2"] = state["musave-1"] 218 new_state["musave-1"] = dip 219 220 dipdot = (dip - state["musave-2"])/(2*dt) 221 if use_qvel: 222 qrdot = state["qvel"] 223 new_state["qvel"] = qvel 224 else: 225 new_state["pos_save-2"] = state["pos_save-1"] 226 new_state["pos_save-1"] = pos 227 new_state["qsave-2"] = state["qsave-1"] 228 new_state["qsave-1"] = q 229 230 if cell_reciprocal is not None: 231 cell,reciprocal_cell = cell_reciprocal 232 vec = pos - state["pos_save-2"] 233 shift = -jnp.round(jnp.dot(vec, reciprocal_cell)) 234 pos = pos + jnp.dot(shift, cell) 235 236 qrdot = (q*pos - state["qsave-2"]*state["pos_save-2"]).sum(axis=0)/(2*dt) 237 238 new_state["mudot"] = state["mudot"].at[istep].set(qrdot+dipdot) 239 240 return new_state 241 242 # @jax.jit 243 # def save_dipole(mudot, state): 244 245 # istep = state["istep"] 246 # new_state = {**state, "istep": istep + 1} 247 # new_state["mudot"] = state["mudot"].at[istep].set(mudot) 248 249 # return new_state 250 251 def postprocess(state): 252 counter.increment() 253 if not counter.is_reset_step: 254 return state 255 state["nadapt"] += 1 256 state["nsample"] = max(state["nadapt"] - startsave + 1, 1) 257 state = compute_spectra(state) 258 state["istep"] = 0 259 state = write_spectra_to_file(state) 260 return state 261 262 return dipole_model,state,save_dipole, postprocess 263 264
def
initialize_ir_spectrum(simulation_parameters, system_data, fprec, dt, apply_kubo_fact=False):
21def initialize_ir_spectrum(simulation_parameters,system_data,fprec,dt,apply_kubo_fact=False): 22 state = {} 23 24 parameters = simulation_parameters.get("ir_parameters", {}) 25 """@keyword[fennol_md] ir_parameters 26 Parameters for infrared spectrum calculation including dipole model settings. 27 Required for ir_spectrum=True 28 """ 29 dipole_model = parameters["dipole_model"] 30 dipole_model = Path(str(dipole_model).strip()) 31 if not dipole_model.exists(): 32 raise FileNotFoundError(f"Dipole model file {dipole_model} not found") 33 else: 34 print(f"# Using '{dipole_model}' as dipole model.") 35 dipole_model = FENNIX.load(dipole_model) 36 37 nblist_skin = simulation_parameters.get("nblist_skin", -1.0) 38 """@keyword[fennol_md] nblist_skin 39 Neighbor list skin distance for dipole model preprocessing (in Angstroms). 40 Default: -1.0 41 """ 42 pbc_data = system_data.get("pbc", None) 43 44 ### CONFIGURE PREPROCESSING 45 preproc_state = unfreeze(dipole_model.preproc_state) 46 layer_state = [] 47 for st in preproc_state["layers_state"]: 48 stnew = unfreeze(st) 49 if nblist_skin > 0: 50 stnew["nblist_skin"] = nblist_skin 51 if "nblist_mult_size" in simulation_parameters: 52 stnew["nblist_mult_size"] = simulation_parameters["nblist_mult_size"] 53 if "nblist_add_neigh" in simulation_parameters: 54 stnew["add_neigh"] = simulation_parameters["nblist_add_neigh"] 55 layer_state.append(freeze(stnew)) 56 preproc_state["layers_state"] = layer_state 57 dipole_model.preproc_state = freeze(preproc_state) 58 59 60 Tseg = parameters.get("tseg", 1.0 / us.PS) 61 """@keyword[fennol_md] ir_parameters/tseg 62 Time segment length for IR spectrum calculation. 63 Default: 1.0 ps 64 """ 65 nseg = int(Tseg / dt) 66 Tseg = nseg * dt 67 dom = 2 * np.pi / (3 * Tseg) 68 omegacut = parameters.get("omegacut", 15000.0 / us.CM1) 69 """@keyword[fennol_md] ir_parameters/omegacut 70 Cutoff frequency for IR spectrum. 71 Default: 15000.0 cm⁻¹ 72 """ 73 nom = int(omegacut / dom) 74 omega = dom * np.arange((3 * nseg) // 2 + 1) 75 76 assert ( 77 omegacut < omega[-1] 78 ), f"omegacut must be smaller than {omega[-1]*us.CM1} CM-1" 79 80 startsave = parameters.get("startsave", 1) 81 """@keyword[fennol_md] ir_parameters/startsave 82 Start saving IR statistics after this many segments. 83 Default: 1 84 """ 85 counter = Counter(nseg, startsave=startsave) 86 state["istep"] = 0 87 state["nsample"] = 0 88 state["nadapt"] = 0 89 90 use_qvel = parameters.get("use_qvel", False) 91 """@keyword[fennol_md] ir_parameters/use_qvel 92 Use quantum velocity correction for IR spectrum calculation. 93 Default: False 94 """ 95 if use_qvel: 96 state["qvel"] = jnp.zeros((3,), dtype=fprec) 97 else: 98 nat = system_data["nat"] 99 state["musave-2"] = jnp.zeros((3,), dtype=fprec) 100 state["musave-1"] = jnp.zeros((3,), dtype=fprec) 101 state["qsave-2"] = jnp.zeros((nat,1), dtype=fprec) 102 state["qsave-1"] = jnp.zeros((nat,1), dtype=fprec) 103 state["pos_save-2"] = jnp.zeros((nat,3), dtype=fprec) 104 state["pos_save-1"] = jnp.zeros((nat,3), dtype=fprec) 105 state["mudot"] = jnp.zeros((nseg, 3), dtype=fprec) 106 state["Cmumu"] = jnp.zeros((nom,), dtype=fprec) 107 state["first"] = True 108 109 kT = system_data["kT"] 110 kubo_fact = np.ones_like(omega) 111 if apply_kubo_fact: 112 uu = 0.5*us.HBAR*omega[1:]/kT 113 kubo_fact[1:] = np.tanh(uu)/uu 114 115 do_deconvolution = parameters.get("deconvolution", False) 116 """@keyword[fennol_md] ir_parameters/deconvolution 117 Apply deconvolution to IR spectrum for better resolution. 118 Default: False 119 """ 120 if do_deconvolution: 121 gamma = simulation_parameters.get("gamma", 1.0 / us.THZ) 122 """@keyword[fennol_md] gamma 123 Friction coefficient for deconvolution of IR spectra. 124 Default: 1.0 ps^-1 125 """ 126 niter_deconv = parameters.get("niter_deconv", 20) 127 """@keyword[fennol_md] ir_parameters/niter_deconv 128 Number of iterations for IR spectrum deconvolution. 129 Default: 20 130 """ 131 print("# Deconvolution of IR spectra with gamma=", gamma*(1./us.PS),"ps^-1 and niter=",niter_deconv) 132 133 134 temp_K = system_data["temperature"] 135 c=2.99792458e-2 # speed of light in cm/ps 136 mufact = 1000*418.40*332.063714*2*np.pi**2/(0.831446215*temp_K*3.*c) 137 pbc_data = system_data.get("pbc", None) 138 if pbc_data is not None: 139 cell = pbc_data["cell"] 140 volume = np.abs(np.linalg.det(cell)) 141 mufact = mufact/volume 142 143 144 @jax.jit 145 def compute_spectra(state): 146 mudot = state["mudot"] 147 smu = jnp.fft.rfft(mudot, 3 * nseg, axis=0, norm="ortho") 148 Cmumu = dt * jnp.sum(jnp.abs(smu[:nom]) ** 2, axis=-1) 149 150 nsinv = 1.0 / state["nsample"] 151 b1 = 1.0 - nsinv 152 Cmumu = state["Cmumu"] * b1 + Cmumu * nsinv 153 154 return { 155 **state, 156 "Cmumu": Cmumu, 157 } 158 159 def write_spectra_to_file(state): 160 Cmumu_avg = np.array(state["Cmumu"])*mufact*kubo_fact[:nom] 161 162 if do_deconvolution: 163 s_out, s_rec, K_D = deconvolute_spectrum( 164 Cmumu_avg, 165 omega[:nom], 166 gamma, 167 niter_deconv, 168 kernel=kernel_lorentz, 169 trans=False, 170 symmetrize=True, 171 verbose=False, 172 K_D=state.get("K_D", None), 173 ) 174 state = {**state, "K_D": K_D} 175 spectra=(s_out,Cmumu_avg,s_rec) 176 else: 177 spectra=(Cmumu_avg,) 178 179 columns = np.column_stack( 180 ( 181 omega[:nom] * us.CM1, 182 *spectra, 183 ) 184 ) 185 np.savetxt( 186 f"IR_spectrum.out", 187 columns, 188 fmt="%12.6f", 189 header="#omega Cmudot", 190 ) 191 print("# IR spectrum written.") 192 193 return state 194 195 @jax.jit 196 def save_dipole(q,vel,pos,dip,cell_reciprocal, state): 197 198 q=q.reshape(-1,1) 199 if use_qvel: 200 qvel = (q*vel).sum(axis=0) 201 if "first" in state: 202 state = { 203 **state, 204 "musave-2":dip, 205 "musave-1":dip, 206 } 207 if use_qvel: 208 state["qvel"] = qvel 209 else: 210 state["pos_save-2"] = pos 211 state["pos_save-1"] = pos 212 state["qsave-2"] = q 213 state["qsave-1"] = q 214 del state["first"] 215 216 istep = state["istep"] 217 new_state = {**state, "istep": istep + 1} 218 new_state["musave-2"] = state["musave-1"] 219 new_state["musave-1"] = dip 220 221 dipdot = (dip - state["musave-2"])/(2*dt) 222 if use_qvel: 223 qrdot = state["qvel"] 224 new_state["qvel"] = qvel 225 else: 226 new_state["pos_save-2"] = state["pos_save-1"] 227 new_state["pos_save-1"] = pos 228 new_state["qsave-2"] = state["qsave-1"] 229 new_state["qsave-1"] = q 230 231 if cell_reciprocal is not None: 232 cell,reciprocal_cell = cell_reciprocal 233 vec = pos - state["pos_save-2"] 234 shift = -jnp.round(jnp.dot(vec, reciprocal_cell)) 235 pos = pos + jnp.dot(shift, cell) 236 237 qrdot = (q*pos - state["qsave-2"]*state["pos_save-2"]).sum(axis=0)/(2*dt) 238 239 new_state["mudot"] = state["mudot"].at[istep].set(qrdot+dipdot) 240 241 return new_state 242 243 # @jax.jit 244 # def save_dipole(mudot, state): 245 246 # istep = state["istep"] 247 # new_state = {**state, "istep": istep + 1} 248 # new_state["mudot"] = state["mudot"].at[istep].set(mudot) 249 250 # return new_state 251 252 def postprocess(state): 253 counter.increment() 254 if not counter.is_reset_step: 255 return state 256 state["nadapt"] += 1 257 state["nsample"] = max(state["nadapt"] - startsave + 1, 1) 258 state = compute_spectra(state) 259 state["istep"] = 0 260 state = write_spectra_to_file(state) 261 return state 262 263 return dipole_model,state,save_dipole, postprocess