fennol.utils
1from .spherical_harmonics import CG_SO3, generate_spherical_harmonics 2from .atomic_units import AtomicUnits, UnitSystem 3from typing import Dict, Any,Sequence, Union 4import jax 5import jax.numpy as jnp 6import numpy as np 7from ase.geometry.cell import cellpar_to_cell 8import numba 9 10def minmaxone(x, name=""): 11 print(name, x.min(), x.max(), (x**2).mean() ** 0.5) 12 13def minmaxone_jax(x, name=""): 14 jax.debug.print( 15 "{name} {min} {max} {mean}", 16 name=name, 17 min=x.min(), 18 max=x.max(), 19 mean=(x**2).mean(), 20 ) 21 22def cell_vectors_to_lengths_angles(cell): 23 cell = cell.reshape(3, 3) 24 a = np.linalg.norm(cell[0]) 25 b = np.linalg.norm(cell[1]) 26 c = np.linalg.norm(cell[2]) 27 degree = 180.0 / np.pi 28 alpha = np.arccos(np.dot(cell[1], cell[2]) / (b * c)) 29 beta = np.arccos(np.dot(cell[0], cell[2]) / (a * c)) 30 gamma = np.arccos(np.dot(cell[0], cell[1]) / (a * b)) 31 return np.array([a, b, c, alpha*degree, beta*degree, gamma*degree], dtype=cell.dtype) 32 33def cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None): 34 return cellpar_to_cell(lengths_angles, ab_normal=ab_normal, a_direction=a_direction) 35 36def parse_cell(cell): 37 if cell is None: 38 return None 39 cell = np.asarray(cell, dtype=float).flatten() 40 assert cell.size in [1, 3, 6, 9], "Cell must be of size 1, 3, 6 or 9" 41 if cell.size == 9: 42 return cell.reshape(3, 3) 43 44 return cell_lengths_angles_to_vectors(cell) 45 46def mask_filter_1d(mask, max_size, *values_fill): 47 cumsum = jnp.cumsum(mask,dtype=jnp.int32) 48 scatter_idx = jnp.where(mask, cumsum - 1, max_size) 49 outputs = [] 50 for value, fill in values_fill: 51 shape = list(value.shape) 52 shape[0] = max_size 53 output = ( 54 jnp.full(shape, fill, dtype=value.dtype) 55 .at[scatter_idx] 56 .set(value, mode="drop") 57 ) 58 outputs.append(output) 59 if cumsum.size == 0: 60 return outputs, scatter_idx, 0 61 return outputs, scatter_idx, cumsum[-1] 62 63 64def deep_update( 65 mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any] 66) -> Dict[Any, Any]: 67 updated_mapping = mapping.copy() 68 for updating_mapping in updating_mappings: 69 for k, v in updating_mapping.items(): 70 if ( 71 k in updated_mapping 72 and isinstance(updated_mapping[k], dict) 73 and isinstance(v, dict) 74 ): 75 updated_mapping[k] = deep_update(updated_mapping[k], v) 76 else: 77 updated_mapping[k] = v 78 return updated_mapping 79 80 81class Counter: 82 def __init__(self, nseg, startsave=1): 83 self.i = 0 84 self.i_avg = 0 85 self.nseg = nseg 86 self.startsave = startsave 87 88 @property 89 def count(self): 90 return self.i 91 92 @property 93 def count_avg(self): 94 return self.i_avg 95 96 @property 97 def nsample(self): 98 return max(self.count_avg - self.startsave + 1, 1) 99 100 @property 101 def is_reset_step(self): 102 return self.count == 0 103 104 def reset_avg(self): 105 self.i_avg = 0 106 107 def reset_all(self): 108 self.i = 0 109 self.i_avg = 0 110 111 def increment(self): 112 self.i = self.i + 1 113 if self.i >= self.nseg: 114 self.i = 0 115 self.i_avg = self.i_avg + 1 116 117### TOPLOGY DETECTION 118@numba.njit 119def _detect_bonds_pbc(radii,coordinates,cell): 120 reciprocal_cell = np.linalg.inv(cell).T 121 cell = cell.T 122 nat = len(radii) 123 bond1 = [] 124 bond2 = [] 125 distances = [] 126 for i in range(nat): 127 for j in range(i + 1, nat): 128 vec = coordinates[i] - coordinates[j] 129 vecpbc = reciprocal_cell @ vec 130 vecpbc -= np.round(vecpbc) 131 vec = cell @ vecpbc 132 dist = np.linalg.norm(vec) 133 if dist < radii[i] + radii[j] + 0.4 and dist > 0.4: 134 bond1.append(i) 135 bond2.append(j) 136 distances.append(dist) 137 return bond1,bond2, distances 138 139@numba.njit 140def _detect_bonds(radii,coordinates): 141 nat = len(radii) 142 bond1 = [] 143 bond2 = [] 144 distances = [] 145 for i in range(nat): 146 for j in range(i + 1, nat): 147 vec = coordinates[i] - coordinates[j] 148 dist = np.linalg.norm(vec) 149 if dist < radii[i] + radii[j] + 0.4 and dist > 0.4: 150 bond1.append(i) 151 bond2.append(j) 152 distances.append(dist) 153 return bond1,bond2, distances 154 155def detect_topology(species,coordinates, cell=None): 156 """ 157 Detects the topology of a system based on species and coordinates. 158 Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond. 159 Inspired by OpenBabel's ConnectTheDots in mol.cpp 160 """ 161 from .periodic_table import COV_RADII, UFF_MAX_COORDINATION 162 radii = (COV_RADII* AtomicUnits.BOHR)[species] 163 max_coord = UFF_MAX_COORDINATION[species] 164 165 if cell is not None: 166 bond1,bond2,distances = _detect_bonds_pbc(radii, coordinates, cell) 167 else: 168 bond1,bond2,distances = _detect_bonds(radii, coordinates) 169 170 bond1 = np.array(bond1, dtype=np.int32) 171 bond2 = np.array(bond2, dtype=np.int32) 172 bonds = np.stack((bond1, bond2), axis=1) 173 174 coord = np.zeros(len(species), dtype=np.int32) 175 np.add.at(coord, bonds[:, 0], 1) 176 np.add.at(coord, bonds[:, 1], 1) 177 178 if np.all(coord <= max_coord): 179 return bonds 180 181 distances = np.array(distances, dtype=np.float32) 182 radiibonds = radii[bonds] 183 req = radiibonds.sum(axis=1) 184 rminbonds = radiibonds.min(axis=1) 185 sorted_indices = np.lexsort((-distances/req, rminbonds)) 186 187 bonds = bonds[sorted_indices,:] 188 distances = distances[sorted_indices] 189 190 true_bonds = [] 191 for ibond in range(bonds.shape[0]): 192 i,j = bonds[ibond] 193 ci, cj = coord[i], coord[j] 194 mci, mcj = max_coord[i], max_coord[j] 195 if ci <= mci and cj <= mcj: 196 true_bonds.append((i, j)) 197 else: 198 coord[i] -= 1 199 coord[j] -= 1 200 201 true_bonds = np.array(true_bonds, dtype=np.int32) 202 sorted_indices = np.lexsort((true_bonds[:, 1], true_bonds[:, 0])) 203 true_bonds = true_bonds[sorted_indices, :] 204 205 return true_bonds 206 207def get_energy_gradient_function( 208 energy_function, 209 gradient_keys: Sequence[str], 210 jit: bool = True, 211 ): 212 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 213 214 def energy_gradient(data): 215 def _etot(inputs): 216 if "strain" in inputs: 217 scaling = inputs["strain"] 218 batch_index = data["batch_index"] 219 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 220 coordinates = jax.vmap(jnp.matmul)( 221 coordinates, scaling[batch_index] 222 ) 223 inputs = {**inputs, "coordinates": coordinates} 224 if "cells" in inputs or "cells" in data: 225 cells = inputs["cells"] if "cells" in inputs else data["cells"] 226 cells = jax.vmap(jnp.matmul)(cells, scaling) 227 inputs["cells"] = cells 228 if "cells" in inputs: 229 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 230 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 231 energy, out = energy_function(inputs) 232 return energy.sum(), out 233 234 if "strain" in gradient_keys and "strain" not in data: 235 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 236 inputs = {k: data[k] for k in gradient_keys} 237 de, out = jax.grad(_etot, argnums=1, has_aux=True)(inputs) 238 239 return ( 240 de, 241 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 242 ) 243 244 if jit: 245 return jax.jit(energy_gradient) 246 else: 247 return energy_gradient 248 249 250def read_tinker_interval(indices_interval: Sequence[Union[int,str]]) -> np.ndarray: 251 interval = [int(i) for i in indices_interval] 252 indices = [] 253 while len(interval) > 0: 254 i = interval.pop(0) 255 if i > 0: 256 indices.append(i) 257 elif i < 0: 258 start = -i 259 end = interval.pop(0) 260 assert end > start, "Syntax error in ligand indices. End index must be greater than start index." 261 indices.extend(range(start, end + 1)) 262 else: 263 raise ValueError("Syntax error in ligand indices. Indicing should be 1-based.") 264 indices = np.unique(np.array(indices, dtype=np.int32)) 265 return indices - 1 # Convert to zero-based indexing
def
minmaxone(x, name=''):
def
minmaxone_jax(x, name=''):
def
cell_vectors_to_lengths_angles(cell):
23def cell_vectors_to_lengths_angles(cell): 24 cell = cell.reshape(3, 3) 25 a = np.linalg.norm(cell[0]) 26 b = np.linalg.norm(cell[1]) 27 c = np.linalg.norm(cell[2]) 28 degree = 180.0 / np.pi 29 alpha = np.arccos(np.dot(cell[1], cell[2]) / (b * c)) 30 beta = np.arccos(np.dot(cell[0], cell[2]) / (a * c)) 31 gamma = np.arccos(np.dot(cell[0], cell[1]) / (a * b)) 32 return np.array([a, b, c, alpha*degree, beta*degree, gamma*degree], dtype=cell.dtype)
def
cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None):
def
parse_cell(cell):
def
mask_filter_1d(mask, max_size, *values_fill):
47def mask_filter_1d(mask, max_size, *values_fill): 48 cumsum = jnp.cumsum(mask,dtype=jnp.int32) 49 scatter_idx = jnp.where(mask, cumsum - 1, max_size) 50 outputs = [] 51 for value, fill in values_fill: 52 shape = list(value.shape) 53 shape[0] = max_size 54 output = ( 55 jnp.full(shape, fill, dtype=value.dtype) 56 .at[scatter_idx] 57 .set(value, mode="drop") 58 ) 59 outputs.append(output) 60 if cumsum.size == 0: 61 return outputs, scatter_idx, 0 62 return outputs, scatter_idx, cumsum[-1]
def
deep_update( mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]) -> Dict[Any, Any]:
65def deep_update( 66 mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any] 67) -> Dict[Any, Any]: 68 updated_mapping = mapping.copy() 69 for updating_mapping in updating_mappings: 70 for k, v in updating_mapping.items(): 71 if ( 72 k in updated_mapping 73 and isinstance(updated_mapping[k], dict) 74 and isinstance(v, dict) 75 ): 76 updated_mapping[k] = deep_update(updated_mapping[k], v) 77 else: 78 updated_mapping[k] = v 79 return updated_mapping
class
Counter:
82class Counter: 83 def __init__(self, nseg, startsave=1): 84 self.i = 0 85 self.i_avg = 0 86 self.nseg = nseg 87 self.startsave = startsave 88 89 @property 90 def count(self): 91 return self.i 92 93 @property 94 def count_avg(self): 95 return self.i_avg 96 97 @property 98 def nsample(self): 99 return max(self.count_avg - self.startsave + 1, 1) 100 101 @property 102 def is_reset_step(self): 103 return self.count == 0 104 105 def reset_avg(self): 106 self.i_avg = 0 107 108 def reset_all(self): 109 self.i = 0 110 self.i_avg = 0 111 112 def increment(self): 113 self.i = self.i + 1 114 if self.i >= self.nseg: 115 self.i = 0 116 self.i_avg = self.i_avg + 1
def
detect_topology(species, coordinates, cell=None):
156def detect_topology(species,coordinates, cell=None): 157 """ 158 Detects the topology of a system based on species and coordinates. 159 Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond. 160 Inspired by OpenBabel's ConnectTheDots in mol.cpp 161 """ 162 from .periodic_table import COV_RADII, UFF_MAX_COORDINATION 163 radii = (COV_RADII* AtomicUnits.BOHR)[species] 164 max_coord = UFF_MAX_COORDINATION[species] 165 166 if cell is not None: 167 bond1,bond2,distances = _detect_bonds_pbc(radii, coordinates, cell) 168 else: 169 bond1,bond2,distances = _detect_bonds(radii, coordinates) 170 171 bond1 = np.array(bond1, dtype=np.int32) 172 bond2 = np.array(bond2, dtype=np.int32) 173 bonds = np.stack((bond1, bond2), axis=1) 174 175 coord = np.zeros(len(species), dtype=np.int32) 176 np.add.at(coord, bonds[:, 0], 1) 177 np.add.at(coord, bonds[:, 1], 1) 178 179 if np.all(coord <= max_coord): 180 return bonds 181 182 distances = np.array(distances, dtype=np.float32) 183 radiibonds = radii[bonds] 184 req = radiibonds.sum(axis=1) 185 rminbonds = radiibonds.min(axis=1) 186 sorted_indices = np.lexsort((-distances/req, rminbonds)) 187 188 bonds = bonds[sorted_indices,:] 189 distances = distances[sorted_indices] 190 191 true_bonds = [] 192 for ibond in range(bonds.shape[0]): 193 i,j = bonds[ibond] 194 ci, cj = coord[i], coord[j] 195 mci, mcj = max_coord[i], max_coord[j] 196 if ci <= mci and cj <= mcj: 197 true_bonds.append((i, j)) 198 else: 199 coord[i] -= 1 200 coord[j] -= 1 201 202 true_bonds = np.array(true_bonds, dtype=np.int32) 203 sorted_indices = np.lexsort((true_bonds[:, 1], true_bonds[:, 0])) 204 true_bonds = true_bonds[sorted_indices, :] 205 206 return true_bonds
Detects the topology of a system based on species and coordinates. Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond. Inspired by OpenBabel's ConnectTheDots in mol.cpp
def
get_energy_gradient_function(energy_function, gradient_keys: Sequence[str], jit: bool = True):
208def get_energy_gradient_function( 209 energy_function, 210 gradient_keys: Sequence[str], 211 jit: bool = True, 212 ): 213 """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys""" 214 215 def energy_gradient(data): 216 def _etot(inputs): 217 if "strain" in inputs: 218 scaling = inputs["strain"] 219 batch_index = data["batch_index"] 220 coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"] 221 coordinates = jax.vmap(jnp.matmul)( 222 coordinates, scaling[batch_index] 223 ) 224 inputs = {**inputs, "coordinates": coordinates} 225 if "cells" in inputs or "cells" in data: 226 cells = inputs["cells"] if "cells" in inputs else data["cells"] 227 cells = jax.vmap(jnp.matmul)(cells, scaling) 228 inputs["cells"] = cells 229 if "cells" in inputs: 230 reciprocal_cells = jnp.linalg.inv(inputs["cells"]) 231 inputs = {**inputs, "reciprocal_cells": reciprocal_cells} 232 energy, out = energy_function(inputs) 233 return energy.sum(), out 234 235 if "strain" in gradient_keys and "strain" not in data: 236 data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))} 237 inputs = {k: data[k] for k in gradient_keys} 238 de, out = jax.grad(_etot, argnums=1, has_aux=True)(inputs) 239 240 return ( 241 de, 242 {**out, **{"dEd_" + k: de[k] for k in gradient_keys}}, 243 ) 244 245 if jit: 246 return jax.jit(energy_gradient) 247 else: 248 return energy_gradient
Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys
def
read_tinker_interval(indices_interval: Sequence[Union[int, str]]) -> numpy.ndarray:
251def read_tinker_interval(indices_interval: Sequence[Union[int,str]]) -> np.ndarray: 252 interval = [int(i) for i in indices_interval] 253 indices = [] 254 while len(interval) > 0: 255 i = interval.pop(0) 256 if i > 0: 257 indices.append(i) 258 elif i < 0: 259 start = -i 260 end = interval.pop(0) 261 assert end > start, "Syntax error in ligand indices. End index must be greater than start index." 262 indices.extend(range(start, end + 1)) 263 else: 264 raise ValueError("Syntax error in ligand indices. Indicing should be 1-based.") 265 indices = np.unique(np.array(indices, dtype=np.int32)) 266 return indices - 1 # Convert to zero-based indexing