fennol.md.barostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7from enum import Enum
  8
  9from ..utils.atomic_units import AtomicUnits as au  # CM1,THZ,BOHR,MPROT
 10from ..utils import Counter
 11from ..utils.deconvolution import (
 12    deconvolute_spectrum,
 13    kernel_lorentz_pot,
 14    kernel_lorentz,
 15)
 16
 17
 18def get_barostat(
 19    thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None
 20):
 21    state = {}
 22
 23    barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper()
 24
 25    kT = system_data.get("kT", None)
 26    assert kT is not None, "kT must be specified for NPT/NPH simulations"
 27    target_pressure = simulation_parameters.get("target_pressure")
 28    if barostat_name != "NONE":
 29        assert (
 30            target_pressure is not None
 31        ), "target_pressure must be specified for NPT/NPH simulations"
 32        target_pressure = target_pressure / au.BOHR**3
 33
 34    nbeads = system_data.get("nbeads", None)
 35    variable_cell = True
 36
 37    anisotropic = simulation_parameters.get("aniso_barostat", False)
 38    
 39    isotropic = not anisotropic
 40
 41    pbc_data = system_data["pbc"]
 42
 43    if barostat_name in ["LGV", "LANGEVIN"]:
 44        assert rng_key is not None, "rng_key must be provided for QTB barostat"
 45        gamma = simulation_parameters.get("gamma_piston", 20.0 / au.THZ) / au.FS
 46        tau_piston = simulation_parameters.get("tau_piston", 200.0 / au.FS) * au.FS
 47        nat = len(system_data["species"])
 48        masspiston = 3 * nat * kT * tau_piston**2
 49        print(f"# LANGEVIN barostat with piston mass={masspiston:.1e} Ha.fs^2")
 50        a1 = math.exp(-gamma * dt)
 51        a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5
 52
 53        rng_key, v_key = jax.random.split(rng_key)
 54        if anisotropic:
 55            extvol = pbc_data["cell"]
 56            vextvol = (
 57                jax.random.normal(v_key, (3, 3), dtype=extvol.dtype)
 58                * (kT / masspiston) ** 0.5
 59            )
 60            vextvol = 0.5 * (vextvol + vextvol.T)
 61
 62            aniso_mask = simulation_parameters.get(
 63                "aniso_mask", [True, True, True, True, True, True]
 64            )
 65            assert len(aniso_mask) == 6, "aniso_mask must have 6 elements"
 66            aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32)
 67            ndof_piston = np.sum(aniso_mask)
 68            # xx   yy   zz    xy   xz   yz
 69            aniso_mask = np.array(
 70                [
 71                    [aniso_mask[0], aniso_mask[3], aniso_mask[4]],
 72                    [aniso_mask[3], aniso_mask[1], aniso_mask[5]],
 73                    [aniso_mask[4], aniso_mask[5], aniso_mask[2]],
 74                ]
 75            ,dtype=np.int32)
 76        else:
 77            extvol = jnp.asarray(pbc_data["volume"])
 78            vextvol = (
 79                jax.random.normal(v_key, (1,), dtype=extvol.dtype)
 80                * (kT / masspiston) ** 0.5
 81            )
 82            ndof_piston = 1.
 83
 84        state["extvol"] = extvol
 85        state["vextvol"] = vextvol
 86        state["rng_key"] = rng_key
 87
 88        def barostat(x, vel, system):
 89            if nbeads is not None:
 90                x, eigx = x[0], x[1:]
 91                vel, eigv = vel[0], vel[1:]
 92            barostat_state = system["barostat"]
 93            extvol = barostat_state["extvol"]
 94            vextvol = barostat_state["vextvol"]
 95            cell = system["cell"]
 96            volume = jnp.abs(jnp.linalg.det(cell))
 97
 98            # apply B
 99            pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"]
