Creating Message Passing Networks ================================= Generalizing the convolution operator to irregular domains is typically expressed as a *neighborhood aggregation* or *message passing* scheme. With :math:`\mathbf{x}^{(k-1)}_i \in \mathbb{R}^F` denoting node features of node :math:`i` in layer :math:`(k-1)` and :math:`\mathbf{e}_{j,i} \in \mathbb{R}^D` denoting (optional) edge features from node :math:`j` to node :math:`i`, message passing graph neural networks can be described as .. math:: \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \bigoplus_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), where :math:`\bigoplus` denotes a differentiable, permutation invariant function, *e.g.*, sum, mean or max, and :math:`\gamma` and :math:`\phi` denote differentiable functions such as MLPs (Multi Layer Perceptrons). .. contents:: :local: The "MessagePassing" Base Class ------------------------------- :pyg:`PyG` provides the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class, which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation. The user only has to define the functions :math:`\phi` , *i.e.* :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, and :math:`\gamma` , *i.e.* :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`, as well as the aggregation scheme to use, *i.e.* :obj:`aggr="add"`, :obj:`aggr="mean"` or :obj:`aggr="max"`. This is done with the help of the following methods: * :obj:`MessagePassing(aggr="add", flow="source_to_target", node_dim=-2)`: Defines the aggregation scheme to use (:obj:`"add"`, :obj:`"mean"` or :obj:`"max"`) and the flow direction of message passing (either :obj:`"source_to_target"` or :obj:`"target_to_source"`). Furthermore, the :obj:`node_dim` attribute indicates along which axis to propagate. * :obj:`MessagePassing.propagate(edge_index, size=None, **kwargs)`: The initial call to start propagating messages. Takes in the edge indices and all additional data which is needed to construct messages and to update node embeddings. Note that :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate` is not limited to exchanging messages in square adjacency matrices of shape :obj:`[N, N]` only, but can also exchange messages in general sparse assignment matrices, *e.g.*, bipartite graphs, of shape :obj:`[N, M]` by passing :obj:`size=(N, M)` as an additional argument. If set to :obj:`None`, the assignment matrix is assumed to be a square matrix. For bipartite graphs with two independent sets of nodes and indices, and each set holding its own information, this split can be marked by passing the information as a tuple, *e.g.* :obj:`x=(x_N, x_M)`. * :obj:`MessagePassing.message(...)`: Constructs messages to node :math:`i` in analogy to :math:`\phi` for each edge :math:`(j,i) \in \mathcal{E}` if :obj:`flow="source_to_target"` and :math:`(i,j) \in \mathcal{E}` if :obj:`flow="target_to_source"`. Can take any argument which was initially passed to :meth:`propagate`. In addition, tensors passed to :meth:`propagate` can be mapped to the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or :obj:`_j` to the variable name, *e.g.* :obj:`x_i` and :obj:`x_j`. Note that we generally refer to :math:`i` as the central nodes that aggregates information, and refer to :math:`j` as the neighboring nodes, since this is the most common notation. * :obj:`MessagePassing.update(aggr_out, ...)`: Updates node embeddings in analogy to :math:`\gamma` for each node :math:`i \in \mathcal{V}`. Takes in the output of aggregation as first argument and any argument which was initially passed to :func:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`. Let us verify this by re-implementing two popular GNN variants, the `GCN layer from Kipf and Welling `_ and the `EdgeConv layer from Wang et al. `_. Implementing the GCN Layer -------------------------- The `GCN layer `_ is mathematically defined as .. math:: \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{W}^{\top} \cdot \mathbf{x}_j^{(k-1)} \right) + \mathbf{b}, where neighboring node features are first transformed by a weight matrix :math:`\mathbf{W}`, normalized by their degree, and finally summed up. Lastly, we apply the bias vector :math:`\mathbf{b}` to the aggregated output. This formula can be divided into the following steps: 1. Add self-loops to the adjacency matrix. 2. Linearly transform node feature matrix. 3. Compute normalization coefficients. 4. Normalize node features in :math:`\phi`. 5. Sum up neighboring node features (:obj:`"add"` aggregation). 6. Apply a final bias vector. Steps 1-3 are typically computed before message passing takes place. Steps 4-5 can be easily processed using the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` base class. The full layer implementation is shown below: .. code-block:: python import torch from torch.nn import Linear, Parameter from torch_geometric.nn import MessagePassing from torch_geometric.utils import add_self_loops, degree class GCNConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='add') # "Add" aggregation (Step 5). self.lin = Linear(in_channels, out_channels, bias=False) self.bias = Parameter(torch.empty(out_channels)) self.reset_parameters() def reset_parameters(self): self.lin.reset_parameters() self.bias.data.zero_() def forward(self, x, edge_index): # x has shape [N, in_channels] # edge_index has shape [2, E] # Step 1: Add self-loops to the adjacency matrix. edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0)) # Step 2: Linearly transform node feature matrix. x = self.lin(x) # Step 3: Compute normalization. row, col = edge_index deg = degree(col, x.size(0), dtype=x.dtype) deg_inv_sqrt = deg.pow(-0.5) deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # Step 4-5: Start propagating messages. out = self.propagate(edge_index, x=x, norm=norm) # Step 6: Apply a final bias vector. out = out + self.bias return out def message(self, x_j, norm): # x_j has shape [E, out_channels] # Step 4: Normalize node features. return norm.view(-1, 1) * x_j :class:`~torch_geometric.nn.conv.GCNConv` inherits from :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` with :obj:`"add"` propagation. All the logic of the layer takes place in its :meth:`forward` method. Here, we first add self-loops to our edge indices using the :meth:`torch_geometric.utils.add_self_loops` function (step 1), as well as linearly transform node features by calling the :class:`torch.nn.Linear` instance (step 2). The normalization coefficients are derived by the node degrees :math:`\deg(i)` for each node :math:`i` which gets transformed to :math:`1/(\sqrt{\deg(i)} \cdot \sqrt{\deg(j)})` for each edge :math:`(j,i) \in \mathcal{E}`. The result is saved in the tensor :obj:`norm` of shape :obj:`[num_edges, ]` (step 3). We then call :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.propagate`, which internally calls :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message`, :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.aggregate` and :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.update`. We pass the node embeddings :obj:`x` and the normalization coefficients :obj:`norm` as additional arguments for message propagation. In the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we need to normalize the neighboring node features :obj:`x_j` by :obj:`norm`. Here, :obj:`x_j` denotes a *lifted* tensor, which contains the source node features of each edge, *i.e.*, the neighbors of each node. Node features can be automatically lifted by appending :obj:`_i` or :obj:`_j` to the variable name. In fact, any tensor can be converted this way, as long as they hold source or destination node features. That is all that it takes to create a simple message passing layer. You can use this layer as a building block for deep architectures. Initializing and calling it is straightforward: .. code-block:: python conv = GCNConv(16, 32) x = conv(x, edge_index) Implementing the Edge Convolution --------------------------------- The `edge convolutional layer `_ processes graphs or point clouds and is mathematically defined as .. math:: \mathbf{x}_i^{(k)} = \max_{j \in \mathcal{N}(i)} h_{\mathbf{\Theta}} \left( \mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)} - \mathbf{x}_i^{(k-1)} \right), where :math:`h_{\mathbf{\Theta}}` denotes an MLP. In analogy to the GCN layer, we can use the :class:`~torch_geometric.nn.conv.message_passing.MessagePassing` class to implement this layer, this time using the :obj:`"max"` aggregation: .. code-block:: python import torch from torch.nn import Sequential as Seq, Linear, ReLU from torch_geometric.nn import MessagePassing class EdgeConv(MessagePassing): def __init__(self, in_channels, out_channels): super().__init__(aggr='max') # "Max" aggregation. self.mlp = Seq(Linear(2 * in_channels, out_channels), ReLU(), Linear(out_channels, out_channels)) def forward(self, x, edge_index): # x has shape [N, in_channels] # edge_index has shape [2, E] return self.propagate(edge_index, x=x) def message(self, x_i, x_j): # x_i has shape [E, in_channels] # x_j has shape [E, in_channels] tmp = torch.cat([x_i, x_j - x_i], dim=1) # tmp has shape [E, 2 * in_channels] return self.mlp(tmp) Inside the :meth:`~torch_geometric.nn.conv.message_passing.MessagePassing.message` function, we use :obj:`self.mlp` to transform both the target node features :obj:`x_i` and the relative source node features :obj:`x_j - x_i` for each edge :math:`(j,i) \in \mathcal{E}`. The edge convolution is actually a dynamic convolution, which recomputes the graph for each layer using nearest neighbors in the feature space. Luckily, :pyg:`PyG` comes with a GPU accelerated batch-wise k-NN graph generation method named :meth:`torch_geometric.nn.pool.knn_graph`: .. code-block:: python from torch_geometric.nn import knn_graph class DynamicEdgeConv(EdgeConv): def __init__(self, in_channels, out_channels, k=6): super().__init__(in_channels, out_channels) self.k = k def forward(self, x, batch=None): edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow) return super().forward(x, edge_index) Here, :meth:`~torch_geometric.nn.pool.knn_graph` computes a nearest neighbor graph, which is further used to call the :meth:`forward` method of :class:`~torch_geometric.nn.conv.EdgeConv`. This leaves us with a clean interface for initializing and calling this layer: .. code-block:: python conv = DynamicEdgeConv(3, 128, k=6) x = conv(x, batch) Exercises --------- Imagine we are given the following :obj:`~torch_geometric.data.Data` object: .. code-block:: python import torch from torch_geometric.data import Data edge_index = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long) x = torch.tensor([[-1], [0], [1]], dtype=torch.float) data = Data(x=x, edge_index=edge_index.t().contiguous()) Try to answer the following questions related to :class:`~torch_geometric.nn.conv.GCNConv`: 1. What information does :obj:`row` and :obj:`col` hold? 2. What does :meth:`~torch_geometric.utils.degree` do? 3. Why do we use :obj:`degree(col, ...)` rather than :obj:`degree(row, ...)`? 4. What does :obj:`deg_inv_sqrt[col]` and :obj:`deg_inv_sqrt[row]` do? 5. What information does :obj:`x_j` hold in the :meth:`~torch_geometric.nn.conv.MessagePassing.message` function? If :obj:`self.lin` denotes the identity function, what is the exact content of :obj:`x_j`? 6. Add an :meth:`~torch_geometric.nn.conv.MessagePassing.update` function to :class:`~torch_geometric.nn.conv.GCNConv` that adds transformed central node features to the aggregated output. Try to answer the following questions related to :class:`~torch_geometric.nn.conv.EdgeConv`: 1. What is :obj:`x_i` and :obj:`x_j - x_i`? 2. What does :obj:`torch.cat([x_i, x_j - x_i], dim=1)` do? Why :obj:`dim = 1`?