torch_geometric.nn.models.ARLinkPredictor
- class ARLinkPredictor(in_channels, hidden_channels, out_channels=None, num_layers=2, dropout=0.0, attract_ratio=0.5)[source]
Bases:
Module
Link predictor using Attract-Repel embeddings from the paper “Pseudo-Euclidean Attract-Repel Embeddings for Undirected Graphs”.
This model splits node embeddings into: attract and repel. The edge prediction score is computed as the dot product of attract components minus the dot product of repel components.
- Parameters:
in_channels (int) – Size of each input sample.
hidden_channels (int) – Size of hidden embeddings.
out_channels (int, optional) – Size of output embeddings. If set to
None
, will default tohidden_channels
. (default:None
)num_layers (int) – Number of message passing layers. (default:
2
)dropout (float) – Dropout probability. (default:
0.0
)attract_ratio (float) – Ratio to use for attract component. Must be between 0 and 1. (default:
0.5
)
- forward(x, edge_index)[source]
Forward pass for link prediction.
- Parameters:
x (torch.Tensor) – Node feature matrix.
edge_index (torch.Tensor) – Edge indices to predict.
- Returns:
Predicted edge scores.
- Return type:
- encode(x, *args, **kwargs)[source]
Encode node features into attract-repel embeddings.
- Parameters:
x (torch.Tensor) – Node feature matrix of shape
[num_nodes, in_channels]
.*args – Variable length argument list
**kwargs – Arbitrary keyword arguments
- decode(attract_z, repel_z, edge_index)[source]
Decode edge scores from attract-repel embeddings.
- Parameters:
attract_z (torch.Tensor) – Attract embeddings of shape
[num_nodes, attract_dim]
.repel_z (torch.Tensor) – Repel embeddings of shape
[num_nodes, repel_dim]
.edge_index (torch.Tensor) – Edge indices of shape
[2, num_edges]
.
- Returns:
Edge prediction scores.
- Return type:
- calculate_r_fraction(attract_z, repel_z)[source]
Calculate the R-fraction (proportion of energy in repel space).
- Parameters:
attract_z (torch.Tensor) – Attract embeddings.
repel_z (torch.Tensor) – Repel embeddings.
- Returns:
R-fraction value.
- Return type: