fennol.models.embeddings.charge_embeddings

 1import jax
 2import jax.numpy as jnp
 3import flax.linen as nn
 4from typing import Sequence, Dict, Union, ClassVar, Optional
 5import numpy as np
 6from ...utils.periodic_table import (
 7    VALENCE_ELECTRONS,
 8)
 9
10class ChargeHypothesis(nn.Module):
11    """Embedding with total charge constraint vi Multiple Neural Charge Equilibration.
12
13    FID: CHARGE_HYPOTHESIS
14
15    """
16
17    embedding_key: str
18    """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight"""
19    output_key: Union[str, None] = None
20    """Key for the charges in the outputs"""
21    total_charge_key: str = "total_charge"
22    """Key for the total charge in the inputs"""
23    ncharges: int = 10
24    """The number of charge hypothesis"""
25    mode: str = "qeq"
26    """Charge distribution mode. Only 'qeq' available for now."""
27    squeeze: bool = True
28
29    FID: ClassVar[str] = "CHARGE_HYPOTHESIS"
30
31    @nn.compact
32    def __call__(self, inputs):
33        embedding = inputs[self.embedding_key].astype(inputs["coordinates"].dtype)
34
35        wi = jax.nn.softplus(
36            nn.Dense(self.ncharges, use_bias=True, name="wi")(embedding)
37        )
38
39        batch_index = inputs["batch_index"]
40        nsys = inputs["natoms"].shape[0]
41        wtot = jax.ops.segment_sum(wi, batch_index, nsys)
42
43        Qtot = (
44            inputs[self.total_charge_key].astype(wi.dtype)
45            if self.total_charge_key in inputs
46            else jnp.zeros(nsys, dtype=wi.dtype)
47        )
48        if Qtot.ndim == 0:
49            Qtot = Qtot * jnp.ones(nsys, dtype=wi.dtype)
50
51        
52        qtilde = nn.Dense(self.ncharges, use_bias=True, name="qi")(embedding)
53        qtot = jax.ops.segment_sum(qtilde, batch_index, nsys)
54        dq = Qtot[:, None] - qtot
55        f = (dq / wtot)[batch_index]
56
57        if "alch_group" in inputs:
58            alch_group = inputs["alch_group"]
59            lambda_e = inputs["alch_elambda"]
60            Qligand = inputs.get("alch_ligand_charge", jnp.zeros(nsys))
61            if Qligand.ndim == 0:
62                Qligand = Qligand * jnp.ones(nsys, dtype=wi.dtype)
63
64            alch_index = alch_group + 2 * batch_index
65            qtot = jax.ops.segment_sum(qtilde, alch_index, 2 * nsys).reshape(
66                nsys, 2, self.ncharges
67            )
68            wtot = jax.ops.segment_sum(wi, alch_index, 2 * nsys).reshape(
69                nsys, 2, self.ncharges
70            )
71            # compute ligand alone
72            dq = Qligand[:, None] - qtot[:, 1, :]
73            f1 = (dq / wtot[:, 1, :])[batch_index]
74            # q1 = alch_group*(qtilde + wi * f1[batch_index])
75
76            # qtot = jax.ops.segment_sum(qtilde*(1-alch_group), batch_index, nsys)
77            # wtot = jax.ops.segment_sum(wi*(1-alch_group), batch_index, nsys)
78            # compute solvent alone
79            dq = (Qtot - Qligand)[:, None] - qtot[:, 0, :]
80            f0 = (dq / wtot[:, 0, :])[batch_index]
81            # q0 = (1-alch_group)*(qtilde + wi * f0[batch_index])
82
83            # alch_q = alch_group*q1 + (1-alch_group)*q0
84            # q = lambda_e*q + (1-lambda_e)*alch_q
85
86            alch_f = (1 - alch_group[:, None]) * f0 + alch_group[:, None] * f1
87            f = lambda_e * f + (1 - lambda_e) * alch_f
88
89        q = qtilde + wi * f
90
91        if self.squeeze and self.ncharges == 1:
92            q = jnp.squeeze(q, axis=-1)
93
94        output_key = self.output_key if self.output_key is not None else self.name
95        return {
96            **inputs,
97            output_key: q,
98        }
class ChargeHypothesis(flax.linen.module.Module):
11class ChargeHypothesis(nn.Module):
12    """Embedding with total charge constraint vi Multiple Neural Charge Equilibration.
13
14    FID: CHARGE_HYPOTHESIS
15
16    """
17
18    embedding_key: str
19    """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight"""
20    output_key: Union[str, None] = None
21    """Key for the charges in the outputs"""
22    total_charge_key: str = "total_charge"
23    """Key for the total charge in the inputs"""
24    ncharges: int = 10
25    """The number of charge hypothesis"""
26    mode: str = "qeq"
27    """Charge distribution mode. Only 'qeq' available for now."""
28    squeeze: bool = True
29
30    FID: ClassVar[str] = "CHARGE_HYPOTHESIS"
31
32    @nn.compact
33    def __call__(self, inputs):
34        embedding = inputs[self.embedding_key].astype(inputs["coordinates"].dtype)
35
36        wi = jax.nn.softplus(
37            nn.Dense(self.ncharges, use_bias=True, name="wi")(embedding)
38        )
39
40        batch_index = inputs["batch_index"]
41        nsys = inputs["natoms"].shape[0]
42        wtot = jax.ops.segment_sum(wi, batch_index, nsys)
43
44        Qtot = (
45            inputs[self.total_charge_key].astype(wi.dtype)
46            if self.total_charge_key in inputs
47            else jnp.zeros(nsys, dtype=wi.dtype)
48        )
49        if Qtot.ndim == 0:
50            Qtot = Qtot * jnp.ones(nsys, dtype=wi.dtype)
51
52        
53        qtilde = nn.Dense(self.ncharges, use_bias=True, name="qi")(embedding)
54        qtot = jax.ops.segment_sum(qtilde, batch_index, nsys)
55        dq = Qtot[:, None] - qtot
56        f = (dq / wtot)[batch_index]
57
58        if "alch_group" in inputs:
59            alch_group = inputs["alch_group"]
60            lambda_e = inputs["alch_elambda"]
61            Qligand = inputs.get("alch_ligand_charge", jnp.zeros(nsys))
62            if Qligand.ndim == 0:
63                Qligand = Qligand * jnp.ones(nsys, dtype=wi.dtype)
64
65            alch_index = alch_group + 2 * batch_index
66            qtot = jax.ops.segment_sum(qtilde, alch_index, 2 * nsys).reshape(
67                nsys, 2, self.ncharges
68            )
69            wtot = jax.ops.segment_sum(wi, alch_index, 2 * nsys).reshape(
70                nsys, 2, self.ncharges
71            )
72            # compute ligand alone
73            dq = Qligand[:, None] - qtot[:, 1, :]
74            f1 = (dq / wtot[:, 1, :])[batch_index]
75            # q1 = alch_group*(qtilde + wi * f1[batch_index])
76
77            # qtot = jax.ops.segment_sum(qtilde*(1-alch_group), batch_index, nsys)
78            # wtot = jax.ops.segment_sum(wi*(1-alch_group), batch_index, nsys)
79            # compute solvent alone
80            dq = (Qtot - Qligand)[:, None] - qtot[:, 0, :]
81            f0 = (dq / wtot[:, 0, :])[batch_index]
82            # q0 = (1-alch_group)*(qtilde + wi * f0[batch_index])
83
84            # alch_q = alch_group*q1 + (1-alch_group)*q0
85            # q = lambda_e*q + (1-lambda_e)*alch_q
86
87            alch_f = (1 - alch_group[:, None]) * f0 + alch_group[:, None] * f1
88            f = lambda_e * f + (1 - lambda_e) * alch_f
89
90        q = qtilde + wi * f
91
92        if self.squeeze and self.ncharges == 1:
93            q = jnp.squeeze(q, axis=-1)
94
95        output_key = self.output_key if self.output_key is not None else self.name
96        return {
97            **inputs,
98            output_key: q,
99        }

Embedding with total charge constraint vi Multiple Neural Charge Equilibration.

FID: CHARGE_HYPOTHESIS

ChargeHypothesis( embedding_key: str, output_key: Optional[str] = None, total_charge_key: str = 'total_charge', ncharges: int = 10, mode: str = 'qeq', squeeze: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
embedding_key: str

Key for the embedding in the inputs that is used to predict an 'electron affinity' weight

output_key: Optional[str] = None

Key for the charges in the outputs

total_charge_key: str = 'total_charge'

Key for the total charge in the inputs

ncharges: int = 10

The number of charge hypothesis

mode: str = 'qeq'

Charge distribution mode. Only 'qeq' available for now.

squeeze: bool = True
FID: ClassVar[str] = 'CHARGE_HYPOTHESIS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None