fennol.utils

  1from .spherical_harmonics import CG_SO3, generate_spherical_harmonics
  2from .atomic_units import AtomicUnits, UnitSystem
  3from typing import Dict, Any,Sequence, Union
  4import jax
  5import jax.numpy as jnp
  6import numpy as np
  7from ase.geometry.cell import cellpar_to_cell
  8import numba
  9
 10def minmaxone(x, name=""):
 11    print(name, x.min(), x.max(), (x**2).mean() ** 0.5)
 12
 13def minmaxone_jax(x, name=""):
 14    jax.debug.print(
 15        "{name}  {min}  {max}  {mean}",
 16        name=name,
 17        min=x.min(),
 18        max=x.max(),
 19        mean=(x**2).mean(),
 20    )
 21
 22def cell_vectors_to_lengths_angles(cell):
 23    cell = cell.reshape(3, 3)
 24    a = np.linalg.norm(cell[0])
 25    b = np.linalg.norm(cell[1])
 26    c = np.linalg.norm(cell[2])
 27    degree = 180.0 / np.pi
 28    alpha = np.arccos(np.dot(cell[1], cell[2]) / (b * c))
 29    beta = np.arccos(np.dot(cell[0], cell[2]) / (a * c))
 30    gamma = np.arccos(np.dot(cell[0], cell[1]) / (a * b))
 31    return np.array([a, b, c, alpha*degree, beta*degree, gamma*degree], dtype=cell.dtype)
 32
 33def cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None):
 34    return cellpar_to_cell(lengths_angles, ab_normal=ab_normal, a_direction=a_direction)
 35
 36def parse_cell(cell):
 37    if cell is None:
 38        return None
 39    cell = np.asarray(cell, dtype=float).flatten()
 40    assert cell.size in [1, 3, 6, 9], "Cell must be of size 1, 3, 6 or 9"
 41    if cell.size == 9:
 42        return cell.reshape(3, 3)
 43    
 44    return cell_lengths_angles_to_vectors(cell)
 45
 46def mask_filter_1d(mask, max_size, *values_fill):
 47    cumsum = jnp.cumsum(mask,dtype=jnp.int32)
 48    scatter_idx = jnp.where(mask, cumsum - 1, max_size)
 49    outputs = []
 50    for value, fill in values_fill:
 51        shape = list(value.shape)
 52        shape[0] = max_size
 53        output = (
 54            jnp.full(shape, fill, dtype=value.dtype)
 55            .at[scatter_idx]
 56            .set(value, mode="drop")
 57        )
 58        outputs.append(output)
 59    if cumsum.size == 0:
 60        return outputs, scatter_idx, 0
 61    return outputs, scatter_idx, cumsum[-1]
 62
 63
 64def deep_update(
 65    mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]
 66) -> Dict[Any, Any]:
 67    updated_mapping = mapping.copy()
 68    for updating_mapping in updating_mappings:
 69        for k, v in updating_mapping.items():
 70            if (
 71                k in updated_mapping
 72                and isinstance(updated_mapping[k], dict)
 73                and isinstance(v, dict)
 74            ):
 75                updated_mapping[k] = deep_update(updated_mapping[k], v)
 76            else:
 77                updated_mapping[k] = v
 78    return updated_mapping
 79
 80
 81class Counter:
 82    def __init__(self, nseg, startsave=1):
 83        self.i = 0
 84        self.i_avg = 0
 85        self.nseg = nseg
 86        self.startsave = startsave
 87
 88    @property
 89    def count(self):
 90        return self.i
 91
 92    @property
 93    def count_avg(self):
 94        return self.i_avg
 95
 96    @property
 97    def nsample(self):
 98        return max(self.count_avg - self.startsave + 1, 1)
 99
