fennol.models.physics.electrostatics

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4import numpy as np
  5
  6# from jaxopt.linear_solve import solve_cg, solve_iterative_refinement, solve_gmres
  7from typing import Any, Dict, Union, Callable, Sequence, Optional, ClassVar
  8from ...utils import AtomicUnits as au
  9import dataclasses
 10from ...utils.periodic_table import (
 11    D3_ELECTRONEGATIVITIES,
 12    D3_HARDNESSES,
 13    D3_VDW_RADII,
 14    D3_COV_RADII,
 15    D3_KAPPA,
 16    VDW_RADII,
 17    POLARIZABILITIES,
 18    VALENCE_ELECTRONS,
 19)
 20import math
 21
 22
 23def prepare_reciprocal_space(
 24    cells, reciprocal_cells, coordinates, batch_index, k_points, bewald
 25):
 26    """Prepare variables for Ewald summation in reciprocal space"""
 27    A = reciprocal_cells
 28    if A.shape[0] == 1:
 29        s = coordinates @ A[0]
 30        ks = 2j * jnp.pi * jnp.einsum("ai,ki-> ak", s, k_points[0])  # nat x nk
 31    else:
 32        s = jnp.einsum("aj,aji->ai", coordinates,A[batch_index])
 33        ks = (
 34            2j * jnp.pi * jnp.einsum("ai,aki-> ak", s, k_points[batch_index])
 35        )  # nat x nk
 36
 37    m2 = jnp.sum(
 38        jnp.einsum("ski,sji->skj", k_points, A) ** 2,
 39        axis=-1,
 40    )  # nsys x nk
 41    a2 = (jnp.pi / bewald) ** 2
 42    expfac = jnp.exp(-a2 * m2) / m2  # nsys x nk
 43
 44    volume = jnp.abs(jnp.linalg.det(cells))  # nsys
 45    phiscale = (au.BOHR / jnp.pi) / volume
 46    selfscale = bewald * (2 * au.BOHR / jnp.pi**0.5)
 47    return batch_index, k_points, phiscale, selfscale, expfac, ks
 48
 49
 50def ewald_reciprocal(q, batch_index, k_points, phiscale, selfscale, expfac, ks):
 51    """Compute Coulomb interactions in reciprocal space using Ewald summation"""
 52    if phiscale.shape[0] == 1:
 53        Sm = (q[:, None] * jnp.exp(ks)).sum(axis=0)[None, :]  # nys x nk
 54    else:
 55        Sm = jax.ops.segment_sum(
 56            q[:, None] * jnp.exp(ks), batch_index, k_points.shape[0]
 57        )  # nsys x nk
 58
 59    ### compute reciprocal Coulomb potential (https://arxiv.org/abs/1805.10363)
 60    phi = (
 61        jnp.real(((Sm * expfac)[batch_index] * jnp.exp(-ks)).sum(axis=-1))
 62        * phiscale[batch_index]
 63        - q * selfscale
 64    )
 65
 66    return 0.5 * q * phi, phi
 67
 68
 69# def ewald_reciprocal(
 70#     q, cells, reciprocal_cells, coordinates, batch_index, k_points, bewald
 71# ):
 72#     A = reciprocal_cells
 73#     ### Ewald reciprocal space
 74
 75#     if A.shape[0] == 1:
 76#         s = jnp.einsum("ij,aj->ai", A[0], coordinates)
 77#         ks = 2j * jnp.pi * jnp.einsum("ai,ki-> ak", s, k_points[0])  # nat x nk
 78#         Sm = (q[:, None] * jnp.exp(ks)).sum(axis=0)[None, :]  # nys x nk
 79#     else:
 80#         s = jnp.einsum("aij,aj->ai", A[batch_index], coordinates)
 81#         ks = (
 82#             2j * jnp.pi * jnp.einsum("ai,aki-> ak", s, k_points[batch_index])
 83#         )  # nat x nk
 84#         Sm = jax.ops.segment_sum(
 85#             q[:, None] * jnp.exp(ks), batch_index, k_points.shape[0]
 86#         )  # nsys x nk
 87
 88#     m2 = jnp.sum(
 89#         jnp.einsum("sij,ski->skj", A, k_points) ** 2,
 90#         axis=-1,
 91#     )  # nsys x nk
 92#     a2 = (jnp.pi / bewald) ** 2
 93#     expfac = Sm * jnp.exp(-a2 * m2) / m2  # nsys x nk
 94#     volume = jnp.linalg.det(cells)  # nsys
 95
 96#     ### compute reciprocal Coulomb potential (https://arxiv.org/abs/1805.10363)
 97#     phi = jnp.real((expfac[batch_index] * jnp.exp(-ks)).sum(axis=-1)) * (
 98#         (au.BOHR / jnp.pi) / volume[batch_index]
 99#     ) - q * (bewald * (2 * au.BOHR / jnp.pi**0.5))
