fennol.md.barostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7from enum import Enum 8 9from ..utils.atomic_units import AtomicUnits as au # CM1,THZ,BOHR,MPROT 10from ..utils import Counter 11from ..utils.deconvolution import ( 12 deconvolute_spectrum, 13 kernel_lorentz_pot, 14 kernel_lorentz, 15) 16 17 18def get_barostat( 19 thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={} 20): 21 state = {} 22 23 barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper() 24 25 kT = system_data.get("kT", None) 26 assert kT is not None, "kT must be specified for NPT/NPH simulations" 27 target_pressure = simulation_parameters.get("target_pressure") 28 if barostat_name != "NONE": 29 assert ( 30 target_pressure is not None 31 ), "target_pressure must be specified for NPT/NPH simulations" 32 target_pressure = target_pressure / au.BOHR**3 33 34 nbeads = system_data.get("nbeads", None) 35 variable_cell = True 36 37 anisotropic = simulation_parameters.get("aniso_barostat", False) 38 39 isotropic = not anisotropic 40 41 pbc_data = system_data["pbc"] 42 start_barostat = simulation_parameters.get("start_barostat", 0.0)*au.FS 43 start_time = restart_data.get("simulation_time_ps",0.) * 1e3 44 start_barostat = max(0.,start_barostat-start_time) 45 istart_barostat = int(round(start_barostat / dt)) 46 if istart_barostat > 0 and barostat_name not in ["NONE"]: 47 print( 48 f"# BAROSTAT will start at {start_barostat/1000:.3f} ps ({istart_barostat} steps)" 49 ) 50 else: 51 istart_barostat = 0 52 53 if barostat_name in ["LGV", "LANGEVIN"]: 54 assert rng_key is not None, "rng_key must be provided for QTB barostat" 55 gamma = simulation_parameters.get("gamma_piston", 20.0 / au.THZ) / au.FS 56 tau_piston = simulation_parameters.get("tau_piston", 200.0 / au.FS) * au.FS 57 nat = system_data["nat"] 58 masspiston = 3 * nat * kT * tau_piston**2 59 print(f"# LANGEVIN barostat with piston mass={masspiston:.1e} Ha.fs^2") 60 a1 = math.exp(-gamma * dt) 61 a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5 62 63 64 rng_key, v_key = jax.random.split(rng_key) 65 if anisotropic: 66 extvol = pbc_data["cell"] 67 vextvol = ( 68 jax.random.normal(v_key, (3, 3), dtype=extvol.dtype) 69 * (kT / masspiston) ** 0.5 70 ) 71 vextvol = 0.5 * (vextvol + vextvol.T) 72 73 aniso_mask = simulation_parameters.get( 74 "aniso_mask", [True, True, True, True, True, True] 75 ) 76 assert len(aniso_mask) == 6, "aniso_mask must have 6 elements" 77 aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32) 78 ndof_piston = np.sum(aniso_mask) 79 # xx yy zz xy xz yz 80 aniso_mask = np.array( 81 [ 82 [aniso_mask[0], aniso_mask[3], aniso_mask[4]], 83 [aniso_mask[3], aniso_mask[1], aniso_mask[5]], 84 [aniso_mask[4], aniso_mask[5], aniso_mask[2]], 85 ] 86 ,dtype=np.int32) 87 else: 88 extvol = jnp.asarray(pbc_data["volume"]) 89 vextvol = ( 90 jax.random.normal(v_key, (1,), dtype=extvol.dtype) 91 * (kT / masspiston) ** 0.5 92 ) 93 ndof_piston = 1. 94 95 state["extvol"] = extvol 96 state["vextvol"] = vextvol 97 state["rng_key"] = rng_key 98 state["istep"] = 0 99 100 def barostat(x, vel, system): 101 if nbeads is not None: 102 x, eigx = x[0], x[1:] 103 vel, eigv = vel[0], vel[1:] 104 barostat_state = system["barostat"] 105 extvol = barostat_state["extvol"] 106 vextvol = barostat_state["vextvol"] 107 cell = system["cell"] 108 volume = jnp.abs(jnp.linalg.det(cell)) 109 110 istep = barostat_state["istep"] + 1 111 dt_bar = dt * (istep >= istart_barostat) 112 113 114 # apply B 115 # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"] 116 ek = system["ek_c"] if nbeads is not None else system["ek"] 117 pV = system["PV_tensor"] + ek*jnp.array(np.eye(3)*(2/(3*x.shape[0]))) 118 if isotropic: 119 dpV = jnp.trace(pV) - 3*volume * target_pressure 120 else: 121 dpV = 0.5 * (pV + pV.T) - volume * jnp.array(target_pressure * np.eye(3)) 122 123 vextvol = vextvol + (dt_bar/masspiston) * dpV 124 125 # apply A 126 if isotropic: 127 scalev = jnp.exp((-0.5 * dt_bar*(1+1./x.shape[0])) * vextvol) 128 vel = vel * scalev 129 scale1 = jnp.exp((0.5 * dt_bar) * vextvol) 130 else: 131 vextvol = aniso_mask * vextvol 132 l, O = jnp.linalg.eigh(vextvol) 133 lcorr = jnp.trace(vextvol)/(3*x.shape[0]) 134 Dv = jnp.diag(jnp.exp(-0.5 * dt_bar * (l+lcorr))) 135 Dx = jnp.diag(jnp.exp(0.5 * dt_bar * l)) 136 scalev = O @ Dv @ O.T 137 scale1 = O @ Dx @ O.T 138 vel = vel @ scalev 139 140 # apply O 141 if nbeads is not None: 142 eigv, thermostat_state = thermostat( 143 jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"] 144 ) 145 vel, eigv = eigv[0], eigv[1:] 146 else: 147 vel, thermostat_state = thermostat(vel, system["thermostat"]) 148 rng_key, noise_key = jax.random.split(barostat_state["rng_key"]) 149 150 if isotropic: 151 noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype) 152 else: 153 noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype) 154 noise = 0.5 * (noise + noise.T) 155 156 vextvol = a1 * vextvol + a2 * noise 157 158 # apply A 159 if isotropic: 160 scalev = jnp.exp((-0.5 * dt_bar*(1+1./x.shape[0])) * vextvol) 161 vel = vel * scalev 162 scale2 = jnp.exp((0.5 * dt_bar) * vextvol) 163 x = x * (scale1 * scale2) 164 extvol = extvol * (scale1 * scale2) ** 3 165 cell = cell * (scale1 * scale2) 166 else: 167 vextvol = aniso_mask * vextvol 168 l, O = jnp.linalg.eigh(vextvol) 169 lcorr = jnp.trace(vextvol)/(3*x.shape[0]) 170 Dv = jnp.diag(jnp.exp(-0.5 * dt_bar * (l+lcorr))) 171 Dx = jnp.diag(jnp.exp(0.5 * dt_bar * l)) 172 scalev = O @ Dv @ O.T 173 scale = scale1 @ (O @ Dx @ O.T) 174 vel = vel @ scalev 175 x = x @ scale 176 extvol = extvol @ scale 177 cell = extvol 178 179 if nbeads is not None: 180 x = jnp.concatenate((x[None], eigx), axis=0) 181 vel = jnp.concatenate((vel[None], eigv), axis=0) 182 183 piston_temperature = (au.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2) 184 barostat_state = { 185 **barostat_state, 186 "istep": istep, 187 "rng_key": rng_key, 188 "vextvol": vextvol, 189 "extvol": extvol, 190 "piston_temperature": piston_temperature, 191 } 192 return ( 193 x, 194 vel, 195 { 196 **system, 197 "barostat": barostat_state, 198 "cell": cell, 199 "thermostat": thermostat_state, 200 }, 201 ) 202 203 elif barostat_name in ["NONE"]: 204 variable_cell = False 205 206 def barostat(x, vel, system): 207 vel, thermostat_state = thermostat(vel, system["thermostat"]) 208 return x, vel, {**system, "thermostat": thermostat_state} 209 210 else: 211 raise ValueError(f"Unknown barostat {barostat_name}") 212 213 return barostat, variable_cell, state
def
get_barostat( thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
19def get_barostat( 20 thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={} 21): 22 state = {} 23 24 barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper() 25 26 kT = system_data.get("kT", None) 27 assert kT is not None, "kT must be specified for NPT/NPH simulations" 28 target_pressure = simulation_parameters.get("target_pressure") 29 if barostat_name != "NONE": 30 assert ( 31 target_pressure is not None 32 ), "target_pressure must be specified for NPT/NPH simulations" 33 target_pressure = target_pressure / au.BOHR**3 34 35 nbeads = system_data.get("nbeads", None) 36 variable_cell = True 37 38 anisotropic = simulation_parameters.get("aniso_barostat", False) 39 40 isotropic = not anisotropic 41 42 pbc_data = system_data["pbc"] 43 start_barostat = simulation_parameters.get("start_barostat", 0.0)*au.FS 44 start_time = restart_data.get("simulation_time_ps",0.) * 1e3 45 start_barostat = max(0.,start_barostat-start_time) 46 istart_barostat = int(round(start_barostat / dt)) 47 if istart_barostat > 0 and barostat_name not in ["NONE"]: 48 print( 49 f"# BAROSTAT will start at {start_barostat/1000:.3f} ps ({istart_barostat} steps)" 50 ) 51 else: 52 istart_barostat = 0 53 54 if barostat_name in ["LGV", "LANGEVIN"]: 55 assert rng_key is not None, "rng_key must be provided for QTB barostat" 56 gamma = simulation_parameters.get("gamma_piston", 20.0 / au.THZ) / au.FS 57 tau_piston = simulation_parameters.get("tau_piston", 200.0 / au.FS) * au.FS 58 nat = system_data["nat"] 59 masspiston = 3 * nat * kT * tau_piston**2 60 print(f"# LANGEVIN barostat with piston mass={masspiston:.1e} Ha.fs^2") 61 a1 = math.exp(-gamma * dt) 62 a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5 63 64 65 rng_key, v_key = jax.random.split(rng_key) 66 if anisotropic: 67 extvol = pbc_data["cell"] 68 vextvol = ( 69 jax.random.normal(v_key, (3, 3), dtype=extvol.dtype) 70 * (kT / masspiston) ** 0.5 71 ) 72 vextvol = 0.5 * (vextvol + vextvol.T) 73 74 aniso_mask = simulation_parameters.get( 75 "aniso_mask", [True, True, True, True, True, True] 76 ) 77 assert len(aniso_mask) == 6, "aniso_mask must have 6 elements" 78 aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32) 79 ndof_piston = np.sum(aniso_mask) 80 # xx yy zz xy xz yz 81 aniso_mask = np.array( 82 [ 83 [aniso_mask[0], aniso_mask[3], aniso_mask[4]], 84 [aniso_mask[3], aniso_mask[1], aniso_mask[5]], 85 [aniso_mask[4], aniso_mask[5], aniso_mask[2]], 86 ] 87 ,dtype=np.int32) 88 else: 89 extvol = jnp.asarray(pbc_data["volume"]) 90 vextvol = ( 91 jax.random.normal(v_key, (1,), dtype=extvol.dtype) 92 * (kT / masspiston) ** 0.5 93 ) 94 ndof_piston = 1. 95 96 state["extvol"] = extvol 97 state["vextvol"] = vextvol 98 state["rng_key"] = rng_key 99 state["istep"] = 0 100 101 def barostat(x, vel, system): 102 if nbeads is not None: 103 x, eigx = x[0], x[1:] 104 vel, eigv = vel[0], vel[1:] 105 barostat_state = system["barostat"] 106 extvol = barostat_state["extvol"] 107 vextvol = barostat_state["vextvol"] 108 cell = system["cell"] 109 volume = jnp.abs(jnp.linalg.det(cell)) 110 111 istep = barostat_state["istep"] + 1 112 dt_bar = dt * (istep >= istart_barostat) 113 114 115 # apply B 116 # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"] 117 ek = system["ek_c"] if nbeads is not None else system["ek"] 118 pV = system["PV_tensor"] + ek*jnp.array(np.eye(3)*(2/(3*x.shape[0]))) 119 if isotropic: 120 dpV = jnp.trace(pV) - 3*volume * target_pressure 121 else: 122 dpV = 0.5 * (pV + pV.T) - volume * jnp.array(target_pressure * np.eye(3)) 123 124 vextvol = vextvol + (dt_bar/masspiston) * dpV 125 126 # apply A 127 if isotropic: 128 scalev = jnp.exp((-0.5 * dt_bar*(1+1./x.shape[0])) * vextvol) 129 vel = vel * scalev 130 scale1 = jnp.exp((0.5 * dt_bar) * vextvol) 131 else: 132 vextvol = aniso_mask * vextvol 133 l, O = jnp.linalg.eigh(vextvol) 134 lcorr = jnp.trace(vextvol)/(3*x.shape[0]) 135 Dv = jnp.diag(jnp.exp(-0.5 * dt_bar * (l+lcorr))) 136 Dx = jnp.diag(jnp.exp(0.5 * dt_bar * l)) 137 scalev = O @ Dv @ O.T 138 scale1 = O @ Dx @ O.T 139 vel = vel @ scalev 140 141 # apply O 142 if nbeads is not None: 143 eigv, thermostat_state = thermostat( 144 jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"] 145 ) 146 vel, eigv = eigv[0], eigv[1:] 147 else: 148 vel, thermostat_state = thermostat(vel, system["thermostat"]) 149 rng_key, noise_key = jax.random.split(barostat_state["rng_key"]) 150 151 if isotropic: 152 noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype) 153 else: 154 noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype) 155 noise = 0.5 * (noise + noise.T) 156 157 vextvol = a1 * vextvol + a2 * noise 158 159 # apply A 160 if isotropic: 161 scalev = jnp.exp((-0.5 * dt_bar*(1+1./x.shape[0])) * vextvol) 162 vel = vel * scalev 163 scale2 = jnp.exp((0.5 * dt_bar) * vextvol) 164 x = x * (scale1 * scale2) 165 extvol = extvol * (scale1 * scale2) ** 3 166 cell = cell * (scale1 * scale2) 167 else: 168 vextvol = aniso_mask * vextvol 169 l, O = jnp.linalg.eigh(vextvol) 170 lcorr = jnp.trace(vextvol)/(3*x.shape[0]) 171 Dv = jnp.diag(jnp.exp(-0.5 * dt_bar * (l+lcorr))) 172 Dx = jnp.diag(jnp.exp(0.5 * dt_bar * l)) 173 scalev = O @ Dv @ O.T 174 scale = scale1 @ (O @ Dx @ O.T) 175 vel = vel @ scalev 176 x = x @ scale 177 extvol = extvol @ scale 178 cell = extvol 179 180 if nbeads is not None: 181 x = jnp.concatenate((x[None], eigx), axis=0) 182 vel = jnp.concatenate((vel[None], eigv), axis=0) 183 184 piston_temperature = (au.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2) 185 barostat_state = { 186 **barostat_state, 187 "istep": istep, 188 "rng_key": rng_key, 189 "vextvol": vextvol, 190 "extvol": extvol, 191 "piston_temperature": piston_temperature, 192 } 193 return ( 194 x, 195 vel, 196 { 197 **system, 198 "barostat": barostat_state, 199 "cell": cell, 200 "thermostat": thermostat_state, 201 }, 202 ) 203 204 elif barostat_name in ["NONE"]: 205 variable_cell = False 206 207 def barostat(x, vel, system): 208 vel, thermostat_state = thermostat(vel, system["thermostat"]) 209 return x, vel, {**system, "thermostat": thermostat_state} 210 211 else: 212 raise ValueError(f"Unknown barostat {barostat_name}") 213 214 return barostat, variable_cell, state