100            if isotropic:
101                dpV = jnp.trace(pV) - 3*volume * target_pressure
102            else:
103                dpV = 0.5 * (pV + pV.T) - volume * target_pressure * jnp.eye(3) 
104
105            vextvol = vextvol + (dt/masspiston) * dpV
106
107            # apply A
108            if isotropic:
109                scale1 = jnp.exp((0.5 * dt*(1+1./x.shape[0])) * vextvol)
110                vel = vel / scale1
111            else:
112                vextvol = aniso_mask * vextvol
113                l, O = jnp.linalg.eigh(vextvol + jnp.trace(vextvol) * jnp.eye(3)/(3*x.shape[0]))
114                Dv = jnp.diag(jnp.exp(-0.5 * dt * l))
115                Dx = jnp.diag(jnp.exp(0.5 * dt * l))
116                scalev = O @ Dv @ O.T
117                scale1 = O @ Dx @ O.T
118                vel = vel @ scalev
119
120            # apply O
121            if nbeads is not None:
122                eigv, thermostat_state = thermostat(
123                    jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"]
124                )
125                vel, eigv = eigv[0], eigv[1:]
126            else:
127                vel, thermostat_state = thermostat(vel, system["thermostat"])
128            rng_key, noise_key = jax.random.split(barostat_state["rng_key"])
129
130            if isotropic:
131                noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype)
132            else:
133                noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype)
134                noise = 0.5 * (noise + noise.T)
135
136            vextvol = a1 * vextvol + a2 * noise
137
138            # apply A
139            if isotropic:
140                scale2 = jnp.exp((0.5 * dt*(1+1./x.shape[0])) * vextvol)
141                vel = vel * scale1
142                x = x * (scale1 * scale2)
143                extvol = extvol * (scale1 * scale2) ** 3
144                cell = cell * (scale1 * scale2)
145            else:
146                vextvol = aniso_mask * vextvol
147                l, O = jnp.linalg.eigh(vextvol + jnp.trace(vextvol) * jnp.eye(3)/(3*x.shape[0]))
148                Dv = jnp.diag(jnp.exp(-0.5 * dt * l))
149                Dx = jnp.diag(jnp.exp(0.5 * dt * l))
150                scalev = O @ Dv @ O.T
151                scale = scale1 @ (O @ Dx @ O.T)
152                vel = vel @ scalev
153                x = x @ scale
154                extvol = extvol @ scale
155                cell = extvol
156
157            if nbeads is not None:
158                x = jnp.concatenate((x[None], eigx), axis=0)
159                vel = jnp.concatenate((vel[None], eigv), axis=0)
160
161            piston_temperature = (au.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2)
162            barostat_state = {
163                **barostat_state,
164                "rng_key": rng_key,
165                "vextvol": vextvol,
166                "extvol": extvol,
167                "piston_temperature": piston_temperature,
168            }
169            return (
170                x,
171                vel,
172                {
173                    "barostat": barostat_state,
174                    "cell": cell,
175                    "thermostat": thermostat_state,
176                },
177            )
178
179    elif barostat_name in ["NONE"]:
180        variable_cell = False
181
182        def barostat(x, vel, system):
183            vel, thermostat_state = thermostat(vel, system["thermostat"])
184            return x, vel, {**system, "thermostat": thermostat_state}
185
186    else:
187        raise ValueError(f"Unknown barostat {barostat_name}")
188
189    return barostat, variable_cell, state
def get_barostat( thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None):
 19def get_barostat(
 20    thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None
 21):
 22    state = {}
 23
 24    barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper()
 25
 26    kT = system_data.get("kT", None)
 27    assert kT is not None, "kT must be specified for NPT/NPH simulations"
 28    target_pressure = simulation_parameters.get("target_pressure")
 29    if barostat_name != "NONE":
 30        assert (
 31            target_pressure is not None
 32        ), "target_pressure must be specified for NPT/NPH simulations"
 33        target_pressure = target_pressure / au.BOHR**3
 34
 35    nbeads = system_data.get("nbeads", None)
 36    variable_cell = True
 37
 38    anisotropic = simulation_parameters.get("aniso_barostat", False)
 39    
 40    isotropic = not anisotropic
 41
 42    pbc_data = system_data["pbc"]
 43
 44    if barostat_name in ["LGV", "LANGEVIN"]:
 45        assert rng_key is not None, "rng_key must be provided for QTB barostat"
 46        gamma = simulation_parameters.get("gamma_piston", 20.0 / au.THZ) / au.FS
 47        tau_piston = simulation_parameters.get("tau_piston", 200.0 / au.FS) * au.FS
 48        nat = len(system_data["species"])
 49        masspiston = 3 * nat * kT * tau_piston**2
 50        print(f"# LANGEVIN barostat with piston mass={masspiston:.1e} Ha.fs^2")
 51        a1 = math.exp(-gamma * dt)
 52        a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5
 53
 54        rng_key, v_key = jax.random.split(rng_key)
 55        if anisotropic:
 56            extvol = pbc_data["cell"]
 57            vextvol = (
 58                jax.random.normal(v_key, (3, 3), dtype=extvol.dtype)
 59                * (kT / masspiston) ** 0.5
 60            )
 61            vextvol = 0.5 * (vextvol + vextvol.T)
 62
 63            aniso_mask = simulation_parameters.get(
 64                "aniso_mask", [True, True, True, True, True, True]
 65            )
 66            assert len(aniso_mask) == 6, "aniso_mask must have 6 elements"
 67            aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32)
 68            ndof_piston = np.sum(aniso_mask)
 69            # xx   yy   zz    xy   xz   yz
 70            aniso_mask = np.array(
 71                [
 72                    [aniso_mask[0], aniso_mask[3], aniso_mask[4]],
 73                    [aniso_mask[3], aniso_mask[1], aniso_mask[5]],
 74                    [aniso_mask[4], aniso_mask[5], aniso_mask[2]],
 75                ]
 76            ,dtype=np.int32)
 77        else:
 78            extvol = jnp.asarray(pbc_data["volume"])
 79            vextvol = (
 80                jax.random.normal(v_key, (1,), dtype=extvol.dtype)
 81                * (kT / masspiston) ** 0.5
 82            )
 83            ndof_piston = 1.
 84
 85        state["extvol"] = extvol
 86        state["vextvol"] = vextvol
 87        state["rng_key"] = rng_key
 88
 89        def barostat(x, vel, system):
 90            if nbeads is not None:
 91                x, eigx = x[0], x[1:]
 92                vel, eigv = vel[0], vel[1:]
 93            barostat_state = system["barostat"]
 94            extvol = barostat_state["extvol"]
 95            vextvol = barostat_state["vextvol"]
 96            cell = system["cell"]
 97            volume = jnp.abs(jnp.linalg.det(cell))
 98
 99            # apply B
