fennol.models.embeddings.charge_embeddings
1import jax 2import jax.numpy as jnp 3import flax.linen as nn 4from typing import Sequence, Dict, Union, ClassVar, Optional 5import numpy as np 6from ...utils.periodic_table import ( 7 VALENCE_ELECTRONS, 8) 9from ...utils.initializers import initializer_from_str 10 11class ChargeHypothesis(nn.Module): 12 """Embedding with total charge constraint vi Multiple Neural Charge Equilibration. 13 14 FID: CHARGE_HYPOTHESIS 15 16 """ 17 18 embedding_key: str 19 """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight""" 20 output_key: Union[str, None] = None 21 """Key for the charges in the outputs""" 22 total_charge_key: str = "total_charge" 23 """Key for the total charge in the inputs""" 24 ncharges: int = 10 25 """The number of charge hypothesis""" 26 mode: str = "qeq" 27 """Charge distribution mode. Only 'qeq' available for now.""" 28 squeeze: bool = True 29 kernel_init: Optional[str] = None 30 31 FID: ClassVar[str] = "CHARGE_HYPOTHESIS" 32 33 @nn.compact 34 def __call__(self, inputs): 35 embedding = inputs[self.embedding_key].astype(inputs["coordinates"].dtype) 36 37 kernel_init = initializer_from_str(self.kernel_init) 38 wi = jax.nn.softplus( 39 nn.Dense(self.ncharges, use_bias=True, name="wi",kernel_init=kernel_init)(embedding) 40 ) 41 42 batch_index = inputs["batch_index"] 43 nsys = inputs["natoms"].shape[0] 44 wtot = jax.ops.segment_sum(wi, batch_index, nsys) 45 46 Qtot = ( 47 inputs[self.total_charge_key].astype(wi.dtype) 48 if self.total_charge_key in inputs 49 else jnp.zeros(nsys, dtype=wi.dtype) 50 ) 51 if Qtot.ndim == 0: 52 Qtot = Qtot * jnp.ones(nsys, dtype=wi.dtype) 53 54 55 qtilde = nn.Dense(self.ncharges, use_bias=True, name="qi",kernel_init=kernel_init)(embedding) 56 qtot = jax.ops.segment_sum(qtilde, batch_index, nsys) 57 dq = Qtot[:, None] - qtot 58 f = (dq / wtot)[batch_index] 59 60 if "alch_group" in inputs: 61 alch_group = inputs["alch_group"] 62 lambda_e = inputs["alch_elambda"] 63 Qligand = inputs.get("alch_ligand_charge", jnp.zeros(nsys)) 64 if Qligand.ndim == 0: 65 Qligand = Qligand * jnp.ones(nsys, dtype=wi.dtype) 66 67 alch_index = alch_group + 2 * batch_index 68 qtot = jax.ops.segment_sum(qtilde, alch_index, 2 * nsys).reshape( 69 nsys, 2, self.ncharges 70 ) 71 wtot = jax.ops.segment_sum(wi, alch_index, 2 * nsys).reshape( 72 nsys, 2, self.ncharges 73 ) 74 # compute ligand alone 75 dq = Qligand[:, None] - qtot[:, 1, :] 76 f1 = (dq / wtot[:, 1, :])[batch_index] 77 # q1 = alch_group*(qtilde + wi * f1[batch_index]) 78 79 # qtot = jax.ops.segment_sum(qtilde*(1-alch_group), batch_index, nsys) 80 # wtot = jax.ops.segment_sum(wi*(1-alch_group), batch_index, nsys) 81 # compute solvent alone 82 dq = (Qtot - Qligand)[:, None] - qtot[:, 0, :] 83 f0 = (dq / wtot[:, 0, :])[batch_index] 84 # q0 = (1-alch_group)*(qtilde + wi * f0[batch_index]) 85 86 # alch_q = alch_group*q1 + (1-alch_group)*q0 87 # q = lambda_e*q + (1-lambda_e)*alch_q 88 89 alch_f = (1 - alch_group[:, None]) * f0 + alch_group[:, None] * f1 90 f = lambda_e * f + (1 - lambda_e) * alch_f 91 92 q = qtilde + wi * f 93 94 if self.squeeze and self.ncharges == 1: 95 q = jnp.squeeze(q, axis=-1) 96 97 output_key = self.output_key if self.output_key is not None else self.name 98 return { 99 **inputs, 100 output_key: q, 101 }
class
ChargeHypothesis(flax.linen.module.Module):
12class ChargeHypothesis(nn.Module): 13 """Embedding with total charge constraint vi Multiple Neural Charge Equilibration. 14 15 FID: CHARGE_HYPOTHESIS 16 17 """ 18 19 embedding_key: str 20 """Key for the embedding in the inputs that is used to predict an 'electron affinity' weight""" 21 output_key: Union[str, None] = None 22 """Key for the charges in the outputs""" 23 total_charge_key: str = "total_charge" 24 """Key for the total charge in the inputs""" 25 ncharges: int = 10 26 """The number of charge hypothesis""" 27 mode: str = "qeq" 28 """Charge distribution mode. Only 'qeq' available for now.""" 29 squeeze: bool = True 30 kernel_init: Optional[str] = None 31 32 FID: ClassVar[str] = "CHARGE_HYPOTHESIS" 33 34 @nn.compact 35 def __call__(self, inputs): 36 embedding = inputs[self.embedding_key].astype(inputs["coordinates"].dtype) 37 38 kernel_init = initializer_from_str(self.kernel_init) 39 wi = jax.nn.softplus( 40 nn.Dense(self.ncharges, use_bias=True, name="wi",kernel_init=kernel_init)(embedding) 41 ) 42 43 batch_index = inputs["batch_index"] 44 nsys = inputs["natoms"].shape[0] 45 wtot = jax.ops.segment_sum(wi, batch_index, nsys) 46 47 Qtot = ( 48 inputs[self.total_charge_key].astype(wi.dtype) 49 if self.total_charge_key in inputs 50 else jnp.zeros(nsys, dtype=wi.dtype) 51 ) 52 if Qtot.ndim == 0: 53 Qtot = Qtot * jnp.ones(nsys, dtype=wi.dtype) 54 55 56 qtilde = nn.Dense(self.ncharges, use_bias=True, name="qi",kernel_init=kernel_init)(embedding) 57 qtot = jax.ops.segment_sum(qtilde, batch_index, nsys) 58 dq = Qtot[:, None] - qtot 59 f = (dq / wtot)[batch_index] 60 61 if "alch_group" in inputs: 62 alch_group = inputs["alch_group"] 63 lambda_e = inputs["alch_elambda"] 64 Qligand = inputs.get("alch_ligand_charge", jnp.zeros(nsys)) 65 if Qligand.ndim == 0: 66 Qligand = Qligand * jnp.ones(nsys, dtype=wi.dtype) 67 68 alch_index = alch_group + 2 * batch_index 69 qtot = jax.ops.segment_sum(qtilde, alch_index, 2 * nsys).reshape( 70 nsys, 2, self.ncharges 71 ) 72 wtot = jax.ops.segment_sum(wi, alch_index, 2 * nsys).reshape( 73 nsys, 2, self.ncharges 74 ) 75 # compute ligand alone 76 dq = Qligand[:, None] - qtot[:, 1, :] 77 f1 = (dq / wtot[:, 1, :])[batch_index] 78 # q1 = alch_group*(qtilde + wi * f1[batch_index]) 79 80 # qtot = jax.ops.segment_sum(qtilde*(1-alch_group), batch_index, nsys) 81 # wtot = jax.ops.segment_sum(wi*(1-alch_group), batch_index, nsys) 82 # compute solvent alone 83 dq = (Qtot - Qligand)[:, None] - qtot[:, 0, :] 84 f0 = (dq / wtot[:, 0, :])[batch_index] 85 # q0 = (1-alch_group)*(qtilde + wi * f0[batch_index]) 86 87 # alch_q = alch_group*q1 + (1-alch_group)*q0 88 # q = lambda_e*q + (1-lambda_e)*alch_q 89 90 alch_f = (1 - alch_group[:, None]) * f0 + alch_group[:, None] * f1 91 f = lambda_e * f + (1 - lambda_e) * alch_f 92 93 q = qtilde + wi * f 94 95 if self.squeeze and self.ncharges == 1: 96 q = jnp.squeeze(q, axis=-1) 97 98 output_key = self.output_key if self.output_key is not None else self.name 99 return { 100 **inputs, 101 output_key: q, 102 }
Embedding with total charge constraint vi Multiple Neural Charge Equilibration.
FID: CHARGE_HYPOTHESIS
ChargeHypothesis( embedding_key: str, output_key: Optional[str] = None, total_charge_key: str = 'total_charge', ncharges: int = 10, mode: str = 'qeq', squeeze: bool = True, kernel_init: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
embedding_key: str
Key for the embedding in the inputs that is used to predict an 'electron affinity' weight
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.