100    @property
101    def is_reset_step(self):
102        return self.count == 0
103
104    def reset_avg(self):
105        self.i_avg = 0
106
107    def reset_all(self):
108        self.i = 0
109        self.i_avg = 0
110
111    def increment(self):
112        self.i = self.i + 1
113        if self.i >= self.nseg:
114            self.i = 0
115            self.i_avg = self.i_avg + 1
116
117### TOPLOGY DETECTION
118@numba.njit
119def _detect_bonds_pbc(radii,coordinates,cell):
120    reciprocal_cell = np.linalg.inv(cell).T
121    cell = cell.T
122    nat = len(radii)
123    bond1 = []
124    bond2 = []
125    distances = []
126    for i in range(nat):
127        for j in range(i + 1, nat):
128            vec = coordinates[i] - coordinates[j]
129            vecpbc = reciprocal_cell @ vec
130            vecpbc -= np.round(vecpbc)
131            vec = cell @ vecpbc
132            dist = np.linalg.norm(vec)
133            if dist < radii[i] + radii[j] + 0.4 and dist > 0.4:
134                bond1.append(i)
135                bond2.append(j)
136                distances.append(dist)
137    return bond1,bond2, distances
138
139@numba.njit
140def _detect_bonds(radii,coordinates):
141    nat = len(radii)
142    bond1 = []
143    bond2 = []
144    distances = []
145    for i in range(nat):
146        for j in range(i + 1, nat):
147            vec = coordinates[i] - coordinates[j]
148            dist = np.linalg.norm(vec)
149            if dist < radii[i] + radii[j] + 0.4 and dist > 0.4:
150                bond1.append(i)
151                bond2.append(j)
152                distances.append(dist)
153    return bond1,bond2, distances
154
155def detect_topology(species,coordinates, cell=None):
156    """
157    Detects the topology of a system based on species and coordinates.
158    Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond.
159    Inspired by OpenBabel's ConnectTheDots in mol.cpp
160    """
161    from .periodic_table import COV_RADII, UFF_MAX_COORDINATION
162    radii = (COV_RADII* AtomicUnits.BOHR)[species]
163    max_coord = UFF_MAX_COORDINATION[species]
164
165    if cell is not None:
166        bond1,bond2,distances = _detect_bonds_pbc(radii, coordinates, cell)
167    else:
168        bond1,bond2,distances = _detect_bonds(radii, coordinates)
169
170    bond1 = np.array(bond1, dtype=np.int32)
171    bond2 = np.array(bond2, dtype=np.int32)
172    bonds = np.stack((bond1, bond2), axis=1)
173    
174    coord = np.zeros(len(species), dtype=np.int32)
175    np.add.at(coord, bonds[:, 0], 1)
176    np.add.at(coord, bonds[:, 1], 1)
177
178    if np.all(coord <= max_coord):
179        return bonds
180
181    distances = np.array(distances, dtype=np.float32)
182    radiibonds = radii[bonds]
183    req = radiibonds.sum(axis=1)
184    rminbonds = radiibonds.min(axis=1)
185    sorted_indices = np.lexsort((-distances/req, rminbonds))
186
187    bonds = bonds[sorted_indices,:]
188    distances = distances[sorted_indices]
189
190    true_bonds = []
191    for ibond in range(bonds.shape[0]):
192        i,j = bonds[ibond]
193        ci, cj = coord[i], coord[j]
194        mci, mcj = max_coord[i], max_coord[j]
195        if ci <= mci and cj <= mcj:
196            true_bonds.append((i, j))
197        else:
198            coord[i] -= 1
199            coord[j] -= 1
200
201    true_bonds = np.array(true_bonds, dtype=np.int32)
202    sorted_indices = np.lexsort((true_bonds[:, 1], true_bonds[:, 0]))
203    true_bonds = true_bonds[sorted_indices, :]
204
205    return true_bonds
206
207def get_energy_gradient_function(
208        energy_function,
209        gradient_keys: Sequence[str],
210        jit: bool = True,
211    ):
212        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
213
214        def energy_gradient(data):
215            def _etot(inputs):
216                if "strain" in inputs:
217                    scaling = inputs["strain"]
218                    batch_index = data["batch_index"]
219                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
220                    coordinates = jax.vmap(jnp.matmul)(
221                        coordinates, scaling[batch_index]
222                    )
223                    inputs = {**inputs, "coordinates": coordinates}
224                    if "cells" in inputs or "cells" in data:
225                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
226                        cells = jax.vmap(jnp.matmul)(cells, scaling)
227                        inputs["cells"] = cells
228                if "cells" in inputs:
229                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
230                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
231                energy, out = energy_function(inputs)
232                return energy.sum(), out
233
234            if "strain" in gradient_keys and "strain" not in data:
235                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
236            inputs = {k: data[k] for k in gradient_keys}
237            de, out = jax.grad(_etot, argnums=1, has_aux=True)(inputs)
238
239            return (
240                de,
241                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
242            )
243
244        if jit:
245            return jax.jit(energy_gradient)
246        else:
247            return energy_gradient
248
249
250def read_tinker_interval(indices_interval: Sequence[Union[int,str]]) -> np.ndarray:
251    interval = [int(i) for i in indices_interval]
252    indices = []
253    while len(interval) > 0:
254        i = interval.pop(0)
255        if i > 0:
256            indices.append(i)
257        elif i < 0:
258            start = -i
259            end = interval.pop(0)
260            assert end > start, "Syntax error in ligand indices. End index must be greater than start index."
261            indices.extend(range(start, end + 1))
262        else:
263            raise ValueError("Syntax error in ligand indices. Indicing should be 1-based.")
264    indices = np.unique(np.array(indices, dtype=np.int32))
265    return indices - 1  # Convert to zero-based indexing
def minmaxone(x, name=''):
11def minmaxone(x, name=""):
12    print(name, x.min(), x.max(), (x**2).mean() ** 0.5)
def minmaxone_jax(x, name=''):
14def minmaxone_jax(x, name=""):
15    jax.debug.print(
16        "{name}  {min}  {max}  {mean}",
17        name=name,
18        min=x.min(),
19        max=x.max(),
20        mean=(x**2).mean(),
21    )
def cell_vectors_to_lengths_angles(cell):
23def cell_vectors_to_lengths_angles(cell):
24    cell = cell.reshape(3, 3)
25    a = np.linalg.norm(cell[0])
26    b = np.linalg.norm(cell[1])
27    c = np.linalg.norm(cell[2])
28    degree = 180.0 / np.pi
29    alpha = np.arccos(np.dot(cell[1], cell[2]) / (b * c))
30    beta = np.arccos(np.dot(cell[0], cell[2]) / (a * c))
31    gamma = np.arccos(np.dot(cell[0], cell[1]) / (a * b))
32    return np.array([a, b, c, alpha*degree, beta*degree, gamma*degree], dtype=cell.dtype)
def cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None):
34def cell_lengths_angles_to_vectors(lengths_angles, ab_normal=(0, 0, 1), a_direction=None):
35    return cellpar_to_cell(lengths_angles, ab_normal=ab_normal, a_direction=a_direction)
def parse_cell(cell):
37def parse_cell(cell):
38    if cell is None:
39        return None
40    cell = np.asarray(cell, dtype=float).flatten()
41    assert cell.size in [1, 3, 6, 9], "Cell must be of size 1, 3, 6 or 9"
42    if cell.size == 9:
43        return cell.reshape(3, 3)
44    
45    return cell_lengths_angles_to_vectors(cell)
def mask_filter_1d(mask, max_size, *values_fill):
47def mask_filter_1d(mask, max_size, *values_fill):
48    cumsum = jnp.cumsum(mask,dtype=jnp.int32)
49    scatter_idx = jnp.where(mask, cumsum - 1, max_size)
50    outputs = []
51    for value, fill in values_fill:
52        shape = list(value.shape)
53        shape[0] = max_size
54        output = (
55            jnp.full(shape, fill, dtype=value.dtype)
56            .at[scatter_idx]
57            .set(value, mode="drop")
58        )
59        outputs.append(output)
60    if cumsum.size == 0:
61        return outputs, scatter_idx, 0
62    return outputs, scatter_idx, cumsum[-1]
def deep_update( mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]) -> Dict[Any, Any]:
65def deep_update(
66    mapping: Dict[Any, Any], *updating_mappings: Dict[Any, Any]
67) -> Dict[Any, Any]:
68    updated_mapping = mapping.copy()
69    for updating_mapping in updating_mappings:
70        for k, v in updating_mapping.items():
71            if (
72                k in updated_mapping
73                and isinstance(updated_mapping[k], dict)
74                and isinstance(v, dict)
75            ):
76                updated_mapping[k] = deep_update(updated_mapping[k], v)
77            else:
78                updated_mapping[k] = v
79    return updated_mapping
class Counter:
 82class Counter:
 83    def __init__(self, nseg, startsave=1):
 84        self.i = 0
 85        self.i_avg = 0
 86        self.nseg = nseg
 87        self.startsave = startsave
 88
 89    @property
 90    def count(self):
 91        return self.i
 92
 93    @property
 94    def count_avg(self):
 95        return self.i_avg
 96
 97    @property
 98    def nsample(self):
 99        return max(self.count_avg - self.startsave + 1, 1)
