torch_geometric.llm.models.ProteinMPNN
- class ProteinMPNN(hidden_dim: int = 128, num_encoder_layers: int = 3, num_decoder_layers: int = 3, num_neighbors: int = 30, num_rbf: int = 16, dropout: float = 0.1, augment_eps: float = 0.2, num_positional_embedding: int = 16, vocab_size: int = 21)[source]
Bases:
Module
The ProteinMPNN model from the “Robust deep learning–based protein sequence design using ProteinMPNN” paper.
- Parameters:
hidden_dim (int) – Hidden channels. (default:
128
)num_encoder_layers (int) – Number of encode layers. (default:
3
)num_decoder_layers (int) – Number of decode layers. (default:
3
)num_neighbors (int) – Number of neighbors for each atom. (default:
30
)num_rbf (int) – Number of radial basis functions. (default:
16
)dropout (float) – Dropout rate. (default:
0.1
)augment_eps (float) – Augmentation epsilon for input coordinates. (default:
0.2
)num_positional_embedding (int) – Number of positional embeddings. (default:
16
)vocab_size (int) – Number of vocabulary. (default:
21
)
Note
For an example of using
ProteinMPNN
, see examples/llm/protein_mpnn.py.- forward(x: Tensor, chain_seq_label: Tensor, mask: Tensor, chain_mask_all: Tensor, residue_idx: Tensor, chain_encoding_all: Tensor, batch: Tensor) Tensor [source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Return type: