fennol.models.misc.nets

   1import flax.linen as nn
   2from typing import Sequence, Callable, Union, Optional, ClassVar, Tuple
   3from ...utils.periodic_table import PERIODIC_TABLE_REV_IDX, PERIODIC_TABLE
   4import jax.numpy as jnp
   5import jax
   6import numpy as np
   7from functools import partial
   8from ...utils.activations import activation_from_str, TrainableSiLU
   9from ...utils.initializers import initializer_from_str, scaled_orthogonal
  10from flax.core import FrozenDict
  11
  12
  13class FullyConnectedNet(nn.Module):
  14    """A fully connected neural network module.
  15
  16    FID: NEURAL_NET
  17    """
  18
  19    neurons: Sequence[int]
  20    """A sequence of integers representing the dimensions of the network."""
  21    activation: Union[Callable, str] = "silu"
  22    """The activation function to use."""
  23    use_bias: bool = True
  24    """Whether to use bias in the dense layers."""
  25    input_key: Optional[str] = None
  26    """The key of the input tensor."""
  27    output_key: Optional[str] = None
  28    """The key of the output tensor."""
  29    squeeze: bool = False
  30    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
  31    kernel_init: Union[str, Callable] = "lecun_normal()"
  32    """The kernel initialization method to use."""
  33
  34    FID: ClassVar[str] = "NEURAL_NET"
  35
  36    @nn.compact
  37    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
  38        """Applies the neural network to the given inputs.
  39
  40        Args:
  41            inputs (Union[dict, jax.Array]): If input_key is None, a JAX array containing the inputs to the neural network. Else, a dictionary containing the inputs at the key input_key.
  42
  43        Returns:
  44            Union[dict, jax.Array]: If output_key is None, the output tensor of the neural network. Else, a dictionary containing the original inputs and the output tensor at the key output_key.
  45        """
  46        if self.input_key is None:
  47            assert not isinstance(
  48                inputs, dict
  49            ), "input key must be provided if inputs is a dictionary"
  50            x = inputs
  51        else:
  52            x = inputs[self.input_key]
  53
  54        # activation = (
  55        #     activation_from_str(self.activation)
  56        #     if isinstance(self.activation, str)
  57        #     else self.activation
  58        # )
  59        kernel_init = (
  60            initializer_from_str(self.kernel_init)
  61            if isinstance(self.kernel_init, str)
  62            else self.kernel_init
  63        )
  64        ############################
  65        if isinstance(self.activation, str) and self.activation.lower() == "swiglu":
  66            for i, d in enumerate(self.neurons[:-1]):
  67                y = nn.Dense(
  68                    d,
  69                    use_bias=self.use_bias,
  70                    name=f"Layer_{i+1}",
  71                    kernel_init=kernel_init,
  72                )(x)
  73                z = jax.nn.swish(nn.Dense(
  74                    d,
  75                    use_bias=self.use_bias,
  76                    name=f"Mask_{i+1}",
  77                    kernel_init=kernel_init,
  78                )(x))
  79                x = y * z
  80        else:
  81            for i, d in enumerate(self.neurons[:-1]):
  82                x = nn.Dense(
  83                    d,
  84                    use_bias=self.use_bias,
  85                    name=f"Layer_{i+1}",
  86                    kernel_init=kernel_init,
  87                )(x)
  88                x = activation_from_str(self.activation)(x)
  89        x = nn.Dense(
  90            self.neurons[-1],
  91            use_bias=self.use_bias,
  92            name=f"Layer_{len(self.neurons)}",
  93            kernel_init=kernel_init,
  94        )(x)
  95        if self.squeeze and x.shape[-1] == 1:
  96            x = jnp.squeeze(x, axis=-1)
  97        ############################
  98
  99        if self.input_key is not None:
 100            output_key = self.name if self.output_key is None else self.output_key
 101            return {**inputs, output_key: x} if output_key is not None else x
 102        return x
 103
 104
 105class ResMLP(nn.Module):
 106    """Residual neural network as defined in the SpookyNet paper.
 107    
 108    FID: RES_MLP
 109    """
 110
 111    use_bias: bool = True
 112    """Whether to include bias in the linear layers."""
 113    input_key: Optional[str] = None
 114    """The key of the input tensor."""
 115    output_key: Optional[str] = None
 116    """The key of the output tensor."""
 117
 118    kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')"
 119    """The kernel initialization method to use."""
 120    res_only: bool = False
 121    """Whether to only apply the residual connection without additional activation and linear layer."""
 122
 123    FID: ClassVar[str] = "RES_MLP"
 124
 125    @nn.compact
 126    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 127        if self.input_key is None:
 128            assert not isinstance(
 129                inputs, dict
 130            ), "input key must be provided if inputs is a dictionary"
 131            x = inputs
 132        else:
 133            x = inputs[self.input_key]
 134
 135        kernel_init = (
 136            initializer_from_str(self.kernel_init)
 137            if isinstance(self.kernel_init, str)
 138            else self.kernel_init
 139        )
 140        ############################
 141        out = nn.Dense(x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init)(
 142            TrainableSiLU()(x)
 143        )
 144        out = x + nn.Dense(
 145            x.shape[-1], use_bias=self.use_bias, kernel_init=nn.initializers.zeros
 146        )(TrainableSiLU()(out))
 147
 148        if not self.res_only:
 149            out = nn.Dense(
 150                x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init
 151            )(TrainableSiLU()(out))
 152        ############################
 153
 154        if self.input_key is not None:
 155            output_key = self.name if self.output_key is None else self.output_key
 156            return {**inputs, output_key: out} if output_key is not None else out
 157        return out
 158
 159
 160class FullyResidualNet(nn.Module):
 161    """A neural network with skip connections at each layer.
 162    
 163    FID: SKIP_NET
 164    """
 165
 166    dim: int
 167    """The dimension of the hidden layers."""
 168    output_dim: int
 169    """The dimension of the output layer."""
 170    nlayers: int
 171    """The number of layers in the network."""
 172    activation: Union[Callable, str] = "silu"
 173    """The activation function to use."""
 174    use_bias: bool = True
 175    """Whether to include bias terms in the linear layers."""
 176    input_key: Optional[str] = None
 177    """The key of the input tensor."""
 178    output_key: Optional[str] = None
 179    """The key of the output tensor."""
 180    squeeze: bool = False
 181    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 182    kernel_init: Union[str, Callable] = "lecun_normal()"
 183    """The kernel initialization method to use."""
 184
 185    FID: ClassVar[str] = "SKIP_NET"
 186
 187    @nn.compact
 188    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 189        if self.input_key is None:
 190            assert not isinstance(
 191                inputs, dict
 192            ), "input key must be provided if inputs is a dictionary"
 193            x = inputs
 194        else:
 195            x = inputs[self.input_key]
 196
 197        # activation = (
 198        #     activation_from_str(self.activation)
 199        #     if isinstance(self.activation, str)
 200        #     else self.activation
 201        # )
 202        kernel_init = (
 203            initializer_from_str(self.kernel_init)
 204            if isinstance(self.kernel_init, str)
 205            else self.kernel_init
 206        )
 207        ############################
 208        if x.shape[-1] != self.dim:
 209            x = nn.Dense(
 210                self.dim,
 211                use_bias=self.use_bias,
 212                name=f"Reshape",
 213                kernel_init=kernel_init,
 214            )(x)
 215
 216        for i in range(self.nlayers - 1):
 217            x = x + activation_from_str(self.activation)(
 218                nn.Dense(
 219                    self.dim,
 220                    use_bias=self.use_bias,
 221                    name=f"Layer_{i+1}",
 222                    kernel_init=kernel_init,
 223                )(x)
 224            )
 225        x = nn.Dense(
 226            self.output_dim,
 227            use_bias=self.use_bias,
 228            name=f"Layer_{self.nlayers}",
 229            kernel_init=kernel_init,
 230        )(x)
 231        if self.squeeze and x.shape[-1] == 1:
 232            x = jnp.squeeze(x, axis=-1)
 233        ############################
 234
 235        if self.input_key is not None:
 236            output_key = self.name if self.output_key is None else self.output_key
 237            return {**inputs, output_key: x} if output_key is not None else x
 238        return x
 239
 240
 241class HierarchicalNet(nn.Module):
 242    """Neural network for a sequence of inputs (in axis=-2) with a decay factor
 243    
 244    FID: HIERARCHICAL_NET
 245    """
 246
 247    neurons: Sequence[int]
 248    """A sequence of integers representing the number of neurons in each layer."""
 249    activation: Union[Callable, str] = "silu"
 250    """The activation function to use."""
 251    use_bias: bool = True
 252    """Whether to include bias terms in the linear layers."""
 253    input_key: Optional[str] = None
 254    """The key of the input tensor."""
 255    output_key: Optional[str] = None
 256    """The key of the output tensor."""
 257    decay: float = 0.01
 258    """The decay factor to scale each element of the sequence."""
 259    squeeze: bool = False
 260    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 261    kernel_init: Union[str, Callable] = "lecun_normal()"
 262    """The kernel initialization method to use."""
 263
 264    FID: ClassVar[str] = "HIERARCHICAL_NET"
 265
 266    @nn.compact
 267    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 268        if self.input_key is None:
 269            assert not isinstance(
 270                inputs, dict
 271            ), "input key must be provided if inputs is a dictionary"
 272            x = inputs
 273        else:
 274            x = inputs[self.input_key]
 275
 276        ############################
 277        networks = nn.vmap(
 278            FullyConnectedNet,
 279            variable_axes={"params": 0},
 280            split_rngs={"params": True},
 281            in_axes=-2,
 282            out_axes=-2,
 283            kenel_init=self.kernel_init,
 284        )(self.neurons, self.activation, self.use_bias)
 285
 286        out = networks(x)
 287        # scale each layer by a decay factor
 288        decay = jnp.asarray([self.decay**i for i in range(out.shape[-2])])
 289        out = out * decay[..., :, None]
 290
 291        if self.squeeze and out.shape[-1] == 1:
 292            out = jnp.squeeze(out, axis=-1)
 293        ############################
 294
 295        if self.input_key is not None:
 296            output_key = self.name if self.output_key is None else self.output_key
 297            return {**inputs, output_key: out} if output_key is not None else out
 298        return out
 299
 300
 301class SpeciesIndexNet(nn.Module):
 302    """Chemical-species-specific neural network using precomputed species index.
 303
 304    FID: SPECIES_INDEX_NET
 305
 306    A neural network that applies a species-specific fully connected network to each atom embedding.
 307    A species index must be provided to filter the embeddings for each species and apply the corresponding network.
 308    This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer`
 309
 310    """
 311
 312    output_dim: int
 313    """The dimension of the output of the fully connected networks."""
 314    hidden_neurons: Union[dict, FrozenDict, Sequence[int]]
 315    """The hidden dimensions of the fully connected networks.
 316        If a dictionary is provided, it should map species names to dimensions.
 317        If a sequence is provided, the same dimensions will be used for all species."""
 318    species_order: Optional[Union[str, Sequence[str]]] = None
 319    """The species for which to build a network. Only required if neurons is not a dictionary."""
 320    activation: Union[Callable, str] = "silu"
 321    """The activation function to use in the fully connected networks."""
 322    use_bias: bool = True
 323    """Whether to include bias terms in the fully connected networks."""
 324    input_key: Optional[str] = None
 325    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
 326    species_index_key: str = "species_index"
 327    """The key in the input dictionary that corresponds to the species index of the atoms. See `fennol.models.preprocessing.SpeciesIndexer`"""
 328    output_key: Optional[str] = None
 329    """The key in the output dictionary that corresponds to the network's output."""
 330
 331    squeeze: bool = False
 332    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 333    kernel_init: Union[str, Callable] = "lecun_normal()"
 334    """The kernel initialization method for the fully connected networks."""
 335    check_unhandled: bool = True
 336
 337    FID: ClassVar[str] = "SPECIES_INDEX_NET"
 338
 339    def setup(self):
 340        if not (
 341            isinstance(self.hidden_neurons, dict)
 342            or isinstance(self.hidden_neurons, FrozenDict)
 343        ):
 344            assert (
 345                self.species_order is not None
 346            ), "species_order must be provided if hidden_neurons is not a dictionary"
 347            if isinstance(self.species_order, str):
 348                species_order = [el.strip() for el in self.species_order.split(",")]
 349            else:
 350                species_order = [el for el in self.species_order]
 351            neurons = {k: self.hidden_neurons for k in species_order}
 352        else:
 353            neurons = self.hidden_neurons
 354            species_order = list(neurons.keys())
 355        for species in species_order:
 356            assert (
 357                species in PERIODIC_TABLE
 358            ), f"species {species} not found in periodic table"
 359
 360        self.networks = {
 361            k: FullyConnectedNet(
 362                [*neurons[k], self.output_dim],
 363                self.activation,
 364                self.use_bias,
 365                name=k,
 366                kernel_init=self.kernel_init,
 367            )
 368            for k in species_order
 369        }
 370
 371    def __call__(
 372        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 373    ) -> Union[dict, jax.Array]:
 374
 375        if self.input_key is None:
 376            assert not isinstance(
 377                inputs, dict
 378            ), "input key must be provided if inputs is a dictionary"
 379            species, embedding, species_index = inputs
 380        else:
 381            species, embedding = inputs["species"], inputs[self.input_key]
 382            species_index = inputs[self.species_index_key]
 383
 384        assert isinstance(
 385            species_index, dict
 386        ), "species_index must be a dictionary for SpeciesIndexNetHet"
 387
 388        ############################
 389        # initialization => instantiate all networks
 390        if self.is_initializing():
 391            x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype)
 392            [net(x) for net in self.networks.values()]
 393
 394        if self.check_unhandled:
 395            for b in species_index.keys():
 396                if b not in self.networks.keys():
 397                    raise ValueError(f"Species {b} not found in networks. Handled species are {self.networks.keys()}")
 398
 399        ############################
 400        outputs = []
 401        indices = []
 402        for s, net in self.networks.items():
 403            if s not in species_index:
 404                continue
 405            idx = species_index[s]
 406            o = net(embedding[idx])
 407            outputs.append(o)
 408            indices.append(idx)
 409
 410        o = jnp.concatenate(outputs, axis=0)
 411        idx = jnp.concatenate(indices, axis=0)
 412
 413        out = (
 414            jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype)
 415            .at[idx]
 416            .set(o, mode="drop")
 417        )
 418
 419        if self.squeeze and out.shape[-1] == 1:
 420            out = jnp.squeeze(out, axis=-1)
 421        ############################
 422
 423        if self.input_key is not None:
 424            output_key = self.name if self.output_key is None else self.output_key
 425            return {**inputs, output_key: out} if output_key is not None else out
 426        return out
 427
 428
 429class ChemicalNet(nn.Module):
 430    """optimized Chemical-species-specific neural network.
 431
 432    FID: CHEMICAL_NET
 433
 434    A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species.
 435    This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once.
 436    The optimization is allowed because all networks have the same shape.
 437
 438    """
 439
 440    species_order: Union[str, Sequence[str]]
 441    """The species for which to build a network."""
 442    neurons: Sequence[int]
 443    """The dimensions of the fully connected networks."""
 444    activation: Union[Callable, str] = "silu"
 445    """The activation function to use in the fully connected networks."""
 446    use_bias: bool = True
 447    """Whether to include bias terms in the fully connected networks."""
 448    input_key: Optional[str] = None
 449    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
 450    output_key: Optional[str] = None
 451    """The key in the output dictionary that corresponds to the network's output."""
 452    squeeze: bool = False
 453    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 454    kernel_init: Union[str, Callable] = "lecun_normal()"
 455    """The kernel initialization method for the fully connected networks."""
 456
 457    FID: ClassVar[str] = "CHEMICAL_NET"
 458
 459    @nn.compact
 460    def __call__(
 461        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 462    ) -> Union[dict, jax.Array]:
 463        if self.input_key is None:
 464            assert not isinstance(
 465                inputs, dict
 466            ), "input key must be provided if inputs is a dictionary"
 467            species, embedding = inputs
 468        else:
 469            species, embedding = inputs["species"], inputs[self.input_key]
 470
 471        ############################
 472        # build species to network index mapping (static => fixed when jitted)
 473        rev_idx = PERIODIC_TABLE_REV_IDX
 474        maxidx = max(rev_idx.values())
 475        if isinstance(self.species_order, str):
 476            species_order = [el.strip() for el in self.species_order.split(",")]
 477        else:
 478            species_order = [el for el in self.species_order]
 479        nspecies = len(species_order)
 480        conv_tensor_ = np.full((maxidx + 2,), -1, dtype=np.int32)
 481        for i, s in enumerate(species_order):
 482            conv_tensor_[rev_idx[s]] = i
 483        conv_tensor = jnp.asarray(conv_tensor_)
 484        indices = conv_tensor[species]
 485
 486        ############################
 487        # build shape-sharing networks using vmap
 488        networks = nn.vmap(
 489            FullyConnectedNet,
 490            variable_axes={"params": 0},
 491            split_rngs={"params": True},
 492            in_axes=0,
 493        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
 494        # repeat input along a new axis to compute for all species at once
 495        x = jnp.broadcast_to(
 496            embedding[None, :, :], (nspecies, *embedding.shape)
 497        )
 498
 499        # apply networks to input and select the output corresponding to the species
 500        out = jnp.squeeze(
 501            jnp.take_along_axis(networks(x), indices[None, :, None], axis=0), axis=0
 502        )
 503
 504        out = jnp.where((indices >= 0)[:, None], out, 0.0)
 505        if self.squeeze and out.shape[-1] == 1:
 506            out = jnp.squeeze(out, axis=-1)
 507        ############################
 508
 509        if self.input_key is not None:
 510            output_key = self.name if self.output_key is None else self.output_key
 511            return {**inputs, output_key: out} if output_key is not None else out
 512        return out
 513
 514
 515class MOENet(nn.Module):
 516    """Mixture of Experts neural network.
 517
 518    FID: MOE_NET
 519
 520    This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks
 521    to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router.
 522
 523    """
 524
 525    neurons: Sequence[int]
 526    """A sequence of integers representing the number of neurons in each shape-sharing network."""
 527    num_networks: int
 528    """The number of shape-sharing networks to create."""
 529    activation: Union[Callable, str] = "silu"
 530    """The activation function to use in the shape-sharing networks."""
 531    use_bias: bool = True
 532    """Whether to include bias in the shape-sharing networks."""
 533    input_key: Optional[str] = None
 534    """The key of the input tensor."""
 535    output_key: Optional[str] = None
 536    """The key of the output tensor."""
 537    squeeze: bool = False
 538    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 539    
 540    kernel_init: Union[str, Callable] = "lecun_normal()"
 541    """The kernel initialization method to use in the shape-sharing networks."""
 542    router_key: Optional[str] = None
 543    """The key of the router tensor. If None, the router is assumed to be the same as the input tensor."""
 544
 545    FID: ClassVar[str] = "MOE_NET"
 546
 547    @nn.compact
 548    def __call__(
 549        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 550    ) -> Union[dict, jax.Array]:
 551        if self.input_key is None:
 552            assert not isinstance(
 553                inputs, dict
 554            ), "input key must be provided if inputs is a dictionary"
 555            if isinstance(inputs, tuple):
 556                embedding, router = inputs
 557            else:
 558                embedding = router = inputs
 559        else:
 560            embedding = inputs[self.input_key]
 561            router = (
 562                inputs[self.router_key] if self.router_key is not None else embedding
 563            )
 564
 565        ############################
 566        # build shape-sharing networks using vmap
 567        networks = nn.vmap(
 568            FullyConnectedNet,
 569            variable_axes={"params": 0},
 570            split_rngs={"params": True},
 571            in_axes=0,
 572            out_axes=0,
 573        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
 574        # repeat input along a new axis to compute for all networks at once
 575        x = jnp.repeat(embedding[None, :, :], self.num_networks, axis=0)
 576
 577        w = nn.softmax(nn.Dense(self.num_networks, name="router")(router), axis=-1)
 578
 579        out = (networks(x) * w.T[:, :, None]).sum(axis=0)
 580
 581        if self.squeeze and out.shape[-1] == 1:
 582            out = jnp.squeeze(out, axis=-1)
 583        ############################
 584
 585        if self.input_key is not None:
 586            output_key = self.name if self.output_key is None else self.output_key
 587            return {**inputs, output_key: out} if output_key is not None else out
 588        return out
 589
 590class ChannelNet(nn.Module):
 591    """Apply a different neural network to each channel.
 592    
 593    FID: CHANNEL_NET
 594    """
 595
 596    neurons: Sequence[int]
 597    """A sequence of integers representing the number of neurons in each shape-sharing network."""
 598    activation: Union[Callable, str] = "silu"
 599    """The activation function to use in the shape-sharing networks."""
 600    use_bias: bool = True
 601    """Whether to include bias in the shape-sharing networks."""
 602    input_key: Optional[str] = None
 603    """The key of the input tensor."""
 604    output_key: Optional[str] = None
 605    """The key of the output tensor."""
 606    squeeze: bool = False
 607    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 608    kernel_init: Union[str, Callable] = "lecun_normal()"
 609    """The kernel initialization method to use in the shape-sharing networks."""
 610    channel_axis: int = -2
 611    """The axis to use as channel. Its length will be the number of shape-sharing networks."""
 612
 613    FID: ClassVar[str] = "CHANNEL_NET"
 614
 615    @nn.compact
 616    def __call__(
 617        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 618    ) -> Union[dict, jax.Array]:
 619        if self.input_key is None:
 620            assert not isinstance(
 621                inputs, dict
 622            ), "input key must be provided if inputs is a dictionary"
 623            x = inputs
 624        else:
 625            x = inputs[self.input_key]
 626
 627        ############################
 628        # build shape-sharing networks using vmap
 629        networks = nn.vmap(
 630            FullyConnectedNet,
 631            variable_axes={"params": 0},
 632            split_rngs={"params": True},
 633            in_axes=self.channel_axis,
 634            out_axes=self.channel_axis,
 635        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
 636
 637        out = networks(x)
 638
 639        if self.squeeze and out.shape[-1] == 1:
 640            out = jnp.squeeze(out, axis=-1)
 641        ############################
 642
 643        if self.input_key is not None:
 644            output_key = self.name if self.output_key is None else self.output_key
 645            return {**inputs, output_key: out} if output_key is not None else out
 646        return out
 647
 648
 649class GatedPerceptron(nn.Module):
 650    """Gated Perceptron neural network.
 651
 652    FID: GATED_PERCEPTRON
 653
 654    This class represents a Gated Perceptron neural network model. It applies a gating mechanism
 655    to the input data and performs linear transformation using a dense layer followed by an activation function.
 656    """
 657
 658    dim: int
 659    """The dimensionality of the output space."""
 660    use_bias: bool = True
 661    """Whether to include a bias term in the dense layer."""
 662    kernel_init: Union[str, Callable] = "lecun_normal()"
 663    """The kernel initialization method to use."""
 664    activation: Union[Callable, str] = "silu"
 665    """The activation function to use."""
 666
 667    input_key: Optional[str] = None
 668    """The key of the input tensor."""
 669    output_key: Optional[str] = None
 670    """The key of the output tensor."""
 671    squeeze: bool = False
 672    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 673
 674    FID: ClassVar[str] = "GATED_PERCEPTRON"
 675
 676    @nn.compact
 677    def __call__(self, inputs):
 678        if self.input_key is None:
 679            assert not isinstance(
 680                inputs, dict
 681            ), "input key must be provided if inputs is a dictionary"
 682            x = inputs
 683        else:
 684            x = inputs[self.input_key]
 685
 686        # activation = (
 687        #     activation_from_str(self.activation)
 688        #     if isinstance(self.activation, str)
 689        #     else self.activation
 690        # )
 691        kernel_init = (
 692            initializer_from_str(self.kernel_init)
 693            if isinstance(self.kernel_init, str)
 694            else self.kernel_init
 695        )
 696        ############################
 697        gate = jax.nn.sigmoid(
 698            nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x)
 699        )
 700        x = gate * activation_from_str(self.activation)(
 701            nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x)
 702        )
 703
 704        if self.squeeze and out.shape[-1] == 1:
 705            out = jnp.squeeze(out, axis=-1)
 706        ############################
 707
 708        if self.input_key is not None:
 709            output_key = self.name if self.output_key is None else self.output_key
 710            return {**inputs, output_key: x} if output_key is not None else x
 711        return x
 712
 713
 714class ZAcNet(nn.Module):
 715    """ A fully connected neural network module with affine Z-dependent adjustments of activations.
 716    
 717    FID: ZACNET
 718    """
 719
 720    neurons: Sequence[int]
 721    """A sequence of integers representing the dimensions of the network."""
 722    zmax: int = 86
 723    """The maximum atomic number to consider."""
 724    activation: Union[Callable, str] = "silu"
 725    """The activation function to use."""
 726    use_bias: bool = True
 727    """Whether to use bias in the dense layers."""
 728    input_key: Optional[str] = None
 729    """The key of the input tensor."""
 730    output_key: Optional[str] = None
 731    """The key of the output tensor."""
 732    squeeze: bool = False
 733    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 734    kernel_init: Union[str, Callable] = "lecun_normal()"
 735    """The kernel initialization method to use."""
 736    species_key: str = "species"
 737    """The key of the species tensor."""
 738
 739    FID: ClassVar[str] = "ZACNET"
 740
 741    @nn.compact
 742    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 743        if self.input_key is None:
 744            assert not isinstance(
 745                inputs, dict
 746            ), "input key must be provided if inputs is a dictionary"
 747            species, x = inputs
 748        else:
 749            species, x = inputs[self.species_key], inputs[self.input_key]
 750
 751        # activation = (
 752        #     activation_from_str(self.activation)
 753        #     if isinstance(self.activation, str)
 754        #     else self.activation
 755        # )
 756        kernel_init = (
 757            initializer_from_str(self.kernel_init)
 758            if isinstance(self.kernel_init, str)
 759            else self.kernel_init
 760        )
 761        ############################
 762        for i, d in enumerate(self.neurons[:-1]):
 763            x = nn.Dense(
 764                d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init
 765            )(x)
 766            sig = self.param(
 767                f"sig_{i+1}",
 768                lambda key, shape: jnp.ones(shape, dtype=x.dtype),
 769                (self.zmax + 2, d),
 770            )[species]
 771            if self.use_bias:
 772                b = self.param(
 773                    f"b_{i+1}",
 774                    lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 775                    (self.zmax + 2, d),
 776                )[species]
 777            else:
 778                b = 0
 779            x = activation_from_str(self.activation)(sig * x + b)
 780        x = nn.Dense(
 781            self.neurons[-1],
 782            use_bias=self.use_bias,
 783            name=f"Layer_{len(self.neurons)}",
 784            kernel_init=kernel_init,
 785        )(x)
 786        sig = self.param(
 787            f"sig_{len(self.neurons)}",
 788            lambda key, shape: jnp.ones(shape, dtype=x.dtype),
 789            (self.zmax + 2, self.neurons[-1]),
 790        )[species]
 791        if self.use_bias:
 792            b = self.param(
 793                f"b_{len(self.neurons)}",
 794                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 795                (self.zmax + 2, self.neurons[-1]),
 796            )[species]
 797        else:
 798            b = 0
 799        x = sig * x + b
 800        if self.squeeze and x.shape[-1] == 1:
 801            x = jnp.squeeze(x, axis=-1)
 802        ############################
 803
 804        if self.input_key is not None:
 805            output_key = self.name if self.output_key is None else self.output_key
 806            return {**inputs, output_key: x} if output_key is not None else x
 807        return x
 808
 809
 810class ZLoRANet(nn.Module):
 811    """A fully connected neural network module with Z-dependent low-rank adaptation.
 812    
 813    FID: ZLORANET
 814    """
 815
 816    neurons: Sequence[int]
 817    """A sequence of integers representing the dimensions of the network."""
 818    ranks: Sequence[int]
 819    """A sequence of integers representing the ranks of the low-rank adaptation at each layer."""
 820    zmax: int = 86
 821    """The maximum atomic number to consider."""
 822    activation: Union[Callable, str] = "silu"
 823    """The activation function to use."""
 824    use_bias: bool = True
 825    """Whether to use bias in the dense layers."""
 826    input_key: Optional[str] = None
 827    """The key of the input tensor."""
 828    output_key: Optional[str] = None
 829    """The key of the output tensor."""
 830    squeeze: bool = False
 831    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 832    kernel_init: Union[str, Callable] = "lecun_normal()"
 833    """The kernel initialization method to use."""
 834    species_key: str = "species"
 835    """The key of the species tensor."""
 836
 837    FID: ClassVar[str] = "ZLORANET"
 838
 839    @nn.compact
 840    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 841        if self.input_key is None:
 842            assert not isinstance(
 843                inputs, dict
 844            ), "input key must be provided if inputs is a dictionary"
 845            species, x = inputs
 846        else:
 847            species, x = inputs[self.species_key], inputs[self.input_key]
 848
 849        # activation = (
 850        #     activation_from_str(self.activation)
 851        #     if isinstance(self.activation, str)
 852        #     else self.activation
 853        # )
 854        kernel_init = (
 855            initializer_from_str(self.kernel_init)
 856            if isinstance(self.kernel_init, str)
 857            else self.kernel_init
 858        )
 859        ############################
 860        for i, d in enumerate(self.neurons[:-1]):
 861            xi = nn.Dense(
 862                d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init
 863            )(x)
 864            A = self.param(
 865                f"A_{i+1}",
 866                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 867                (self.zmax + 2, self.ranks[i], x.shape[-1]),
 868            )[species]
 869            B = self.param(
 870                f"B_{i+1}",
 871                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 872                (self.zmax + 2, d, self.ranks[i]),
 873            )[species]
 874            Ax = jnp.einsum("zrd,zd->zr", A, x)
 875            BAx = jnp.einsum("zrd,zd->zr", B, Ax)
 876            x = activation_from_str(self.activation)(xi + BAx)
 877        xi = nn.Dense(
 878            self.neurons[-1],
 879            use_bias=self.use_bias,
 880            name=f"Layer_{len(self.neurons)}",
 881            kernel_init=kernel_init,
 882        )(x)
 883        A = self.param(
 884            f"A_{len(self.neurons)}",
 885            lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 886            (self.zmax + 2, self.ranks[-1], x.shape[-1]),
 887        )[species]
 888        B = self.param(
 889            f"B_{len(self.neurons)}",
 890            lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
 891            (self.zmax + 2, self.neurons[-1], self.ranks[-1]),
 892        )[species]
 893        Ax = jnp.einsum("zrd,zd->zr", A, x)
 894        BAx = jnp.einsum("zrd,zd->zr", B, Ax)
 895        x = xi + BAx
 896        if self.squeeze and x.shape[-1] == 1:
 897            x = jnp.squeeze(x, axis=-1)
 898        ############################
 899
 900        if self.input_key is not None:
 901            output_key = self.name if self.output_key is None else self.output_key
 902            return {**inputs, output_key: x} if output_key is not None else x
 903        return x
 904
 905
 906class BlockIndexNet(nn.Module):
 907    """Chemical-species-specific neural network using precomputed species index.
 908
 909    FID: BLOCK_INDEX_NET
 910
 911    A neural network that applies a species-specific fully connected network to each atom embedding.
 912    A species index must be provided to filter the embeddings for each species and apply the corresponding network.
 913    This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer`
 914
 915    """
 916
 917    output_dim: int
 918    """The dimension of the output of the fully connected networks."""
 919    hidden_neurons: Sequence[int]
 920    """The hidden dimensions of the fully connected networks.
 921        If a dictionary is provided, it should map species names to dimensions.
 922        If a sequence is provided, the same dimensions will be used for all species."""
 923    used_blocks: Optional[Sequence[str]] = None
 924    """The blocks to use. If None, all blocks will be used."""
 925    activation: Union[Callable, str] = "silu"
 926    """The activation function to use in the fully connected networks."""
 927    use_bias: bool = True
 928    """Whether to include bias terms in the fully connected networks."""
 929    input_key: Optional[str] = None
 930    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
 931    block_index_key: str = "block_index"
 932    """The key in the input dictionary that corresponds to the block index of the atoms. See `fennol.models.preprocessing.BlockIndexer`"""
 933    output_key: Optional[str] = None
 934    """The key in the output dictionary that corresponds to the network's output."""
 935
 936    squeeze: bool = False
 937    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 938    kernel_init: Union[str, Callable] = "lecun_normal()"
 939    """The kernel initialization method for the fully connected networks."""
 940    # check_unhandled: bool = True
 941
 942    FID: ClassVar[str] = "BLOCK_INDEX_NET"
 943
 944    # def setup(self):
 945    #     all_blocks = CHEMICAL_BLOCKS_NAMES
 946    #     if self.used_blocks is None:
 947    #         used_blocks = all_blocks
 948    #     else:
 949    #         used_blocks = []
 950    #         for b in self.used_blocks:
 951    #             b_=str(b).strip().upper()
 952    #             if b_ not in all_blocks:
 953    #                 raise ValueError(f"Block {b} not found in {all_blocks}")
 954    #             used_blocks.append(b_)
 955    #     used_blocks = set(used_blocks)
 956    #     self._used_blocks = used_blocks
 957
 958    #     if not (
 959    #         isinstance(self.hidden_neurons, dict)
 960    #         or isinstance(self.hidden_neurons, FrozenDict)
 961    #     ):
 962    #         neurons = {k: self.hidden_neurons for k in used_blocks}
 963    #     else:
 964    #         neurons = {}
 965    #         for b in self.hidden_neurons.keys():
 966    #             b_=str(b).strip().upper()
 967    #             if b_ not in all_blocks:
 968    #                 raise ValueError(f"Block {b} does not exist.  Available blocks are {all_blocks}")
 969    #             neurons[b_] = self.hidden_neurons[b]
 970    #         used_blocks = set(neurons.keys())
 971    #         if used_blocks != self._used_blocks and self.used_blocks is not None:
 972    #             print(
 973    #                 f"Warning: hidden neurons definitions do not match specified used_blocks {self.used_blocks}. Using blocks defined in hidden_neurons.")
 974    #         self._used_blocks = used_blocks
 975
 976    #     self.networks = {
 977    #         k: FullyConnectedNet(
 978    #             [*neurons[k], self.output_dim],
 979    #             self.activation,
 980    #             self.use_bias,
 981    #             name=k,
 982    #             kernel_init=self.kernel_init,
 983    #         )
 984    #         for k in self._used_blocks
 985    #     }
 986
 987    @nn.compact
 988    def __call__(
 989        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 990    ) -> Union[dict, jax.Array]:
 991
 992        if self.input_key is None:
 993            assert not isinstance(
 994                inputs, dict
 995            ), "input key must be provided if inputs is a dictionary"
 996            species, embedding, block_index = inputs
 997        else:
 998            species, embedding = inputs["species"], inputs[self.input_key]
 999            block_index = inputs[self.block_index_key]
1000
1001        assert isinstance(
1002            block_index, dict
1003        ), "block_index must be a dictionary for BlockIndexNet"
1004
1005        networks = {
1006            k: FullyConnectedNet(
1007                [*self.hidden_neurons, self.output_dim],
1008                self.activation,
1009                self.use_bias,
1010                name=k,
1011                kernel_init=self.kernel_init,
1012            )
1013            for k in block_index.keys()
1014        }
1015
1016        ############################
1017        # initialization => instantiate all networks
1018        if self.is_initializing():
1019            x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype)
1020            [net(x) for net in networks.values()]
1021
1022        # if self.check_unhandled:
1023        #     for b in block_index.keys():
1024        #         if b not in networks.keys():
1025        #             raise ValueError(f"Block {b} not found in networks. Available blocks are {self.networks.keys()}")
1026
1027        ############################
1028        outputs = []
1029        indices = []
1030        for s, net in networks.items():
1031            if s not in block_index:
1032                continue
1033            if block_index[s] is None:
1034                continue
1035            idx = block_index[s]
1036            o = net(embedding[idx])
1037            outputs.append(o)
1038            indices.append(idx)
1039
1040        o = jnp.concatenate(outputs, axis=0)
1041        idx = jnp.concatenate(indices, axis=0)
1042
1043        out = (
1044            jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype)
1045            .at[idx]
1046            .set(o, mode="drop")
1047        )
1048
1049        if self.squeeze and out.shape[-1] == 1:
1050            out = jnp.squeeze(out, axis=-1)
1051        ############################
1052
1053        if self.input_key is not None:
1054            output_key = self.name if self.output_key is None else self.output_key
1055            return {**inputs, output_key: out} if output_key is not None else out
1056        return out
class FullyConnectedNet(flax.linen.module.Module):
 14class FullyConnectedNet(nn.Module):
 15    """A fully connected neural network module.
 16
 17    FID: NEURAL_NET
 18    """
 19
 20    neurons: Sequence[int]
 21    """A sequence of integers representing the dimensions of the network."""
 22    activation: Union[Callable, str] = "silu"
 23    """The activation function to use."""
 24    use_bias: bool = True
 25    """Whether to use bias in the dense layers."""
 26    input_key: Optional[str] = None
 27    """The key of the input tensor."""
 28    output_key: Optional[str] = None
 29    """The key of the output tensor."""
 30    squeeze: bool = False
 31    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 32    kernel_init: Union[str, Callable] = "lecun_normal()"
 33    """The kernel initialization method to use."""
 34
 35    FID: ClassVar[str] = "NEURAL_NET"
 36
 37    @nn.compact
 38    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
 39        """Applies the neural network to the given inputs.
 40
 41        Args:
 42            inputs (Union[dict, jax.Array]): If input_key is None, a JAX array containing the inputs to the neural network. Else, a dictionary containing the inputs at the key input_key.
 43
 44        Returns:
 45            Union[dict, jax.Array]: If output_key is None, the output tensor of the neural network. Else, a dictionary containing the original inputs and the output tensor at the key output_key.
 46        """
 47        if self.input_key is None:
 48            assert not isinstance(
 49                inputs, dict
 50            ), "input key must be provided if inputs is a dictionary"
 51            x = inputs
 52        else:
 53            x = inputs[self.input_key]
 54
 55        # activation = (
 56        #     activation_from_str(self.activation)
 57        #     if isinstance(self.activation, str)
 58        #     else self.activation
 59        # )
 60        kernel_init = (
 61            initializer_from_str(self.kernel_init)
 62            if isinstance(self.kernel_init, str)
 63            else self.kernel_init
 64        )
 65        ############################
 66        if isinstance(self.activation, str) and self.activation.lower() == "swiglu":
 67            for i, d in enumerate(self.neurons[:-1]):
 68                y = nn.Dense(
 69                    d,
 70                    use_bias=self.use_bias,
 71                    name=f"Layer_{i+1}",
 72                    kernel_init=kernel_init,
 73                )(x)
 74                z = jax.nn.swish(nn.Dense(
 75                    d,
 76                    use_bias=self.use_bias,
 77                    name=f"Mask_{i+1}",
 78                    kernel_init=kernel_init,
 79                )(x))
 80                x = y * z
 81        else:
 82            for i, d in enumerate(self.neurons[:-1]):
 83                x = nn.Dense(
 84                    d,
 85                    use_bias=self.use_bias,
 86                    name=f"Layer_{i+1}",
 87                    kernel_init=kernel_init,
 88                )(x)
 89                x = activation_from_str(self.activation)(x)
 90        x = nn.Dense(
 91            self.neurons[-1],
 92            use_bias=self.use_bias,
 93            name=f"Layer_{len(self.neurons)}",
 94            kernel_init=kernel_init,
 95        )(x)
 96        if self.squeeze and x.shape[-1] == 1:
 97            x = jnp.squeeze(x, axis=-1)
 98        ############################
 99
