fennol.models.embeddings.charge_embeddings

  1import jax
  2import jax.numpy as jnp
  3import flax.linen as nn
  4from typing import Sequence, Dict, Union, ClassVar, Optional
  5import numpy as np
  6from ...utils.periodic_table import (
  7    VALENCE_ELECTRONS,
  8)
  9from ...utils.initializers import initializer_from_str
 10
 11class ChargeHypothesis(nn.Module):
 12    """Embedding with total charge constraint vi Multiple Neural Charge Equilibration.
 13
 14    FID: CHARGE_HYPOTHESIS
 15
 16    """
 17
 18    embedding_key: str
 19    """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight"""
 20    output_key: Union[str, None] = None
 21    """Key for the charges in the outputs"""
 22    total_charge_key: str = "total_charge"
 23    """Key for the total charge in the inputs"""
 24    ncharges: int = 10
 25    """The number of charge hypothesis"""
 26    mode: str = "qeq"
 27    """Charge distribution mode. Only 'qeq' available for now."""
 28    squeeze: bool = True
 29    kernel_init: Optional[str] = None
 30
 31    FID: ClassVar[str] = "CHARGE_HYPOTHESIS"
 32
 33    @nn.compact
 34    def __call__(self, inputs):
 35        embedding = inputs[self.embedding_key].astype(inputs["coordinates"].dtype)
 36
 37        kernel_init = initializer_from_str(self.kernel_init)
 38        wi = jax.nn.softplus(
 39            nn.Dense(self.ncharges, use_bias=True, name="wi",kernel_init=kernel_init)(embedding)
 40        )
 41
 42        batch_index = inputs["batch_index"]
 43        nsys = inputs["natoms"].shape[0]
 44        wtot = jax.ops.segment_sum(wi, batch_index, nsys)
 45
 46        Qtot = (
 47            inputs[self.total_charge_key].astype(wi.dtype)
 48            if self.total_charge_key in inputs
 49            else jnp.zeros(nsys, dtype=wi.dtype)
 50        )
 51        if Qtot.ndim == 0:
 52            Qtot = Qtot * jnp.ones(nsys, dtype=wi.dtype)
 53
 54        
 55        qtilde = nn.Dense(self.ncharges, use_bias=True, name="qi",kernel_init=kernel_init)(embedding)
 56        qtot = jax.ops.segment_sum(qtilde, batch_index, nsys)
 57        dq = Qtot[:, None] - qtot
 58        f = (dq / wtot)[batch_index]
 59
 60        if "alch_group" in inputs:
 61            alch_group = inputs["alch_group"]
 62            lambda_e = inputs["alch_elambda"]
 63            Qligand = inputs.get("alch_ligand_charge", jnp.zeros(nsys))
 64            if Qligand.ndim == 0:
 65                Qligand = Qligand * jnp.ones(nsys, dtype=wi.dtype)
 66
 67            alch_index = alch_group + 2 * batch_index
 68            qtot = jax.ops.segment_sum(qtilde, alch_index, 2 * nsys).reshape(
 69                nsys, 2, self.ncharges
 70            )
 71            wtot = jax.ops.segment_sum(wi, alch_index, 2 * nsys).reshape(
 72                nsys, 2, self.ncharges
 73            )
 74            # compute ligand alone
 75            dq = Qligand[:, None] - qtot[:, 1, :]
 76            f1 = (dq / wtot[:, 1, :])[batch_index]
 77            # q1 = alch_group*(qtilde + wi * f1[batch_index])
 78
 79            # qtot = jax.ops.segment_sum(qtilde*(1-alch_group), batch_index, nsys)
 80            # wtot = jax.ops.segment_sum(wi*(1-alch_group), batch_index, nsys)
 81            # compute solvent alone
 82            dq = (Qtot - Qligand)[:, None] - qtot[:, 0, :]
 83            f0 = (dq / wtot[:, 0, :])[batch_index]
 84            # q0 = (1-alch_group)*(qtilde + wi * f0[batch_index])
 85
 86            # alch_q = alch_group*q1 + (1-alch_group)*q0
 87            # q = lambda_e*q + (1-lambda_e)*alch_q
 88
 89            alch_f = (1 - alch_group[:, None]) * f0 + alch_group[:, None] * f1
 90            f = lambda_e * f + (1 - lambda_e) * alch_f
 91
 92        q = qtilde + wi * f
 93
 94        if self.squeeze and self.ncharges == 1:
 95            q = jnp.squeeze(q, axis=-1)
 96
 97        output_key = self.output_key if self.output_key is not None else self.name
 98        return {
 99            **inputs,
100            output_key: q,
101        }
class ChargeHypothesis(flax.linen.module.Module):
 12class ChargeHypothesis(nn.Module):
 13    """Embedding with total charge constraint vi Multiple Neural Charge Equilibration.
 14
 15    FID: CHARGE_HYPOTHESIS
 16
 17    """
 18
 19    embedding_key: str
 20    """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight"""
 21    output_key: Union[str, None] = None
 22    """Key for the charges in the outputs"""
 23    total_charge_key: str = "total_charge"
 24    """Key for the total charge in the inputs"""
 25    ncharges: int = 10
 26    """The number of charge hypothesis"""
 27    mode: str = "qeq"
 28    """Charge distribution mode. Only 'qeq' available for now."""
 29    squeeze: bool = True
 30    kernel_init: Optional[str] = None
 31
 32    FID: ClassVar[str] = "CHARGE_HYPOTHESIS"
 33
 34    @nn.compact
 35    def __call__(self, inputs):
 36        embedding = inputs[self.embedding_key].astype(inputs["coordinates"].dtype)
 37
 38        kernel_init = initializer_from_str(self.kernel_init)
 39        wi = jax.nn.softplus(
 40            nn.Dense(self.ncharges, use_bias=True, name="wi",kernel_init=kernel_init)(embedding)
 41        )
 42
 43        batch_index = inputs["batch_index"]
 44        nsys = inputs["natoms"].shape[0]
 45        wtot = jax.ops.segment_sum(wi, batch_index, nsys)
 46
 47        Qtot = (
 48            inputs[self.total_charge_key].astype(wi.dtype)
 49            if self.total_charge_key in inputs
 50            else jnp.zeros(nsys, dtype=wi.dtype)
 51        )
 52        if Qtot.ndim == 0:
 53            Qtot = Qtot * jnp.ones(nsys, dtype=wi.dtype)
 54
 55        
 56        qtilde = nn.Dense(self.ncharges, use_bias=True, name="qi",kernel_init=kernel_init)(embedding)
 57        qtot = jax.ops.segment_sum(qtilde, batch_index, nsys)
 58        dq = Qtot[:, None] - qtot
 59        f = (dq / wtot)[batch_index]
 60
 61        if "alch_group" in inputs:
 62            alch_group = inputs["alch_group"]
 63            lambda_e = inputs["alch_elambda"]
 64            Qligand = inputs.get("alch_ligand_charge", jnp.zeros(nsys))
 65            if Qligand.ndim == 0:
 66                Qligand = Qligand * jnp.ones(nsys, dtype=wi.dtype)
 67
 68            alch_index = alch_group + 2 * batch_index
 69            qtot = jax.ops.segment_sum(qtilde, alch_index, 2 * nsys).reshape(
 70                nsys, 2, self.ncharges
 71            )
 72            wtot = jax.ops.segment_sum(wi, alch_index, 2 * nsys).reshape(
 73                nsys, 2, self.ncharges
 74            )
 75            # compute ligand alone
 76            dq = Qligand[:, None] - qtot[:, 1, :]
 77            f1 = (dq / wtot[:, 1, :])[batch_index]
 78            # q1 = alch_group*(qtilde + wi * f1[batch_index])
 79
 80            # qtot = jax.ops.segment_sum(qtilde*(1-alch_group), batch_index, nsys)
 81            # wtot = jax.ops.segment_sum(wi*(1-alch_group), batch_index, nsys)
 82            # compute solvent alone
 83            dq = (Qtot - Qligand)[:, None] - qtot[:, 0, :]
 84            f0 = (dq / wtot[:, 0, :])[batch_index]
 85            # q0 = (1-alch_group)*(qtilde + wi * f0[batch_index])
 86
 87            # alch_q = alch_group*q1 + (1-alch_group)*q0
 88            # q = lambda_e*q + (1-lambda_e)*alch_q
 89
 90            alch_f = (1 - alch_group[:, None]) * f0 + alch_group[:, None] * f1
 91            f = lambda_e * f + (1 - lambda_e) * alch_f
 92
 93        q = qtilde + wi * f
 94
 95        if self.squeeze and self.ncharges == 1:
 96            q = jnp.squeeze(q, axis=-1)
 97
 98        output_key = self.output_key if self.output_key is not None else self.name
 99        return {
100            **inputs,
101            output_key: q,
102        }

Embedding with total charge constraint vi Multiple Neural Charge Equilibration.

FID: CHARGE_HYPOTHESIS

ChargeHypothesis( embedding_key: str, output_key: Optional[str] = None, total_charge_key: str = 'total_charge', ncharges: int = 10, mode: str = 'qeq', squeeze: bool = True, kernel_init: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
embedding_key: str

Key for the embedding in the inputs that is used to predict an 'electron affinity' weight

output_key: Optional[str] = None

Key for the charges in the outputs

total_charge_key: str = 'total_charge'

Key for the total charge in the inputs

ncharges: int = 10

The number of charge hypothesis

mode: str = 'qeq'

Charge distribution mode. Only 'qeq' available for now.

squeeze: bool = True
kernel_init: Optional[str] = None
FID: ClassVar[str] = 'CHARGE_HYPOTHESIS'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None