Source code for torch_geometric.nn.attention.sgformer

from typing import Optional

import torch
from torch import Tensor


[docs]class SGFormerAttention(torch.nn.Module): r"""The simple global attention mechanism from the `"SGFormer: Simplifying and Empowering Transformers for Large-Graph Representations" <https://arxiv.org/abs/2306.10759>`_ paper. Args: channels (int): Size of each input sample. heads (int, optional): Number of parallel attention heads. (default: :obj:`1.`) head_channels (int, optional): Size of each attention head. (default: :obj:`64.`) qkv_bias (bool, optional): If specified, add bias to query, key and value in the self attention. (default: :obj:`False`) """ def __init__( self, channels: int, heads: int = 1, head_channels: int = 64, qkv_bias: bool = False, ) -> None: super().__init__() assert channels % heads == 0 if head_channels is None: head_channels = channels // heads self.heads = heads self.head_channels = head_channels inner_channels = head_channels * heads self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias)
[docs] def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) """ B, N, *_ = x.shape qs, ks, vs = self.q(x), self.k(x), self.v(x) # reshape and permute q, k and v to proper shape # (b, n, num_heads * head_channels) to (b, n, num_heads, head_channels) qs, ks, vs = map( lambda t: t.reshape(B, N, self.heads, self.head_channels), (qs, ks, vs)) if mask is not None: mask = mask[:, :, None, None] vs.masked_fill_(~mask, 0.) # replace 0's with epsilon epsilon = 1e-6 qs[qs == 0] = epsilon ks[ks == 0] = epsilon # normalize input, shape not changed qs, ks = map( lambda t: t / torch.linalg.norm(t, ord=2, dim=-1, keepdim=True), (qs, ks)) # numerator kvs = torch.einsum("blhm,blhd->bhmd", ks, vs) attention_num = torch.einsum("bnhm,bhmd->bnhd", qs, kvs) attention_num += N * vs # denominator all_ones = torch.ones([B, N]).to(ks.device) ks_sum = torch.einsum("blhm,bl->bhm", ks, all_ones) attention_normalizer = torch.einsum("bnhm,bhm->bnh", qs, ks_sum) # attentive aggregated results attention_normalizer = torch.unsqueeze(attention_normalizer, len(attention_normalizer.shape)) attention_normalizer += torch.ones_like(attention_normalizer) * N attn_output = attention_num / attention_normalizer return attn_output.mean(dim=2)
[docs] def reset_parameters(self): self.q.reset_parameters() self.k.reset_parameters() self.v.reset_parameters()
def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'heads={self.heads}, ' f'head_channels={self.head_channels})')