fennol.md.colvars

  1import jax
  2import jax.numpy as jnp
  3import numpy as np
  4from functools import partial
  5
  6from ..utils import read_tinker_interval
  7
  8
  9def build_colvar_distance(colvar_name, colvar_def, global_colvars={}):
 10    atoms = colvar_def["atoms"]
 11    assert len(atoms) == 2, f"Colvar {colvar_name} must have exactly 2 atoms defined"
 12    atom1 = int(atoms[0]) - 1
 13    atom2 = int(atoms[1]) - 1
 14    assert atom1 != atom2, f"atom1 and atom2 must be different for colvar {colvar_name}"
 15    assert (
 16        atom1 >= 0 and atom2 >= 0
 17    ), f"atom1 and atom2 must be > 0 for colvar {colvar_name}"
 18
 19    def colvar_distance(coordinates):
 20        return jnp.linalg.norm(coordinates[atom1] - coordinates[atom2])
 21
 22    return colvar_distance
 23
 24
 25def build_colvar_angle(colvar_name, colvar_def, global_colvars={}):
 26    atoms = colvar_def["atoms"]
 27    assert len(atoms) == 3, f"Colvar {colvar_name} must have exactly 3 atoms defined"
 28    atom1 = int(atoms[0]) - 1
 29    atom2 = int(atoms[1]) - 1
 30    atom3 = int(atoms[2]) - 1
 31    assert (
 32        atom1 != atom2 and atom2 != atom3 and atom1 != atom3
 33    ), f"atom1, atom2 and atom3 must be different for colvar {colvar_name}"
 34    assert (
 35        atom1 >= 0 and atom2 >= 0 and atom3 >= 0
 36    ), f"atom1, atom2 and atom3 must be > 0 for colvar {colvar_name}"
 37    use_radians = colvar_def.get("use_radians", False)
 38    fact = 1.0 if use_radians else 180.0 / np.pi
 39
 40    def colvar_angle(coordinates):
 41        v1 = coordinates[atom1] - coordinates[atom2]
 42        v2 = coordinates[atom3] - coordinates[atom2]
 43        v1 = v1 / jnp.linalg.norm(v1)
 44        v2 = v2 / jnp.linalg.norm(v2)
 45        return jnp.arccos(jnp.dot(v1, v2)) * fact
 46
 47    return colvar_angle
 48
 49
 50def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}):
 51    atoms = colvar_def["atoms"]
 52    assert len(atoms) == 4, f"Colvar {colvar_name} must have exactly 4 atoms defined"
 53    atom1 = int(atoms[0]) - 1
 54    atom2 = int(atoms[1]) - 1
 55    atom3 = int(atoms[2]) - 1
 56    atom4 = int(atoms[3]) - 1
 57    # Check that atoms are different and >= 0
 58    assert (
 59        atom1 != atom2 and atom2 != atom3 and atom3 != atom4 and atom1 != atom4
 60    ), f"atom1, atom2, atom3 and atom4 must be different for colvar {colvar_name}"
 61    assert (
 62        atom1 >= 0 and atom2 >= 0 and atom3 >= 0 and atom4 >= 0
 63    ), f"atom1, atom2, atom3 and atom4 must be > 0 for colvar {colvar_name}"
 64    use_radians = colvar_def.get("use_radians", False)
 65    fact = 1.0 if use_radians else 180.0 / np.pi
 66
 67    def colvar_dihedral(coordinates):
 68        v12 = coordinates[atom2] - coordinates[atom1]
 69        v23 = coordinates[atom3] - coordinates[atom2]
 70        v34 = coordinates[atom4] - coordinates[atom3]
 71        n1 = jnp.cross(v12, v23)
 72        n2 = jnp.cross(v23, v34)
 73        n1 = n1 / jnp.linalg.norm(n1)
 74        n2 = n2 / jnp.linalg.norm(n2)
 75        cos_phi = (n1*n2).sum()
 76        sin_phi = (n1*v34).sum()*jnp.linalg.norm(v23)
 77        phi = jnp.arctan2(sin_phi, cos_phi)
 78        return phi * fact
 79
 80    return colvar_dihedral
 81
 82def build_colvar_maxdcom(colvar_name, colvar_def, global_colvars={}):
 83    atoms = read_tinker_interval(colvar_def["atoms"])
 84    
 85    def colvar_maxdcom(coordinates):
 86        sel_coord = coordinates[atoms]
 87        com = jnp.mean(sel_coord, axis=0,keepdims=True)
 88        return jnp.max(jnp.linalg.norm(sel_coord - com, axis=1))
 89
 90    return colvar_maxdcom
 91
 92
 93__RAW_COLVAR = {
 94    "distance": build_colvar_distance,
 95    "angle": build_colvar_angle,
 96    "dihedral": build_colvar_dihedral,
 97    "maxdcom": build_colvar_maxdcom,
 98}
 99
