fennol.models.misc.nets

   1import flax.linen as nn
   2from typing import Sequence, Callable, Union, Optional, ClassVar, Tuple
   3from ...utils.periodic_table import PERIODIC_TABLE_REV_IDX, PERIODIC_TABLE
   4import jax.numpy as jnp
   5import jax
   6import numpy as np
   7from functools import partial
   8from ...utils.activations import activation_from_str, TrainableSiLU
   9from ...utils.initializers import initializer_from_str, scaled_orthogonal
  10from flax.core import FrozenDict
  11
  12
  13class FullyConnectedNet(nn.Module):
  14    """A fully connected neural network module.
  15
  16    FID: NEURAL_NET
  17    """
  18
  19    neurons: Sequence[int]
  20    """A sequence of integers representing the dimensions of the network."""
  21    activation: Union[Callable, str] = "silu"
  22    """The activation function to use."""
  23    use_bias: bool = True
  24    """Whether to use bias in the dense layers."""
  25    input_key: Optional[str] = None
  26    """The key of the input tensor."""
  27    output_key: Optional[str] = None
  28    """The key of the output tensor."""
  29    squeeze: bool = False
  30    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
  31    kernel_init: Union[str, Callable] = "lecun_normal()"
  32    """The kernel initialization method to use."""
  33
  34    FID: ClassVar[str] = "NEURAL_NET"
  35
  36    @nn.compact
  37    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
  38        """Applies the neural network to the given inputs.
  39
  40        Args:
  41            inputs (Union[dict, jax.Array]): If input_key is None, a JAX array containing the inputs to the neural network. Else, a dictionary containing the inputs at the key input_key.
  42
  43        Returns:
  44            Union[dict, jax.Array]: If output_key is None, the output tensor of the neural network. Else, a dictionary containing the original inputs and the output tensor at the key output_key.
  45        """
  46        if self.input_key is None:
  47            assert not isinstance(
  48                inputs, dict
  49            ), "input key must be provided if inputs is a dictionary"
  50            x = inputs
  51        else:
  52            x = inputs[self.input_key]
  53
  54        # activation = (
  55        #     activation_from_str(self.activation)
  56        #     if isinstance(self.activation, str)
  57        #     else self.activation
  58        # )
  59        kernel_init = (
  60            initializer_from_str(self.kernel_init)
  61            if isinstance(self.kernel_init, str)
  62            else self.kernel_init
  63        )
  64        ############################
  65        for i, d in enumerate(self.neurons[:-1]):
  66            x = nn.Dense(
  67                d,
  68                use_bias=self.use_bias,
  69                name=f"Layer_{i+1}",
  70                kernel_init=kernel_init,
  71                # precision=jax.lax.Precision.HIGH,
  72            )(x)
  73            x = activation_from_str(self.activation)(x)
  74        x = nn.Dense(
  75            self.neurons[-1],
  76            use_bias=self.use_bias,
  77            name=f"Layer_{len(self.neurons)}",
  78            kernel_init=kernel_init,
  79            # precision=jax.lax.Precision.HIGH,
  80        )(x)
  81        if self.squeeze and x.shape[-1] == 1:
  82            x = jnp.squeeze(x, axis=-1)
  83        ############################
  84
  85        if self.input_key is not None:
  86            output_key = self.name if self.output_key is None else self.output_key
  87            return {**inputs, output_key: x} if output_key is not None else x
  88        return x
  89
  90
  91class ResMLP(nn.Module):
  92    """Residual neural network as defined in the SpookyNet paper.
  93    
  94    FID: RES_MLP
  95    """
  96
  97    use_bias: bool = True
  98    """Whether to include bias in the linear layers."""
  99    input_key: Optional[str] = None
 100    """The key of the input tensor."""
 101    output_key: Optional[str] = None
 102    """The key of the output tensor."""
 103
 104    kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')"
 105    """The kernel initialization method to use."""
 106    res_only: bool = False
 107    """Whether to only apply the residual connection without additional activation and linear layer."""
 108
 109    FID: ClassVar[str] = "RES_MLP"
 110
 111    @nn.compact
 112    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 113        if self.input_key is None:
 114            assert not isinstance(
 115                inputs, dict
 116            ), "input key must be provided if inputs is a dictionary"
 117            x = inputs
 118        else:
 119            x = inputs[self.input_key]
 120
 121        kernel_init = (
 122            initializer_from_str(self.kernel_init)
 123            if isinstance(self.kernel_init, str)
 124            else self.kernel_init
 125        )
 126        ############################
 127        out = nn.Dense(x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init)(
 128            TrainableSiLU()(x)
 129        )
 130        out = x + nn.Dense(
 131            x.shape[-1], use_bias=self.use_bias, kernel_init=nn.initializers.zeros
 132        )(TrainableSiLU()(out))
 133
 134        if not self.res_only:
 135            out = nn.Dense(
 136                x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init
 137            )(TrainableSiLU()(out))
 138        ############################
 139
 140        if self.input_key is not None:
 141            output_key = self.name if self.output_key is None else self.output_key
 142            return {**inputs, output_key: out} if output_key is not None else out
 143        return out
 144
 145
 146class FullyResidualNet(nn.Module):
 147    """A neural network with skip connections at each layer.
 148    
 149    FID: SKIP_NET
 150    """
 151
 152    dim: int
 153    """The dimension of the hidden layers."""
 154    output_dim: int
 155    """The dimension of the output layer."""
 156    nlayers: int
 157    """The number of layers in the network."""
 158    activation: Union[Callable, str] = "silu"
 159    """The activation function to use."""
 160    use_bias: bool = True
 161    """Whether to include bias terms in the linear layers."""
 162    input_key: Optional[str] = None
 163    """The key of the input tensor."""
 164    output_key: Optional[str] = None
 165    """The key of the output tensor."""
 166    squeeze: bool = False
 167    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 168    kernel_init: Union[str, Callable] = "lecun_normal()"
 169    """The kernel initialization method to use."""
 170
 171    FID: ClassVar[str] = "SKIP_NET"
 172
 173    @nn.compact
 174    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 175        if self.input_key is None:
 176            assert not isinstance(
 177                inputs, dict
 178            ), "input key must be provided if inputs is a dictionary"
 179            x = inputs
 180        else:
 181            x = inputs[self.input_key]
 182
 183        # activation = (
 184        #     activation_from_str(self.activation)
 185        #     if isinstance(self.activation, str)
 186        #     else self.activation
 187        # )
 188        kernel_init = (
 189            initializer_from_str(self.kernel_init)
 190            if isinstance(self.kernel_init, str)
 191            else self.kernel_init
 192        )
 193        ############################
 194        if x.shape[-1] != self.dim:
 195            x = nn.Dense(
 196                self.dim,
 197                use_bias=self.use_bias,
 198                name=f"Reshape",
 199                kernel_init=kernel_init,
 200            )(x)
 201
 202        for i in range(self.nlayers - 1):
 203            x = x + activation_from_str(self.activation)(
 204                nn.Dense(
 205                    self.dim,
 206                    use_bias=self.use_bias,
 207                    name=f"Layer_{i+1}",
 208                    kernel_init=kernel_init,
 209                )(x)
 210            )
 211        x = nn.Dense(
 212            self.output_dim,
 213            use_bias=self.use_bias,
 214            name=f"Layer_{self.nlayers}",
 215            kernel_init=kernel_init,
 216        )(x)
 217        if self.squeeze and x.shape[-1] == 1:
 218            x = jnp.squeeze(x, axis=-1)
 219        ############################
 220
 221        if self.input_key is not None:
 222            output_key = self.name if self.output_key is None else self.output_key
 223            return {**inputs, output_key: x} if output_key is not None else x
 224        return x
 225
 226
 227class HierarchicalNet(nn.Module):
 228    """Neural network for a sequence of inputs (in axis=-2) with a decay factor
 229    
 230    FID: HIERARCHICAL_NET
 231    """
 232
 233    neurons: Sequence[int]
 234    """A sequence of integers representing the number of neurons in each layer."""
 235    activation: Union[Callable, str] = "silu"
 236    """The activation function to use."""
 237    use_bias: bool = True
 238    """Whether to include bias terms in the linear layers."""
 239    input_key: Optional[str] = None
 240    """The key of the input tensor."""
 241    output_key: Optional[str] = None
 242    """The key of the output tensor."""
 243    decay: float = 0.01
 244    """The decay factor to scale each element of the sequence."""
 245    squeeze: bool = False
 246    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 247    kernel_init: Union[str, Callable] = "lecun_normal()"
 248    """The kernel initialization method to use."""
 249
 250    FID: ClassVar[str] = "HIERARCHICAL_NET"
 251
 252    @nn.compact
 253    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 254        if self.input_key is None:
 255            assert not isinstance(
 256                inputs, dict
 257            ), "input key must be provided if inputs is a dictionary"
 258            x = inputs
 259        else:
 260            x = inputs[self.input_key]
 261
 262        ############################
 263        networks = nn.vmap(
 264            FullyConnectedNet,
 265            variable_axes={"params": 0},
 266            split_rngs={"params": True},
 267            in_axes=-2,
 268            out_axes=-2,
 269            kenel_init=self.kernel_init,
 270        )(self.neurons, self.activation, self.use_bias)
 271
 272        out = networks(x)
 273        # scale each layer by a decay factor
 274        decay = jnp.asarray([self.decay**i for i in range(out.shape[-2])])
 275        out = out * decay[..., :, None]
 276
 277        if self.squeeze and out.shape[-1] == 1:
 278            out = jnp.squeeze(out, axis=-1)
 279        ############################
 280
 281        if self.input_key is not None:
 282            output_key = self.name if self.output_key is None else self.output_key
 283            return {**inputs, output_key: out} if output_key is not None else out
 284        return out
 285
 286
 287class SpeciesIndexNet(nn.Module):
 288    """Chemical-species-specific neural network using precomputed species index.
 289
 290    FID: SPECIES_INDEX_NET
 291
 292    A neural network that applies a species-specific fully connected network to each atom embedding.
 293    A species index must be provided to filter the embeddings for each species and apply the corresponding network.
 294    This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer`
 295
 296    """
 297
 298    output_dim: int
 299    """The dimension of the output of the fully connected networks."""
 300    hidden_neurons: Union[dict, FrozenDict, Sequence[int]]
 301    """The hidden dimensions of the fully connected networks.
 302        If a dictionary is provided, it should map species names to dimensions.
 303        If a sequence is provided, the same dimensions will be used for all species."""
 304    species_order: Optional[Union[str, Sequence[str]]] = None
 305    """The species for which to build a network. Only required if neurons is not a dictionary."""
 306    activation: Union[Callable, str] = "silu"
 307    """The activation function to use in the fully connected networks."""
 308    use_bias: bool = True
 309    """Whether to include bias terms in the fully connected networks."""
 310    input_key: Optional[str] = None
 311    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
 312    species_index_key: str = "species_index"
 313    """The key in the input dictionary that corresponds to the species index of the atoms. See `fennol.models.preprocessing.SpeciesIndexer`"""
 314    output_key: Optional[str] = None
 315    """The key in the output dictionary that corresponds to the network's output."""
 316
 317    squeeze: bool = False
 318    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 319    kernel_init: Union[str, Callable] = "lecun_normal()"
 320    """The kernel initialization method for the fully connected networks."""
 321    check_unhandled: bool = True
 322
 323    FID: ClassVar[str] = "SPECIES_INDEX_NET"
 324
 325    def setup(self):
 326        if not (
 327            isinstance(self.hidden_neurons, dict)
 328            or isinstance(self.hidden_neurons, FrozenDict)
 329        ):
 330            assert (
 331                self.species_order is not None
 332            ), "species_order must be provided if hidden_neurons is not a dictionary"
 333            if isinstance(self.species_order, str):
 334                species_order = [el.strip() for el in self.species_order.split(",")]
 335            else:
 336                species_order = [el for el in self.species_order]
 337            neurons = {k: self.hidden_neurons for k in species_order}
 338        else:
 339            neurons = self.hidden_neurons
 340            species_order = list(neurons.keys())
 341        for species in species_order:
 342            assert (
 343                species in PERIODIC_TABLE
 344            ), f"species {species} not found in periodic table"
 345
 346        self.networks = {
 347            k: FullyConnectedNet(
 348                [*neurons[k], self.output_dim],
 349                self.activation,
 350                self.use_bias,
 351                name=k,
 352                kernel_init=self.kernel_init,
 353            )
 354            for k in species_order
 355        }
 356
 357    def __call__(
 358        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 359    ) -> Union[dict, jax.Array]:
 360
 361        if self.input_key is None:
 362            assert not isinstance(
 363                inputs, dict
 364            ), "input key must be provided if inputs is a dictionary"
 365            species, embedding, species_index = inputs
 366        else:
 367            species, embedding = inputs["species"], inputs[self.input_key]
 368            species_index = inputs[self.species_index_key]
 369
 370        assert isinstance(
 371            species_index, dict
 372        ), "species_index must be a dictionary for SpeciesIndexNetHet"
 373
 374        ############################
 375        # initialization => instantiate all networks
 376        if self.is_initializing():
 377            x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype)
 378            [net(x) for net in self.networks.values()]
 379
 380        if self.check_unhandled:
 381            for b in species_index.keys():
 382                if b not in self.networks.keys():
 383                    raise ValueError(f"Species {b} not found in networks. Handled species are {self.networks.keys()}")
 384
 385        ############################
 386        outputs = []
 387        indices = []
 388        for s, net in self.networks.items():
 389            if s not in species_index:
 390                continue
 391            idx = species_index[s]
 392            o = net(embedding[idx])
 393            outputs.append(o)
 394            indices.append(idx)
 395
 396        o = jnp.concatenate(outputs, axis=0)
 397        idx = jnp.concatenate(indices, axis=0)
 398
 399        out = (
 400            jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype)
 401            .at[idx]
 402            .set(o, mode="drop")
 403        )
 404
 405        if self.squeeze and out.shape[-1] == 1:
 406            out = jnp.squeeze(out, axis=-1)
 407        ############################
 408
 409        if self.input_key is not None:
 410            output_key = self.name if self.output_key is None else self.output_key
 411            return {**inputs, output_key: out} if output_key is not None else out
 412        return out
 413
 414
 415class ChemicalNet(nn.Module):
 416    """optimized Chemical-species-specific neural network.
 417
 418    FID: CHEMICAL_NET
 419
 420    A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species.
 421    This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once.
 422    The optimization is allowed because all networks have the same shape.
 423
 424    """
 425
 426    species_order: Union[str, Sequence[str]]
 427    """The species for which to build a network."""
 428    neurons: Sequence[int]
 429    """The dimensions of the fully connected networks."""
 430    activation: Union[Callable, str] = "silu"
 431    """The activation function to use in the fully connected networks."""
 432    use_bias: bool = True
 433    """Whether to include bias terms in the fully connected networks."""
 434    input_key: Optional[str] = None
 435    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
 436    output_key: Optional[str] = None
 437    """The key in the output dictionary that corresponds to the network's output."""
 438    squeeze: bool = False
 439    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 440    kernel_init: Union[str, Callable] = "lecun_normal()"
 441    """The kernel initialization method for the fully connected networks."""
 442
 443    FID: ClassVar[str] = "CHEMICAL_NET"
 444
 445    @nn.compact
 446    def __call__(
 447        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 448    ) -> Union[dict, jax.Array]:
 449        if self.input_key is None:
 450            assert not isinstance(
 451                inputs, dict
 452            ), "input key must be provided if inputs is a dictionary"
 453            species, embedding = inputs
 454        else:
 455            species, embedding = inputs["species"], inputs[self.input_key]
 456
 457        ############################
 458        # build species to network index mapping (static => fixed when jitted)
 459        rev_idx = PERIODIC_TABLE_REV_IDX
 460        maxidx = max(rev_idx.values())
 461        if isinstance(self.species_order, str):
 462            species_order = [el.strip() for el in self.species_order.split(",")]
 463        else:
 464            species_order = [el for el in self.species_order]
 465        nspecies = len(species_order)
 466        conv_tensor_ = np.full((maxidx + 2,), -1, dtype=np.int32)
 467        for i, s in enumerate(species_order):
 468            conv_tensor_[rev_idx[s]] = i
 469        conv_tensor = jnp.asarray(conv_tensor_)
 470        indices = conv_tensor[species]
 471
 472        ############################
 473        # build shape-sharing networks using vmap
 474        networks = nn.vmap(
 475            FullyConnectedNet,
 476            variable_axes={"params": 0},
 477            split_rngs={"params": True},
 478            in_axes=0,
 479        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
 480        # repeat input along a new axis to compute for all species at once
 481        x = jnp.broadcast_to(
 482            embedding[None, :, :], (nspecies, *embedding.shape)
 483        )
 484
 485        # apply networks to input and select the output corresponding to the species
 486        out = jnp.squeeze(
 487            jnp.take_along_axis(networks(x), indices[None, :, None], axis=0), axis=0
 488        )
 489
 490        out = jnp.where((indices >= 0)[:, None], out, 0.0)
 491        if self.squeeze and out.shape[-1] == 1:
 492            out = jnp.squeeze(out, axis=-1)
 493        ############################
 494
 495        if self.input_key is not None:
 496            output_key = self.name if self.output_key is None else self.output_key
 497            return {**inputs, output_key: out} if output_key is not None else out
 498        return out
 499
 500
 501class MOENet(nn.Module):
 502    """Mixture of Experts neural network.
 503
 504    FID: MOE_NET
 505
 506    This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks
 507    to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router.
 508
 509    """
 510
 511    neurons: Sequence[int]
 512    """A sequence of integers representing the number of neurons in each shape-sharing network."""
 513    num_networks: int
 514    """The number of shape-sharing networks to create."""
 515    activation: Union[Callable, str] = "silu"
 516    """The activation function to use in the shape-sharing networks."""
 517    use_bias: bool = True
 518    """Whether to include bias in the shape-sharing networks."""
 519    input_key: Optional[str] = None
 520    """The key of the input tensor."""
 521    output_key: Optional[str] = None
 522    """The key of the output tensor."""
 523    squeeze: bool = False
 524    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 525    
 526    kernel_init: Union[str, Callable] = "lecun_normal()"
 527    """The kernel initialization method to use in the shape-sharing networks."""
 528    router_key: Optional[str] = None
 529    """The key of the router tensor. If None, the router is assumed to be the same as the input tensor."""
 530
 531    FID: ClassVar[str] = "MOE_NET"
 532
 533    @nn.compact
 534    def __call__(
 535        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 536    ) -> Union[dict, jax.Array]:
 537        if self.input_key is None:
 538            assert not isinstance(
 539                inputs, dict
 540            ), "input key must be provided if inputs is a dictionary"
 541            if isinstance(inputs, tuple):
 542                embedding, router = inputs
 543            else:
 544                embedding = router = inputs
 545        else:
 546            embedding = inputs[self.input_key]
 547            router = (
 548                inputs[self.router_key] if self.router_key is not None else embedding
 549            )
 550
 551        ############################
 552        # build shape-sharing networks using vmap
 553        networks = nn.vmap(
 554            FullyConnectedNet,
 555            variable_axes={"params": 0},
 556            split_rngs={"params": True},
 557            in_axes=0,
 558            out_axes=0,
 559        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
 560        # repeat input along a new axis to compute for all networks at once
 561        x = jnp.repeat(embedding[None, :, :], self.num_networks, axis=0)
 562
 563        w = nn.softmax(nn.Dense(self.num_networks, name="router")(router), axis=-1)
 564
 565        out = (networks(x) * w.T[:, :, None]).sum(axis=0)
 566
 567        if self.squeeze and out.shape[-1] == 1:
 568            out = jnp.squeeze(out, axis=-1)
 569        ############################
 570
 571        if self.input_key is not None:
 572            output_key = self.name if self.output_key is None else self.output_key
 573            return {**inputs, output_key: out} if output_key is not None else out
 574        return out
 575
 576class ChannelNet(nn.Module):
 577    """Apply a different neural network to each channel.
 578    
 579    FID: CHANNEL_NET
 580    """
 581
 582    neurons: Sequence[int]
 583    """A sequence of integers representing the number of neurons in each shape-sharing network."""
 584    activation: Union[Callable, str] = "silu"
 585    """The activation function to use in the shape-sharing networks."""
 586    use_bias: bool = True
 587    """Whether to include bias in the shape-sharing networks."""
 588    input_key: Optional[str] = None
 589    """The key of the input tensor."""
 590    output_key: Optional[str] = None
 591    """The key of the output tensor."""
 592    squeeze: bool = False
 593    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 594    kernel_init: Union[str, Callable] = "lecun_normal()"
 595    """The kernel initialization method to use in the shape-sharing networks."""
 596    channel_axis: int = -2
 597    """The axis to use as channel. Its length will be the number of shape-sharing networks."""
 598
 599    FID: ClassVar[str] = "CHANNEL_NET"
 600
 601    @nn.compact
 602    def __call__(
 603        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 604    ) -> Union[dict, jax.Array]:
 605        if self.input_key is None:
 606            assert not isinstance(
 607                inputs, dict
 608            ), "input key must be provided if inputs is a dictionary"
 609            x = inputs
 610        else:
 611            x = inputs[self.input_key]
 612
 613        ############################
 614        # build shape-sharing networks using vmap
 615        networks = nn.vmap(
 616            FullyConnectedNet,
 617            variable_axes={"params": 0},
 618            split_rngs={"params": True},
 619            in_axes=self.channel_axis,
 620            out_axes=self.channel_axis,
 621        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
 622
 623        out = networks(x)
 624
 625        if self.squeeze and out.shape[-1] == 1:
 626            out = jnp.squeeze(out, axis=-1)
 627        ############################
 628
 629        if self.input_key is not None:
 630            output_key = self.name if self.output_key is None else self.output_key
 631            return {**inputs, output_key: out} if output_key is not None else out
 632        return out
 633
 634
 635class GatedPerceptron(nn.Module):
 636    """Gated Perceptron neural network.
 637
 638    FID: GATED_PERCEPTRON
 639
 640    This class represents a Gated Perceptron neural network model. It applies a gating mechanism
 641    to the input data and performs linear transformation using a dense layer followed by an activation function.
 642    """
 643
 644    dim: int
 645    """The dimensionality of the output space."""
 646    use_bias: bool = True
 647    """Whether to include a bias term in the dense layer."""
 648    kernel_init: Union[str, Callable] = "lecun_normal()"
 649    """The kernel initialization method to use."""
 650    activation: Union[Callable, str] = "silu"
 651    """The activation function to use."""
 652
 653    input_key: Optional[str] = None
 654    """The key of the input tensor."""
 655    output_key: Optional[str] = None
 656    """The key of the output tensor."""
 657    squeeze: bool = False
 658    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 659
 660    FID: ClassVar[str] = "GATED_PERCEPTRON"
 661
 662    @nn.compact
 663    def __call__(self, inputs):
 664        if self.input_key is None:
 665            assert not isinstance(
 666                inputs, dict
 667            ), "input key must be provided if inputs is a dictionary"
 668            x = inputs
 669        else:
 670            x = inputs[self.input_key]
 671
 672        # activation = (
 673        #     activation_from_str(self.activation)
 674        #     if isinstance(self.activation, str)
 675        #     else self.activation
 676        # )
 677        kernel_init = (
 678            initializer_from_str(self.kernel_init)
 679            if isinstance(self.kernel_init, str)
 680            else self.kernel_init
 681        )
 682        ############################
 683        gate = jax.nn.sigmoid(
 684            nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x)
 685        )
 686        x = gate * activation_from_str(self.activation)(
 687            nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x)
 688        )
 689
 690        if self.squeeze and out.shape[-1] == 1:
 691            out = jnp.squeeze(out, axis=-1)
 692        ############################
 693
 694        if self.input_key is not None:
 695            output_key = self.name if self.output_key is None else self.output_key
 696            return {**inputs, output_key: x} if output_key is not None else x
 697        return x
 698
 699
 700class ZAcNet(nn.Module):
 701    """ A fully connected neural network module with affine Z-dependent adjustments of activations.
 702    
 703    FID: ZACNET
 704    """
 705
 706    neurons: Sequence[int]
 707    """A sequence of integers representing the dimensions of the network."""
 708    zmax: int = 86
 709    """The maximum atomic number to consider."""
 710    activation: Union[Callable, str] = "silu"
 711    """The activation function to use."""
 712    use_bias: bool = True
 713    """Whether to use bias in the dense layers."""
 714    input_key: Optional[str] = None
 715    """The key of the input tensor."""
 716    output_key: Optional[str] = None
 717    """The key of the output tensor."""
 718    squeeze: bool = False
 719    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 720    kernel_init: Union[str, Callable] = "lecun_normal()"
 721    """The kernel initialization method to use."""
 722    species_key: str = "species"
 723    """The key of the species tensor."""
 724
 725    FID: ClassVar[str] = "ZACNET"
 726
 727    @nn.compact
 728    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 729        if self.input_key is None:
 730            assert not isinstance(
 731                inputs, dict
 732            ), "input key must be provided if inputs is a dictionary"
 733            species, x = inputs
 734        else:
 735            species, x = inputs[self.species_key], inputs[self.input_key]
 736
 737        # activation = (
 738        #     activation_from_str(self.activation)
 739        #     if isinstance(self.activation, str)
 740        #     else self.activation
 741        # )
 742        kernel_init = (
 743            initializer_from_str(self.kernel_init)
 744            if isinstance(self.kernel_init, str)
 745            else self.kernel_init
 746        )
 747        ############################
 748        for i, d in enumerate(self.neurons[:-1]):
 749            x = nn.Dense(
 750                d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init
 751            )(x)
 752            sig = self.param(
 753                f"sig_{i+1}",
 754                lambda key, shape: jnp.ones(shape, dtype=x.dtype),
 755                (self.zmax + 2, d),
 756            )[species]
 757            if self.use_bias:
 758                b = self.param(
 759                    f"b_{i+1}",
 760                    lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 761                    (self.zmax + 2, d),
 762                )[species]
 763            else:
 764                b = 0
 765            x = activation_from_str(self.activation)(sig * x + b)
 766        x = nn.Dense(
 767            self.neurons[-1],
 768            use_bias=self.use_bias,
 769            name=f"Layer_{len(self.neurons)}",
 770            kernel_init=kernel_init,
 771        )(x)
 772        sig = self.param(
 773            f"sig_{len(self.neurons)}",
 774            lambda key, shape: jnp.ones(shape, dtype=x.dtype),
 775            (self.zmax + 2, self.neurons[-1]),
 776        )[species]
 777        if self.use_bias:
 778            b = self.param(
 779                f"b_{len(self.neurons)}",
 780                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 781                (self.zmax + 2, self.neurons[-1]),
 782            )[species]
 783        else:
 784            b = 0
 785        x = sig * x + b
 786        if self.squeeze and x.shape[-1] == 1:
 787            x = jnp.squeeze(x, axis=-1)
 788        ############################
 789
 790        if self.input_key is not None:
 791            output_key = self.name if self.output_key is None else self.output_key
 792            return {**inputs, output_key: x} if output_key is not None else x
 793        return x
 794
 795
 796class ZLoRANet(nn.Module):
 797    """A fully connected neural network module with Z-dependent low-rank adaptation.
 798    
 799    FID: ZLORANET
 800    """
 801
 802    neurons: Sequence[int]
 803    """A sequence of integers representing the dimensions of the network."""
 804    ranks: Sequence[int]
 805    """A sequence of integers representing the ranks of the low-rank adaptation at each layer."""
 806    zmax: int = 86
 807    """The maximum atomic number to consider."""
 808    activation: Union[Callable, str] = "silu"
 809    """The activation function to use."""
 810    use_bias: bool = True
 811    """Whether to use bias in the dense layers."""
 812    input_key: Optional[str] = None
 813    """The key of the input tensor."""
 814    output_key: Optional[str] = None
 815    """The key of the output tensor."""
 816    squeeze: bool = False
 817    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 818    kernel_init: Union[str, Callable] = "lecun_normal()"
 819    """The kernel initialization method to use."""
 820    species_key: str = "species"
 821    """The key of the species tensor."""
 822
 823    FID: ClassVar[str] = "ZLORANET"
 824
 825    @nn.compact
 826    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 827        if self.input_key is None:
 828            assert not isinstance(
 829                inputs, dict
 830            ), "input key must be provided if inputs is a dictionary"
 831            species, x = inputs
 832        else:
 833            species, x = inputs[self.species_key], inputs[self.input_key]
 834
 835        # activation = (
 836        #     activation_from_str(self.activation)
 837        #     if isinstance(self.activation, str)
 838        #     else self.activation
 839        # )
 840        kernel_init = (
 841            initializer_from_str(self.kernel_init)
 842            if isinstance(self.kernel_init, str)
 843            else self.kernel_init
 844        )
 845        ############################
 846        for i, d in enumerate(self.neurons[:-1]):
 847            xi = nn.Dense(
 848                d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init
 849            )(x)
 850            A = self.param(
 851                f"A_{i+1}",
 852                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 853                (self.zmax + 2, self.ranks[i], x.shape[-1]),
 854            )[species]
 855            B = self.param(
 856                f"B_{i+1}",
 857                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 858                (self.zmax + 2, d, self.ranks[i]),
 859            )[species]
 860            Ax = jnp.einsum("zrd,zd->zr", A, x)
 861            BAx = jnp.einsum("zrd,zd->zr", B, Ax)
 862            x = activation_from_str(self.activation)(xi + BAx)
 863        xi = nn.Dense(
 864            self.neurons[-1],
 865            use_bias=self.use_bias,
 866            name=f"Layer_{len(self.neurons)}",
 867            kernel_init=kernel_init,
 868        )(x)
 869        A = self.param(
 870            f"A_{len(self.neurons)}",
 871            lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 872            (self.zmax + 2, self.ranks[-1], x.shape[-1]),
 873        )[species]
 874        B = self.param(
 875            f"B_{len(self.neurons)}",
 876            lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 877            (self.zmax + 2, self.neurons[-1], self.ranks[-1]),
 878        )[species]
 879        Ax = jnp.einsum("zrd,zd->zr", A, x)
 880        BAx = jnp.einsum("zrd,zd->zr", B, Ax)
 881        x = xi + BAx
 882        if self.squeeze and x.shape[-1] == 1:
 883            x = jnp.squeeze(x, axis=-1)
 884        ############################
 885
 886        if self.input_key is not None:
 887            output_key = self.name if self.output_key is None else self.output_key
 888            return {**inputs, output_key: x} if output_key is not None else x
 889        return x
 890
 891
 892class BlockIndexNet(nn.Module):
 893    """Chemical-species-specific neural network using precomputed species index.
 894
 895    FID: BLOCK_INDEX_NET
 896
 897    A neural network that applies a species-specific fully connected network to each atom embedding.
 898    A species index must be provided to filter the embeddings for each species and apply the corresponding network.
 899    This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer`
 900
 901    """
 902
 903    output_dim: int
 904    """The dimension of the output of the fully connected networks."""
 905    hidden_neurons: Sequence[int]
 906    """The hidden dimensions of the fully connected networks.
 907        If a dictionary is provided, it should map species names to dimensions.
 908        If a sequence is provided, the same dimensions will be used for all species."""
 909    used_blocks: Optional[Sequence[str]] = None
 910    """The blocks to use. If None, all blocks will be used."""
 911    activation: Union[Callable, str] = "silu"
 912    """The activation function to use in the fully connected networks."""
 913    use_bias: bool = True
 914    """Whether to include bias terms in the fully connected networks."""
 915    input_key: Optional[str] = None
 916    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
 917    block_index_key: str = "block_index"
 918    """The key in the input dictionary that corresponds to the block index of the atoms. See `fennol.models.preprocessing.BlockIndexer`"""
 919    output_key: Optional[str] = None
 920    """The key in the output dictionary that corresponds to the network's output."""
 921
 922    squeeze: bool = False
 923    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 924    kernel_init: Union[str, Callable] = "lecun_normal()"
 925    """The kernel initialization method for the fully connected networks."""
 926    # check_unhandled: bool = True
 927
 928    FID: ClassVar[str] = "BLOCK_INDEX_NET"
 929
 930    # def setup(self):
 931    #     all_blocks = CHEMICAL_BLOCKS_NAMES
 932    #     if self.used_blocks is None:
 933    #         used_blocks = all_blocks
 934    #     else:
 935    #         used_blocks = []
 936    #         for b in self.used_blocks:
 937    #             b_=str(b).strip().upper()
 938    #             if b_ not in all_blocks:
 939    #                 raise ValueError(f"Block {b} not found in {all_blocks}")
 940    #             used_blocks.append(b_)
 941    #     used_blocks = set(used_blocks)
 942    #     self._used_blocks = used_blocks
 943
 944    #     if not (
 945    #         isinstance(self.hidden_neurons, dict)
 946    #         or isinstance(self.hidden_neurons, FrozenDict)
 947    #     ):
 948    #         neurons = {k: self.hidden_neurons for k in used_blocks}
 949    #     else:
 950    #         neurons = {}
 951    #         for b in self.hidden_neurons.keys():
 952    #             b_=str(b).strip().upper()
 953    #             if b_ not in all_blocks:
 954    #                 raise ValueError(f"Block {b} does not exist.  Available blocks are {all_blocks}")
 955    #             neurons[b_] = self.hidden_neurons[b]
 956    #         used_blocks = set(neurons.keys())
 957    #         if used_blocks != self._used_blocks and self.used_blocks is not None:
 958    #             print(
 959    #                 f"Warning: hidden neurons definitions do not match specified used_blocks {self.used_blocks}. Using blocks defined in hidden_neurons.")
 960    #         self._used_blocks = used_blocks
 961
 962    #     self.networks = {
 963    #         k: FullyConnectedNet(
 964    #             [*neurons[k], self.output_dim],
 965    #             self.activation,
 966    #             self.use_bias,
 967    #             name=k,
 968    #             kernel_init=self.kernel_init,
 969    #         )
 970    #         for k in self._used_blocks
 971    #     }
 972
 973    @nn.compact
 974    def __call__(
 975        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 976    ) -> Union[dict, jax.Array]:
 977
 978        if self.input_key is None:
 979            assert not isinstance(
 980                inputs, dict
 981            ), "input key must be provided if inputs is a dictionary"
 982            species, embedding, block_index = inputs
 983        else:
 984            species, embedding = inputs["species"], inputs[self.input_key]
 985            block_index = inputs[self.block_index_key]
 986
 987        assert isinstance(
 988            block_index, dict
 989        ), "block_index must be a dictionary for BlockIndexNet"
 990
 991        networks = {
 992            k: FullyConnectedNet(
 993                [*self.hidden_neurons, self.output_dim],
 994                self.activation,
 995                self.use_bias,
 996                name=k,
 997                kernel_init=self.kernel_init,
 998            )
 999            for k in block_index.keys()
1000        }
1001
1002        ############################
1003        # initialization => instantiate all networks
1004        if self.is_initializing():
1005            x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype)
1006            [net(x) for net in networks.values()]
1007
1008        # if self.check_unhandled:
1009        #     for b in block_index.keys():
1010        #         if b not in networks.keys():
1011        #             raise ValueError(f"Block {b} not found in networks. Available blocks are {self.networks.keys()}")
1012
1013        ############################
1014        outputs = []
1015        indices = []
1016        for s, net in networks.items():
1017            if s not in block_index:
1018                continue
1019            if block_index[s] is None:
1020                continue
1021            idx = block_index[s]
1022            o = net(embedding[idx])
1023            outputs.append(o)
1024            indices.append(idx)
1025
1026        o = jnp.concatenate(outputs, axis=0)
1027        idx = jnp.concatenate(indices, axis=0)
1028
1029        out = (
1030            jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype)
1031            .at[idx]
1032            .set(o, mode="drop")
1033        )
1034
1035        if self.squeeze and out.shape[-1] == 1:
1036            out = jnp.squeeze(out, axis=-1)
1037        ############################
1038
1039        if self.input_key is not None:
1040            output_key = self.name if self.output_key is None else self.output_key
1041            return {**inputs, output_key: out} if output_key is not None else out
1042        return out
class FullyConnectedNet(flax.linen.module.Module):
14class FullyConnectedNet(nn.Module):
15    """A fully connected neural network module.
16
17    FID: NEURAL_NET
18    """
19
20    neurons: Sequence[int]
21    """A sequence of integers representing the dimensions of the network."""
22    activation: Union[Callable, str] = "silu"
23    """The activation function to use."""
24    use_bias: bool = True
25    """Whether to use bias in the dense layers."""
26    input_key: Optional[str] = None
27    """The key of the input tensor."""
28    output_key: Optional[str] = None
29    """The key of the output tensor."""
30    squeeze: bool = False
31    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
32    kernel_init: Union[str, Callable] = "lecun_normal()"
33    """The kernel initialization method to use."""
34
35    FID: ClassVar[str] = "NEURAL_NET"
36
37    @nn.compact
38    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
39        """Applies the neural network to the given inputs.
40
41        Args:
42            inputs (Union[dict, jax.Array]): If input_key is None, a JAX array containing the inputs to the neural network. Else, a dictionary containing the inputs at the key input_key.
43
44        Returns:
45            Union[dict, jax.Array]: If output_key is None, the output tensor of the neural network. Else, a dictionary containing the original inputs and the output tensor at the key output_key.
46        """
47        if self.input_key is None:
48            assert not isinstance(
49                inputs, dict
50            ), "input key must be provided if inputs is a dictionary"
51            x = inputs
52        else:
53            x = inputs[self.input_key]
54
55        # activation = (
56        #     activation_from_str(self.activation)
57        #     if isinstance(self.activation, str)
58        #     else self.activation
59        # )
60        kernel_init = (
61            initializer_from_str(self.kernel_init)
62            if isinstance(self.kernel_init, str)
63            else self.kernel_init
64        )
65        ############################
66        for i, d in enumerate(self.neurons[:-1]):
67            x = nn.Dense(
68                d,
69                use_bias=self.use_bias,
70                name=f"Layer_{i+1}",
71                kernel_init=kernel_init,
72                # precision=jax.lax.Precision.HIGH,
73            )(x)
74            x = activation_from_str(self.activation)(x)
75        x = nn.Dense(
76            self.neurons[-1],
77            use_bias=self.use_bias,
78            name=f"Layer_{len(self.neurons)}",
79            kernel_init=kernel_init,
80            # precision=jax.lax.Precision.HIGH,
81        )(x)
82        if self.squeeze and x.shape[-1] == 1:
83            x = jnp.squeeze(x, axis=-1)
84        ############################
85
86        if self.input_key is not None:
87            output_key = self.name if self.output_key is None else self.output_key
88            return {**inputs, output_key: x} if output_key is not None else x
89        return x

