torch_geometric.nn.attention.PerformerAttention

class PerformerAttention(channels: int, heads: int, head_channels: int = 64, kernel: Callable = ReLU(), qkv_bias: bool = False, attn_out_bias: bool = True, dropout: float = 0.0)[source]

Bases: Module

The linear scaled attention mechanism from the “Rethinking Attention with Performers” paper.

Parameters:
  • channels (int) – Size of each input sample.

  • heads (int, optional) – Number of parallel attention heads.

  • head_channels (int, optional) – Size of each attention head. (default: 64.)

  • kernel (Callable, optional) – Kernels for generalized attention. If not specified, ReLU kernel will be used. (default: torch.nn.ReLU())

  • qkv_bias (bool, optional) – If specified, add bias to query, key and value in the self attention. (default: False)

  • attn_out_bias (bool, optional) – If specified, add bias to the attention output. (default: True)

  • dropout (float, optional) – Dropout probability of the final attention output. (default: 0.0)

forward(x: Tensor, mask: Optional[Tensor] = None) Tensor[source]

Forward pass.

Parameters:
  • x (torch.Tensor) – Node feature tensor \(\mathbf{X} \in \mathbb{R}^{B \times N \times F}\), with batch-size \(B\), (maximum) number of nodes \(N\) for each graph, and feature dimension \(F\).

  • mask (torch.Tensor, optional) – Mask matrix \(\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}\) indicating the valid nodes for each graph. (default: None)

Return type:

Tensor

redraw_projection_matrix()[source]

As described in the paper, periodically redraw examples to improve overall approximation of attention.