100
101    @property
102    def is_reset_step(self):
103        return self.count == 0
104
105    def reset_avg(self):
106        self.i_avg = 0
107
108    def reset_all(self):
109        self.i = 0
110        self.i_avg = 0
111
112    def increment(self):
113        self.i = self.i + 1
114        if self.i >= self.nseg:
115            self.i = 0
116            self.i_avg = self.i_avg + 1
Counter(nseg, startsave=1)
83    def __init__(self, nseg, startsave=1):
84        self.i = 0
85        self.i_avg = 0
86        self.nseg = nseg
87        self.startsave = startsave
i
i_avg
nseg
startsave
count
89    @property
90    def count(self):
91        return self.i
count_avg
93    @property
94    def count_avg(self):
95        return self.i_avg
nsample
97    @property
98    def nsample(self):
99        return max(self.count_avg - self.startsave + 1, 1)
is_reset_step
101    @property
102    def is_reset_step(self):
103        return self.count == 0
def reset_avg(self):
105    def reset_avg(self):
106        self.i_avg = 0
def reset_all(self):
108    def reset_all(self):
109        self.i = 0
110        self.i_avg = 0
def increment(self):
112    def increment(self):
113        self.i = self.i + 1
114        if self.i >= self.nseg:
115            self.i = 0
116            self.i_avg = self.i_avg + 1
def detect_topology(species, coordinates, cell=None):
156def detect_topology(species,coordinates, cell=None):
157    """
158    Detects the topology of a system based on species and coordinates.
159    Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond.
160    Inspired by OpenBabel's ConnectTheDots in mol.cpp
161    """
162    from .periodic_table import COV_RADII, UFF_MAX_COORDINATION
163    radii = (COV_RADII* AtomicUnits.BOHR)[species]
164    max_coord = UFF_MAX_COORDINATION[species]
165
166    if cell is not None:
167        bond1,bond2,distances = _detect_bonds_pbc(radii, coordinates, cell)
168    else:
169        bond1,bond2,distances = _detect_bonds(radii, coordinates)
170
171    bond1 = np.array(bond1, dtype=np.int32)
172    bond2 = np.array(bond2, dtype=np.int32)
173    bonds = np.stack((bond1, bond2), axis=1)
174    
175    coord = np.zeros(len(species), dtype=np.int32)
176    np.add.at(coord, bonds[:, 0], 1)
177    np.add.at(coord, bonds[:, 1], 1)
178
179    if np.all(coord <= max_coord):
180        return bonds
181
182    distances = np.array(distances, dtype=np.float32)
183    radiibonds = radii[bonds]
184    req = radiibonds.sum(axis=1)
185    rminbonds = radiibonds.min(axis=1)
186    sorted_indices = np.lexsort((-distances/req, rminbonds))
187
188    bonds = bonds[sorted_indices,:]
189    distances = distances[sorted_indices]
190
191    true_bonds = []
192    for ibond in range(bonds.shape[0]):
193        i,j = bonds[ibond]
194        ci, cj = coord[i], coord[j]
195        mci, mcj = max_coord[i], max_coord[j]
196        if ci <= mci and cj <= mcj:
197            true_bonds.append((i, j))
198        else:
199            coord[i] -= 1
200            coord[j] -= 1
201
202    true_bonds = np.array(true_bonds, dtype=np.int32)
203    sorted_indices = np.lexsort((true_bonds[:, 1], true_bonds[:, 0]))
204    true_bonds = true_bonds[sorted_indices, :]
205
206    return true_bonds

