fennol.md.colvars

  1import jax
  2import jax.numpy as jnp
  3import numpy as np
  4from functools import partial
  5
  6
  7def build_colvar_distance(colvar_name, colvar_def, global_colvars={}):
  8    atom1 = colvar_def["atom1"] - 1
  9    atom2 = colvar_def["atom2"] - 1
 10    assert atom1 != atom2, f"atom1 and atom2 must be different for colvar {colvar_name}"
 11    assert (
 12        atom1 >= 0 and atom2 >= 0
 13    ), f"atom1 and atom2 must be > 0 for colvar {colvar_name}"
 14
 15    def colvar_distance(coordinates):
 16        return jnp.linalg.norm(coordinates[atom1] - coordinates[atom2])
 17
 18    return colvar_distance
 19
 20
 21def build_colvar_angle(colvar_name, colvar_def, global_colvars={}):
 22    atom1 = colvar_def["atom1"] - 1
 23    atom2 = colvar_def["atom2"] - 1
 24    atom3 = colvar_def["atom3"] - 1
 25    assert (
 26        atom1 != atom2 and atom2 != atom3 and atom1 != atom3
 27    ), f"atom1, atom2 and atom3 must be different for colvar {colvar_name}"
 28    assert (
 29        atom1 >= 0 and atom2 >= 0 and atom3 >= 0
 30    ), f"atom1, atom2 and atom3 must be > 0 for colvar {colvar_name}"
 31    use_radians = colvar_def.get("use_radians", False)
 32    fact = 1.0 if use_radians else 180.0 / np.pi
 33
 34    def colvar_angle(coordinates):
 35        v1 = coordinates[atom1] - coordinates[atom2]
 36        v2 = coordinates[atom3] - coordinates[atom2]
 37        v1 = v1 / jnp.linalg.norm(v1)
 38        v2 = v2 / jnp.linalg.norm(v2)
 39        return jnp.arccos(jnp.dot(v1, v2)) * fact
 40
 41    return colvar_angle
 42
 43
 44def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}):
 45    atom1 = colvar_def["atom1"] - 1
 46    atom2 = colvar_def["atom2"] - 1
 47    atom3 = colvar_def["atom3"] - 1
 48    atom4 = colvar_def["atom4"] - 1
 49    assert (
 50        atom1 != atom2 and atom2 != atom3 and atom3 != atom4 and atom1 != atom4
 51    ), f"atom1, atom2, atom3 and atom4 must be different for colvar {colvar_name}"
 52    assert (
 53        atom1 >= 0 and atom2 >= 0 and atom3 >= 0 and atom4 >= 0
 54    ), f"atom1, atom2, atom3 and atom4 must be > 0 for colvar {colvar_name}"
 55    use_radians = colvar_def.get("use_radians", False)
 56    fact = 1.0 if use_radians else 180.0 / np.pi
 57
 58    def colvar_dihedral(coordinates):
 59        v1 = coordinates[atom1] - coordinates[atom2]
 60        v2 = coordinates[atom3] - coordinates[atom2]
 61        v3 = coordinates[atom4] - coordinates[atom3]
 62        n1 = jnp.cross(v1, v2)
 63        n2 = jnp.cross(v2, v3)
 64        n1 = n1 / jnp.linalg.norm(n1)
 65        n2 = n2 / jnp.linalg.norm(n2)
 66        return jnp.arccos(jnp.dot(n1, n2)) * fact
 67
 68    return colvar_dihedral
 69
 70
 71__RAW_COLVAR = {
 72    "distance": build_colvar_distance,
 73    "angle": build_colvar_angle,
 74    "dihedral": build_colvar_dihedral,
 75}
 76
 77
 78def build_colvar_function(colvar_name, colvar_def, global_colvars={}):
 79    func_def = colvar_def["lambda"]
 80    assert isinstance(
 81        func_def, str
 82    ), f"'lambda' field of colvar '{colvar_name}' must be a string (try using quotation marks)"
 83    arguments = [s.strip() for s in func_def.split(":")[0].split(",")]
 84    _cvs = []
 85    for cv_name in arguments:
 86        if cv_name in colvar_def:
 87            cv_def = colvar_def[cv_name]
 88        elif cv_name in global_colvars:
 89            cv_def = global_colvars[cv_name]
 90        else:
 91            raise ValueError(
 92                f"Colvar '{colvar_name}' references colvar '{cv_name}' which is not defined in the colvars section"
 93            )
 94        cv_type = str(cv_def.get("type", "distance")).lower()
 95        assert (
 96            cv_type in __RAW_COLVAR
 97        ), f"Unknown colvar type '{cv_type}' for colvar '{colvar_name}/{cv_name}'. Available colvars are {list(__RAW_COLVAR.keys())}"
 98        _cvs.append(__RAW_COLVAR[cv_type](cv_name, cv_def))
 99
