torch_geometric.nn.models.LPFormer

class LPFormer(in_channels: int, hidden_channels: int, num_gnn_layers: int = 2, gnn_dropout: float = 0.1, num_transformer_layers: int = 1, num_heads: int = 1, transformer_dropout: float = 0.1, ppr_thresholds: Optional[list] = None, gcn_cache=False)[source]

Bases: Module

The LPFormer model from the “LPFormer: An Adaptive Graph Transformer for Link Prediction” paper.

Note

For an example of using LPFormer, see examples/lpformer.py.

Parameters:
  • in_channels (int) – Size of input dimension

  • hidden_channels (int) – Size of hidden dimension

  • num_gnn_layers (int, optional) – Number of GNN layers (default: 2)

  • gnn_dropout (float, optional) – Dropout used for GNN (default: 0.1)

  • num_transformer_layers (int, optional) – Number of Transformer layers (default: 1)

  • num_heads (int, optional) – Number of heads to use in MHA (default: 1)

  • transformer_dropout (float, optional) – Dropout used for Transformer (default: 0.1)

  • ppr_thresholds (list) – PPR thresholds for different types of nodes. Types include (in order) common neighbors, 1-Hop nodes (that aren’t CNs), and all other nodes. (default: [0, 1e-4, 1e-2])

  • gcn_cache (bool, optional) – Whether to cache edge indices during message passing. (default: False)

forward(batch: Tensor, x: Tensor, edge_index: Union[Tensor, SparseTensor], ppr_matrix: Tensor) Tensor[source]

Forward Pass of LPFormer.

Returns raw logits for each link

Parameters:
  • batch (Tensor) – The batch vector. Denotes which node pairs to predict.

  • x (Tensor) – Input node features

  • edge_index (torch.Tensor, SparseTensor) – The edge indices. Either in COO or SparseTensor format

  • ppr_matrix (Tensor) – PPR matrix

Return type:

Tensor

reset_parameters()[source]

Resets all learnable parameters of the module.

propagate(x: Tensor, adj: Union[Tensor, SparseTensor]) Tensor[source]

Propagate via GNN.

Parameters:
  • x (Tensor) – Node features

  • adj (torch.Tensor, SparseTensor) – Adjacency matrix

Return type:

Tensor

calc_pairwise(batch: Tensor, X_node: Tensor, adj_mask: Tensor, ppr_matrix: Tensor) Tensor[source]

Calculate the pairwise features for the node pairs.

Parameters:
  • batch (Tensor) – The batch vector. Denotes which node pairs to predict.

  • X_node (Tensor) – Node representations

  • adj_mask (Tensor) – Mask of adjacency matrix used for computing the different node types.

  • ppr_matrix (Tensor) – PPR matrix

Return type:

Tensor

get_pos_encodings(cn_ppr: Tuple[Tensor, Tensor], onehop_ppr: Optional[Tuple[Tensor, Tensor]] = None, non1hop_ppr: Optional[Tuple[Tensor, Tensor]] = None) Tensor[source]

Calculate the PPR-based relative positional encodings.

Due to thresholds, sometimes we don’t have 1-hop or >1-hop nodes. In those cases, the value of onehop_ppr and/or non1hop_ppr should be None.

Parameters:
  • cn_ppr (tuple, optional) – PPR scores of CNs.

  • onehop_ppr (tuple, optional) – PPR scores of 1-Hop. (default: None)

  • non1hop_ppr (tuple, optional) – PPR scores of >1-Hop. (default: None)

Return type:

Tensor

compute_node_mask(batch: Tensor, adj: Tensor, ppr_matrix: Tensor) Tuple[Tuple, Optional[Tuple], Optional[Tuple]][source]

Get mask based on type of node.

When mask_type is not “cn”, also return the ppr vals for both the source and target.

Parameters:
  • batch (Tensor) – The batch vector. Denotes which node pairs to predict.

  • adj (SparseTensor) – Adjacency matrix

  • ppr_matrix (Tensor) – PPR matrix

Return type:

