fennol.md.colvars
1import jax 2import jax.numpy as jnp 3import numpy as np 4from functools import partial 5 6 7def build_colvar_distance(colvar_name, colvar_def, global_colvars={}): 8 atom1 = colvar_def["atom1"] - 1 9 atom2 = colvar_def["atom2"] - 1 10 assert atom1 != atom2, f"atom1 and atom2 must be different for colvar {colvar_name}" 11 assert ( 12 atom1 >= 0 and atom2 >= 0 13 ), f"atom1 and atom2 must be > 0 for colvar {colvar_name}" 14 15 def colvar_distance(coordinates): 16 return jnp.linalg.norm(coordinates[atom1] - coordinates[atom2]) 17 18 return colvar_distance 19 20 21def build_colvar_angle(colvar_name, colvar_def, global_colvars={}): 22 atom1 = colvar_def["atom1"] - 1 23 atom2 = colvar_def["atom2"] - 1 24 atom3 = colvar_def["atom3"] - 1 25 assert ( 26 atom1 != atom2 and atom2 != atom3 and atom1 != atom3 27 ), f"atom1, atom2 and atom3 must be different for colvar {colvar_name}" 28 assert ( 29 atom1 >= 0 and atom2 >= 0 and atom3 >= 0 30 ), f"atom1, atom2 and atom3 must be > 0 for colvar {colvar_name}" 31 use_radians = colvar_def.get("use_radians", False) 32 fact = 1.0 if use_radians else 180.0 / np.pi 33 34 def colvar_angle(coordinates): 35 v1 = coordinates[atom1] - coordinates[atom2] 36 v2 = coordinates[atom3] - coordinates[atom2] 37 v1 = v1 / jnp.linalg.norm(v1) 38 v2 = v2 / jnp.linalg.norm(v2) 39 return jnp.arccos(jnp.dot(v1, v2)) * fact 40 41 return colvar_angle 42 43 44def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}): 45 atom1 = colvar_def["atom1"] - 1 46 atom2 = colvar_def["atom2"] - 1 47 atom3 = colvar_def["atom3"] - 1 48 atom4 = colvar_def["atom4"] - 1 49 assert ( 50 atom1 != atom2 and atom2 != atom3 and atom3 != atom4 and atom1 != atom4 51 ), f"atom1, atom2, atom3 and atom4 must be different for colvar {colvar_name}" 52 assert ( 53 atom1 >= 0 and atom2 >= 0 and atom3 >= 0 and atom4 >= 0 54 ), f"atom1, atom2, atom3 and atom4 must be > 0 for colvar {colvar_name}" 55 use_radians = colvar_def.get("use_radians", False) 56 fact = 1.0 if use_radians else 180.0 / np.pi 57 58 def colvar_dihedral(coordinates): 59 v1 = coordinates[atom1] - coordinates[atom2] 60 v2 = coordinates[atom3] - coordinates[atom2] 61 v3 = coordinates[atom4] - coordinates[atom3] 62 n1 = jnp.cross(v1, v2) 63 n2 = jnp.cross(v2, v3) 64 n1 = n1 / jnp.linalg.norm(n1) 65 n2 = n2 / jnp.linalg.norm(n2) 66 return jnp.arccos(jnp.dot(n1, n2)) * fact 67 68 return colvar_dihedral 69 70 71__RAW_COLVAR = { 72 "distance": build_colvar_distance, 73 "angle": build_colvar_angle, 74 "dihedral": build_colvar_dihedral, 75} 76 77 78def build_colvar_function(colvar_name, colvar_def, global_colvars={}): 79 func_def = colvar_def["lambda"] 80 assert isinstance( 81 func_def, str 82 ), f"'lambda' field of colvar '{colvar_name}' must be a string (try using quotation marks)" 83 arguments = [s.strip() for s in func_def.split(":")[0].split(",")] 84 _cvs = [] 85 for cv_name in arguments: 86 if cv_name in colvar_def: 87 cv_def = colvar_def[cv_name] 88 elif cv_name in global_colvars: 89 cv_def = global_colvars[cv_name] 90 else: 91 raise ValueError( 92 f"Colvar '{colvar_name}' references colvar '{cv_name}' which is not defined in the colvars section" 93 ) 94 cv_type = str(cv_def.get("type", "distance")).lower() 95 assert ( 96 cv_type in __RAW_COLVAR 97 ), f"Unknown colvar type '{cv_type}' for colvar '{colvar_name}/{cv_name}'. Available colvars are {list(__RAW_COLVAR.keys())}" 98 _cvs.append(__RAW_COLVAR[cv_type](cv_name, cv_def)) 99 100 func = eval( 101 "lambda " + func_def, 102 { 103 "__builtins__": None, 104 **jax.nn.__dict__, 105 **jax.numpy.__dict__, 106 **jax.__dict__, 107 }, 108 ) 109 110 def colvar_function(coordinates): 111 cv_values = [] 112 for cv in _cvs: 113 cv_values.append(cv(coordinates)) 114 return func(*cv_values) 115 116 return colvar_function 117 118 119__BUILD_COLVAR = { 120 **__RAW_COLVAR, 121 "function": build_colvar_function, 122} 123 124 125def setup_colvars(colvars_definitions): 126 colvars = {} 127 for colvar_name, colvar_def in colvars_definitions.items(): 128 colvar_type = str(colvar_def.get("type", "distance")).lower() 129 assert ( 130 colvar_type in __BUILD_COLVAR 131 ), f"Unknown colvar type '{colvar_type}' for colvar '{colvar_name}'. Available colvars are {list(__BUILD_COLVAR.keys())}" 132 colvars[colvar_name] = __BUILD_COLVAR[colvar_type]( 133 colvar_name, colvar_def, colvars_definitions 134 ) 135 136 return colvars, list(colvars.keys())
def
build_colvar_distance(colvar_name, colvar_def, global_colvars={}):
8def build_colvar_distance(colvar_name, colvar_def, global_colvars={}): 9 atom1 = colvar_def["atom1"] - 1 10 atom2 = colvar_def["atom2"] - 1 11 assert atom1 != atom2, f"atom1 and atom2 must be different for colvar {colvar_name}" 12 assert ( 13 atom1 >= 0 and atom2 >= 0 14 ), f"atom1 and atom2 must be > 0 for colvar {colvar_name}" 15 16 def colvar_distance(coordinates): 17 return jnp.linalg.norm(coordinates[atom1] - coordinates[atom2]) 18 19 return colvar_distance
def
build_colvar_angle(colvar_name, colvar_def, global_colvars={}):
22def build_colvar_angle(colvar_name, colvar_def, global_colvars={}): 23 atom1 = colvar_def["atom1"] - 1 24 atom2 = colvar_def["atom2"] - 1 25 atom3 = colvar_def["atom3"] - 1 26 assert ( 27 atom1 != atom2 and atom2 != atom3 and atom1 != atom3 28 ), f"atom1, atom2 and atom3 must be different for colvar {colvar_name}" 29 assert ( 30 atom1 >= 0 and atom2 >= 0 and atom3 >= 0 31 ), f"atom1, atom2 and atom3 must be > 0 for colvar {colvar_name}" 32 use_radians = colvar_def.get("use_radians", False) 33 fact = 1.0 if use_radians else 180.0 / np.pi 34 35 def colvar_angle(coordinates): 36 v1 = coordinates[atom1] - coordinates[atom2] 37 v2 = coordinates[atom3] - coordinates[atom2] 38 v1 = v1 / jnp.linalg.norm(v1) 39 v2 = v2 / jnp.linalg.norm(v2) 40 return jnp.arccos(jnp.dot(v1, v2)) * fact 41 42 return colvar_angle
def
build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}):
45def build_colvar_dihedral(colvar_name, colvar_def, global_colvars={}): 46 atom1 = colvar_def["atom1"] - 1 47 atom2 = colvar_def["atom2"] - 1 48 atom3 = colvar_def["atom3"] - 1 49 atom4 = colvar_def["atom4"] - 1 50 assert ( 51 atom1 != atom2 and atom2 != atom3 and atom3 != atom4 and atom1 != atom4 52 ), f"atom1, atom2, atom3 and atom4 must be different for colvar {colvar_name}" 53 assert ( 54 atom1 >= 0 and atom2 >= 0 and atom3 >= 0 and atom4 >= 0 55 ), f"atom1, atom2, atom3 and atom4 must be > 0 for colvar {colvar_name}" 56 use_radians = colvar_def.get("use_radians", False) 57 fact = 1.0 if use_radians else 180.0 / np.pi 58 59 def colvar_dihedral(coordinates): 60 v1 = coordinates[atom1] - coordinates[atom2] 61 v2 = coordinates[atom3] - coordinates[atom2] 62 v3 = coordinates[atom4] - coordinates[atom3] 63 n1 = jnp.cross(v1, v2) 64 n2 = jnp.cross(v2, v3) 65 n1 = n1 / jnp.linalg.norm(n1) 66 n2 = n2 / jnp.linalg.norm(n2) 67 return jnp.arccos(jnp.dot(n1, n2)) * fact 68 69 return colvar_dihedral
def
build_colvar_function(colvar_name, colvar_def, global_colvars={}):
79def build_colvar_function(colvar_name, colvar_def, global_colvars={}): 80 func_def = colvar_def["lambda"] 81 assert isinstance( 82 func_def, str 83 ), f"'lambda' field of colvar '{colvar_name}' must be a string (try using quotation marks)" 84 arguments = [s.strip() for s in func_def.split(":")[0].split(",")] 85 _cvs = [] 86 for cv_name in arguments: 87 if cv_name in colvar_def: 88 cv_def = colvar_def[cv_name] 89 elif cv_name in global_colvars: 90 cv_def = global_colvars[cv_name] 91 else: 92 raise ValueError( 93 f"Colvar '{colvar_name}' references colvar '{cv_name}' which is not defined in the colvars section" 94 ) 95 cv_type = str(cv_def.get("type", "distance")).lower() 96 assert ( 97 cv_type in __RAW_COLVAR 98 ), f"Unknown colvar type '{cv_type}' for colvar '{colvar_name}/{cv_name}'. Available colvars are {list(__RAW_COLVAR.keys())}" 99 _cvs.append(__RAW_COLVAR[cv_type](cv_name, cv_def)) 100 101 func = eval( 102 "lambda " + func_def, 103 { 104 "__builtins__": None, 105 **jax.nn.__dict__, 106 **jax.numpy.__dict__, 107 **jax.__dict__, 108 }, 109 ) 110 111 def colvar_function(coordinates): 112 cv_values = [] 113 for cv in _cvs: 114 cv_values.append(cv(coordinates)) 115 return func(*cv_values) 116 117 return colvar_function
def
setup_colvars(colvars_definitions):
126def setup_colvars(colvars_definitions): 127 colvars = {} 128 for colvar_name, colvar_def in colvars_definitions.items(): 129 colvar_type = str(colvar_def.get("type", "distance")).lower() 130 assert ( 131 colvar_type in __BUILD_COLVAR 132 ), f"Unknown colvar type '{colvar_type}' for colvar '{colvar_name}'. Available colvars are {list(__BUILD_COLVAR.keys())}" 133 colvars[colvar_name] = __BUILD_COLVAR[colvar_type]( 134 colvar_name, colvar_def, colvars_definitions 135 ) 136 137 return colvars, list(colvars.keys())