Source code for torch_geometric.nn.models.lpformer

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import Parameter

from ...nn.conv import MessagePassing
from ...nn.dense.linear import Linear
from ...nn.inits import glorot, zeros
from ...typing import Adj, OptTensor, Tuple
from ...utils import get_ppr, is_sparse, scatter, softmax
from .basic_gnn import GCN


[docs]class LPFormer(nn.Module): r"""The LPFormer model from the `"LPFormer: An Adaptive Graph Transformer for Link Prediction" <https://arxiv.org/abs/2310.11009>`_ paper. .. note:: For an example of using LPFormer, see `examples/lpformer.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/ lpformer.py>`_. Args: in_channels (int): Size of input dimension hidden_channels (int): Size of hidden dimension num_gnn_layers (int, optional): Number of GNN layers (default: :obj:`2`) gnn_dropout(float, optional): Dropout used for GNN (default: :obj:`0.1`) num_transformer_layers (int, optional): Number of Transformer layers (default: :obj:`1`) num_heads (int, optional): Number of heads to use in MHA (default: :obj:`1`) transformer_dropout (float, optional): Dropout used for Transformer (default: :obj:`0.1`) ppr_thresholds (list): PPR thresholds for different types of nodes. Types include (in order) common neighbors, 1-Hop nodes (that aren't CNs), and all other nodes. (default: :obj:`[0, 1e-4, 1e-2]`) gcn_cache (bool, optional): Whether to cache edge indices during message passing. (default: :obj:`False`) """ def __init__( self, in_channels: int, hidden_channels: int, num_gnn_layers: int = 2, gnn_dropout: float = 0.1, num_transformer_layers: int = 1, num_heads: int = 1, transformer_dropout: float = 0.1, ppr_thresholds: list = None, gcn_cache=False, ): super().__init__() # Default thresholds if ppr_thresholds is None: ppr_thresholds = [0, 1e-4, 1e-2] if len(ppr_thresholds) == 3: self.thresh_cn = ppr_thresholds[0] self.thresh_1hop = ppr_thresholds[1] self.thresh_non1hop = ppr_thresholds[2] else: raise ValueError( "Argument 'ppr_thresholds' must only be length 3!") self.in_dim = in_channels self.hid_dim = hidden_channels self.gnn_drop = gnn_dropout self.trans_drop = transformer_dropout self.gnn = GCN(in_channels, hidden_channels, num_gnn_layers, dropout=gnn_dropout, norm="layer_norm", cached=gcn_cache) self.gnn_norm = nn.LayerNorm(hidden_channels) # Create Transformer Layers self.att_layers = nn.ModuleList() for il in range(num_transformer_layers): if il == 0: node_dim = None self.out_dim = self.hid_dim * 2 if num_transformer_layers > 1 \ else self.hid_dim elif il == self.num_layers - 1: node_dim = self.hid_dim else: self.out_dim = node_dim = self.hid_dim self.att_layers.append( LPAttLayer(self.hid_dim, self.out_dim, node_dim, num_heads, self.trans_drop)) self.elementwise_lin = MLP(self.hid_dim, self.hid_dim, self.hid_dim) # Relative Positional Encodings self.ppr_encoder_cn = MLP(2, self.hid_dim, self.hid_dim) self.ppr_encoder_onehop = MLP(2, self.hid_dim, self.hid_dim) self.ppr_encoder_non1hop = MLP(2, self.hid_dim, self.hid_dim) # thresh=1 implies ignoring some set of nodes # Also allows us to be more efficient later if self.thresh_non1hop == 1 and self.thresh_1hop == 1: self.mask = "cn" elif self.thresh_non1hop == 1 and self.thresh_1hop < 1: self.mask = "1-hop" else: self.mask = "all" # 4 is for counts of diff nodes pairwise_dim = self.hid_dim * num_heads + 4 self.pairwise_lin = MLP(pairwise_dim, pairwise_dim, self.hid_dim) self.score_func = MLP(self.hid_dim * 2, self.hid_dim * 2, 1, norm=None) def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_dim}, ' f'{self.hid_dim}, num_gnn_layers={self.gnn.num_layers}, ' f'num_transformer_layers={len(self.att_layers)})')
[docs] def reset_parameters(self): r"""Resets all learnable parameters of the module.""" self.gnn.reset_parameters() self.gnn_norm.reset_parameters() self.elementwise_lin.reset_parameters() self.pairwise_lin.reset_parameters() self.ppr_encoder_cn.reset_parameters() self.ppr_encoder_onehop.reset_parameters() self.ppr_encoder_non1hop.reset_parameters() self.score_func.reset_parameters() for i in range(len(self.att_layers)): self.att_layers[i].reset_parameters()
[docs] def forward( self, batch: Tensor, x: Tensor, edge_index: Adj, ppr_matrix: Tensor, ) -> Tensor: r"""Forward Pass of LPFormer. Returns raw logits for each link Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. x (Tensor): Input node features edge_index (torch.Tensor, SparseTensor): The edge indices. Either in COO or SparseTensor format ppr_matrix (Tensor): PPR matrix """ batch = batch.to(x.device) X_node = self.propagate(x, edge_index) x_i, x_j = X_node[batch[0]], X_node[batch[1]] elementwise_edge_feats = self.elementwise_lin(x_i * x_j) # Ensure in sparse format # Need as native torch.sparse for later computations # (necessary operations are not supported by PyG SparseTensor) if not edge_index.is_sparse: num_nodes = ppr_matrix.size(1) vals = torch.ones(len(edge_index[0]), device=edge_index.device) edge_index = torch.sparse_coo_tensor(edge_index, vals, [num_nodes, num_nodes]) # Checks if SparseTensor, if so the convert if is_sparse(edge_index) and not edge_index.is_sparse: edge_index = edge_index.to_torch_sparse_coo_tensor() # Ensure {0, 1} edge_index = edge_index.coalesce().bool().int() pairwise_feats = self.calc_pairwise(batch, X_node, edge_index, ppr_matrix) combined_feats = torch.cat((elementwise_edge_feats, pairwise_feats), dim=-1) logits = self.score_func(combined_feats) return logits
[docs] def propagate(self, x: Tensor, adj: Adj) -> Tensor: """Propagate via GNN. Args: x (Tensor): Node features adj (torch.Tensor, SparseTensor): Adjacency matrix """ x = F.dropout(x, p=self.gnn_drop, training=self.training) X_node = self.gnn(x, adj) X_node = self.gnn_norm(X_node) return X_node
[docs] def calc_pairwise(self, batch: Tensor, X_node: Tensor, adj_mask: Tensor, ppr_matrix: Tensor) -> Tensor: r"""Calculate the pairwise features for the node pairs. Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. X_node (Tensor): Node representations adj_mask (Tensor): Mask of adjacency matrix used for computing the different node types. ppr_matrix (Tensor): PPR matrix """ k_i, k_j = X_node[batch[0]], X_node[batch[1]] pairwise_feats = torch.cat((k_i, k_j), dim=-1) cn_info, onehop_info, non1hop_info = self.compute_node_mask( batch, adj_mask, ppr_matrix) all_mask = cn_info[0] if onehop_info is not None: all_mask = torch.cat((all_mask, onehop_info[0]), dim=-1) if non1hop_info is not None: all_mask = torch.cat((all_mask, non1hop_info[0]), dim=-1) pes = self.get_pos_encodings(cn_info[1:], onehop_info[1:], non1hop_info[1:]) for lay in range(len(self.att_layers)): pairwise_feats = self.att_layers[lay](all_mask, pairwise_feats, X_node, pes) num_cns, num_1hop, num_non1hop, num_neigh = self.get_structure_cnts( batch, cn_info, onehop_info, non1hop_info) pairwise_feats = torch.cat( (pairwise_feats, num_cns, num_1hop, num_non1hop, num_neigh), dim=-1) pairwise_feats = self.pairwise_lin(pairwise_feats) return pairwise_feats
[docs] def get_pos_encodings( self, cn_ppr: Tuple[Tensor, Tensor], onehop_ppr: Optional[Tuple[Tensor, Tensor]] = None, non1hop_ppr: Optional[Tuple[Tensor, Tensor]] = None) -> Tensor: r"""Calculate the PPR-based relative positional encodings. Due to thresholds, sometimes we don't have 1-hop or >1-hop nodes. In those cases, the value of onehop_ppr and/or non1hop_ppr should be `None`. Args: cn_ppr (tuple, optional): PPR scores of CNs. onehop_ppr (tuple, optional): PPR scores of 1-Hop. (default: :obj:`None`) non1hop_ppr (tuple, optional): PPR scores of >1-Hop. (default: :obj:`None`) """ cn_a = self.ppr_encoder_cn(torch.stack((cn_ppr[0], cn_ppr[1])).t()) cn_b = self.ppr_encoder_cn(torch.stack((cn_ppr[1], cn_ppr[0])).t()) cn_pe = cn_a + cn_b if onehop_ppr is None: return cn_pe onehop_a = self.ppr_encoder_onehop( torch.stack((onehop_ppr[0], onehop_ppr[1])).t()) onehop_b = self.ppr_encoder_onehop( torch.stack((onehop_ppr[1], onehop_ppr[0])).t()) onehop_pe = onehop_a + onehop_b if non1hop_ppr is None: return torch.cat((cn_pe, onehop_pe), dim=0) non1hop_a = self.ppr_encoder_non1hop( torch.stack((non1hop_ppr[0], non1hop_ppr[1])).t()) non1hop_b = self.ppr_encoder_non1hop( torch.stack((non1hop_ppr[1], non1hop_ppr[0])).t()) non1hop_pe = non1hop_a + non1hop_b return torch.cat((cn_pe, onehop_pe, non1hop_pe), dim=0)
[docs] def compute_node_mask( self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor ) -> Tuple[Tuple, Optional[Tuple], Optional[Tuple]]: r"""Get mask based on type of node. When mask_type is not "cn", also return the ppr vals for both the source and target. Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. adj (SparseTensor): Adjacency matrix ppr_matrix (Tensor): PPR matrix """ src_adj = torch.index_select(adj, 0, batch[0]) tgt_adj = torch.index_select(adj, 0, batch[1]) if self.mask == "cn": # 1 when CN, 0 otherwise pair_adj = src_adj * tgt_adj else: # Equals: {0: ">1-Hop", 1: "1-Hop (Non-CN)", 2: "CN"} pair_adj = src_adj + tgt_adj pair_ix, node_type, src_ppr, tgt_ppr = self.get_ppr_vals( batch, pair_adj, ppr_matrix) cn_filt_cond = (src_ppr >= self.thresh_cn) & (tgt_ppr >= self.thresh_cn) onehop_filt_cond = (src_ppr >= self.thresh_1hop) & ( tgt_ppr >= self.thresh_1hop) if self.mask != "cn": filt_cond = torch.where(node_type == 1, onehop_filt_cond, cn_filt_cond) else: filt_cond = torch.where(node_type == 0, onehop_filt_cond, cn_filt_cond) pair_ix, node_type = pair_ix[:, filt_cond], node_type[filt_cond] src_ppr, tgt_ppr = src_ppr[filt_cond], tgt_ppr[filt_cond] # >1-Hop mask is gotten separately if self.mask == "all": non1hop_ix, non1hop_sppr, non1hop_tppr = self.get_non_1hop_ppr( batch, adj, ppr_matrix) # Dropout if self.training and self.trans_drop > 0: pair_ix, src_ppr, tgt_ppr, node_type = self.drop_pairwise( pair_ix, src_ppr, tgt_ppr, node_type) if self.mask == "all": non1hop_ix, non1hop_sppr, non1hop_tppr, _ = self.drop_pairwise( non1hop_ix, non1hop_sppr, non1hop_tppr) # Separate out CN and 1-Hop if self.mask != "cn": cn_ind = node_type == 2 cn_ix = pair_ix[:, cn_ind] cn_src_ppr = src_ppr[cn_ind] cn_tgt_ppr = tgt_ppr[cn_ind] one_hop_ind = node_type == 1 onehop_ix = pair_ix[:, one_hop_ind] onehop_src_ppr = src_ppr[one_hop_ind] onehop_tgt_ppr = tgt_ppr[one_hop_ind] if self.mask == "cn": return (pair_ix, src_ppr, tgt_ppr), None, None elif self.mask == "1-hop": return (cn_ix, cn_src_ppr, cn_tgt_ppr), (onehop_ix, onehop_src_ppr, onehop_tgt_ppr), None else: return (cn_ix, cn_src_ppr, cn_tgt_ppr), (onehop_ix, onehop_src_ppr, onehop_tgt_ppr), (non1hop_ix, non1hop_sppr, non1hop_tppr)
[docs] def get_ppr_vals( self, batch: Tensor, pair_diff_adj: Tensor, ppr_matrix: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""Get the src and tgt ppr vals. Returns the: link the node belongs to, type of node (e.g., CN), PPR relative to src, PPR relative to tgt. Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. pair_diff_adj (SparseTensor): Combination of rows in adjacency for src and tgt nodes (e.g., X1 + X2) ppr_matrix (Tensor): PPR matrix """ # Additional terms for also choosing scores when ppr=0 # Multiplication removes any values for nodes not in batch # Addition then adds offset to ensure we select when ppr=0 # All selected scores are +1 higher than their true val src_ppr_adj = torch.index_select( ppr_matrix, 0, batch[0]) * pair_diff_adj + pair_diff_adj tgt_ppr_adj = torch.index_select( ppr_matrix, 0, batch[1]) * pair_diff_adj + pair_diff_adj # Can now convert ppr scores to dense ppr_ix = src_ppr_adj.coalesce().indices() src_ppr = src_ppr_adj.coalesce().values() tgt_ppr = tgt_ppr_adj.coalesce().values() # TODO: Needed due to a bug in recent torch versions # see here for more - https://github.com/pytorch/pytorch/issues/114529 # note that if one is 0 so is the other zero_vals = (src_ppr != 0) src_ppr = src_ppr[zero_vals] tgt_ppr = tgt_ppr[tgt_ppr != 0] ppr_ix = ppr_ix[:, zero_vals] pair_diff_adj = pair_diff_adj.coalesce().values() node_type = pair_diff_adj[src_ppr != 0] # Remove additional +1 from each ppr val src_ppr = (src_ppr - node_type) / node_type tgt_ppr = (tgt_ppr - node_type) / node_type return ppr_ix, node_type, src_ppr, tgt_ppr
[docs] def drop_pairwise( self, pair_ix: Tensor, src_ppr: Optional[Tensor] = None, tgt_ppr: Optional[Tensor] = None, node_indicator: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: r"""Perform dropout on pairwise information by randomly dropping a percentage of nodes. Done before performing attention for efficiency Args: pair_ix (Tensor): Link node belongs to src_ppr (Tensor, optional): PPR relative to src (default: :obj:`None`) tgt_ppr (Tensor, optional): PPR relative to tgt (default: :obj:`None`) node_indicator (Tensor, optional): Type of node (e.g., CN) (default: :obj:`None`) """ num_indices = math.ceil(pair_ix.size(1) * (1 - self.trans_drop)) indices = torch.randperm(pair_ix.size(1))[:num_indices] pair_ix = pair_ix[:, indices] if src_ppr is not None: src_ppr = src_ppr[indices] if tgt_ppr is not None: tgt_ppr = tgt_ppr[indices] if node_indicator is not None: node_indicator = node_indicator[indices] return pair_ix, src_ppr, tgt_ppr, node_indicator
[docs] def get_structure_cnts( self, batch: Tensor, cn_info: Tuple[Tensor, Tensor], onehop_info: Tuple[Tensor, Tensor], non1hop_info: Optional[Tuple[Tensor, Tensor]], ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: """Counts for CNs, 1-Hop, and >1-Hop that satisfy PPR threshold. Also include total # of neighbors Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. cn_info (tuple): Information of CN nodes Contains (ID of node, src ppr, tgt ppr) onehop_info (tuple): Information of 1-Hop nodes. Contains (ID of node, src ppr, tgt ppr) non1hop_info (tuple): Information of >1-Hop nodes. Contains (ID of node, src ppr, tgt ppr) """ num_cns = self.get_num_ppr_thresh(batch, cn_info[0], cn_info[1], cn_info[2], self.thresh_cn) num_1hop = self.get_num_ppr_thresh(batch, onehop_info[0], onehop_info[1], onehop_info[2], self.thresh_1hop) # TOTAL num of 1-hop neighbors union num_ppr_ones = self.get_num_ppr_thresh(batch, onehop_info[0], onehop_info[1], onehop_info[2], thresh=0) num_neighbors = num_cns + num_ppr_ones # Process for >1-hop is different which is why we use get_count below if non1hop_info is None: return num_cns, num_1hop, 0, num_neighbors else: num_non1hop = self.get_count(non1hop_info[0], batch) return num_cns, num_1hop, num_non1hop, num_neighbors
[docs] def get_num_ppr_thresh(self, batch: Tensor, node_mask: Tensor, src_ppr: Tensor, tgt_ppr: Tensor, thresh: float) -> Tensor: """Get # of nodes `v` where `ppr(a, v) >= eta` & `ppr(b, v) >= eta`. Args: batch (Tensor): The batch vector. Denotes which node pairs to predict. node_mask (Tensor): IDs of nodes src_ppr (Tensor): PPR relative to src node tgt_ppr (Tensor): PPR relative to tgt node thresh (float): PPR threshold for nodes (`eta`) """ weight = torch.ones(node_mask.size(1), device=node_mask.device) ppr_above_thresh = (src_ppr >= thresh) & (tgt_ppr >= thresh) num_ppr = scatter(ppr_above_thresh.float() * weight, node_mask[0].long(), dim=0, dim_size=batch.size(1), reduce="sum") num_ppr = num_ppr.unsqueeze(-1) return num_ppr
[docs] def get_count( self, node_mask: Tensor, batch: Tensor, ) -> Tensor: """# of nodes for each sample in batch. They node have already filtered by PPR beforehand Args: node_mask (Tensor): IDs of nodes batch (Tensor): The batch vector. Denotes which node pairs to predict. """ weight = torch.ones(node_mask.size(1), device=node_mask.device) num_nodes = scatter(weight, node_mask[0].long(), dim=0, dim_size=batch.size(1), reduce="sum") num_nodes = num_nodes.unsqueeze(-1) return num_nodes
[docs] def get_non_1hop_ppr(self, batch: Tensor, adj: Tensor, ppr_matrix: Tensor) -> Tensor: r"""Get PPR scores for non-1hop nodes. Args: batch (Tensor): Links in batch adj (Tensor): Adjacency matrix ppr_matrix (Tensor): Sparse PPR matrix """ # NOTE: Use original adj (one pass in forward() removes links in batch) # Done since removing them converts src/tgt nodes to >1-hop nodes. # Therefore removing CN and 1-hop will also remove the batch links. # During training we add back in the links in the batch # (we're removed from adjacency before being passed to model) # Done since otherwise they will be mistakenly seen as >1-Hop nodes # Instead they're 1-Hop, and get ignored accordingly # Ignored during eval since we know the links aren't in the adj adj2 = adj if self.training: n = adj.size(0) batch_flip = torch.cat( (batch, torch.flip(batch, (0, )).to(batch.device)), dim=-1) batch_ones = torch.ones_like(batch_flip[0], device=batch.device) adj_edges = torch.sparse_coo_tensor(batch_flip, batch_ones, [n, n], device=batch.device) adj_edges = adj_edges adj2 = (adj + adj_edges).coalesce().bool().int() src_adj = torch.index_select(adj2, 0, batch[0]) tgt_adj = torch.index_select(adj2, 0, batch[1]) src_ppr = torch.index_select(ppr_matrix, 0, batch[0]) tgt_ppr = torch.index_select(ppr_matrix, 0, batch[1]) # Remove CN scores src_ppr = src_ppr - src_ppr * (src_adj * tgt_adj) tgt_ppr = tgt_ppr - tgt_ppr * (src_adj * tgt_adj) # Also need to remove CN entries in Adj # Otherwise they leak into next computation src_adj = src_adj - src_adj * (src_adj * tgt_adj) tgt_adj = tgt_adj - tgt_adj * (src_adj * tgt_adj) # Remove 1-Hop scores src_ppr = src_ppr - src_ppr * (src_adj + tgt_adj) tgt_ppr = tgt_ppr - tgt_ppr * (src_adj + tgt_adj) # Make sure we include both when we convert to dense so indices align # Do so by adding 1 to each based on the other src_ppr_add = src_ppr + torch.sign(tgt_ppr) tgt_ppr_add = tgt_ppr + torch.sign(src_ppr) src_ix = src_ppr_add.coalesce().indices() src_vals = src_ppr_add.coalesce().values() tgt_vals = tgt_ppr_add.coalesce().values() # Now we can remove value which is just 1 # Technically creates -1 scores for ppr scores that were 0 # Doesn't matter as they'll be filtered out by condition later src_vals = src_vals - 1 tgt_vals = tgt_vals - 1 ppr_condition = (src_vals >= self.thresh_non1hop) & ( tgt_vals >= self.thresh_non1hop) src_ix, src_vals, tgt_vals = src_ix[:, ppr_condition], src_vals[ ppr_condition], tgt_vals[ppr_condition] return src_ix, src_vals, tgt_vals
[docs] def calc_sparse_ppr(self, edge_index: Tensor, num_nodes: int, alpha: float = 0.15, eps: float = 5e-5) -> Tensor: r"""Calculate the PPR of the graph in sparse format. Args: edge_index: The edge indices num_nodes: Number of nodes alpha (float, optional): The alpha value of the PageRank algorithm. (default: :obj:`0.15`) eps (float, optional): Threshold for stopping the PPR calculation (default: :obj:`5e-5`) """ ei, ei_w = get_ppr(edge_index.cpu(), alpha=alpha, eps=eps, num_nodes=num_nodes) ppr_matrix = torch.sparse_coo_tensor(ei, ei_w, [num_nodes, num_nodes]) return ppr_matrix
class LPAttLayer(MessagePassing): r"""Attention Layer for pairwise interaction module. Args: in_channels (int): Size of input dimension out_channels (int): Size of output dimension node_dim (int): Dimension of nodes being aggregated num_heads (int): Number of heads to use in MHA dropout (float): Dropout on attention values concat (bool, optional): Whether to concat attention heads. Otherwise sum (default: :obj:`True`) """ _alpha: OptTensor def __init__( self, in_channels: int, out_channels: int, node_dim: int, num_heads: int, dropout: float, concat: bool = True, **kwargs, ): super().__init__(node_dim=0, flow="target_to_source", **kwargs) self.in_channels = in_channels self.out_channels = out_channels self.heads = num_heads self.concat = concat self.dropout = dropout self.negative_slope = 0.2 # LeakyRelu out_dim = 2 if node_dim is None: node_dim = in_channels * out_dim else: node_dim = node_dim * out_dim self.lin_l = Linear(in_channels, self.heads * out_channels, weight_initializer='glorot') self.lin_r = Linear(node_dim, self.heads * out_channels, weight_initializer='glorot') att_out = out_channels self.att = Parameter(Tensor(1, self.heads, att_out)) if concat: self.bias = Parameter(Tensor(self.heads * out_channels)) else: self.bias = Parameter(Tensor(out_channels)) self._alpha = None self.dropout = dropout self.post_att_norm = nn.LayerNorm(out_channels) self.reset_parameters() def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, heads={self.heads})') def reset_parameters(self): self.lin_l.reset_parameters() self.lin_r.reset_parameters() self.post_att_norm.reset_parameters() glorot(self.att) zeros(self.bias) def forward( self, edge_index: Tensor, edge_feats: Tensor, node_feats: Tensor, ppr_rpes: Tensor, ) -> Tensor: """Runs the forward pass of the module. Args: edge_index (Tensor): The edge indices. edge_feats (Tensor): Concatenated representations of src and target nodes for each link node_feats (Tensor): Representations for individual nodes ppr_rpes (Tensor): Relative PEs for each node """ out = self.propagate(edge_index, x=(edge_feats, node_feats), ppr_rpes=ppr_rpes, size=None) alpha = self._alpha assert alpha is not None self._alpha = None if self.concat: out = out.view(-1, self.heads * self.out_channels) else: out = out.mean(dim=1) if self.bias is not None: out = out + self.bias out = self.post_att_norm(out) out = F.dropout(out, p=self.dropout, training=self.training) return out def message(self, x_i: Tensor, x_j: Tensor, ppr_rpes: Tensor, index: Tensor, ptr: Tensor, size_i: Optional[int]) -> Tensor: H, C = self.heads, self.out_channels x_j = torch.cat((x_j, ppr_rpes), dim=-1) x_j = self.lin_r(x_j).view(-1, H, C) # e=(a, b) attending to v e1, e2 = x_i.chunk(2, dim=-1) e1 = self.lin_l(e1).view(-1, H, C) e2 = self.lin_l(e2).view(-1, H, C) x = x_j * (e1 + e2) x = F.leaky_relu(x, self.negative_slope) alpha = (x * self.att).sum(dim=-1) alpha = softmax(alpha, index, ptr, size_i) self._alpha = alpha return x_j * alpha.unsqueeze(-1) class MLP(nn.Module): r"""L Layer MLP.""" def __init__(self, in_channels: int, hid_channels: int, out_channels: int, num_layers: int = 2, drop: int = 0, norm: str = "layer"): super().__init__() self.dropout = drop if norm == "batch": self.norm = nn.BatchNorm1d(hid_channels) elif norm == "layer": self.norm = nn.LayerNorm(hid_channels) else: self.norm = None self.linears = torch.nn.ModuleList() if num_layers == 1: self.linears.append(nn.Linear(in_channels, out_channels)) else: self.linears.append(nn.Linear(in_channels, hid_channels)) for _ in range(num_layers - 2): self.linears.append(nn.Linear(hid_channels, hid_channels)) self.linears.append(nn.Linear(hid_channels, out_channels)) def reset_parameters(self): for lin in self.linears: lin.reset_parameters() if self.norm is not None: self.norm.reset_parameters() def forward(self, x: Tensor) -> Tensor: for lin in self.linears[:-1]: x = lin(x) x = self.norm(x) if self.norm is not None else x x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.linears[-1](x) return x.squeeze(-1)