Memory-Efficient Aggregations ============================= The :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface of :pyg:`PyG` relies on a gather-scatter scheme to aggregate messages from neighboring nodes. For example, consider the message passing layer .. math:: \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} \textrm{MLP}(\mathbf{x}_j - \mathbf{x}_i), that can be implemented as: .. code-block:: python from torch_geometric.nn import MessagePassing x = ... # Node features of shape [num_nodes, num_features] edge_index = ... # Edge indices of shape [2, num_edges] class MyConv(MessagePassing): def __init__(self): super().__init__(aggr="add") def forward(self, x, edge_index): return self.propagate(edge_index, x=x) def message(self, x_i, x_j): return MLP(x_j - x_i) Under the hood, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementation produces a code that looks as follows: .. code-block:: python from torch_geometric.utils import scatter x = ... # Node features of shape [num_nodes, num_features] edge_index = ... # Edge indices of shape [2, num_edges] x_j = x[edge_index[0]] # Source node features [num_edges, num_features] x_i = x[edge_index[1]] # Target node features [num_edges, num_features] msg = MLP(x_j - x_i) # Compute message for each edge # Aggregate messages based on target node indices out = scatter(msg, edge_index[1], dim=0, dim_size=x.size(0), reduce='sum') While the gather-scatter formulation generalizes to a lot of useful GNN implementations, it has the disadvantage of explicitly materalizing :obj:`x_j` and :obj:`x_i`, resulting in a high memory footprint on large and dense graphs. Luckily, not all GNNs need to be implemented by explicitly materalizing :obj:`x_j` and/or :obj:`x_i`. In some cases, GNNs can also be implemented as a simple-sparse matrix multiplication. As a general rule of thumb, this holds true for GNNs that do not make use of the central node features :obj:`x_i` or multi-dimensional edge features when computing messages. For example, the :class:`~torch_geometric.nn.conv.GINConv` layer .. math:: \mathbf{x}^{\prime}_i = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right), is equivalent to computing .. math:: \mathbf{X}^{\prime} = \textrm{MLP} \left( (1 + \epsilon) \cdot \mathbf{X} + \mathbf{A}\mathbf{X} \right), where :math:`\mathbf{A}` denotes a sparse adjacency matrix of shape :obj:`[num_nodes, num_nodes]`. This formulation allows to leverage dedicated and fast sparse-matrix multiplication implementations. In :pyg:`null` **PyG >= 1.6.0**, we officially introduce better support for sparse-matrix multiplication GNNs, resulting in a **lower memory footprint** and a **faster execution time**. As a result, we introduce the :class:`SparseTensor` class (from the :obj:`torch_sparse` package), which implements fast forward and backward passes for sparse-matrix multiplication based on the `"Design Principles for Sparse Matrix Multiplication on the GPU" `_ paper. Using the :class:`SparseTensor` class is straightforward and similar to the way :obj:`scipy` treats sparse matrices: .. code-block:: python from torch_sparse import SparseTensor adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=..., sparse_sizes=(num_nodes, num_nodes)) # value is optional and can be None # Obtain different representations (COO, CSR, CSC): row, col, value = adj.coo() rowptr, col, value = adj.csr() colptr, row, value = adj.csc() adj = adj[:100, :100] # Slicing, indexing and masking support adj = adj.set_diag() # Add diagonal entries adj_t = adj.t() # Transpose out = adj.matmul(x) # Sparse-dense matrix multiplication adj = adj.matmul(adj) # Sparse-sparse matrix multiplication # Creating SparseTensor instances: adj = SparseTensor.from_dense(mat) adj = SparseTensor.eye(100, 100) adj = SparseTensor.from_scipy(mat) Our :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface can handle both :obj:`torch.Tensor` and :class:`SparseTensor` as input for propagating messages. However, when holding a directed graph in :class:`SparseTensor`, you need to make sure to input the **transposed sparse matrix** to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`: .. code-block:: python conv = GCNConv(16, 32) out1 = conv(x, edge_index) out2 = conv(x, adj.t()) assert torch.allclose(out1, out2) conv = GINConv(nn=Sequential(Linear(16, 32), ReLU(), Linear(32, 32))) out1 = conv(x, edge_index) out2 = conv(x, adj.t()) assert torch.allclose(out1, out2) To leverage sparse-matrix multiplications, the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` interface introduces the :func:`~torch_geometric.nn.conv.message_passing.message_and_aggregate` function (which fuses the :func:`~torch_geometric.nn.conv.message_passing.message` and :func:`~torch_geometric.nn.conv.message_passing.aggregate` functions into a single computation step), which gets called whenever it is implemented and receives a :class:`SparseTensor` as input for :obj:`edge_index`. With it, the :class:`~torch_geometric.nn.conv.GINConv` layer can now be implemented as follows: .. code-block:: python import torch_sparse class GINConv(MessagePassing): def __init__(self): super().__init__(aggr="add") def forward(self, x, edge_index): out = self.propagate(edge_index, x=x) return MLP((1 + eps) x + out) def message(self, x_j): return x_j def message_and_aggregate(self, adj_t, x): return torch_sparse.matmul(adj_t, x, reduce=self.aggr) Playing around with the new :class:`SparseTensor` format is straightforward since all of our GNNs work with it out-of-the-box. To convert the :obj:`edge_index` format to the newly introduced :class:`SparseTensor` format, you can make use of the :class:`torch_geometric.transforms.ToSparseTensor` transform: .. code-block:: python import torch import torch.nn.functional as F from torch_geometric.nn import GCNConv import torch_geometric.transforms as T from torch_geometric.datasets import Planetoid dataset = Planetoid("Planetoid", name="Cora", transform=T.ToSparseTensor()) data = dataset[0] >>> Data(adj_t=[2708, 2708, nnz=10556], x=[2708, 1433], y=[2708], ...) class GNN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16, cached=True) self.conv2 = GCNConv(16, dataset.num_classes, cached=True) def forward(self, x, adj_t): x = self.conv1(x, adj_t) x = F.relu(x) x = self.conv2(x, adj_t) return F.log_softmax(x, dim=1) model = GNN() optimizer = torch.optim.Adam(model.parameters(), lr=0.01) def train(data): model.train() optimizer.zero_grad() out = model(data.x, data.adj_t) loss = F.nll_loss(out, data.y) loss.backward() optimizer.step() return float(loss) for epoch in range(1, 201): loss = train(data) All code remains the same as before, except for the :obj:`data` transform via :obj:`T.ToSparseTensor()`. As an additional advantage, :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` implementations that utilize the :class:`SparseTensor` class are deterministic on the GPU since aggregations no longer rely on atomic operations. Notably, the GNN layer execution slightly changes in case GNNs incorporate single or multi-dimensional edge information :obj:`edge_weight` or :obj:`edge_attr` into their message passing formulation, respectively. In particular, it is now expected that these attributes are directly added as values to the :class:`SparseTensor` object. Instead of calling the GNN as .. code-block:: python conv = GMMConv(16, 32, dim=3) out = conv(x, edge_index, edge_attr) we now execute our GNN operator as .. code-block:: python conv = GMMConv(16, 32, dim=3) adj = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_attr) out = conv(x, adj.t()) .. note:: Since this feature is still experimental, some operations, *e.g.*, graph pooling methods, may still require you to input the :obj:`edge_index` format. You can convert :obj:`adj_t` back to :obj:`(edge_index, edge_attr)` via: .. code-block:: python row, col, edge_attr = adj_t.t().coo() edge_index = torch.stack([row, col], dim=0) Please let us know what you think of :class:`SparseTensor`, how we can improve it, and whenever you encounter any unexpected behavior.