100        if self.input_key is not None:
101            output_key = self.name if self.output_key is None else self.output_key
102            return {**inputs, output_key: x} if output_key is not None else x
103        return x

A fully connected neural network module.

FID: NEURAL_NET

FullyConnectedNet( neurons: Sequence[int], activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the dimensions of the network.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to use bias in the dense layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

FID: ClassVar[str] = 'NEURAL_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ResMLP(flax.linen.module.Module):
106class ResMLP(nn.Module):
107    """Residual neural network as defined in the SpookyNet paper.
108    
109    FID: RES_MLP
110    """
111
112    use_bias: bool = True
113    """Whether to include bias in the linear layers."""
114    input_key: Optional[str] = None
115    """The key of the input tensor."""
116    output_key: Optional[str] = None
117    """The key of the output tensor."""
118
119    kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')"
120    """The kernel initialization method to use."""
121    res_only: bool = False
122    """Whether to only apply the residual connection without additional activation and linear layer."""
123
124    FID: ClassVar[str] = "RES_MLP"
125
126    @nn.compact
127    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
128        if self.input_key is None:
129            assert not isinstance(
130                inputs, dict
131            ), "input key must be provided if inputs is a dictionary"
132            x = inputs
133        else:
134            x = inputs[self.input_key]
135
136        kernel_init = (
137            initializer_from_str(self.kernel_init)
138            if isinstance(self.kernel_init, str)
139            else self.kernel_init
140        )
141        ############################
142        out = nn.Dense(x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init)(
143            TrainableSiLU()(x)
144        )
145        out = x + nn.Dense(
146            x.shape[-1], use_bias=self.use_bias, kernel_init=nn.initializers.zeros
147        )(TrainableSiLU()(out))
148
149        if not self.res_only:
150            out = nn.Dense(
151                x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init
152            )(TrainableSiLU()(out))
153        ############################
154
155        if self.input_key is not None:
156            output_key = self.name if self.output_key is None else self.output_key
157            return {**inputs, output_key: out} if output_key is not None else out
158        return out

Residual neural network as defined in the SpookyNet paper.

FID: RES_MLP

ResMLP( use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')", res_only: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
use_bias: bool = True

Whether to include bias in the linear layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')"

The kernel initialization method to use.

res_only: bool = False

Whether to only apply the residual connection without additional activation and linear layer.

FID: ClassVar[str] = 'RES_MLP'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class FullyResidualNet(flax.linen.module.Module):
161class FullyResidualNet(nn.Module):
162    """A neural network with skip connections at each layer.
163    
164    FID: SKIP_NET
165    """
166
167    dim: int
168    """The dimension of the hidden layers."""
169    output_dim: int
170    """The dimension of the output layer."""
171    nlayers: int
172    """The number of layers in the network."""
173    activation: Union[Callable, str] = "silu"
174    """The activation function to use."""
175    use_bias: bool = True
176    """Whether to include bias terms in the linear layers."""
177    input_key: Optional[str] = None
178    """The key of the input tensor."""
179    output_key: Optional[str] = None
180    """The key of the output tensor."""
181    squeeze: bool = False
182    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
183    kernel_init: Union[str, Callable] = "lecun_normal()"
184    """The kernel initialization method to use."""
185
186    FID: ClassVar[str] = "SKIP_NET"
187
188    @nn.compact
189    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
190        if self.input_key is None:
191            assert not isinstance(
192                inputs, dict
193            ), "input key must be provided if inputs is a dictionary"
194            x = inputs
195        else:
196            x = inputs[self.input_key]
197
198        # activation = (
199        #     activation_from_str(self.activation)
200        #     if isinstance(self.activation, str)
201        #     else self.activation
202        # )
203        kernel_init = (
204            initializer_from_str(self.kernel_init)
205            if isinstance(self.kernel_init, str)
206            else self.kernel_init
207        )
208        ############################
209        if x.shape[-1] != self.dim:
210            x = nn.Dense(
211                self.dim,
212                use_bias=self.use_bias,
213                name=f"Reshape",
214                kernel_init=kernel_init,
215            )(x)
216
217        for i in range(self.nlayers - 1):
218            x = x + activation_from_str(self.activation)(
219                nn.Dense(
220                    self.dim,
221                    use_bias=self.use_bias,
222                    name=f"Layer_{i+1}",
223                    kernel_init=kernel_init,
224                )(x)
225            )
226        x = nn.Dense(
227            self.output_dim,
228            use_bias=self.use_bias,
229            name=f"Layer_{self.nlayers}",
230            kernel_init=kernel_init,
231        )(x)
232        if self.squeeze and x.shape[-1] == 1:
233            x = jnp.squeeze(x, axis=-1)
234        ############################
235
236        if self.input_key is not None:
237            output_key = self.name if self.output_key is None else self.output_key
238            return {**inputs, output_key: x} if output_key is not None else x
239        return x

A neural network with skip connections at each layer.

FID: SKIP_NET

FullyResidualNet( dim: int, output_dim: int, nlayers: int, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int

The dimension of the hidden layers.

output_dim: int

The dimension of the output layer.

nlayers: int

The number of layers in the network.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to include bias terms in the linear layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

FID: ClassVar[str] = 'SKIP_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class HierarchicalNet(flax.linen.module.Module):
242class HierarchicalNet(nn.Module):
243    """Neural network for a sequence of inputs (in axis=-2) with a decay factor
244    
245    FID: HIERARCHICAL_NET
246    """
247
248    neurons: Sequence[int]
249    """A sequence of integers representing the number of neurons in each layer."""
250    activation: Union[Callable, str] = "silu"
251    """The activation function to use."""
252    use_bias: bool = True
253    """Whether to include bias terms in the linear layers."""
254    input_key: Optional[str] = None
255    """The key of the input tensor."""
256    output_key: Optional[str] = None
257    """The key of the output tensor."""
258    decay: float = 0.01
259    """The decay factor to scale each element of the sequence."""
260    squeeze: bool = False
261    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
262    kernel_init: Union[str, Callable] = "lecun_normal()"
263    """The kernel initialization method to use."""
264
265    FID: ClassVar[str] = "HIERARCHICAL_NET"
266
267    @nn.compact
268    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
269        if self.input_key is None:
270            assert not isinstance(
271                inputs, dict
272            ), "input key must be provided if inputs is a dictionary"
273            x = inputs
274        else:
275            x = inputs[self.input_key]
276
277        ############################
278        networks = nn.vmap(
279            FullyConnectedNet,
280            variable_axes={"params": 0},
281            split_rngs={"params": True},
282            in_axes=-2,
283            out_axes=-2,
284            kenel_init=self.kernel_init,
285        )(self.neurons, self.activation, self.use_bias)
286
287        out = networks(x)
288        # scale each layer by a decay factor
289        decay = jnp.asarray([self.decay**i for i in range(out.shape[-2])])
290        out = out * decay[..., :, None]
291
292        if self.squeeze and out.shape[-1] == 1:
293            out = jnp.squeeze(out, axis=-1)
294        ############################
295
296        if self.input_key is not None:
297            output_key = self.name if self.output_key is None else self.output_key
298            return {**inputs, output_key: out} if output_key is not None else out
299        return out

Neural network for a sequence of inputs (in axis=-2) with a decay factor

FID: HIERARCHICAL_NET

HierarchicalNet( neurons: Sequence[int], activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, decay: float = 0.01, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the number of neurons in each layer.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to include bias terms in the linear layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

decay: float = 0.01

The decay factor to scale each element of the sequence.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

FID: ClassVar[str] = 'HIERARCHICAL_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class SpeciesIndexNet(flax.linen.module.Module):
302class SpeciesIndexNet(nn.Module):
303    """Chemical-species-specific neural network using precomputed species index.
304
305    FID: SPECIES_INDEX_NET
306
307    A neural network that applies a species-specific fully connected network to each atom embedding.
308    A species index must be provided to filter the embeddings for each species and apply the corresponding network.
309    This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer`
310
311    """
312
313    output_dim: int
314    """The dimension of the output of the fully connected networks."""
315    hidden_neurons: Union[dict, FrozenDict, Sequence[int]]
316    """The hidden dimensions of the fully connected networks.
317        If a dictionary is provided, it should map species names to dimensions.
318        If a sequence is provided, the same dimensions will be used for all species."""
319    species_order: Optional[Union[str, Sequence[str]]] = None
320    """The species for which to build a network. Only required if neurons is not a dictionary."""
321    activation: Union[Callable, str] = "silu"
322    """The activation function to use in the fully connected networks."""
323    use_bias: bool = True
324    """Whether to include bias terms in the fully connected networks."""
325    input_key: Optional[str] = None
326    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
327    species_index_key: str = "species_index"
328    """The key in the input dictionary that corresponds to the species index of the atoms. See `fennol.models.preprocessing.SpeciesIndexer`"""
329    output_key: Optional[str] = None
330    """The key in the output dictionary that corresponds to the network's output."""
331
332    squeeze: bool = False
333    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
334    kernel_init: Union[str, Callable] = "lecun_normal()"
335    """The kernel initialization method for the fully connected networks."""
336    check_unhandled: bool = True
337
338    FID: ClassVar[str] = "SPECIES_INDEX_NET"
339
340    def setup(self):
341        if not (
342            isinstance(self.hidden_neurons, dict)
343            or isinstance(self.hidden_neurons, FrozenDict)
344        ):
345            assert (
346                self.species_order is not None
347            ), "species_order must be provided if hidden_neurons is not a dictionary"
348            if isinstance(self.species_order, str):
349                species_order = [el.strip() for el in self.species_order.split(",")]
350            else:
351                species_order = [el for el in self.species_order]
352            neurons = {k: self.hidden_neurons for k in species_order}
353        else:
354            neurons = self.hidden_neurons
355            species_order = list(neurons.keys())
356        for species in species_order:
357            assert (
358                species in PERIODIC_TABLE
359            ), f"species {species} not found in periodic table"
360
361        self.networks = {
362            k: FullyConnectedNet(
363                [*neurons[k], self.output_dim],
364                self.activation,
365                self.use_bias,
366                name=k,
367                kernel_init=self.kernel_init,
368            )
369            for k in species_order
370        }
371
372    def __call__(
373        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
374    ) -> Union[dict, jax.Array]:
375
376        if self.input_key is None:
377            assert not isinstance(
378                inputs, dict
379            ), "input key must be provided if inputs is a dictionary"
380            species, embedding, species_index = inputs
381        else:
382            species, embedding = inputs["species"], inputs[self.input_key]
383            species_index = inputs[self.species_index_key]
384
385        assert isinstance(
386            species_index, dict
387        ), "species_index must be a dictionary for SpeciesIndexNetHet"
388
389        ############################
390        # initialization => instantiate all networks
391        if self.is_initializing():
392            x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype)
393            [net(x) for net in self.networks.values()]
394
395        if self.check_unhandled:
396            for b in species_index.keys():
397                if b not in self.networks.keys():
398                    raise ValueError(f"Species {b} not found in networks. Handled species are {self.networks.keys()}")
399
400        ############################
401        outputs = []
402        indices = []
403        for s, net in self.networks.items():
404            if s not in species_index:
405                continue
406            idx = species_index[s]
407            o = net(embedding[idx])
408            outputs.append(o)
409            indices.append(idx)
410
411        o = jnp.concatenate(outputs, axis=0)
412        idx = jnp.concatenate(indices, axis=0)
413
414        out = (
415            jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype)
416            .at[idx]
417            .set(o, mode="drop")
418        )
419
420        if self.squeeze and out.shape[-1] == 1:
421            out = jnp.squeeze(out, axis=-1)
422        ############################
423
424        if self.input_key is not None:
425            output_key = self.name if self.output_key is None else self.output_key
426            return {**inputs, output_key: out} if output_key is not None else out
427        return out

Chemical-species-specific neural network using precomputed species index.

FID: SPECIES_INDEX_NET

A neural network that applies a species-specific fully connected network to each atom embedding. A species index must be provided to filter the embeddings for each species and apply the corresponding network. This index can be obtained using the SPECIES_INDEXER preprocessing module from fennol.models.preprocessing.SpeciesIndexer

SpeciesIndexNet( output_dim: int, hidden_neurons: Union[dict, flax.core.frozen_dict.FrozenDict, Sequence[int]], species_order: Union[str, Sequence[str], NoneType] = None, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, species_index_key: str = 'species_index', output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', check_unhandled: bool = True, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
output_dim: int

The dimension of the output of the fully connected networks.

hidden_neurons: Union[dict, flax.core.frozen_dict.FrozenDict, Sequence[int]]

The hidden dimensions of the fully connected networks. If a dictionary is provided, it should map species names to dimensions. If a sequence is provided, the same dimensions will be used for all species.

species_order: Union[str, Sequence[str], NoneType] = None

The species for which to build a network. Only required if neurons is not a dictionary.

activation: Union[Callable, str] = 'silu'

The activation function to use in the fully connected networks.

use_bias: bool = True

Whether to include bias terms in the fully connected networks.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the embeddings of the atoms.

species_index_key: str = 'species_index'

The key in the input dictionary that corresponds to the species index of the atoms. See fennol.models.preprocessing.SpeciesIndexer

output_key: Optional[str] = None

The key in the output dictionary that corresponds to the network's output.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method for the fully connected networks.

check_unhandled: bool = True
FID: ClassVar[str] = 'SPECIES_INDEX_NET'
def setup(self):
340    def setup(self):
341        if not (
342            isinstance(self.hidden_neurons, dict)
343            or isinstance(self.hidden_neurons, FrozenDict)
344        ):
345            assert (
346                self.species_order is not None
347            ), "species_order must be provided if hidden_neurons is not a dictionary"
348            if isinstance(self.species_order, str):
349                species_order = [el.strip() for el in self.species_order.split(",")]
350            else:
351                species_order = [el for el in self.species_order]
352            neurons = {k: self.hidden_neurons for k in species_order}
353        else:
354            neurons = self.hidden_neurons
355            species_order = list(neurons.keys())
356        for species in species_order:
357            assert (
358                species in PERIODIC_TABLE
359            ), f"species {species} not found in periodic table"
360
361        self.networks = {
362            k: FullyConnectedNet(
363                [*neurons[k], self.output_dim],
364                self.activation,
365                self.use_bias,
366                name=k,
367                kernel_init=self.kernel_init,
368            )
369            for k in species_order
370        }

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module's setup method (see __setattr__())::

    >>> class MyModule(nn.Module):
    ...   def setup(self):
    ...     submodule = nn.Conv(...)
    
    ...     # Accessing `submodule` attributes does not yet work here.
    
    ...     # The following line invokes `self.__setattr__`, which gives
    ...     # `submodule` the name "conv1".
    ...     self.conv1 = submodule
    
    ...     # Accessing `submodule` attributes or methods is now safe and
    ...     # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChemicalNet(flax.linen.module.Module):
430class ChemicalNet(nn.Module):
431    """optimized Chemical-species-specific neural network.
432
433    FID: CHEMICAL_NET
434
435    A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species.
436    This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once.
437    The optimization is allowed because all networks have the same shape.
438
439    """
440
441    species_order: Union[str, Sequence[str]]
442    """The species for which to build a network."""
443    neurons: Sequence[int]
444    """The dimensions of the fully connected networks."""
445    activation: Union[Callable, str] = "silu"
446    """The activation function to use in the fully connected networks."""
447    use_bias: bool = True
448    """Whether to include bias terms in the fully connected networks."""
449    input_key: Optional[str] = None
450    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
451    output_key: Optional[str] = None
452    """The key in the output dictionary that corresponds to the network's output."""
453    squeeze: bool = False
454    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
455    kernel_init: Union[str, Callable] = "lecun_normal()"
456    """The kernel initialization method for the fully connected networks."""
457
458    FID: ClassVar[str] = "CHEMICAL_NET"
459
460    @nn.compact
461    def __call__(
462        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
463    ) -> Union[dict, jax.Array]:
464        if self.input_key is None:
465            assert not isinstance(
466                inputs, dict
467            ), "input key must be provided if inputs is a dictionary"
468            species, embedding = inputs
469        else:
470            species, embedding = inputs["species"], inputs[self.input_key]
471
472        ############################
473        # build species to network index mapping (static => fixed when jitted)
474        rev_idx = PERIODIC_TABLE_REV_IDX
475        maxidx = max(rev_idx.values())
476        if isinstance(self.species_order, str):
477            species_order = [el.strip() for el in self.species_order.split(",")]
478        else:
479            species_order = [el for el in self.species_order]
480        nspecies = len(species_order)
481        conv_tensor_ = np.full((maxidx + 2,), -1, dtype=np.int32)
482        for i, s in enumerate(species_order):
483            conv_tensor_[rev_idx[s]] = i
484        conv_tensor = jnp.asarray(conv_tensor_)
485        indices = conv_tensor[species]
486
487        ############################
488        # build shape-sharing networks using vmap
489        networks = nn.vmap(
490            FullyConnectedNet,
491            variable_axes={"params": 0},
492            split_rngs={"params": True},
493            in_axes=0,
494        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
495        # repeat input along a new axis to compute for all species at once
496        x = jnp.broadcast_to(
497            embedding[None, :, :], (nspecies, *embedding.shape)
498        )
499
500        # apply networks to input and select the output corresponding to the species
501        out = jnp.squeeze(
502            jnp.take_along_axis(networks(x), indices[None, :, None], axis=0), axis=0
503        )
504
505        out = jnp.where((indices >= 0)[:, None], out, 0.0)
506        if self.squeeze and out.shape[-1] == 1:
507            out = jnp.squeeze(out, axis=-1)
508        ############################
509
510        if self.input_key is not None:
511            output_key = self.name if self.output_key is None else self.output_key
512            return {**inputs, output_key: out} if output_key is not None else out
513        return out

optimized Chemical-species-specific neural network.

FID: CHEMICAL_NET

A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species. This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once. The optimization is allowed because all networks have the same shape.

ChemicalNet( species_order: Union[str, Sequence[str]], neurons: Sequence[int], activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
species_order: Union[str, Sequence[str]]

The species for which to build a network.

neurons: Sequence[int]

The dimensions of the fully connected networks.

activation: Union[Callable, str] = 'silu'

The activation function to use in the fully connected networks.

use_bias: bool = True

Whether to include bias terms in the fully connected networks.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the embeddings of the atoms.

output_key: Optional[str] = None

The key in the output dictionary that corresponds to the network's output.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method for the fully connected networks.

FID: ClassVar[str] = 'CHEMICAL_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class MOENet(flax.linen.module.Module):
516class MOENet(nn.Module):
517    """Mixture of Experts neural network.
518
519    FID: MOE_NET
520
521    This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks
522    to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router.
523
524    """
525
526    neurons: Sequence[int]
527    """A sequence of integers representing the number of neurons in each shape-sharing network."""
528    num_networks: int
529    """The number of shape-sharing networks to create."""
530    activation: Union[Callable, str] = "silu"
531    """The activation function to use in the shape-sharing networks."""
532    use_bias: bool = True
533    """Whether to include bias in the shape-sharing networks."""
534    input_key: Optional[str] = None
535    """The key of the input tensor."""
536    output_key: Optional[str] = None
537    """The key of the output tensor."""
538    squeeze: bool = False
539    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
540    
541    kernel_init: Union[str, Callable] = "lecun_normal()"
542    """The kernel initialization method to use in the shape-sharing networks."""
543    router_key: Optional[str] = None
544    """The key of the router tensor. If None, the router is assumed to be the same as the input tensor."""
545
546    FID: ClassVar[str] = "MOE_NET"
547
548    @nn.compact
549    def __call__(
550        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
551    ) -> Union[dict, jax.Array]:
552        if self.input_key is None:
553            assert not isinstance(
554                inputs, dict
555            ), "input key must be provided if inputs is a dictionary"
556            if isinstance(inputs, tuple):
557                embedding, router = inputs
558            else:
559                embedding = router = inputs
560        else:
561            embedding = inputs[self.input_key]
562            router = (
563                inputs[self.router_key] if self.router_key is not None else embedding
564            )
565
566        ############################
567        # build shape-sharing networks using vmap
568        networks = nn.vmap(
569            FullyConnectedNet,
570            variable_axes={"params": 0},
571            split_rngs={"params": True},
572            in_axes=0,
573            out_axes=0,
574        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
575        # repeat input along a new axis to compute for all networks at once
576        x = jnp.repeat(embedding[None, :, :], self.num_networks, axis=0)
577
578        w = nn.softmax(nn.Dense(self.num_networks, name="router")(router), axis=-1)
579
580        out = (networks(x) * w.T[:, :, None]).sum(axis=0)
581
582        if self.squeeze and out.shape[-1] == 1:
583            out = jnp.squeeze(out, axis=-1)
584        ############################
585
586        if self.input_key is not None:
587            output_key = self.name if self.output_key is None else self.output_key
588            return {**inputs, output_key: out} if output_key is not None else out
589        return out

Mixture of Experts neural network.

FID: MOE_NET

This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router.

MOENet( neurons: Sequence[int], num_networks: int, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', router_key: Optional[str] = None, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the number of neurons in each shape-sharing network.

num_networks: int

The number of shape-sharing networks to create.

activation: Union[Callable, str] = 'silu'

The activation function to use in the shape-sharing networks.

use_bias: bool = True

Whether to include bias in the shape-sharing networks.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use in the shape-sharing networks.

router_key: Optional[str] = None

The key of the router tensor. If None, the router is assumed to be the same as the input tensor.

FID: ClassVar[str] = 'MOE_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ChannelNet(flax.linen.module.Module):
591class ChannelNet(nn.Module):
592    """Apply a different neural network to each channel.
593    
594    FID: CHANNEL_NET
595    """
596
597    neurons: Sequence[int]
598    """A sequence of integers representing the number of neurons in each shape-sharing network."""
599    activation: Union[Callable, str] = "silu"
600    """The activation function to use in the shape-sharing networks."""
601    use_bias: bool = True
602    """Whether to include bias in the shape-sharing networks."""
603    input_key: Optional[str] = None
604    """The key of the input tensor."""
605    output_key: Optional[str] = None
606    """The key of the output tensor."""
607    squeeze: bool = False
608    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
609    kernel_init: Union[str, Callable] = "lecun_normal()"
610    """The kernel initialization method to use in the shape-sharing networks."""
611    channel_axis: int = -2
612    """The axis to use as channel. Its length will be the number of shape-sharing networks."""
613
614    FID: ClassVar[str] = "CHANNEL_NET"
615
616    @nn.compact
617    def __call__(
618        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
619    ) -> Union[dict, jax.Array]:
620        if self.input_key is None:
621            assert not isinstance(
622                inputs, dict
623            ), "input key must be provided if inputs is a dictionary"
624            x = inputs
625        else:
626            x = inputs[self.input_key]
627
628        ############################
629        # build shape-sharing networks using vmap
630        networks = nn.vmap(
631            FullyConnectedNet,
632            variable_axes={"params": 0},
633            split_rngs={"params": True},
634            in_axes=self.channel_axis,
635            out_axes=self.channel_axis,
636        )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init)
637
638        out = networks(x)
639
640        if self.squeeze and out.shape[-1] == 1:
641            out = jnp.squeeze(out, axis=-1)
642        ############################
643
644        if self.input_key is not None:
645            output_key = self.name if self.output_key is None else self.output_key
646            return {**inputs, output_key: out} if output_key is not None else out
647        return out

Apply a different neural network to each channel.

FID: CHANNEL_NET

ChannelNet( neurons: Sequence[int], activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', channel_axis: int = -2, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the number of neurons in each shape-sharing network.

activation: Union[Callable, str] = 'silu'

The activation function to use in the shape-sharing networks.

use_bias: bool = True

Whether to include bias in the shape-sharing networks.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use in the shape-sharing networks.

channel_axis: int = -2

The axis to use as channel. Its length will be the number of shape-sharing networks.

FID: ClassVar[str] = 'CHANNEL_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class GatedPerceptron(flax.linen.module.Module):
650class GatedPerceptron(nn.Module):
651    """Gated Perceptron neural network.
652
653    FID: GATED_PERCEPTRON
654
655    This class represents a Gated Perceptron neural network model. It applies a gating mechanism
656    to the input data and performs linear transformation using a dense layer followed by an activation function.
657    """
658
659    dim: int
660    """The dimensionality of the output space."""
661    use_bias: bool = True
662    """Whether to include a bias term in the dense layer."""
663    kernel_init: Union[str, Callable] = "lecun_normal()"
664    """The kernel initialization method to use."""
665    activation: Union[Callable, str] = "silu"
666    """The activation function to use."""
667
668    input_key: Optional[str] = None
669    """The key of the input tensor."""
670    output_key: Optional[str] = None
671    """The key of the output tensor."""
672    squeeze: bool = False
673    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
674
675    FID: ClassVar[str] = "GATED_PERCEPTRON"
676
677    @nn.compact
678    def __call__(self, inputs):
679        if self.input_key is None:
680            assert not isinstance(
681                inputs, dict
682            ), "input key must be provided if inputs is a dictionary"
683            x = inputs
684        else:
685            x = inputs[self.input_key]
686
687        # activation = (
688        #     activation_from_str(self.activation)
689        #     if isinstance(self.activation, str)
690        #     else self.activation
691        # )
692        kernel_init = (
693            initializer_from_str(self.kernel_init)
694            if isinstance(self.kernel_init, str)
695            else self.kernel_init
696        )
697        ############################
698        gate = jax.nn.sigmoid(
699            nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x)
700        )
701        x = gate * activation_from_str(self.activation)(
702            nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x)
703        )
704
705        if self.squeeze and out.shape[-1] == 1:
706            out = jnp.squeeze(out, axis=-1)
707        ############################
708
709        if self.input_key is not None:
710            output_key = self.name if self.output_key is None else self.output_key
711            return {**inputs, output_key: x} if output_key is not None else x
712        return x

Gated Perceptron neural network.

FID: GATED_PERCEPTRON

This class represents a Gated Perceptron neural network model. It applies a gating mechanism to the input data and performs linear transformation using a dense layer followed by an activation function.

GatedPerceptron( dim: int, use_bias: bool = True, kernel_init: Union[str, Callable] = 'lecun_normal()', activation: Union[Callable, str] = 'silu', input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
dim: int

The dimensionality of the output space.

use_bias: bool = True

Whether to include a bias term in the dense layer.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

activation: Union[Callable, str] = 'silu'

The activation function to use.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

FID: ClassVar[str] = 'GATED_PERCEPTRON'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ZAcNet(flax.linen.module.Module):
715class ZAcNet(nn.Module):
716    """ A fully connected neural network module with affine Z-dependent adjustments of activations.
717    
718    FID: ZACNET
719    """
720
721    neurons: Sequence[int]
722    """A sequence of integers representing the dimensions of the network."""
723    zmax: int = 86
724    """The maximum atomic number to consider."""
725    activation: Union[Callable, str] = "silu"
726    """The activation function to use."""
727    use_bias: bool = True
728    """Whether to use bias in the dense layers."""
729    input_key: Optional[str] = None
730    """The key of the input tensor."""
731    output_key: Optional[str] = None
732    """The key of the output tensor."""
733    squeeze: bool = False
734    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
735    kernel_init: Union[str, Callable] = "lecun_normal()"
736    """The kernel initialization method to use."""
737    species_key: str = "species"
738    """The key of the species tensor."""
739
740    FID: ClassVar[str] = "ZACNET"
741
742    @nn.compact
743    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
744        if self.input_key is None:
745            assert not isinstance(
746                inputs, dict
747            ), "input key must be provided if inputs is a dictionary"
748            species, x = inputs
749        else:
750            species, x = inputs[self.species_key], inputs[self.input_key]
751
752        # activation = (
753        #     activation_from_str(self.activation)
754        #     if isinstance(self.activation, str)
755        #     else self.activation
756        # )
757        kernel_init = (
758            initializer_from_str(self.kernel_init)
759            if isinstance(self.kernel_init, str)
760            else self.kernel_init
761        )
762        ############################
763        for i, d in enumerate(self.neurons[:-1]):
764            x = nn.Dense(
765                d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init
766            )(x)
767            sig = self.param(
768                f"sig_{i+1}",
769                lambda key, shape: jnp.ones(shape, dtype=x.dtype),
770                (self.zmax + 2, d),
771            )[species]
772            if self.use_bias:
773                b = self.param(
774                    f"b_{i+1}",
775                    lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
776                    (self.zmax + 2, d),
777                )[species]
778            else:
779                b = 0
780            x = activation_from_str(self.activation)(sig * x + b)
781        x = nn.Dense(
782            self.neurons[-1],
783            use_bias=self.use_bias,
784            name=f"Layer_{len(self.neurons)}",
785            kernel_init=kernel_init,
786        )(x)
787        sig = self.param(
788            f"sig_{len(self.neurons)}",
789            lambda key, shape: jnp.ones(shape, dtype=x.dtype),
790            (self.zmax + 2, self.neurons[-1]),
791        )[species]
792        if self.use_bias:
793            b = self.param(
794                f"b_{len(self.neurons)}",
795                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
796                (self.zmax + 2, self.neurons[-1]),
797            )[species]
798        else:
799            b = 0
800        x = sig * x + b
801        if self.squeeze and x.shape[-1] == 1:
802            x = jnp.squeeze(x, axis=-1)
803        ############################
804
805        if self.input_key is not None:
806            output_key = self.name if self.output_key is None else self.output_key
807            return {**inputs, output_key: x} if output_key is not None else x
808        return x

A fully connected neural network module with affine Z-dependent adjustments of activations.

FID: ZACNET

ZAcNet( neurons: Sequence[int], zmax: int = 86, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', species_key: str = 'species', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the dimensions of the network.

zmax: int = 86

The maximum atomic number to consider.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to use bias in the dense layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

species_key: str = 'species'

The key of the species tensor.

FID: ClassVar[str] = 'ZACNET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class ZLoRANet(flax.linen.module.Module):
811class ZLoRANet(nn.Module):
812    """A fully connected neural network module with Z-dependent low-rank adaptation.
813    
814    FID: ZLORANET
815    """
816
817    neurons: Sequence[int]
818    """A sequence of integers representing the dimensions of the network."""
819    ranks: Sequence[int]
820    """A sequence of integers representing the ranks of the low-rank adaptation at each layer."""
821    zmax: int = 86
822    """The maximum atomic number to consider."""
823    activation: Union[Callable, str] = "silu"
824    """The activation function to use."""
825    use_bias: bool = True
826    """Whether to use bias in the dense layers."""
827    input_key: Optional[str] = None
828    """The key of the input tensor."""
829    output_key: Optional[str] = None
830    """The key of the output tensor."""
831    squeeze: bool = False
832    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
833    kernel_init: Union[str, Callable] = "lecun_normal()"
834    """The kernel initialization method to use."""
835    species_key: str = "species"
836    """The key of the species tensor."""
837
838    FID: ClassVar[str] = "ZLORANET"
839
840    @nn.compact
841    def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]:
842        if self.input_key is None:
843            assert not isinstance(
844                inputs, dict
845            ), "input key must be provided if inputs is a dictionary"
846            species, x = inputs
847        else:
848            species, x = inputs[self.species_key], inputs[self.input_key]
849
850        # activation = (
851        #     activation_from_str(self.activation)
852        #     if isinstance(self.activation, str)
853        #     else self.activation
854        # )
855        kernel_init = (
856            initializer_from_str(self.kernel_init)
857            if isinstance(self.kernel_init, str)
858            else self.kernel_init
859        )
860        ############################
861        for i, d in enumerate(self.neurons[:-1]):
862            xi = nn.Dense(
863                d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init
864            )(x)
865            A = self.param(
866                f"A_{i+1}",
867                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
868                (self.zmax + 2, self.ranks[i], x.shape[-1]),
869            )[species]
870            B = self.param(
871                f"B_{i+1}",
872                lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
873                (self.zmax + 2, d, self.ranks[i]),
874            )[species]
875            Ax = jnp.einsum("zrd,zd->zr", A, x)
876            BAx = jnp.einsum("zrd,zd->zr", B, Ax)
877            x = activation_from_str(self.activation)(xi + BAx)
878        xi = nn.Dense(
879            self.neurons[-1],
880            use_bias=self.use_bias,
881            name=f"Layer_{len(self.neurons)}",
882            kernel_init=kernel_init,
883        )(x)
884        A = self.param(
885            f"A_{len(self.neurons)}",
886            lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
887            (self.zmax + 2, self.ranks[-1], x.shape[-1]),
888        )[species]
889        B = self.param(
890            f"B_{len(self.neurons)}",
891            lambda key, shape: jnp.zeros(shape, dtype=x.dtype),
892            (self.zmax + 2, self.neurons[-1], self.ranks[-1]),
893        )[species]
894        Ax = jnp.einsum("zrd,zd->zr", A, x)
895        BAx = jnp.einsum("zrd,zd->zr", B, Ax)
896        x = xi + BAx
897        if self.squeeze and x.shape[-1] == 1:
898            x = jnp.squeeze(x, axis=-1)
899        ############################
900
901        if self.input_key is not None:
902            output_key = self.name if self.output_key is None else self.output_key
903            return {**inputs, output_key: x} if output_key is not None else x
904        return x

A fully connected neural network module with Z-dependent low-rank adaptation.

FID: ZLORANET

ZLoRANet( neurons: Sequence[int], ranks: Sequence[int], zmax: int = 86, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', species_key: str = 'species', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
neurons: Sequence[int]

A sequence of integers representing the dimensions of the network.

ranks: Sequence[int]

A sequence of integers representing the ranks of the low-rank adaptation at each layer.

zmax: int = 86

The maximum atomic number to consider.

activation: Union[Callable, str] = 'silu'

The activation function to use.

use_bias: bool = True

Whether to use bias in the dense layers.

input_key: Optional[str] = None

The key of the input tensor.

output_key: Optional[str] = None

The key of the output tensor.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method to use.

species_key: str = 'species'

The key of the species tensor.

FID: ClassVar[str] = 'ZLORANET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None
class BlockIndexNet(flax.linen.module.Module):
 907class BlockIndexNet(nn.Module):
 908    """Chemical-species-specific neural network using precomputed species index.
 909
 910    FID: BLOCK_INDEX_NET
 911
 912    A neural network that applies a species-specific fully connected network to each atom embedding.
 913    A species index must be provided to filter the embeddings for each species and apply the corresponding network.
 914    This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer`
 915
 916    """
 917
 918    output_dim: int
 919    """The dimension of the output of the fully connected networks."""
 920    hidden_neurons: Sequence[int]
 921    """The hidden dimensions of the fully connected networks.
 922        If a dictionary is provided, it should map species names to dimensions.
 923        If a sequence is provided, the same dimensions will be used for all species."""
 924    used_blocks: Optional[Sequence[str]] = None
 925    """The blocks to use. If None, all blocks will be used."""
 926    activation: Union[Callable, str] = "silu"
 927    """The activation function to use in the fully connected networks."""
 928    use_bias: bool = True
 929    """Whether to include bias terms in the fully connected networks."""
 930    input_key: Optional[str] = None
 931    """The key in the input dictionary that corresponds to the embeddings of the atoms."""
 932    block_index_key: str = "block_index"
 933    """The key in the input dictionary that corresponds to the block index of the atoms. See `fennol.models.preprocessing.BlockIndexer`"""
 934    output_key: Optional[str] = None
 935    """The key in the output dictionary that corresponds to the network's output."""
 936
 937    squeeze: bool = False
 938    """Whether to remove the last axis of the output tensor if it is of dimension 1."""
 939    kernel_init: Union[str, Callable] = "lecun_normal()"
 940    """The kernel initialization method for the fully connected networks."""
 941    # check_unhandled: bool = True
 942
 943    FID: ClassVar[str] = "BLOCK_INDEX_NET"
 944
 945    # def setup(self):
 946    #     all_blocks = CHEMICAL_BLOCKS_NAMES
 947    #     if self.used_blocks is None:
 948    #         used_blocks = all_blocks
 949    #     else:
 950    #         used_blocks = []
 951    #         for b in self.used_blocks:
 952    #             b_=str(b).strip().upper()
 953    #             if b_ not in all_blocks:
 954    #                 raise ValueError(f"Block {b} not found in {all_blocks}")
 955    #             used_blocks.append(b_)
 956    #     used_blocks = set(used_blocks)
 957    #     self._used_blocks = used_blocks
 958
 959    #     if not (
 960    #         isinstance(self.hidden_neurons, dict)
 961    #         or isinstance(self.hidden_neurons, FrozenDict)
 962    #     ):
 963    #         neurons = {k: self.hidden_neurons for k in used_blocks}
 964    #     else:
 965    #         neurons = {}
 966    #         for b in self.hidden_neurons.keys():
 967    #             b_=str(b).strip().upper()
 968    #             if b_ not in all_blocks:
 969    #                 raise ValueError(f"Block {b} does not exist.  Available blocks are {all_blocks}")
 970    #             neurons[b_] = self.hidden_neurons[b]
 971    #         used_blocks = set(neurons.keys())
 972    #         if used_blocks != self._used_blocks and self.used_blocks is not None:
 973    #             print(
 974    #                 f"Warning: hidden neurons definitions do not match specified used_blocks {self.used_blocks}. Using blocks defined in hidden_neurons.")
 975    #         self._used_blocks = used_blocks
 976
 977    #     self.networks = {
 978    #         k: FullyConnectedNet(
 979    #             [*neurons[k], self.output_dim],
 980    #             self.activation,
 981    #             self.use_bias,
 982    #             name=k,
 983    #             kernel_init=self.kernel_init,
 984    #         )
 985    #         for k in self._used_blocks
 986    #     }
 987
 988    @nn.compact
 989    def __call__(
 990        self, inputs: Union[dict, Tuple[jax.Array, jax.Array]]
 991    ) -> Union[dict, jax.Array]:
 992
 993        if self.input_key is None:
 994            assert not isinstance(
 995                inputs, dict
 996            ), "input key must be provided if inputs is a dictionary"
 997            species, embedding, block_index = inputs
 998        else:
 999            species, embedding = inputs["species"], inputs[self.input_key]
1000            block_index = inputs[self.block_index_key]
1001
1002        assert isinstance(
1003            block_index, dict
1004        ), "block_index must be a dictionary for BlockIndexNet"
1005
1006        networks = {
1007            k: FullyConnectedNet(
1008                [*self.hidden_neurons, self.output_dim],
1009                self.activation,
1010                self.use_bias,
1011                name=k,
1012                kernel_init=self.kernel_init,
1013            )
1014            for k in block_index.keys()
1015        }
1016
1017        ############################
1018        # initialization => instantiate all networks
1019        if self.is_initializing():
1020            x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype)
1021            [net(x) for net in networks.values()]
1022
1023        # if self.check_unhandled:
1024        #     for b in block_index.keys():
1025        #         if b not in networks.keys():
1026        #             raise ValueError(f"Block {b} not found in networks. Available blocks are {self.networks.keys()}")
1027
1028        ############################
1029        outputs = []
1030        indices = []
1031        for s, net in networks.items():
1032            if s not in block_index:
1033                continue
1034            if block_index[s] is None:
1035                continue
1036            idx = block_index[s]
1037            o = net(embedding[idx])
1038            outputs.append(o)
1039            indices.append(idx)
1040
1041        o = jnp.concatenate(outputs, axis=0)
1042        idx = jnp.concatenate(indices, axis=0)
1043
1044        out = (
1045            jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype)
1046            .at[idx]
1047            .set(o, mode="drop")
1048        )
1049
1050        if self.squeeze and out.shape[-1] == 1:
1051            out = jnp.squeeze(out, axis=-1)
1052        ############################
1053
1054        if self.input_key is not None:
1055            output_key = self.name if self.output_key is None else self.output_key
1056            return {**inputs, output_key: out} if output_key is not None else out
1057        return out

Chemical-species-specific neural network using precomputed species index.

FID: BLOCK_INDEX_NET

A neural network that applies a species-specific fully connected network to each atom embedding. A species index must be provided to filter the embeddings for each species and apply the corresponding network. This index can be obtained using the SPECIES_INDEXER preprocessing module from fennol.models.preprocessing.SpeciesIndexer

BlockIndexNet( output_dim: int, hidden_neurons: Sequence[int], used_blocks: Optional[Sequence[str]] = None, activation: Union[Callable, str] = 'silu', use_bias: bool = True, input_key: Optional[str] = None, block_index_key: str = 'block_index', output_key: Optional[str] = None, squeeze: bool = False, kernel_init: Union[str, Callable] = 'lecun_normal()', parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType] = <flax.linen.module._Sentinel object>, name: Optional[str] = None)
output_dim: int

The dimension of the output of the fully connected networks.

hidden_neurons: Sequence[int]

The hidden dimensions of the fully connected networks. If a dictionary is provided, it should map species names to dimensions. If a sequence is provided, the same dimensions will be used for all species.

used_blocks: Optional[Sequence[str]] = None

The blocks to use. If None, all blocks will be used.

activation: Union[Callable, str] = 'silu'

The activation function to use in the fully connected networks.

use_bias: bool = True

Whether to include bias terms in the fully connected networks.

input_key: Optional[str] = None

The key in the input dictionary that corresponds to the embeddings of the atoms.

block_index_key: str = 'block_index'

The key in the input dictionary that corresponds to the block index of the atoms. See fennol.models.preprocessing.BlockIndexer

output_key: Optional[str] = None

The key in the output dictionary that corresponds to the network's output.

squeeze: bool = False

Whether to remove the last axis of the output tensor if it is of dimension 1.

kernel_init: Union[str, Callable] = 'lecun_normal()'

The kernel initialization method for the fully connected networks.

FID: ClassVar[str] = 'BLOCK_INDEX_NET'
parent: Union[flax.linen.module.Module, flax.core.scope.Scope, flax.linen.module._Sentinel, NoneType]

Wraps parent module references in weak refs.

This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.

Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.

name: Optional[str] = None
scope = None