fennol.md.barostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7from enum import Enum
  8
  9from ..utils.atomic_units import AtomicUnits as au  # CM1,THZ,BOHR,MPROT
 10from ..utils import Counter
 11from ..utils.deconvolution import (
 12    deconvolute_spectrum,
 13    kernel_lorentz_pot,
 14    kernel_lorentz,
 15)
 16
 17
 18def get_barostat(
 19    thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}
 20):
 21    state = {}
 22
 23    barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper()
 24
 25    kT = system_data.get("kT", None)
 26    assert kT is not None, "kT must be specified for NPT/NPH simulations"
 27    target_pressure = simulation_parameters.get("target_pressure")
 28    if barostat_name != "NONE":
 29        assert (
 30            target_pressure is not None
 31        ), "target_pressure must be specified for NPT/NPH simulations"
 32        target_pressure = target_pressure / au.BOHR**3
 33
 34    nbeads = system_data.get("nbeads", None)
 35    variable_cell = True
 36
 37    anisotropic = simulation_parameters.get("aniso_barostat", False)
 38    
 39    isotropic = not anisotropic
 40
 41    pbc_data = system_data["pbc"]
 42    start_barostat = simulation_parameters.get("start_barostat", 0.0)*au.FS
 43    start_time = restart_data.get("simulation_time_ps",0.) * 1e3
 44    start_barostat = max(0.,start_barostat-start_time)
 45    istart_barostat = int(round(start_barostat / dt))
 46    if istart_barostat > 0 and barostat_name not in ["NONE"]:
 47        print(
 48            f"# BAROSTAT will start at {start_barostat/1000:.3f} ps ({istart_barostat} steps)"
 49        )
 50    else:
 51        istart_barostat = 0
 52
 53    if barostat_name in ["LGV", "LANGEVIN"]:
 54        assert rng_key is not None, "rng_key must be provided for QTB barostat"
 55        gamma = simulation_parameters.get("gamma_piston", 20.0 / au.THZ) / au.FS
 56        tau_piston = simulation_parameters.get("tau_piston", 200.0 / au.FS) * au.FS
 57        nat = system_data["nat"]
 58        masspiston = 3 * nat * kT * tau_piston**2
 59        print(f"# LANGEVIN barostat with piston mass={masspiston:.1e} Ha.fs^2")
 60        a1 = math.exp(-gamma * dt)
 61        a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5
 62
 63
 64        rng_key, v_key = jax.random.split(rng_key)
 65        if anisotropic:
 66            extvol = pbc_data["cell"]
 67            vextvol = (
 68                jax.random.normal(v_key, (3, 3), dtype=extvol.dtype)
 69                * (kT / masspiston) ** 0.5
 70            )
 71            vextvol = 0.5 * (vextvol + vextvol.T)
 72
 73            aniso_mask = simulation_parameters.get(
 74                "aniso_mask", [True, True, True, True, True, True]
 75            )
 76            assert len(aniso_mask) == 6, "aniso_mask must have 6 elements"
 77            aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32)
 78            ndof_piston = np.sum(aniso_mask)
 79            # xx   yy   zz    xy   xz   yz
 80            aniso_mask = np.array(
 81                [
 82                    [aniso_mask[0], aniso_mask[3], aniso_mask[4]],
 83                    [aniso_mask[3], aniso_mask[1], aniso_mask[5]],
 84                    [aniso_mask[4], aniso_mask[5], aniso_mask[2]],
 85                ]
 86            ,dtype=np.int32)
 87        else:
 88            extvol = jnp.asarray(pbc_data["volume"])
 89            vextvol = (
 90                jax.random.normal(v_key, (1,), dtype=extvol.dtype)
 91                * (kT / masspiston) ** 0.5
 92            )
 93            ndof_piston = 1.
 94
 95        state["extvol"] = extvol
 96        state["vextvol"] = vextvol
 97        state["rng_key"] = rng_key
 98        state["istep"] = 0
 99
