torch_geometric.nn.models.LPFormer
- class LPFormer(in_channels: int, hidden_channels: int, num_gnn_layers: int = 2, gnn_dropout: float = 0.1, num_transformer_layers: int = 1, num_heads: int = 1, transformer_dropout: float = 0.1, ppr_thresholds: Optional[list] = None, gcn_cache=False)[source]
Bases:
Module
The LPFormer model from the “LPFormer: An Adaptive Graph Transformer for Link Prediction” paper.
Note
For an example of using LPFormer, see examples/lpformer.py.
- Parameters:
in_channels (int) – Size of input dimension
hidden_channels (int) – Size of hidden dimension
num_gnn_layers (int, optional) – Number of GNN layers (default:
2
)gnn_dropout (float, optional) – Dropout used for GNN (default:
0.1
)num_transformer_layers (int, optional) – Number of Transformer layers (default:
1
)num_heads (int, optional) – Number of heads to use in MHA (default:
1
)transformer_dropout (float, optional) – Dropout used for Transformer (default:
0.1
)ppr_thresholds (list) – PPR thresholds for different types of nodes. Types include (in order) common neighbors, 1-Hop nodes (that aren’t CNs), and all other nodes. (default:
[0, 1e-4, 1e-2]
)gcn_cache (bool, optional) – Whether to cache edge indices during message passing. (default:
False
)
- forward(batch: Tensor, x: Tensor, edge_index: Union[Tensor, SparseTensor], ppr_matrix: Tensor) Tensor [source]
Forward Pass of LPFormer.
Returns raw logits for each link
- Parameters:
batch (Tensor) – The batch vector. Denotes which node pairs to predict.
x (Tensor) – Input node features
edge_index (torch.Tensor, SparseTensor) – The edge indices. Either in COO or SparseTensor format
ppr_matrix (Tensor) – PPR matrix
- Return type:
- propagate(x: Tensor, adj: Union[Tensor, SparseTensor]) Tensor [source]
Propagate via GNN.
- Parameters:
x (Tensor) – Node features
adj (torch.Tensor, SparseTensor) – Adjacency matrix
- Return type:
- calc_pairwise(batch: Tensor, X_node: Tensor, adj_mask: Tensor, ppr_matrix: Tensor) Tensor [source]
Calculate the pairwise features for the node pairs.
- Parameters:
batch (Tensor) – The batch vector. Denotes which node pairs to predict.
X_node (Tensor) – Node representations
adj_mask (Tensor) – Mask of adjacency matrix used for computing the different node types.
ppr_matrix (Tensor) – PPR matrix
- Return type:
- get_pos_encodings(cn_ppr: Tuple[Tensor, Tensor], onehop_ppr: Optional[Tuple[Tensor, Tensor]] = None, non1hop_ppr: Optional[Tuple[Tensor, Tensor]] = None) Tensor [source]
Calculate the PPR-based relative positional encodings.
Due to thresholds, sometimes we don’t have 1-hop or >1-hop nodes. In those cases, the value of onehop_ppr and/or non1hop_ppr should be None.
- compute_node_mask(batch: Tensor, adj: Tensor, ppr_matrix: Tensor) Tuple[Tuple, Optional[Tuple], Optional[Tuple]] [source]
Get mask based on type of node.
When mask_type is not “cn”, also return the ppr vals for both the source and target.
- get_ppr_vals(batch: Tensor, pair_diff_adj: Tensor, ppr_matrix: Tensor) Tuple[Tensor, Tensor, Tensor, Tensor] [source]
Get the src and tgt ppr vals.
Returns the: link the node belongs to, type of node (e.g., CN), PPR relative to src, PPR relative to tgt.
- drop_pairwise(pair_ix: Tensor, src_ppr: Optional[Tensor] = None, tgt_ppr: Optional[Tensor] = None, node_indicator: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor] [source]
Perform dropout on pairwise information by randomly dropping a percentage of nodes.
Done before performing attention for efficiency
- Parameters:
- Return type:
- get_structure_cnts(batch: Tensor, cn_info: Tuple[Tensor, Tensor], onehop_info: Tuple[Tensor, Tensor], non1hop_info: Optional[Tuple[Tensor, Tensor]]) Tuple[Tensor, Tensor, Tensor, Tensor] [source]
Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold.
Also include total # of neighbors
- Parameters:
batch (Tensor) – The batch vector. Denotes which node pairs to predict.
cn_info (tuple) – Information of CN nodes Contains (ID of node, src ppr, tgt ppr)
onehop_info (tuple) – Information of 1-Hop nodes. Contains (ID of node, src ppr, tgt ppr)
non1hop_info (tuple) – Information of >1-Hop nodes. Contains (ID of node, src ppr, tgt ppr)
- Return type:
- get_num_ppr_thresh(batch: Tensor, node_mask: Tensor, src_ppr: Tensor, tgt_ppr: Tensor, thresh: float) Tensor [source]
Get # of nodes v where ppr(a, v) >= eta & ppr(b, v) >= eta.
- get_count(node_mask: Tensor, batch: Tensor) Tensor [source]
# of nodes for each sample in batch.
They node have already filtered by PPR beforehand
- Parameters:
node_mask (Tensor) – IDs of nodes
batch (Tensor) – The batch vector. Denotes which node pairs to predict.
- Return type:
- get_non_1hop_ppr(batch: Tensor, adj: Tensor, ppr_matrix: Tensor) Tensor [source]
Get PPR scores for non-1hop nodes.
- Parameters:
batch (Tensor) – Links in batch
adj (Tensor) – Adjacency matrix
ppr_matrix (Tensor) – Sparse PPR matrix
- Return type: