Source code for torch_geometric.nn.attention.performer

import math
from typing import Callable, Optional

import torch
from torch import Tensor


def _orthogonal_matrix(dim: int) -> Tensor:
    r"""Get an orthogonal matrix by applying QR decomposition."""
    # Random matrix from normal distribution
    mat = torch.randn((dim, dim))
    # QR decomposition to two orthogonal matrices
    q, _ = torch.linalg.qr(mat.cpu(), mode='reduced')
    return q.t()


def orthogonal_matrix(num_rows: int, num_cols: int) -> Tensor:
    r"""Generate an orthogonal matrix with `num_rows` rows
    and `num_cols` columns.
    """
    num_full_blocks = int(num_rows / num_cols)
    blocks = []
    for _ in range(num_full_blocks):
        q = _orthogonal_matrix(num_cols)
        blocks.append(q)
    remain_rows = num_rows - num_full_blocks * num_cols
    if remain_rows > 0:
        q = _orthogonal_matrix(num_cols)
        blocks.append(q[:remain_rows])
    mat = torch.cat(blocks)
    # multiplier = torch.randn((num_rows, num_cols)).norm(dim=1)
    # scaler = torch.diag(multiplier)
    # mat = scaler @ mat
    return mat


def linear_attention(q: Tensor, k: Tensor, v: Tensor) -> Tensor:
    r"""Efficient attention mechanism from the
    `"Rethinking Attention with Performers"
    <https://arxiv.org/abs/2009.14794>`_ paper.

    .. math::
        \mathbf{\hat{D}}^{-1}(\mathbf{Q}'((\mathbf{K}')^{\top} \mathbf{V}))

    """
    D_inv = 1.0 / (q @ k.sum(dim=-2).unsqueeze(-1))
    kv = k.transpose(-2, -1) @ v
    qkv = q @ kv
    out = torch.einsum('...L,...Ld->...Ld', D_inv.squeeze(-1), qkv)
    return out


def generalized_kernel(
        x: Tensor,
        mat: Tensor,
        kernel: Callable = torch.nn.ReLU(),
        epsilon: float = 0.001,
) -> Tensor:
    batch_size, num_heads = x.size()[:2]
    projection = mat.t().expand(batch_size, num_heads, -1, -1)
    x = x @ projection
    out = kernel(x) + epsilon
    return out


class PerformerProjection(torch.nn.Module):
    r"""The fast attention that uses a projection matrix
    from the `"Rethinking Attention with Performers"
    <https://arxiv.org/abs/2009.14794>`_ paper. This class
    projects :math:`\mathbf{Q}` and :math:`\mathbf{K}` matrices
    with specified kernel.

    Args:
        num_cols (int): Projection matrix number of columns.
        kernel (Callable, optional): Kernels for generalized attention.
            If not specified, `ReLU` kernel will be used.
            (default: :obj:`torch.nn.ReLU()`)
    """
    def __init__(self, num_cols: int, kernel: Callable = torch.nn.ReLU()):
        super().__init__()
        num_rows = int(num_cols * math.log(num_cols))
        self.num_rows = num_rows
        self.num_cols = num_cols
        # Generate an orthogonal projection matrix
        # with the shape (num_rows, num_cols)
        projection_matrix = orthogonal_matrix(self.num_rows, self.num_cols)
        self.register_buffer('projection_matrix', projection_matrix)
        assert kernel is not None
        self.kernel = kernel

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        q = generalized_kernel(q, self.projection_matrix, self.kernel)
        k = generalized_kernel(k, self.projection_matrix, self.kernel)
        out = linear_attention(q, k, v)
        return out


[docs]class PerformerAttention(torch.nn.Module): r"""The linear scaled attention mechanism from the `"Rethinking Attention with Performers" <https://arxiv.org/abs/2009.14794>`_ paper. Args: channels (int): Size of each input sample. heads (int, optional): Number of parallel attention heads. head_channels (int, optional): Size of each attention head. (default: :obj:`64.`) kernel (Callable, optional): Kernels for generalized attention. If not specified, `ReLU` kernel will be used. (default: :obj:`torch.nn.ReLU()`) qkv_bias (bool, optional): If specified, add bias to query, key and value in the self attention. (default: :obj:`False`) attn_out_bias (bool, optional): If specified, add bias to the attention output. (default: :obj:`True`) dropout (float, optional): Dropout probability of the final attention output. (default: :obj:`0.0`) """ def __init__( self, channels: int, heads: int, head_channels: int = 64, kernel: Callable = torch.nn.ReLU(), qkv_bias: bool = False, attn_out_bias: bool = True, dropout: float = 0.0, ): super().__init__() assert channels % heads == 0 if head_channels is None: head_channels = channels // heads self.heads = heads self.head_channels = head_channels self.kernel = kernel self.fast_attn = PerformerProjection(head_channels, kernel) inner_channels = head_channels * heads self.q = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.k = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.v = torch.nn.Linear(channels, inner_channels, bias=qkv_bias) self.attn_out = torch.nn.Linear(inner_channels, channels, bias=attn_out_bias) self.dropout = torch.nn.Dropout(dropout)
[docs] def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: r"""Forward pass. Args: x (torch.Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B \times N \times F}`, with batch-size :math:`B`, (maximum) number of nodes :math:`N` for each graph, and feature dimension :math:`F`. mask (torch.Tensor, optional): Mask matrix :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating the valid nodes for each graph. (default: :obj:`None`) """ B, N, *_ = x.shape q, k, v = self.q(x), self.k(x), self.v(x) # Reshape and permute q, k and v to proper shape # (B, N, num_heads * head_channels) to (b, num_heads, n, head_channels) q, k, v = map( lambda t: t.reshape(B, N, self.heads, self.head_channels).permute( 0, 2, 1, 3), (q, k, v)) if mask is not None: mask = mask[:, None, :, None] v.masked_fill_(~mask, 0.) out = self.fast_attn(q, k, v) out = out.permute(0, 2, 1, 3).reshape(B, N, -1) out = self.attn_out(out) out = self.dropout(out) return out
[docs] @torch.no_grad() def redraw_projection_matrix(self): r"""As described in the paper, periodically redraw examples to improve overall approximation of attention. """ num_rows = self.fast_attn.num_rows num_cols = self.fast_attn.num_cols projection_matrix = orthogonal_matrix(num_rows, num_cols) self.fast_attn.projection_matrix.copy_(projection_matrix) del projection_matrix
def _reset_parameters(self): self.q.reset_parameters() self.k.reset_parameters() self.v.reset_parameters() self.attn_out.reset_parameters() self.redraw_projection_matrix() def __repr__(self) -> str: return (f'{self.__class__.__name__}(' f'heads={self.heads}, ' f'head_channels={self.head_channels} ' f'kernel={self.kernel})')