100
101#     return 0.5 * q * phi, phi
102
103
104class Coulomb(nn.Module):
105    """Coulomb interaction between distributed point charges
106
107    FID: COULOMB   
108    
109    """
110    _graphs_properties: Dict
111    graph_key: str = "graph"
112    """Key for the graph in the inputs"""
113    charges_key: str = "charges"
114    """Key for the charges in the inputs"""
115    energy_key: Optional[str] = None
116    """Key for the energy in the outputs"""
117    # switch_fraction: float = 0.9
118    scale: Optional[float] = None
119    """Scaling factor for the energy"""
120    charge_scale: Optional[float] = None
121    """Scaling factor for the charges"""
122    damp_style: Optional[str] = None
123    """Damping style. Available options are: None, 'TS', 'OQDO', 'D3', 'SPOOKY', 'CP', 'KEY'"""
124    damp_params: Dict = dataclasses.field(default_factory=dict)
125    """Damping parameters"""
126    bscreen: float = -1.0
127    """Screening parameter. If >0, the Coulomb potential becomes a Yukawa potential and the reciprocal space is not computed"""
128    trainable: bool = True
129    """Whether the parameters are trainable"""
130    _energy_unit: str = "Ha"
131    """The energy unit of the model. **Automatically set by FENNIX**"""
132
133    FID: ClassVar[str] = "COULOMB"
134
135    @nn.compact
136    def __call__(self, inputs):
137        species = inputs["species"]
138        graph = inputs[self.graph_key]
139        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
140        distances = graph["distances"]
141        switch = graph["switch"]
142
143        rij = distances / au.BOHR
144        q = inputs[self.charges_key]
145        if q.shape[-1] == 1:
146            q = jnp.squeeze(q, axis=-1)
147        if self.charge_scale is not None:
148            if self.trainable:
149                charge_scale = jnp.abs(
150                    self.param(
151                        "charge_scale", lambda key: jnp.asarray(self.charge_scale)
152                    )
153                )
154            else:
155                charge_scale = self.charge_scale
156            q = q * charge_scale
157
158        damp_style = self.damp_style.upper() if self.damp_style is not None else None
159
160        do_recip = self.bscreen <= 0.0 and "k_points" in graph
161
162        if self.bscreen > 0.0:
163            # dirfact = jax.scipy.special.erfc(self.bscreen * distances)
164            dirfact = jnp.exp(-self.bscreen * distances)
165        elif do_recip:
166            k_points = graph["k_points"]
167            bewald = graph["b_ewald"]
168            cells = inputs["cells"]
169            reciprocal_cells = inputs["reciprocal_cells"]
170            batch_index = inputs["batch_index"]
171            erec, _ = ewald_reciprocal(
172                q,
173                *prepare_reciprocal_space(
174                    cells,
175                    reciprocal_cells,
176                    inputs["coordinates"],
177                    batch_index,
178                    k_points,
179                    bewald,
180                ),
181            )
182            dirfact = jax.scipy.special.erfc(bewald * distances)
183        else:
184            dirfact = 1.0
185
186        if damp_style is None:
187            Aij = switch * dirfact / rij
188            eat = (
189                0.5
190                * q
191                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
192            )
193
194        elif damp_style == "TS":
195            cpB = self.damp_params.get("cpB", 3.5)
196            s = self.damp_params.get("s", 2.4)
197
198            if self.trainable:
199                cpB = jnp.abs(self.param("cpB", lambda key: jnp.asarray(cpB)))
200                s = jnp.abs(self.param("s", lambda key: jnp.asarray(s)))
201
202            ratiovol_key = self.damp_params.get("ratiovol_key", None)
203            if ratiovol_key is not None:
204                ratiovol = inputs[ratiovol_key]
205                if ratiovol.shape[-1] == 1:
206                    ratiovol = jnp.squeeze(ratiovol, axis=-1)
207                rvdw = jnp.asarray(VDW_RADII)[species] * ratiovol ** (1.0 / 3.0)
208            else:
209                rvdw = jnp.asarray(VDW_RADII)[species]
210            Rij = rvdw[edge_src] + rvdw[edge_dst]
211            Bij = cpB * (rij / Rij) ** s
212
213            eBij = jnp.where(Bij < 20.0, jnp.exp(-Bij), 0.0)
214
215            Aij = (dirfact - eBij) / rij * switch
216            eat = (
217                0.5
218                * q
219                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
220            )
221
222        elif damp_style == "OQDO":
223            ratiovol_key = self.damp_params.get("ratiovol_key", None)
224            alpha = jnp.asarray(POLARIZABILITIES)[species]
225            if ratiovol_key is not None:
226                ratiovol = inputs[ratiovol_key]
227                if ratiovol.shape[-1] == 1:
228                    ratiovol = jnp.squeeze(ratiovol, axis=-1)
229                alpha = alpha * ratiovol
230
231            alphaij = 0.5 * (alpha[edge_src] + alpha[edge_dst])
232            Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
233            Re2 = Re**2
234            Re4 = Re**4
235            muw = (
236                3.66316787e01
237                - 5.79579187 * Re
238                + 3.02674813e-01 * Re2
239                - 3.65461255e-04 * Re4
240            ) / (-1.46169102e01 + 7.32461225 * Re)
241            # muw = (
242            #     4.83053463e-01
243            #     - 3.76191669e-02 * Re
244            #     + 1.27066988e-03 * Re2
245            #     - 7.21940151e-07 * Re4
246            # ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2)
247            Bij = 0.5 * muw * rij**2
248
249            eBij = jnp.where(Bij < 20.0, jnp.exp(-Bij), 0.0)
250
251            Aij = (dirfact - eBij) / rij * switch
252            eat = (
253                0.5
254                * q
255                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
256            )
257
258        elif damp_style == "D3":
259            ratiovol_key = self.damp_params.get("ratiovol_key", None)
260            if ratiovol_key is not None:
261                ratiovol = inputs[ratiovol_key]
262                if ratiovol.shape[-1] == 1:
263                    ratiovol = jnp.squeeze(ratiovol, axis=-1)
264            else:
265                ratiovol = 1.0
266
267            gamma_scheme = self.damp_params.get("gamma_scheme", "D3")
268            if gamma_scheme == "D3":
269                if self.trainable:
270                    rvdw = jnp.abs(
271                        self.param("rvdw", lambda key: jnp.asarray(VDW_RADII))
272                    )[species]
273                else:
274                    rvdw = jnp.asarray(VDW_RADII)[species]
275                rvdw = rvdw * ratiovol ** (1.0 / 3.0)
276                ai2 = rvdw**2
277                gamma_ij = (ai2[edge_src] + ai2[edge_dst] + 1.0e-3) ** (-0.5)
278
279            elif gamma_scheme == "QDO":
280                gscale = self.damp_params.get("gamma_scale", 2.0)
281                if self.trainable:
282                    gscale = jnp.abs(
283                        self.param("gamma_scale", lambda key: jnp.asarray(gscale))
284                    )
285                    alpha = jnp.abs(
286                        self.param("alpha", lambda key: jnp.asarray(POLARIZABILITIES))
287                    )[species]
288                else:
289                    alpha = jnp.asarray(POLARIZABILITIES)[species]
290                alpha = alpha * ratiovol
291                alphaij = 0.5 * (alpha[edge_src] + alpha[edge_dst])
292                rvdwij = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
293                gamma_ij = gscale / rvdwij
294            else:
295                raise NotImplementedError(
296                    f"gamma_scheme {gamma_scheme} not implemented"
297                )
298
299            Aij = (dirfact - jax.scipy.special.erfc(gamma_ij * rij)) / rij * switch
300
301            eat = (
302                0.5
303                * q
304                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
305            )
306
307        elif damp_style == "SPOOKY":
308            shortrange_cutoff = self.damp_params.get("shortrange_cutoff", 5.0)
309            r_on = self.damp_params.get("r_on", 0.25) * shortrange_cutoff
310            r_off = self.damp_params.get("r_off", 0.75) * shortrange_cutoff
311            x1 = (distances - r_on) / (r_off - r_on)
312            x2 = 1.0 - x1
313            mask1 = x1 <= 1.0e-6
314            mask2 = x2 <= 1.0e-6
315            x1 = jnp.where(mask1, 1.0, x1)
316            x2 = jnp.where(mask2, 1.0, x2)
317            s1 = jnp.where(mask1, 0.0, jnp.exp(-1.0 / x1))
318            s2 = jnp.where(mask2, 0.0, jnp.exp(-1.0 / x2))
319            Bij = s2 / (s1 + s2)
320
321            Aij = Bij / (rij**2 + 1) ** 0.5 + (dirfact - Bij) / rij * switch
322            eat = (
323                0.5
324                * q
325                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
326            )
327
328        elif damp_style == "CP":
329            cpA = self.damp_params.get("cpA", 4.42)
330            cpB = self.damp_params.get("cpB", 4.12)
331            gamma = self.damp_params.get("gamma", 0.5)
332
333            if self.trainable:
334                cpA = jnp.abs(self.param("cpA", lambda key: jnp.asarray(cpA)))
335                cpB = jnp.abs(self.param("cpB", lambda key: jnp.asarray(cpB)))
336                gamma = jnp.abs(self.param("gamma", lambda key: jnp.asarray(gamma)))
337
338            rvdw = jnp.asarray(VDW_RADII)[species]
339            ratiovol_key = self.damp_params.get("ratiovol_key", None)
340            if ratiovol_key is not None:
341                ratiovol = inputs[ratiovol_key]
342                if ratiovol.shape[-1] == 1:
343                    ratiovol = jnp.squeeze(ratiovol, axis=-1)
344                rvdw = rvdw * ratiovol ** (1.0 / 3.0)
345
346            Zv = jnp.asarray(VALENCE_ELECTRONS)[species]
347            Zi, Zj = Zv[edge_src], Zv[edge_dst]
348            qi, qj = q[edge_src], q[edge_dst]
349            rvdwi, rvdwj = rvdw[edge_src], rvdw[edge_dst]
350
351            eAi = jnp.exp(-cpA * rij / rvdwi)
352            eAj = jnp.exp(-cpA * rij / rvdwj)
353            eBi = jnp.exp(-cpB * rij / rvdwi)
354            eBj = jnp.exp(-cpB * rij / rvdwj)
355            eBij = eBi * eBj - eBi - eBj
356
357            Bshort = jnp.exp(-gamma * distances**4)
358            Dshort = 1.0 - Bshort
359            ecp = Dshort * (
360                Zi * Zj * (eAi + eAj + eBij)
361                - qi * Zj * (eAi + eBij)
362                - qj * Zi * (eAj + eBij)
363            )
364
365            # eq = qi * qj * (1 + eBij) * (1 - Bshort)
366            eqq = qi * qj * (dirfact - Bshort + eBij * Dshort)
367
368            epair = (ecp + eqq) * switch / rij
369
370            # epair = (
371            #     (1 - Bshort)
372            #     * (
373            #         Zi * Zj * (eAi + eAj + eBij)
374            #         - qi * Zj * (eAi + eBij)
375            #         - qj * Zi * (eAj + eBij)
376            #         + qi * qj * (1 + eBij)
377            #     )
378            #     * switch
379            #     / rij
380            # )
381            eat = 0.5 * jax.ops.segment_sum(epair, edge_src, species.shape[0])
382
383        elif damp_style == "KEY":
384            damp_key = self.damp_params["key"]
385            damp = inputs[damp_key]
386            epair = (dirfact - damp) * switch / rij
387            eat = 0.5 * jax.ops.segment_sum(epair, edge_src, species.shape[0])
388        else:
389            raise NotImplementedError(f"damp_style {self.damp_style} not implemented")
390
391        if do_recip:
392            eat = eat + erec
393
394        if self.scale is not None:
395            if self.trainable:
396                scale = jnp.abs(
397                    self.param("scale", lambda key: jnp.asarray(self.scale))
398                )
399            else:
400                scale = self.scale
401            eat = eat * scale
402
403        energy_key = self.energy_key if self.energy_key is not None else self.name
404        energy_unit = au.get_multiplier(self._energy_unit)
405        out = {**inputs, energy_key: eat*energy_unit}
406        if do_recip:
407            out[energy_key + "_reciprocal"] = erec*energy_unit
408        return out
409
410
411class QeqD4(nn.Module):
412    """ QEq-D4 charge equilibration scheme
413
414    FID: QEQ_D4
415    
416    ### Reference
417    E. Caldeweyher et al.,A generally applicable atomic-charge dependent London dispersion correction,
418    J Chem Phys. 2019 Apr 21;150(15):154122. (https://doi.org/10.1063/1.5090222)
419    """
420    graph_key: str = "graph"
421    """Key for the graph in the inputs"""
422    trainable: bool = False
423    """Whether the parameters are trainable"""
424    charges_key: str = "charges"
425    """Key for the charges in the outputs. 
426        If charges are provided in the inputs, they are not re-optimized and we only compute the energy"""
427    energy_key: Optional[str] = None
428    """Key for the energy in the outputs"""
429    chi_key: Optional[str] = None
430    """Key for additional electronegativity in the inputs"""
431    c3_key: Optional[str] = None
432    """Key for additional c3 in the inputs. Only used if charges are provided in the inputs"""
433    c4_key: Optional[str] = None
434    """Key for additional c4 in the inputs. Only used if charges are provided in the inputs"""
435    total_charge_key: str = "total_charge"
436    """Key for the total charge in the inputs"""
437    non_interacting_guess: bool = False
438    """Whether to use the non-interacting limit as an initial guess."""
439    solver: str = "gmres"
440    _energy_unit: str = "Ha"
441    """The energy unit of the model. **Automatically set by FENNIX**"""
442
443    FID: ClassVar[str] = "QEQ_D4"
444
445    @nn.compact
446    def __call__(self, inputs):
447        species = inputs["species"]
448        graph = inputs[self.graph_key]
449        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
450        switch = graph["switch"]
451
452        rij = graph["distances"] / au.BOHR
453
454        do_recip = "k_points" in graph
455        if do_recip:
456            k_points = graph["k_points"]
457            bewald = graph["b_ewald"]
458            cells = inputs["cells"]
459            reciprocal_cells = inputs["reciprocal_cells"]
460            batch_index = inputs["batch_index"]
461            dirfact = jax.scipy.special.erfc(bewald * graph["distances"])
462            ewald_params = prepare_reciprocal_space(
463                cells,
464                reciprocal_cells,
465                inputs["coordinates"],
466                batch_index,
467                k_points,
468                bewald,
469            )
470        else:
471            dirfact = 1.0
472
473        Jii = D3_HARDNESSES
474        ai = D3_VDW_RADII
475        ETA = [jii + (2.0 / np.pi) ** 0.5 / aii for jii, aii in zip(Jii, ai)]
476
477        # D3 parameters
478        if self.trainable:
479            ENi = self.param("EN", lambda key: jnp.asarray(D3_ELECTRONEGATIVITIES))[
480                species
481            ]
482            # Jii = self.param("J", lambda key: jnp.asarray(D3_HARDNESSES))[species]
483            eta = jnp.abs(self.param("eta", lambda key: jnp.asarray(ETA)))[species]
484            ai = jnp.abs(self.param("a", lambda key: jnp.asarray(D3_VDW_RADII)))[
485                species
486            ]
487            rci = jnp.abs(self.param("rc", lambda key: jnp.asarray(D3_COV_RADII)))[
488                species
489            ]
490            c3 = self.param("c3", lambda key: jnp.zeros(len(D3_ELECTRONEGATIVITIES)))[
491                species
492            ]
493            c4 = jnp.abs(
494                self.param("c4", lambda key: jnp.zeros(len(D3_ELECTRONEGATIVITIES)))[
495                    species
496                ]
497            )
498            kappai = self.param("kappa", lambda key: jnp.asarray(D3_KAPPA))[species]
499            k1 = jnp.abs(self.param("k1", lambda key: jnp.asarray(7.5)))
500            training = "training" in inputs.get("flags", {})
501            if training:
502                regularization = (
503                    (ENi - jnp.asarray(D3_ELECTRONEGATIVITIES)[species]) ** 2
504                    + (eta - jnp.asarray(ETA)[species]) ** 2
505                    + (ai - jnp.asarray(D3_VDW_RADII)[species]) ** 2
506                    + (rci - jnp.asarray(D3_COV_RADII)[species]) ** 2
507                    + (kappai - jnp.asarray(D3_KAPPA)[species]) ** 2
508                    + (k1 - 7.5) ** 2
509                )
510        else:
511            c3 = jnp.zeros_like(species, dtype=jnp.float32)
512            c4 = jnp.zeros_like(species, dtype=jnp.float32)
513            ENi = jnp.asarray(D3_ELECTRONEGATIVITIES)[species]
514            # Jii = jnp.asarray(D3_HARDNESSES)[species]
515            eta = jnp.asarray(ETA)[species]
516            ai = jnp.asarray(D3_VDW_RADII)[species]
517            rci = jnp.asarray(D3_COV_RADII)[species]
518            kappai = jnp.asarray(D3_KAPPA)[species]
519            k1 = 7.5
520
521        ai2 = ai**2
522        rcij = (
523            rci.at[edge_src].get(mode="fill", fill_value=1.0)
524            + rci.at[edge_dst].get(mode="fill", fill_value=1.0)
525            + 1.0e-3
526        )
527        mCNij = 1.0 + jax.scipy.special.erf(-k1 * (rij / rcij - 1))
528        mCNi = 0.5 * jax.ops.segment_sum(mCNij * switch, edge_src, species.shape[0])
529        chi = ENi - kappai * (mCNi + 1.0e-3) ** 0.5
530        if self.chi_key is not None:
531            chi = chi + inputs[self.chi_key]
532
533        gamma_ij = (
534            ai2.at[edge_src].get(mode="fill", fill_value=1.0)
535            + ai2.at[edge_dst].get(mode="fill", fill_value=1.0)
536            + 1.0e-3
537        ) ** (-0.5)
538
539        Aii = eta  # Jii + ((2.0 / np.pi) ** 0.5) / ai
540        # Aij = jax.scipy.special.erf(gamma_ij * rij) / rij * switch
541        Aij = (dirfact - jax.scipy.special.erfc(gamma_ij * rij)) / rij * switch
542
543        if self.charges_key in inputs:
544            q = inputs[self.charges_key]
545            q_ = q
546        else:
547            nsys = inputs["natoms"].shape[0]
548            batch_index = inputs["batch_index"]
549
550            def matvec(x):
551                l, q = jnp.split(x, (nsys,))
552                Aq_self = Aii * q
553                qdest = q.at[edge_dst].get(mode="fill", fill_value=0.0)
554                Aq_pair = jax.ops.segment_sum(Aij * qdest, edge_src, species.shape[0])
555                Aq = (
556                    Aq_self
557                    + Aq_pair
558                    + l.at[batch_index].get(mode="fill", fill_value=0.0)
559                )
560                if do_recip:
561                    _, phirec = ewald_reciprocal(q, *ewald_params)
562                    Aq = Aq + phirec
563                Al = jax.ops.segment_sum(q, batch_index, nsys)
564                return jnp.concatenate((Al, Aq))
565
566            Qtot = (
567                inputs[self.total_charge_key].astype(chi.dtype).reshape(nsys)
568                if self.total_charge_key in inputs
569                else jnp.zeros(nsys, dtype=chi.dtype)
570            )
571            b = jnp.concatenate([Qtot, -chi])
572
573            if self.non_interacting_guess:
574                # build initial guess
575                si = 1./Aii
576                q0 = -chi*si
577                qtot = jax.ops.segment_sum(q0,batch_index,nsys)
578                sisum = jax.ops.segment_sum(si,batch_index,nsys)
579                l0 = sisum*(qtot - Qtot)
580                q0 = q0 - si*l0[batch_index]
581                x0 = jnp.concatenate((l0,q0))
582            else:
583                x0 = None
584
585            solver = self.solver.lower()
586            if solver == "bicg":
587                x = jax.scipy.sparse.linalg.bicgstab(matvec, b,x0=x0)[0]
588            elif solver == "gmres":
589                x = jax.scipy.sparse.linalg.gmres(matvec, b,x0=x0)[0]
590            elif solver == "cg":
591                print("Warning: Use of cg solver for Qeq is not recommended")
592                x = jax.scipy.sparse.linalg.cg(matvec, b,x0=x0)[0]
593            else:
594                raise NotImplementedError(f"solver '{solver}' is not implemented. Choose one of [bicg, gmres]")
595
596
597            q = x[nsys:]
598            q_ = jax.lax.stop_gradient(q)
599
600        eself = 0.5 * Aii * q_**2 + chi * q_
601
602        phi = jax.ops.segment_sum(Aij * q_[edge_dst], edge_src, species.shape[0])
603        if do_recip:
604            erec, _ = ewald_reciprocal(q_, *ewald_params)
605        epair = 0.5 * q_ * phi
606
607        if self.charges_key in inputs:
608            if self.c3_key is not None:
609                c3 = c3 + inputs[self.c3_key]
610            if self.c4_key is not None:
611                c4 = c4 + inputs[self.c4_key]
612            eself = eself + c3 * q_**3 + c4 * q_**4
613            training = "training" in inputs.get("flags", {})
614            if self.trainable and training:
615                Aii_ = jax.lax.stop_gradient(Aii)
616                chi_ = jax.lax.stop_gradient(chi)
617                phi_ = jax.lax.stop_gradient(phi)
618                c3_ = jax.lax.stop_gradient(c3)
619                c4_ = jax.lax.stop_gradient(c4)
620                switch_ = jax.lax.stop_gradient(switch)
621                Aij_ = jax.lax.stop_gradient(Aij)
622                phi_ = jax.ops.segment_sum(
623                    Aij_ * q[edge_dst], edge_src, species.shape[0]
624                )
625
626                dedq = Aii_ * q + chi_ + phi_ + 3 * c3_ * q**2 + 4 * c4_ * q**3
627                dedq = jax.ops.segment_sum(
628                    switch_ * (dedq[edge_src] - dedq[edge_dst]) ** 2,
629                    edge_src,
630                    species.shape[0],
631                )
632                etrain = (
633                    0.5 * Aii_ * q**2
634                    + chi_ * q
635                    + 0.5 * q * phi_
636                    + c3_ * q**3
637                    + c4_ * q**4
638                )
639                if do_recip:
640                    etrain = etrain + erec
641
642        energy = eself + epair
643        if do_recip:
644            energy = energy + erec
645
646        energy_key = self.energy_key if self.energy_key is not None else self.name
647        energy_unit = au.get_multiplier(self._energy_unit)
648        output = {
649            **inputs,
650            self.charges_key: q,
651            energy_key: energy*energy_unit,
652        }
653        if do_recip:
654            output[energy_key + "_reciprocal"] = erec*energy_unit
655
656        training = "training" in inputs.get("flags", {})
657        if self.charges_key in inputs and self.trainable and training:
658            output[energy_key + "_regularization"] = regularization
659            output[energy_key + "_dedq"] = dedq*energy_unit
660            output[energy_key + "_etrain"] = etrain*energy_unit
661        return output
662
663
664class ChargeCorrection(nn.Module):
665    """Charge correction scheme
666    
667    FID: CHARGE_CORRECTION
668
669    Used to correct the provided charges to sum to the total charge of the system.
670    """
671    key: str = "charges"
672    """Key for the charges in the inputs"""
673    output_key: str = None
674    """Key for the corrected charges in the outputs. If None, it is the same as the input key"""
675    dq_key: str = "delta_qtot"
676    """Key for the deviation of the raw charge sum in the outputs"""
677    ratioeta_key: str = None
678    """Key for the ratio of hardness between AIM and free atom in the inputs. Used to adjust charge redistribution."""
679    trainable: bool = False
680    """Whether the parameters are trainable"""
681    cn_key: str = None
682    """Key for the coordination number in the inputs. Used to adjust charge redistribution."""
683    total_charge_key: str = "total_charge"
684    """Key for the total charge in the inputs"""
685    _energy_unit: str = "Ha"
686    """The energy unit of the model. **Automatically set by FENNIX**"""
687
688    FID: ClassVar[str] = "CHARGE_CORRECTION"
689
690    @nn.compact
691    def __call__(self, inputs) -> Any:
692        species = inputs["species"]
693        batch_index = inputs["batch_index"]
694        nsys = inputs["natoms"].shape[0]
695        q = inputs[self.key]
696        if q.shape[-1] == 1:
697            q = jnp.squeeze(q, axis=-1)
698        qtot = jax.ops.segment_sum(q, batch_index, nsys)
699        Qtot = (
700            inputs[self.total_charge_key].astype(q.dtype)
701            if self.total_charge_key in inputs
702            else jnp.zeros(qtot.shape[0], dtype=q.dtype)
703        )
704        dq = Qtot - qtot
705
706        Jii = D3_HARDNESSES
707        ai = D3_VDW_RADII
708        eta = [jii + (2.0 / np.pi) ** 0.5 / aii for jii, aii in zip(Jii, ai)]
709        if self.trainable:
710            eta = jnp.abs(self.param("eta", lambda key: jnp.asarray(eta)))[species]
711        else:
712            eta = jnp.asarray(eta)[species]
713
714        if self.ratioeta_key is not None:
715            ratioeta = inputs[self.ratioeta_key]
716            if ratioeta.shape[-1] == 1:
717                ratioeta = jnp.squeeze(ratioeta, axis=-1)
718            eta = eta * ratioeta
719
720        s = (1.0e-6 + 2 * jnp.abs(eta)) ** (-1)
721        if self.cn_key is not None:
722            cn = inputs[self.cn_key]
723            if cn.shape[-1] == 1:
724                cn = jnp.squeeze(cn, axis=-1)
725            s = s * cn
726
727        f = dq / jax.ops.segment_sum(s, batch_index, nsys)
728
729        qf = q + s * f[batch_index]
730
731        energy_unit = au.get_multiplier(self._energy_unit)
732        ecorr = (0.5*energy_unit) * eta * (qf - q) ** 2
733        output_key = self.output_key if self.output_key is not None else self.key
734        return {
735            **inputs,
736            output_key: qf,
737            self.dq_key: dq,
738            "charge_correction_energy": ecorr,
739        }
740
741class DistributeElectrons(nn.Module):
742    """Distribute valence electrons between the atoms
743
744    FID: DISTRIBUTE_ELECTRONS
745
746    Used to predict charges that sum to the total charge of the system.
747    """
748    embedding_key: str
749    """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight"""
750    output_key: Union[str,None] = None
751    """Key for the charges in the outputs"""
752    total_charge_key: str = "total_charge"
753    """Key for the total charge in the inputs"""
754
755    FID: ClassVar[str] = "DISTRIBUTE_ELECTRONS"
756
757    @nn.compact
758    def __call__(self, inputs) -> Any:
759        species = inputs["species"]
760        Nel = jnp.asarray(VALENCE_ELECTRONS)[species]
761
762        ei = nn.Dense(1, use_bias=True, name="wi")(inputs[self.embedding_key]).squeeze(-1)
763        wi = jax.nn.softplus(ei)
764
765        batch_index = inputs["batch_index"]
766        nsys = inputs["natoms"].shape[0]
767        wtot = jax.ops.segment_sum(wi, inputs["batch_index"], inputs["natoms"].shape[0])
768
769        Qtot = (
770            inputs[self.total_charge_key].astype(ei.dtype)
771            if self.total_charge_key in inputs
772            else jnp.zeros(nsys, dtype=ei.dtype)
773        )
774        Neltot = jax.ops.segment_sum(Nel, batch_index, nsys) - Qtot
775
776        f = Neltot / wtot
777        Ni = wi* f[batch_index]
778        q = Nel-Ni
779        
780
781        output_key = self.output_key if self.output_key is not None else self.name
782        return {
783            **inputs,
784            output_key: q,
785        }
def prepare_reciprocal_space(cells, reciprocal_cells, coordinates, batch_index, k_points, bewald):
24def prepare_reciprocal_space(
25    cells, reciprocal_cells, coordinates, batch_index, k_points, bewald
26):
27    """Prepare variables for Ewald summation in reciprocal space"""
28    A = reciprocal_cells
29    if A.shape[0] == 1:
30        s = coordinates @ A[0]
31        ks = 2j * jnp.pi * jnp.einsum("ai,ki-> ak", s, k_points[0])  # nat x nk
32    else:
33        s = jnp.einsum("aj,aji->ai", coordinates,A[batch_index])
34        ks = (
35            2j * jnp.pi * jnp.einsum("ai,aki-> ak", s, k_points[batch_index])
36        )  # nat x nk
37
38    m2 = jnp.sum(
39        jnp.einsum("ski,sji->skj", k_points, A) ** 2,
40        axis=-1,
41    )  # nsys x nk
42    a2 = (jnp.pi / bewald) ** 2
43    expfac = jnp.exp(-a2 * m2) / m2  # nsys x nk
44
45    volume = jnp.abs(jnp.linalg.det(cells))  # nsys
46    phiscale = (au.BOHR / jnp.pi) / volume
47    selfscale = bewald * (2 * au.BOHR / jnp.pi**0.5)
48    return batch_index, k_points, phiscale, selfscale, expfac, ks

