fennol.models.embeddings.mace
1import functools 2import math 3import jax 4import jax.numpy as jnp 5import flax.linen as nn 6from typing import Sequence, Dict, Union, ClassVar, Optional, Set 7import dataclasses 8 9from ..misc.encodings import RadialBasis 10from ...utils.activations import activation_from_str 11 12 13try: 14 import e3nn_jax as e3nn 15 16 E3NN_AVAILABLE = True 17 E3NN_EXCEPTION = None 18 Irreps = e3nn.Irreps 19 Irrep = e3nn.Irrep 20except Exception as e: 21 E3NN_AVAILABLE = False 22 E3NN_EXCEPTION = e 23 e3nn = None 24 25 class Irreps(tuple): 26 pass 27 28 class Irrep(tuple): 29 pass 30 31 32class MACE(nn.Module): 33 """MACE equivariant message passing neural network. 34 35 adapted from MACE-jax github repo by M. Geiger and I. Batatia 36 37 T. Plé reordered some operations and changed defaults to match the recent mace-torch version 38 -> compatibility with pretrained torch models requires some work on the parameters: 39 - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling 40 - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors 41 - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters 42 43 References: 44 - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436. 45 https://doi.org/10.48550/arXiv.2206.07697 46 - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022). 47 https://doi.org/10.48550/arXiv.2205.06643 48 49 """ 50 _graphs_properties: Dict 51 output_irreps: Union[Irreps, str] = "1x0e" 52 """The output irreps of the model.""" 53 hidden_irreps: Union[Irreps, str] = "128x0e + 128x1o" 54 """The hidden irreps of the model.""" 55 readout_mlp_irreps: Union[Irreps, str] = "16x0e" 56 """The hidden irreps of the readout MLP.""" 57 graph_key: str = "graph" 58 """The key in the input dictionary that corresponds to the molecular graph to use.""" 59 output_key: Optional[str] = None 60 """The key of the embedding in the output dictionary.""" 61 avg_num_neighbors: float = 1.0 62 """The expected average number of neighbors.""" 63 ninteractions: int = 2 64 """The number of interaction layers.""" 65 num_features: Optional[int] = None 66 """The number of features per node. default gcd of hidden_irreps multiplicities""" 67 radial_basis: dict = dataclasses.field( 68 default_factory=lambda: {"basis": "bessel", "dim": 8, "trainable": False} 69 ) 70 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 71 lmax: int = 1 72 """The maximum angular momentum to consider.""" 73 correlation: int = 3 74 """The correlation order at each layer.""" 75 activation: str = "silu" 76 """The activation function to use.""" 77 symmetric_tensor_product_basis: bool = False 78 """Whether to use the symmetric tensor product basis.""" 79 interaction_irreps: Union[Irreps, str] = "o3_restricted" 80 skip_connection_first_layer: bool = True 81 radial_network_hidden: Sequence[int] = dataclasses.field( 82 default_factory=lambda: [64, 64, 64] 83 ) 84 scalar_output: bool = False 85 zmax: int = 86 86 """The maximum atomic number to consider.""" 87 convolution_mode: int = 1 88 89 FID: ClassVar[str] = "MACE" 90 91 @nn.compact 92 def __call__(self, inputs): 93 if not E3NN_AVAILABLE: 94 raise E3NN_EXCEPTION 95 96 species_indices = inputs["species"] 97 graph = inputs[self.graph_key] 98 distances = graph["distances"] 99 vec = e3nn.IrrepsArray("1o", graph["vec"]) 100 switch = graph["switch"] 101 edge_src = graph["edge_src"] 102 edge_dst = graph["edge_dst"] 103 104 output_irreps = e3nn.Irreps(self.output_irreps) 105 hidden_irreps = e3nn.Irreps(self.hidden_irreps) 106 readout_mlp_irreps = e3nn.Irreps(self.readout_mlp_irreps) 107 108 # extract or set num_features 109 if self.num_features is None: 110 num_features = functools.reduce(math.gcd, (mul for mul, _ in hidden_irreps)) 111 hidden_irreps = e3nn.Irreps( 112 [(mul // num_features, ir) for mul, ir in hidden_irreps] 113 ) 114 else: 115 num_features = self.num_features 116 117 # get interaction irreps 118 if self.interaction_irreps == "o3_restricted": 119 interaction_irreps = e3nn.Irreps.spherical_harmonics(self.lmax) 120 elif self.interaction_irreps == "o3_full": 121 interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(self.lmax)) 122 else: 123 interaction_irreps = e3nn.Irreps(self.interaction_irreps) 124 convol_irreps = num_features * interaction_irreps 125 126 # convert species to internal indices 127 # maxidx = max(PERIODIC_TABLE_REV_IDX.values()) 128 # conv_tensor = [0] * (maxidx + 2) 129 # if isinstance(self.species_order, str): 130 # species_order = [el.strip() for el in self.species_order.split(",")] 131 # else: 132 # species_order = [el for el in self.species_order] 133 # for i, s in enumerate(species_order): 134 # conv_tensor[PERIODIC_TABLE_REV_IDX[s]] = i 135 # species_indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species] 136 num_species = self.zmax + 2 137 138 # species encoding 139 encoding_irreps: e3nn.Irreps = ( 140 (num_features * hidden_irreps).filter("0e").regroup() 141 ) 142 species_encoding = self.param( 143 "species_encoding", 144 lambda key, shape: jax.nn.standardize( 145 jax.random.normal(key, shape, dtype=jnp.float32) 146 ), 147 (num_species, encoding_irreps.dim), 148 )[species_indices] 149 # convert to IrrepsArray 150 node_feats = e3nn.IrrepsArray(encoding_irreps, species_encoding) 151 152 # radial embedding 153 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 154 radial_embedding = ( 155 RadialBasis( 156 **{ 157 **self.radial_basis, 158 "end": cutoff, 159 "name": f"RadialBasis", 160 } 161 )(distances) 162 * switch[:, None] 163 ) 164 165 # spherical harmonics 166 assert self.convolution_mode in [0,1,2], "convolution_mode must be 0, 1 or 2" 167 if self.convolution_mode == 0: 168 Yij = e3nn.spherical_harmonics(range(0, self.lmax + 1), vec, True) 169 elif self.convolution_mode == 1: 170 Yij = e3nn.spherical_harmonics(range(1, self.lmax + 1), vec, True) 171 172 outputs = [] 173 node_feats_all = [] 174 for layer in range(self.ninteractions): 175 first = layer == 0 176 last = layer == self.ninteractions - 1 177 178 layer_irreps = num_features * ( 179 hidden_irreps if not last else hidden_irreps.filter(output_irreps) 180 ) 181 182 # Linear skip connection 183 sc = None 184 if not first or self.skip_connection_first_layer: 185 sc = e3nn.flax.Linear( 186 layer_irreps, 187 num_indexed_weights=num_species, 188 name=f"skip_tp_{layer}", 189 force_irreps_out=True, 190 )(species_indices, node_feats) 191 192 ################################################ 193 # Interaction block (Message passing convolution) 194 node_feats = e3nn.flax.Linear(node_feats.irreps, name=f"linear_up_{layer}")( 195 node_feats 196 ) 197 198 199 messages = node_feats[edge_src] 200 if self.convolution_mode == 0: 201 messages = e3nn.tensor_product( 202 messages, 203 Yij, 204 filter_ir_out=convol_irreps, 205 regroup_output=True, 206 ) 207 elif self.convolution_mode == 1: 208 messages = e3nn.concatenate( 209 [ 210 messages.filter(convol_irreps), 211 e3nn.tensor_product( 212 messages, 213 Yij, 214 filter_ir_out=convol_irreps, 215 ), 216 # e3nn.tensor_product_with_spherical_harmonics( 217 # messages, vectors, self.max_ell 218 # ).filter(convol_irreps), 219 ] 220 ).regroup() 221 else: 222 messages = e3nn.tensor_product_with_spherical_harmonics( 223 messages, vec, self.lmax 224 ).filter(convol_irreps).regroup() 225 226 # mix = FullyConnectedNet( 227 # [*self.radial_network_hidden, messages.irreps.num_irreps], 228 # activation=activation_from_str(self.activation), 229 # name=f"radial_network_{layer}", 230 # use_bias=False, 231 # )(radial_embedding) 232 mix = e3nn.flax.MultiLayerPerceptron( 233 [*self.radial_network_hidden, messages.irreps.num_irreps], 234 act=activation_from_str(self.activation), 235 output_activation=False, 236 name=f"radial_network_{layer}", 237 gradient_normalization="element", 238 )( 239 radial_embedding 240 ) 241 242 messages = messages * mix 243 node_feats = ( 244 e3nn.IrrepsArray.zeros( 245 messages.irreps, node_feats.shape[:1], messages.dtype 246 ) 247 .at[edge_dst] 248 .add(messages) 249 ) 250 # print("irreps_mid jax",node_feats.irreps) 251 # jax.debug.print("node_feats={n}", n=jnp.sum(node_feats.array,axis=0)[550:570]) 252 253 node_feats = ( 254 e3nn.flax.Linear(convol_irreps, name=f"linear_dn_{layer}")(node_feats) 255 / self.avg_num_neighbors 256 ) 257 258 if first and not self.skip_connection_first_layer: 259 node_feats = e3nn.flax.Linear( 260 node_feats.irreps, 261 num_indexed_weights=num_species, 262 name=f"skip_tp_{layer}", 263 )(species_indices, node_feats) 264 265 ################################################ 266 # Equivariant product basis block 267 268 # symmetric contractions 269 node_feats = SymmetricContraction( 270 keep_irrep_out={ir for _, ir in layer_irreps}, 271 correlation=self.correlation, 272 num_species=num_species, 273 gradient_normalization="element", # NOTE: This is to copy mace-torch 274 symmetric_tensor_product_basis=self.symmetric_tensor_product_basis, 275 )( 276 node_feats, species_indices 277 ) 278 279 280 node_feats = e3nn.flax.Linear( 281 layer_irreps, name=f"linear_contraction_{layer}" 282 )(node_feats) 283 284 285 if sc is not None: 286 # add skip connection 287 node_feats = node_feats + sc 288 289 290 ################################################ 291 292 # Readout block 293 if last: 294 num_vectors = readout_mlp_irreps.filter(drop=["0e", "0o"]).num_irreps 295 layer_out = e3nn.flax.Linear( 296 (readout_mlp_irreps + e3nn.Irreps(f"{num_vectors}x0e")).simplify(), 297 name=f"hidden_linear_readout_last", 298 )(node_feats) 299 layer_out = e3nn.gate( 300 layer_out, 301 even_act=activation_from_str(self.activation), 302 even_gate_act=None, 303 ) 304 layer_out = e3nn.flax.Linear( 305 output_irreps, name=f"linear_readout_last" 306 )(layer_out) 307 else: 308 layer_out = e3nn.flax.Linear( 309 output_irreps, 310 name=f"linear_readout_{layer}", 311 )(node_feats) 312 313 if self.scalar_output: 314 layer_out = layer_out.filter("0e").array 315 316 outputs.append(layer_out) 317 node_feats_all.append(node_feats.filter("0e").array) 318 319 if self.scalar_output: 320 output = jnp.stack(outputs, axis=1) 321 else: 322 output = e3nn.stack(outputs, axis=1) 323 324 node_feats_all = jnp.concatenate(node_feats_all, axis=-1) 325 326 output_key = self.output_key if self.output_key is not None else self.name 327 return { 328 **inputs, 329 output_key: output, 330 output_key + "_node_feats": node_feats_all, 331 } 332 333 334class SymmetricContraction(nn.Module): 335 336 correlation: int 337 keep_irrep_out: Set[Irrep] 338 num_species: int 339 gradient_normalization: Union[str, float] 340 symmetric_tensor_product_basis: bool 341 342 @nn.compact 343 def __call__(self, input, index): 344 if not E3NN_AVAILABLE: 345 raise E3NN_EXCEPTION 346 347 if self.gradient_normalization is None: 348 gradient_normalization = e3nn.config("gradient_normalization") 349 else: 350 gradient_normalization = self.gradient_normalization 351 if isinstance(gradient_normalization, str): 352 gradient_normalization = {"element": 0.0, "path": 1.0}[ 353 gradient_normalization 354 ] 355 356 keep_irrep_out = self.keep_irrep_out 357 if isinstance(keep_irrep_out, str): 358 keep_irrep_out = e3nn.Irreps(keep_irrep_out) 359 assert all(mul == 1 for mul, _ in keep_irrep_out) 360 361 keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out} 362 363 input = input.mul_to_axis().remove_nones() 364 365 ### PREPARE WEIGHTS 366 ws = [] 367 Us = [] 368 for order in range(1, self.correlation + 1): # correlation, ..., 1 369 if self.symmetric_tensor_product_basis: 370 U = e3nn.reduced_symmetric_tensor_product_basis( 371 input.irreps, order, keep_ir=keep_irrep_out 372 ) 373 else: 374 U = e3nn.reduced_tensor_product_basis( 375 [input.irreps] * order, keep_ir=keep_irrep_out 376 ) 377 # U = U / order # normalization TODO(mario): put back after testing 378 # NOTE(mario): The normalization constants (/order and /mul**0.5) 379 # has been numerically checked to be correct. 380 381 # TODO(mario) implement norm_p 382 Us.append(U) 383 384 wsorder = [] 385 for (mul, ir_out), u in zip(U.irreps, U.list): 386 u = u.astype(input.array.dtype) 387 # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] 388 389 w = self.param( 390 f"w{order}_{ir_out}", 391 nn.initializers.normal( 392 stddev=(mul**-0.5) ** (1.0 - gradient_normalization) 393 ), 394 (self.num_species, mul, input.shape[-2]), 395 ) 396 w = w * (mul**-0.5) ** gradient_normalization 397 wsorder.append(w) 398 ws.append(wsorder) 399 400 def fn(input: e3nn.IrrepsArray, index: jnp.ndarray): 401 # - This operation is parallel on the feature dimension (but each feature has its own parameters) 402 # This operation is an efficient implementation of 403 # vmap(lambda w, x: FunctionalLinear(irreps_out)(w, concatenate([x, tensor_product(x, x), tensor_product(x, x, x), ...])))(w, x) 404 # up to x power self.correlation 405 assert input.ndim == 2 # [num_features, irreps_x.dim] 406 assert index.ndim == 0 # int 407 408 out = dict() 409 x_ = input.array 410 411 for order in range(self.correlation, 0, -1): # correlation, ..., 1 412 413 U = Us[order - 1] 414 415 # ((w3 x + w2) x + w1) x 416 # \-----------/ 417 # out 418 419 for ii, ((mul, ir_out), u) in enumerate(zip(U.irreps, U.list)): 420 u = u.astype(x_.dtype) 421 # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] 422 423 w = ws[order - 1][ii][index] 424 if ir_out not in out: 425 out[ir_out] = ( 426 "special", 427 jnp.einsum("...jki,kc,cj->c...i", u, w, x_), 428 ) # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim] 429 else: 430 out[ir_out] += jnp.einsum( 431 "...ki,kc->c...i", u, w 432 ) # [num_features, (irreps_x.dim)^order, ir_out.dim] 433 434 # ((w3 x + w2) x + w1) x 435 # \----------------/ 436 # out (in the normal case) 437 438 for ir_out in out: 439 if isinstance(out[ir_out], tuple): 440 out[ir_out] = out[ir_out][1] 441 continue # already done (special case optimization above) 442 443 out[ir_out] = jnp.einsum( 444 "c...ji,cj->c...i", out[ir_out], x_ 445 ) # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim] 446 447 # ((w3 x + w2) x + w1) x 448 # \-------------------/ 449 # out 450 451 # out[irrep_out] : [num_features, ir_out.dim] 452 irreps_out = e3nn.Irreps(sorted(out.keys())) 453 return e3nn.IrrepsArray.from_list( 454 irreps_out, 455 [out[ir][:, None, :] for (_, ir) in irreps_out], 456 (input.shape[0],), 457 ) 458 459 # Treat batch indices using vmap 460 shape = jnp.broadcast_shapes(input.shape[:-2], index.shape) 461 input = input.broadcast_to(shape + input.shape[-2:]) 462 index = jnp.broadcast_to(index, shape) 463 464 fn_mapped = fn 465 for _ in range(input.ndim - 2): 466 fn_mapped = jax.vmap(fn_mapped) 467 468 return fn_mapped(input, index).axis_to_mul() 469 470 471# class SymmetricContraction(nn.Module): 472 473# correlation: int 474# keep_irrep_out: Set[Irrep] 475# num_species: int 476# gradient_normalization: Union[str, float] 477# symmetric_tensor_product_basis: bool 478 479# @nn.compact 480# def __call__(self, input: IrrepsArray, index: jnp.ndarray): 481# if not E3NN_AVAILABLE: 482# raise E3NN_EXCEPTION 483 484# if self.gradient_normalization is None: 485# gradient_normalization = e3nn.config("gradient_normalization") 486# else: 487# gradient_normalization = self.gradient_normalization 488# if isinstance(gradient_normalization, str): 489# gradient_normalization = {"element": 0.0, "path": 1.0}[ 490# gradient_normalization 491# ] 492 493# keep_irrep_out = self.keep_irrep_out 494# if isinstance(keep_irrep_out, str): 495# keep_irrep_out = e3nn.Irreps(keep_irrep_out) 496# assert all(mul == 1 for mul, _ in keep_irrep_out) 497 498# keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out} 499 500# onehot = jnp.eye(self.num_species)[index] 501 502# ### PREPARE WEIGHTS 503# ws = [] 504# us = [] 505# for ir_out in keep_irrep_out: 506# usorder = [] 507# wsorder = [] 508# for order in range(1, self.correlation + 1): # correlation, ..., 1 509# if self.symmetric_tensor_product_basis: 510# U = e3nn.reduced_symmetric_tensor_product_basis( 511# input.irreps, order, keep_ir=[ir_out] 512# ) 513# else: 514# U = e3nn.reduced_tensor_product_basis( 515# [input.irreps] * order, keep_ir=[ir_out] 516# ) 517# u = jnp.moveaxis(U.list[0].astype(input.array.dtype), -1, 0) 518# usorder.append(u) 519 520# mul, _ = U.irreps[0] 521# w = self.param( 522# f"w{order}_{ir_out}", 523# nn.initializers.normal( 524# stddev=(mul**-0.5) ** (1.0 - gradient_normalization) 525# ), 526# (self.num_species, mul, input.shape[-2]), 527# ) 528# w = w * (mul**-0.5) ** gradient_normalization 529# wsorder.append(w) 530# ws.append(wsorder) 531# us.append(usorder) 532 533# x = input.array 534 535# outs = [] 536# for i, ir in enumerate(keep_irrep_out): 537# w = ws[i][-1] # [index] 538# u = us[i][-1] 539# out = jnp.einsum("...jk,ekc,bcj,be->bc...", u, w, x, onehot) 540 541# for order in range(self.correlation - 1, 0, -1): 542# w = ws[i][order - 1] # [index] 543# u = us[i][order - 1] 544 545# c_tensor = jnp.einsum("...k,ekc,be->bc...", u, w, onehot) + out 546# out = jnp.einsum("bc...j,bcj->bc...", c_tensor, x) 547 548# outs.append(out.reshape(x.shape[0], -1)) 549 550# out = jnp.concatenate(outs, axis=-1) 551 552# return e3nn.IrrepsArray(input.shape[1] * e3nn.Irreps(keep_irrep_out), out)
33class MACE(nn.Module): 34 """MACE equivariant message passing neural network. 35 36 adapted from MACE-jax github repo by M. Geiger and I. Batatia 37 38 T. Plé reordered some operations and changed defaults to match the recent mace-torch version 39 -> compatibility with pretrained torch models requires some work on the parameters: 40 - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling 41 - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors 42 - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters 43 44 References: 45 - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436. 46 https://doi.org/10.48550/arXiv.2206.07697 47 - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022). 48 https://doi.org/10.48550/arXiv.2205.06643 49 50 """ 51 _graphs_properties: Dict 52 output_irreps: Union[Irreps, str] = "1x0e" 53 """The output irreps of the model.""" 54 hidden_irreps: Union[Irreps, str] = "128x0e + 128x1o" 55 """The hidden irreps of the model.""" 56 readout_mlp_irreps: Union[Irreps, str] = "16x0e" 57 """The hidden irreps of the readout MLP.""" 58 graph_key: str = "graph" 59 """The key in the input dictionary that corresponds to the molecular graph to use.""" 60 output_key: Optional[str] = None 61 """The key of the embedding in the output dictionary.""" 62 avg_num_neighbors: float = 1.0 63 """The expected average number of neighbors.""" 64 ninteractions: int = 2 65 """The number of interaction layers.""" 66 num_features: Optional[int] = None 67 """The number of features per node. default gcd of hidden_irreps multiplicities""" 68 radial_basis: dict = dataclasses.field( 69 default_factory=lambda: {"basis": "bessel", "dim": 8, "trainable": False} 70 ) 71 """The dictionary of parameters for radial basis functions. See `fennol.models.misc.encodings.RadialBasis`.""" 72 lmax: int = 1 73 """The maximum angular momentum to consider.""" 74 correlation: int = 3 75 """The correlation order at each layer.""" 76 activation: str = "silu" 77 """The activation function to use.""" 78 symmetric_tensor_product_basis: bool = False 79 """Whether to use the symmetric tensor product basis.""" 80 interaction_irreps: Union[Irreps, str] = "o3_restricted" 81 skip_connection_first_layer: bool = True 82 radial_network_hidden: Sequence[int] = dataclasses.field( 83 default_factory=lambda: [64, 64, 64] 84 ) 85 scalar_output: bool = False 86 zmax: int = 86 87 """The maximum atomic number to consider.""" 88 convolution_mode: int = 1 89 90 FID: ClassVar[str] = "MACE" 91 92 @nn.compact 93 def __call__(self, inputs): 94 if not E3NN_AVAILABLE: 95 raise E3NN_EXCEPTION 96 97 species_indices = inputs["species"] 98 graph = inputs[self.graph_key] 99 distances = graph["distances"] 100 vec = e3nn.IrrepsArray("1o", graph["vec"]) 101 switch = graph["switch"] 102 edge_src = graph["edge_src"] 103 edge_dst = graph["edge_dst"] 104 105 output_irreps = e3nn.Irreps(self.output_irreps) 106 hidden_irreps = e3nn.Irreps(self.hidden_irreps) 107 readout_mlp_irreps = e3nn.Irreps(self.readout_mlp_irreps) 108 109 # extract or set num_features 110 if self.num_features is None: 111 num_features = functools.reduce(math.gcd, (mul for mul, _ in hidden_irreps)) 112 hidden_irreps = e3nn.Irreps( 113 [(mul // num_features, ir) for mul, ir in hidden_irreps] 114 ) 115 else: 116 num_features = self.num_features 117 118 # get interaction irreps 119 if self.interaction_irreps == "o3_restricted": 120 interaction_irreps = e3nn.Irreps.spherical_harmonics(self.lmax) 121 elif self.interaction_irreps == "o3_full": 122 interaction_irreps = e3nn.Irreps(e3nn.Irrep.iterator(self.lmax)) 123 else: 124 interaction_irreps = e3nn.Irreps(self.interaction_irreps) 125 convol_irreps = num_features * interaction_irreps 126 127 # convert species to internal indices 128 # maxidx = max(PERIODIC_TABLE_REV_IDX.values()) 129 # conv_tensor = [0] * (maxidx + 2) 130 # if isinstance(self.species_order, str): 131 # species_order = [el.strip() for el in self.species_order.split(",")] 132 # else: 133 # species_order = [el for el in self.species_order] 134 # for i, s in enumerate(species_order): 135 # conv_tensor[PERIODIC_TABLE_REV_IDX[s]] = i 136 # species_indices = jnp.asarray(conv_tensor, dtype=jnp.int32)[species] 137 num_species = self.zmax + 2 138 139 # species encoding 140 encoding_irreps: e3nn.Irreps = ( 141 (num_features * hidden_irreps).filter("0e").regroup() 142 ) 143 species_encoding = self.param( 144 "species_encoding", 145 lambda key, shape: jax.nn.standardize( 146 jax.random.normal(key, shape, dtype=jnp.float32) 147 ), 148 (num_species, encoding_irreps.dim), 149 )[species_indices] 150 # convert to IrrepsArray 151 node_feats = e3nn.IrrepsArray(encoding_irreps, species_encoding) 152 153 # radial embedding 154 cutoff = self._graphs_properties[self.graph_key]["cutoff"] 155 radial_embedding = ( 156 RadialBasis( 157 **{ 158 **self.radial_basis, 159 "end": cutoff, 160 "name": f"RadialBasis", 161 } 162 )(distances) 163 * switch[:, None] 164 ) 165 166 # spherical harmonics 167 assert self.convolution_mode in [0,1,2], "convolution_mode must be 0, 1 or 2" 168 if self.convolution_mode == 0: 169 Yij = e3nn.spherical_harmonics(range(0, self.lmax + 1), vec, True) 170 elif self.convolution_mode == 1: 171 Yij = e3nn.spherical_harmonics(range(1, self.lmax + 1), vec, True) 172 173 outputs = [] 174 node_feats_all = [] 175 for layer in range(self.ninteractions): 176 first = layer == 0 177 last = layer == self.ninteractions - 1 178 179 layer_irreps = num_features * ( 180 hidden_irreps if not last else hidden_irreps.filter(output_irreps) 181 ) 182 183 # Linear skip connection 184 sc = None 185 if not first or self.skip_connection_first_layer: 186 sc = e3nn.flax.Linear( 187 layer_irreps, 188 num_indexed_weights=num_species, 189 name=f"skip_tp_{layer}", 190 force_irreps_out=True, 191 )(species_indices, node_feats) 192 193 ################################################ 194 # Interaction block (Message passing convolution) 195 node_feats = e3nn.flax.Linear(node_feats.irreps, name=f"linear_up_{layer}")( 196 node_feats 197 ) 198 199 200 messages = node_feats[edge_src] 201 if self.convolution_mode == 0: 202 messages = e3nn.tensor_product( 203 messages, 204 Yij, 205 filter_ir_out=convol_irreps, 206 regroup_output=True, 207 ) 208 elif self.convolution_mode == 1: 209 messages = e3nn.concatenate( 210 [ 211 messages.filter(convol_irreps), 212 e3nn.tensor_product( 213 messages, 214 Yij, 215 filter_ir_out=convol_irreps, 216 ), 217 # e3nn.tensor_product_with_spherical_harmonics( 218 # messages, vectors, self.max_ell 219 # ).filter(convol_irreps), 220 ] 221 ).regroup() 222 else: 223 messages = e3nn.tensor_product_with_spherical_harmonics( 224 messages, vec, self.lmax 225 ).filter(convol_irreps).regroup() 226 227 # mix = FullyConnectedNet( 228 # [*self.radial_network_hidden, messages.irreps.num_irreps], 229 # activation=activation_from_str(self.activation), 230 # name=f"radial_network_{layer}", 231 # use_bias=False, 232 # )(radial_embedding) 233 mix = e3nn.flax.MultiLayerPerceptron( 234 [*self.radial_network_hidden, messages.irreps.num_irreps], 235 act=activation_from_str(self.activation), 236 output_activation=False, 237 name=f"radial_network_{layer}", 238 gradient_normalization="element", 239 )( 240 radial_embedding 241 ) 242 243 messages = messages * mix 244 node_feats = ( 245 e3nn.IrrepsArray.zeros( 246 messages.irreps, node_feats.shape[:1], messages.dtype 247 ) 248 .at[edge_dst] 249 .add(messages) 250 ) 251 # print("irreps_mid jax",node_feats.irreps) 252 # jax.debug.print("node_feats={n}", n=jnp.sum(node_feats.array,axis=0)[550:570]) 253 254 node_feats = ( 255 e3nn.flax.Linear(convol_irreps, name=f"linear_dn_{layer}")(node_feats) 256 / self.avg_num_neighbors 257 ) 258 259 if first and not self.skip_connection_first_layer: 260 node_feats = e3nn.flax.Linear( 261 node_feats.irreps, 262 num_indexed_weights=num_species, 263 name=f"skip_tp_{layer}", 264 )(species_indices, node_feats) 265 266 ################################################ 267 # Equivariant product basis block 268 269 # symmetric contractions 270 node_feats = SymmetricContraction( 271 keep_irrep_out={ir for _, ir in layer_irreps}, 272 correlation=self.correlation, 273 num_species=num_species, 274 gradient_normalization="element", # NOTE: This is to copy mace-torch 275 symmetric_tensor_product_basis=self.symmetric_tensor_product_basis, 276 )( 277 node_feats, species_indices 278 ) 279 280 281 node_feats = e3nn.flax.Linear( 282 layer_irreps, name=f"linear_contraction_{layer}" 283 )(node_feats) 284 285 286 if sc is not None: 287 # add skip connection 288 node_feats = node_feats + sc 289 290 291 ################################################ 292 293 # Readout block 294 if last: 295 num_vectors = readout_mlp_irreps.filter(drop=["0e", "0o"]).num_irreps 296 layer_out = e3nn.flax.Linear( 297 (readout_mlp_irreps + e3nn.Irreps(f"{num_vectors}x0e")).simplify(), 298 name=f"hidden_linear_readout_last", 299 )(node_feats) 300 layer_out = e3nn.gate( 301 layer_out, 302 even_act=activation_from_str(self.activation), 303 even_gate_act=None, 304 ) 305 layer_out = e3nn.flax.Linear( 306 output_irreps, name=f"linear_readout_last" 307 )(layer_out) 308 else: 309 layer_out = e3nn.flax.Linear( 310 output_irreps, 311 name=f"linear_readout_{layer}", 312 )(node_feats) 313 314 if self.scalar_output: 315 layer_out = layer_out.filter("0e").array 316 317 outputs.append(layer_out) 318 node_feats_all.append(node_feats.filter("0e").array) 319 320 if self.scalar_output: 321 output = jnp.stack(outputs, axis=1) 322 else: 323 output = e3nn.stack(outputs, axis=1) 324 325 node_feats_all = jnp.concatenate(node_feats_all, axis=-1) 326 327 output_key = self.output_key if self.output_key is not None else self.name 328 return { 329 **inputs, 330 output_key: output, 331 output_key + "_node_feats": node_feats_all, 332 }
MACE equivariant message passing neural network.
adapted from MACE-jax github repo by M. Geiger and I. Batatia
T. Plé reordered some operations and changed defaults to match the recent mace-torch version -> compatibility with pretrained torch models requires some work on the parameters: - normalization of activation functions in e3nn differ between jax and pytorch => need rescaling - multiplicity ordering and signs of U matrices in SymmetricContraction differ => need to reorder and flip signs in the weight tensors - we use a maximum Z instead of a list of species => need to adapt species-dependent parameters
References: - I. Batatia et al. "MACE: Higher order equivariant message passing neural networks for fast and accurate force fields." Advances in Neural Information Processing Systems 35 (2022): 11423-11436. https://doi.org/10.48550/arXiv.2206.07697 - I. Batatia et al. "The design space of e(3)-equivariant atom-centered interatomic potentials." arXiv preprint arXiv:2205.06643 (2022). https://doi.org/10.48550/arXiv.2205.06643
The hidden irreps of the readout MLP.
The key in the input dictionary that corresponds to the molecular graph to use.
The number of features per node. default gcd of hidden_irreps multiplicities
The dictionary of parameters for radial basis functions. See fennol.models.misc.encodings.RadialBasis
.
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.
335class SymmetricContraction(nn.Module): 336 337 correlation: int 338 keep_irrep_out: Set[Irrep] 339 num_species: int 340 gradient_normalization: Union[str, float] 341 symmetric_tensor_product_basis: bool 342 343 @nn.compact 344 def __call__(self, input, index): 345 if not E3NN_AVAILABLE: 346 raise E3NN_EXCEPTION 347 348 if self.gradient_normalization is None: 349 gradient_normalization = e3nn.config("gradient_normalization") 350 else: 351 gradient_normalization = self.gradient_normalization 352 if isinstance(gradient_normalization, str): 353 gradient_normalization = {"element": 0.0, "path": 1.0}[ 354 gradient_normalization 355 ] 356 357 keep_irrep_out = self.keep_irrep_out 358 if isinstance(keep_irrep_out, str): 359 keep_irrep_out = e3nn.Irreps(keep_irrep_out) 360 assert all(mul == 1 for mul, _ in keep_irrep_out) 361 362 keep_irrep_out = {e3nn.Irrep(ir) for ir in keep_irrep_out} 363 364 input = input.mul_to_axis().remove_nones() 365 366 ### PREPARE WEIGHTS 367 ws = [] 368 Us = [] 369 for order in range(1, self.correlation + 1): # correlation, ..., 1 370 if self.symmetric_tensor_product_basis: 371 U = e3nn.reduced_symmetric_tensor_product_basis( 372 input.irreps, order, keep_ir=keep_irrep_out 373 ) 374 else: 375 U = e3nn.reduced_tensor_product_basis( 376 [input.irreps] * order, keep_ir=keep_irrep_out 377 ) 378 # U = U / order # normalization TODO(mario): put back after testing 379 # NOTE(mario): The normalization constants (/order and /mul**0.5) 380 # has been numerically checked to be correct. 381 382 # TODO(mario) implement norm_p 383 Us.append(U) 384 385 wsorder = [] 386 for (mul, ir_out), u in zip(U.irreps, U.list): 387 u = u.astype(input.array.dtype) 388 # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] 389 390 w = self.param( 391 f"w{order}_{ir_out}", 392 nn.initializers.normal( 393 stddev=(mul**-0.5) ** (1.0 - gradient_normalization) 394 ), 395 (self.num_species, mul, input.shape[-2]), 396 ) 397 w = w * (mul**-0.5) ** gradient_normalization 398 wsorder.append(w) 399 ws.append(wsorder) 400 401 def fn(input: e3nn.IrrepsArray, index: jnp.ndarray): 402 # - This operation is parallel on the feature dimension (but each feature has its own parameters) 403 # This operation is an efficient implementation of 404 # vmap(lambda w, x: FunctionalLinear(irreps_out)(w, concatenate([x, tensor_product(x, x), tensor_product(x, x, x), ...])))(w, x) 405 # up to x power self.correlation 406 assert input.ndim == 2 # [num_features, irreps_x.dim] 407 assert index.ndim == 0 # int 408 409 out = dict() 410 x_ = input.array 411 412 for order in range(self.correlation, 0, -1): # correlation, ..., 1 413 414 U = Us[order - 1] 415 416 # ((w3 x + w2) x + w1) x 417 # \-----------/ 418 # out 419 420 for ii, ((mul, ir_out), u) in enumerate(zip(U.irreps, U.list)): 421 u = u.astype(x_.dtype) 422 # u: ndarray [(irreps_x.dim)^order, multiplicity, ir_out.dim] 423 424 w = ws[order - 1][ii][index] 425 if ir_out not in out: 426 out[ir_out] = ( 427 "special", 428 jnp.einsum("...jki,kc,cj->c...i", u, w, x_), 429 ) # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim] 430 else: 431 out[ir_out] += jnp.einsum( 432 "...ki,kc->c...i", u, w 433 ) # [num_features, (irreps_x.dim)^order, ir_out.dim] 434 435 # ((w3 x + w2) x + w1) x 436 # \----------------/ 437 # out (in the normal case) 438 439 for ir_out in out: 440 if isinstance(out[ir_out], tuple): 441 out[ir_out] = out[ir_out][1] 442 continue # already done (special case optimization above) 443 444 out[ir_out] = jnp.einsum( 445 "c...ji,cj->c...i", out[ir_out], x_ 446 ) # [num_features, (irreps_x.dim)^(oder-1), ir_out.dim] 447 448 # ((w3 x + w2) x + w1) x 449 # \-------------------/ 450 # out 451 452 # out[irrep_out] : [num_features, ir_out.dim] 453 irreps_out = e3nn.Irreps(sorted(out.keys())) 454 return e3nn.IrrepsArray.from_list( 455 irreps_out, 456 [out[ir][:, None, :] for (_, ir) in irreps_out], 457 (input.shape[0],), 458 ) 459 460 # Treat batch indices using vmap 461 shape = jnp.broadcast_shapes(input.shape[:-2], index.shape) 462 input = input.broadcast_to(shape + input.shape[-2:]) 463 index = jnp.broadcast_to(index, shape) 464 465 fn_mapped = fn 466 for _ in range(input.ndim - 2): 467 fn_mapped = jax.vmap(fn_mapped) 468 469 return fn_mapped(input, index).axis_to_mul()
Wraps parent module references in weak refs.
This prevents reference cycles from forming via parent links which can lead to accidental OOMs in eager mode due to slow garbage collection as well as spurious tracer leaks during jit compilation.
Note: "descriptors" are the underlying python mechanism for implementing dynamic @property decorators. We need to use a raw descriptor instead of the more common decorator in order to force that the appropriate getter/setter logic applies in subclasses even after various dataclass transforms.