torch_geometric.nn.models.SGFormer

class SGFormer(in_channels: int, hidden_channels: int, out_channels: int, trans_num_layers: int = 2, trans_num_heads: int = 1, trans_dropout: float = 0.5, gnn_num_layers: int = 3, gnn_dropout: float = 0.5, graph_weight: float = 0.5, aggregate: str = 'add')[source]

Bases: Module

The sgformer module from the “SGFormer: Simplifying and Empowering Transformers for Large-Graph Representations” paper.

Parameters:
  • in_channels (int) – Input channels.

  • hidden_channels (int) – Hidden channels.

  • out_channels (int) – Output channels.

  • trans_num_layers (int) – The number of layers for all-pair attention. (default: 2)

  • trans_num_heads (int) – The number of heads for attention. (default: 1)

  • trans_dropout (float) – Global dropout rate. (default: 0.5)

  • gnn_num_layers (int) – The number of layers for GNN. (default: 3)

  • gnn_dropout (float) – GNN dropout rate. (default: 0.5)

  • graph_weight (float) – The weight balance global and gnn module. (default: 0.5)

  • aggregate (str) – Aggregate type. (default: add)

forward(x: Tensor, edge_index: Tensor, batch: Optional[Tensor]) Tensor[source]

Forward pass.

Parameters:
  • x (torch.Tensor) – The input node features.

  • edge_index (torch.Tensor or SparseTensor) – The edge indices.

  • batch (torch.Tensor, optional) – The batch vector \(\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N\), which assigns each element to a specific example.

Return type:

Tensor

reset_parameters() None[source]
Return type:

None