Source code for torch_geometric.nn.models.polynormer

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn import GATConv, GCNConv
from torch_geometric.nn.attention import PolynormerAttention
from torch_geometric.utils import to_dense_batch


[docs]class Polynormer(torch.nn.Module): r"""The polynormer module from the `"Polynormer: polynomial-expressive graph transformer in linear time" <https://arxiv.org/abs/2403.01232>`_ paper. Args: in_channels (int): Input channels. hidden_channels (int): Hidden channels. out_channels (int): Output channels. local_layers (int): The number of local attention layers. (default: :obj:`7`) global_layers (int): The number of global attention layers. (default: :obj:`2`) in_dropout (float): Input dropout rate. (default: :obj:`0.15`) dropout (float): Dropout rate. (default: :obj:`0.5`) global_dropout (float): Global dropout rate. (default: :obj:`0.5`) heads (int): The number of heads. (default: :obj:`1`) beta (float): Aggregate type. (default: :obj:`0.9`) qk_shared (bool optional): Whether weight of query and key are shared. (default: :obj:`True`) pre_ln (bool): Pre layer normalization. (default: :obj:`False`) post_bn (bool): Post batch normalization. (default: :obj:`True`) local_attn (bool): Whether use local attention. (default: :obj:`False`) """ def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, local_layers: int = 7, global_layers: int = 2, in_dropout: float = 0.15, dropout: float = 0.5, global_dropout: float = 0.5, heads: int = 1, beta: float = 0.9, qk_shared: bool = False, pre_ln: bool = False, post_bn: bool = True, local_attn: bool = False, ) -> None: super().__init__() self._global = False self.in_drop = in_dropout self.dropout = dropout self.pre_ln = pre_ln self.post_bn = post_bn self.beta = beta self.h_lins = torch.nn.ModuleList() self.local_convs = torch.nn.ModuleList() self.lins = torch.nn.ModuleList() self.lns = torch.nn.ModuleList() if self.pre_ln: self.pre_lns = torch.nn.ModuleList() if self.post_bn: self.post_bns = torch.nn.ModuleList() # first layer inner_channels = heads * hidden_channels self.h_lins.append(torch.nn.Linear(in_channels, inner_channels)) if local_attn: self.local_convs.append( GATConv(in_channels, hidden_channels, heads=heads, concat=True, add_self_loops=False, bias=False)) else: self.local_convs.append( GCNConv(in_channels, inner_channels, cached=False, normalize=True)) self.lins.append(torch.nn.Linear(in_channels, inner_channels)) self.lns.append(torch.nn.LayerNorm(inner_channels)) if self.pre_ln: self.pre_lns.append(torch.nn.LayerNorm(in_channels)) if self.post_bn: self.post_bns.append(torch.nn.BatchNorm1d(inner_channels)) # following layers for _ in range(local_layers - 1): self.h_lins.append(torch.nn.Linear(inner_channels, inner_channels)) if local_attn: self.local_convs.append( GATConv(inner_channels, hidden_channels, heads=heads, concat=True, add_self_loops=False, bias=False)) else: self.local_convs.append( GCNConv(inner_channels, inner_channels, cached=False, normalize=True)) self.lins.append(torch.nn.Linear(inner_channels, inner_channels)) self.lns.append(torch.nn.LayerNorm(inner_channels)) if self.pre_ln: self.pre_lns.append(torch.nn.LayerNorm(heads * hidden_channels)) if self.post_bn: self.post_bns.append(torch.nn.BatchNorm1d(inner_channels)) self.lin_in = torch.nn.Linear(in_channels, inner_channels) self.ln = torch.nn.LayerNorm(inner_channels) self.global_attn = torch.nn.ModuleList() for _ in range(global_layers): self.global_attn.append( PolynormerAttention( channels=hidden_channels, heads=heads, head_channels=hidden_channels, beta=beta, dropout=global_dropout, qk_shared=qk_shared, )) self.pred_local = torch.nn.Linear(inner_channels, out_channels) self.pred_global = torch.nn.Linear(inner_channels, out_channels) self.reset_parameters()
[docs] def reset_parameters(self) -> None: for local_conv in self.local_convs: local_conv.reset_parameters() for attn in self.global_attn: attn.reset_parameters() for lin in self.lins: lin.reset_parameters() for h_lin in self.h_lins: h_lin.reset_parameters() for ln in self.lns: ln.reset_parameters() if self.pre_ln: for p_ln in self.pre_lns: p_ln.reset_parameters() if self.post_bn: for p_bn in self.post_bns: p_bn.reset_parameters() self.lin_in.reset_parameters() self.ln.reset_parameters() self.pred_local.reset_parameters() self.pred_global.reset_parameters()
[docs] def forward( self, x: Tensor, edge_index: Tensor, batch: Optional[Tensor], ) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): The input node features. edge_index (torch.Tensor or SparseTensor): The edge indices. batch (torch.Tensor, optional): The batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each element to a specific example. """ x = F.dropout(x, p=self.in_drop, training=self.training) # equivariant local attention x_local = 0 for i, local_conv in enumerate(self.local_convs): if self.pre_ln: x = self.pre_lns[i](x) h = self.h_lins[i](x) h = F.relu(h) x = local_conv(x, edge_index) + self.lins[i](x) if self.post_bn: x = self.post_bns[i](x) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = (1 - self.beta) * self.lns[i](h * x) + self.beta * x x_local = x_local + x # equivariant global attention if self._global: batch, indices = batch.sort() rev_perm = torch.empty_like(indices) rev_perm[indices] = torch.arange(len(indices), device=indices.device) x_local = self.ln(x_local[indices]) x_global, mask = to_dense_batch(x_local, batch) for attn in self.global_attn: x_global = attn(x_global, mask) x = x_global[mask][rev_perm] x = self.pred_global(x) else: x = self.pred_local(x_local) return F.log_softmax(x, dim=-1)