fennol.md.colvars
1import jax 2import jax.numpy as jnp 3import numpy as np 4from functools import partial 5 6from ..utils import read_tinker_interval 7 8 9def build_colvar_distance(colvar_name, colvar_def, global_colvars={}): 10 atoms = colvar_def["atoms"] 11 assert len(atoms) == 2, f"Colvar {colvar_name} must have exactly 2 atoms defined" 12 atom1 = int(atoms[0]) - 1 13 atom2 = int(atoms[1]) - 1 14 assert atom1 != atom2, f"atom1 and atom2 must be different for colvar {colvar_name}" 15 assert ( 16 atom1 >= 0 and atom2 >= 0 17 ), f"atom1 and atom2 must be > 0 for colvar {colvar_name}" 18 19 def colvar_distance(coordinates): 20 return jnp.linalg.norm(coordinates[atom1] - coordinates[atom2]) 21 22 return colvar_distance 23 24 25def build_colvar_angle(colvar_name, colvar_def, global_colvars={}): 26 atoms = colvar_def["atoms"] 27 assert len(atoms) == 3, f"Colvar {colvar_name} must have exactly 3 atoms defined" 28 atom1 = int(atoms[0]) - 1 29 atom2 = int(atoms[1]) - 1 30 atom3 = int(atoms[2]) - 1 31 assert ( 32 atom1 != atom2 and atom2 != atom3 and atom1 != atom3 33 ), f"atom1, atom2 and atom3 must be different for colvar {colvar_name}" 34 assert ( 35 atom1 >= 0 and atom2 >= 0 and atom3 >= 0 36 ), f"atom1, atom2 and atom3 must be > 0 for colvar {colvar_name}" 37 use_radians = colvar_def.get("use_radians", False) 38 fact = 1.0 if use_radians else 180.0 / np.pi 39 40 def colvar_angle(coordinates): 41 v1 = coordinates[atom1] - coordinates[atom2] 42 v2 = coordinates[atom3] - coordinates[atom2] 43 v1 = v1 / jnp.linalg.norm(v1) 44 v2 = v2 / jnp.linalg.norm(v2) 45 return jnp.arccos(jnp.dot(v1, v2)) * fact 46 47 return colvar_angle 48 49 50def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}): 51 atoms = colvar_def["atoms"] 52 assert len(atoms) == 4, f"Colvar {colvar_name} must have exactly 4 atoms defined" 53 atom1 = int(atoms[0]) - 1 54 atom2 = int(atoms[1]) - 1 55 atom3 = int(atoms[2]) - 1 56 atom4 = int(atoms[3]) - 1 57 # Check that atoms are different and >= 0 58 assert ( 59 atom1 != atom2 and atom2 != atom3 and atom3 != atom4 and atom1 != atom4 60 ), f"atom1, atom2, atom3 and atom4 must be different for colvar {colvar_name}" 61 assert ( 62 atom1 >= 0 and atom2 >= 0 and atom3 >= 0 and atom4 >= 0 63 ), f"atom1, atom2, atom3 and atom4 must be > 0 for colvar {colvar_name}" 64 use_radians = colvar_def.get("use_radians", False) 65 fact = 1.0 if use_radians else 180.0 / np.pi 66 67 def colvar_dihedral(coordinates): 68 v12 = coordinates[atom2] - coordinates[atom1] 69 v23 = coordinates[atom3] - coordinates[atom2] 70 v34 = coordinates[atom4] - coordinates[atom3] 71 n1 = jnp.cross(v12, v23) 72 n2 = jnp.cross(v23, v34) 73 n1 = n1 / jnp.linalg.norm(n1) 74 n2 = n2 / jnp.linalg.norm(n2) 75 cos_phi = (n1*n2).sum() 76 sin_phi = (n1*v34).sum()*jnp.linalg.norm(v23) 77 phi = jnp.arctan2(sin_phi, cos_phi) 78 return phi * fact 79 80 return colvar_dihedral 81 82def build_colvar_maxdcom(colvar_name, colvar_def, global_colvars={}): 83 atoms = read_tinker_interval(colvar_def["atoms"]) 84 85 def colvar_maxdcom(coordinates): 86 sel_coord = coordinates[atoms] 87 com = jnp.mean(sel_coord, axis=0,keepdims=True) 88 return jnp.max(jnp.linalg.norm(sel_coord - com, axis=1)) 89 90 return colvar_maxdcom 91 92 93__RAW_COLVAR = { 94 "distance": build_colvar_distance, 95 "angle": build_colvar_angle, 96 "dihedral": build_colvar_dihedral, 97 "maxdcom": build_colvar_maxdcom, 98} 99 100 101def build_colvar_function(colvar_name, colvar_def, global_colvars={}): 102 func_def = colvar_def["lambda"] 103 assert isinstance( 104 func_def, str 105 ), f"'lambda' field of colvar '{colvar_name}' must be a string (try using quotation marks)" 106 arguments = [s.strip() for s in func_def.split(":")[0].split(",")] 107 _cvs = [] 108 for cv_name in arguments: 109 if cv_name in colvar_def: 110 cv_def = colvar_def[cv_name] 111 elif cv_name in global_colvars: 112 cv_def = global_colvars[cv_name] 113 else: 114 raise ValueError( 115 f"Colvar '{colvar_name}' references colvar '{cv_name}' which is not defined in the colvars section" 116 ) 117 cv_type = str(cv_def.get("type", "distance")).lower() 118 assert ( 119 cv_type in __RAW_COLVAR 120 ), f"Unknown colvar type '{cv_type}' for colvar '{colvar_name}/{cv_name}'. Available colvars are {list(__RAW_COLVAR.keys())}" 121 _cvs.append(__RAW_COLVAR[cv_type](cv_name, cv_def)) 122 123 func = eval( 124 "lambda " + func_def, 125 { 126 "__builtins__": None, 127 **jax.nn.__dict__, 128 **jax.numpy.__dict__, 129 **jax.__dict__, 130 }, 131 ) 132 133 def colvar_function(coordinates): 134 cv_values = [] 135 for cv in _cvs: 136 cv_values.append(cv(coordinates)) 137 return func(*cv_values) 138 139 return colvar_function 140 141 142__BUILD_COLVAR = { 143 **__RAW_COLVAR, 144 "function": build_colvar_function, 145} 146 147 148def setup_colvars(colvars_definitions): 149 colvars = {} 150 for colvar_name, colvar_def in colvars_definitions.items(): 151 colvar_type = str(colvar_def.get("type", "distance")).lower() 152 assert ( 153 colvar_type in __BUILD_COLVAR 154 ), f"Unknown colvar type '{colvar_type}' for colvar '{colvar_name}'. Available colvars are {list(__BUILD_COLVAR.keys())}" 155 colvars[colvar_name] = __BUILD_COLVAR[colvar_type]( 156 colvar_name, colvar_def, colvars_definitions 157 ) 158 159 return colvars, list(colvars.keys())
def
build_colvar_distance(colvar_name, colvar_def, global_colvars={}):
10def build_colvar_distance(colvar_name, colvar_def, global_colvars={}): 11 atoms = colvar_def["atoms"] 12 assert len(atoms) == 2, f"Colvar {colvar_name} must have exactly 2 atoms defined" 13 atom1 = int(atoms[0]) - 1 14 atom2 = int(atoms[1]) - 1 15 assert atom1 != atom2, f"atom1 and atom2 must be different for colvar {colvar_name}" 16 assert ( 17 atom1 >= 0 and atom2 >= 0 18 ), f"atom1 and atom2 must be > 0 for colvar {colvar_name}" 19 20 def colvar_distance(coordinates): 21 return jnp.linalg.norm(coordinates[atom1] - coordinates[atom2]) 22 23 return colvar_distance
def
build_colvar_angle(colvar_name, colvar_def, global_colvars={}):
26def build_colvar_angle(colvar_name, colvar_def, global_colvars={}): 27 atoms = colvar_def["atoms"] 28 assert len(atoms) == 3, f"Colvar {colvar_name} must have exactly 3 atoms defined" 29 atom1 = int(atoms[0]) - 1 30 atom2 = int(atoms[1]) - 1 31 atom3 = int(atoms[2]) - 1 32 assert ( 33 atom1 != atom2 and atom2 != atom3 and atom1 != atom3 34 ), f"atom1, atom2 and atom3 must be different for colvar {colvar_name}" 35 assert ( 36 atom1 >= 0 and atom2 >= 0 and atom3 >= 0 37 ), f"atom1, atom2 and atom3 must be > 0 for colvar {colvar_name}" 38 use_radians = colvar_def.get("use_radians", False) 39 fact = 1.0 if use_radians else 180.0 / np.pi 40 41 def colvar_angle(coordinates): 42 v1 = coordinates[atom1] - coordinates[atom2] 43 v2 = coordinates[atom3] - coordinates[atom2] 44 v1 = v1 / jnp.linalg.norm(v1) 45 v2 = v2 / jnp.linalg.norm(v2) 46 return jnp.arccos(jnp.dot(v1, v2)) * fact 47 48 return colvar_angle
def
build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}):
51def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}): 52 atoms = colvar_def["atoms"] 53 assert len(atoms) == 4, f"Colvar {colvar_name} must have exactly 4 atoms defined" 54 atom1 = int(atoms[0]) - 1 55 atom2 = int(atoms[1]) - 1 56 atom3 = int(atoms[2]) - 1 57 atom4 = int(atoms[3]) - 1 58 # Check that atoms are different and >= 0 59 assert ( 60 atom1 != atom2 and atom2 != atom3 and atom3 != atom4 and atom1 != atom4 61 ), f"atom1, atom2, atom3 and atom4 must be different for colvar {colvar_name}" 62 assert ( 63 atom1 >= 0 and atom2 >= 0 and atom3 >= 0 and atom4 >= 0 64 ), f"atom1, atom2, atom3 and atom4 must be > 0 for colvar {colvar_name}" 65 use_radians = colvar_def.get("use_radians", False) 66 fact = 1.0 if use_radians else 180.0 / np.pi 67 68 def colvar_dihedral(coordinates): 69 v12 = coordinates[atom2] - coordinates[atom1] 70 v23 = coordinates[atom3] - coordinates[atom2] 71 v34 = coordinates[atom4] - coordinates[atom3] 72 n1 = jnp.cross(v12, v23) 73 n2 = jnp.cross(v23, v34) 74 n1 = n1 / jnp.linalg.norm(n1) 75 n2 = n2 / jnp.linalg.norm(n2) 76 cos_phi = (n1*n2).sum() 77 sin_phi = (n1*v34).sum()*jnp.linalg.norm(v23) 78 phi = jnp.arctan2(sin_phi, cos_phi) 79 return phi * fact 80 81 return colvar_dihedral
def
build_colvar_maxdcom(colvar_name, colvar_def, global_colvars={}):
83def build_colvar_maxdcom(colvar_name, colvar_def, global_colvars={}): 84 atoms = read_tinker_interval(colvar_def["atoms"]) 85 86 def colvar_maxdcom(coordinates): 87 sel_coord = coordinates[atoms] 88 com = jnp.mean(sel_coord, axis=0,keepdims=True) 89 return jnp.max(jnp.linalg.norm(sel_coord - com, axis=1)) 90 91 return colvar_maxdcom
def
build_colvar_function(colvar_name, colvar_def, global_colvars={}):
102def build_colvar_function(colvar_name, colvar_def, global_colvars={}): 103 func_def = colvar_def["lambda"] 104 assert isinstance( 105 func_def, str 106 ), f"'lambda' field of colvar '{colvar_name}' must be a string (try using quotation marks)" 107 arguments = [s.strip() for s in func_def.split(":")[0].split(",")] 108 _cvs = [] 109 for cv_name in arguments: 110 if cv_name in colvar_def: 111 cv_def = colvar_def[cv_name] 112 elif cv_name in global_colvars: 113 cv_def = global_colvars[cv_name] 114 else: 115 raise ValueError( 116 f"Colvar '{colvar_name}' references colvar '{cv_name}' which is not defined in the colvars section" 117 ) 118 cv_type = str(cv_def.get("type", "distance")).lower() 119 assert ( 120 cv_type in __RAW_COLVAR 121 ), f"Unknown colvar type '{cv_type}' for colvar '{colvar_name}/{cv_name}'. Available colvars are {list(__RAW_COLVAR.keys())}" 122 _cvs.append(__RAW_COLVAR[cv_type](cv_name, cv_def)) 123 124 func = eval( 125 "lambda " + func_def, 126 { 127 "__builtins__": None, 128 **jax.nn.__dict__, 129 **jax.numpy.__dict__, 130 **jax.__dict__, 131 }, 132 ) 133 134 def colvar_function(coordinates): 135 cv_values = [] 136 for cv in _cvs: 137 cv_values.append(cv(coordinates)) 138 return func(*cv_values) 139 140 return colvar_function
def
setup_colvars(colvars_definitions):
149def setup_colvars(colvars_definitions): 150 colvars = {} 151 for colvar_name, colvar_def in colvars_definitions.items(): 152 colvar_type = str(colvar_def.get("type", "distance")).lower() 153 assert ( 154 colvar_type in __BUILD_COLVAR 155 ), f"Unknown colvar type '{colvar_type}' for colvar '{colvar_name}'. Available colvars are {list(__BUILD_COLVAR.keys())}" 156 colvars[colvar_name] = __BUILD_COLVAR[colvar_type]( 157 colvar_name, colvar_def, colvars_definitions 158 ) 159 160 return colvars, list(colvars.keys())