fennol.md.barostats

  1import numpy as np
  2import flax.linen as nn
  3import jax
  4import jax.numpy as jnp
  5import math
  6import optax
  7from enum import Enum
  8
  9from .utils import us
 10
 11def get_barostat(
 12    thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}
 13):
 14    state = {}
 15
 16    barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper()
 17    """@keyword[fennol_md] barostat
 18    Type of barostat for pressure control (NONE, LGV, LANGEVIN).
 19    Default: NONE
 20    """
 21
 22    kT = system_data.get("kT", None)
 23    assert kT is not None, "kT must be specified for NPT/NPH simulations"
 24    target_pressure = simulation_parameters.get("target_pressure")
 25    """@keyword[fennol_md] target_pressure
 26    Target pressure for NPT ensemble simulations.
 27    Required for barostat != NONE
 28    """
 29    if barostat_name != "NONE":
 30        assert (
 31            target_pressure is not None
 32        ), "target_pressure must be specified for NPT/NPH simulations"
 33
 34    nbeads = system_data.get("nbeads", None)
 35    variable_cell = True
 36
 37    anisotropic = simulation_parameters.get("aniso_barostat", False)
 38    """@keyword[fennol_md] aniso_barostat
 39    Use anisotropic barostat allowing independent cell parameter scaling.
 40    Default: False
 41    """
 42    
 43    isotropic = not anisotropic
 44
 45    pbc_data = system_data["pbc"]
 46    start_barostat = simulation_parameters.get("start_barostat", 0.0)
 47    """@keyword[fennol_md] start_barostat
 48    Time delay before starting barostat pressure coupling.
 49    Default: 0.0
 50    """
 51    start_time = restart_data.get("simulation_time_ps",0.) / us.PS
 52    start_barostat = max(0.,start_barostat-start_time)
 53    istart_barostat = int(round(start_barostat / dt))
 54    if istart_barostat > 0 and barostat_name not in ["NONE"]:
 55        print(
 56            f"# BAROSTAT will start at {start_barostat*us.PS:.3f} ps ({istart_barostat} steps)"
 57        )
 58    else:
 59        istart_barostat = 0
 60
 61    if barostat_name in ["LGV", "LANGEVIN"]:
 62        assert rng_key is not None, "rng_key must be provided for QTB barostat"
 63        gamma = simulation_parameters.get("gamma_piston", 20.0 / us.THZ)
 64        """@keyword[fennol_md] gamma_piston
 65        Piston friction coefficient for Langevin barostat.
 66        Default: 20.0 ps^-1
 67        """
 68        tau_piston = simulation_parameters.get("tau_piston", 200.0 / us.FS)
 69        """@keyword[fennol_md] tau_piston
 70        Piston time constant for barostat coupling.
 71        Default: 200.0 fs
 72        """
 73        nat = system_data["nat"]
 74        masspiston = 3 * nat * kT * tau_piston**2
 75        print(f"# LANGEVIN barostat with piston mass={masspiston*us.get_multiplier('KCALPERMOL*PS^{2}'):.1e} kcal/mol.ps^2")
 76        a1 = math.exp(-gamma * dt)
 77        a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5
 78
 79
 80        rng_key, v_key = jax.random.split(rng_key)
 81        if anisotropic:
 82            extvol = pbc_data["cell"]
 83            vextvol = (
 84                jax.random.normal(v_key, (3, 3), dtype=extvol.dtype)
 85                * (kT / masspiston) ** 0.5
 86            )
 87            vextvol = 0.5 * (vextvol + vextvol.T)
 88
 89            aniso_mask = simulation_parameters.get(
 90                "aniso_mask", [True, True, True, True, True, True]
 91            )
 92            """@keyword[fennol_md] aniso_mask
 93            Mask for anisotropic barostat degrees of freedom [xx, yy, zz, xy, xz, yz].
 94            Default: [True, True, True, True, True, True]
 95            """
 96            assert len(aniso_mask) == 6, "aniso_mask must have 6 elements"
 97            aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32)
 98            ndof_piston = np.sum(aniso_mask)
 99            # xx   yy   zz    xy   xz   yz
