fennol.models.misc.nets
1import flax.linen as nn 2from typing import Sequence, Callable, Union, Optional, ClassVar, Tuple 3from ...utils.periodic_table import PERIODIC_TABLE_REV_IDX, PERIODIC_TABLE 4import jax.numpy as jnp 5import jax 6import numpy as np 7from functools import partial 8from ...utils.activations import activation_from_str, TrainableSiLU 9from ...utils.initializers import initializer_from_str, scaled_orthogonal 10from flax.core import FrozenDict 11 12 13class FullyConnectedNet(nn.Module): 14 """A fully connected neural network module. 15 16 FID: NEURAL_NET 17 """ 18 19 neurons: Sequence[int] 20 """A sequence of integers representing the dimensions of the network.""" 21 activation: Union[Callable, str] = "silu" 22 """The activation function to use.""" 23 use_bias: bool = True 24 """Whether to use bias in the dense layers.""" 25 input_key: Optional[str] = None 26 """The key of the input tensor.""" 27 output_key: Optional[str] = None 28 """The key of the output tensor.""" 29 squeeze: bool = False 30 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 31 kernel_init: Union[str, Callable] = "lecun_normal()" 32 """The kernel initialization method to use.""" 33 34 FID: ClassVar[str] = "NEURAL_NET" 35 36 @nn.compact 37 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 38 """Applies the neural network to the given inputs. 39 40 Args: 41 inputs (Union[dict, jax.Array]): If input_key is None, a JAX array containing the inputs to the neural network. Else, a dictionary containing the inputs at the key input_key. 42 43 Returns: 44 Union[dict, jax.Array]: If output_key is None, the output tensor of the neural network. Else, a dictionary containing the original inputs and the output tensor at the key output_key. 45 """ 46 if self.input_key is None: 47 assert not isinstance( 48 inputs, dict 49 ), "input key must be provided if inputs is a dictionary" 50 x = inputs 51 else: 52 x = inputs[self.input_key] 53 54 # activation = ( 55 # activation_from_str(self.activation) 56 # if isinstance(self.activation, str) 57 # else self.activation 58 # ) 59 kernel_init = ( 60 initializer_from_str(self.kernel_init) 61 if isinstance(self.kernel_init, str) 62 else self.kernel_init 63 ) 64 ############################ 65 if isinstance(self.activation, str) and self.activation.lower() == "swiglu": 66 for i, d in enumerate(self.neurons[:-1]): 67 y = nn.Dense( 68 d, 69 use_bias=self.use_bias, 70 name=f"Layer_{i+1}", 71 kernel_init=kernel_init, 72 )(x) 73 z = jax.nn.swish(nn.Dense( 74 d, 75 use_bias=self.use_bias, 76 name=f"Mask_{i+1}", 77 kernel_init=kernel_init, 78 )(x)) 79 x = y * z 80 else: 81 for i, d in enumerate(self.neurons[:-1]): 82 x = nn.Dense( 83 d, 84 use_bias=self.use_bias, 85 name=f"Layer_{i+1}", 86 kernel_init=kernel_init, 87 )(x) 88 x = activation_from_str(self.activation)(x) 89 x = nn.Dense( 90 self.neurons[-1], 91 use_bias=self.use_bias, 92 name=f"Layer_{len(self.neurons)}", 93 kernel_init=kernel_init, 94 )(x) 95 if self.squeeze and x.shape[-1] == 1: 96 x = jnp.squeeze(x, axis=-1) 97 ############################ 98 99 if self.input_key is not None: 100 output_key = self.name if self.output_key is None else self.output_key 101 return {**inputs, output_key: x} if output_key is not None else x 102 return x 103 104 105class ResMLP(nn.Module): 106 """Residual neural network as defined in the SpookyNet paper. 107 108 FID: RES_MLP 109 """ 110 111 use_bias: bool = True 112 """Whether to include bias in the linear layers.""" 113 input_key: Optional[str] = None 114 """The key of the input tensor.""" 115 output_key: Optional[str] = None 116 """The key of the output tensor.""" 117 118 kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')" 119 """The kernel initialization method to use.""" 120 res_only: bool = False 121 """Whether to only apply the residual connection without additional activation and linear layer.""" 122 123 FID: ClassVar[str] = "RES_MLP" 124 125 @nn.compact 126 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 127 if self.input_key is None: 128 assert not isinstance( 129 inputs, dict 130 ), "input key must be provided if inputs is a dictionary" 131 x = inputs 132 else: 133 x = inputs[self.input_key] 134 135 kernel_init = ( 136 initializer_from_str(self.kernel_init) 137 if isinstance(self.kernel_init, str) 138 else self.kernel_init 139 ) 140 ############################ 141 out = nn.Dense(x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init)( 142 TrainableSiLU()(x) 143 ) 144 out = x + nn.Dense( 145 x.shape[-1], use_bias=self.use_bias, kernel_init=nn.initializers.zeros 146 )(TrainableSiLU()(out)) 147 148 if not self.res_only: 149 out = nn.Dense( 150 x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init 151 )(TrainableSiLU()(out)) 152 ############################ 153 154 if self.input_key is not None: 155 output_key = self.name if self.output_key is None else self.output_key 156 return {**inputs, output_key: out} if output_key is not None else out 157 return out 158 159 160class FullyResidualNet(nn.Module): 161 """A neural network with skip connections at each layer. 162 163 FID: SKIP_NET 164 """ 165 166 dim: int 167 """The dimension of the hidden layers.""" 168 output_dim: int 169 """The dimension of the output layer.""" 170 nlayers: int 171 """The number of layers in the network.""" 172 activation: Union[Callable, str] = "silu" 173 """The activation function to use.""" 174 use_bias: bool = True 175 """Whether to include bias terms in the linear layers.""" 176 input_key: Optional[str] = None 177 """The key of the input tensor.""" 178 output_key: Optional[str] = None 179 """The key of the output tensor.""" 180 squeeze: bool = False 181 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 182 kernel_init: Union[str, Callable] = "lecun_normal()" 183 """The kernel initialization method to use.""" 184 185 FID: ClassVar[str] = "SKIP_NET" 186 187 @nn.compact 188 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 189 if self.input_key is None: 190 assert not isinstance( 191 inputs, dict 192 ), "input key must be provided if inputs is a dictionary" 193 x = inputs 194 else: 195 x = inputs[self.input_key] 196 197 # activation = ( 198 # activation_from_str(self.activation) 199 # if isinstance(self.activation, str) 200 # else self.activation 201 # ) 202 kernel_init = ( 203 initializer_from_str(self.kernel_init) 204 if isinstance(self.kernel_init, str) 205 else self.kernel_init 206 ) 207 ############################ 208 if x.shape[-1] != self.dim: 209 x = nn.Dense( 210 self.dim, 211 use_bias=self.use_bias, 212 name=f"Reshape", 213 kernel_init=kernel_init, 214 )(x) 215 216 for i in range(self.nlayers - 1): 217 x = x + activation_from_str(self.activation)( 218 nn.Dense( 219 self.dim, 220 use_bias=self.use_bias, 221 name=f"Layer_{i+1}", 222 kernel_init=kernel_init, 223 )(x) 224 ) 225 x = nn.Dense( 226 self.output_dim, 227 use_bias=self.use_bias, 228 name=f"Layer_{self.nlayers}", 229 kernel_init=kernel_init, 230 )(x) 231 if self.squeeze and x.shape[-1] == 1: 232 x = jnp.squeeze(x, axis=-1) 233 ############################ 234 235 if self.input_key is not None: 236 output_key = self.name if self.output_key is None else self.output_key 237 return {**inputs, output_key: x} if output_key is not None else x 238 return x 239 240 241class HierarchicalNet(nn.Module): 242 """Neural network for a sequence of inputs (in axis=-2) with a decay factor 243 244 FID: HIERARCHICAL_NET 245 """ 246 247 neurons: Sequence[int] 248 """A sequence of integers representing the number of neurons in each layer.""" 249 activation: Union[Callable, str] = "silu" 250 """The activation function to use.""" 251 use_bias: bool = True 252 """Whether to include bias terms in the linear layers.""" 253 input_key: Optional[str] = None 254 """The key of the input tensor.""" 255 output_key: Optional[str] = None 256 """The key of the output tensor.""" 257 decay: float = 0.01 258 """The decay factor to scale each element of the sequence.""" 259 squeeze: bool = False 260 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 261 kernel_init: Union[str, Callable] = "lecun_normal()" 262 """The kernel initialization method to use.""" 263 264 FID: ClassVar[str] = "HIERARCHICAL_NET" 265 266 @nn.compact 267 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 268 if self.input_key is None: 269 assert not isinstance( 270 inputs, dict 271 ), "input key must be provided if inputs is a dictionary" 272 x = inputs 273 else: 274 x = inputs[self.input_key] 275 276 ############################ 277 networks = nn.vmap( 278 FullyConnectedNet, 279 variable_axes={"params": 0}, 280 split_rngs={"params": True}, 281 in_axes=-2, 282 out_axes=-2, 283 kenel_init=self.kernel_init, 284 )(self.neurons, self.activation, self.use_bias) 285 286 out = networks(x) 287 # scale each layer by a decay factor 288 decay = jnp.asarray([self.decay**i for i in range(out.shape[-2])]) 289 out = out * decay[..., :, None] 290 291 if self.squeeze and out.shape[-1] == 1: 292 out = jnp.squeeze(out, axis=-1) 293 ############################ 294 295 if self.input_key is not None: 296 output_key = self.name if self.output_key is None else self.output_key 297 return {**inputs, output_key: out} if output_key is not None else out 298 return out 299 300 301class SpeciesIndexNet(nn.Module): 302 """Chemical-species-specific neural network using precomputed species index. 303 304 FID: SPECIES_INDEX_NET 305 306 A neural network that applies a species-specific fully connected network to each atom embedding. 307 A species index must be provided to filter the embeddings for each species and apply the corresponding network. 308 This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer` 309 310 """ 311 312 output_dim: int 313 """The dimension of the output of the fully connected networks.""" 314 hidden_neurons: Union[dict, FrozenDict, Sequence[int]] 315 """The hidden dimensions of the fully connected networks. 316 If a dictionary is provided, it should map species names to dimensions. 317 If a sequence is provided, the same dimensions will be used for all species.""" 318 species_order: Optional[Union[str, Sequence[str]]] = None 319 """The species for which to build a network. Only required if neurons is not a dictionary.""" 320 activation: Union[Callable, str] = "silu" 321 """The activation function to use in the fully connected networks.""" 322 use_bias: bool = True 323 """Whether to include bias terms in the fully connected networks.""" 324 input_key: Optional[str] = None 325 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 326 species_index_key: str = "species_index" 327 """The key in the input dictionary that corresponds to the species index of the atoms. See `fennol.models.preprocessing.SpeciesIndexer`""" 328 output_key: Optional[str] = None 329 """The key in the output dictionary that corresponds to the network's output.""" 330 331 squeeze: bool = False 332 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 333 kernel_init: Union[str, Callable] = "lecun_normal()" 334 """The kernel initialization method for the fully connected networks.""" 335 check_unhandled: bool = True 336 337 FID: ClassVar[str] = "SPECIES_INDEX_NET" 338 339 def setup(self): 340 if not ( 341 isinstance(self.hidden_neurons, dict) 342 or isinstance(self.hidden_neurons, FrozenDict) 343 ): 344 assert ( 345 self.species_order is not None 346 ), "species_order must be provided if hidden_neurons is not a dictionary" 347 if isinstance(self.species_order, str): 348 species_order = [el.strip() for el in self.species_order.split(",")] 349 else: 350 species_order = [el for el in self.species_order] 351 neurons = {k: self.hidden_neurons for k in species_order} 352 else: 353 neurons = self.hidden_neurons 354 species_order = list(neurons.keys()) 355 for species in species_order: 356 assert ( 357 species in PERIODIC_TABLE 358 ), f"species {species} not found in periodic table" 359 360 self.networks = { 361 k: FullyConnectedNet( 362 [*neurons[k], self.output_dim], 363 self.activation, 364 self.use_bias, 365 name=k, 366 kernel_init=self.kernel_init, 367 ) 368 for k in species_order 369 } 370 371 def __call__( 372 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 373 ) -> Union[dict, jax.Array]: 374 375 if self.input_key is None: 376 assert not isinstance( 377 inputs, dict 378 ), "input key must be provided if inputs is a dictionary" 379 species, embedding, species_index = inputs 380 else: 381 species, embedding = inputs["species"], inputs[self.input_key] 382 species_index = inputs[self.species_index_key] 383 384 assert isinstance( 385 species_index, dict 386 ), "species_index must be a dictionary for SpeciesIndexNetHet" 387 388 ############################ 389 # initialization => instantiate all networks 390 if self.is_initializing(): 391 x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype) 392 [net(x) for net in self.networks.values()] 393 394 if self.check_unhandled: 395 for b in species_index.keys(): 396 if b not in self.networks.keys(): 397 raise ValueError(f"Species {b} not found in networks. Handled species are {self.networks.keys()}") 398 399 ############################ 400 outputs = [] 401 indices = [] 402 for s, net in self.networks.items(): 403 if s not in species_index: 404 continue 405 idx = species_index[s] 406 o = net(embedding[idx]) 407 outputs.append(o) 408 indices.append(idx) 409 410 o = jnp.concatenate(outputs, axis=0) 411 idx = jnp.concatenate(indices, axis=0) 412 413 out = ( 414 jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype) 415 .at[idx] 416 .set(o, mode="drop") 417 ) 418 419 if self.squeeze and out.shape[-1] == 1: 420 out = jnp.squeeze(out, axis=-1) 421 ############################ 422 423 if self.input_key is not None: 424 output_key = self.name if self.output_key is None else self.output_key 425 return {**inputs, output_key: out} if output_key is not None else out 426 return out 427 428 429class ChemicalNet(nn.Module): 430 """optimized Chemical-species-specific neural network. 431 432 FID: CHEMICAL_NET 433 434 A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species. 435 This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once. 436 The optimization is allowed because all networks have the same shape. 437 438 """ 439 440 species_order: Union[str, Sequence[str]] 441 """The species for which to build a network.""" 442 neurons: Sequence[int] 443 """The dimensions of the fully connected networks.""" 444 activation: Union[Callable, str] = "silu" 445 """The activation function to use in the fully connected networks.""" 446 use_bias: bool = True 447 """Whether to include bias terms in the fully connected networks.""" 448 input_key: Optional[str] = None 449 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 450 output_key: Optional[str] = None 451 """The key in the output dictionary that corresponds to the network's output.""" 452 squeeze: bool = False 453 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 454 kernel_init: Union[str, Callable] = "lecun_normal()" 455 """The kernel initialization method for the fully connected networks.""" 456 457 FID: ClassVar[str] = "CHEMICAL_NET" 458 459 @nn.compact 460 def __call__( 461 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 462 ) -> Union[dict, jax.Array]: 463 if self.input_key is None: 464 assert not isinstance( 465 inputs, dict 466 ), "input key must be provided if inputs is a dictionary" 467 species, embedding = inputs 468 else: 469 species, embedding = inputs["species"], inputs[self.input_key] 470 471 ############################ 472 # build species to network index mapping (static => fixed when jitted) 473 rev_idx = PERIODIC_TABLE_REV_IDX 474 maxidx = max(rev_idx.values()) 475 if isinstance(self.species_order, str): 476 species_order = [el.strip() for el in self.species_order.split(",")] 477 else: 478 species_order = [el for el in self.species_order] 479 nspecies = len(species_order) 480 conv_tensor_ = np.full((maxidx + 2,), -1, dtype=np.int32) 481 for i, s in enumerate(species_order): 482 conv_tensor_[rev_idx[s]] = i 483 conv_tensor = jnp.asarray(conv_tensor_) 484 indices = conv_tensor[species] 485 486 ############################ 487 # build shape-sharing networks using vmap 488 networks = nn.vmap( 489 FullyConnectedNet, 490 variable_axes={"params": 0}, 491 split_rngs={"params": True}, 492 in_axes=0, 493 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 494 # repeat input along a new axis to compute for all species at once 495 x = jnp.broadcast_to( 496 embedding[None, :, :], (nspecies, *embedding.shape) 497 ) 498 499 # apply networks to input and select the output corresponding to the species 500 out = jnp.squeeze( 501 jnp.take_along_axis(networks(x), indices[None, :, None], axis=0), axis=0 502 ) 503 504 out = jnp.where((indices >= 0)[:, None], out, 0.0) 505 if self.squeeze and out.shape[-1] == 1: 506 out = jnp.squeeze(out, axis=-1) 507 ############################ 508 509 if self.input_key is not None: 510 output_key = self.name if self.output_key is None else self.output_key 511 return {**inputs, output_key: out} if output_key is not None else out 512 return out 513 514 515class MOENet(nn.Module): 516 """Mixture of Experts neural network. 517 518 FID: MOE_NET 519 520 This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks 521 to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router. 522 523 """ 524 525 neurons: Sequence[int] 526 """A sequence of integers representing the number of neurons in each shape-sharing network.""" 527 num_networks: int 528 """The number of shape-sharing networks to create.""" 529 activation: Union[Callable, str] = "silu" 530 """The activation function to use in the shape-sharing networks.""" 531 use_bias: bool = True 532 """Whether to include bias in the shape-sharing networks.""" 533 input_key: Optional[str] = None 534 """The key of the input tensor.""" 535 output_key: Optional[str] = None 536 """The key of the output tensor.""" 537 squeeze: bool = False 538 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 539 540 kernel_init: Union[str, Callable] = "lecun_normal()" 541 """The kernel initialization method to use in the shape-sharing networks.""" 542 router_key: Optional[str] = None 543 """The key of the router tensor. If None, the router is assumed to be the same as the input tensor.""" 544 545 FID: ClassVar[str] = "MOE_NET" 546 547 @nn.compact 548 def __call__( 549 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 550 ) -> Union[dict, jax.Array]: 551 if self.input_key is None: 552 assert not isinstance( 553 inputs, dict 554 ), "input key must be provided if inputs is a dictionary" 555 if isinstance(inputs, tuple): 556 embedding, router = inputs 557 else: 558 embedding = router = inputs 559 else: 560 embedding = inputs[self.input_key] 561 router = ( 562 inputs[self.router_key] if self.router_key is not None else embedding 563 ) 564 565 ############################ 566 # build shape-sharing networks using vmap 567 networks = nn.vmap( 568 FullyConnectedNet, 569 variable_axes={"params": 0}, 570 split_rngs={"params": True}, 571 in_axes=0, 572 out_axes=0, 573 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 574 # repeat input along a new axis to compute for all networks at once 575 x = jnp.repeat(embedding[None, :, :], self.num_networks, axis=0) 576 577 w = nn.softmax(nn.Dense(self.num_networks, name="router")(router), axis=-1) 578 579 out = (networks(x) * w.T[:, :, None]).sum(axis=0) 580 581 if self.squeeze and out.shape[-1] == 1: 582 out = jnp.squeeze(out, axis=-1) 583 ############################ 584 585 if self.input_key is not None: 586 output_key = self.name if self.output_key is None else self.output_key 587 return {**inputs, output_key: out} if output_key is not None else out 588 return out 589 590class ChannelNet(nn.Module): 591 """Apply a different neural network to each channel. 592 593 FID: CHANNEL_NET 594 """ 595 596 neurons: Sequence[int] 597 """A sequence of integers representing the number of neurons in each shape-sharing network.""" 598 activation: Union[Callable, str] = "silu" 599 """The activation function to use in the shape-sharing networks.""" 600 use_bias: bool = True 601 """Whether to include bias in the shape-sharing networks.""" 602 input_key: Optional[str] = None 603 """The key of the input tensor.""" 604 output_key: Optional[str] = None 605 """The key of the output tensor.""" 606 squeeze: bool = False 607 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 608 kernel_init: Union[str, Callable] = "lecun_normal()" 609 """The kernel initialization method to use in the shape-sharing networks.""" 610 channel_axis: int = -2 611 """The axis to use as channel. Its length will be the number of shape-sharing networks.""" 612 613 FID: ClassVar[str] = "CHANNEL_NET" 614 615 @nn.compact 616 def __call__( 617 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 618 ) -> Union[dict, jax.Array]: 619 if self.input_key is None: 620 assert not isinstance( 621 inputs, dict 622 ), "input key must be provided if inputs is a dictionary" 623 x = inputs 624 else: 625 x = inputs[self.input_key] 626 627 ############################ 628 # build shape-sharing networks using vmap 629 networks = nn.vmap( 630 FullyConnectedNet, 631 variable_axes={"params": 0}, 632 split_rngs={"params": True}, 633 in_axes=self.channel_axis, 634 out_axes=self.channel_axis, 635 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 636 637 out = networks(x) 638 639 if self.squeeze and out.shape[-1] == 1: 640 out = jnp.squeeze(out, axis=-1) 641 ############################ 642 643 if self.input_key is not None: 644 output_key = self.name if self.output_key is None else self.output_key 645 return {**inputs, output_key: out} if output_key is not None else out 646 return out 647 648 649class GatedPerceptron(nn.Module): 650 """Gated Perceptron neural network. 651 652 FID: GATED_PERCEPTRON 653 654 This class represents a Gated Perceptron neural network model. It applies a gating mechanism 655 to the input data and performs linear transformation using a dense layer followed by an activation function. 656 """ 657 658 dim: int 659 """The dimensionality of the output space.""" 660 use_bias: bool = True 661 """Whether to include a bias term in the dense layer.""" 662 kernel_init: Union[str, Callable] = "lecun_normal()" 663 """The kernel initialization method to use.""" 664 activation: Union[Callable, str] = "silu" 665 """The activation function to use.""" 666 667 input_key: Optional[str] = None 668 """The key of the input tensor.""" 669 output_key: Optional[str] = None 670 """The key of the output tensor.""" 671 squeeze: bool = False 672 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 673 674 FID: ClassVar[str] = "GATED_PERCEPTRON" 675 676 @nn.compact 677 def __call__(self, inputs): 678 if self.input_key is None: 679 assert not isinstance( 680 inputs, dict 681 ), "input key must be provided if inputs is a dictionary" 682 x = inputs 683 else: 684 x = inputs[self.input_key] 685 686 # activation = ( 687 # activation_from_str(self.activation) 688 # if isinstance(self.activation, str) 689 # else self.activation 690 # ) 691 kernel_init = ( 692 initializer_from_str(self.kernel_init) 693 if isinstance(self.kernel_init, str) 694 else self.kernel_init 695 ) 696 ############################ 697 gate = jax.nn.sigmoid( 698 nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x) 699 ) 700 x = gate * activation_from_str(self.activation)( 701 nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x) 702 ) 703 704 if self.squeeze and out.shape[-1] == 1: 705 out = jnp.squeeze(out, axis=-1) 706 ############################ 707 708 if self.input_key is not None: 709 output_key = self.name if self.output_key is None else self.output_key 710 return {**inputs, output_key: x} if output_key is not None else x 711 return x 712 713 714class ZAcNet(nn.Module): 715 """ A fully connected neural network module with affine Z-dependent adjustments of activations. 716 717 FID: ZACNET 718 """ 719 720 neurons: Sequence[int] 721 """A sequence of integers representing the dimensions of the network.""" 722 zmax: int = 86 723 """The maximum atomic number to consider.""" 724 activation: Union[Callable, str] = "silu" 725 """The activation function to use.""" 726 use_bias: bool = True 727 """Whether to use bias in the dense layers.""" 728 input_key: Optional[str] = None 729 """The key of the input tensor.""" 730 output_key: Optional[str] = None 731 """The key of the output tensor.""" 732 squeeze: bool = False 733 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 734 kernel_init: Union[str, Callable] = "lecun_normal()" 735 """The kernel initialization method to use.""" 736 species_key: str = "species" 737 """The key of the species tensor.""" 738 739 FID: ClassVar[str] = "ZACNET" 740 741 @nn.compact 742 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 743 if self.input_key is None: 744 assert not isinstance( 745 inputs, dict 746 ), "input key must be provided if inputs is a dictionary" 747 species, x = inputs 748 else: 749 species, x = inputs[self.species_key], inputs[self.input_key] 750 751 # activation = ( 752 # activation_from_str(self.activation) 753 # if isinstance(self.activation, str) 754 # else self.activation 755 # ) 756 kernel_init = ( 757 initializer_from_str(self.kernel_init) 758 if isinstance(self.kernel_init, str) 759 else self.kernel_init 760 ) 761 ############################ 762 for i, d in enumerate(self.neurons[:-1]): 763 x = nn.Dense( 764 d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init 765 )(x) 766 sig = self.param( 767 f"sig_{i+1}", 768 lambda key, shape: jnp.ones(shape, dtype=x.dtype), 769 (self.zmax + 2, d), 770 )[species] 771 if self.use_bias: 772 b = self.param( 773 f"b_{i+1}", 774 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 775 (self.zmax + 2, d), 776 )[species] 777 else: 778 b = 0 779 x = activation_from_str(self.activation)(sig * x + b) 780 x = nn.Dense( 781 self.neurons[-1], 782 use_bias=self.use_bias, 783 name=f"Layer_{len(self.neurons)}", 784 kernel_init=kernel_init, 785 )(x) 786 sig = self.param( 787 f"sig_{len(self.neurons)}", 788 lambda key, shape: jnp.ones(shape, dtype=x.dtype), 789 (self.zmax + 2, self.neurons[-1]), 790 )[species] 791 if self.use_bias: 792 b = self.param( 793 f"b_{len(self.neurons)}", 794 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 795 (self.zmax + 2, self.neurons[-1]), 796 )[species] 797 else: 798 b = 0 799 x = sig * x + b 800 if self.squeeze and x.shape[-1] == 1: 801 x = jnp.squeeze(x, axis=-1) 802 ############################ 803 804 if self.input_key is not None: 805 output_key = self.name if self.output_key is None else self.output_key 806 return {**inputs, output_key: x} if output_key is not None else x 807 return x 808 809 810class ZLoRANet(nn.Module): 811 """A fully connected neural network module with Z-dependent low-rank adaptation. 812 813 FID: ZLORANET 814 """ 815 816 neurons: Sequence[int] 817 """A sequence of integers representing the dimensions of the network.""" 818 ranks: Sequence[int] 819 """A sequence of integers representing the ranks of the low-rank adaptation at each layer.""" 820 zmax: int = 86 821 """The maximum atomic number to consider.""" 822 activation: Union[Callable, str] = "silu" 823 """The activation function to use.""" 824 use_bias: bool = True 825 """Whether to use bias in the dense layers.""" 826 input_key: Optional[str] = None 827 """The key of the input tensor.""" 828 output_key: Optional[str] = None 829 """The key of the output tensor.""" 830 squeeze: bool = False 831 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 832 kernel_init: Union[str, Callable] = "lecun_normal()" 833 """The kernel initialization method to use.""" 834 species_key: str = "species" 835 """The key of the species tensor.""" 836 837 FID: ClassVar[str] = "ZLORANET" 838 839 @nn.compact 840 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 841 if self.input_key is None: 842 assert not isinstance( 843 inputs, dict 844 ), "input key must be provided if inputs is a dictionary" 845 species, x = inputs 846 else: 847 species, x = inputs[self.species_key], inputs[self.input_key] 848 849 # activation = ( 850 # activation_from_str(self.activation) 851 # if isinstance(self.activation, str) 852 # else self.activation 853 # ) 854 kernel_init = ( 855 initializer_from_str(self.kernel_init) 856 if isinstance(self.kernel_init, str) 857 else self.kernel_init 858 ) 859 ############################ 860 for i, d in enumerate(self.neurons[:-1]): 861 xi = nn.Dense( 862 d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init 863 )(x) 864 A = self.param( 865 f"A_{i+1}", 866 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 867 (self.zmax + 2, self.ranks[i], x.shape[-1]), 868 )[species] 869 B = self.param( 870 f"B_{i+1}", 871 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 872 (self.zmax + 2, d, self.ranks[i]), 873 )[species] 874 Ax = jnp.einsum("zrd,zd->zr", A, x) 875 BAx = jnp.einsum("zrd,zd->zr", B, Ax) 876 x = activation_from_str(self.activation)(xi + BAx) 877 xi = nn.Dense( 878 self.neurons[-1], 879 use_bias=self.use_bias, 880 name=f"Layer_{len(self.neurons)}", 881 kernel_init=kernel_init, 882 )(x) 883 A = self.param( 884 f"A_{len(self.neurons)}", 885 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 886 (self.zmax + 2, self.ranks[-1], x.shape[-1]), 887 )[species] 888 B = self.param( 889 f"B_{len(self.neurons)}", 890 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 891 (self.zmax + 2, self.neurons[-1], self.ranks[-1]), 892 )[species] 893 Ax = jnp.einsum("zrd,zd->zr", A, x) 894 BAx = jnp.einsum("zrd,zd->zr", B, Ax) 895 x = xi + BAx 896 if self.squeeze and x.shape[-1] == 1: 897 x = jnp.squeeze(x, axis=-1) 898 ############################ 899 900 if self.input_key is not None: 901 output_key = self.name if self.output_key is None else self.output_key 902 return {**inputs, output_key: x} if output_key is not None else x 903 return x 904 905 906class BlockIndexNet(nn.Module): 907 """Chemical-species-specific neural network using precomputed species index. 908 909 FID: BLOCK_INDEX_NET 910 911 A neural network that applies a species-specific fully connected network to each atom embedding. 912 A species index must be provided to filter the embeddings for each species and apply the corresponding network. 913 This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer` 914 915 """ 916 917 output_dim: int 918 """The dimension of the output of the fully connected networks.""" 919 hidden_neurons: Sequence[int] 920 """The hidden dimensions of the fully connected networks. 921 If a dictionary is provided, it should map species names to dimensions. 922 If a sequence is provided, the same dimensions will be used for all species.""" 923 used_blocks: Optional[Sequence[str]] = None 924 """The blocks to use. If None, all blocks will be used.""" 925 activation: Union[Callable, str] = "silu" 926 """The activation function to use in the fully connected networks.""" 927 use_bias: bool = True 928 """Whether to include bias terms in the fully connected networks.""" 929 input_key: Optional[str] = None 930 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 931 block_index_key: str = "block_index" 932 """The key in the input dictionary that corresponds to the block index of the atoms. See `fennol.models.preprocessing.BlockIndexer`""" 933 output_key: Optional[str] = None 934 """The key in the output dictionary that corresponds to the network's output.""" 935 936 squeeze: bool = False 937 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 938 kernel_init: Union[str, Callable] = "lecun_normal()" 939 """The kernel initialization method for the fully connected networks.""" 940 # check_unhandled: bool = True 941 942 FID: ClassVar[str] = "BLOCK_INDEX_NET" 943 944 # def setup(self): 945 # all_blocks = CHEMICAL_BLOCKS_NAMES 946 # if self.used_blocks is None: 947 # used_blocks = all_blocks 948 # else: 949 # used_blocks = [] 950 # for b in self.used_blocks: 951 # b_=str(b).strip().upper() 952 # if b_ not in all_blocks: 953 # raise ValueError(f"Block {b} not found in {all_blocks}") 954 # used_blocks.append(b_) 955 # used_blocks = set(used_blocks) 956 # self._used_blocks = used_blocks 957 958 # if not ( 959 # isinstance(self.hidden_neurons, dict) 960 # or isinstance(self.hidden_neurons, FrozenDict) 961 # ): 962 # neurons = {k: self.hidden_neurons for k in used_blocks} 963 # else: 964 # neurons = {} 965 # for b in self.hidden_neurons.keys(): 966 # b_=str(b).strip().upper() 967 # if b_ not in all_blocks: 968 # raise ValueError(f"Block {b} does not exist. Available blocks are {all_blocks}") 969 # neurons[b_] = self.hidden_neurons[b] 970 # used_blocks = set(neurons.keys()) 971 # if used_blocks != self._used_blocks and self.used_blocks is not None: 972 # print( 973 # f"Warning: hidden neurons definitions do not match specified used_blocks {self.used_blocks}. Using blocks defined in hidden_neurons.") 974 # self._used_blocks = used_blocks 975 976 # self.networks = { 977 # k: FullyConnectedNet( 978 # [*neurons[k], self.output_dim], 979 # self.activation, 980 # self.use_bias, 981 # name=k, 982 # kernel_init=self.kernel_init, 983 # ) 984 # for k in self._used_blocks 985 # } 986 987 @nn.compact 988 def __call__( 989 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 990 ) -> Union[dict, jax.Array]: 991 992 if self.input_key is None: 993 assert not isinstance( 994 inputs, dict 995 ), "input key must be provided if inputs is a dictionary" 996 species, embedding, block_index = inputs 997 else: 998 species, embedding = inputs["species"], inputs[self.input_key] 999 block_index = inputs[self.block_index_key] 1000 1001 assert isinstance( 1002 block_index, dict 1003 ), "block_index must be a dictionary for BlockIndexNet" 1004 1005 networks = { 1006 k: FullyConnectedNet( 1007 [*self.hidden_neurons, self.output_dim], 1008 self.activation, 1009 self.use_bias, 1010 name=k, 1011 kernel_init=self.kernel_init, 1012 ) 1013 for k in block_index.keys() 1014 } 1015 1016 ############################ 1017 # initialization => instantiate all networks 1018 if self.is_initializing(): 1019 x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype) 1020 [net(x) for net in networks.values()] 1021 1022 # if self.check_unhandled: 1023 # for b in block_index.keys(): 1024 # if b not in networks.keys(): 1025 # raise ValueError(f"Block {b} not found in networks. Available blocks are {self.networks.keys()}") 1026 1027 ############################ 1028 outputs = [] 1029 indices = [] 1030 for s, net in networks.items(): 1031 if s not in block_index: 1032 continue 1033 if block_index[s] is None: 1034 continue 1035 idx = block_index[s] 1036 o = net(embedding[idx]) 1037 outputs.append(o) 1038 indices.append(idx) 1039 1040 o = jnp.concatenate(outputs, axis=0) 1041 idx = jnp.concatenate(indices, axis=0) 1042 1043 out = ( 1044 jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype) 1045 .at[idx] 1046 .set(o, mode="drop") 1047 ) 1048 1049 if self.squeeze and out.shape[-1] == 1: 1050 out = jnp.squeeze(out, axis=-1) 1051 ############################ 1052 1053 if self.input_key is not None: 1054 output_key = self.name if self.output_key is None else self.output_key 1055 return {**inputs, output_key: out} if output_key is not None else out 1056 return out
14class FullyConnectedNet(nn.Module): 15 """A fully connected neural network module. 16 17 FID: NEURAL_NET 18 """ 19 20 neurons: Sequence[int] 21 """A sequence of integers representing the dimensions of the network.""" 22 activation: Union[Callable, str] = "silu" 23 """The activation function to use.""" 24 use_bias: bool = True 25 """Whether to use bias in the dense layers.""" 26 input_key: Optional[str] = None 27 """The key of the input tensor.""" 28 output_key: Optional[str] = None 29 """The key of the output tensor.""" 30 squeeze: bool = False 31 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 32 kernel_init: Union[str, Callable] = "lecun_normal()" 33 """The kernel initialization method to use.""" 34 35 FID: ClassVar[str] = "NEURAL_NET" 36 37 @nn.compact 38 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 39 """Applies the neural network to the given inputs. 40 41 Args: 42 inputs (Union[dict, jax.Array]): If input_key is None, a JAX array containing the inputs to the neural network. Else, a dictionary containing the inputs at the key input_key. 43 44 Returns: 45 Union[dict, jax.Array]: If output_key is None, the output tensor of the neural network. Else, a dictionary containing the original inputs and the output tensor at the key output_key. 46 """ 47 if self.input_key is None: 48 assert not isinstance( 49 inputs, dict 50 ), "input key must be provided if inputs is a dictionary" 51 x = inputs 52 else: 53 x = inputs[self.input_key] 54 55 # activation = ( 56 # activation_from_str(self.activation) 57 # if isinstance(self.activation, str) 58 # else self.activation 59 # ) 60 kernel_init = ( 61 initializer_from_str(self.kernel_init) 62 if isinstance(self.kernel_init, str) 63 else self.kernel_init 64 ) 65 ############################ 66 if isinstance(self.activation, str) and self.activation.lower() == "swiglu": 67 for i, d in enumerate(self.neurons[:-1]): 68 y = nn.Dense( 69 d, 70 use_bias=self.use_bias, 71 name=f"Layer_{i+1}", 72 kernel_init=kernel_init, 73 )(x) 74 z = jax.nn.swish(nn.Dense( 75 d, 76 use_bias=self.use_bias, 77 name=f"Mask_{i+1}", 78 kernel_init=kernel_init, 79 )(x)) 80 x = y * z 81 else: 82 for i, d in enumerate(self.neurons[:-1]): 83 x = nn.Dense( 84 d, 85 use_bias=self.use_bias, 86 name=f"Layer_{i+1}", 87 kernel_init=kernel_init, 88 )(x) 89 x = activation_from_str(self.activation)(x) 90 x = nn.Dense( 91 self.neurons[-1], 92 use_bias=self.use_bias, 93 name=f"Layer_{len(self.neurons)}", 94 kernel_init=kernel_init, 95 )(x) 96 if self.squeeze and x.shape[-1] == 1: 97 x = jnp.squeeze(x, axis=-1) 98 ############################ 99 100 if self.input_key is not None: 101 output_key = self.name if self.output_key is None else self.output_key 102 return {**inputs, output_key: x} if output_key is not None else x 103 return x
A fully connected neural network module.
FID: NEURAL_NET
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
106class ResMLP(nn.Module): 107 """Residual neural network as defined in the SpookyNet paper. 108 109 FID: RES_MLP 110 """ 111 112 use_bias: bool = True 113 """Whether to include bias in the linear layers.""" 114 input_key: Optional[str] = None 115 """The key of the input tensor.""" 116 output_key: Optional[str] = None 117 """The key of the output tensor.""" 118 119 kernel_init: Union[str, Callable] = "scaled_orthogonal(mode='fan_avg')" 120 """The kernel initialization method to use.""" 121 res_only: bool = False 122 """Whether to only apply the residual connection without additional activation and linear layer.""" 123 124 FID: ClassVar[str] = "RES_MLP" 125 126 @nn.compact 127 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 128 if self.input_key is None: 129 assert not isinstance( 130 inputs, dict 131 ), "input key must be provided if inputs is a dictionary" 132 x = inputs 133 else: 134 x = inputs[self.input_key] 135 136 kernel_init = ( 137 initializer_from_str(self.kernel_init) 138 if isinstance(self.kernel_init, str) 139 else self.kernel_init 140 ) 141 ############################ 142 out = nn.Dense(x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init)( 143 TrainableSiLU()(x) 144 ) 145 out = x + nn.Dense( 146 x.shape[-1], use_bias=self.use_bias, kernel_init=nn.initializers.zeros 147 )(TrainableSiLU()(out)) 148 149 if not self.res_only: 150 out = nn.Dense( 151 x.shape[-1], use_bias=self.use_bias, kernel_init=kernel_init 152 )(TrainableSiLU()(out)) 153 ############################ 154 155 if self.input_key is not None: 156 output_key = self.name if self.output_key is None else self.output_key 157 return {**inputs, output_key: out} if output_key is not None else out 158 return out
Residual neural network as defined in the SpookyNet paper.
FID: RES_MLP
The kernel initialization method to use.
Whether to only apply the residual connection without additional activation and linear layer.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
161class FullyResidualNet(nn.Module): 162 """A neural network with skip connections at each layer. 163 164 FID: SKIP_NET 165 """ 166 167 dim: int 168 """The dimension of the hidden layers.""" 169 output_dim: int 170 """The dimension of the output layer.""" 171 nlayers: int 172 """The number of layers in the network.""" 173 activation: Union[Callable, str] = "silu" 174 """The activation function to use.""" 175 use_bias: bool = True 176 """Whether to include bias terms in the linear layers.""" 177 input_key: Optional[str] = None 178 """The key of the input tensor.""" 179 output_key: Optional[str] = None 180 """The key of the output tensor.""" 181 squeeze: bool = False 182 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 183 kernel_init: Union[str, Callable] = "lecun_normal()" 184 """The kernel initialization method to use.""" 185 186 FID: ClassVar[str] = "SKIP_NET" 187 188 @nn.compact 189 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 190 if self.input_key is None: 191 assert not isinstance( 192 inputs, dict 193 ), "input key must be provided if inputs is a dictionary" 194 x = inputs 195 else: 196 x = inputs[self.input_key] 197 198 # activation = ( 199 # activation_from_str(self.activation) 200 # if isinstance(self.activation, str) 201 # else self.activation 202 # ) 203 kernel_init = ( 204 initializer_from_str(self.kernel_init) 205 if isinstance(self.kernel_init, str) 206 else self.kernel_init 207 ) 208 ############################ 209 if x.shape[-1] != self.dim: 210 x = nn.Dense( 211 self.dim, 212 use_bias=self.use_bias, 213 name=f"Reshape", 214 kernel_init=kernel_init, 215 )(x) 216 217 for i in range(self.nlayers - 1): 218 x = x + activation_from_str(self.activation)( 219 nn.Dense( 220 self.dim, 221 use_bias=self.use_bias, 222 name=f"Layer_{i+1}", 223 kernel_init=kernel_init, 224 )(x) 225 ) 226 x = nn.Dense( 227 self.output_dim, 228 use_bias=self.use_bias, 229 name=f"Layer_{self.nlayers}", 230 kernel_init=kernel_init, 231 )(x) 232 if self.squeeze and x.shape[-1] == 1: 233 x = jnp.squeeze(x, axis=-1) 234 ############################ 235 236 if self.input_key is not None: 237 output_key = self.name if self.output_key is None else self.output_key 238 return {**inputs, output_key: x} if output_key is not None else x 239 return x
A neural network with skip connections at each layer.
FID: SKIP_NET
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
242class HierarchicalNet(nn.Module): 243 """Neural network for a sequence of inputs (in axis=-2) with a decay factor 244 245 FID: HIERARCHICAL_NET 246 """ 247 248 neurons: Sequence[int] 249 """A sequence of integers representing the number of neurons in each layer.""" 250 activation: Union[Callable, str] = "silu" 251 """The activation function to use.""" 252 use_bias: bool = True 253 """Whether to include bias terms in the linear layers.""" 254 input_key: Optional[str] = None 255 """The key of the input tensor.""" 256 output_key: Optional[str] = None 257 """The key of the output tensor.""" 258 decay: float = 0.01 259 """The decay factor to scale each element of the sequence.""" 260 squeeze: bool = False 261 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 262 kernel_init: Union[str, Callable] = "lecun_normal()" 263 """The kernel initialization method to use.""" 264 265 FID: ClassVar[str] = "HIERARCHICAL_NET" 266 267 @nn.compact 268 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 269 if self.input_key is None: 270 assert not isinstance( 271 inputs, dict 272 ), "input key must be provided if inputs is a dictionary" 273 x = inputs 274 else: 275 x = inputs[self.input_key] 276 277 ############################ 278 networks = nn.vmap( 279 FullyConnectedNet, 280 variable_axes={"params": 0}, 281 split_rngs={"params": True}, 282 in_axes=-2, 283 out_axes=-2, 284 kenel_init=self.kernel_init, 285 )(self.neurons, self.activation, self.use_bias) 286 287 out = networks(x) 288 # scale each layer by a decay factor 289 decay = jnp.asarray([self.decay**i for i in range(out.shape[-2])]) 290 out = out * decay[..., :, None] 291 292 if self.squeeze and out.shape[-1] == 1: 293 out = jnp.squeeze(out, axis=-1) 294 ############################ 295 296 if self.input_key is not None: 297 output_key = self.name if self.output_key is None else self.output_key 298 return {**inputs, output_key: out} if output_key is not None else out 299 return out
Neural network for a sequence of inputs (in axis=-2) with a decay factor
FID: HIERARCHICAL_NET
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
302class SpeciesIndexNet(nn.Module): 303 """Chemical-species-specific neural network using precomputed species index. 304 305 FID: SPECIES_INDEX_NET 306 307 A neural network that applies a species-specific fully connected network to each atom embedding. 308 A species index must be provided to filter the embeddings for each species and apply the corresponding network. 309 This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer` 310 311 """ 312 313 output_dim: int 314 """The dimension of the output of the fully connected networks.""" 315 hidden_neurons: Union[dict, FrozenDict, Sequence[int]] 316 """The hidden dimensions of the fully connected networks. 317 If a dictionary is provided, it should map species names to dimensions. 318 If a sequence is provided, the same dimensions will be used for all species.""" 319 species_order: Optional[Union[str, Sequence[str]]] = None 320 """The species for which to build a network. Only required if neurons is not a dictionary.""" 321 activation: Union[Callable, str] = "silu" 322 """The activation function to use in the fully connected networks.""" 323 use_bias: bool = True 324 """Whether to include bias terms in the fully connected networks.""" 325 input_key: Optional[str] = None 326 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 327 species_index_key: str = "species_index" 328 """The key in the input dictionary that corresponds to the species index of the atoms. See `fennol.models.preprocessing.SpeciesIndexer`""" 329 output_key: Optional[str] = None 330 """The key in the output dictionary that corresponds to the network's output.""" 331 332 squeeze: bool = False 333 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 334 kernel_init: Union[str, Callable] = "lecun_normal()" 335 """The kernel initialization method for the fully connected networks.""" 336 check_unhandled: bool = True 337 338 FID: ClassVar[str] = "SPECIES_INDEX_NET" 339 340 def setup(self): 341 if not ( 342 isinstance(self.hidden_neurons, dict) 343 or isinstance(self.hidden_neurons, FrozenDict) 344 ): 345 assert ( 346 self.species_order is not None 347 ), "species_order must be provided if hidden_neurons is not a dictionary" 348 if isinstance(self.species_order, str): 349 species_order = [el.strip() for el in self.species_order.split(",")] 350 else: 351 species_order = [el for el in self.species_order] 352 neurons = {k: self.hidden_neurons for k in species_order} 353 else: 354 neurons = self.hidden_neurons 355 species_order = list(neurons.keys()) 356 for species in species_order: 357 assert ( 358 species in PERIODIC_TABLE 359 ), f"species {species} not found in periodic table" 360 361 self.networks = { 362 k: FullyConnectedNet( 363 [*neurons[k], self.output_dim], 364 self.activation, 365 self.use_bias, 366 name=k, 367 kernel_init=self.kernel_init, 368 ) 369 for k in species_order 370 } 371 372 def __call__( 373 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 374 ) -> Union[dict, jax.Array]: 375 376 if self.input_key is None: 377 assert not isinstance( 378 inputs, dict 379 ), "input key must be provided if inputs is a dictionary" 380 species, embedding, species_index = inputs 381 else: 382 species, embedding = inputs["species"], inputs[self.input_key] 383 species_index = inputs[self.species_index_key] 384 385 assert isinstance( 386 species_index, dict 387 ), "species_index must be a dictionary for SpeciesIndexNetHet" 388 389 ############################ 390 # initialization => instantiate all networks 391 if self.is_initializing(): 392 x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype) 393 [net(x) for net in self.networks.values()] 394 395 if self.check_unhandled: 396 for b in species_index.keys(): 397 if b not in self.networks.keys(): 398 raise ValueError(f"Species {b} not found in networks. Handled species are {self.networks.keys()}") 399 400 ############################ 401 outputs = [] 402 indices = [] 403 for s, net in self.networks.items(): 404 if s not in species_index: 405 continue 406 idx = species_index[s] 407 o = net(embedding[idx]) 408 outputs.append(o) 409 indices.append(idx) 410 411 o = jnp.concatenate(outputs, axis=0) 412 idx = jnp.concatenate(indices, axis=0) 413 414 out = ( 415 jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype) 416 .at[idx] 417 .set(o, mode="drop") 418 ) 419 420 if self.squeeze and out.shape[-1] == 1: 421 out = jnp.squeeze(out, axis=-1) 422 ############################ 423 424 if self.input_key is not None: 425 output_key = self.name if self.output_key is None else self.output_key 426 return {**inputs, output_key: out} if output_key is not None else out 427 return out
Chemical-species-specific neural network using precomputed species index.
FID: SPECIES_INDEX_NET
A neural network that applies a species-specific fully connected network to each atom embedding.
A species index must be provided to filter the embeddings for each species and apply the corresponding network.
This index can be obtained using the SPECIES_INDEXER preprocessing module from fennol.models.preprocessing.SpeciesIndexer
The species for which to build a network. Only required if neurons is not a dictionary.
The activation function to use in the fully connected networks.
The key in the input dictionary that corresponds to the embeddings of the atoms.
The key in the input dictionary that corresponds to the species index of the atoms. See fennol.models.preprocessing.SpeciesIndexer
The key in the output dictionary that corresponds to the network's output.
The kernel initialization method for the fully connected networks.
340 def setup(self): 341 if not ( 342 isinstance(self.hidden_neurons, dict) 343 or isinstance(self.hidden_neurons, FrozenDict) 344 ): 345 assert ( 346 self.species_order is not None 347 ), "species_order must be provided if hidden_neurons is not a dictionary" 348 if isinstance(self.species_order, str): 349 species_order = [el.strip() for el in self.species_order.split(",")] 350 else: 351 species_order = [el for el in self.species_order] 352 neurons = {k: self.hidden_neurons for k in species_order} 353 else: 354 neurons = self.hidden_neurons 355 species_order = list(neurons.keys()) 356 for species in species_order: 357 assert ( 358 species in PERIODIC_TABLE 359 ), f"species {species} not found in periodic table" 360 361 self.networks = { 362 k: FullyConnectedNet( 363 [*neurons[k], self.output_dim], 364 self.activation, 365 self.use_bias, 366 name=k, 367 kernel_init=self.kernel_init, 368 ) 369 for k in species_order 370 }
Initializes a Module lazily (similar to a lazy __init__
).
setup
is called once lazily on a module instance when a module
is bound, immediately before any other methods like __call__
are
invoked, or before a setup
-defined attribute on self
is accessed.
This can happen in three cases:
Immediately when invoking
apply()
,init()
orinit_and_output()
.Once the module is given a name by being assigned to an attribute of another module inside the other module's
setup
method (see__setattr__()
)::>>> class MyModule(nn.Module): ... def setup(self): ... submodule = nn.Conv(...) ... # Accessing `submodule` attributes does not yet work here. ... # The following line invokes `self.__setattr__`, which gives ... # `submodule` the name "conv1". ... self.conv1 = submodule ... # Accessing `submodule` attributes or methods is now safe and ... # either causes setup() to be called once.
Once a module is constructed inside a method wrapped with
compact()
, immediately before another method is called orsetup
defined attribute is accessed.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
430class ChemicalNet(nn.Module): 431 """optimized Chemical-species-specific neural network. 432 433 FID: CHEMICAL_NET 434 435 A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species. 436 This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once. 437 The optimization is allowed because all networks have the same shape. 438 439 """ 440 441 species_order: Union[str, Sequence[str]] 442 """The species for which to build a network.""" 443 neurons: Sequence[int] 444 """The dimensions of the fully connected networks.""" 445 activation: Union[Callable, str] = "silu" 446 """The activation function to use in the fully connected networks.""" 447 use_bias: bool = True 448 """Whether to include bias terms in the fully connected networks.""" 449 input_key: Optional[str] = None 450 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 451 output_key: Optional[str] = None 452 """The key in the output dictionary that corresponds to the network's output.""" 453 squeeze: bool = False 454 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 455 kernel_init: Union[str, Callable] = "lecun_normal()" 456 """The kernel initialization method for the fully connected networks.""" 457 458 FID: ClassVar[str] = "CHEMICAL_NET" 459 460 @nn.compact 461 def __call__( 462 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 463 ) -> Union[dict, jax.Array]: 464 if self.input_key is None: 465 assert not isinstance( 466 inputs, dict 467 ), "input key must be provided if inputs is a dictionary" 468 species, embedding = inputs 469 else: 470 species, embedding = inputs["species"], inputs[self.input_key] 471 472 ############################ 473 # build species to network index mapping (static => fixed when jitted) 474 rev_idx = PERIODIC_TABLE_REV_IDX 475 maxidx = max(rev_idx.values()) 476 if isinstance(self.species_order, str): 477 species_order = [el.strip() for el in self.species_order.split(",")] 478 else: 479 species_order = [el for el in self.species_order] 480 nspecies = len(species_order) 481 conv_tensor_ = np.full((maxidx + 2,), -1, dtype=np.int32) 482 for i, s in enumerate(species_order): 483 conv_tensor_[rev_idx[s]] = i 484 conv_tensor = jnp.asarray(conv_tensor_) 485 indices = conv_tensor[species] 486 487 ############################ 488 # build shape-sharing networks using vmap 489 networks = nn.vmap( 490 FullyConnectedNet, 491 variable_axes={"params": 0}, 492 split_rngs={"params": True}, 493 in_axes=0, 494 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 495 # repeat input along a new axis to compute for all species at once 496 x = jnp.broadcast_to( 497 embedding[None, :, :], (nspecies, *embedding.shape) 498 ) 499 500 # apply networks to input and select the output corresponding to the species 501 out = jnp.squeeze( 502 jnp.take_along_axis(networks(x), indices[None, :, None], axis=0), axis=0 503 ) 504 505 out = jnp.where((indices >= 0)[:, None], out, 0.0) 506 if self.squeeze and out.shape[-1] == 1: 507 out = jnp.squeeze(out, axis=-1) 508 ############################ 509 510 if self.input_key is not None: 511 output_key = self.name if self.output_key is None else self.output_key 512 return {**inputs, output_key: out} if output_key is not None else out 513 return out
optimized Chemical-species-specific neural network.
FID: CHEMICAL_NET
A neural network that applies a fully connected network to each atom embedding in a chemical system and selects the output corresponding to the atom's species. This is an optimized version of ChemicalNetHet that uses vmap to apply the networks to all atoms at once. The optimization is allowed because all networks have the same shape.
The activation function to use in the fully connected networks.
The key in the input dictionary that corresponds to the embeddings of the atoms.
The key in the output dictionary that corresponds to the network's output.
The kernel initialization method for the fully connected networks.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
516class MOENet(nn.Module): 517 """Mixture of Experts neural network. 518 519 FID: MOE_NET 520 521 This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks 522 to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router. 523 524 """ 525 526 neurons: Sequence[int] 527 """A sequence of integers representing the number of neurons in each shape-sharing network.""" 528 num_networks: int 529 """The number of shape-sharing networks to create.""" 530 activation: Union[Callable, str] = "silu" 531 """The activation function to use in the shape-sharing networks.""" 532 use_bias: bool = True 533 """Whether to include bias in the shape-sharing networks.""" 534 input_key: Optional[str] = None 535 """The key of the input tensor.""" 536 output_key: Optional[str] = None 537 """The key of the output tensor.""" 538 squeeze: bool = False 539 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 540 541 kernel_init: Union[str, Callable] = "lecun_normal()" 542 """The kernel initialization method to use in the shape-sharing networks.""" 543 router_key: Optional[str] = None 544 """The key of the router tensor. If None, the router is assumed to be the same as the input tensor.""" 545 546 FID: ClassVar[str] = "MOE_NET" 547 548 @nn.compact 549 def __call__( 550 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 551 ) -> Union[dict, jax.Array]: 552 if self.input_key is None: 553 assert not isinstance( 554 inputs, dict 555 ), "input key must be provided if inputs is a dictionary" 556 if isinstance(inputs, tuple): 557 embedding, router = inputs 558 else: 559 embedding = router = inputs 560 else: 561 embedding = inputs[self.input_key] 562 router = ( 563 inputs[self.router_key] if self.router_key is not None else embedding 564 ) 565 566 ############################ 567 # build shape-sharing networks using vmap 568 networks = nn.vmap( 569 FullyConnectedNet, 570 variable_axes={"params": 0}, 571 split_rngs={"params": True}, 572 in_axes=0, 573 out_axes=0, 574 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 575 # repeat input along a new axis to compute for all networks at once 576 x = jnp.repeat(embedding[None, :, :], self.num_networks, axis=0) 577 578 w = nn.softmax(nn.Dense(self.num_networks, name="router")(router), axis=-1) 579 580 out = (networks(x) * w.T[:, :, None]).sum(axis=0) 581 582 if self.squeeze and out.shape[-1] == 1: 583 out = jnp.squeeze(out, axis=-1) 584 ############################ 585 586 if self.input_key is not None: 587 output_key = self.name if self.output_key is None else self.output_key 588 return {**inputs, output_key: out} if output_key is not None else out 589 return out
Mixture of Experts neural network.
FID: MOE_NET
This class represents a Mixture of Experts neural network. It takes in an input and applies a set of shape-sharing networks to the input based on a router. The outputs of the shape-sharing networks are then combined using weights computed by the router.
A sequence of integers representing the number of neurons in each shape-sharing network.
The activation function to use in the shape-sharing networks.
The kernel initialization method to use in the shape-sharing networks.
The key of the router tensor. If None, the router is assumed to be the same as the input tensor.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
591class ChannelNet(nn.Module): 592 """Apply a different neural network to each channel. 593 594 FID: CHANNEL_NET 595 """ 596 597 neurons: Sequence[int] 598 """A sequence of integers representing the number of neurons in each shape-sharing network.""" 599 activation: Union[Callable, str] = "silu" 600 """The activation function to use in the shape-sharing networks.""" 601 use_bias: bool = True 602 """Whether to include bias in the shape-sharing networks.""" 603 input_key: Optional[str] = None 604 """The key of the input tensor.""" 605 output_key: Optional[str] = None 606 """The key of the output tensor.""" 607 squeeze: bool = False 608 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 609 kernel_init: Union[str, Callable] = "lecun_normal()" 610 """The kernel initialization method to use in the shape-sharing networks.""" 611 channel_axis: int = -2 612 """The axis to use as channel. Its length will be the number of shape-sharing networks.""" 613 614 FID: ClassVar[str] = "CHANNEL_NET" 615 616 @nn.compact 617 def __call__( 618 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 619 ) -> Union[dict, jax.Array]: 620 if self.input_key is None: 621 assert not isinstance( 622 inputs, dict 623 ), "input key must be provided if inputs is a dictionary" 624 x = inputs 625 else: 626 x = inputs[self.input_key] 627 628 ############################ 629 # build shape-sharing networks using vmap 630 networks = nn.vmap( 631 FullyConnectedNet, 632 variable_axes={"params": 0}, 633 split_rngs={"params": True}, 634 in_axes=self.channel_axis, 635 out_axes=self.channel_axis, 636 )(self.neurons, self.activation, self.use_bias, kernel_init=self.kernel_init) 637 638 out = networks(x) 639 640 if self.squeeze and out.shape[-1] == 1: 641 out = jnp.squeeze(out, axis=-1) 642 ############################ 643 644 if self.input_key is not None: 645 output_key = self.name if self.output_key is None else self.output_key 646 return {**inputs, output_key: out} if output_key is not None else out 647 return out
Apply a different neural network to each channel.
FID: CHANNEL_NET
A sequence of integers representing the number of neurons in each shape-sharing network.
The activation function to use in the shape-sharing networks.
The kernel initialization method to use in the shape-sharing networks.
The axis to use as channel. Its length will be the number of shape-sharing networks.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
650class GatedPerceptron(nn.Module): 651 """Gated Perceptron neural network. 652 653 FID: GATED_PERCEPTRON 654 655 This class represents a Gated Perceptron neural network model. It applies a gating mechanism 656 to the input data and performs linear transformation using a dense layer followed by an activation function. 657 """ 658 659 dim: int 660 """The dimensionality of the output space.""" 661 use_bias: bool = True 662 """Whether to include a bias term in the dense layer.""" 663 kernel_init: Union[str, Callable] = "lecun_normal()" 664 """The kernel initialization method to use.""" 665 activation: Union[Callable, str] = "silu" 666 """The activation function to use.""" 667 668 input_key: Optional[str] = None 669 """The key of the input tensor.""" 670 output_key: Optional[str] = None 671 """The key of the output tensor.""" 672 squeeze: bool = False 673 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 674 675 FID: ClassVar[str] = "GATED_PERCEPTRON" 676 677 @nn.compact 678 def __call__(self, inputs): 679 if self.input_key is None: 680 assert not isinstance( 681 inputs, dict 682 ), "input key must be provided if inputs is a dictionary" 683 x = inputs 684 else: 685 x = inputs[self.input_key] 686 687 # activation = ( 688 # activation_from_str(self.activation) 689 # if isinstance(self.activation, str) 690 # else self.activation 691 # ) 692 kernel_init = ( 693 initializer_from_str(self.kernel_init) 694 if isinstance(self.kernel_init, str) 695 else self.kernel_init 696 ) 697 ############################ 698 gate = jax.nn.sigmoid( 699 nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x) 700 ) 701 x = gate * activation_from_str(self.activation)( 702 nn.Dense(self.dim, use_bias=self.use_bias, kernel_init=kernel_init)(x) 703 ) 704 705 if self.squeeze and out.shape[-1] == 1: 706 out = jnp.squeeze(out, axis=-1) 707 ############################ 708 709 if self.input_key is not None: 710 output_key = self.name if self.output_key is None else self.output_key 711 return {**inputs, output_key: x} if output_key is not None else x 712 return x
Gated Perceptron neural network.
FID: GATED_PERCEPTRON
This class represents a Gated Perceptron neural network model. It applies a gating mechanism to the input data and performs linear transformation using a dense layer followed by an activation function.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
715class ZAcNet(nn.Module): 716 """ A fully connected neural network module with affine Z-dependent adjustments of activations. 717 718 FID: ZACNET 719 """ 720 721 neurons: Sequence[int] 722 """A sequence of integers representing the dimensions of the network.""" 723 zmax: int = 86 724 """The maximum atomic number to consider.""" 725 activation: Union[Callable, str] = "silu" 726 """The activation function to use.""" 727 use_bias: bool = True 728 """Whether to use bias in the dense layers.""" 729 input_key: Optional[str] = None 730 """The key of the input tensor.""" 731 output_key: Optional[str] = None 732 """The key of the output tensor.""" 733 squeeze: bool = False 734 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 735 kernel_init: Union[str, Callable] = "lecun_normal()" 736 """The kernel initialization method to use.""" 737 species_key: str = "species" 738 """The key of the species tensor.""" 739 740 FID: ClassVar[str] = "ZACNET" 741 742 @nn.compact 743 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 744 if self.input_key is None: 745 assert not isinstance( 746 inputs, dict 747 ), "input key must be provided if inputs is a dictionary" 748 species, x = inputs 749 else: 750 species, x = inputs[self.species_key], inputs[self.input_key] 751 752 # activation = ( 753 # activation_from_str(self.activation) 754 # if isinstance(self.activation, str) 755 # else self.activation 756 # ) 757 kernel_init = ( 758 initializer_from_str(self.kernel_init) 759 if isinstance(self.kernel_init, str) 760 else self.kernel_init 761 ) 762 ############################ 763 for i, d in enumerate(self.neurons[:-1]): 764 x = nn.Dense( 765 d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init 766 )(x) 767 sig = self.param( 768 f"sig_{i+1}", 769 lambda key, shape: jnp.ones(shape, dtype=x.dtype), 770 (self.zmax + 2, d), 771 )[species] 772 if self.use_bias: 773 b = self.param( 774 f"b_{i+1}", 775 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 776 (self.zmax + 2, d), 777 )[species] 778 else: 779 b = 0 780 x = activation_from_str(self.activation)(sig * x + b) 781 x = nn.Dense( 782 self.neurons[-1], 783 use_bias=self.use_bias, 784 name=f"Layer_{len(self.neurons)}", 785 kernel_init=kernel_init, 786 )(x) 787 sig = self.param( 788 f"sig_{len(self.neurons)}", 789 lambda key, shape: jnp.ones(shape, dtype=x.dtype), 790 (self.zmax + 2, self.neurons[-1]), 791 )[species] 792 if self.use_bias: 793 b = self.param( 794 f"b_{len(self.neurons)}", 795 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 796 (self.zmax + 2, self.neurons[-1]), 797 )[species] 798 else: 799 b = 0 800 x = sig * x + b 801 if self.squeeze and x.shape[-1] == 1: 802 x = jnp.squeeze(x, axis=-1) 803 ############################ 804 805 if self.input_key is not None: 806 output_key = self.name if self.output_key is None else self.output_key 807 return {**inputs, output_key: x} if output_key is not None else x 808 return x
A fully connected neural network module with affine Z-dependent adjustments of activations.
FID: ZACNET
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
811class ZLoRANet(nn.Module): 812 """A fully connected neural network module with Z-dependent low-rank adaptation. 813 814 FID: ZLORANET 815 """ 816 817 neurons: Sequence[int] 818 """A sequence of integers representing the dimensions of the network.""" 819 ranks: Sequence[int] 820 """A sequence of integers representing the ranks of the low-rank adaptation at each layer.""" 821 zmax: int = 86 822 """The maximum atomic number to consider.""" 823 activation: Union[Callable, str] = "silu" 824 """The activation function to use.""" 825 use_bias: bool = True 826 """Whether to use bias in the dense layers.""" 827 input_key: Optional[str] = None 828 """The key of the input tensor.""" 829 output_key: Optional[str] = None 830 """The key of the output tensor.""" 831 squeeze: bool = False 832 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 833 kernel_init: Union[str, Callable] = "lecun_normal()" 834 """The kernel initialization method to use.""" 835 species_key: str = "species" 836 """The key of the species tensor.""" 837 838 FID: ClassVar[str] = "ZLORANET" 839 840 @nn.compact 841 def __call__(self, inputs: Union[dict, jax.Array]) -> Union[dict, jax.Array]: 842 if self.input_key is None: 843 assert not isinstance( 844 inputs, dict 845 ), "input key must be provided if inputs is a dictionary" 846 species, x = inputs 847 else: 848 species, x = inputs[self.species_key], inputs[self.input_key] 849 850 # activation = ( 851 # activation_from_str(self.activation) 852 # if isinstance(self.activation, str) 853 # else self.activation 854 # ) 855 kernel_init = ( 856 initializer_from_str(self.kernel_init) 857 if isinstance(self.kernel_init, str) 858 else self.kernel_init 859 ) 860 ############################ 861 for i, d in enumerate(self.neurons[:-1]): 862 xi = nn.Dense( 863 d, use_bias=self.use_bias, name=f"Layer_{i+1}", kernel_init=kernel_init 864 )(x) 865 A = self.param( 866 f"A_{i+1}", 867 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 868 (self.zmax + 2, self.ranks[i], x.shape[-1]), 869 )[species] 870 B = self.param( 871 f"B_{i+1}", 872 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 873 (self.zmax + 2, d, self.ranks[i]), 874 )[species] 875 Ax = jnp.einsum("zrd,zd->zr", A, x) 876 BAx = jnp.einsum("zrd,zd->zr", B, Ax) 877 x = activation_from_str(self.activation)(xi + BAx) 878 xi = nn.Dense( 879 self.neurons[-1], 880 use_bias=self.use_bias, 881 name=f"Layer_{len(self.neurons)}", 882 kernel_init=kernel_init, 883 )(x) 884 A = self.param( 885 f"A_{len(self.neurons)}", 886 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 887 (self.zmax + 2, self.ranks[-1], x.shape[-1]), 888 )[species] 889 B = self.param( 890 f"B_{len(self.neurons)}", 891 lambda key, shape: jnp.zeros(shape, dtype=x.dtype), 892 (self.zmax + 2, self.neurons[-1], self.ranks[-1]), 893 )[species] 894 Ax = jnp.einsum("zrd,zd->zr", A, x) 895 BAx = jnp.einsum("zrd,zd->zr", B, Ax) 896 x = xi + BAx 897 if self.squeeze and x.shape[-1] == 1: 898 x = jnp.squeeze(x, axis=-1) 899 ############################ 900 901 if self.input_key is not None: 902 output_key = self.name if self.output_key is None else self.output_key 903 return {**inputs, output_key: x} if output_key is not None else x 904 return x
A fully connected neural network module with Z-dependent low-rank adaptation.
FID: ZLORANET
A sequence of integers representing the ranks of the low-rank adaptation at each layer.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
907class BlockIndexNet(nn.Module): 908 """Chemical-species-specific neural network using precomputed species index. 909 910 FID: BLOCK_INDEX_NET 911 912 A neural network that applies a species-specific fully connected network to each atom embedding. 913 A species index must be provided to filter the embeddings for each species and apply the corresponding network. 914 This index can be obtained using the SPECIES_INDEXER preprocessing module from `fennol.models.preprocessing.SpeciesIndexer` 915 916 """ 917 918 output_dim: int 919 """The dimension of the output of the fully connected networks.""" 920 hidden_neurons: Sequence[int] 921 """The hidden dimensions of the fully connected networks. 922 If a dictionary is provided, it should map species names to dimensions. 923 If a sequence is provided, the same dimensions will be used for all species.""" 924 used_blocks: Optional[Sequence[str]] = None 925 """The blocks to use. If None, all blocks will be used.""" 926 activation: Union[Callable, str] = "silu" 927 """The activation function to use in the fully connected networks.""" 928 use_bias: bool = True 929 """Whether to include bias terms in the fully connected networks.""" 930 input_key: Optional[str] = None 931 """The key in the input dictionary that corresponds to the embeddings of the atoms.""" 932 block_index_key: str = "block_index" 933 """The key in the input dictionary that corresponds to the block index of the atoms. See `fennol.models.preprocessing.BlockIndexer`""" 934 output_key: Optional[str] = None 935 """The key in the output dictionary that corresponds to the network's output.""" 936 937 squeeze: bool = False 938 """Whether to remove the last axis of the output tensor if it is of dimension 1.""" 939 kernel_init: Union[str, Callable] = "lecun_normal()" 940 """The kernel initialization method for the fully connected networks.""" 941 # check_unhandled: bool = True 942 943 FID: ClassVar[str] = "BLOCK_INDEX_NET" 944 945 # def setup(self): 946 # all_blocks = CHEMICAL_BLOCKS_NAMES 947 # if self.used_blocks is None: 948 # used_blocks = all_blocks 949 # else: 950 # used_blocks = [] 951 # for b in self.used_blocks: 952 # b_=str(b).strip().upper() 953 # if b_ not in all_blocks: 954 # raise ValueError(f"Block {b} not found in {all_blocks}") 955 # used_blocks.append(b_) 956 # used_blocks = set(used_blocks) 957 # self._used_blocks = used_blocks 958 959 # if not ( 960 # isinstance(self.hidden_neurons, dict) 961 # or isinstance(self.hidden_neurons, FrozenDict) 962 # ): 963 # neurons = {k: self.hidden_neurons for k in used_blocks} 964 # else: 965 # neurons = {} 966 # for b in self.hidden_neurons.keys(): 967 # b_=str(b).strip().upper() 968 # if b_ not in all_blocks: 969 # raise ValueError(f"Block {b} does not exist. Available blocks are {all_blocks}") 970 # neurons[b_] = self.hidden_neurons[b] 971 # used_blocks = set(neurons.keys()) 972 # if used_blocks != self._used_blocks and self.used_blocks is not None: 973 # print( 974 # f"Warning: hidden neurons definitions do not match specified used_blocks {self.used_blocks}. Using blocks defined in hidden_neurons.") 975 # self._used_blocks = used_blocks 976 977 # self.networks = { 978 # k: FullyConnectedNet( 979 # [*neurons[k], self.output_dim], 980 # self.activation, 981 # self.use_bias, 982 # name=k, 983 # kernel_init=self.kernel_init, 984 # ) 985 # for k in self._used_blocks 986 # } 987 988 @nn.compact 989 def __call__( 990 self, inputs: Union[dict, Tuple[jax.Array, jax.Array]] 991 ) -> Union[dict, jax.Array]: 992 993 if self.input_key is None: 994 assert not isinstance( 995 inputs, dict 996 ), "input key must be provided if inputs is a dictionary" 997 species, embedding, block_index = inputs 998 else: 999 species, embedding = inputs["species"], inputs[self.input_key] 1000 block_index = inputs[self.block_index_key] 1001 1002 assert isinstance( 1003 block_index, dict 1004 ), "block_index must be a dictionary for BlockIndexNet" 1005 1006 networks = { 1007 k: FullyConnectedNet( 1008 [*self.hidden_neurons, self.output_dim], 1009 self.activation, 1010 self.use_bias, 1011 name=k, 1012 kernel_init=self.kernel_init, 1013 ) 1014 for k in block_index.keys() 1015 } 1016 1017 ############################ 1018 # initialization => instantiate all networks 1019 if self.is_initializing(): 1020 x = jnp.zeros((1, embedding.shape[-1]), dtype=embedding.dtype) 1021 [net(x) for net in networks.values()] 1022 1023 # if self.check_unhandled: 1024 # for b in block_index.keys(): 1025 # if b not in networks.keys(): 1026 # raise ValueError(f"Block {b} not found in networks. Available blocks are {self.networks.keys()}") 1027 1028 ############################ 1029 outputs = [] 1030 indices = [] 1031 for s, net in networks.items(): 1032 if s not in block_index: 1033 continue 1034 if block_index[s] is None: 1035 continue 1036 idx = block_index[s] 1037 o = net(embedding[idx]) 1038 outputs.append(o) 1039 indices.append(idx) 1040 1041 o = jnp.concatenate(outputs, axis=0) 1042 idx = jnp.concatenate(indices, axis=0) 1043 1044 out = ( 1045 jnp.zeros((species.shape[0], *o.shape[1:]), dtype=o.dtype) 1046 .at[idx] 1047 .set(o, mode="drop") 1048 ) 1049 1050 if self.squeeze and out.shape[-1] == 1: 1051 out = jnp.squeeze(out, axis=-1) 1052 ############################ 1053 1054 if self.input_key is not None: 1055 output_key = self.name if self.output_key is None else self.output_key 1056 return {**inputs, output_key: out} if output_key is not None else out 1057 return out
Chemical-species-specific neural network using precomputed species index.
FID: BLOCK_INDEX_NET
A neural network that applies a species-specific fully connected network to each atom embedding.
A species index must be provided to filter the embeddings for each species and apply the corresponding network.
This index can be obtained using the SPECIES_INDEXER preprocessing module from fennol.models.preprocessing.SpeciesIndexer
The activation function to use in the fully connected networks.
The key in the input dictionary that corresponds to the embeddings of the atoms.
The key in the input dictionary that corresponds to the block index of the atoms. See fennol.models.preprocessing.BlockIndexer
The key in the output dictionary that corresponds to the network's output.
The kernel initialization method for the fully connected networks.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.