100        def barostat(x, vel, system):
101            if nbeads is not None:
102                x, eigx = x[0], x[1:]
103                vel, eigv = vel[0], vel[1:]
104            barostat_state = system["barostat"]
105            extvol = barostat_state["extvol"]
106            vextvol = barostat_state["vextvol"]
107            cell = system["cell"]
108            volume = jnp.abs(jnp.linalg.det(cell))
109
110            istep = barostat_state["istep"] + 1
111            dt_bar = dt * (istep >= istart_barostat)
112
113
114            # apply B
115            # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"]
116            ek = system["ek_c"] if nbeads is not None else system["ek"] 
117            pV = system["PV_tensor"] + ek*jnp.array(np.eye(3)*(2/(3*x.shape[0])))
118            if isotropic:
119                dpV = jnp.trace(pV) - 3*volume * target_pressure
120            else:
121                dpV = 0.5 * (pV + pV.T) - volume * jnp.array(target_pressure * np.eye(3))
122
123            vextvol = vextvol + (dt_bar/masspiston) * dpV
124
125            # apply A
126            if isotropic:
127                scalev = jnp.exp((-0.5 * dt_bar*(1+1./x.shape[0])) * vextvol)
128                vel = vel * scalev
129                scale1 = jnp.exp((0.5 * dt_bar) * vextvol)
130            else:
131                vextvol = aniso_mask * vextvol
132                l, O = jnp.linalg.eigh(vextvol) 
133                lcorr = jnp.trace(vextvol)/(3*x.shape[0])
134                Dv = jnp.diag(jnp.exp(-0.5 * dt_bar * (l+lcorr)))
135                Dx = jnp.diag(jnp.exp(0.5 * dt_bar * l))
136                scalev = O @ Dv @ O.T
137                scale1 = O @ Dx @ O.T
138                vel = vel @ scalev
139
140            # apply O
141            if nbeads is not None:
142                eigv, thermostat_state = thermostat(
143                    jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"]
144                )
145                vel, eigv = eigv[0], eigv[1:]
146            else:
147                vel, thermostat_state = thermostat(vel, system["thermostat"])
148            rng_key, noise_key = jax.random.split(barostat_state["rng_key"])
149
150            if isotropic:
151                noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype)
152            else:
153                noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype)
154                noise = 0.5 * (noise + noise.T)
155
156            vextvol = a1 * vextvol + a2 * noise
157
158            # apply A
159            if isotropic:
160                scalev = jnp.exp((-0.5 * dt_bar*(1+1./x.shape[0])) * vextvol)
161                vel = vel * scalev
162                scale2 = jnp.exp((0.5 * dt_bar) * vextvol)
163                x = x * (scale1 * scale2)
164                extvol = extvol * (scale1 * scale2) ** 3
165                cell = cell * (scale1 * scale2)
166            else:
167                vextvol = aniso_mask * vextvol
168                l, O = jnp.linalg.eigh(vextvol) 
169                lcorr = jnp.trace(vextvol)/(3*x.shape[0])
170                Dv = jnp.diag(jnp.exp(-0.5 * dt_bar * (l+lcorr)))
171                Dx = jnp.diag(jnp.exp(0.5 * dt_bar * l))
172                scalev = O @ Dv @ O.T
173                scale = scale1 @ (O @ Dx @ O.T)
174                vel = vel @ scalev
175                x = x @ scale
176                extvol = extvol @ scale
177                cell = extvol
178
179            if nbeads is not None:
180                x = jnp.concatenate((x[None], eigx), axis=0)
181                vel = jnp.concatenate((vel[None], eigv), axis=0)
182
183            piston_temperature = (au.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2)
184            barostat_state = {
185                **barostat_state,
186                "istep": istep,
187                "rng_key": rng_key,
188                "vextvol": vextvol,
189                "extvol": extvol,
190                "piston_temperature": piston_temperature,
191            }
192            return (
193                x,
194                vel,
195                {
196                    **system,
197                    "barostat": barostat_state,
198                    "cell": cell,
199                    "thermostat": thermostat_state,
200                },
201            )
202
203    elif barostat_name in ["NONE"]:
204        variable_cell = False
205
206        def barostat(x, vel, system):
207            vel, thermostat_state = thermostat(vel, system["thermostat"])
208            return x, vel, {**system, "thermostat": thermostat_state}
209
210    else:
211        raise ValueError(f"Unknown barostat {barostat_name}")
212
213    return barostat, variable_cell, state
def get_barostat( thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 19def get_barostat(
 20    thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}
 21):
 22    state = {}
 23
 24    barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper()
 25
 26    kT = system_data.get("kT", None)
 27    assert kT is not None, "kT must be specified for NPT/NPH simulations"
 28    target_pressure = simulation_parameters.get("target_pressure")
 29    if barostat_name != "NONE":
 30        assert (
 31            target_pressure is not None
 32        ), "target_pressure must be specified for NPT/NPH simulations"
 33        target_pressure = target_pressure / au.BOHR**3
 34
 35    nbeads = system_data.get("nbeads", None)
 36    variable_cell = True
 37
 38    anisotropic = simulation_parameters.get("aniso_barostat", False)
 39    
 40    isotropic = not anisotropic
 41
 42    pbc_data = system_data["pbc"]
 43    start_barostat = simulation_parameters.get("start_barostat", 0.0)*au.FS
 44    start_time = restart_data.get("simulation_time_ps",0.) * 1e3
 45    start_barostat = max(0.,start_barostat-start_time)
 46    istart_barostat = int(round(start_barostat / dt))
 47    if istart_barostat > 0 and barostat_name not in ["NONE"]:
 48        print(
 49            f"# BAROSTAT will start at {start_barostat/1000:.3f} ps ({istart_barostat} steps)"
 50        )
 51    else:
 52        istart_barostat = 0
 53
 54    if barostat_name in ["LGV", "LANGEVIN"]:
 55        assert rng_key is not None, "rng_key must be provided for QTB barostat"
 56        gamma = simulation_parameters.get("gamma_piston", 20.0 / au.THZ) / au.FS
 57        tau_piston = simulation_parameters.get("tau_piston", 200.0 / au.FS) * au.FS
 58        nat = system_data["nat"]
 59        masspiston = 3 * nat * kT * tau_piston**2
 60        print(f"# LANGEVIN barostat with piston mass={masspiston:.1e} Ha.fs^2")
 61        a1 = math.exp(-gamma * dt)
 62        a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5
 63
 64
 65        rng_key, v_key = jax.random.split(rng_key)
 66        if anisotropic:
 67            extvol = pbc_data["cell"]
 68            vextvol = (
 69                jax.random.normal(v_key, (3, 3), dtype=extvol.dtype)
 70                * (kT / masspiston) ** 0.5
 71            )
 72            vextvol = 0.5 * (vextvol + vextvol.T)
 73
 74            aniso_mask = simulation_parameters.get(
 75                "aniso_mask", [True, True, True, True, True, True]
 76            )
 77            assert len(aniso_mask) == 6, "aniso_mask must have 6 elements"
 78            aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32)
 79            ndof_piston = np.sum(aniso_mask)
 80            # xx   yy   zz    xy   xz   yz
 81            aniso_mask = np.array(
 82                [
 83                    [aniso_mask[0], aniso_mask[3], aniso_mask[4]],
 84                    [aniso_mask[3], aniso_mask[1], aniso_mask[5]],
 85                    [aniso_mask[4], aniso_mask[5], aniso_mask[2]],
 86                ]
 87            ,dtype=np.int32)
 88        else:
 89            extvol = jnp.asarray(pbc_data["volume"])
 90            vextvol = (
 91                jax.random.normal(v_key, (1,), dtype=extvol.dtype)
 92                * (kT / masspiston) ** 0.5
 93            )
 94            ndof_piston = 1.
 95
 96        state["extvol"] = extvol
 97        state["vextvol"] = vextvol
 98        state["rng_key"] = rng_key
 99        state["istep"] = 0