Tuple[Tuple, Optional[Tuple], Optional[Tuple]]

get_ppr_vals(batch: Tensor, pair_diff_adj: Tensor, ppr_matrix: Tensor) Tuple[Tensor, Tensor, Tensor, Tensor][source]

Get the src and tgt ppr vals.

Returns the: link the node belongs to, type of node (e.g., CN), PPR relative to src, PPR relative to tgt.

Parameters:
  • batch (Tensor) – The batch vector. Denotes which node pairs to predict.

  • pair_diff_adj (SparseTensor) – Combination of rows in adjacency for src and tgt nodes (e.g., X1 + X2)

  • ppr_matrix (Tensor) – PPR matrix

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

drop_pairwise(pair_ix: Tensor, src_ppr: Optional[Tensor] = None, tgt_ppr: Optional[Tensor] = None, node_indicator: Optional[Tensor] = None) Tuple[Tensor, Tensor, Tensor, Tensor][source]

Perform dropout on pairwise information by randomly dropping a percentage of nodes.

Done before performing attention for efficiency

Parameters:
  • pair_ix (Tensor) – Link node belongs to

  • src_ppr (Tensor, optional) – PPR relative to src (default: None)

  • tgt_ppr (Tensor, optional) – PPR relative to tgt (default: None)

  • node_indicator (Tensor, optional) – Type of node (e.g., CN) (default: None)

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

get_structure_cnts(batch: Tensor, cn_info: Tuple[Tensor, Tensor], onehop_info: Tuple[Tensor, Tensor], non1hop_info: Optional[Tuple[Tensor, Tensor]]) Tuple[Tensor, Tensor, Tensor, Tensor][source]

Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold.

Also include total # of neighbors

Parameters:
  • batch (Tensor) – The batch vector. Denotes which node pairs to predict.

  • cn_info (tuple) – Information of CN nodes Contains (ID of node, src ppr, tgt ppr)

  • onehop_info (tuple) – Information of 1-Hop nodes. Contains (ID of node, src ppr, tgt ppr)

  • non1hop_info (tuple) – Information of >1-Hop nodes. Contains (ID of node, src ppr, tgt ppr)

Return type:

Tuple[Tensor, Tensor, Tensor, Tensor]

get_num_ppr_thresh(batch: Tensor, node_mask: Tensor, src_ppr: Tensor, tgt_ppr: Tensor, thresh: float) Tensor[source]

Get # of nodes v where ppr(a, v) >= eta & ppr(b, v) >= eta.

Parameters:
  • batch (Tensor) – The batch vector. Denotes which node pairs to predict.

  • node_mask (Tensor) – IDs of nodes

  • src_ppr (Tensor) – PPR relative to src node

  • tgt_ppr (Tensor) – PPR relative to tgt node

  • thresh (float) – PPR threshold for nodes (eta)

Return type:

Tensor

get_count(node_mask: Tensor, batch: Tensor) Tensor[source]

# of nodes for each sample in batch.

They node have already filtered by PPR beforehand

Parameters:
  • node_mask (Tensor) – IDs of nodes

  • batch (Tensor) – The batch vector. Denotes which node pairs to predict.

Return type:

Tensor

get_non_1hop_ppr(batch: Tensor, adj: Tensor, ppr_matrix: Tensor) Tensor[source]

Get PPR scores for non-1hop nodes.

Parameters:
  • batch (Tensor) – Links in batch

  • adj (Tensor) – Adjacency matrix

  • ppr_matrix (Tensor) – Sparse PPR matrix

Return type:

Tensor

calc_sparse_ppr(edge_index: Tensor, num_nodes: int, alpha: float = 0.15, eps: float = 5e-05) Tensor[source]

Calculate the PPR of the graph in sparse format.

Parameters:
  • edge_index (Tensor) – The edge indices

  • num_nodes (int) – Number of nodes

  • alpha (float, optional) – The alpha value of the PageRank algorithm. (default: 0.15)

  • eps (float, optional) – Threshold for stopping the PPR calculation (default: 5e-5)

Return type:

Tensor