100    func = eval(
101        "lambda " + func_def,
102        {
103            "__builtins__": None,
104            **jax.nn.__dict__,
105            **jax.numpy.__dict__,
106            **jax.__dict__,
107        },
108    )
109
110    def colvar_function(coordinates):
111        cv_values = []
112        for cv in _cvs:
113            cv_values.append(cv(coordinates))
114        return func(*cv_values)
115
116    return colvar_function
117
118
119__BUILD_COLVAR = {
120    **__RAW_COLVAR,
121    "function": build_colvar_function,
122}
123
124
125def setup_colvars(colvars_definitions):
126    colvars = {}
127    for colvar_name, colvar_def in colvars_definitions.items():
128        colvar_type = str(colvar_def.get("type", "distance")).lower()
129        assert (
130            colvar_type in __BUILD_COLVAR
131        ), f"Unknown colvar type '{colvar_type}' for colvar '{colvar_name}'. Available colvars are {list(__BUILD_COLVAR.keys())}"
132        colvars[colvar_name] = __BUILD_COLVAR[colvar_type](
133            colvar_name, colvar_def, colvars_definitions
134        )
135
136    return colvars, list(colvars.keys())
def build_colvar_distance(colvar_name, colvar_def, global_colvars={}):
 8def build_colvar_distance(colvar_name, colvar_def, global_colvars={}):
 9    atom1 = colvar_def["atom1"] - 1
10    atom2 = colvar_def["atom2"] - 1
11    assert atom1 != atom2, f"atom1 and atom2 must be different for colvar {colvar_name}"
12    assert (
13        atom1 >= 0 and atom2 >= 0
14    ), f"atom1 and atom2 must be > 0 for colvar {colvar_name}"
15
16    def colvar_distance(coordinates):
17        return jnp.linalg.norm(coordinates[atom1] - coordinates[atom2])
18
19    return colvar_distance
def build_colvar_angle(colvar_name, colvar_def, global_colvars={}):
22def build_colvar_angle(colvar_name, colvar_def, global_colvars={}):
23    atom1 = colvar_def["atom1"] - 1
24    atom2 = colvar_def["atom2"] - 1
25    atom3 = colvar_def["atom3"] - 1
26    assert (
27        atom1 != atom2 and atom2 != atom3 and atom1 != atom3
28    ), f"atom1, atom2 and atom3 must be different for colvar {colvar_name}"
29    assert (
30        atom1 >= 0 and atom2 >= 0 and atom3 >= 0
31    ), f"atom1, atom2 and atom3 must be > 0 for colvar {colvar_name}"
32    use_radians = colvar_def.get("use_radians", False)
33    fact = 1.0 if use_radians else 180.0 / np.pi
34
35    def colvar_angle(coordinates):
36        v1 = coordinates[atom1] - coordinates[atom2]
37        v2 = coordinates[atom3] - coordinates[atom2]
38        v1 = v1 / jnp.linalg.norm(v1)
39        v2 = v2 / jnp.linalg.norm(v2)
40        return jnp.arccos(jnp.dot(v1, v2)) * fact
41
42    return colvar_angle
def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}):
45def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}):
46    atom1 = colvar_def["atom1"] - 1
47    atom2 = colvar_def["atom2"] - 1
48    atom3 = colvar_def["atom3"] - 1
49    atom4 = colvar_def["atom4"] - 1
50    assert (
51        atom1 != atom2 and atom2 != atom3 and atom3 != atom4 and atom1 != atom4
52    ), f"atom1, atom2, atom3 and atom4 must be different for colvar {colvar_name}"
53    assert (
54        atom1 >= 0 and atom2 >= 0 and atom3 >= 0 and atom4 >= 0
55    ), f"atom1, atom2, atom3 and atom4 must be > 0 for colvar {colvar_name}"
56    use_radians = colvar_def.get("use_radians", False)
57    fact = 1.0 if use_radians else 180.0 / np.pi
58
59    def colvar_dihedral(coordinates):
60        v1 = coordinates[atom1] - coordinates[atom2]
61        v2 = coordinates[atom3] - coordinates[atom2]
62        v3 = coordinates[atom4] - coordinates[atom3]
63        n1 = jnp.cross(v1, v2)
64        n2 = jnp.cross(v2, v3)
65        n1 = n1 / jnp.linalg.norm(n1)
66        n2 = n2 / jnp.linalg.norm(n2)
67        return jnp.arccos(jnp.dot(n1, n2)) * fact
68
69    return colvar_dihedral
def build_colvar_function(colvar_name, colvar_def, global_colvars={}):
 79def build_colvar_function(colvar_name, colvar_def, global_colvars={}):
 80    func_def = colvar_def["lambda"]
 81    assert isinstance(
 82        func_def, str
 83    ), f"'lambda' field of colvar '{colvar_name}' must be a string (try using quotation marks)"
 84    arguments = [s.strip() for s in func_def.split(":")[0].split(",")]
 85    _cvs = []
 86    for cv_name in arguments:
 87        if cv_name in colvar_def:
 88            cv_def = colvar_def[cv_name]
 89        elif cv_name in global_colvars:
 90            cv_def = global_colvars[cv_name]
 91        else:
 92            raise ValueError(
 93                f"Colvar '{colvar_name}' references colvar '{cv_name}' which is not defined in the colvars section"
 94            )
 95        cv_type = str(cv_def.get("type", "distance")).lower()
 96        assert (
 97            cv_type in __RAW_COLVAR
 98        ), f"Unknown colvar type '{cv_type}' for colvar '{colvar_name}/{cv_name}'. Available colvars are {list(__RAW_COLVAR.keys())}"
 99        _cvs.append(__RAW_COLVAR[cv_type](cv_name, cv_def))
100
101    func = eval(
102        "lambda " + func_def,
103        {
104            "__builtins__": None,
105            **jax.nn.__dict__,
106            **jax.numpy.__dict__,
107            **jax.__dict__,
108        },
109    )
110
111    def colvar_function(coordinates):
112        cv_values = []
113        for cv in _cvs:
114            cv_values.append(cv(coordinates))
115        return func(*cv_values)
116
117    return colvar_function
def setup_colvars(colvars_definitions):
126def setup_colvars(colvars_definitions):
127    colvars = {}
128    for colvar_name, colvar_def in colvars_definitions.items():
129        colvar_type = str(colvar_def.get("type", "distance")).lower()
130        assert (
131            colvar_type in __BUILD_COLVAR
132        ), f"Unknown colvar type '{colvar_type}' for colvar '{colvar_name}'. Available colvars are {list(__BUILD_COLVAR.keys())}"
133        colvars[colvar_name] = __BUILD_COLVAR[colvar_type](
134            colvar_name, colvar_def, colvars_definitions
135        )
136
137    return colvars, list(colvars.keys())