torch_geometric.llm.models.ProteinMPNN

class ProteinMPNN(hidden_dim: int = 128, num_encoder_layers: int = 3, num_decoder_layers: int = 3, num_neighbors: int = 30, num_rbf: int = 16, dropout: float = 0.1, augment_eps: float = 0.2, num_positional_embedding: int = 16, vocab_size: int = 21)[source]

Bases: Module

The ProteinMPNN model from the “Robust deep learning–based protein sequence design using ProteinMPNN” paper.

Parameters:
  • hidden_dim (int) – Hidden channels. (default: 128)

  • num_encoder_layers (int) – Number of encode layers. (default: 3)

  • num_decoder_layers (int) – Number of decode layers. (default: 3)

  • num_neighbors (int) – Number of neighbors for each atom. (default: 30)

  • num_rbf (int) – Number of radial basis functions. (default: 16)

  • dropout (float) – Dropout rate. (default: 0.1)

  • augment_eps (float) – Augmentation epsilon for input coordinates. (default: 0.2)

  • num_positional_embedding (int) – Number of positional embeddings. (default: 16)

  • vocab_size (int) – Number of vocabulary. (default: 21)

Note

For an example of using ProteinMPNN, see examples/llm/protein_mpnn.py.

forward(x: Tensor, chain_seq_label: Tensor, mask: Tensor, chain_mask_all: Tensor, residue_idx: Tensor, chain_encoding_all: Tensor, batch: Tensor) Tensor[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type:

Tensor