Source code for torch_geometric.nn.models.sgformer
from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.nn.attention import SGFormerAttention
from torch_geometric.nn.conv import GCNConv
from torch_geometric.utils import to_dense_batch
class GraphModule(torch.nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
num_layers=2,
dropout=0.5,
):
super().__init__()
self.convs = torch.nn.ModuleList()
self.fcs = torch.nn.ModuleList()
self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
self.bns = torch.nn.ModuleList()
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
for _ in range(num_layers):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
self.dropout = dropout
self.activation = F.relu
def reset_parameters(self):
for conv in self.convs:
conv.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
for fc in self.fcs:
fc.reset_parameters()
def forward(self, x, edge_index):
x = self.fcs[0](x)
x = self.bns[0](x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
last_x = x
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
x = self.bns[i + 1](x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = x + last_x
return x
class SGModule(torch.nn.Module):
def __init__(
self,
in_channels,
hidden_channels,
num_layers=2,
num_heads=1,
dropout=0.5,
):
super().__init__()
self.attns = torch.nn.ModuleList()
self.fcs = torch.nn.ModuleList()
self.fcs.append(torch.nn.Linear(in_channels, hidden_channels))
self.bns = torch.nn.ModuleList()
self.bns.append(torch.nn.LayerNorm(hidden_channels))
for _ in range(num_layers):
self.attns.append(
SGFormerAttention(hidden_channels, num_heads, hidden_channels))
self.bns.append(torch.nn.LayerNorm(hidden_channels))
self.dropout = dropout
self.activation = F.relu
def reset_parameters(self):
for attn in self.attns:
attn.reset_parameters()
for bn in self.bns:
bn.reset_parameters()
for fc in self.fcs:
fc.reset_parameters()
def forward(self, x: Tensor, batch: Tensor):
# to dense batch expects sorted batch
batch, indices = batch.sort(stable=True)
rev_perm = torch.empty_like(indices)
rev_perm[indices] = torch.arange(len(indices), device=indices.device)
x = x[indices]
x, mask = to_dense_batch(x, batch)
layer_ = []
# input MLP layer
x = self.fcs[0](x)
x = self.bns[0](x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# store as residual link
layer_.append(x)
for i, attn in enumerate(self.attns):
x = attn(x, mask)
x = (x + layer_[i]) / 2.
x = self.bns[i + 1](x)
x = self.activation(x)
x = F.dropout(x, p=self.dropout, training=self.training)
layer_.append(x)
x_mask = x[mask]
# reverse the sorting
unsorted_x_mask = x_mask[rev_perm]
return unsorted_x_mask
[docs]class SGFormer(torch.nn.Module):
r"""The sgformer module from the
`"SGFormer: Simplifying and Empowering Transformers for
Large-Graph Representations"
<https://arxiv.org/abs/2306.10759>`_ paper.
Args:
in_channels (int): Input channels.
hidden_channels (int): Hidden channels.
out_channels (int): Output channels.
trans_num_layers (int): The number of layers for all-pair attention.
(default: :obj:`2`)
trans_num_heads (int): The number of heads for attention.
(default: :obj:`1`)
trans_dropout (float): Global dropout rate.
(default: :obj:`0.5`)
gnn_num_layers (int): The number of layers for GNN.
(default: :obj:`3`)
gnn_dropout (float): GNN dropout rate.
(default: :obj:`0.5`)
graph_weight (float): The weight balance global and gnn module.
(default: :obj:`0.5`)
aggregate (str): Aggregate type.
(default: :obj:`add`)
"""
def __init__(
self,
in_channels: int,
hidden_channels: int,
out_channels: int,
trans_num_layers: int = 2,
trans_num_heads: int = 1,
trans_dropout: float = 0.5,
gnn_num_layers: int = 3,
gnn_dropout: float = 0.5,
graph_weight: float = 0.5,
aggregate: str = 'add',
):
super().__init__()
self.trans_conv = SGModule(
in_channels,
hidden_channels,
trans_num_layers,
trans_num_heads,
trans_dropout,
)
self.graph_conv = GraphModule(
in_channels,
hidden_channels,
gnn_num_layers,
gnn_dropout,
)
self.graph_weight = graph_weight
self.aggregate = aggregate
if aggregate == 'add':
self.fc = torch.nn.Linear(hidden_channels, out_channels)
elif aggregate == 'cat':
self.fc = torch.nn.Linear(2 * hidden_channels, out_channels)
else:
raise ValueError(f'Invalid aggregate type:{aggregate}')
self.params1 = list(self.trans_conv.parameters())
self.params2 = list(self.graph_conv.parameters())
self.params2.extend(list(self.fc.parameters()))
self.out_channels = out_channels
[docs] def reset_parameters(self) -> None:
self.trans_conv.reset_parameters()
self.graph_conv.reset_parameters()
self.fc.reset_parameters()
[docs] def forward(
self,
x: Tensor,
edge_index: Tensor,
batch: Optional[Tensor],
) -> Tensor:
r"""Forward pass.
Args:
x (torch.Tensor): The input node features.
edge_index (torch.Tensor or SparseTensor): The edge indices.
batch (torch.Tensor, optional): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
each element to a specific example.
"""
x1 = self.trans_conv(x, batch)
x2 = self.graph_conv(x, edge_index)
if self.aggregate == 'add':
x = self.graph_weight * x2 + (1 - self.graph_weight) * x1
else:
x = torch.cat((x1, x2), dim=1)
x = self.fc(x)
return F.log_softmax(x, dim=-1)