fennol.md.barostats
1import numpy as np 2import flax.linen as nn 3import jax 4import jax.numpy as jnp 5import math 6import optax 7from enum import Enum 8 9from ..utils.atomic_units import AtomicUnits as au # CM1,THZ,BOHR,MPROT 10from ..utils import Counter 11from ..utils.deconvolution import ( 12 deconvolute_spectrum, 13 kernel_lorentz_pot, 14 kernel_lorentz, 15) 16 17 18def get_barostat( 19 thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None 20): 21 state = {} 22 23 barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper() 24 25 kT = system_data.get("kT", None) 26 assert kT is not None, "kT must be specified for NPT/NPH simulations" 27 target_pressure = simulation_parameters.get("target_pressure") 28 if barostat_name != "NONE": 29 assert ( 30 target_pressure is not None 31 ), "target_pressure must be specified for NPT/NPH simulations" 32 target_pressure = target_pressure / au.BOHR**3 33 34 nbeads = system_data.get("nbeads", None) 35 variable_cell = True 36 37 anisotropic = simulation_parameters.get("aniso_barostat", False) 38 39 isotropic = not anisotropic 40 41 pbc_data = system_data["pbc"] 42 43 if barostat_name in ["LGV", "LANGEVIN"]: 44 assert rng_key is not None, "rng_key must be provided for QTB barostat" 45 gamma = simulation_parameters.get("gamma_piston", 20.0 / au.THZ) / au.FS 46 tau_piston = simulation_parameters.get("tau_piston", 200.0 / au.FS) * au.FS 47 nat = len(system_data["species"]) 48 masspiston = 3 * nat * kT * tau_piston**2 49 print(f"# LANGEVIN barostat with piston mass={masspiston:.1e} Ha.fs^2") 50 a1 = math.exp(-gamma * dt) 51 a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5 52 53 rng_key, v_key = jax.random.split(rng_key) 54 if anisotropic: 55 extvol = pbc_data["cell"] 56 vextvol = ( 57 jax.random.normal(v_key, (3, 3), dtype=extvol.dtype) 58 * (kT / masspiston) ** 0.5 59 ) 60 vextvol = 0.5 * (vextvol + vextvol.T) 61 62 aniso_mask = simulation_parameters.get( 63 "aniso_mask", [True, True, True, True, True, True] 64 ) 65 assert len(aniso_mask) == 6, "aniso_mask must have 6 elements" 66 aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32) 67 ndof_piston = np.sum(aniso_mask) 68 # xx yy zz xy xz yz 69 aniso_mask = np.array( 70 [ 71 [aniso_mask[0], aniso_mask[3], aniso_mask[4]], 72 [aniso_mask[3], aniso_mask[1], aniso_mask[5]], 73 [aniso_mask[4], aniso_mask[5], aniso_mask[2]], 74 ] 75 ,dtype=np.int32) 76 else: 77 extvol = jnp.asarray(pbc_data["volume"]) 78 vextvol = ( 79 jax.random.normal(v_key, (1,), dtype=extvol.dtype) 80 * (kT / masspiston) ** 0.5 81 ) 82 ndof_piston = 1. 83 84 state["extvol"] = extvol 85 state["vextvol"] = vextvol 86 state["rng_key"] = rng_key 87 88 def barostat(x, vel, system): 89 if nbeads is not None: 90 x, eigx = x[0], x[1:] 91 vel, eigv = vel[0], vel[1:] 92 barostat_state = system["barostat"] 93 extvol = barostat_state["extvol"] 94 vextvol = barostat_state["vextvol"] 95 cell = system["cell"] 96 volume = jnp.abs(jnp.linalg.det(cell)) 97 98 # apply B 99 pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"] 100 if isotropic: 101 dpV = jnp.trace(pV) - 3*volume * target_pressure 102 else: 103 dpV = 0.5 * (pV + pV.T) - volume * target_pressure * jnp.eye(3) 104 105 vextvol = vextvol + (dt/masspiston) * dpV 106 107 # apply A 108 if isotropic: 109 scale1 = jnp.exp((0.5 * dt*(1+1./x.shape[0])) * vextvol) 110 vel = vel / scale1 111 else: 112 vextvol = aniso_mask * vextvol 113 l, O = jnp.linalg.eigh(vextvol + jnp.trace(vextvol) * jnp.eye(3)/(3*x.shape[0])) 114 Dv = jnp.diag(jnp.exp(-0.5 * dt * l)) 115 Dx = jnp.diag(jnp.exp(0.5 * dt * l)) 116 scalev = O @ Dv @ O.T 117 scale1 = O @ Dx @ O.T 118 vel = vel @ scalev 119 120 # apply O 121 if nbeads is not None: 122 eigv, thermostat_state = thermostat( 123 jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"] 124 ) 125 vel, eigv = eigv[0], eigv[1:] 126 else: 127 vel, thermostat_state = thermostat(vel, system["thermostat"]) 128 rng_key, noise_key = jax.random.split(barostat_state["rng_key"]) 129 130 if isotropic: 131 noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype) 132 else: 133 noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype) 134 noise = 0.5 * (noise + noise.T) 135 136 vextvol = a1 * vextvol + a2 * noise 137 138 # apply A 139 if isotropic: 140 scale2 = jnp.exp((0.5 * dt*(1+1./x.shape[0])) * vextvol) 141 vel = vel * scale1 142 x = x * (scale1 * scale2) 143 extvol = extvol * (scale1 * scale2) ** 3 144 cell = cell * (scale1 * scale2) 145 else: 146 vextvol = aniso_mask * vextvol 147 l, O = jnp.linalg.eigh(vextvol + jnp.trace(vextvol) * jnp.eye(3)/(3*x.shape[0])) 148 Dv = jnp.diag(jnp.exp(-0.5 * dt * l)) 149 Dx = jnp.diag(jnp.exp(0.5 * dt * l)) 150 scalev = O @ Dv @ O.T 151 scale = scale1 @ (O @ Dx @ O.T) 152 vel = vel @ scalev 153 x = x @ scale 154 extvol = extvol @ scale 155 cell = extvol 156 157 if nbeads is not None: 158 x = jnp.concatenate((x[None], eigx), axis=0) 159 vel = jnp.concatenate((vel[None], eigv), axis=0) 160 161 piston_temperature = (au.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2) 162 barostat_state = { 163 **barostat_state, 164 "rng_key": rng_key, 165 "vextvol": vextvol, 166 "extvol": extvol, 167 "piston_temperature": piston_temperature, 168 } 169 return ( 170 x, 171 vel, 172 { 173 "barostat": barostat_state, 174 "cell": cell, 175 "thermostat": thermostat_state, 176 }, 177 ) 178 179 elif barostat_name in ["NONE"]: 180 variable_cell = False 181 182 def barostat(x, vel, system): 183 vel, thermostat_state = thermostat(vel, system["thermostat"]) 184 return x, vel, {**system, "thermostat": thermostat_state} 185 186 else: 187 raise ValueError(f"Unknown barostat {barostat_name}") 188 189 return barostat, variable_cell, state
def
get_barostat( thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None):
19def get_barostat( 20 thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None 21): 22 state = {} 23 24 barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper() 25 26 kT = system_data.get("kT", None) 27 assert kT is not None, "kT must be specified for NPT/NPH simulations" 28 target_pressure = simulation_parameters.get("target_pressure") 29 if barostat_name != "NONE": 30 assert ( 31 target_pressure is not None 32 ), "target_pressure must be specified for NPT/NPH simulations" 33 target_pressure = target_pressure / au.BOHR**3 34 35 nbeads = system_data.get("nbeads", None) 36 variable_cell = True 37 38 anisotropic = simulation_parameters.get("aniso_barostat", False) 39 40 isotropic = not anisotropic 41 42 pbc_data = system_data["pbc"] 43 44 if barostat_name in ["LGV", "LANGEVIN"]: 45 assert rng_key is not None, "rng_key must be provided for QTB barostat" 46 gamma = simulation_parameters.get("gamma_piston", 20.0 / au.THZ) / au.FS 47 tau_piston = simulation_parameters.get("tau_piston", 200.0 / au.FS) * au.FS 48 nat = len(system_data["species"]) 49 masspiston = 3 * nat * kT * tau_piston**2 50 print(f"# LANGEVIN barostat with piston mass={masspiston:.1e} Ha.fs^2") 51 a1 = math.exp(-gamma * dt) 52 a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5 53 54 rng_key, v_key = jax.random.split(rng_key) 55 if anisotropic: 56 extvol = pbc_data["cell"] 57 vextvol = ( 58 jax.random.normal(v_key, (3, 3), dtype=extvol.dtype) 59 * (kT / masspiston) ** 0.5 60 ) 61 vextvol = 0.5 * (vextvol + vextvol.T) 62 63 aniso_mask = simulation_parameters.get( 64 "aniso_mask", [True, True, True, True, True, True] 65 ) 66 assert len(aniso_mask) == 6, "aniso_mask must have 6 elements" 67 aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32) 68 ndof_piston = np.sum(aniso_mask) 69 # xx yy zz xy xz yz 70 aniso_mask = np.array( 71 [ 72 [aniso_mask[0], aniso_mask[3], aniso_mask[4]], 73 [aniso_mask[3], aniso_mask[1], aniso_mask[5]], 74 [aniso_mask[4], aniso_mask[5], aniso_mask[2]], 75 ] 76 ,dtype=np.int32) 77 else: 78 extvol = jnp.asarray(pbc_data["volume"]) 79 vextvol = ( 80 jax.random.normal(v_key, (1,), dtype=extvol.dtype) 81 * (kT / masspiston) ** 0.5 82 ) 83 ndof_piston = 1. 84 85 state["extvol"] = extvol 86 state["vextvol"] = vextvol 87 state["rng_key"] = rng_key 88 89 def barostat(x, vel, system): 90 if nbeads is not None: 91 x, eigx = x[0], x[1:] 92 vel, eigv = vel[0], vel[1:] 93 barostat_state = system["barostat"] 94 extvol = barostat_state["extvol"] 95 vextvol = barostat_state["vextvol"] 96 cell = system["cell"] 97 volume = jnp.abs(jnp.linalg.det(cell)) 98 99 # apply B 100 pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"] 101 if isotropic: 102 dpV = jnp.trace(pV) - 3*volume * target_pressure 103 else: 104 dpV = 0.5 * (pV + pV.T) - volume * target_pressure * jnp.eye(3) 105 106 vextvol = vextvol + (dt/masspiston) * dpV 107 108 # apply A 109 if isotropic: 110 scale1 = jnp.exp((0.5 * dt*(1+1./x.shape[0])) * vextvol) 111 vel = vel / scale1 112 else: 113 vextvol = aniso_mask * vextvol 114 l, O = jnp.linalg.eigh(vextvol + jnp.trace(vextvol) * jnp.eye(3)/(3*x.shape[0])) 115 Dv = jnp.diag(jnp.exp(-0.5 * dt * l)) 116 Dx = jnp.diag(jnp.exp(0.5 * dt * l)) 117 scalev = O @ Dv @ O.T 118 scale1 = O @ Dx @ O.T 119 vel = vel @ scalev 120 121 # apply O 122 if nbeads is not None: 123 eigv, thermostat_state = thermostat( 124 jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"] 125 ) 126 vel, eigv = eigv[0], eigv[1:] 127 else: 128 vel, thermostat_state = thermostat(vel, system["thermostat"]) 129 rng_key, noise_key = jax.random.split(barostat_state["rng_key"]) 130 131 if isotropic: 132 noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype) 133 else: 134 noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype) 135 noise = 0.5 * (noise + noise.T) 136 137 vextvol = a1 * vextvol + a2 * noise 138 139 # apply A 140 if isotropic: 141 scale2 = jnp.exp((0.5 * dt*(1+1./x.shape[0])) * vextvol) 142 vel = vel * scale1 143 x = x * (scale1 * scale2) 144 extvol = extvol * (scale1 * scale2) ** 3 145 cell = cell * (scale1 * scale2) 146 else: 147 vextvol = aniso_mask * vextvol 148 l, O = jnp.linalg.eigh(vextvol + jnp.trace(vextvol) * jnp.eye(3)/(3*x.shape[0])) 149 Dv = jnp.diag(jnp.exp(-0.5 * dt * l)) 150 Dx = jnp.diag(jnp.exp(0.5 * dt * l)) 151 scalev = O @ Dv @ O.T 152 scale = scale1 @ (O @ Dx @ O.T) 153 vel = vel @ scalev 154 x = x @ scale 155 extvol = extvol @ scale 156 cell = extvol 157 158 if nbeads is not None: 159 x = jnp.concatenate((x[None], eigx), axis=0) 160 vel = jnp.concatenate((vel[None], eigv), axis=0) 161 162 piston_temperature = (au.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2) 163 barostat_state = { 164 **barostat_state, 165 "rng_key": rng_key, 166 "vextvol": vextvol, 167 "extvol": extvol, 168 "piston_temperature": piston_temperature, 169 } 170 return ( 171 x, 172 vel, 173 { 174 "barostat": barostat_state, 175 "cell": cell, 176 "thermostat": thermostat_state, 177 }, 178 ) 179 180 elif barostat_name in ["NONE"]: 181 variable_cell = False 182 183 def barostat(x, vel, system): 184 vel, thermostat_state = thermostat(vel, system["thermostat"]) 185 return x, vel, {**system, "thermostat": thermostat_state} 186 187 else: 188 raise ValueError(f"Unknown barostat {barostat_name}") 189 190 return barostat, variable_cell, state