100
101        def barostat(x, vel, system):
102            if nbeads is not None:
103                x, eigx = x[0], x[1:]
104                vel, eigv = vel[0], vel[1:]
105            barostat_state = system["barostat"]
106            extvol = barostat_state["extvol"]
107            vextvol = barostat_state["vextvol"]
108            cell = system["cell"]
109            volume = jnp.abs(jnp.linalg.det(cell))
110
111            istep = barostat_state["istep"] + 1
112            dt_bar = dt * (istep >= istart_barostat)
113
114
115            # apply B
116            # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"]
117            ek = system["ek_c"] if nbeads is not None else system["ek"] 
118            pV = system["PV_tensor"] + ek*jnp.array(np.eye(3)*(2/(3*x.shape[0])))
119            if isotropic:
120                dpV = jnp.trace(pV) - 3*volume * target_pressure
121            else:
122                dpV = 0.5 * (pV + pV.T) - volume * jnp.array(target_pressure * np.eye(3))
123
124            vextvol = vextvol + (dt_bar/masspiston) * dpV
125
126            # apply A
127            if isotropic:
128                scalev = jnp.exp((-0.5 * dt_bar*(1+1./x.shape[0])) * vextvol)
129                vel = vel * scalev
130                scale1 = jnp.exp((0.5 * dt_bar) * vextvol)
131            else:
132                vextvol = aniso_mask * vextvol
133                l, O = jnp.linalg.eigh(vextvol) 
134                lcorr = jnp.trace(vextvol)/(3*x.shape[0])
135                Dv = jnp.diag(jnp.exp(-0.5 * dt_bar * (l+lcorr)))
136                Dx = jnp.diag(jnp.exp(0.5 * dt_bar * l))
137                scalev = O @ Dv @ O.T
138                scale1 = O @ Dx @ O.T
139                vel = vel @ scalev
140
141            # apply O
142            if nbeads is not None:
143                eigv, thermostat_state = thermostat(
144                    jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"]
145                )
146                vel, eigv = eigv[0], eigv[1:]
147            else:
148                vel, thermostat_state = thermostat(vel, system["thermostat"])
149            rng_key, noise_key = jax.random.split(barostat_state["rng_key"])
150
151            if isotropic:
152                noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype)
153            else:
154                noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype)
155                noise = 0.5 * (noise + noise.T)
156
157            vextvol = a1 * vextvol + a2 * noise
158
159            # apply A
160            if isotropic:
161                scalev = jnp.exp((-0.5 * dt_bar*(1+1./x.shape[0])) * vextvol)
162                vel = vel * scalev
163                scale2 = jnp.exp((0.5 * dt_bar) * vextvol)
164                x = x * (scale1 * scale2)
165                extvol = extvol * (scale1 * scale2) ** 3
166                cell = cell * (scale1 * scale2)
167            else:
168                vextvol = aniso_mask * vextvol
169                l, O = jnp.linalg.eigh(vextvol) 
170                lcorr = jnp.trace(vextvol)/(3*x.shape[0])
171                Dv = jnp.diag(jnp.exp(-0.5 * dt_bar * (l+lcorr)))
172                Dx = jnp.diag(jnp.exp(0.5 * dt_bar * l))
173                scalev = O @ Dv @ O.T
174                scale = scale1 @ (O @ Dx @ O.T)
175                vel = vel @ scalev
176                x = x @ scale
177                extvol = extvol @ scale
178                cell = extvol
179
180            if nbeads is not None:
181                x = jnp.concatenate((x[None], eigx), axis=0)
182                vel = jnp.concatenate((vel[None], eigv), axis=0)
183
184            piston_temperature = (au.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2)
185            barostat_state = {
186                **barostat_state,
187                "istep": istep,
188                "rng_key": rng_key,
189                "vextvol": vextvol,
190                "extvol": extvol,
191                "piston_temperature": piston_temperature,
192            }
193            return (
194                x,
195                vel,
196                {
197                    **system,
198                    "barostat": barostat_state,
199                    "cell": cell,
200                    "thermostat": thermostat_state,
201                },
202            )
203
204    elif barostat_name in ["NONE"]:
205        variable_cell = False
206
207        def barostat(x, vel, system):
208            vel, thermostat_state = thermostat(vel, system["thermostat"])
209            return x, vel, {**system, "thermostat": thermostat_state}
210
211    else:
212        raise ValueError(f"Unknown barostat {barostat_name}")
213
214    return barostat, variable_cell, state