100            pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"]
101            if isotropic:
102                dpV = jnp.trace(pV) - 3*volume * target_pressure
103            else:
104                dpV = 0.5 * (pV + pV.T) - volume * target_pressure * jnp.eye(3) 
105
106            vextvol = vextvol + (dt/masspiston) * dpV
107
108            # apply A
109            if isotropic:
110                scale1 = jnp.exp((0.5 * dt*(1+1./x.shape[0])) * vextvol)
111                vel = vel / scale1
112            else:
113                vextvol = aniso_mask * vextvol
114                l, O = jnp.linalg.eigh(vextvol + jnp.trace(vextvol) * jnp.eye(3)/(3*x.shape[0]))
115                Dv = jnp.diag(jnp.exp(-0.5 * dt * l))
116                Dx = jnp.diag(jnp.exp(0.5 * dt * l))
117                scalev = O @ Dv @ O.T
118                scale1 = O @ Dx @ O.T
119                vel = vel @ scalev
120
121            # apply O
122            if nbeads is not None:
123                eigv, thermostat_state = thermostat(
124                    jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"]
125                )
126                vel, eigv = eigv[0], eigv[1:]
127            else:
128                vel, thermostat_state = thermostat(vel, system["thermostat"])
129            rng_key, noise_key = jax.random.split(barostat_state["rng_key"])
130
131            if isotropic:
132                noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype)
133            else:
134                noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype)
135                noise = 0.5 * (noise + noise.T)
136
137            vextvol = a1 * vextvol + a2 * noise
138
139            # apply A
140            if isotropic:
141                scale2 = jnp.exp((0.5 * dt*(1+1./x.shape[0])) * vextvol)
142                vel = vel * scale1
143                x = x * (scale1 * scale2)
144                extvol = extvol * (scale1 * scale2) ** 3
145                cell = cell * (scale1 * scale2)
146            else:
147                vextvol = aniso_mask * vextvol
148                l, O = jnp.linalg.eigh(vextvol + jnp.trace(vextvol) * jnp.eye(3)/(3*x.shape[0]))
149                Dv = jnp.diag(jnp.exp(-0.5 * dt * l))
150                Dx = jnp.diag(jnp.exp(0.5 * dt * l))
151                scalev = O @ Dv @ O.T
152                scale = scale1 @ (O @ Dx @ O.T)
153                vel = vel @ scalev
154                x = x @ scale
155                extvol = extvol @ scale
156                cell = extvol
157
158            if nbeads is not None:
159                x = jnp.concatenate((x[None], eigx), axis=0)
160                vel = jnp.concatenate((vel[None], eigv), axis=0)
161
162            piston_temperature = (au.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2)
163            barostat_state = {
164                **barostat_state,
165                "rng_key": rng_key,
166                "vextvol": vextvol,
167                "extvol": extvol,
168                "piston_temperature": piston_temperature,
169            }
170            return (
171                x,
172                vel,
173                {
174                    "barostat": barostat_state,
175                    "cell": cell,
176                    "thermostat": thermostat_state,
177                },
178            )
179
180    elif barostat_name in ["NONE"]:
181        variable_cell = False
182
183        def barostat(x, vel, system):
184            vel, thermostat_state = thermostat(vel, system["thermostat"])
185            return x, vel, {**system, "thermostat": thermostat_state}
186
187    else:
188        raise ValueError(f"Unknown barostat {barostat_name}")
189
190    return barostat, variable_cell, state