100            aniso_mask = np.array(
101                [
102                    [aniso_mask[0], aniso_mask[3], aniso_mask[4]],
103                    [aniso_mask[3], aniso_mask[1], aniso_mask[5]],
104                    [aniso_mask[4], aniso_mask[5], aniso_mask[2]],
105                ]
106            ,dtype=np.int32)
107        else:
108            extvol = jnp.asarray(pbc_data["volume"])
109            vextvol = (
110                jax.random.normal(v_key, (1,), dtype=extvol.dtype)
111                * (kT / masspiston) ** 0.5
112            )
113            ndof_piston = 1.
114
115        state["extvol"] = extvol
116        state["vextvol"] = vextvol
117        state["rng_key"] = rng_key
118        state["istep"] = 0
119
120        def barostat(x, vel, system):
121            if nbeads is not None:
122                x, eigx = x[0], x[1:]
123                vel, eigv = vel[0], vel[1:]
124            barostat_state = system["barostat"]
125            extvol = barostat_state["extvol"]
126            vextvol = barostat_state["vextvol"]
127            cell = system["cell"]
128            volume = jnp.abs(jnp.linalg.det(cell))
129
130            istep = barostat_state["istep"] + 1
131            dt_bar = dt * (istep >= istart_barostat)
132
133
134            # apply B
135            # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"]
136            ek = system["ek_c"] if nbeads is not None else system["ek"] 
137            Pres = system["pressure_tensor"] + ek*jnp.array(np.eye(3)*(2/(3*x.shape[0])))/volume
138            if isotropic:
139                dPres = jnp.trace(Pres) - 3 * target_pressure
140            else:
141                dPres = 0.5 * (Pres + Pres.T) -  jnp.array(target_pressure * np.eye(3))
142
143            vextvol = vextvol + ((dt_bar/masspiston)*volume) * dPres
144
145            # apply A
146            if isotropic:
147                vdt2 = 0.5*dt_bar * vextvol
148                scalev = jnp.exp(-vdt2*(1+1./x.shape[0])) 
149                vel = vel * scalev
150                scale1 = jnp.exp(vdt2)
151            else:
152                vextvol = aniso_mask * vextvol
153                vdt2 = 0.5* dt_bar * vextvol
154                l, O = jnp.linalg.eigh(vdt2) 
155                lcorr = jnp.trace(vdt2)/(3*x.shape[0])
156                Dv = jnp.diag(jnp.exp(-(l+lcorr)))
157                Dx = jnp.diag(jnp.exp(l))
158                scalev = O @ Dv @ O.T
159                scale1 = O @ Dx @ O.T
160                vel = vel @ scalev
161
162            # apply O
163            if nbeads is not None:
164                eigv, thermostat_state = thermostat(
165                    jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"]
166                )
167                vel, eigv = eigv[0], eigv[1:]
168            else:
169                vel, thermostat_state = thermostat(vel, system["thermostat"])
170            rng_key, noise_key = jax.random.split(barostat_state["rng_key"])
171
172            if isotropic:
173                noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype)
174            else:
175                noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype)
176                noise = 0.5 * (noise + noise.T)
177
178            vextvol = a1 * vextvol + a2 * noise
179
180            # apply A
181            if isotropic:
182                vdt2 = 0.5 * dt_bar * vextvol
183                scalev = jnp.exp(-vdt2*(1+1./x.shape[0]))
184                vel = vel * scalev
185                scale2 = jnp.exp(vdt2)
186                x = x * (scale1 * scale2)
187                extvol = extvol * (scale1 * scale2) ** 3
188                cell = cell * (scale1 * scale2)
189            else:
190                vextvol = aniso_mask * vextvol
191                vdt2 = 0.5 * dt_bar * vextvol
192                l, O = jnp.linalg.eigh(vdt2) 
193                lcorr = jnp.trace(vdt2)/(3*x.shape[0])
194                Dv = jnp.diag(jnp.exp(-(l+lcorr)))
195                Dx = jnp.diag(jnp.exp(l))
196                scalev = O @ Dv @ O.T
197                scale = scale1 @ (O @ Dx @ O.T)
198                vel = vel @ scalev
199                x = x @ scale
200                extvol = extvol @ scale
201                cell = extvol
202
203            if nbeads is not None:
204                x = jnp.concatenate((x[None], eigx), axis=0)
205                vel = jnp.concatenate((vel[None], eigv), axis=0)
206
207            piston_temperature = (us.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2)
208            barostat_state = {
209                **barostat_state,
210                "istep": istep,
211                "rng_key": rng_key,
212                "vextvol": vextvol,
213                "extvol": extvol,
214                "piston_temperature": piston_temperature,
215            }
216            return (
217                x,
218                vel,
219                {
220                    **system,
221                    "barostat": barostat_state,
222                    "cell": cell,
223                    "thermostat": thermostat_state,
224                },
225            )
226
227    elif barostat_name in ["NONE"]:
228        variable_cell = False
229
230        def barostat(x, vel, system):
231            vel, thermostat_state = thermostat(vel, system["thermostat"])
232            return x, vel, {**system, "thermostat": thermostat_state}
233
234    else:
235        raise ValueError(f"Unknown barostat {barostat_name}")
236
237    return barostat, variable_cell, state
def get_barostat( thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}):
 12def get_barostat(
 13    thermostat, simulation_parameters, dt, system_data, fprec, rng_key=None, restart_data={}
 14):
 15    state = {}
 16
 17    barostat_name = str(simulation_parameters.get("barostat", "NONE")).upper()
 18    """@keyword[fennol_md] barostat
 19    Type of barostat for pressure control (NONE, LGV, LANGEVIN).
 20    Default: NONE
 21    """
 22
 23    kT = system_data.get("kT", None)
 24    assert kT is not None, "kT must be specified for NPT/NPH simulations"
 25    target_pressure = simulation_parameters.get("target_pressure")
 26    """@keyword[fennol_md] target_pressure
 27    Target pressure for NPT ensemble simulations.
 28    Required for barostat != NONE
 29    """
 30    if barostat_name != "NONE":
 31        assert (
 32            target_pressure is not None
 33        ), "target_pressure must be specified for NPT/NPH simulations"
 34
 35    nbeads = system_data.get("nbeads", None)
 36    variable_cell = True
 37
 38    anisotropic = simulation_parameters.get("aniso_barostat", False)
 39    """@keyword[fennol_md] aniso_barostat
 40    Use anisotropic barostat allowing independent cell parameter scaling.
 41    Default: False
 42    """
 43    
 44    isotropic = not anisotropic
 45
 46    pbc_data = system_data["pbc"]
 47    start_barostat = simulation_parameters.get("start_barostat", 0.0)
 48    """@keyword[fennol_md] start_barostat
 49    Time delay before starting barostat pressure coupling.
 50    Default: 0.0
 51    """
 52    start_time = restart_data.get("simulation_time_ps",0.) / us.PS
 53    start_barostat = max(0.,start_barostat-start_time)
 54    istart_barostat = int(round(start_barostat / dt))
 55    if istart_barostat > 0 and barostat_name not in ["NONE"]:
 56        print(
 57            f"# BAROSTAT will start at {start_barostat*us.PS:.3f} ps ({istart_barostat} steps)"
 58        )
 59    else:
 60        istart_barostat = 0
 61
 62    if barostat_name in ["LGV", "LANGEVIN"]:
 63        assert rng_key is not None, "rng_key must be provided for QTB barostat"
 64        gamma = simulation_parameters.get("gamma_piston", 20.0 / us.THZ)
 65        """@keyword[fennol_md] gamma_piston
 66        Piston friction coefficient for Langevin barostat.
 67        Default: 20.0 ps^-1
 68        """
 69        tau_piston = simulation_parameters.get("tau_piston", 200.0 / us.FS)
 70        """@keyword[fennol_md] tau_piston
 71        Piston time constant for barostat coupling.
 72        Default: 200.0 fs
 73        """
 74        nat = system_data["nat"]
 75        masspiston = 3 * nat * kT * tau_piston**2
 76        print(f"# LANGEVIN barostat with piston mass={masspiston*us.get_multiplier('KCALPERMOL*PS^{2}'):.1e} kcal/mol.ps^2")
 77        a1 = math.exp(-gamma * dt)
 78        a2 = ((1 - a1 * a1) * kT / masspiston) ** 0.5
 79
 80
 81        rng_key, v_key = jax.random.split(rng_key)
 82        if anisotropic:
 83            extvol = pbc_data["cell"]
 84            vextvol = (
 85                jax.random.normal(v_key, (3, 3), dtype=extvol.dtype)
 86                * (kT / masspiston) ** 0.5
 87            )
 88            vextvol = 0.5 * (vextvol + vextvol.T)
 89
 90            aniso_mask = simulation_parameters.get(
 91                "aniso_mask", [True, True, True, True, True, True]
 92            )
 93            """@keyword[fennol_md] aniso_mask
 94            Mask for anisotropic barostat degrees of freedom [xx, yy, zz, xy, xz, yz].
 95            Default: [True, True, True, True, True, True]
 96            """
 97            assert len(aniso_mask) == 6, "aniso_mask must have 6 elements"
 98            aniso_mask = np.array(aniso_mask, dtype=bool).astype(np.int32)
 99            ndof_piston = np.sum(aniso_mask)
