fennol.models.embeddings.charge_embeddings
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, Union, ClassVar, Optional 5import numpy as np 6from ...utils.periodic_table import ( 7 VALENCE_ELECTRONS, 8) 9 10class ChargeHypothesis(nn.Module): 11 """Embedding with total charge constraint vi Multiple Neural Charge Equilibration. 12 13 FID: CHARGE_HYPOTHESIS 14 15 """ 16 17 embedding_key: str 18 """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight""" 19 output_key: Union[str, None] = None 20 """Key for the charges in the outputs""" 21 total_charge_key: str = "total_charge" 22 """Key for the total charge in the inputs""" 23 ncharges: int = 10 24 """The number of charge hypothesis""" 25 mode: str = "qeq" 26 """Charge distribution mode. Only 'qeq' available for now.""" 27 squeeze: bool = True 28 29 FID: ClassVar[str] = "CHARGE_HYPOTHESIS" 30 31 @nn.compact 32 def __call__(self, inputs): 33 embedding = inputs[self.embedding_key].astype(inputs["coordinates"].dtype) 34 35 wi = jax.nn.softplus( 36 nn.Dense(self.ncharges, use_bias=True, name="wi")(embedding) 37 ) 38 39 batch_index = inputs["batch_index"] 40 nsys = inputs["natoms"].shape[0] 41 wtot = jax.ops.segment_sum(wi, batch_index, nsys) 42 43 Qtot = ( 44 inputs[self.total_charge_key].astype(wi.dtype) 45 if self.total_charge_key in inputs 46 else jnp.zeros(nsys, dtype=wi.dtype) 47 ) 48 if Qtot.ndim == 0: 49 Qtot = Qtot * jnp.ones(nsys, dtype=wi.dtype) 50 51 52 qtilde = nn.Dense(self.ncharges, use_bias=True, name="qi")(embedding) 53 qtot = jax.ops.segment_sum(qtilde, batch_index, nsys) 54 dq = Qtot[:, None] - qtot 55 f = (dq / wtot)[batch_index] 56 57 if "alch_group" in inputs: 58 alch_group = inputs["alch_group"] 59 lambda_e = inputs["alch_elambda"] 60 Qligand = inputs.get("alch_ligand_charge", jnp.zeros(nsys)) 61 if Qligand.ndim == 0: 62 Qligand = Qligand * jnp.ones(nsys, dtype=wi.dtype) 63 64 alch_index = alch_group + 2 * batch_index 65 qtot = jax.ops.segment_sum(qtilde, alch_index, 2 * nsys).reshape( 66 nsys, 2, self.ncharges 67 ) 68 wtot = jax.ops.segment_sum(wi, alch_index, 2 * nsys).reshape( 69 nsys, 2, self.ncharges 70 ) 71 # compute ligand alone 72 dq = Qligand[:, None] - qtot[:, 1, :] 73 f1 = (dq / wtot[:, 1, :])[batch_index] 74 # q1 = alch_group*(qtilde + wi * f1[batch_index]) 75 76 # qtot = jax.ops.segment_sum(qtilde*(1-alch_group), batch_index, nsys) 77 # wtot = jax.ops.segment_sum(wi*(1-alch_group), batch_index, nsys) 78 # compute solvent alone 79 dq = (Qtot - Qligand)[:, None] - qtot[:, 0, :] 80 f0 = (dq / wtot[:, 0, :])[batch_index] 81 # q0 = (1-alch_group)*(qtilde + wi * f0[batch_index]) 82 83 # alch_q = alch_group*q1 + (1-alch_group)*q0 84 # q = lambda_e*q + (1-lambda_e)*alch_q 85 86 alch_f = (1 - alch_group[:, None]) * f0 + alch_group[:, None] * f1 87 f = lambda_e * f + (1 - lambda_e) * alch_f 88 89 q = qtilde + wi * f 90 91 if self.squeeze and self.ncharges == 1: 92 q = jnp.squeeze(q, axis=-1) 93 94 output_key = self.output_key if self.output_key is not None else self.name 95 return { 96 **inputs, 97 output_key: q, 98 }
class
ChargeHypothesis(flax.linen.module.Module):
11class ChargeHypothesis(nn.Module): 12 """Embedding with total charge constraint vi Multiple Neural Charge Equilibration. 13 14 FID: CHARGE_HYPOTHESIS 15 16 """ 17 18 embedding_key: str 19 """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight""" 20 output_key: Union[str, None] = None 21 """Key for the charges in the outputs""" 22 total_charge_key: str = "total_charge" 23 """Key for the total charge in the inputs""" 24 ncharges: int = 10 25 """The number of charge hypothesis""" 26 mode: str = "qeq" 27 """Charge distribution mode. Only 'qeq' available for now.""" 28 squeeze: bool = True 29 30 FID: ClassVar[str] = "CHARGE_HYPOTHESIS" 31 32 @nn.compact 33 def __call__(self, inputs): 34 embedding = inputs[self.embedding_key].astype(inputs["coordinates"].dtype) 35 36 wi = jax.nn.softplus( 37 nn.Dense(self.ncharges, use_bias=True, name="wi")(embedding) 38 ) 39 40 batch_index = inputs["batch_index"] 41 nsys = inputs["natoms"].shape[0] 42 wtot = jax.ops.segment_sum(wi, batch_index, nsys) 43 44 Qtot = ( 45 inputs[self.total_charge_key].astype(wi.dtype) 46 if self.total_charge_key in inputs 47 else jnp.zeros(nsys, dtype=wi.dtype) 48 ) 49 if Qtot.ndim == 0: 50 Qtot = Qtot * jnp.ones(nsys, dtype=wi.dtype) 51 52 53 qtilde = nn.Dense(self.ncharges, use_bias=True, name="qi")(embedding) 54 qtot = jax.ops.segment_sum(qtilde, batch_index, nsys) 55 dq = Qtot[:, None] - qtot 56 f = (dq / wtot)[batch_index] 57 58 if "alch_group" in inputs: 59 alch_group = inputs["alch_group"] 60 lambda_e = inputs["alch_elambda"] 61 Qligand = inputs.get("alch_ligand_charge", jnp.zeros(nsys)) 62 if Qligand.ndim == 0: 63 Qligand = Qligand * jnp.ones(nsys, dtype=wi.dtype) 64 65 alch_index = alch_group + 2 * batch_index 66 qtot = jax.ops.segment_sum(qtilde, alch_index, 2 * nsys).reshape( 67 nsys, 2, self.ncharges 68 ) 69 wtot = jax.ops.segment_sum(wi, alch_index, 2 * nsys).reshape( 70 nsys, 2, self.ncharges 71 ) 72 # compute ligand alone 73 dq = Qligand[:, None] - qtot[:, 1, :] 74 f1 = (dq / wtot[:, 1, :])[batch_index] 75 # q1 = alch_group*(qtilde + wi * f1[batch_index]) 76 77 # qtot = jax.ops.segment_sum(qtilde*(1-alch_group), batch_index, nsys) 78 # wtot = jax.ops.segment_sum(wi*(1-alch_group), batch_index, nsys) 79 # compute solvent alone 80 dq = (Qtot - Qligand)[:, None] - qtot[:, 0, :] 81 f0 = (dq / wtot[:, 0, :])[batch_index] 82 # q0 = (1-alch_group)*(qtilde + wi * f0[batch_index]) 83 84 # alch_q = alch_group*q1 + (1-alch_group)*q0 85 # q = lambda_e*q + (1-lambda_e)*alch_q 86 87 alch_f = (1 - alch_group[:, None]) * f0 + alch_group[:, None] * f1 88 f = lambda_e * f + (1 - lambda_e) * alch_f 89 90 q = qtilde + wi * f 91 92 if self.squeeze and self.ncharges == 1: 93 q = jnp.squeeze(q, axis=-1) 94 95 output_key = self.output_key if self.output_key is not None else self.name 96 return { 97 **inputs, 98 output_key: q, 99 }
Embedding with total charge constraint vi Multiple Neural Charge Equilibration.
FID: CHARGE_HYPOTHESIS
ChargeHypothesis( embedding_key: str, output_key: Optional[str] = None, total_charge_key: str = 'total_charge', ncharges: int = 10, mode: str = 'qeq', squeeze: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
embedding_key: str
Key for the embedding in the inputs that is used to predict an 'electron affinity' weight
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.