Prepare variables for Ewald summation in reciprocal space

def ewald_reciprocal(q, batch_index, k_points, phiscale, selfscale, expfac, ks):
51def ewald_reciprocal(q, batch_index, k_points, phiscale, selfscale, expfac, ks):
52    """Compute Coulomb interactions in reciprocal space using Ewald summation"""
53    if phiscale.shape[0] == 1:
54        Sm = (q[:, None] * jnp.exp(ks)).sum(axis=0)[None, :]  # nys x nk
55    else:
56        Sm = jax.ops.segment_sum(
57            q[:, None] * jnp.exp(ks), batch_index, k_points.shape[0]
58        )  # nsys x nk
59
60    ### compute reciprocal Coulomb potential (https://arxiv.org/abs/1805.10363)
61    phi = (
62        jnp.real(((Sm * expfac)[batch_index] * jnp.exp(-ks)).sum(axis=-1))
63        * phiscale[batch_index]
64        - q * selfscale
65    )
66
67    return 0.5 * q * phi, phi

Compute Coulomb interactions in reciprocal space using Ewald summation

class Coulomb(flax.linen.module.Module):
105class Coulomb(nn.Module):
106    """Coulomb interaction between distributed point charges
107
108    FID: COULOMB   
109    
110    """
111    _graphs_properties: Dict
112    graph_key: str = "graph"
113    """Key for the graph in the inputs"""
114    charges_key: str = "charges"
115    """Key for the charges in the inputs"""
116    energy_key: Optional[str] = None
117    """Key for the energy in the outputs"""
118    # switch_fraction: float = 0.9
119    scale: Optional[float] = None
120    """Scaling factor for the energy"""
121    charge_scale: Optional[float] = None
122    """Scaling factor for the charges"""
123    damp_style: Optional[str] = None
124    """Damping style. Available options are: None, 'TS', 'OQDO', 'D3', 'SPOOKY', 'CP', 'KEY'"""
125    damp_params: Dict = dataclasses.field(default_factory=dict)
126    """Damping parameters"""
127    bscreen: float = -1.0
128    """Screening parameter. If >0, the Coulomb potential becomes a Yukawa potential and the reciprocal space is not computed"""
129    trainable: bool = True
130    """Whether the parameters are trainable"""
131    _energy_unit: str = "Ha"
132    """The energy unit of the model. **Automatically set by FENNIX**"""
133
134    FID: ClassVar[str] = "COULOMB"
135
136    @nn.compact
137    def __call__(self, inputs):
138        species = inputs["species"]
139        graph = inputs[self.graph_key]
140        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
141        distances = graph["distances"]
142        switch = graph["switch"]
143
144        rij = distances / au.BOHR
145        q = inputs[self.charges_key]
146        if q.shape[-1] == 1:
147            q = jnp.squeeze(q, axis=-1)
148        if self.charge_scale is not None:
149            if self.trainable:
150                charge_scale = jnp.abs(
151                    self.param(
152                        "charge_scale", lambda key: jnp.asarray(self.charge_scale)
153                    )
154                )
155            else:
156                charge_scale = self.charge_scale
157            q = q * charge_scale
158
159        damp_style = self.damp_style.upper() if self.damp_style is not None else None
160
161        do_recip = self.bscreen <= 0.0 and "k_points" in graph
162
163        if self.bscreen > 0.0:
164            # dirfact = jax.scipy.special.erfc(self.bscreen * distances)
165            dirfact = jnp.exp(-self.bscreen * distances)
166        elif do_recip:
167            k_points = graph["k_points"]
168            bewald = graph["b_ewald"]
169            cells = inputs["cells"]
170            reciprocal_cells = inputs["reciprocal_cells"]
171            batch_index = inputs["batch_index"]
172            erec, _ = ewald_reciprocal(
173                q,
174                *prepare_reciprocal_space(
175                    cells,
176                    reciprocal_cells,
177                    inputs["coordinates"],
178                    batch_index,
179                    k_points,
180                    bewald,
181                ),
182            )
183            dirfact = jax.scipy.special.erfc(bewald * distances)
184        else:
185            dirfact = 1.0
186
187        if damp_style is None:
188            Aij = switch * dirfact / rij
189            eat = (
190                0.5
191                * q
192                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
193            )
194
195        elif damp_style == "TS":
196            cpB = self.damp_params.get("cpB", 3.5)
197            s = self.damp_params.get("s", 2.4)
198
199            if self.trainable:
200                cpB = jnp.abs(self.param("cpB", lambda key: jnp.asarray(cpB)))
201                s = jnp.abs(self.param("s", lambda key: jnp.asarray(s)))
202
203            ratiovol_key = self.damp_params.get("ratiovol_key", None)
204            if ratiovol_key is not None:
205                ratiovol = inputs[ratiovol_key]
206                if ratiovol.shape[-1] == 1:
207                    ratiovol = jnp.squeeze(ratiovol, axis=-1)
208                rvdw = jnp.asarray(VDW_RADII)[species] * ratiovol ** (1.0 / 3.0)
209            else:
210                rvdw = jnp.asarray(VDW_RADII)[species]
211            Rij = rvdw[edge_src] + rvdw[edge_dst]
212            Bij = cpB * (rij / Rij) ** s
213
214            eBij = jnp.where(Bij < 20.0, jnp.exp(-Bij), 0.0)
215
216            Aij = (dirfact - eBij) / rij * switch
217            eat = (
218                0.5
219                * q
220                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
221            )
222
223        elif damp_style == "OQDO":
224            ratiovol_key = self.damp_params.get("ratiovol_key", None)
225            alpha = jnp.asarray(POLARIZABILITIES)[species]
226            if ratiovol_key is not None:
227                ratiovol = inputs[ratiovol_key]
228                if ratiovol.shape[-1] == 1:
229                    ratiovol = jnp.squeeze(ratiovol, axis=-1)
230                alpha = alpha * ratiovol
231
232            alphaij = 0.5 * (alpha[edge_src] + alpha[edge_dst])
233            Re = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
234            Re2 = Re**2
235            Re4 = Re**4
236            muw = (
237                3.66316787e01
238                - 5.79579187 * Re
239                + 3.02674813e-01 * Re2
240                - 3.65461255e-04 * Re4
241            ) / (-1.46169102e01 + 7.32461225 * Re)
242            # muw = (
243            #     4.83053463e-01
244            #     - 3.76191669e-02 * Re
245            #     + 1.27066988e-03 * Re2
246            #     - 7.21940151e-07 * Re4
247            # ) / (3.84212120e-02 - 3.16915319e-02 * Re + 2.37410890e-02 * Re2)
248            Bij = 0.5 * muw * rij**2
249
250            eBij = jnp.where(Bij < 20.0, jnp.exp(-Bij), 0.0)
251
252            Aij = (dirfact - eBij) / rij * switch
253            eat = (
254                0.5
255                * q
256                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
257            )
258
259        elif damp_style == "D3":
260            ratiovol_key = self.damp_params.get("ratiovol_key", None)
261            if ratiovol_key is not None:
262                ratiovol = inputs[ratiovol_key]
263                if ratiovol.shape[-1] == 1:
264                    ratiovol = jnp.squeeze(ratiovol, axis=-1)
265            else:
266                ratiovol = 1.0
267
268            gamma_scheme = self.damp_params.get("gamma_scheme", "D3")
269            if gamma_scheme == "D3":
270                if self.trainable:
271                    rvdw = jnp.abs(
272                        self.param("rvdw", lambda key: jnp.asarray(VDW_RADII))
273                    )[species]
274                else:
275                    rvdw = jnp.asarray(VDW_RADII)[species]
276                rvdw = rvdw * ratiovol ** (1.0 / 3.0)
277                ai2 = rvdw**2
278                gamma_ij = (ai2[edge_src] + ai2[edge_dst] + 1.0e-3) ** (-0.5)
279
280            elif gamma_scheme == "QDO":
281                gscale = self.damp_params.get("gamma_scale", 2.0)
282                if self.trainable:
283                    gscale = jnp.abs(
284                        self.param("gamma_scale", lambda key: jnp.asarray(gscale))
285                    )
286                    alpha = jnp.abs(
287                        self.param("alpha", lambda key: jnp.asarray(POLARIZABILITIES))
288                    )[species]
289                else:
290                    alpha = jnp.asarray(POLARIZABILITIES)[species]
291                alpha = alpha * ratiovol
292                alphaij = 0.5 * (alpha[edge_src] + alpha[edge_dst])
293                rvdwij = (alphaij * (128.0 / au.FSC ** (4.0 / 3.0))) ** (1.0 / 7.0)
294                gamma_ij = gscale / rvdwij
295            else:
296                raise NotImplementedError(
297                    f"gamma_scheme {gamma_scheme} not implemented"
298                )
299
300            Aij = (dirfact - jax.scipy.special.erfc(gamma_ij * rij)) / rij * switch
301
302            eat = (
303                0.5
304                * q
305                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
306            )
307
308        elif damp_style == "SPOOKY":
309            shortrange_cutoff = self.damp_params.get("shortrange_cutoff", 5.0)
310            r_on = self.damp_params.get("r_on", 0.25) * shortrange_cutoff
311            r_off = self.damp_params.get("r_off", 0.75) * shortrange_cutoff
312            x1 = (distances - r_on) / (r_off - r_on)
313            x2 = 1.0 - x1
314            mask1 = x1 <= 1.0e-6
315            mask2 = x2 <= 1.0e-6
316            x1 = jnp.where(mask1, 1.0, x1)
317            x2 = jnp.where(mask2, 1.0, x2)
318            s1 = jnp.where(mask1, 0.0, jnp.exp(-1.0 / x1))
319            s2 = jnp.where(mask2, 0.0, jnp.exp(-1.0 / x2))
320            Bij = s2 / (s1 + s2)
321
322            Aij = Bij / (rij**2 + 1) ** 0.5 + (dirfact - Bij) / rij * switch
323            eat = (
324                0.5
325                * q
326                * jax.ops.segment_sum(Aij * q[edge_dst], edge_src, species.shape[0])
327            )
328
329        elif damp_style == "CP":
330            cpA = self.damp_params.get("cpA", 4.42)
331            cpB = self.damp_params.get("cpB", 4.12)
332            gamma = self.damp_params.get("gamma", 0.5)
333
334            if self.trainable:
335                cpA = jnp.abs(self.param("cpA", lambda key: jnp.asarray(cpA)))
336                cpB = jnp.abs(self.param("cpB", lambda key: jnp.asarray(cpB)))
337                gamma = jnp.abs(self.param("gamma", lambda key: jnp.asarray(gamma)))
338
339            rvdw = jnp.asarray(VDW_RADII)[species]
340            ratiovol_key = self.damp_params.get("ratiovol_key", None)
341            if ratiovol_key is not None:
342                ratiovol = inputs[ratiovol_key]
343                if ratiovol.shape[-1] == 1:
344                    ratiovol = jnp.squeeze(ratiovol, axis=-1)
345                rvdw = rvdw * ratiovol ** (1.0 / 3.0)
346
347            Zv = jnp.asarray(VALENCE_ELECTRONS)[species]
348            Zi, Zj = Zv[edge_src], Zv[edge_dst]
349            qi, qj = q[edge_src], q[edge_dst]
350            rvdwi, rvdwj = rvdw[edge_src], rvdw[edge_dst]
351
352            eAi = jnp.exp(-cpA * rij / rvdwi)
353            eAj = jnp.exp(-cpA * rij / rvdwj)
354            eBi = jnp.exp(-cpB * rij / rvdwi)
355            eBj = jnp.exp(-cpB * rij / rvdwj)
356            eBij = eBi * eBj - eBi - eBj
357
358            Bshort = jnp.exp(-gamma * distances**4)
359            Dshort = 1.0 - Bshort
360            ecp = Dshort * (
361                Zi * Zj * (eAi + eAj + eBij)
362                - qi * Zj * (eAi + eBij)
363                - qj * Zi * (eAj + eBij)
364            )
365
366            # eq = qi * qj * (1 + eBij) * (1 - Bshort)
367            eqq = qi * qj * (dirfact - Bshort + eBij * Dshort)
368
369            epair = (ecp + eqq) * switch / rij
370
371            # epair = (
372            #     (1 - Bshort)
373            #     * (
374            #         Zi * Zj * (eAi + eAj + eBij)
375            #         - qi * Zj * (eAi + eBij)
376            #         - qj * Zi * (eAj + eBij)
377            #         + qi * qj * (1 + eBij)
378            #     )
379            #     * switch
380            #     / rij
381            # )
382            eat = 0.5 * jax.ops.segment_sum(epair, edge_src, species.shape[0])
383
384        elif damp_style == "KEY":
385            damp_key = self.damp_params["key"]
386            damp = inputs[damp_key]
387            epair = (dirfact - damp) * switch / rij
388            eat = 0.5 * jax.ops.segment_sum(epair, edge_src, species.shape[0])
389        else:
390            raise NotImplementedError(f"damp_style {self.damp_style} not implemented")
391
392        if do_recip:
393            eat = eat + erec
394
395        if self.scale is not None:
396            if self.trainable:
397                scale = jnp.abs(
398                    self.param("scale", lambda key: jnp.asarray(self.scale))
399                )
400            else:
401                scale = self.scale
402            eat = eat * scale
403
404        energy_key = self.energy_key if self.energy_key is not None else self.name
405        energy_unit = au.get_multiplier(self._energy_unit)
406        out = {**inputs, energy_key: eat*energy_unit}
407        if do_recip:
408            out[energy_key + "_reciprocal"] = erec*energy_unit
409        return out