Detects the topology of a system based on species and coordinates. Returns a np.ndarray of shape [nbonds,2] containing the two indices for each bond. Inspired by OpenBabel's ConnectTheDots in mol.cpp

def get_energy_gradient_function(energy_function, gradient_keys: Sequence[str], jit: bool = True):
208def get_energy_gradient_function(
209        energy_function,
210        gradient_keys: Sequence[str],
211        jit: bool = True,
212    ):
213        """Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys"""
214
215        def energy_gradient(data):
216            def _etot(inputs):
217                if "strain" in inputs:
218                    scaling = inputs["strain"]
219                    batch_index = data["batch_index"]
220                    coordinates = inputs["coordinates"] if "coordinates" in inputs else data["coordinates"]
221                    coordinates = jax.vmap(jnp.matmul)(
222                        coordinates, scaling[batch_index]
223                    )
224                    inputs = {**inputs, "coordinates": coordinates}
225                    if "cells" in inputs or "cells" in data:
226                        cells = inputs["cells"] if "cells" in inputs else data["cells"]
227                        cells = jax.vmap(jnp.matmul)(cells, scaling)
228                        inputs["cells"] = cells
229                if "cells" in inputs:
230                    reciprocal_cells = jnp.linalg.inv(inputs["cells"])
231                    inputs = {**inputs, "reciprocal_cells": reciprocal_cells}
232                energy, out = energy_function(inputs)
233                return energy.sum(), out
234
235            if "strain" in gradient_keys and "strain" not in data:
236                data = {**data, "strain": jnp.array(np.eye(3)[None, :, :].repeat(data["natoms"].shape[0], axis=0))}
237            inputs = {k: data[k] for k in gradient_keys}
238            de, out = jax.grad(_etot, argnums=1, has_aux=True)(inputs)
239
240            return (
241                de,
242                {**out, **{"dEd_" + k: de[k] for k in gradient_keys}},
243            )
244
245        if jit:
246            return jax.jit(energy_gradient)
247        else:
248            return energy_gradient

Return a function that computes the energy and the gradient of the energy with respect to the keys in gradient_keys

def read_tinker_interval(indices_interval: Sequence[Union[int, str]]) -> numpy.ndarray:
251def read_tinker_interval(indices_interval: Sequence[Union[int,str]]) -> np.ndarray:
252    interval = [int(i) for i in indices_interval]
253    indices = []
254    while len(interval) > 0:
255        i = interval.pop(0)
256        if i > 0:
257            indices.append(i)
258        elif i < 0:
259            start = -i
260            end = interval.pop(0)
261            assert end > start, "Syntax error in ligand indices. End index must be greater than start index."
262            indices.extend(range(start, end + 1))
263        else:
264            raise ValueError("Syntax error in ligand indices. Indicing should be 1-based.")
265    indices = np.unique(np.array(indices, dtype=np.int32))
266    return indices - 1  # Convert to zero-based indexing