100
101def build_colvar_function(colvar_name, colvar_def, global_colvars={}):
102    func_def = colvar_def["lambda"]
103    assert isinstance(
104        func_def, str
105    ), f"'lambda' field of colvar '{colvar_name}' must be a string (try using quotation marks)"
106    arguments = [s.strip() for s in func_def.split(":")[0].split(",")]
107    _cvs = []
108    for cv_name in arguments:
109        if cv_name in colvar_def:
110            cv_def = colvar_def[cv_name]
111        elif cv_name in global_colvars:
112            cv_def = global_colvars[cv_name]
113        else:
114            raise ValueError(
115                f"Colvar '{colvar_name}' references colvar '{cv_name}' which is not defined in the colvars section"
116            )
117        cv_type = str(cv_def.get("type", "distance")).lower()
118        assert (
119            cv_type in __RAW_COLVAR
120        ), f"Unknown colvar type '{cv_type}' for colvar '{colvar_name}/{cv_name}'. Available colvars are {list(__RAW_COLVAR.keys())}"
121        _cvs.append(__RAW_COLVAR[cv_type](cv_name, cv_def))
122
123    func = eval(
124        "lambda " + func_def,
125        {
126            "__builtins__": None,
127            **jax.nn.__dict__,
128            **jax.numpy.__dict__,
129            **jax.__dict__,
130        },
131    )
132
133    def colvar_function(coordinates):
134        cv_values = []
135        for cv in _cvs:
136            cv_values.append(cv(coordinates))
137        return func(*cv_values)
138
139    return colvar_function
140
141
142__BUILD_COLVAR = {
143    **__RAW_COLVAR,
144    "function": build_colvar_function,
145}
146
147
148def setup_colvars(colvars_definitions):
149    colvars = {}
150    for colvar_name, colvar_def in colvars_definitions.items():
151        colvar_type = str(colvar_def.get("type", "distance")).lower()
152        assert (
153            colvar_type in __BUILD_COLVAR
154        ), f"Unknown colvar type '{colvar_type}' for colvar '{colvar_name}'. Available colvars are {list(__BUILD_COLVAR.keys())}"
155        colvars[colvar_name] = __BUILD_COLVAR[colvar_type](
156            colvar_name, colvar_def, colvars_definitions
157        )
158
159    return colvars, list(colvars.keys())
def build_colvar_distance(colvar_name, colvar_def, global_colvars={}):
10def build_colvar_distance(colvar_name, colvar_def, global_colvars={}):
11    atoms = colvar_def["atoms"]
12    assert len(atoms) == 2, f"Colvar {colvar_name} must have exactly 2 atoms defined"
13    atom1 = int(atoms[0]) - 1
14    atom2 = int(atoms[1]) - 1
15    assert atom1 != atom2, f"atom1 and atom2 must be different for colvar {colvar_name}"
16    assert (
17        atom1 >= 0 and atom2 >= 0
18    ), f"atom1 and atom2 must be > 0 for colvar {colvar_name}"
19
20    def colvar_distance(coordinates):
21        return jnp.linalg.norm(coordinates[atom1] - coordinates[atom2])
22
23    return colvar_distance
def build_colvar_angle(colvar_name, colvar_def, global_colvars={}):
26def build_colvar_angle(colvar_name, colvar_def, global_colvars={}):
27    atoms = colvar_def["atoms"]
28    assert len(atoms) == 3, f"Colvar {colvar_name} must have exactly 3 atoms defined"
29    atom1 = int(atoms[0]) - 1
30    atom2 = int(atoms[1]) - 1
31    atom3 = int(atoms[2]) - 1
32    assert (
33        atom1 != atom2 and atom2 != atom3 and atom1 != atom3
34    ), f"atom1, atom2 and atom3 must be different for colvar {colvar_name}"
35    assert (
36        atom1 >= 0 and atom2 >= 0 and atom3 >= 0
37    ), f"atom1, atom2 and atom3 must be > 0 for colvar {colvar_name}"
38    use_radians = colvar_def.get("use_radians", False)
39    fact = 1.0 if use_radians else 180.0 / np.pi
40
41    def colvar_angle(coordinates):
42        v1 = coordinates[atom1] - coordinates[atom2]
43        v2 = coordinates[atom3] - coordinates[atom2]
44        v1 = v1 / jnp.linalg.norm(v1)
45        v2 = v2 / jnp.linalg.norm(v2)
46        return jnp.arccos(jnp.dot(v1, v2)) * fact
47
48    return colvar_angle
def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}):
51def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}):
52    atoms = colvar_def["atoms"]
53    assert len(atoms) == 4, f"Colvar {colvar_name} must have exactly 4 atoms defined"
54    atom1 = int(atoms[0]) - 1
55    atom2 = int(atoms[1]) - 1
56    atom3 = int(atoms[2]) - 1
57    atom4 = int(atoms[3]) - 1
58    # Check that atoms are different and >= 0
59    assert (
60        atom1 != atom2 and atom2 != atom3 and atom3 != atom4 and atom1 != atom4
61    ), f"atom1, atom2, atom3 and atom4 must be different for colvar {colvar_name}"
62    assert (
63        atom1 >= 0 and atom2 >= 0 and atom3 >= 0 and atom4 >= 0
64    ), f"atom1, atom2, atom3 and atom4 must be > 0 for colvar {colvar_name}"
65    use_radians = colvar_def.get("use_radians", False)
66    fact = 1.0 if use_radians else 180.0 / np.pi
67
68    def colvar_dihedral(coordinates):
69        v12 = coordinates[atom2] - coordinates[atom1]
70        v23 = coordinates[atom3] - coordinates[atom2]
71        v34 = coordinates[atom4] - coordinates[atom3]
72        n1 = jnp.cross(v12, v23)
73        n2 = jnp.cross(v23, v34)
74        n1 = n1 / jnp.linalg.norm(n1)
75        n2 = n2 / jnp.linalg.norm(n2)
76        cos_phi = (n1*n2).sum()
77        sin_phi = (n1*v34).sum()*jnp.linalg.norm(v23)
78        phi = jnp.arctan2(sin_phi, cos_phi)
79        return phi * fact
80
81    return colvar_dihedral
def build_colvar_maxdcom(colvar_name, colvar_def, global_colvars={}):
83def build_colvar_maxdcom(colvar_name, colvar_def, global_colvars={}):
84    atoms = read_tinker_interval(colvar_def["atoms"])
85    
86    def colvar_maxdcom(coordinates):
87        sel_coord = coordinates[atoms]
88        com = jnp.mean(sel_coord, axis=0,keepdims=True)
89        return jnp.max(jnp.linalg.norm(sel_coord - com, axis=1))
90
91    return colvar_maxdcom
def build_colvar_function(colvar_name, colvar_def, global_colvars={}):
102def build_colvar_function(colvar_name, colvar_def, global_colvars={}):
103    func_def = colvar_def["lambda"]
104    assert isinstance(
105        func_def, str
106    ), f"'lambda' field of colvar '{colvar_name}' must be a string (try using quotation marks)"
107    arguments = [s.strip() for s in func_def.split(":")[0].split(",")]
108    _cvs = []
109    for cv_name in arguments:
110        if cv_name in colvar_def:
111            cv_def = colvar_def[cv_name]
112        elif cv_name in global_colvars:
113            cv_def = global_colvars[cv_name]
114        else:
115            raise ValueError(
116                f"Colvar '{colvar_name}' references colvar '{cv_name}' which is not defined in the colvars section"
117            )
118        cv_type = str(cv_def.get("type", "distance")).lower()
119        assert (
120            cv_type in __RAW_COLVAR
121        ), f"Unknown colvar type '{cv_type}' for colvar '{colvar_name}/{cv_name}'. Available colvars are {list(__RAW_COLVAR.keys())}"
122        _cvs.append(__RAW_COLVAR[cv_type](cv_name, cv_def))
123
124    func = eval(
125        "lambda " + func_def,
126        {
127            "__builtins__": None,
128            **jax.nn.__dict__,
129            **jax.numpy.__dict__,
130            **jax.__dict__,
131        },
132    )
133
134    def colvar_function(coordinates):
135        cv_values = []
136        for cv in _cvs:
137            cv_values.append(cv(coordinates))
138        return func(*cv_values)
139
140    return colvar_function
def setup_colvars(colvars_definitions):
149def setup_colvars(colvars_definitions):
150    colvars = {}
151    for colvar_name, colvar_def in colvars_definitions.items():
152        colvar_type = str(colvar_def.get("type", "distance")).lower()
153        assert (
154            colvar_type in __BUILD_COLVAR
155        ), f"Unknown colvar type '{colvar_type}' for colvar '{colvar_name}'. Available colvars are {list(__BUILD_COLVAR.keys())}"
156        colvars[colvar_name] = __BUILD_COLVAR[colvar_type](
157            colvar_name, colvar_def, colvars_definitions
158        )
159
160    return colvars, list(colvars.keys())