A fully connected neural network module.

FID: NEURAL_NET

FullyConnectedNet( neurons: Sequence[int], activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the dimensions of the network.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to use bias in the dense layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

FID: ClassVar[str] = 'NEURAL_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ResMLP(flax.linen.module.Module):
 92class ResMLP(nn.Module):
 93    """Residual neural network as defined in the SpookyNet paper.
 94    
 95    FID: RES_MLP
 96    """
 97
 98    use_bias: bool = True
 99    """Whether to include bias in the linear layers."""
100    input_key: Optional[str] = None
101    """The key of the input tensor."""
102    output_key: Optional[str] = None
103    """The key of the output tensor."""
104
105    kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')"
106    """The kernel initialization method to use."""
107    res_only: bool = False
108    """Whether to only apply the residual connection without additional activation and linear layer."""
109
110    FID: ClassVar[str] = "RES_MLP"
111
112    @nn.compact
113    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
114        if self.input_key is None:
115            assert not isinstance(
116                inputs, dict
117            ), "input key must be provided if inputs is a dictionary"
118            x = inputs
119        else:
120            x = inputs[self.input_key]
121
122        kernel_init = (
123            initializer_from_str(self.kernel_init)
124            if isinstance(self.kernel_init, str)
125            else self.kernel_init
126        )
127        ############################
128        out = nn.Dense(x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init)(
129            TrainableSiLU()(x)
130        )
131        out = x + nn.Dense(
132            x.shape[-1], use_bias=self.use_bias, kernel_init=nn.initializers.zeros
133        )(TrainableSiLU()(out))
134
135        if not self.res_only:
136            out = nn.Dense(
137                x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init
138            )(TrainableSiLU()(out))
139        ############################
140
141        if self.input_key is not None:
142            output_key = self.name if self.output_key is None else self.output_key
143            return {**inputs, output_key: out} if output_key is not None else out
144        return out

Residual neural network as defined in the SpookyNet paper.

FID: RES_MLP

ResMLP( use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')", res_only: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
use_bias: bool = True

Whether to include bias in the linear layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')"

The kernel initialization method to use.

res_only: bool = False

Whether to only apply the residual connection without additional activation and linear layer.

FID: ClassVar[str] = 'RES_MLP'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class FullyResidualNet(flax.linen.module.Module):
147class FullyResidualNet(nn.Module):
148    """A neural network with skip connections at each layer.
149    
150    FID: SKIP_NET
151    """
152
153    dim: int
154    """The dimension of the hidden layers."""
155    output_dim: int
156    """The dimension of the output layer."""
157    nlayers: int
158    """The number of layers in the network."""
159    activation: Union[Callable, str] = "silu"
160    """The activation function to use."""
161    use_bias: bool = True
162    """Whether to include bias terms in the linear layers."""
163    input_key: Optional[str] = None
164    """The key of the input tensor."""
165    output_key: Optional[str] = None
166    """The key of the output tensor."""
167    squeeze: bool = False
168    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
169    kernel_init: Union[str, Callable] = "lecun_normal()"
170    """The kernel initialization method to use."""
171
172    FID: ClassVar[str] = "SKIP_NET"
173
174    @nn.compact
175    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
176        if self.input_key is None:
177            assert not isinstance(
178                inputs, dict
179            ), "input key must be provided if inputs is a dictionary"
180            x = inputs
181        else:
182            x = inputs[self.input_key]
183
184        # activation = (
185        #     activation_from_str(self.activation)
186        #     if isinstance(self.activation, str)
187        #     else self.activation
188        # )
189        kernel_init = (
190            initializer_from_str(self.kernel_init)
191            if isinstance(self.kernel_init, str)
192            else self.kernel_init
193        )
194        ############################
195        if x.shape[-1] != self.dim:
196            x = nn.Dense(
197                self.dim,
198                use_bias=self.use_bias,
199                name=f"Reshape",
200                kernel_init=kernel_init,
201            )(x)
202
203        for i in range(self.nlayers - 1):
204            x = x + activation_from_str(self.activation)(
205                nn.Dense(
206                    self.dim,
207                    use_bias=self.use_bias,
208                    name=f"Layer_{i+1}",
209                    kernel_init=kernel_init,
210                )(x)
211            )
212        x = nn.Dense(
213            self.output_dim,
214            use_bias=self.use_bias,
215            name=f"Layer_{self.nlayers}",
216            kernel_init=kernel_init,
217        )(x)
218        if self.squeeze and x.shape[-1] == 1:
219            x = jnp.squeeze(x, axis=-1)
220        ############################
221
222        if self.input_key is not None:
223            output_key = self.name if self.output_key is None else self.output_key
224            return {**inputs, output_key: x} if output_key is not None else x
225        return x

A neural network with skip connections at each layer.

FID: SKIP_NET

FullyResidualNet( dim: int, output_dim: int, nlayers: int, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int

The dimension of the hidden layers.

output_dim: int

The dimension of the output layer.

nlayers: int

The number of layers in the network.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to include bias terms in the linear layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

FID: ClassVar[str] = 'SKIP_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class HierarchicalNet(flax.linen.module.Module):
228class HierarchicalNet(nn.Module):
229    """Neural network for a sequence of inputs (in axis=-2) with a decay factor
230    
231    FID: HIERARCHICAL_NET
232    """
233
234    neurons: Sequence[int]
235    """A sequence of integers representing the number of neurons in each layer."""
236    activation: Union[Callable, str] = "silu"
237    """The activation function to use."""
238    use_bias: bool = True
239    """Whether to include bias terms in the linear layers."""
240    input_key: Optional[str] = None
241    """The key of the input tensor."""
242    output_key: Optional[str] = None
243    """The key of the output tensor."""
244    decay: float = 0.01
245    """The decay factor to scale each element of the sequence."""
246    squeeze: bool = False
247    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
248    kernel_init: Union[str, Callable] = "lecun_normal()"
249    """The kernel initialization method to use."""
250
251    FID: ClassVar[str] = "HIERARCHICAL_NET"
252
253    @nn.compact
254    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
255        if self.input_key is None:
256            assert not isinstance(
257                inputs, dict
258            ), "input key must be provided if inputs is a dictionary"
259            x = inputs
260        else:
261            x = inputs[self.input_key]
262
263        ############################
264        networks = nn.vmap(
265            FullyConnectedNet,
266            variable_axes={"params": 0},
267            split_rngs={"params": True},
268            in_axes=-2,
269            out_axes=-2,
270            kenel_init=self.kernel_init,
271        )(self.neurons, self.activation, self.use_bias)
272
273        out = networks(x)
274        # scale each layer by a decay factor
275        decay = jnp.asarray([self.decay**i for i in range(out.shape[-2])])
276        out = out * decay[..., :, None]
277
278        if self.squeeze and out.shape[-1] == 1:
279            out = jnp.squeeze(out, axis=-1)
280        ############################
281
282        if self.input_key is not None:
283            output_key = self.name if self.output_key is None else self.output_key
284            return {**inputs, output_key: out} if output_key is not None else out
285        return out

Neural network for a sequence of inputs (in axis=-2) with a decay factor

FID: HIERARCHICAL_NET

HierarchicalNet( neurons: Sequence[int], activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, decay: float = 0.01, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the number of neurons in each layer.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to include bias terms in the linear layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

decay: float = 0.01

The decay factor to scale each element of the sequence.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

FID: ClassVar[str] = 'HIERARCHICAL_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SpeciesIndexNet(flax.linen.module.Module):
288class SpeciesIndexNet(nn.Module):
289    """Chemical-species-specific neural network using precomputed species index.
290
291    FID: SPECIES_INDEX_NET
292
293    A neural network that applies a species-specific fully connected network to each atom embedding.
294    A species index must be provided to filter the embeddings for each species and apply the corresponding network.
295    This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer`
296
297    """
298
299    output_dim: int
300    """The dimension of the output of the fully connected networks."""
301    hidden_neurons: Union[dict, FrozenDict, Sequence[int]]
302    """The hidden dimensions of the fully connected networks.
303        If a dictionary is provided, it should map species names to dimensions.
304        If a sequence is provided, the same dimensions will be used for all species."""
305    species_order: Optional[Union[str, Sequence[str]]] = None
306    """The species for which to build a network. Only required if neurons is not a dictionary."""
307    activation: Union[Callable, str] = "silu"
308    """The activation function to use in the fully connected networks."""
309    use_bias: bool = True
310    """Whether to include bias terms in the fully connected networks."""
311    input_key: Optional[str] = None
312    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
313    species_index_key: str = "species_index"
314    """The key in the input dictionary that corresponds to the species index of the atoms. See `fennol.models.preprocessing.SpeciesIndexer`"""
315    output_key: Optional[str] = None
316    """The key in the output dictionary that corresponds to the network's output."""
317
318    squeeze: bool = False
319    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
320    kernel_init: Union[str, Callable] = "lecun_normal()"
321    """The kernel initialization method for the fully connected networks."""
322    check_unhandled: bool = True
323
324    FID: ClassVar[str] = "SPECIES_INDEX_NET"
325
326    def setup(self):
327        if not (
328            isinstance(self.hidden_neurons, dict)
329            or isinstance(self.hidden_neurons, FrozenDict)
330        ):
331            assert (
332                self.species_order is not None
333            ), "species_order must be provided if hidden_neurons is not a dictionary"
334            if isinstance(self.species_order, str):
335                species_order = [el.strip() for el in self.species_order.split(",")]
336            else:
337                species_order = [el for el in self.species_order]
338            neurons = {k: self.hidden_neurons for k in species_order}
339        else:
340            neurons = self.hidden_neurons
341            species_order = list(neurons.keys())
342        for species in species_order:
343            assert (
344                species in PERIODIC_TABLE
345            ), f"species {species} not found in periodic table"
346
347        self.networks = {
348            k: FullyConnectedNet(
349                [*neurons[k], self.output_dim],
350                self.activation,
351                self.use_bias,
352                name=k,
353                kernel_init=self.kernel_init,
354            )
355            for k in species_order
356        }
357
358    def __call__(
359        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
360    ) -> Union[dict, jax.Array]:
361
362        if self.input_key is None:
363            assert not isinstance(
364                inputs, dict
365            ), "input key must be provided if inputs is a dictionary"
366            species, embedding, species_index = inputs
367        else:
368            species, embedding = inputs["species"], inputs[self.input_key]
369            species_index = inputs[self.species_index_key]
370
371        assert isinstance(
372            species_index, dict
373        ), "species_index must be a dictionary for SpeciesIndexNetHet"
374
375        ############################
376        # initialization => instantiate all networks
377        if self.is_initializing():
378            x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype)
379            [net(x) for net in self.networks.values()]
380
381        if self.check_unhandled:
382            for b in species_index.keys():
383                if b not in self.networks.keys():
384                    raise ValueError(f"Species {b} not found in networks. Handled species are {self.networks.keys()}")
385
386        ############################
387        outputs = []
388        indices = []
389        for s, net in self.networks.items():
390            if s not in species_index:
391                continue
392            idx = species_index[s]
393            o = net(embedding[idx])
394            outputs.append(o)
395            indices.append(idx)
396
397        o = jnp.concatenate(outputs, axis=0)
398        idx = jnp.concatenate(indices, axis=0)
399
400        out = (
401            jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype)
402            .at[idx]
403            .set(o, mode="drop")
404        )
405
406        if self.squeeze and out.shape[-1] == 1:
407            out = jnp.squeeze(out, axis=-1)
408        ############################
409
410        if self.input_key is not None:
411            output_key = self.name if self.output_key is None else self.output_key
412            return {**inputs, output_key: out} if output_key is not None else out
413        return out

Chemical-species-specific neural network using precomputed species index.

FID: SPECIES_INDEX_NET

A neural network that applies a species-specific fully connected network to each atom embedding. A species index must be provided to filter the embeddings for each species and apply the corresponding network. This index can be obtained using the SPECIES_INDEXER preprocessing module from fennol.models.preprocessing.SpeciesIndexer

SpeciesIndexNet( output_dim: int, hidden_neurons: Union[dict, flax.core.frozen_dict.FrozenDict, Sequence[int]], species_order: Union[str, Sequence[str], NoneType] = None, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, species_index_key: str = 'species_index', output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', check_unhandled: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
output_dim: int

The dimension of the output of the fully connected networks.

hidden_neurons: Union[dict, flax.core.frozen_dict.FrozenDict, Sequence[int]]

The hidden dimensions of the fully connected networks. If a dictionary is provided, it should map species names to dimensions. If a sequence is provided, the same dimensions will be used for all species.

species_order: Union[str, Sequence[str], NoneType] = None

The species for which to build a network. Only required if neurons is not a dictionary.

activation: Union[Callable, str] = 'silu'

The activation function to use in the fully connected networks.

use_bias: bool = True

Whether to include bias terms in the fully connected networks.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the embeddings of the atoms.

species_index_key: str = 'species_index'

The key in the input dictionary that corresponds to the species index of the atoms. See fennol.models.preprocessing.SpeciesIndexer

output_key: Optional[str] = None

The key in the output dictionary that corresponds to the network's output.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method for the fully connected networks.

check_unhandled: bool = True
FID: ClassVar[str] = 'SPECIES_INDEX_NET'
def setup(self):
326    def setup(self):
327        if not (
328            isinstance(self.hidden_neurons, dict)
329            or isinstance(self.hidden_neurons, FrozenDict)
330        ):
331            assert (
332                self.species_order is not None
333            ), "species_order must be provided if hidden_neurons is not a dictionary"
334            if isinstance(self.species_order, str):
335                species_order = [el.strip() for el in self.species_order.split(",")]
336            else:
337                species_order = [el for el in self.species_order]
338            neurons = {k: self.hidden_neurons for k in species_order}
339        else:
340            neurons = self.hidden_neurons
341            species_order = list(neurons.keys())
342        for species in species_order:
343            assert (
344                species in PERIODIC_TABLE
345            ), f"species {species} not found in periodic table"
346
347        self.networks = {
348            k: FullyConnectedNet(
349                [*neurons[k], self.output_dim],
350                self.activation,
351                self.use_bias,
352                name=k,
353                kernel_init=self.kernel_init,
354            )
355            for k in species_order
356        }

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module's setup method (see __setattr__())::

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChemicalNet(flax.linen.module.Module):
416class ChemicalNet(nn.Module):
417    """optimized Chemical-species-specific neural network.
418
419    FID: CHEMICAL_NET
420
421    A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species.
422    This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once.
423    The optimization is allowed because all networks have the same shape.
424
425    """
426
427    species_order: Union[str, Sequence[str]]
428    """The species for which to build a network."""
429    neurons: Sequence[int]
430    """The dimensions of the fully connected networks."""
431    activation: Union[Callable, str] = "silu"
432    """The activation function to use in the fully connected networks."""
433    use_bias: bool = True
434    """Whether to include bias terms in the fully connected networks."""
435    input_key: Optional[str] = None
436    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
437    output_key: Optional[str] = None
438    """The key in the output dictionary that corresponds to the network's output."""
439    squeeze: bool = False
440    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
441    kernel_init: Union[str, Callable] = "lecun_normal()"
442    """The kernel initialization method for the fully connected networks."""
443
444    FID: ClassVar[str] = "CHEMICAL_NET"
445
446    @nn.compact
447    def __call__(
448        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
449    ) -> Union[dict, jax.Array]:
450        if self.input_key is None:
451            assert not isinstance(
452                inputs, dict
453            ), "input key must be provided if inputs is a dictionary"
454            species, embedding = inputs
455        else:
456            species, embedding = inputs["species"], inputs[self.input_key]
457
458        ############################
459        # build species to network index mapping (static => fixed when jitted)
460        rev_idx = PERIODIC_TABLE_REV_IDX
461        maxidx = max(rev_idx.values())
462        if isinstance(self.species_order, str):
463            species_order = [el.strip() for el in self.species_order.split(",")]
464        else:
465            species_order = [el for el in self.species_order]
466        nspecies = len(species_order)
467        conv_tensor_ = np.full((maxidx + 2,), -1, dtype=np.int32)
468        for i, s in enumerate(species_order):
469            conv_tensor_[rev_idx[s]] = i
470        conv_tensor = jnp.asarray(conv_tensor_)
471        indices = conv_tensor[species]
472
473        ############################
474        # build shape-sharing networks using vmap
475        networks = nn.vmap(
476            FullyConnectedNet,
477            variable_axes={"params": 0},
478            split_rngs={"params": True},
479            in_axes=0,
480        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
481        # repeat input along a new axis to compute for all species at once
482        x = jnp.broadcast_to(
483            embedding[None, :, :], (nspecies, *embedding.shape)
484        )
485
486        # apply networks to input and select the output corresponding to the species
487        out = jnp.squeeze(
488            jnp.take_along_axis(networks(x), indices[None, :, None], axis=0), axis=0
489        )
490
491        out = jnp.where((indices >= 0)[:, None], out, 0.0)
492        if self.squeeze and out.shape[-1] == 1:
493            out = jnp.squeeze(out, axis=-1)
494        ############################
495
496        if self.input_key is not None:
497            output_key = self.name if self.output_key is None else self.output_key
498            return {**inputs, output_key: out} if output_key is not None else out
499        return out

optimized Chemical-species-specific neural network.

FID: CHEMICAL_NET

A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species. This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once. The optimization is allowed because all networks have the same shape.

ChemicalNet( species_order: Union[str, Sequence[str]], neurons: Sequence[int], activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
species_order: Union[str, Sequence[str]]

The species for which to build a network.

neurons: Sequence[int]

The dimensions of the fully connected networks.

activation: Union[Callable, str] = 'silu'

The activation function to use in the fully connected networks.

use_bias: bool = True

Whether to include bias terms in the fully connected networks.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the embeddings of the atoms.

output_key: Optional[str] = None

The key in the output dictionary that corresponds to the network's output.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method for the fully connected networks.

FID: ClassVar[str] = 'CHEMICAL_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class MOENet(flax.linen.module.Module):
502class MOENet(nn.Module):
503    """Mixture of Experts neural network.
504
505    FID: MOE_NET
506
507    This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks
508    to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router.
509
510    """
511
512    neurons: Sequence[int]
513    """A sequence of integers representing the number of neurons in each shape-sharing network."""
514    num_networks: int
515    """The number of shape-sharing networks to create."""
516    activation: Union[Callable, str] = "silu"
517    """The activation function to use in the shape-sharing networks."""
518    use_bias: bool = True
519    """Whether to include bias in the shape-sharing networks."""
520    input_key: Optional[str] = None
521    """The key of the input tensor."""
522    output_key: Optional[str] = None
523    """The key of the output tensor."""
524    squeeze: bool = False
525    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
526    
527    kernel_init: Union[str, Callable] = "lecun_normal()"
528    """The kernel initialization method to use in the shape-sharing networks."""
529    router_key: Optional[str] = None
530    """The key of the router tensor. If None, the router is assumed to be the same as the input tensor."""
531
532    FID: ClassVar[str] = "MOE_NET"
533
534    @nn.compact
535    def __call__(
536        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
537    ) -> Union[dict, jax.Array]:
538        if self.input_key is None:
539            assert not isinstance(
540                inputs, dict
541            ), "input key must be provided if inputs is a dictionary"
542            if isinstance(inputs, tuple):
543                embedding, router = inputs
544            else:
545                embedding = router = inputs
546        else:
547            embedding = inputs[self.input_key]
548            router = (
549                inputs[self.router_key] if self.router_key is not None else embedding
550            )
551
552        ############################
553        # build shape-sharing networks using vmap
554        networks = nn.vmap(
555            FullyConnectedNet,
556            variable_axes={"params": 0},
557            split_rngs={"params": True},
558            in_axes=0,
559            out_axes=0,
560        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
561        # repeat input along a new axis to compute for all networks at once
562        x = jnp.repeat(embedding[None, :, :], self.num_networks, axis=0)
563
564        w = nn.softmax(nn.Dense(self.num_networks, name="router")(router), axis=-1)
565
566        out = (networks(x) * w.T[:, :, None]).sum(axis=0)
567
568        if self.squeeze and out.shape[-1] == 1:
569            out = jnp.squeeze(out, axis=-1)
570        ############################
571
572        if self.input_key is not None:
573            output_key = self.name if self.output_key is None else self.output_key
574            return {**inputs, output_key: out} if output_key is not None else out
575        return out

Mixture of Experts neural network.

FID: MOE_NET

This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router.

MOENet( neurons: Sequence[int], num_networks: int, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', router_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the number of neurons in each shape-sharing network.

num_networks: int

The number of shape-sharing networks to create.

activation: Union[Callable, str] = 'silu'

The activation function to use in the shape-sharing networks.

use_bias: bool = True

Whether to include bias in the shape-sharing networks.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use in the shape-sharing networks.

router_key: Optional[str] = None

The key of the router tensor. If None, the router is assumed to be the same as the input tensor.

FID: ClassVar[str] = 'MOE_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChannelNet(flax.linen.module.Module):
577class ChannelNet(nn.Module):
578    """Apply a different neural network to each channel.
579    
580    FID: CHANNEL_NET
581    """
582
583    neurons: Sequence[int]
584    """A sequence of integers representing the number of neurons in each shape-sharing network."""
585    activation: Union[Callable, str] = "silu"
586    """The activation function to use in the shape-sharing networks."""
587    use_bias: bool = True
588    """Whether to include bias in the shape-sharing networks."""
589    input_key: Optional[str] = None
590    """The key of the input tensor."""
591    output_key: Optional[str] = None
592    """The key of the output tensor."""
593    squeeze: bool = False
594    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
595    kernel_init: Union[str, Callable] = "lecun_normal()"
596    """The kernel initialization method to use in the shape-sharing networks."""
597    channel_axis: int = -2
598    """The axis to use as channel. Its length will be the number of shape-sharing networks."""
599
600    FID: ClassVar[str] = "CHANNEL_NET"
601
602    @nn.compact
603    def __call__(
604        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
605    ) -> Union[dict, jax.Array]:
606        if self.input_key is None:
607            assert not isinstance(
608                inputs, dict
609            ), "input key must be provided if inputs is a dictionary"
610            x = inputs
611        else:
612            x = inputs[self.input_key]
613
614        ############################
615        # build shape-sharing networks using vmap
616        networks = nn.vmap(
617            FullyConnectedNet,
618            variable_axes={"params": 0},
619            split_rngs={"params": True},
620            in_axes=self.channel_axis,
621            out_axes=self.channel_axis,
622        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
623
624        out = networks(x)
625
626        if self.squeeze and out.shape[-1] == 1:
627            out = jnp.squeeze(out, axis=-1)
628        ############################
629
630        if self.input_key is not None:
631            output_key = self.name if self.output_key is None else self.output_key
632            return {**inputs, output_key: out} if output_key is not None else out
633        return out

Apply a different neural network to each channel.

FID: CHANNEL_NET

ChannelNet( neurons: Sequence[int], activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', channel_axis: int = -2, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the number of neurons in each shape-sharing network.

activation: Union[Callable, str] = 'silu'

The activation function to use in the shape-sharing networks.

use_bias: bool = True

Whether to include bias in the shape-sharing networks.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use in the shape-sharing networks.

channel_axis: int = -2

The axis to use as channel. Its length will be the number of shape-sharing networks.

FID: ClassVar[str] = 'CHANNEL_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class GatedPerceptron(flax.linen.module.Module):
636class GatedPerceptron(nn.Module):
637    """Gated Perceptron neural network.
638
639    FID: GATED_PERCEPTRON
640
641    This class represents a Gated Perceptron neural network model. It applies a gating mechanism
642    to the input data and performs linear transformation using a dense layer followed by an activation function.
643    """
644
645    dim: int
646    """The dimensionality of the output space."""
647    use_bias: bool = True
648    """Whether to include a bias term in the dense layer."""
649    kernel_init: Union[str, Callable] = "lecun_normal()"
650    """The kernel initialization method to use."""
651    activation: Union[Callable, str] = "silu"
652    """The activation function to use."""
653
654    input_key: Optional[str] = None
655    """The key of the input tensor."""
656    output_key: Optional[str] = None
657    """The key of the output tensor."""
658    squeeze: bool = False
659    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
660
661    FID: ClassVar[str] = "GATED_PERCEPTRON"
662
663    @nn.compact
664    def __call__(self, inputs):
665        if self.input_key is None:
666            assert not isinstance(
667                inputs, dict
668            ), "input key must be provided if inputs is a dictionary"
669            x = inputs
670        else:
671            x = inputs[self.input_key]
672
673        # activation = (
674        #     activation_from_str(self.activation)
675        #     if isinstance(self.activation, str)
676        #     else self.activation
677        # )
678        kernel_init = (
679            initializer_from_str(self.kernel_init)
680            if isinstance(self.kernel_init, str)
681            else self.kernel_init
682        )
683        ############################
684        gate = jax.nn.sigmoid(
685            nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x)
686        )
687        x = gate * activation_from_str(self.activation)(
688            nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x)
689        )
690
691        if self.squeeze and out.shape[-1] == 1:
692            out = jnp.squeeze(out, axis=-1)
693        ############################
694
695        if self.input_key is not None:
696            output_key = self.name if self.output_key is None else self.output_key
697            return {**inputs, output_key: x} if output_key is not None else x
698        return x

Gated Perceptron neural network.

FID: GATED_PERCEPTRON

This class represents a Gated Perceptron neural network model. It applies a gating mechanism to the input data and performs linear transformation using a dense layer followed by an activation function.

GatedPerceptron( dim: int, use_bias: bool = True, kernel_init: Union[str, Callable] = 'lecun_normal()', activation: Union[Callable, str] = 'silu', input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int

The dimensionality of the output space.

use_bias: bool = True

Whether to include a bias term in the dense layer.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

activation: Union[Callable, str] = 'silu'

The activation function to use.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

FID: ClassVar[str] = 'GATED_PERCEPTRON'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ZAcNet(flax.linen.module.Module):
701class ZAcNet(nn.Module):
702    """ A fully connected neural network module with affine Z-dependent adjustments of activations.
703    
704    FID: ZACNET
705    """
706
707    neurons: Sequence[int]
708    """A sequence of integers representing the dimensions of the network."""
709    zmax: int = 86
710    """The maximum atomic number to consider."""
711    activation: Union[Callable, str] = "silu"
712    """The activation function to use."""
713    use_bias: bool = True
714    """Whether to use bias in the dense layers."""
715    input_key: Optional[str] = None
716    """The key of the input tensor."""
717    output_key: Optional[str] = None
718    """The key of the output tensor."""
719    squeeze: bool = False
720    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
721    kernel_init: Union[str, Callable] = "lecun_normal()"
722    """The kernel initialization method to use."""
723    species_key: str = "species"
724    """The key of the species tensor."""
725
726    FID: ClassVar[str] = "ZACNET"
727
728    @nn.compact
729    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
730        if self.input_key is None:
731            assert not isinstance(
732                inputs, dict
733            ), "input key must be provided if inputs is a dictionary"
734            species, x = inputs
735        else:
736            species, x = inputs[self.species_key], inputs[self.input_key]
737
738        # activation = (
739        #     activation_from_str(self.activation)
740        #     if isinstance(self.activation, str)
741        #     else self.activation
742        # )
743        kernel_init = (
744            initializer_from_str(self.kernel_init)
745            if isinstance(self.kernel_init, str)
746            else self.kernel_init
747        )
748        ############################
749        for i, d in enumerate(self.neurons[:-1]):
750            x = nn.Dense(
751                d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init
752            )(x)
753            sig = self.param(
754                f"sig_{i+1}",
755                lambda key, shape: jnp.ones(shape, dtype=x.dtype),
756                (self.zmax + 2, d),
757            )[species]
758            if self.use_bias:
759                b = self.param(
760                    f"b_{i+1}",
761                    lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
762                    (self.zmax + 2, d),
763                )[species]
764            else:
765                b = 0
766            x = activation_from_str(self.activation)(sig * x + b)
767        x = nn.Dense(
768            self.neurons[-1],
769            use_bias=self.use_bias,
770            name=f"Layer_{len(self.neurons)}",
771            kernel_init=kernel_init,
772        )(x)
773        sig = self.param(
774            f"sig_{len(self.neurons)}",
775            lambda key, shape: jnp.ones(shape, dtype=x.dtype),
776            (self.zmax + 2, self.neurons[-1]),
777        )[species]
778        if self.use_bias:
779            b = self.param(
780                f"b_{len(self.neurons)}",
781                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
782                (self.zmax + 2, self.neurons[-1]),
783            )[species]
784        else:
785            b = 0
786        x = sig * x + b
787        if self.squeeze and x.shape[-1] == 1:
788            x = jnp.squeeze(x, axis=-1)
789        ############################
790
791        if self.input_key is not None:
792            output_key = self.name if self.output_key is None else self.output_key
793            return {**inputs, output_key: x} if output_key is not None else x
794        return x

A fully connected neural network module with affine Z-dependent adjustments of activations.

FID: ZACNET

ZAcNet( neurons: Sequence[int], zmax: int = 86, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', species_key: str = 'species', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the dimensions of the network.

zmax: int = 86

The maximum atomic number to consider.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to use bias in the dense layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

species_key: str = 'species'

The key of the species tensor.

FID: ClassVar[str] = 'ZACNET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ZLoRANet(flax.linen.module.Module):
797class ZLoRANet(nn.Module):
798    """A fully connected neural network module with Z-dependent low-rank adaptation.
799    
800    FID: ZLORANET
801    """
802
803    neurons: Sequence[int]
804    """A sequence of integers representing the dimensions of the network."""
805    ranks: Sequence[int]
806    """A sequence of integers representing the ranks of the low-rank adaptation at each layer."""
807    zmax: int = 86
808    """The maximum atomic number to consider."""
809    activation: Union[Callable, str] = "silu"
810    """The activation function to use."""
811    use_bias: bool = True
812    """Whether to use bias in the dense layers."""
813    input_key: Optional[str] = None
814    """The key of the input tensor."""
815    output_key: Optional[str] = None
816    """The key of the output tensor."""
817    squeeze: bool = False
818    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
819    kernel_init: Union[str, Callable] = "lecun_normal()"
820    """The kernel initialization method to use."""
821    species_key: str = "species"
822    """The key of the species tensor."""
823
824    FID: ClassVar[str] = "ZLORANET"
825
826    @nn.compact
827    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
828        if self.input_key is None:
829            assert not isinstance(
830                inputs, dict
831            ), "input key must be provided if inputs is a dictionary"
832            species, x = inputs
833        else:
834            species, x = inputs[self.species_key], inputs[self.input_key]
835
836        # activation = (
837        #     activation_from_str(self.activation)
838        #     if isinstance(self.activation, str)
839        #     else self.activation
840        # )
841        kernel_init = (
842            initializer_from_str(self.kernel_init)
843            if isinstance(self.kernel_init, str)
844            else self.kernel_init
845        )
846        ############################
847        for i, d in enumerate(self.neurons[:-1]):
848            xi = nn.Dense(
849                d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init
850            )(x)
851            A = self.param(
852                f"A_{i+1}",
853                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
854                (self.zmax + 2, self.ranks[i], x.shape[-1]),
855            )[species]
856            B = self.param(
857                f"B_{i+1}",
858                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
859                (self.zmax + 2, d, self.ranks[i]),
860            )[species]
861            Ax = jnp.einsum("zrd,zd->zr", A, x)
862            BAx = jnp.einsum("zrd,zd->zr", B, Ax)
863            x = activation_from_str(self.activation)(xi + BAx)
864        xi = nn.Dense(
865            self.neurons[-1],
866            use_bias=self.use_bias,
867            name=f"Layer_{len(self.neurons)}",
868            kernel_init=kernel_init,
869        )(x)
870        A = self.param(
871            f"A_{len(self.neurons)}",
872            lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
873            (self.zmax + 2, self.ranks[-1], x.shape[-1]),
874        )[species]
875        B = self.param(
876            f"B_{len(self.neurons)}",
877            lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
878            (self.zmax + 2, self.neurons[-1], self.ranks[-1]),
879        )[species]
880        Ax = jnp.einsum("zrd,zd->zr", A, x)
881        BAx = jnp.einsum("zrd,zd->zr", B, Ax)
882        x = xi + BAx
883        if self.squeeze and x.shape[-1] == 1:
884            x = jnp.squeeze(x, axis=-1)
885        ############################
886
887        if self.input_key is not None:
888            output_key = self.name if self.output_key is None else self.output_key
889            return {**inputs, output_key: x} if output_key is not None else x
890        return x

A fully connected neural network module with Z-dependent low-rank adaptation.

FID: ZLORANET

ZLoRANet( neurons: Sequence[int], ranks: Sequence[int], zmax: int = 86, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', species_key: str = 'species', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the dimensions of the network.

ranks: Sequence[int]

A sequence of integers representing the ranks of the low-rank adaptation at each layer.

zmax: int = 86

The maximum atomic number to consider.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to use bias in the dense layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

species_key: str = 'species'

The key of the species tensor.

FID: ClassVar[str] = 'ZLORANET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class BlockIndexNet(flax.linen.module.Module):
 893class BlockIndexNet(nn.Module):
 894    """Chemical-species-specific neural network using precomputed species index.
 895
 896    FID: BLOCK_INDEX_NET
 897
 898    A neural network that applies a species-specific fully connected network to each atom embedding.
 899    A species index must be provided to filter the embeddings for each species and apply the corresponding network.
 900    This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer`
 901
 902    """
 903
 904    output_dim: int
 905    """The dimension of the output of the fully connected networks."""
 906    hidden_neurons: Sequence[int]
 907    """The hidden dimensions of the fully connected networks.
 908        If a dictionary is provided, it should map species names to dimensions.
 909        If a sequence is provided, the same dimensions will be used for all species."""
 910    used_blocks: Optional[Sequence[str]] = None
 911    """The blocks to use. If None, all blocks will be used."""
 912    activation: Union[Callable, str] = "silu"
 913    """The activation function to use in the fully connected networks."""
 914    use_bias: bool = True
 915    """Whether to include bias terms in the fully connected networks."""
 916    input_key: Optional[str] = None
 917    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
 918    block_index_key: str = "block_index"
 919    """The key in the input dictionary that corresponds to the block index of the atoms. See `fennol.models.preprocessing.BlockIndexer`"""
 920    output_key: Optional[str] = None
 921    """The key in the output dictionary that corresponds to the network's output."""
 922
 923    squeeze: bool = False
 924    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 925    kernel_init: Union[str, Callable] = "lecun_normal()"
 926    """The kernel initialization method for the fully connected networks."""
 927    # check_unhandled: bool = True
 928
 929    FID: ClassVar[str] = "BLOCK_INDEX_NET"
 930
 931    # def setup(self):
 932    #     all_blocks = CHEMICAL_BLOCKS_NAMES
 933    #     if self.used_blocks is None:
 934    #         used_blocks = all_blocks
 935    #     else:
 936    #         used_blocks = []
 937    #         for b in self.used_blocks:
 938    #             b_=str(b).strip().upper()
 939    #             if b_ not in all_blocks:
 940    #                 raise ValueError(f"Block {b} not found in {all_blocks}")
 941    #             used_blocks.append(b_)
 942    #     used_blocks = set(used_blocks)
 943    #     self._used_blocks = used_blocks
 944
 945    #     if not (
 946    #         isinstance(self.hidden_neurons, dict)
 947    #         or isinstance(self.hidden_neurons, FrozenDict)
 948    #     ):
 949    #         neurons = {k: self.hidden_neurons for k in used_blocks}
 950    #     else:
 951    #         neurons = {}
 952    #         for b in self.hidden_neurons.keys():
 953    #             b_=str(b).strip().upper()
 954    #             if b_ not in all_blocks:
 955    #                 raise ValueError(f"Block {b} does not exist.  Available blocks are {all_blocks}")
 956    #             neurons[b_] = self.hidden_neurons[b]
 957    #         used_blocks = set(neurons.keys())
 958    #         if used_blocks != self._used_blocks and self.used_blocks is not None:
 959    #             print(
 960    #                 f"Warning: hidden neurons definitions do not match specified used_blocks {self.used_blocks}. Using blocks defined in hidden_neurons.")
 961    #         self._used_blocks = used_blocks
 962
 963    #     self.networks = {
 964    #         k: FullyConnectedNet(
 965    #             [*neurons[k], self.output_dim],
 966    #             self.activation,
 967    #             self.use_bias,
 968    #             name=k,
 969    #             kernel_init=self.kernel_init,
 970    #         )
 971    #         for k in self._used_blocks
 972    #     }
 973
 974    @nn.compact
 975    def __call__(
 976        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 977    ) -> Union[dict, jax.Array]:
 978
 979        if self.input_key is None:
 980            assert not isinstance(
 981                inputs, dict
 982            ), "input key must be provided if inputs is a dictionary"
 983            species, embedding, block_index = inputs
 984        else:
 985            species, embedding = inputs["species"], inputs[self.input_key]
 986            block_index = inputs[self.block_index_key]
 987
 988        assert isinstance(
 989            block_index, dict
 990        ), "block_index must be a dictionary for BlockIndexNet"
 991
 992        networks = {
 993            k: FullyConnectedNet(
 994                [*self.hidden_neurons, self.output_dim],
 995                self.activation,
 996                self.use_bias,
 997                name=k,
 998                kernel_init=self.kernel_init,
 999            )
1000            for k in block_index.keys()
1001        }
1002
1003        ############################
1004        # initialization => instantiate all networks
1005        if self.is_initializing():
1006            x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype)
1007            [net(x) for net in networks.values()]
1008
1009        # if self.check_unhandled:
1010        #     for b in block_index.keys():
1011        #         if b not in networks.keys():
1012        #             raise ValueError(f"Block {b} not found in networks. Available blocks are {self.networks.keys()}")
1013
1014        ############################
1015        outputs = []
1016        indices = []
1017        for s, net in networks.items():
1018            if s not in block_index:
1019                continue
1020            if block_index[s] is None:
1021                continue
1022            idx = block_index[s]
1023            o = net(embedding[idx])
1024            outputs.append(o)
1025            indices.append(idx)
1026
1027        o = jnp.concatenate(outputs, axis=0)
1028        idx = jnp.concatenate(indices, axis=0)
1029
1030        out = (
1031            jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype)
1032            .at[idx]
1033            .set(o, mode="drop")
1034        )
1035
1036        if self.squeeze and out.shape[-1] == 1:
1037            out = jnp.squeeze(out, axis=-1)
1038        ############################
1039
1040        if self.input_key is not None:
1041            output_key = self.name if self.output_key is None else self.output_key
1042            return {**inputs, output_key: out} if output_key is not None else out
1043        return out

Chemical-species-specific neural network using precomputed species index.

FID: BLOCK_INDEX_NET

A neural network that applies a species-specific fully connected network to each atom embedding. A species index must be provided to filter the embeddings for each species and apply the corresponding network. This index can be obtained using the SPECIES_INDEXER preprocessing module from fennol.models.preprocessing.SpeciesIndexer

BlockIndexNet( output_dim: int, hidden_neurons: Sequence[int], used_blocks: Optional[Sequence[str]] = None, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, block_index_key: str = 'block_index', output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
output_dim: int

The dimension of the output of the fully connected networks.

hidden_neurons: Sequence[int]

The hidden dimensions of the fully connected networks. If a dictionary is provided, it should map species names to dimensions. If a sequence is provided, the same dimensions will be used for all species.

used_blocks: Optional[Sequence[str]] = None

The blocks to use. If None, all blocks will be used.

activation: Union[Callable, str] = 'silu'

The activation function to use in the fully connected networks.

use_bias: bool = True

Whether to include bias terms in the fully connected networks.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the embeddings of the atoms.

block_index_key: str = 'block_index'

The key in the input dictionary that corresponds to the block index of the atoms. See fennol.models.preprocessing.BlockIndexer

output_key: Optional[str] = None

The key in the output dictionary that corresponds to the network's output.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method for the fully connected networks.

FID: ClassVar[str] = 'BLOCK_INDEX_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None