Coulomb interaction between distributed point charges

FID: COULOMB

Coulomb( _graphs_properties: Dict, graph_key: str = 'graph', charges_key: str = 'charges', energy_key: Optional[str] = None, scale: Optional[float] = None, charge_scale: Optional[float] = None, damp_style: Optional[str] = None, damp_params: Dict = <factory>, bscreen: float = -1.0, trainable: bool = True, _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

Key for the graph in the inputs

charges_key: str = 'charges'

Key for the charges in the inputs

energy_key: Optional[str] = None

Key for the energy in the outputs

scale: Optional[float] = None

Scaling factor for the energy

charge_scale: Optional[float] = None

Scaling factor for the charges

damp_style: Optional[str] = None

Damping style. Available options are: None, 'TS', 'OQDO', 'D3', 'SPOOKY', 'CP', 'KEY'

damp_params: Dict

Damping parameters

bscreen: float = -1.0

Screening parameter. If >0, the Coulomb potential becomes a Yukawa potential and the reciprocal space is not computed

trainable: bool = True

Whether the parameters are trainable

FID: ClassVar[str] = 'COULOMB'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class QeqD4(flax.linen.module.Module):
412class QeqD4(nn.Module):
413    """ QEq-D4 charge equilibration scheme
414
415    FID: QEQ_D4
416    
417    ### Reference
418    E. Caldeweyher et al.,A generally applicable atomic-charge dependent London dispersion correction,
419    J Chem Phys. 2019 Apr 21;150(15):154122. (https://doi.org/10.1063/1.5090222)
420    """
421    graph_key: str = "graph"
422    """Key for the graph in the inputs"""
423    trainable: bool = False
424    """Whether the parameters are trainable"""
425    charges_key: str = "charges"
426    """Key for the charges in the outputs. 
427        If charges are provided in the inputs, they are not re-optimized and we only compute the energy"""
428    energy_key: Optional[str] = None
429    """Key for the energy in the outputs"""
430    chi_key: Optional[str] = None
431    """Key for additional electronegativity in the inputs"""
432    c3_key: Optional[str] = None
433    """Key for additional c3 in the inputs. Only used if charges are provided in the inputs"""
434    c4_key: Optional[str] = None
435    """Key for additional c4 in the inputs. Only used if charges are provided in the inputs"""
436    total_charge_key: str = "total_charge"
437    """Key for the total charge in the inputs"""
438    non_interacting_guess: bool = False
439    """Whether to use the non-interacting limit as an initial guess."""
440    solver: str = "gmres"
441    _energy_unit: str = "Ha"
442    """The energy unit of the model. **Automatically set by FENNIX**"""
443
444    FID: ClassVar[str] = "QEQ_D4"
445
446    @nn.compact
447    def __call__(self, inputs):
448        species = inputs["species"]
449        graph = inputs[self.graph_key]
450        edge_src, edge_dst = graph["edge_src"], graph["edge_dst"]
451        switch = graph["switch"]
452
453        rij = graph["distances"] / au.BOHR
454
455        do_recip = "k_points" in graph
456        if do_recip:
457            k_points = graph["k_points"]
458            bewald = graph["b_ewald"]
459            cells = inputs["cells"]
460            reciprocal_cells = inputs["reciprocal_cells"]
461            batch_index = inputs["batch_index"]
462            dirfact = jax.scipy.special.erfc(bewald * graph["distances"])
463            ewald_params = prepare_reciprocal_space(
464                cells,
465                reciprocal_cells,
466                inputs["coordinates"],
467                batch_index,
468                k_points,
469                bewald,
470            )
471        else:
472            dirfact = 1.0
473
474        Jii = D3_HARDNESSES
475        ai = D3_VDW_RADII
476        ETA = [jii + (2.0 / np.pi) ** 0.5 / aii for jii, aii in zip(Jii, ai)]
477
478        # D3 parameters
479        if self.trainable:
480            ENi = self.param("EN", lambda key: jnp.asarray(D3_ELECTRONEGATIVITIES))[
481                species
482            ]
483            # Jii = self.param("J", lambda key: jnp.asarray(D3_HARDNESSES))[species]
484            eta = jnp.abs(self.param("eta", lambda key: jnp.asarray(ETA)))[species]
485            ai = jnp.abs(self.param("a", lambda key: jnp.asarray(D3_VDW_RADII)))[
486                species
487            ]
488            rci = jnp.abs(self.param("rc", lambda key: jnp.asarray(D3_COV_RADII)))[
489                species
490            ]
491            c3 = self.param("c3", lambda key: jnp.zeros(len(D3_ELECTRONEGATIVITIES)))[
492                species
493            ]
494            c4 = jnp.abs(
495                self.param("c4", lambda key: jnp.zeros(len(D3_ELECTRONEGATIVITIES)))[
496                    species
497                ]
498            )
499            kappai = self.param("kappa", lambda key: jnp.asarray(D3_KAPPA))[species]
500            k1 = jnp.abs(self.param("k1", lambda key: jnp.asarray(7.5)))
501            training = "training" in inputs.get("flags", {})
502            if training:
503                regularization = (
504                    (ENi - jnp.asarray(D3_ELECTRONEGATIVITIES)[species]) ** 2
505                    + (eta - jnp.asarray(ETA)[species]) ** 2
506                    + (ai - jnp.asarray(D3_VDW_RADII)[species]) ** 2
507                    + (rci - jnp.asarray(D3_COV_RADII)[species]) ** 2
508                    + (kappai - jnp.asarray(D3_KAPPA)[species]) ** 2
509                    + (k1 - 7.5) ** 2
510                )
511        else:
512            c3 = jnp.zeros_like(species, dtype=jnp.float32)
513            c4 = jnp.zeros_like(species, dtype=jnp.float32)
514            ENi = jnp.asarray(D3_ELECTRONEGATIVITIES)[species]
515            # Jii = jnp.asarray(D3_HARDNESSES)[species]
516            eta = jnp.asarray(ETA)[species]
517            ai = jnp.asarray(D3_VDW_RADII)[species]
518            rci = jnp.asarray(D3_COV_RADII)[species]
519            kappai = jnp.asarray(D3_KAPPA)[species]
520            k1 = 7.5
521
522        ai2 = ai**2
523        rcij = (
524            rci.at[edge_src].get(mode="fill", fill_value=1.0)
525            + rci.at[edge_dst].get(mode="fill", fill_value=1.0)
526            + 1.0e-3
527        )
528        mCNij = 1.0 + jax.scipy.special.erf(-k1 * (rij / rcij - 1))
529        mCNi = 0.5 * jax.ops.segment_sum(mCNij * switch, edge_src, species.shape[0])
530        chi = ENi - kappai * (mCNi + 1.0e-3) ** 0.5
531        if self.chi_key is not None:
532            chi = chi + inputs[self.chi_key]
533
534        gamma_ij = (
535            ai2.at[edge_src].get(mode="fill", fill_value=1.0)
536            + ai2.at[edge_dst].get(mode="fill", fill_value=1.0)
537            + 1.0e-3
538        ) ** (-0.5)
539
540        Aii = eta  # Jii + ((2.0 / np.pi) ** 0.5) / ai
541        # Aij = jax.scipy.special.erf(gamma_ij * rij) / rij * switch
542        Aij = (dirfact - jax.scipy.special.erfc(gamma_ij * rij)) / rij * switch
543
544        if self.charges_key in inputs:
545            q = inputs[self.charges_key]
546            q_ = q
547        else:
548            nsys = inputs["natoms"].shape[0]
549            batch_index = inputs["batch_index"]
550
551            def matvec(x):
552                l, q = jnp.split(x, (nsys,))
553                Aq_self = Aii * q
554                qdest = q.at[edge_dst].get(mode="fill", fill_value=0.0)
555                Aq_pair = jax.ops.segment_sum(Aij * qdest, edge_src, species.shape[0])
556                Aq = (
557                    Aq_self
558                    + Aq_pair
559                    + l.at[batch_index].get(mode="fill", fill_value=0.0)
560                )
561                if do_recip:
562                    _, phirec = ewald_reciprocal(q, *ewald_params)
563                    Aq = Aq + phirec
564                Al = jax.ops.segment_sum(q, batch_index, nsys)
565                return jnp.concatenate((Al, Aq))
566
567            Qtot = (
568                inputs[self.total_charge_key].astype(chi.dtype).reshape(nsys)
569                if self.total_charge_key in inputs
570                else jnp.zeros(nsys, dtype=chi.dtype)
571            )
572            b = jnp.concatenate([Qtot, -chi])
573
574            if self.non_interacting_guess:
575                # build initial guess
576                si = 1./Aii
577                q0 = -chi*si
578                qtot = jax.ops.segment_sum(q0,batch_index,nsys)
579                sisum = jax.ops.segment_sum(si,batch_index,nsys)
580                l0 = sisum*(qtot - Qtot)
581                q0 = q0 - si*l0[batch_index]
582                x0 = jnp.concatenate((l0,q0))
583            else:
584                x0 = None
585
586            solver = self.solver.lower()
587            if solver == "bicg":
588                x = jax.scipy.sparse.linalg.bicgstab(matvec, b,x0=x0)[0]
589            elif solver == "gmres":
590                x = jax.scipy.sparse.linalg.gmres(matvec, b,x0=x0)[0]
591            elif solver == "cg":
592                print("Warning: Use of cg solver for Qeq is not recommended")
593                x = jax.scipy.sparse.linalg.cg(matvec, b,x0=x0)[0]
594            else:
595                raise NotImplementedError(f"solver '{solver}' is not implemented. Choose one of [bicg, gmres]")
596
597
598            q = x[nsys:]
599            q_ = jax.lax.stop_gradient(q)
600
601        eself = 0.5 * Aii * q_**2 + chi * q_
602
603        phi = jax.ops.segment_sum(Aij * q_[edge_dst], edge_src, species.shape[0])
604        if do_recip:
605            erec, _ = ewald_reciprocal(q_, *ewald_params)
606        epair = 0.5 * q_ * phi
607
608        if self.charges_key in inputs:
609            if self.c3_key is not None:
610                c3 = c3 + inputs[self.c3_key]
611            if self.c4_key is not None:
612                c4 = c4 + inputs[self.c4_key]
613            eself = eself + c3 * q_**3 + c4 * q_**4
614            training = "training" in inputs.get("flags", {})
615            if self.trainable and training:
616                Aii_ = jax.lax.stop_gradient(Aii)
617                chi_ = jax.lax.stop_gradient(chi)
618                phi_ = jax.lax.stop_gradient(phi)
619                c3_ = jax.lax.stop_gradient(c3)
620                c4_ = jax.lax.stop_gradient(c4)
621                switch_ = jax.lax.stop_gradient(switch)
622                Aij_ = jax.lax.stop_gradient(Aij)
623                phi_ = jax.ops.segment_sum(
624                    Aij_ * q[edge_dst], edge_src, species.shape[0]
625                )
626
627                dedq = Aii_ * q + chi_ + phi_ + 3 * c3_ * q**2 + 4 * c4_ * q**3
628                dedq = jax.ops.segment_sum(
629                    switch_ * (dedq[edge_src] - dedq[edge_dst]) ** 2,
630                    edge_src,
631                    species.shape[0],
632                )
633                etrain = (
634                    0.5 * Aii_ * q**2
635                    + chi_ * q
636                    + 0.5 * q * phi_
637                    + c3_ * q**3
638                    + c4_ * q**4
639                )
640                if do_recip:
641                    etrain = etrain + erec
642
643        energy = eself + epair
644        if do_recip:
645            energy = energy + erec
646
647        energy_key = self.energy_key if self.energy_key is not None else self.name
648        energy_unit = au.get_multiplier(self._energy_unit)
649        output = {
650            **inputs,
651            self.charges_key: q,
652            energy_key: energy*energy_unit,
653        }
654        if do_recip:
655            output[energy_key + "_reciprocal"] = erec*energy_unit
656
657        training = "training" in inputs.get("flags", {})
658        if self.charges_key in inputs and self.trainable and training:
659            output[energy_key + "_regularization"] = regularization
660            output[energy_key + "_dedq"] = dedq*energy_unit
661            output[energy_key + "_etrain"] = etrain*energy_unit
662        return output

QEq-D4 charge equilibration scheme

FID: QEQ_D4

Reference

E. Caldeweyher et al.,A generally applicable atomic-charge dependent London dispersion correction, J Chem Phys. 2019 Apr 21;150(15):154122. (https://doi.org/10.1063/1.5090222)

QeqD4( graph_key: str = 'graph', trainable: bool = False, charges_key: str = 'charges', energy_key: Optional[str] = None, chi_key: Optional[str] = None, c3_key: Optional[str] = None, c4_key: Optional[str] = None, total_charge_key: str = 'total_charge', non_interacting_guess: bool = False, solver: str = 'gmres', _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
graph_key: str = 'graph'

Key for the graph in the inputs

trainable: bool = False

Whether the parameters are trainable

charges_key: str = 'charges'

Key for the charges in the outputs. If charges are provided in the inputs, they are not re-optimized and we only compute the energy

energy_key: Optional[str] = None

Key for the energy in the outputs

chi_key: Optional[str] = None

Key for additional electronegativity in the inputs

c3_key: Optional[str] = None

Key for additional c3 in the inputs. Only used if charges are provided in the inputs

c4_key: Optional[str] = None

Key for additional c4 in the inputs. Only used if charges are provided in the inputs

total_charge_key: str = 'total_charge'

Key for the total charge in the inputs

non_interacting_guess: bool = False

Whether to use the non-interacting limit as an initial guess.

solver: str = 'gmres'
FID: ClassVar[str] = 'QEQ_D4'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChargeCorrection(flax.linen.module.Module):
665class ChargeCorrection(nn.Module):
666    """Charge correction scheme
667    
668    FID: CHARGE_CORRECTION
669
670    Used to correct the provided charges to sum to the total charge of the system.
671    """
672    key: str = "charges"
673    """Key for the charges in the inputs"""
674    output_key: str = None
675    """Key for the corrected charges in the outputs. If None, it is the same as the input key"""
676    dq_key: str = "delta_qtot"
677    """Key for the deviation of the raw charge sum in the outputs"""
678    ratioeta_key: str = None
679    """Key for the ratio of hardness between AIM and free atom in the inputs. Used to adjust charge redistribution."""
680    trainable: bool = False
681    """Whether the parameters are trainable"""
682    cn_key: str = None
683    """Key for the coordination number in the inputs. Used to adjust charge redistribution."""
684    total_charge_key: str = "total_charge"
685    """Key for the total charge in the inputs"""
686    _energy_unit: str = "Ha"
687    """The energy unit of the model. **Automatically set by FENNIX**"""
688
689    FID: ClassVar[str] = "CHARGE_CORRECTION"
690
691    @nn.compact
692    def __call__(self, inputs) -> Any:
693        species = inputs["species"]
694        batch_index = inputs["batch_index"]
695        nsys = inputs["natoms"].shape[0]
696        q = inputs[self.key]
697        if q.shape[-1] == 1:
698            q = jnp.squeeze(q, axis=-1)
699        qtot = jax.ops.segment_sum(q, batch_index, nsys)
700        Qtot = (
701            inputs[self.total_charge_key].astype(q.dtype)
702            if self.total_charge_key in inputs
703            else jnp.zeros(qtot.shape[0], dtype=q.dtype)
704        )
705        dq = Qtot - qtot
706
707        Jii = D3_HARDNESSES
708        ai = D3_VDW_RADII
709        eta = [jii + (2.0 / np.pi) ** 0.5 / aii for jii, aii in zip(Jii, ai)]
710        if self.trainable:
711            eta = jnp.abs(self.param("eta", lambda key: jnp.asarray(eta)))[species]
712        else:
713            eta = jnp.asarray(eta)[species]
714
715        if self.ratioeta_key is not None:
716            ratioeta = inputs[self.ratioeta_key]
717            if ratioeta.shape[-1] == 1:
718                ratioeta = jnp.squeeze(ratioeta, axis=-1)
719            eta = eta * ratioeta
720
721        s = (1.0e-6 + 2 * jnp.abs(eta)) ** (-1)
722        if self.cn_key is not None:
723            cn = inputs[self.cn_key]
724            if cn.shape[-1] == 1:
725                cn = jnp.squeeze(cn, axis=-1)
726            s = s * cn
727
728        f = dq / jax.ops.segment_sum(s, batch_index, nsys)
729
730        qf = q + s * f[batch_index]
731
732        energy_unit = au.get_multiplier(self._energy_unit)
733        ecorr = (0.5*energy_unit) * eta * (qf - q) ** 2
734        output_key = self.output_key if self.output_key is not None else self.key
735        return {
736            **inputs,
737            output_key: qf,
738            self.dq_key: dq,
739            "charge_correction_energy": ecorr,
740        }

Charge correction scheme

FID: CHARGE_CORRECTION

Used to correct the provided charges to sum to the total charge of the system.

ChargeCorrection( key: str = 'charges', output_key: str = None, dq_key: str = 'delta_qtot', ratioeta_key: str = None, trainable: bool = False, cn_key: str = None, total_charge_key: str = 'total_charge', _energy_unit: str = 'Ha', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
key: str = 'charges'

Key for the charges in the inputs

output_key: str = None

Key for the corrected charges in the outputs. If None, it is the same as the input key

dq_key: str = 'delta_qtot'

Key for the deviation of the raw charge sum in the outputs

ratioeta_key: str = None

Key for the ratio of hardness between AIM and free atom in the inputs. Used to adjust charge redistribution.

trainable: bool = False

Whether the parameters are trainable

cn_key: str = None

Key for the coordination number in the inputs. Used to adjust charge redistribution.

total_charge_key: str = 'total_charge'

Key for the total charge in the inputs

FID: ClassVar[str] = 'CHARGE_CORRECTION'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class DistributeElectrons(flax.linen.module.Module):
742class DistributeElectrons(nn.Module):
743    """Distribute valence electrons between the atoms
744
745    FID: DISTRIBUTE_ELECTRONS
746
747    Used to predict charges that sum to the total charge of the system.
748    """
749    embedding_key: str
750    """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight"""
751    output_key: Union[str,None] = None
752    """Key for the charges in the outputs"""
753    total_charge_key: str = "total_charge"
754    """Key for the total charge in the inputs"""
755
756    FID: ClassVar[str] = "DISTRIBUTE_ELECTRONS"
757
758    @nn.compact
759    def __call__(self, inputs) -> Any:
760        species = inputs["species"]
761        Nel = jnp.asarray(VALENCE_ELECTRONS)[species]
762
763        ei = nn.Dense(1, use_bias=True, name="wi")(inputs[self.embedding_key]).squeeze(-1)
764        wi = jax.nn.softplus(ei)
765
766        batch_index = inputs["batch_index"]
767        nsys = inputs["natoms"].shape[0]
768        wtot = jax.ops.segment_sum(wi, inputs["batch_index"], inputs["natoms"].shape[0])
769
770        Qtot = (
771            inputs[self.total_charge_key].astype(ei.dtype)
772            if self.total_charge_key in inputs
773            else jnp.zeros(nsys, dtype=ei.dtype)
774        )
775        Neltot = jax.ops.segment_sum(Nel, batch_index, nsys) - Qtot
776
777        f = Neltot / wtot
778        Ni = wi* f[batch_index]
779        q = Nel-Ni
780        
781
782        output_key = self.output_key if self.output_key is not None else self.name
783        return {
784            **inputs,
785            output_key: q,
786        }

Distribute valence electrons between the atoms

FID: DISTRIBUTE_ELECTRONS

Used to predict charges that sum to the total charge of the system.

DistributeElectrons( embedding_key: str, output_key: Optional[str] = None, total_charge_key: str = 'total_charge', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
embedding_key: str

Key for the embedding in the inputs that is used to predict an 'electron affinity' weight

output_key: Optional[str] = None

Key for the charges in the outputs

total_charge_key: str = 'total_charge'

Key for the total charge in the inputs

FID: ClassVar[str] = 'DISTRIBUTE_ELECTRONS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None