100            # xx   yy   zz    xy   xz   yz
101            aniso_mask = np.array(
102                [
103                    [aniso_mask[0], aniso_mask[3], aniso_mask[4]],
104                    [aniso_mask[3], aniso_mask[1], aniso_mask[5]],
105                    [aniso_mask[4], aniso_mask[5], aniso_mask[2]],
106                ]
107            ,dtype=np.int32)
108        else:
109            extvol = jnp.asarray(pbc_data["volume"])
110            vextvol = (
111                jax.random.normal(v_key, (1,), dtype=extvol.dtype)
112                * (kT / masspiston) ** 0.5
113            )
114            ndof_piston = 1.
115
116        state["extvol"] = extvol
117        state["vextvol"] = vextvol
118        state["rng_key"] = rng_key
119        state["istep"] = 0
120
121        def barostat(x, vel, system):
122            if nbeads is not None:
123                x, eigx = x[0], x[1:]
124                vel, eigv = vel[0], vel[1:]
125            barostat_state = system["barostat"]
126            extvol = barostat_state["extvol"]
127            vextvol = barostat_state["vextvol"]
128            cell = system["cell"]
129            volume = jnp.abs(jnp.linalg.det(cell))
130
131            istep = barostat_state["istep"] + 1
132            dt_bar = dt * (istep >= istart_barostat)
133
134
135            # apply B
136            # pV = 2 * (system["ek_tensor"] + jnp.trace(system["ek_tensor"])*jnp.eye(3)/(3*x.shape[0])) - system["virial"]
137            ek = system["ek_c"] if nbeads is not None else system["ek"] 
138            Pres = system["pressure_tensor"] + ek*jnp.array(np.eye(3)*(2/(3*x.shape[0])))/volume
139            if isotropic:
140                dPres = jnp.trace(Pres) - 3 * target_pressure
141            else:
142                dPres = 0.5 * (Pres + Pres.T) -  jnp.array(target_pressure * np.eye(3))
143
144            vextvol = vextvol + ((dt_bar/masspiston)*volume) * dPres
145
146            # apply A
147            if isotropic:
148                vdt2 = 0.5*dt_bar * vextvol
149                scalev = jnp.exp(-vdt2*(1+1./x.shape[0])) 
150                vel = vel * scalev
151                scale1 = jnp.exp(vdt2)
152            else:
153                vextvol = aniso_mask * vextvol
154                vdt2 = 0.5* dt_bar * vextvol
155                l, O = jnp.linalg.eigh(vdt2) 
156                lcorr = jnp.trace(vdt2)/(3*x.shape[0])
157                Dv = jnp.diag(jnp.exp(-(l+lcorr)))
158                Dx = jnp.diag(jnp.exp(l))
159                scalev = O @ Dv @ O.T
160                scale1 = O @ Dx @ O.T
161                vel = vel @ scalev
162
163            # apply O
164            if nbeads is not None:
165                eigv, thermostat_state = thermostat(
166                    jnp.concatenate((vel[None], eigv), axis=0), system["thermostat"]
167                )
168                vel, eigv = eigv[0], eigv[1:]
169            else:
170                vel, thermostat_state = thermostat(vel, system["thermostat"])
171            rng_key, noise_key = jax.random.split(barostat_state["rng_key"])
172
173            if isotropic:
174                noise = jax.random.normal(noise_key, (1,), dtype=vextvol.dtype)
175            else:
176                noise = jax.random.normal(noise_key, (3, 3), dtype=vextvol.dtype)
177                noise = 0.5 * (noise + noise.T)
178
179            vextvol = a1 * vextvol + a2 * noise
180
181            # apply A
182            if isotropic:
183                vdt2 = 0.5 * dt_bar * vextvol
184                scalev = jnp.exp(-vdt2*(1+1./x.shape[0]))
185                vel = vel * scalev
186                scale2 = jnp.exp(vdt2)
187                x = x * (scale1 * scale2)
188                extvol = extvol * (scale1 * scale2) ** 3
189                cell = cell * (scale1 * scale2)
190            else:
191                vextvol = aniso_mask * vextvol
192                vdt2 = 0.5 * dt_bar * vextvol
193                l, O = jnp.linalg.eigh(vdt2) 
194                lcorr = jnp.trace(vdt2)/(3*x.shape[0])
195                Dv = jnp.diag(jnp.exp(-(l+lcorr)))
196                Dx = jnp.diag(jnp.exp(l))
197                scalev = O @ Dv @ O.T
198                scale = scale1 @ (O @ Dx @ O.T)
199                vel = vel @ scalev
200                x = x @ scale
201                extvol = extvol @ scale
202                cell = extvol
203
204            if nbeads is not None:
205                x = jnp.concatenate((x[None], eigx), axis=0)
206                vel = jnp.concatenate((vel[None], eigv), axis=0)
207
208            piston_temperature = (us.KELVIN * masspiston/ndof_piston) * jnp.sum(vextvol**2)
209            barostat_state = {
210                **barostat_state,
211                "istep": istep,
212                "rng_key": rng_key,
213                "vextvol": vextvol,
214                "extvol": extvol,
215                "piston_temperature": piston_temperature,
216            }
217            return (
218                x,
219                vel,
220                {
221                    **system,
222                    "barostat": barostat_state,
223                    "cell": cell,
224                    "thermostat": thermostat_state,
225                },
226            )
227
228    elif barostat_name in ["NONE"]:
229        variable_cell = False
230
231        def barostat(x, vel, system):
232            vel, thermostat_state = thermostat(vel, system["thermostat"])
233            return x, vel, {**system, "thermostat": thermostat_state}
234
235    else:
236        raise ValueError(f"Unknown barostat {barostat_name}")
237
238    return barostat, variable_cell, state