fennol.md.barostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7from enum import Enum 8 9from .utils import us 10 11def get_barostat( 12 thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={} 13): 14 state = {} 15 16 barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper() 17 """@keyword[fennol_md] barostat 18 Type of barostat for pressure control (NONE, LGV, LANGEVIN). 19 Default: NONE 20 """ 21 22 kT = system_data.get("kT", None) 23 assert kT is not None, "kT must be specified for NPT/NPH simulations" 24 target_pressure = simulation_parameters.get("target_pressure") 25 """@keyword[fennol_md] target_pressure 26 Target pressure for NPT ensemble simulations. 27 Required for barostat != NONE 28 """ 29 if barostat_name != "NONE": 30 assert ( 31 target_pressure is not None 32 ), "target_pressure must be specified for NPT/NPH simulations" 33 34 nbeads = system_data.get("nbeads", None) 35 variable_cell = True 36 37 anisotropic = simulation_parameters.get("aniso_barostat", False) 38 """@keyword[fennol_md] aniso_barostat 39 Use anisotropic barostat allowing independent cell parameter scaling. 40 Default: False 41 """ 42 43 isotropic = not anisotropic 44 45 pbc_data = system_data["pbc"] 46 start_barostat = simulation_parameters.get("start_barostat", 0.0) 47 """@keyword[fennol_md] start_barostat 48 Time delay before starting barostat pressure coupling. 49 Default: 0.0 50 """ 51 start_time = restart_data.get("simulation_time_ps",0.) / us.PS 52 start_barostat = max(0.,start_barostat-start_time) 53 istart_barostat = int(round(start_barostat / dt)) 54 if istart_barostat > 0 and barostat_name not in ["NONE"]: 55 print( 56 f"# BAROSTAT will start at {start_barostat*us.PS:.3f} ps ({istart_barostat} steps)" 57 ) 58 else: 59 istart_barostat = 0 60 61 if barostat_name in ["LGV", "LANGEVIN"]: 62 assert rng_key is not None, "rng_key must be provided for QTB barostat" 63 gamma = simulation_parameters.get("gamma_piston", 20.0 / us.THZ) 64 """@keyword[fennol_md] gamma_piston 65 Piston friction coefficient for Langevin barostat. 66 Default: 20.0 ps^-1 67 """ 68 tau_piston = simulation_parameters.get("tau_piston", 200.0 / us.FS) 69 """@keyword[fennol_md] tau_piston 70 Piston time constant for barostat coupling. 71 Default: 200.0 fs 72 """ 73 nat = system_data["nat"] 74 masspiston = 3 * nat * kT * tau_piston**2 75 print(f"# LANGEVIN barostat with piston mass={masspiston*us.get_multiplier('KCALPERMOL*PS^{2}'):.1e} kcal/mol.ps^2") 76 a1 = math.exp(-gamma * dt) 77 a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5 78 79 80 rng_key, v_key = jax.random.split(rng_key) 81 if anisotropic: 82 extvol = pbc_data["cell"] 83 vextvol = ( 84 jax.random.normal(v_key, (3, 3), dtype=extvol.dtype) 85 * (kT / masspiston) ** 0.5 86 ) 87 vextvol = 0.5 * (vextvol + vextvol.T) 88 89 aniso_mask = simulation_parameters.get( 90 "aniso_mask", [True, True, True, True, True, True] 91 ) 92 """@keyword[fennol_md] aniso_mask 93 Mask for anisotropic barostat degrees of freedom [xx, yy, zz, xy, xz, yz]. 94 Default: [True, True, True, True, True, True] 95 """ 96 assert len(aniso_mask) == 6, "aniso_mask must have 6 elements" 97 aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32) 98 ndof_piston = np.sum(aniso_mask) 99 # xx yy zz xy xz yz 100 aniso_mask = np.array( 101 [ 102 [aniso_mask[0], aniso_mask[3], aniso_mask[4]], 103 [aniso_mask[3], aniso_mask[1], aniso_mask[5]], 104 [aniso_mask[4], aniso_mask[5], aniso_mask[2]], 105 ] 106 ,dtype=np.int32) 107 else: 108 extvol = jnp.asarray(pbc_data["volume"]) 109 vextvol = ( 110 jax.random.normal(v_key, (1,), dtype=extvol.dtype) 111 * (kT / masspiston) ** 0.5 112 ) 113 ndof_piston = 1. 114 115 state["extvol"] = extvol 116 state["vextvol"] = vextvol 117 state["rng_key"] = rng_key 118 state["istep"] = 0 119 120 def barostat(x, vel, system): 121 if nbeads is not None: 122 x, eigx = x[0], x[1:] 123 vel, eigv = vel[0], vel[1:] 124 barostat_state = system["barostat"] 125 extvol = barostat_state["extvol"] 126 vextvol = barostat_state["vextvol"] 127 cell = system["cell"] 128 volume = jnp.abs(jnp.linalg.det(cell)) 129 130 istep = barostat_state["istep"] + 1 131 dt_bar = dt * (istep >= istart_barostat) 132 133 134 # apply B 135 # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"] 136 ek = system["ek_c"] if nbeads is not None else system["ek"] 137 Pres = system["pressure_tensor"] + ek*jnp.array(np.eye(3)*(2/(3*x.shape[0])))/volume 138 if isotropic: 139 dPres = jnp.trace(Pres) - 3 * target_pressure 140 else: 141 dPres = 0.5 * (Pres + Pres.T) - jnp.array(target_pressure * np.eye(3)) 142 143 vextvol = vextvol + ((dt_bar/masspiston)*volume) * dPres 144 145 # apply A 146 if isotropic: 147 vdt2 = 0.5*dt_bar * vextvol 148 scalev = jnp.exp(-vdt2*(1+1./x.shape[0])) 149 vel = vel * scalev 150 scale1 = jnp.exp(vdt2) 151 else: 152 vextvol = aniso_mask * vextvol 153 vdt2 = 0.5* dt_bar * vextvol 154 l, O = jnp.linalg.eigh(vdt2) 155 lcorr = jnp.trace(vdt2)/(3*x.shape[0]) 156 Dv = jnp.diag(jnp.exp(-(l+lcorr))) 157 Dx = jnp.diag(jnp.exp(l)) 158 scalev = O @ Dv @ O.T 159 scale1 = O @ Dx @ O.T 160 vel = vel @ scalev 161 162 # apply O 163 if nbeads is not None: 164 eigv, thermostat_state = thermostat( 165 jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"] 166 ) 167 vel, eigv = eigv[0], eigv[1:] 168 else: 169 vel, thermostat_state = thermostat(vel, system["thermostat"]) 170 rng_key, noise_key = jax.random.split(barostat_state["rng_key"]) 171 172 if isotropic: 173 noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype) 174 else: 175 noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype) 176 noise = 0.5 * (noise + noise.T) 177 178 vextvol = a1 * vextvol + a2 * noise 179 180 # apply A 181 if isotropic: 182 vdt2 = 0.5 * dt_bar * vextvol 183 scalev = jnp.exp(-vdt2*(1+1./x.shape[0])) 184 vel = vel * scalev 185 scale2 = jnp.exp(vdt2) 186 x = x * (scale1 * scale2) 187 extvol = extvol * (scale1 * scale2) ** 3 188 cell = cell * (scale1 * scale2) 189 else: 190 vextvol = aniso_mask * vextvol 191 vdt2 = 0.5 * dt_bar * vextvol 192 l, O = jnp.linalg.eigh(vdt2) 193 lcorr = jnp.trace(vdt2)/(3*x.shape[0]) 194 Dv = jnp.diag(jnp.exp(-(l+lcorr))) 195 Dx = jnp.diag(jnp.exp(l)) 196 scalev = O @ Dv @ O.T 197 scale = scale1 @ (O @ Dx @ O.T) 198 vel = vel @ scalev 199 x = x @ scale 200 extvol = extvol @ scale 201 cell = extvol 202 203 if nbeads is not None: 204 x = jnp.concatenate((x[None], eigx), axis=0) 205 vel = jnp.concatenate((vel[None], eigv), axis=0) 206 207 piston_temperature = (us.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2) 208 barostat_state = { 209 **barostat_state, 210 "istep": istep, 211 "rng_key": rng_key, 212 "vextvol": vextvol, 213 "extvol": extvol, 214 "piston_temperature": piston_temperature, 215 } 216 return ( 217 x, 218 vel, 219 { 220 **system, 221 "barostat": barostat_state, 222 "cell": cell, 223 "thermostat": thermostat_state, 224 }, 225 ) 226 227 elif barostat_name in ["NONE"]: 228 variable_cell = False 229 230 def barostat(x, vel, system): 231 vel, thermostat_state = thermostat(vel, system["thermostat"]) 232 return x, vel, {**system, "thermostat": thermostat_state} 233 234 else: 235 raise ValueError(f"Unknown barostat {barostat_name}") 236 237 return barostat, variable_cell, state
def
get_barostat( thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
12def get_barostat( 13 thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={} 14): 15 state = {} 16 17 barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper() 18 """@keyword[fennol_md] barostat 19 Type of barostat for pressure control (NONE, LGV, LANGEVIN). 20 Default: NONE 21 """ 22 23 kT = system_data.get("kT", None) 24 assert kT is not None, "kT must be specified for NPT/NPH simulations" 25 target_pressure = simulation_parameters.get("target_pressure") 26 """@keyword[fennol_md] target_pressure 27 Target pressure for NPT ensemble simulations. 28 Required for barostat != NONE 29 """ 30 if barostat_name != "NONE": 31 assert ( 32 target_pressure is not None 33 ), "target_pressure must be specified for NPT/NPH simulations" 34 35 nbeads = system_data.get("nbeads", None) 36 variable_cell = True 37 38 anisotropic = simulation_parameters.get("aniso_barostat", False) 39 """@keyword[fennol_md] aniso_barostat 40 Use anisotropic barostat allowing independent cell parameter scaling. 41 Default: False 42 """ 43 44 isotropic = not anisotropic 45 46 pbc_data = system_data["pbc"] 47 start_barostat = simulation_parameters.get("start_barostat", 0.0) 48 """@keyword[fennol_md] start_barostat 49 Time delay before starting barostat pressure coupling. 50 Default: 0.0 51 """ 52 start_time = restart_data.get("simulation_time_ps",0.) / us.PS 53 start_barostat = max(0.,start_barostat-start_time) 54 istart_barostat = int(round(start_barostat / dt)) 55 if istart_barostat > 0 and barostat_name not in ["NONE"]: 56 print( 57 f"# BAROSTAT will start at {start_barostat*us.PS:.3f} ps ({istart_barostat} steps)" 58 ) 59 else: 60 istart_barostat = 0 61 62 if barostat_name in ["LGV", "LANGEVIN"]: 63 assert rng_key is not None, "rng_key must be provided for QTB barostat" 64 gamma = simulation_parameters.get("gamma_piston", 20.0 / us.THZ) 65 """@keyword[fennol_md] gamma_piston 66 Piston friction coefficient for Langevin barostat. 67 Default: 20.0 ps^-1 68 """ 69 tau_piston = simulation_parameters.get("tau_piston", 200.0 / us.FS) 70 """@keyword[fennol_md] tau_piston 71 Piston time constant for barostat coupling. 72 Default: 200.0 fs 73 """ 74 nat = system_data["nat"] 75 masspiston = 3 * nat * kT * tau_piston**2 76 print(f"# LANGEVIN barostat with piston mass={masspiston*us.get_multiplier('KCALPERMOL*PS^{2}'):.1e} kcal/mol.ps^2") 77 a1 = math.exp(-gamma * dt) 78 a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5 79 80 81 rng_key, v_key = jax.random.split(rng_key) 82 if anisotropic: 83 extvol = pbc_data["cell"] 84 vextvol = ( 85 jax.random.normal(v_key, (3, 3), dtype=extvol.dtype) 86 * (kT / masspiston) ** 0.5 87 ) 88 vextvol = 0.5 * (vextvol + vextvol.T) 89 90 aniso_mask = simulation_parameters.get( 91 "aniso_mask", [True, True, True, True, True, True] 92 ) 93 """@keyword[fennol_md] aniso_mask 94 Mask for anisotropic barostat degrees of freedom [xx, yy, zz, xy, xz, yz]. 95 Default: [True, True, True, True, True, True] 96 """ 97 assert len(aniso_mask) == 6, "aniso_mask must have 6 elements" 98 aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32) 99 ndof_piston = np.sum(aniso_mask) 100 # xx yy zz xy xz yz 101 aniso_mask = np.array( 102 [ 103 [aniso_mask[0], aniso_mask[3], aniso_mask[4]], 104 [aniso_mask[3], aniso_mask[1], aniso_mask[5]], 105 [aniso_mask[4], aniso_mask[5], aniso_mask[2]], 106 ] 107 ,dtype=np.int32) 108 else: 109 extvol = jnp.asarray(pbc_data["volume"]) 110 vextvol = ( 111 jax.random.normal(v_key, (1,), dtype=extvol.dtype) 112 * (kT / masspiston) ** 0.5 113 ) 114 ndof_piston = 1. 115 116 state["extvol"] = extvol 117 state["vextvol"] = vextvol 118 state["rng_key"] = rng_key 119 state["istep"] = 0 120 121 def barostat(x, vel, system): 122 if nbeads is not None: 123 x, eigx = x[0], x[1:] 124 vel, eigv = vel[0], vel[1:] 125 barostat_state = system["barostat"] 126 extvol = barostat_state["extvol"] 127 vextvol = barostat_state["vextvol"] 128 cell = system["cell"] 129 volume = jnp.abs(jnp.linalg.det(cell)) 130 131 istep = barostat_state["istep"] + 1 132 dt_bar = dt * (istep >= istart_barostat) 133 134 135 # apply B 136 # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"] 137 ek = system["ek_c"] if nbeads is not None else system["ek"] 138 Pres = system["pressure_tensor"] + ek*jnp.array(np.eye(3)*(2/(3*x.shape[0])))/volume 139 if isotropic: 140 dPres = jnp.trace(Pres) - 3 * target_pressure 141 else: 142 dPres = 0.5 * (Pres + Pres.T) - jnp.array(target_pressure * np.eye(3)) 143 144 vextvol = vextvol + ((dt_bar/masspiston)*volume) * dPres 145 146 # apply A 147 if isotropic: 148 vdt2 = 0.5*dt_bar * vextvol 149 scalev = jnp.exp(-vdt2*(1+1./x.shape[0])) 150 vel = vel * scalev 151 scale1 = jnp.exp(vdt2) 152 else: 153 vextvol = aniso_mask * vextvol 154 vdt2 = 0.5* dt_bar * vextvol 155 l, O = jnp.linalg.eigh(vdt2) 156 lcorr = jnp.trace(vdt2)/(3*x.shape[0]) 157 Dv = jnp.diag(jnp.exp(-(l+lcorr))) 158 Dx = jnp.diag(jnp.exp(l)) 159 scalev = O @ Dv @ O.T 160 scale1 = O @ Dx @ O.T 161 vel = vel @ scalev 162 163 # apply O 164 if nbeads is not None: 165 eigv, thermostat_state = thermostat( 166 jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"] 167 ) 168 vel, eigv = eigv[0], eigv[1:] 169 else: 170 vel, thermostat_state = thermostat(vel, system["thermostat"]) 171 rng_key, noise_key = jax.random.split(barostat_state["rng_key"]) 172 173 if isotropic: 174 noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype) 175 else: 176 noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype) 177 noise = 0.5 * (noise + noise.T) 178 179 vextvol = a1 * vextvol + a2 * noise 180 181 # apply A 182 if isotropic: 183 vdt2 = 0.5 * dt_bar * vextvol 184 scalev = jnp.exp(-vdt2*(1+1./x.shape[0])) 185 vel = vel * scalev 186 scale2 = jnp.exp(vdt2) 187 x = x * (scale1 * scale2) 188 extvol = extvol * (scale1 * scale2) ** 3 189 cell = cell * (scale1 * scale2) 190 else: 191 vextvol = aniso_mask * vextvol 192 vdt2 = 0.5 * dt_bar * vextvol 193 l, O = jnp.linalg.eigh(vdt2) 194 lcorr = jnp.trace(vdt2)/(3*x.shape[0]) 195 Dv = jnp.diag(jnp.exp(-(l+lcorr))) 196 Dx = jnp.diag(jnp.exp(l)) 197 scalev = O @ Dv @ O.T 198 scale = scale1 @ (O @ Dx @ O.T) 199 vel = vel @ scalev 200 x = x @ scale 201 extvol = extvol @ scale 202 cell = extvol 203 204 if nbeads is not None: 205 x = jnp.concatenate((x[None], eigx), axis=0) 206 vel = jnp.concatenate((vel[None], eigv), axis=0) 207 208 piston_temperature = (us.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2) 209 barostat_state = { 210 **barostat_state, 211 "istep": istep, 212 "rng_key": rng_key, 213 "vextvol": vextvol, 214 "extvol": extvol, 215 "piston_temperature": piston_temperature, 216 } 217 return ( 218 x, 219 vel, 220 { 221 **system, 222 "barostat": barostat_state, 223 "cell": cell, 224 "thermostat": thermostat_state, 225 }, 226 ) 227 228 elif barostat_name in ["NONE"]: 229 variable_cell = False 230 231 def barostat(x, vel, system): 232 vel, thermostat_state = thermostat(vel, system["thermostat"]) 233 return x, vel, {**system, "thermostat": thermostat_state} 234 235 else: 236 raise ValueError(f"Unknown barostat {barostat_name}") 237 238 return barostat, variable_cell, state