fennol.models.misc.nets
1import flax.linen as nn 2from typing import Sequence, Callable, Union, Optional, ClassVar, Tuple 3from ...utils.periodic_table import PERIODIC_TABLE_REV_IDX, PERIODIC_TABLE 4import jax.numpy as jnp 5import jax 6import numpy as np 7from functools import partial 8from ...utils.activations import activation_from_str, TrainableSiLU 9from ...utils.initializers import initializer_from_str, scaled_orthogonal 10from flax.core import FrozenDict 11 12 13class FullyConnectedNet(nn.Module): 14 """A fully connected neural network module. 15 16 FID: NEURAL_NET 17 """ 18 19 neurons: Sequence[int] 20 """A sequence of integers representing the dimensions of the network.""" 21 activation: Union[Callable, str] = "silu" 22 """The activation function to use.""" 23 use_bias: bool = True 24 """Whether to use bias in the dense layers.""" 25 input_key: Optional[str] = None 26 """The key of the input tensor.""" 27 output_key: Optional[str] = None 28 """The key of the output tensor.""" 29 squeeze: bool = False 30 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 31 kernel_init: Union[str, Callable] = "lecun_normal()" 32 """The kernel initialization method to use.""" 33 34 FID: ClassVar[str] = "NEURAL_NET" 35 36 @nn.compact 37 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 38 """Applies the neural network to the given inputs. 39 40 Args: 41 inputs (Union[dict, jax.Array]): If input_key is None, a JAX array containing the inputs to the neural network. Else, a dictionary containing the inputs at the key input_key. 42 43 Returns: 44 Union[dict, jax.Array]: If output_key is None, the output tensor of the neural network. Else, a dictionary containing the original inputs and the output tensor at the key output_key. 45 """ 46 if self.input_key is None: 47 assert not isinstance( 48 inputs, dict 49 ), "input key must be provided if inputs is a dictionary" 50 x = inputs 51 else: 52 x = inputs[self.input_key] 53 54 # activation = ( 55 # activation_from_str(self.activation) 56 # if isinstance(self.activation, str) 57 # else self.activation 58 # ) 59 kernel_init = ( 60 initializer_from_str(self.kernel_init) 61 if isinstance(self.kernel_init, str) 62 else self.kernel_init 63 ) 64 ############################ 65 for i, d in enumerate(self.neurons[:-1]): 66 x = nn.Dense( 67 d, 68 use_bias=self.use_bias, 69 name=f"Layer_{i+1}", 70 kernel_init=kernel_init, 71 # precision=jax.lax.Precision.HIGH, 72 )(x) 73 x = activation_from_str(self.activation)(x) 74 x = nn.Dense( 75 self.neurons[-1], 76 use_bias=self.use_bias, 77 name=f"Layer_{len(self.neurons)}", 78 kernel_init=kernel_init, 79 # precision=jax.lax.Precision.HIGH, 80 )(x) 81 if self.squeeze and x.shape[-1] == 1: 82 x = jnp.squeeze(x, axis=-1) 83 ############################ 84 85 if self.input_key is not None: 86 output_key = self.name if self.output_key is None else self.output_key 87 return {**inputs, output_key: x} if output_key is not None else x 88 return x 89 90 91class ResMLP(nn.Module): 92 """Residual neural network as defined in the SpookyNet paper. 93 94 FID: RES_MLP 95 """ 96 97 use_bias: bool = True 98 """Whether to include bias in the linear layers.""" 99 input_key: Optional[str] = None 100 """The key of the input tensor.""" 101 output_key: Optional[str] = None 102 """The key of the output tensor.""" 103 104 kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')" 105 """The kernel initialization method to use.""" 106 res_only: bool = False 107 """Whether to only apply the residual connection without additional activation and linear layer.""" 108 109 FID: ClassVar[str] = "RES_MLP" 110 111 @nn.compact 112 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 113 if self.input_key is None: 114 assert not isinstance( 115 inputs, dict 116 ), "input key must be provided if inputs is a dictionary" 117 x = inputs 118 else: 119 x = inputs[self.input_key] 120 121 kernel_init = ( 122 initializer_from_str(self.kernel_init) 123 if isinstance(self.kernel_init, str) 124 else self.kernel_init 125 ) 126 ############################ 127 out = nn.Dense(x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init)( 128 TrainableSiLU()(x) 129 ) 130 out = x + nn.Dense( 131 x.shape[-1], use_bias=self.use_bias, kernel_init=nn.initializers.zeros 132 )(TrainableSiLU()(out)) 133 134 if not self.res_only: 135 out = nn.Dense( 136 x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init 137 )(TrainableSiLU()(out)) 138 ############################ 139 140 if self.input_key is not None: 141 output_key = self.name if self.output_key is None else self.output_key 142 return {**inputs, output_key: out} if output_key is not None else out 143 return out 144 145 146class FullyResidualNet(nn.Module): 147 """A neural network with skip connections at each layer. 148 149 FID: SKIP_NET 150 """ 151 152 dim: int 153 """The dimension of the hidden layers.""" 154 output_dim: int 155 """The dimension of the output layer.""" 156 nlayers: int 157 """The number of layers in the network.""" 158 activation: Union[Callable, str] = "silu" 159 """The activation function to use.""" 160 use_bias: bool = True 161 """Whether to include bias terms in the linear layers.""" 162 input_key: Optional[str] = None 163 """The key of the input tensor.""" 164 output_key: Optional[str] = None 165 """The key of the output tensor.""" 166 squeeze: bool = False 167 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 168 kernel_init: Union[str, Callable] = "lecun_normal()" 169 """The kernel initialization method to use.""" 170 171 FID: ClassVar[str] = "SKIP_NET" 172 173 @nn.compact 174 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 175 if self.input_key is None: 176 assert not isinstance( 177 inputs, dict 178 ), "input key must be provided if inputs is a dictionary" 179 x = inputs 180 else: 181 x = inputs[self.input_key] 182 183 # activation = ( 184 # activation_from_str(self.activation) 185 # if isinstance(self.activation, str) 186 # else self.activation 187 # ) 188 kernel_init = ( 189 initializer_from_str(self.kernel_init) 190 if isinstance(self.kernel_init, str) 191 else self.kernel_init 192 ) 193 ############################ 194 if x.shape[-1] != self.dim: 195 x = nn.Dense( 196 self.dim, 197 use_bias=self.use_bias, 198 name=f"Reshape", 199 kernel_init=kernel_init, 200 )(x) 201 202 for i in range(self.nlayers - 1): 203 x = x + activation_from_str(self.activation)( 204 nn.Dense( 205 self.dim, 206 use_bias=self.use_bias, 207 name=f"Layer_{i+1}", 208 kernel_init=kernel_init, 209 )(x) 210 ) 211 x = nn.Dense( 212 self.output_dim, 213 use_bias=self.use_bias, 214 name=f"Layer_{self.nlayers}", 215 kernel_init=kernel_init, 216 )(x) 217 if self.squeeze and x.shape[-1] == 1: 218 x = jnp.squeeze(x, axis=-1) 219 ############################ 220 221 if self.input_key is not None: 222 output_key = self.name if self.output_key is None else self.output_key 223 return {**inputs, output_key: x} if output_key is not None else x 224 return x 225 226 227class HierarchicalNet(nn.Module): 228 """Neural network for a sequence of inputs (in axis=-2) with a decay factor 229 230 FID: HIERARCHICAL_NET 231 """ 232 233 neurons: Sequence[int] 234 """A sequence of integers representing the number of neurons in each layer.""" 235 activation: Union[Callable, str] = "silu" 236 """The activation function to use.""" 237 use_bias: bool = True 238 """Whether to include bias terms in the linear layers.""" 239 input_key: Optional[str] = None 240 """The key of the input tensor.""" 241 output_key: Optional[str] = None 242 """The key of the output tensor.""" 243 decay: float = 0.01 244 """The decay factor to scale each element of the sequence.""" 245 squeeze: bool = False 246 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 247 kernel_init: Union[str, Callable] = "lecun_normal()" 248 """The kernel initialization method to use.""" 249 250 FID: ClassVar[str] = "HIERARCHICAL_NET" 251 252 @nn.compact 253 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 254 if self.input_key is None: 255 assert not isinstance( 256 inputs, dict 257 ), "input key must be provided if inputs is a dictionary" 258 x = inputs 259 else: 260 x = inputs[self.input_key] 261 262 ############################ 263 networks = nn.vmap( 264 FullyConnectedNet, 265 variable_axes={"params": 0}, 266 split_rngs={"params": True}, 267 in_axes=-2, 268 out_axes=-2, 269 kenel_init=self.kernel_init, 270 )(self.neurons, self.activation, self.use_bias) 271 272 out = networks(x) 273 # scale each layer by a decay factor 274 decay = jnp.asarray([self.decay**i for i in range(out.shape[-2])]) 275 out = out * decay[..., :, None] 276 277 if self.squeeze and out.shape[-1] == 1: 278 out = jnp.squeeze(out, axis=-1) 279 ############################ 280 281 if self.input_key is not None: 282 output_key = self.name if self.output_key is None else self.output_key 283 return {**inputs, output_key: out} if output_key is not None else out 284 return out 285 286 287class SpeciesIndexNet(nn.Module): 288 """Chemical-species-specific neural network using precomputed species index. 289 290 FID: SPECIES_INDEX_NET 291 292 A neural network that applies a species-specific fully connected network to each atom embedding. 293 A species index must be provided to filter the embeddings for each species and apply the corresponding network. 294 This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer` 295 296 """ 297 298 output_dim: int 299 """The dimension of the output of the fully connected networks.""" 300 hidden_neurons: Union[dict, FrozenDict, Sequence[int]] 301 """The hidden dimensions of the fully connected networks. 302 If a dictionary is provided, it should map species names to dimensions. 303 If a sequence is provided, the same dimensions will be used for all species.""" 304 species_order: Optional[Union[str, Sequence[str]]] = None 305 """The species for which to build a network. Only required if neurons is not a dictionary.""" 306 activation: Union[Callable, str] = "silu" 307 """The activation function to use in the fully connected networks.""" 308 use_bias: bool = True 309 """Whether to include bias terms in the fully connected networks.""" 310 input_key: Optional[str] = None 311 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 312 species_index_key: str = "species_index" 313 """The key in the input dictionary that corresponds to the species index of the atoms. See `fennol.models.preprocessing.SpeciesIndexer`""" 314 output_key: Optional[str] = None 315 """The key in the output dictionary that corresponds to the network's output.""" 316 317 squeeze: bool = False 318 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 319 kernel_init: Union[str, Callable] = "lecun_normal()" 320 """The kernel initialization method for the fully connected networks.""" 321 check_unhandled: bool = True 322 323 FID: ClassVar[str] = "SPECIES_INDEX_NET" 324 325 def setup(self): 326 if not ( 327 isinstance(self.hidden_neurons, dict) 328 or isinstance(self.hidden_neurons, FrozenDict) 329 ): 330 assert ( 331 self.species_order is not None 332 ), "species_order must be provided if hidden_neurons is not a dictionary" 333 if isinstance(self.species_order, str): 334 species_order = [el.strip() for el in self.species_order.split(",")] 335 else: 336 species_order = [el for el in self.species_order] 337 neurons = {k: self.hidden_neurons for k in species_order} 338 else: 339 neurons = self.hidden_neurons 340 species_order = list(neurons.keys()) 341 for species in species_order: 342 assert ( 343 species in PERIODIC_TABLE 344 ), f"species {species} not found in periodic table" 345 346 self.networks = { 347 k: FullyConnectedNet( 348 [*neurons[k], self.output_dim], 349 self.activation, 350 self.use_bias, 351 name=k, 352 kernel_init=self.kernel_init, 353 ) 354 for k in species_order 355 } 356 357 def __call__( 358 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 359 ) -> Union[dict, jax.Array]: 360 361 if self.input_key is None: 362 assert not isinstance( 363 inputs, dict 364 ), "input key must be provided if inputs is a dictionary" 365 species, embedding, species_index = inputs 366 else: 367 species, embedding = inputs["species"], inputs[self.input_key] 368 species_index = inputs[self.species_index_key] 369 370 assert isinstance( 371 species_index, dict 372 ), "species_index must be a dictionary for SpeciesIndexNetHet" 373 374 ############################ 375 # initialization => instantiate all networks 376 if self.is_initializing(): 377 x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype) 378 [net(x) for net in self.networks.values()] 379 380 if self.check_unhandled: 381 for b in species_index.keys(): 382 if b not in self.networks.keys(): 383 raise ValueError(f"Species {b} not found in networks. Handled species are {self.networks.keys()}") 384 385 ############################ 386 outputs = [] 387 indices = [] 388 for s, net in self.networks.items(): 389 if s not in species_index: 390 continue 391 idx = species_index[s] 392 o = net(embedding[idx]) 393 outputs.append(o) 394 indices.append(idx) 395 396 o = jnp.concatenate(outputs, axis=0) 397 idx = jnp.concatenate(indices, axis=0) 398 399 out = ( 400 jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype) 401 .at[idx] 402 .set(o, mode="drop") 403 ) 404 405 if self.squeeze and out.shape[-1] == 1: 406 out = jnp.squeeze(out, axis=-1) 407 ############################ 408 409 if self.input_key is not None: 410 output_key = self.name if self.output_key is None else self.output_key 411 return {**inputs, output_key: out} if output_key is not None else out 412 return out 413 414 415class ChemicalNet(nn.Module): 416 """optimized Chemical-species-specific neural network. 417 418 FID: CHEMICAL_NET 419 420 A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species. 421 This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once. 422 The optimization is allowed because all networks have the same shape. 423 424 """ 425 426 species_order: Union[str, Sequence[str]] 427 """The species for which to build a network.""" 428 neurons: Sequence[int] 429 """The dimensions of the fully connected networks.""" 430 activation: Union[Callable, str] = "silu" 431 """The activation function to use in the fully connected networks.""" 432 use_bias: bool = True 433 """Whether to include bias terms in the fully connected networks.""" 434 input_key: Optional[str] = None 435 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 436 output_key: Optional[str] = None 437 """The key in the output dictionary that corresponds to the network's output.""" 438 squeeze: bool = False 439 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 440 kernel_init: Union[str, Callable] = "lecun_normal()" 441 """The kernel initialization method for the fully connected networks.""" 442 443 FID: ClassVar[str] = "CHEMICAL_NET" 444 445 @nn.compact 446 def __call__( 447 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 448 ) -> Union[dict, jax.Array]: 449 if self.input_key is None: 450 assert not isinstance( 451 inputs, dict 452 ), "input key must be provided if inputs is a dictionary" 453 species, embedding = inputs 454 else: 455 species, embedding = inputs["species"], inputs[self.input_key] 456 457 ############################ 458 # build species to network index mapping (static => fixed when jitted) 459 rev_idx = PERIODIC_TABLE_REV_IDX 460 maxidx = max(rev_idx.values()) 461 if isinstance(self.species_order, str): 462 species_order = [el.strip() for el in self.species_order.split(",")] 463 else: 464 species_order = [el for el in self.species_order] 465 nspecies = len(species_order) 466 conv_tensor_ = np.full((maxidx + 2,), -1, dtype=np.int32) 467 for i, s in enumerate(species_order): 468 conv_tensor_[rev_idx[s]] = i 469 conv_tensor = jnp.asarray(conv_tensor_) 470 indices = conv_tensor[species] 471 472 ############################ 473 # build shape-sharing networks using vmap 474 networks = nn.vmap( 475 FullyConnectedNet, 476 variable_axes={"params": 0}, 477 split_rngs={"params": True}, 478 in_axes=0, 479 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 480 # repeat input along a new axis to compute for all species at once 481 x = jnp.broadcast_to( 482 embedding[None, :, :], (nspecies, *embedding.shape) 483 ) 484 485 # apply networks to input and select the output corresponding to the species 486 out = jnp.squeeze( 487 jnp.take_along_axis(networks(x), indices[None, :, None], axis=0), axis=0 488 ) 489 490 out = jnp.where((indices >= 0)[:, None], out, 0.0) 491 if self.squeeze and out.shape[-1] == 1: 492 out = jnp.squeeze(out, axis=-1) 493 ############################ 494 495 if self.input_key is not None: 496 output_key = self.name if self.output_key is None else self.output_key 497 return {**inputs, output_key: out} if output_key is not None else out 498 return out 499 500 501class MOENet(nn.Module): 502 """Mixture of Experts neural network. 503 504 FID: MOE_NET 505 506 This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks 507 to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router. 508 509 """ 510 511 neurons: Sequence[int] 512 """A sequence of integers representing the number of neurons in each shape-sharing network.""" 513 num_networks: int 514 """The number of shape-sharing networks to create.""" 515 activation: Union[Callable, str] = "silu" 516 """The activation function to use in the shape-sharing networks.""" 517 use_bias: bool = True 518 """Whether to include bias in the shape-sharing networks.""" 519 input_key: Optional[str] = None 520 """The key of the input tensor.""" 521 output_key: Optional[str] = None 522 """The key of the output tensor.""" 523 squeeze: bool = False 524 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 525 526 kernel_init: Union[str, Callable] = "lecun_normal()" 527 """The kernel initialization method to use in the shape-sharing networks.""" 528 router_key: Optional[str] = None 529 """The key of the router tensor. If None, the router is assumed to be the same as the input tensor.""" 530 531 FID: ClassVar[str] = "MOE_NET" 532 533 @nn.compact 534 def __call__( 535 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 536 ) -> Union[dict, jax.Array]: 537 if self.input_key is None: 538 assert not isinstance( 539 inputs, dict 540 ), "input key must be provided if inputs is a dictionary" 541 if isinstance(inputs, tuple): 542 embedding, router = inputs 543 else: 544 embedding = router = inputs 545 else: 546 embedding = inputs[self.input_key] 547 router = ( 548 inputs[self.router_key] if self.router_key is not None else embedding 549 ) 550 551 ############################ 552 # build shape-sharing networks using vmap 553 networks = nn.vmap( 554 FullyConnectedNet, 555 variable_axes={"params": 0}, 556 split_rngs={"params": True}, 557 in_axes=0, 558 out_axes=0, 559 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 560 # repeat input along a new axis to compute for all networks at once 561 x = jnp.repeat(embedding[None, :, :], self.num_networks, axis=0) 562 563 w = nn.softmax(nn.Dense(self.num_networks, name="router")(router), axis=-1) 564 565 out = (networks(x) * w.T[:, :, None]).sum(axis=0) 566 567 if self.squeeze and out.shape[-1] == 1: 568 out = jnp.squeeze(out, axis=-1) 569 ############################ 570 571 if self.input_key is not None: 572 output_key = self.name if self.output_key is None else self.output_key 573 return {**inputs, output_key: out} if output_key is not None else out 574 return out 575 576class ChannelNet(nn.Module): 577 """Apply a different neural network to each channel. 578 579 FID: CHANNEL_NET 580 """ 581 582 neurons: Sequence[int] 583 """A sequence of integers representing the number of neurons in each shape-sharing network.""" 584 activation: Union[Callable, str] = "silu" 585 """The activation function to use in the shape-sharing networks.""" 586 use_bias: bool = True 587 """Whether to include bias in the shape-sharing networks.""" 588 input_key: Optional[str] = None 589 """The key of the input tensor.""" 590 output_key: Optional[str] = None 591 """The key of the output tensor.""" 592 squeeze: bool = False 593 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 594 kernel_init: Union[str, Callable] = "lecun_normal()" 595 """The kernel initialization method to use in the shape-sharing networks.""" 596 channel_axis: int = -2 597 """The axis to use as channel. Its length will be the number of shape-sharing networks.""" 598 599 FID: ClassVar[str] = "CHANNEL_NET" 600 601 @nn.compact 602 def __call__( 603 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 604 ) -> Union[dict, jax.Array]: 605 if self.input_key is None: 606 assert not isinstance( 607 inputs, dict 608 ), "input key must be provided if inputs is a dictionary" 609 x = inputs 610 else: 611 x = inputs[self.input_key] 612 613 ############################ 614 # build shape-sharing networks using vmap 615 networks = nn.vmap( 616 FullyConnectedNet, 617 variable_axes={"params": 0}, 618 split_rngs={"params": True}, 619 in_axes=self.channel_axis, 620 out_axes=self.channel_axis, 621 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 622 623 out = networks(x) 624 625 if self.squeeze and out.shape[-1] == 1: 626 out = jnp.squeeze(out, axis=-1) 627 ############################ 628 629 if self.input_key is not None: 630 output_key = self.name if self.output_key is None else self.output_key 631 return {**inputs, output_key: out} if output_key is not None else out 632 return out 633 634 635class GatedPerceptron(nn.Module): 636 """Gated Perceptron neural network. 637 638 FID: GATED_PERCEPTRON 639 640 This class represents a Gated Perceptron neural network model. It applies a gating mechanism 641 to the input data and performs linear transformation using a dense layer followed by an activation function. 642 """ 643 644 dim: int 645 """The dimensionality of the output space.""" 646 use_bias: bool = True 647 """Whether to include a bias term in the dense layer.""" 648 kernel_init: Union[str, Callable] = "lecun_normal()" 649 """The kernel initialization method to use.""" 650 activation: Union[Callable, str] = "silu" 651 """The activation function to use.""" 652 653 input_key: Optional[str] = None 654 """The key of the input tensor.""" 655 output_key: Optional[str] = None 656 """The key of the output tensor.""" 657 squeeze: bool = False 658 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 659 660 FID: ClassVar[str] = "GATED_PERCEPTRON" 661 662 @nn.compact 663 def __call__(self, inputs): 664 if self.input_key is None: 665 assert not isinstance( 666 inputs, dict 667 ), "input key must be provided if inputs is a dictionary" 668 x = inputs 669 else: 670 x = inputs[self.input_key] 671 672 # activation = ( 673 # activation_from_str(self.activation) 674 # if isinstance(self.activation, str) 675 # else self.activation 676 # ) 677 kernel_init = ( 678 initializer_from_str(self.kernel_init) 679 if isinstance(self.kernel_init, str) 680 else self.kernel_init 681 ) 682 ############################ 683 gate = jax.nn.sigmoid( 684 nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x) 685 ) 686 x = gate * activation_from_str(self.activation)( 687 nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x) 688 ) 689 690 if self.squeeze and out.shape[-1] == 1: 691 out = jnp.squeeze(out, axis=-1) 692 ############################ 693 694 if self.input_key is not None: 695 output_key = self.name if self.output_key is None else self.output_key 696 return {**inputs, output_key: x} if output_key is not None else x 697 return x 698 699 700class ZAcNet(nn.Module): 701 """ A fully connected neural network module with affine Z-dependent adjustments of activations. 702 703 FID: ZACNET 704 """ 705 706 neurons: Sequence[int] 707 """A sequence of integers representing the dimensions of the network.""" 708 zmax: int = 86 709 """The maximum atomic number to consider.""" 710 activation: Union[Callable, str] = "silu" 711 """The activation function to use.""" 712 use_bias: bool = True 713 """Whether to use bias in the dense layers.""" 714 input_key: Optional[str] = None 715 """The key of the input tensor.""" 716 output_key: Optional[str] = None 717 """The key of the output tensor.""" 718 squeeze: bool = False 719 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 720 kernel_init: Union[str, Callable] = "lecun_normal()" 721 """The kernel initialization method to use.""" 722 species_key: str = "species" 723 """The key of the species tensor.""" 724 725 FID: ClassVar[str] = "ZACNET" 726 727 @nn.compact 728 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 729 if self.input_key is None: 730 assert not isinstance( 731 inputs, dict 732 ), "input key must be provided if inputs is a dictionary" 733 species, x = inputs 734 else: 735 species, x = inputs[self.species_key], inputs[self.input_key] 736 737 # activation = ( 738 # activation_from_str(self.activation) 739 # if isinstance(self.activation, str) 740 # else self.activation 741 # ) 742 kernel_init = ( 743 initializer_from_str(self.kernel_init) 744 if isinstance(self.kernel_init, str) 745 else self.kernel_init 746 ) 747 ############################ 748 for i, d in enumerate(self.neurons[:-1]): 749 x = nn.Dense( 750 d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init 751 )(x) 752 sig = self.param( 753 f"sig_{i+1}", 754 lambda key, shape: jnp.ones(shape, dtype=x.dtype), 755 (self.zmax + 2, d), 756 )[species] 757 if self.use_bias: 758 b = self.param( 759 f"b_{i+1}", 760 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 761 (self.zmax + 2, d), 762 )[species] 763 else: 764 b = 0 765 x = activation_from_str(self.activation)(sig * x + b) 766 x = nn.Dense( 767 self.neurons[-1], 768 use_bias=self.use_bias, 769 name=f"Layer_{len(self.neurons)}", 770 kernel_init=kernel_init, 771 )(x) 772 sig = self.param( 773 f"sig_{len(self.neurons)}", 774 lambda key, shape: jnp.ones(shape, dtype=x.dtype), 775 (self.zmax + 2, self.neurons[-1]), 776 )[species] 777 if self.use_bias: 778 b = self.param( 779 f"b_{len(self.neurons)}", 780 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 781 (self.zmax + 2, self.neurons[-1]), 782 )[species] 783 else: 784 b = 0 785 x = sig * x + b 786 if self.squeeze and x.shape[-1] == 1: 787 x = jnp.squeeze(x, axis=-1) 788 ############################ 789 790 if self.input_key is not None: 791 output_key = self.name if self.output_key is None else self.output_key 792 return {**inputs, output_key: x} if output_key is not None else x 793 return x 794 795 796class ZLoRANet(nn.Module): 797 """A fully connected neural network module with Z-dependent low-rank adaptation. 798 799 FID: ZLORANET 800 """ 801 802 neurons: Sequence[int] 803 """A sequence of integers representing the dimensions of the network.""" 804 ranks: Sequence[int] 805 """A sequence of integers representing the ranks of the low-rank adaptation at each layer.""" 806 zmax: int = 86 807 """The maximum atomic number to consider.""" 808 activation: Union[Callable, str] = "silu" 809 """The activation function to use.""" 810 use_bias: bool = True 811 """Whether to use bias in the dense layers.""" 812 input_key: Optional[str] = None 813 """The key of the input tensor.""" 814 output_key: Optional[str] = None 815 """The key of the output tensor.""" 816 squeeze: bool = False 817 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 818 kernel_init: Union[str, Callable] = "lecun_normal()" 819 """The kernel initialization method to use.""" 820 species_key: str = "species" 821 """The key of the species tensor.""" 822 823 FID: ClassVar[str] = "ZLORANET" 824 825 @nn.compact 826 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 827 if self.input_key is None: 828 assert not isinstance( 829 inputs, dict 830 ), "input key must be provided if inputs is a dictionary" 831 species, x = inputs 832 else: 833 species, x = inputs[self.species_key], inputs[self.input_key] 834 835 # activation = ( 836 # activation_from_str(self.activation) 837 # if isinstance(self.activation, str) 838 # else self.activation 839 # ) 840 kernel_init = ( 841 initializer_from_str(self.kernel_init) 842 if isinstance(self.kernel_init, str) 843 else self.kernel_init 844 ) 845 ############################ 846 for i, d in enumerate(self.neurons[:-1]): 847 xi = nn.Dense( 848 d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init 849 )(x) 850 A = self.param( 851 f"A_{i+1}", 852 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 853 (self.zmax + 2, self.ranks[i], x.shape[-1]), 854 )[species] 855 B = self.param( 856 f"B_{i+1}", 857 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 858 (self.zmax + 2, d, self.ranks[i]), 859 )[species] 860 Ax = jnp.einsum("zrd,zd->zr", A, x) 861 BAx = jnp.einsum("zrd,zd->zr", B, Ax) 862 x = activation_from_str(self.activation)(xi + BAx) 863 xi = nn.Dense( 864 self.neurons[-1], 865 use_bias=self.use_bias, 866 name=f"Layer_{len(self.neurons)}", 867 kernel_init=kernel_init, 868 )(x) 869 A = self.param( 870 f"A_{len(self.neurons)}", 871 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 872 (self.zmax + 2, self.ranks[-1], x.shape[-1]), 873 )[species] 874 B = self.param( 875 f"B_{len(self.neurons)}", 876 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 877 (self.zmax + 2, self.neurons[-1], self.ranks[-1]), 878 )[species] 879 Ax = jnp.einsum("zrd,zd->zr", A, x) 880 BAx = jnp.einsum("zrd,zd->zr", B, Ax) 881 x = xi + BAx 882 if self.squeeze and x.shape[-1] == 1: 883 x = jnp.squeeze(x, axis=-1) 884 ############################ 885 886 if self.input_key is not None: 887 output_key = self.name if self.output_key is None else self.output_key 888 return {**inputs, output_key: x} if output_key is not None else x 889 return x 890 891 892class BlockIndexNet(nn.Module): 893 """Chemical-species-specific neural network using precomputed species index. 894 895 FID: BLOCK_INDEX_NET 896 897 A neural network that applies a species-specific fully connected network to each atom embedding. 898 A species index must be provided to filter the embeddings for each species and apply the corresponding network. 899 This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer` 900 901 """ 902 903 output_dim: int 904 """The dimension of the output of the fully connected networks.""" 905 hidden_neurons: Sequence[int] 906 """The hidden dimensions of the fully connected networks. 907 If a dictionary is provided, it should map species names to dimensions. 908 If a sequence is provided, the same dimensions will be used for all species.""" 909 used_blocks: Optional[Sequence[str]] = None 910 """The blocks to use. If None, all blocks will be used.""" 911 activation: Union[Callable, str] = "silu" 912 """The activation function to use in the fully connected networks.""" 913 use_bias: bool = True 914 """Whether to include bias terms in the fully connected networks.""" 915 input_key: Optional[str] = None 916 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 917 block_index_key: str = "block_index" 918 """The key in the input dictionary that corresponds to the block index of the atoms. See `fennol.models.preprocessing.BlockIndexer`""" 919 output_key: Optional[str] = None 920 """The key in the output dictionary that corresponds to the network's output.""" 921 922 squeeze: bool = False 923 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 924 kernel_init: Union[str, Callable] = "lecun_normal()" 925 """The kernel initialization method for the fully connected networks.""" 926 # check_unhandled: bool = True 927 928 FID: ClassVar[str] = "BLOCK_INDEX_NET" 929 930 # def setup(self): 931 # all_blocks = CHEMICAL_BLOCKS_NAMES 932 # if self.used_blocks is None: 933 # used_blocks = all_blocks 934 # else: 935 # used_blocks = [] 936 # for b in self.used_blocks: 937 # b_=str(b).strip().upper() 938 # if b_ not in all_blocks: 939 # raise ValueError(f"Block {b} not found in {all_blocks}") 940 # used_blocks.append(b_) 941 # used_blocks = set(used_blocks) 942 # self._used_blocks = used_blocks 943 944 # if not ( 945 # isinstance(self.hidden_neurons, dict) 946 # or isinstance(self.hidden_neurons, FrozenDict) 947 # ): 948 # neurons = {k: self.hidden_neurons for k in used_blocks} 949 # else: 950 # neurons = {} 951 # for b in self.hidden_neurons.keys(): 952 # b_=str(b).strip().upper() 953 # if b_ not in all_blocks: 954 # raise ValueError(f"Block {b} does not exist. Available blocks are {all_blocks}") 955 # neurons[b_] = self.hidden_neurons[b] 956 # used_blocks = set(neurons.keys()) 957 # if used_blocks != self._used_blocks and self.used_blocks is not None: 958 # print( 959 # f"Warning: hidden neurons definitions do not match specified used_blocks {self.used_blocks}. Using blocks defined in hidden_neurons.") 960 # self._used_blocks = used_blocks 961 962 # self.networks = { 963 # k: FullyConnectedNet( 964 # [*neurons[k], self.output_dim], 965 # self.activation, 966 # self.use_bias, 967 # name=k, 968 # kernel_init=self.kernel_init, 969 # ) 970 # for k in self._used_blocks 971 # } 972 973 @nn.compact 974 def __call__( 975 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 976 ) -> Union[dict, jax.Array]: 977 978 if self.input_key is None: 979 assert not isinstance( 980 inputs, dict 981 ), "input key must be provided if inputs is a dictionary" 982 species, embedding, block_index = inputs 983 else: 984 species, embedding = inputs["species"], inputs[self.input_key] 985 block_index = inputs[self.block_index_key] 986 987 assert isinstance( 988 block_index, dict 989 ), "block_index must be a dictionary for BlockIndexNet" 990 991 networks = { 992 k: FullyConnectedNet( 993 [*self.hidden_neurons, self.output_dim], 994 self.activation, 995 self.use_bias, 996 name=k, 997 kernel_init=self.kernel_init, 998 ) 999 for k in block_index.keys() 1000 } 1001 1002 ############################ 1003 # initialization => instantiate all networks 1004 if self.is_initializing(): 1005 x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype) 1006 [net(x) for net in networks.values()] 1007 1008 # if self.check_unhandled: 1009 # for b in block_index.keys(): 1010 # if b not in networks.keys(): 1011 # raise ValueError(f"Block {b} not found in networks. Available blocks are {self.networks.keys()}") 1012 1013 ############################ 1014 outputs = [] 1015 indices = [] 1016 for s, net in networks.items(): 1017 if s not in block_index: 1018 continue 1019 if block_index[s] is None: 1020 continue 1021 idx = block_index[s] 1022 o = net(embedding[idx]) 1023 outputs.append(o) 1024 indices.append(idx) 1025 1026 o = jnp.concatenate(outputs, axis=0) 1027 idx = jnp.concatenate(indices, axis=0) 1028 1029 out = ( 1030 jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype) 1031 .at[idx] 1032 .set(o, mode="drop") 1033 ) 1034 1035 if self.squeeze and out.shape[-1] == 1: 1036 out = jnp.squeeze(out, axis=-1) 1037 ############################ 1038 1039 if self.input_key is not None: 1040 output_key = self.name if self.output_key is None else self.output_key 1041 return {**inputs, output_key: out} if output_key is not None else out 1042 return out
14class FullyConnectedNet(nn.Module): 15 """A fully connected neural network module. 16 17 FID: NEURAL_NET 18 """ 19 20 neurons: Sequence[int] 21 """A sequence of integers representing the dimensions of the network.""" 22 activation: Union[Callable, str] = "silu" 23 """The activation function to use.""" 24 use_bias: bool = True 25 """Whether to use bias in the dense layers.""" 26 input_key: Optional[str] = None 27 """The key of the input tensor.""" 28 output_key: Optional[str] = None 29 """The key of the output tensor.""" 30 squeeze: bool = False 31 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 32 kernel_init: Union[str, Callable] = "lecun_normal()" 33 """The kernel initialization method to use.""" 34 35 FID: ClassVar[str] = "NEURAL_NET" 36 37 @nn.compact 38 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 39 """Applies the neural network to the given inputs. 40 41 Args: 42 inputs (Union[dict, jax.Array]): If input_key is None, a JAX array containing the inputs to the neural network. Else, a dictionary containing the inputs at the key input_key. 43 44 Returns: 45 Union[dict, jax.Array]: If output_key is None, the output tensor of the neural network. Else, a dictionary containing the original inputs and the output tensor at the key output_key. 46 """ 47 if self.input_key is None: 48 assert not isinstance( 49 inputs, dict 50 ), "input key must be provided if inputs is a dictionary" 51 x = inputs 52 else: 53 x = inputs[self.input_key] 54 55 # activation = ( 56 # activation_from_str(self.activation) 57 # if isinstance(self.activation, str) 58 # else self.activation 59 # ) 60 kernel_init = ( 61 initializer_from_str(self.kernel_init) 62 if isinstance(self.kernel_init, str) 63 else self.kernel_init 64 ) 65 ############################ 66 for i, d in enumerate(self.neurons[:-1]): 67 x = nn.Dense( 68 d, 69 use_bias=self.use_bias, 70 name=f"Layer_{i+1}", 71 kernel_init=kernel_init, 72 # precision=jax.lax.Precision.HIGH, 73 )(x) 74 x = activation_from_str(self.activation)(x) 75 x = nn.Dense( 76 self.neurons[-1], 77 use_bias=self.use_bias, 78 name=f"Layer_{len(self.neurons)}", 79 kernel_init=kernel_init, 80 # precision=jax.lax.Precision.HIGH, 81 )(x) 82 if self.squeeze and x.shape[-1] == 1: 83 x = jnp.squeeze(x, axis=-1) 84 ############################ 85 86 if self.input_key is not None: 87 output_key = self.name if self.output_key is None else self.output_key 88 return {**inputs, output_key: x} if output_key is not None else x 89 return x
A fully connected neural network module.
FID: NEURAL_NET
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
92class ResMLP(nn.Module): 93 """Residual neural network as defined in the SpookyNet paper. 94 95 FID: RES_MLP 96 """ 97 98 use_bias: bool = True 99 """Whether to include bias in the linear layers.""" 100 input_key: Optional[str] = None 101 """The key of the input tensor.""" 102 output_key: Optional[str] = None 103 """The key of the output tensor.""" 104 105 kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')" 106 """The kernel initialization method to use.""" 107 res_only: bool = False 108 """Whether to only apply the residual connection without additional activation and linear layer.""" 109 110 FID: ClassVar[str] = "RES_MLP" 111 112 @nn.compact 113 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 114 if self.input_key is None: 115 assert not isinstance( 116 inputs, dict 117 ), "input key must be provided if inputs is a dictionary" 118 x = inputs 119 else: 120 x = inputs[self.input_key] 121 122 kernel_init = ( 123 initializer_from_str(self.kernel_init) 124 if isinstance(self.kernel_init, str) 125 else self.kernel_init 126 ) 127 ############################ 128 out = nn.Dense(x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init)( 129 TrainableSiLU()(x) 130 ) 131 out = x + nn.Dense( 132 x.shape[-1], use_bias=self.use_bias, kernel_init=nn.initializers.zeros 133 )(TrainableSiLU()(out)) 134 135 if not self.res_only: 136 out = nn.Dense( 137 x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init 138 )(TrainableSiLU()(out)) 139 ############################ 140 141 if self.input_key is not None: 142 output_key = self.name if self.output_key is None else self.output_key 143 return {**inputs, output_key: out} if output_key is not None else out 144 return out
Residual neural network as defined in the SpookyNet paper.
FID: RES_MLP
The kernel initialization method to use.
Whether to only apply the residual connection without additional activation and linear layer.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
147class FullyResidualNet(nn.Module): 148 """A neural network with skip connections at each layer. 149 150 FID: SKIP_NET 151 """ 152 153 dim: int 154 """The dimension of the hidden layers.""" 155 output_dim: int 156 """The dimension of the output layer.""" 157 nlayers: int 158 """The number of layers in the network.""" 159 activation: Union[Callable, str] = "silu" 160 """The activation function to use.""" 161 use_bias: bool = True 162 """Whether to include bias terms in the linear layers.""" 163 input_key: Optional[str] = None 164 """The key of the input tensor.""" 165 output_key: Optional[str] = None 166 """The key of the output tensor.""" 167 squeeze: bool = False 168 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 169 kernel_init: Union[str, Callable] = "lecun_normal()" 170 """The kernel initialization method to use.""" 171 172 FID: ClassVar[str] = "SKIP_NET" 173 174 @nn.compact 175 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 176 if self.input_key is None: 177 assert not isinstance( 178 inputs, dict 179 ), "input key must be provided if inputs is a dictionary" 180 x = inputs 181 else: 182 x = inputs[self.input_key] 183 184 # activation = ( 185 # activation_from_str(self.activation) 186 # if isinstance(self.activation, str) 187 # else self.activation 188 # ) 189 kernel_init = ( 190 initializer_from_str(self.kernel_init) 191 if isinstance(self.kernel_init, str) 192 else self.kernel_init 193 ) 194 ############################ 195 if x.shape[-1] != self.dim: 196 x = nn.Dense( 197 self.dim, 198 use_bias=self.use_bias, 199 name=f"Reshape", 200 kernel_init=kernel_init, 201 )(x) 202 203 for i in range(self.nlayers - 1): 204 x = x + activation_from_str(self.activation)( 205 nn.Dense( 206 self.dim, 207 use_bias=self.use_bias, 208 name=f"Layer_{i+1}", 209 kernel_init=kernel_init, 210 )(x) 211 ) 212 x = nn.Dense( 213 self.output_dim, 214 use_bias=self.use_bias, 215 name=f"Layer_{self.nlayers}", 216 kernel_init=kernel_init, 217 )(x) 218 if self.squeeze and x.shape[-1] == 1: 219 x = jnp.squeeze(x, axis=-1) 220 ############################ 221 222 if self.input_key is not None: 223 output_key = self.name if self.output_key is None else self.output_key 224 return {**inputs, output_key: x} if output_key is not None else x 225 return x
A neural network with skip connections at each layer.
FID: SKIP_NET
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
228class HierarchicalNet(nn.Module): 229 """Neural network for a sequence of inputs (in axis=-2) with a decay factor 230 231 FID: HIERARCHICAL_NET 232 """ 233 234 neurons: Sequence[int] 235 """A sequence of integers representing the number of neurons in each layer.""" 236 activation: Union[Callable, str] = "silu" 237 """The activation function to use.""" 238 use_bias: bool = True 239 """Whether to include bias terms in the linear layers.""" 240 input_key: Optional[str] = None 241 """The key of the input tensor.""" 242 output_key: Optional[str] = None 243 """The key of the output tensor.""" 244 decay: float = 0.01 245 """The decay factor to scale each element of the sequence.""" 246 squeeze: bool = False 247 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 248 kernel_init: Union[str, Callable] = "lecun_normal()" 249 """The kernel initialization method to use.""" 250 251 FID: ClassVar[str] = "HIERARCHICAL_NET" 252 253 @nn.compact 254 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 255 if self.input_key is None: 256 assert not isinstance( 257 inputs, dict 258 ), "input key must be provided if inputs is a dictionary" 259 x = inputs 260 else: 261 x = inputs[self.input_key] 262 263 ############################ 264 networks = nn.vmap( 265 FullyConnectedNet, 266 variable_axes={"params": 0}, 267 split_rngs={"params": True}, 268 in_axes=-2, 269 out_axes=-2, 270 kenel_init=self.kernel_init, 271 )(self.neurons, self.activation, self.use_bias) 272 273 out = networks(x) 274 # scale each layer by a decay factor 275 decay = jnp.asarray([self.decay**i for i in range(out.shape[-2])]) 276 out = out * decay[..., :, None] 277 278 if self.squeeze and out.shape[-1] == 1: 279 out = jnp.squeeze(out, axis=-1) 280 ############################ 281 282 if self.input_key is not None: 283 output_key = self.name if self.output_key is None else self.output_key 284 return {**inputs, output_key: out} if output_key is not None else out 285 return out
Neural network for a sequence of inputs (in axis=-2) with a decay factor
FID: HIERARCHICAL_NET
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
288class SpeciesIndexNet(nn.Module): 289 """Chemical-species-specific neural network using precomputed species index. 290 291 FID: SPECIES_INDEX_NET 292 293 A neural network that applies a species-specific fully connected network to each atom embedding. 294 A species index must be provided to filter the embeddings for each species and apply the corresponding network. 295 This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer` 296 297 """ 298 299 output_dim: int 300 """The dimension of the output of the fully connected networks.""" 301 hidden_neurons: Union[dict, FrozenDict, Sequence[int]] 302 """The hidden dimensions of the fully connected networks. 303 If a dictionary is provided, it should map species names to dimensions. 304 If a sequence is provided, the same dimensions will be used for all species.""" 305 species_order: Optional[Union[str, Sequence[str]]] = None 306 """The species for which to build a network. Only required if neurons is not a dictionary.""" 307 activation: Union[Callable, str] = "silu" 308 """The activation function to use in the fully connected networks.""" 309 use_bias: bool = True 310 """Whether to include bias terms in the fully connected networks.""" 311 input_key: Optional[str] = None 312 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 313 species_index_key: str = "species_index" 314 """The key in the input dictionary that corresponds to the species index of the atoms. See `fennol.models.preprocessing.SpeciesIndexer`""" 315 output_key: Optional[str] = None 316 """The key in the output dictionary that corresponds to the network's output.""" 317 318 squeeze: bool = False 319 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 320 kernel_init: Union[str, Callable] = "lecun_normal()" 321 """The kernel initialization method for the fully connected networks.""" 322 check_unhandled: bool = True 323 324 FID: ClassVar[str] = "SPECIES_INDEX_NET" 325 326 def setup(self): 327 if not ( 328 isinstance(self.hidden_neurons, dict) 329 or isinstance(self.hidden_neurons, FrozenDict) 330 ): 331 assert ( 332 self.species_order is not None 333 ), "species_order must be provided if hidden_neurons is not a dictionary" 334 if isinstance(self.species_order, str): 335 species_order = [el.strip() for el in self.species_order.split(",")] 336 else: 337 species_order = [el for el in self.species_order] 338 neurons = {k: self.hidden_neurons for k in species_order} 339 else: 340 neurons = self.hidden_neurons 341 species_order = list(neurons.keys()) 342 for species in species_order: 343 assert ( 344 species in PERIODIC_TABLE 345 ), f"species {species} not found in periodic table" 346 347 self.networks = { 348 k: FullyConnectedNet( 349 [*neurons[k], self.output_dim], 350 self.activation, 351 self.use_bias, 352 name=k, 353 kernel_init=self.kernel_init, 354 ) 355 for k in species_order 356 } 357 358 def __call__( 359 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 360 ) -> Union[dict, jax.Array]: 361 362 if self.input_key is None: 363 assert not isinstance( 364 inputs, dict 365 ), "input key must be provided if inputs is a dictionary" 366 species, embedding, species_index = inputs 367 else: 368 species, embedding = inputs["species"], inputs[self.input_key] 369 species_index = inputs[self.species_index_key] 370 371 assert isinstance( 372 species_index, dict 373 ), "species_index must be a dictionary for SpeciesIndexNetHet" 374 375 ############################ 376 # initialization => instantiate all networks 377 if self.is_initializing(): 378 x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype) 379 [net(x) for net in self.networks.values()] 380 381 if self.check_unhandled: 382 for b in species_index.keys(): 383 if b not in self.networks.keys(): 384 raise ValueError(f"Species {b} not found in networks. Handled species are {self.networks.keys()}") 385 386 ############################ 387 outputs = [] 388 indices = [] 389 for s, net in self.networks.items(): 390 if s not in species_index: 391 continue 392 idx = species_index[s] 393 o = net(embedding[idx]) 394 outputs.append(o) 395 indices.append(idx) 396 397 o = jnp.concatenate(outputs, axis=0) 398 idx = jnp.concatenate(indices, axis=0) 399 400 out = ( 401 jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype) 402 .at[idx] 403 .set(o, mode="drop") 404 ) 405 406 if self.squeeze and out.shape[-1] == 1: 407 out = jnp.squeeze(out, axis=-1) 408 ############################ 409 410 if self.input_key is not None: 411 output_key = self.name if self.output_key is None else self.output_key 412 return {**inputs, output_key: out} if output_key is not None else out 413 return out
Chemical-species-specific neural network using precomputed species index.
FID: SPECIES_INDEX_NET
A neural network that applies a species-specific fully connected network to each atom embedding.
A species index must be provided to filter the embeddings for each species and apply the corresponding network.
This index can be obtained using the SPECIES_INDEXER preprocessing module from fennol.models.preprocessing.SpeciesIndexer
The species for which to build a network. Only required if neurons is not a dictionary.
The activation function to use in the fully connected networks.
The key in the input dictionary that corresponds to the embeddings of the atoms.
The key in the input dictionary that corresponds to the species index of the atoms. See fennol.models.preprocessing.SpeciesIndexer
The key in the output dictionary that corresponds to the network's output.
The kernel initialization method for the fully connected networks.
326 def setup(self): 327 if not ( 328 isinstance(self.hidden_neurons, dict) 329 or isinstance(self.hidden_neurons, FrozenDict) 330 ): 331 assert ( 332 self.species_order is not None 333 ), "species_order must be provided if hidden_neurons is not a dictionary" 334 if isinstance(self.species_order, str): 335 species_order = [el.strip() for el in self.species_order.split(",")] 336 else: 337 species_order = [el for el in self.species_order] 338 neurons = {k: self.hidden_neurons for k in species_order} 339 else: 340 neurons = self.hidden_neurons 341 species_order = list(neurons.keys()) 342 for species in species_order: 343 assert ( 344 species in PERIODIC_TABLE 345 ), f"species {species} not found in periodic table" 346 347 self.networks = { 348 k: FullyConnectedNet( 349 [*neurons[k], self.output_dim], 350 self.activation, 351 self.use_bias, 352 name=k, 353 kernel_init=self.kernel_init, 354 ) 355 for k in species_order 356 }
Initializes a Module lazily (similar to a lazy __init__
).
setup
is called once lazily on a module instance when a module
is bound, immediately before any other methods like __call__
are
invoked, or before a setup
-defined attribute on self
is accessed.
This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module's
setup
method (see__setattr__()
)::>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
416class ChemicalNet(nn.Module): 417 """optimized Chemical-species-specific neural network. 418 419 FID: CHEMICAL_NET 420 421 A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species. 422 This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once. 423 The optimization is allowed because all networks have the same shape. 424 425 """ 426 427 species_order: Union[str, Sequence[str]] 428 """The species for which to build a network.""" 429 neurons: Sequence[int] 430 """The dimensions of the fully connected networks.""" 431 activation: Union[Callable, str] = "silu" 432 """The activation function to use in the fully connected networks.""" 433 use_bias: bool = True 434 """Whether to include bias terms in the fully connected networks.""" 435 input_key: Optional[str] = None 436 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 437 output_key: Optional[str] = None 438 """The key in the output dictionary that corresponds to the network's output.""" 439 squeeze: bool = False 440 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 441 kernel_init: Union[str, Callable] = "lecun_normal()" 442 """The kernel initialization method for the fully connected networks.""" 443 444 FID: ClassVar[str] = "CHEMICAL_NET" 445 446 @nn.compact 447 def __call__( 448 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 449 ) -> Union[dict, jax.Array]: 450 if self.input_key is None: 451 assert not isinstance( 452 inputs, dict 453 ), "input key must be provided if inputs is a dictionary" 454 species, embedding = inputs 455 else: 456 species, embedding = inputs["species"], inputs[self.input_key] 457 458 ############################ 459 # build species to network index mapping (static => fixed when jitted) 460 rev_idx = PERIODIC_TABLE_REV_IDX 461 maxidx = max(rev_idx.values()) 462 if isinstance(self.species_order, str): 463 species_order = [el.strip() for el in self.species_order.split(",")] 464 else: 465 species_order = [el for el in self.species_order] 466 nspecies = len(species_order) 467 conv_tensor_ = np.full((maxidx + 2,), -1, dtype=np.int32) 468 for i, s in enumerate(species_order): 469 conv_tensor_[rev_idx[s]] = i 470 conv_tensor = jnp.asarray(conv_tensor_) 471 indices = conv_tensor[species] 472 473 ############################ 474 # build shape-sharing networks using vmap 475 networks = nn.vmap( 476 FullyConnectedNet, 477 variable_axes={"params": 0}, 478 split_rngs={"params": True}, 479 in_axes=0, 480 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 481 # repeat input along a new axis to compute for all species at once 482 x = jnp.broadcast_to( 483 embedding[None, :, :], (nspecies, *embedding.shape) 484 ) 485 486 # apply networks to input and select the output corresponding to the species 487 out = jnp.squeeze( 488 jnp.take_along_axis(networks(x), indices[None, :, None], axis=0), axis=0 489 ) 490 491 out = jnp.where((indices >= 0)[:, None], out, 0.0) 492 if self.squeeze and out.shape[-1] == 1: 493 out = jnp.squeeze(out, axis=-1) 494 ############################ 495 496 if self.input_key is not None: 497 output_key = self.name if self.output_key is None else self.output_key 498 return {**inputs, output_key: out} if output_key is not None else out 499 return out
optimized Chemical-species-specific neural network.
FID: CHEMICAL_NET
A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species. This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once. The optimization is allowed because all networks have the same shape.
The activation function to use in the fully connected networks.
The key in the input dictionary that corresponds to the embeddings of the atoms.
The key in the output dictionary that corresponds to the network's output.
The kernel initialization method for the fully connected networks.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
502class MOENet(nn.Module): 503 """Mixture of Experts neural network. 504 505 FID: MOE_NET 506 507 This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks 508 to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router. 509 510 """ 511 512 neurons: Sequence[int] 513 """A sequence of integers representing the number of neurons in each shape-sharing network.""" 514 num_networks: int 515 """The number of shape-sharing networks to create.""" 516 activation: Union[Callable, str] = "silu" 517 """The activation function to use in the shape-sharing networks.""" 518 use_bias: bool = True 519 """Whether to include bias in the shape-sharing networks.""" 520 input_key: Optional[str] = None 521 """The key of the input tensor.""" 522 output_key: Optional[str] = None 523 """The key of the output tensor.""" 524 squeeze: bool = False 525 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 526 527 kernel_init: Union[str, Callable] = "lecun_normal()" 528 """The kernel initialization method to use in the shape-sharing networks.""" 529 router_key: Optional[str] = None 530 """The key of the router tensor. If None, the router is assumed to be the same as the input tensor.""" 531 532 FID: ClassVar[str] = "MOE_NET" 533 534 @nn.compact 535 def __call__( 536 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 537 ) -> Union[dict, jax.Array]: 538 if self.input_key is None: 539 assert not isinstance( 540 inputs, dict 541 ), "input key must be provided if inputs is a dictionary" 542 if isinstance(inputs, tuple): 543 embedding, router = inputs 544 else: 545 embedding = router = inputs 546 else: 547 embedding = inputs[self.input_key] 548 router = ( 549 inputs[self.router_key] if self.router_key is not None else embedding 550 ) 551 552 ############################ 553 # build shape-sharing networks using vmap 554 networks = nn.vmap( 555 FullyConnectedNet, 556 variable_axes={"params": 0}, 557 split_rngs={"params": True}, 558 in_axes=0, 559 out_axes=0, 560 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 561 # repeat input along a new axis to compute for all networks at once 562 x = jnp.repeat(embedding[None, :, :], self.num_networks, axis=0) 563 564 w = nn.softmax(nn.Dense(self.num_networks, name="router")(router), axis=-1) 565 566 out = (networks(x) * w.T[:, :, None]).sum(axis=0) 567 568 if self.squeeze and out.shape[-1] == 1: 569 out = jnp.squeeze(out, axis=-1) 570 ############################ 571 572 if self.input_key is not None: 573 output_key = self.name if self.output_key is None else self.output_key 574 return {**inputs, output_key: out} if output_key is not None else out 575 return out
Mixture of Experts neural network.
FID: MOE_NET
This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router.
A sequence of integers representing the number of neurons in each shape-sharing network.
The activation function to use in the shape-sharing networks.
The kernel initialization method to use in the shape-sharing networks.
The key of the router tensor. If None, the router is assumed to be the same as the input tensor.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
577class ChannelNet(nn.Module): 578 """Apply a different neural network to each channel. 579 580 FID: CHANNEL_NET 581 """ 582 583 neurons: Sequence[int] 584 """A sequence of integers representing the number of neurons in each shape-sharing network.""" 585 activation: Union[Callable, str] = "silu" 586 """The activation function to use in the shape-sharing networks.""" 587 use_bias: bool = True 588 """Whether to include bias in the shape-sharing networks.""" 589 input_key: Optional[str] = None 590 """The key of the input tensor.""" 591 output_key: Optional[str] = None 592 """The key of the output tensor.""" 593 squeeze: bool = False 594 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 595 kernel_init: Union[str, Callable] = "lecun_normal()" 596 """The kernel initialization method to use in the shape-sharing networks.""" 597 channel_axis: int = -2 598 """The axis to use as channel. Its length will be the number of shape-sharing networks.""" 599 600 FID: ClassVar[str] = "CHANNEL_NET" 601 602 @nn.compact 603 def __call__( 604 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 605 ) -> Union[dict, jax.Array]: 606 if self.input_key is None: 607 assert not isinstance( 608 inputs, dict 609 ), "input key must be provided if inputs is a dictionary" 610 x = inputs 611 else: 612 x = inputs[self.input_key] 613 614 ############################ 615 # build shape-sharing networks using vmap 616 networks = nn.vmap( 617 FullyConnectedNet, 618 variable_axes={"params": 0}, 619 split_rngs={"params": True}, 620 in_axes=self.channel_axis, 621 out_axes=self.channel_axis, 622 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 623 624 out = networks(x) 625 626 if self.squeeze and out.shape[-1] == 1: 627 out = jnp.squeeze(out, axis=-1) 628 ############################ 629 630 if self.input_key is not None: 631 output_key = self.name if self.output_key is None else self.output_key 632 return {**inputs, output_key: out} if output_key is not None else out 633 return out
Apply a different neural network to each channel.
FID: CHANNEL_NET
A sequence of integers representing the number of neurons in each shape-sharing network.
The activation function to use in the shape-sharing networks.
The kernel initialization method to use in the shape-sharing networks.
The axis to use as channel. Its length will be the number of shape-sharing networks.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
636class GatedPerceptron(nn.Module): 637 """Gated Perceptron neural network. 638 639 FID: GATED_PERCEPTRON 640 641 This class represents a Gated Perceptron neural network model. It applies a gating mechanism 642 to the input data and performs linear transformation using a dense layer followed by an activation function. 643 """ 644 645 dim: int 646 """The dimensionality of the output space.""" 647 use_bias: bool = True 648 """Whether to include a bias term in the dense layer.""" 649 kernel_init: Union[str, Callable] = "lecun_normal()" 650 """The kernel initialization method to use.""" 651 activation: Union[Callable, str] = "silu" 652 """The activation function to use.""" 653 654 input_key: Optional[str] = None 655 """The key of the input tensor.""" 656 output_key: Optional[str] = None 657 """The key of the output tensor.""" 658 squeeze: bool = False 659 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 660 661 FID: ClassVar[str] = "GATED_PERCEPTRON" 662 663 @nn.compact 664 def __call__(self, inputs): 665 if self.input_key is None: 666 assert not isinstance( 667 inputs, dict 668 ), "input key must be provided if inputs is a dictionary" 669 x = inputs 670 else: 671 x = inputs[self.input_key] 672 673 # activation = ( 674 # activation_from_str(self.activation) 675 # if isinstance(self.activation, str) 676 # else self.activation 677 # ) 678 kernel_init = ( 679 initializer_from_str(self.kernel_init) 680 if isinstance(self.kernel_init, str) 681 else self.kernel_init 682 ) 683 ############################ 684 gate = jax.nn.sigmoid( 685 nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x) 686 ) 687 x = gate * activation_from_str(self.activation)( 688 nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x) 689 ) 690 691 if self.squeeze and out.shape[-1] == 1: 692 out = jnp.squeeze(out, axis=-1) 693 ############################ 694 695 if self.input_key is not None: 696 output_key = self.name if self.output_key is None else self.output_key 697 return {**inputs, output_key: x} if output_key is not None else x 698 return x
Gated Perceptron neural network.
FID: GATED_PERCEPTRON
This class represents a Gated Perceptron neural network model. It applies a gating mechanism to the input data and performs linear transformation using a dense layer followed by an activation function.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
701class ZAcNet(nn.Module): 702 """ A fully connected neural network module with affine Z-dependent adjustments of activations. 703 704 FID: ZACNET 705 """ 706 707 neurons: Sequence[int] 708 """A sequence of integers representing the dimensions of the network.""" 709 zmax: int = 86 710 """The maximum atomic number to consider.""" 711 activation: Union[Callable, str] = "silu" 712 """The activation function to use.""" 713 use_bias: bool = True 714 """Whether to use bias in the dense layers.""" 715 input_key: Optional[str] = None 716 """The key of the input tensor.""" 717 output_key: Optional[str] = None 718 """The key of the output tensor.""" 719 squeeze: bool = False 720 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 721 kernel_init: Union[str, Callable] = "lecun_normal()" 722 """The kernel initialization method to use.""" 723 species_key: str = "species" 724 """The key of the species tensor.""" 725 726 FID: ClassVar[str] = "ZACNET" 727 728 @nn.compact 729 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 730 if self.input_key is None: 731 assert not isinstance( 732 inputs, dict 733 ), "input key must be provided if inputs is a dictionary" 734 species, x = inputs 735 else: 736 species, x = inputs[self.species_key], inputs[self.input_key] 737 738 # activation = ( 739 # activation_from_str(self.activation) 740 # if isinstance(self.activation, str) 741 # else self.activation 742 # ) 743 kernel_init = ( 744 initializer_from_str(self.kernel_init) 745 if isinstance(self.kernel_init, str) 746 else self.kernel_init 747 ) 748 ############################ 749 for i, d in enumerate(self.neurons[:-1]): 750 x = nn.Dense( 751 d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init 752 )(x) 753 sig = self.param( 754 f"sig_{i+1}", 755 lambda key, shape: jnp.ones(shape, dtype=x.dtype), 756 (self.zmax + 2, d), 757 )[species] 758 if self.use_bias: 759 b = self.param( 760 f"b_{i+1}", 761 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 762 (self.zmax + 2, d), 763 )[species] 764 else: 765 b = 0 766 x = activation_from_str(self.activation)(sig * x + b) 767 x = nn.Dense( 768 self.neurons[-1], 769 use_bias=self.use_bias, 770 name=f"Layer_{len(self.neurons)}", 771 kernel_init=kernel_init, 772 )(x) 773 sig = self.param( 774 f"sig_{len(self.neurons)}", 775 lambda key, shape: jnp.ones(shape, dtype=x.dtype), 776 (self.zmax + 2, self.neurons[-1]), 777 )[species] 778 if self.use_bias: 779 b = self.param( 780 f"b_{len(self.neurons)}", 781 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 782 (self.zmax + 2, self.neurons[-1]), 783 )[species] 784 else: 785 b = 0 786 x = sig * x + b 787 if self.squeeze and x.shape[-1] == 1: 788 x = jnp.squeeze(x, axis=-1) 789 ############################ 790 791 if self.input_key is not None: 792 output_key = self.name if self.output_key is None else self.output_key 793 return {**inputs, output_key: x} if output_key is not None else x 794 return x
A fully connected neural network module with affine Z-dependent adjustments of activations.
FID: ZACNET
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
797class ZLoRANet(nn.Module): 798 """A fully connected neural network module with Z-dependent low-rank adaptation. 799 800 FID: ZLORANET 801 """ 802 803 neurons: Sequence[int] 804 """A sequence of integers representing the dimensions of the network.""" 805 ranks: Sequence[int] 806 """A sequence of integers representing the ranks of the low-rank adaptation at each layer.""" 807 zmax: int = 86 808 """The maximum atomic number to consider.""" 809 activation: Union[Callable, str] = "silu" 810 """The activation function to use.""" 811 use_bias: bool = True 812 """Whether to use bias in the dense layers.""" 813 input_key: Optional[str] = None 814 """The key of the input tensor.""" 815 output_key: Optional[str] = None 816 """The key of the output tensor.""" 817 squeeze: bool = False 818 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 819 kernel_init: Union[str, Callable] = "lecun_normal()" 820 """The kernel initialization method to use.""" 821 species_key: str = "species" 822 """The key of the species tensor.""" 823 824 FID: ClassVar[str] = "ZLORANET" 825 826 @nn.compact 827 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 828 if self.input_key is None: 829 assert not isinstance( 830 inputs, dict 831 ), "input key must be provided if inputs is a dictionary" 832 species, x = inputs 833 else: 834 species, x = inputs[self.species_key], inputs[self.input_key] 835 836 # activation = ( 837 # activation_from_str(self.activation) 838 # if isinstance(self.activation, str) 839 # else self.activation 840 # ) 841 kernel_init = ( 842 initializer_from_str(self.kernel_init) 843 if isinstance(self.kernel_init, str) 844 else self.kernel_init 845 ) 846 ############################ 847 for i, d in enumerate(self.neurons[:-1]): 848 xi = nn.Dense( 849 d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init 850 )(x) 851 A = self.param( 852 f"A_{i+1}", 853 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 854 (self.zmax + 2, self.ranks[i], x.shape[-1]), 855 )[species] 856 B = self.param( 857 f"B_{i+1}", 858 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 859 (self.zmax + 2, d, self.ranks[i]), 860 )[species] 861 Ax = jnp.einsum("zrd,zd->zr", A, x) 862 BAx = jnp.einsum("zrd,zd->zr", B, Ax) 863 x = activation_from_str(self.activation)(xi + BAx) 864 xi = nn.Dense( 865 self.neurons[-1], 866 use_bias=self.use_bias, 867 name=f"Layer_{len(self.neurons)}", 868 kernel_init=kernel_init, 869 )(x) 870 A = self.param( 871 f"A_{len(self.neurons)}", 872 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 873 (self.zmax + 2, self.ranks[-1], x.shape[-1]), 874 )[species] 875 B = self.param( 876 f"B_{len(self.neurons)}", 877 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 878 (self.zmax + 2, self.neurons[-1], self.ranks[-1]), 879 )[species] 880 Ax = jnp.einsum("zrd,zd->zr", A, x) 881 BAx = jnp.einsum("zrd,zd->zr", B, Ax) 882 x = xi + BAx 883 if self.squeeze and x.shape[-1] == 1: 884 x = jnp.squeeze(x, axis=-1) 885 ############################ 886 887 if self.input_key is not None: 888 output_key = self.name if self.output_key is None else self.output_key 889 return {**inputs, output_key: x} if output_key is not None else x 890 return x
A fully connected neural network module with Z-dependent low-rank adaptation.
FID: ZLORANET
A sequence of integers representing the ranks of the low-rank adaptation at each layer.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
893class BlockIndexNet(nn.Module): 894 """Chemical-species-specific neural network using precomputed species index. 895 896 FID: BLOCK_INDEX_NET 897 898 A neural network that applies a species-specific fully connected network to each atom embedding. 899 A species index must be provided to filter the embeddings for each species and apply the corresponding network. 900 This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer` 901 902 """ 903 904 output_dim: int 905 """The dimension of the output of the fully connected networks.""" 906 hidden_neurons: Sequence[int] 907 """The hidden dimensions of the fully connected networks. 908 If a dictionary is provided, it should map species names to dimensions. 909 If a sequence is provided, the same dimensions will be used for all species.""" 910 used_blocks: Optional[Sequence[str]] = None 911 """The blocks to use. If None, all blocks will be used.""" 912 activation: Union[Callable, str] = "silu" 913 """The activation function to use in the fully connected networks.""" 914 use_bias: bool = True 915 """Whether to include bias terms in the fully connected networks.""" 916 input_key: Optional[str] = None 917 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 918 block_index_key: str = "block_index" 919 """The key in the input dictionary that corresponds to the block index of the atoms. See `fennol.models.preprocessing.BlockIndexer`""" 920 output_key: Optional[str] = None 921 """The key in the output dictionary that corresponds to the network's output.""" 922 923 squeeze: bool = False 924 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 925 kernel_init: Union[str, Callable] = "lecun_normal()" 926 """The kernel initialization method for the fully connected networks.""" 927 # check_unhandled: bool = True 928 929 FID: ClassVar[str] = "BLOCK_INDEX_NET" 930 931 # def setup(self): 932 # all_blocks = CHEMICAL_BLOCKS_NAMES 933 # if self.used_blocks is None: 934 # used_blocks = all_blocks 935 # else: 936 # used_blocks = [] 937 # for b in self.used_blocks: 938 # b_=str(b).strip().upper() 939 # if b_ not in all_blocks: 940 # raise ValueError(f"Block {b} not found in {all_blocks}") 941 # used_blocks.append(b_) 942 # used_blocks = set(used_blocks) 943 # self._used_blocks = used_blocks 944 945 # if not ( 946 # isinstance(self.hidden_neurons, dict) 947 # or isinstance(self.hidden_neurons, FrozenDict) 948 # ): 949 # neurons = {k: self.hidden_neurons for k in used_blocks} 950 # else: 951 # neurons = {} 952 # for b in self.hidden_neurons.keys(): 953 # b_=str(b).strip().upper() 954 # if b_ not in all_blocks: 955 # raise ValueError(f"Block {b} does not exist. Available blocks are {all_blocks}") 956 # neurons[b_] = self.hidden_neurons[b] 957 # used_blocks = set(neurons.keys()) 958 # if used_blocks != self._used_blocks and self.used_blocks is not None: 959 # print( 960 # f"Warning: hidden neurons definitions do not match specified used_blocks {self.used_blocks}. Using blocks defined in hidden_neurons.") 961 # self._used_blocks = used_blocks 962 963 # self.networks = { 964 # k: FullyConnectedNet( 965 # [*neurons[k], self.output_dim], 966 # self.activation, 967 # self.use_bias, 968 # name=k, 969 # kernel_init=self.kernel_init, 970 # ) 971 # for k in self._used_blocks 972 # } 973 974 @nn.compact 975 def __call__( 976 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 977 ) -> Union[dict, jax.Array]: 978 979 if self.input_key is None: 980 assert not isinstance( 981 inputs, dict 982 ), "input key must be provided if inputs is a dictionary" 983 species, embedding, block_index = inputs 984 else: 985 species, embedding = inputs["species"], inputs[self.input_key] 986 block_index = inputs[self.block_index_key] 987 988 assert isinstance( 989 block_index, dict 990 ), "block_index must be a dictionary for BlockIndexNet" 991 992 networks = { 993 k: FullyConnectedNet( 994 [*self.hidden_neurons, self.output_dim], 995 self.activation, 996 self.use_bias, 997 name=k, 998 kernel_init=self.kernel_init, 999 ) 1000 for k in block_index.keys() 1001 } 1002 1003 ############################ 1004 # initialization => instantiate all networks 1005 if self.is_initializing(): 1006 x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype) 1007 [net(x) for net in networks.values()] 1008 1009 # if self.check_unhandled: 1010 # for b in block_index.keys(): 1011 # if b not in networks.keys(): 1012 # raise ValueError(f"Block {b} not found in networks. Available blocks are {self.networks.keys()}") 1013 1014 ############################ 1015 outputs = [] 1016 indices = [] 1017 for s, net in networks.items(): 1018 if s not in block_index: 1019 continue 1020 if block_index[s] is None: 1021 continue 1022 idx = block_index[s] 1023 o = net(embedding[idx]) 1024 outputs.append(o) 1025 indices.append(idx) 1026 1027 o = jnp.concatenate(outputs, axis=0) 1028 idx = jnp.concatenate(indices, axis=0) 1029 1030 out = ( 1031 jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype) 1032 .at[idx] 1033 .set(o, mode="drop") 1034 ) 1035 1036 if self.squeeze and out.shape[-1] == 1: 1037 out = jnp.squeeze(out, axis=-1) 1038 ############################ 1039 1040 if self.input_key is not None: 1041 output_key = self.name if self.output_key is None else self.output_key 1042 return {**inputs, output_key: out} if output_key is not None else out 1043 return out
Chemical-species-specific neural network using precomputed species index.
FID: BLOCK_INDEX_NET
A neural network that applies a species-specific fully connected network to each atom embedding.
A species index must be provided to filter the embeddings for each species and apply the corresponding network.
This index can be obtained using the SPECIES_INDEXER preprocessing module from fennol.models.preprocessing.SpeciesIndexer
The activation function to use in the fully connected networks.
The key in the input dictionary that corresponds to the embeddings of the atoms.
The key in the input dictionary that corresponds to the block index of the atoms. See fennol.models.preprocessing.BlockIndexer
The key in the output dictionary that corresponds to the network's output.
The kernel initialization method for the fully connected networks.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.