Source code for torch_geometric.nn.conv.meshcnn_conv

# The below is to suppress the warning on torch.nn.conv.MeshCNNConv::update
# pyright: reportIncompatibleMethodOverride=false
import warnings
from typing import Optional

import torch
from torch.nn import Linear, Module, ModuleList

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Tensor


[docs]class MeshCNNConv(MessagePassing): r"""The convolutional layer introduced by the paper `"MeshCNN: A Network With An Edge" <https://arxiv.org/abs/1809.05910>`_. Recall that, given a set of categories :math:`C`, MeshCNN is a function that takes as its input a triangular mesh :math:`\mathcal{m} = (V, F) \in \mathbb{R}^{|V| \times 3} \times \{0,...,|V|-1\}^{3 \times |F|}`, and returns as its output a :math:`|C|`-dimensional vector, whose :math:`i` th component denotes the probability of the input mesh belonging to category :math:`c_i \in C`. Let :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}` denote the output value of the prior (e.g. :math:`k` th ) layer of our neural network. The :math:`i` th row of :math:`X^{(k)}` is a :math:`\text{Dim-Out}(k)`-dimensional vector that represents the features computed by the :math:`k` th layer for edge :math:`e_i` of the input mesh :math:`\mathcal{m}`. Let :math:`A \in \{0, ..., |E|-1\}^{2 \times 4*|E|}` denote the *edge adjacency* matrix of our input mesh :math:`\mathcal{m}`. The :math:`j` th column of :math:`A` returns a pair of indices :math:`k,l \in \{0,...,|E|-1\}`, which means that edge :math:`e_k` is adjacent to edge :math:`e_l` in our input mesh :math:`\mathcal{m}`. The definition of edge adjacency in a triangular mesh is illustrated in Figure 1. In a triangular mesh, each edge :math:`e_i` is expected to be adjacent to exactly :math:`4` neighboring edges, hence the number of columns of :math:`A`: :math:`4*|E|`. We write *the neighborhood* of edge :math:`e_i` as :math:`\mathcal{N}(i) = (a(i), b(i), c(i), d(i))` where 1. :math:`a(i)` denotes the index of the *first* counter-clockwise edge of the face *above* :math:`e_i`. 2. :math:`b(i)` denotes the index of the *second* counter-clockwise edge of the face *above* :math:`e_i`. 3. :math:`c(i)` denotes the index of the *first* counter-clockwise edge of the face *below* :math:`e_i`. 4. :math:`d(i)` denotes the index of the *second* counter-clockwise edge of the face *below* :math:`e_i`. .. figure:: ../_figures/meshcnn_edge_adjacency.svg :align: center :width: 80% **Figure 1:** The neighbors of edge :math:`\mathbf{e_1}` are :math:`\mathbf{e_2}, \mathbf{e_3}, \mathbf{e_4}` and :math:`\mathbf{e_5}`, respectively. We write this as :math:`\mathcal{N}(1) = (a(1), b(1), c(1), d(1)) = (2, 3, 4, 5)` Because of this ordering constraint, :obj:`MeshCNNConv` **requires that the columns of** :math:`A` **be ordered in the following way**: .. math:: &A[:,0] = (0, \text{The index of the "a" edge for edge } 0) \\ &A[:,1] = (0, \text{The index of the "b" edge for edge } 0) \\ &A[:,2] = (0, \text{The index of the "c" edge for edge } 0) \\ &A[:,3] = (0, \text{The index of the "d" edge for edge } 0) \\ \vdots \\ &A[:,4*|E|-4] = \bigl(|E|-1, a\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-3] = \bigl(|E|-1, b\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-2] = \bigl(|E|-1, c\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-1] = \bigl(|E|-1, d\bigl(|E|-1\bigr)\bigr) Stated a bit more compactly, for every edge :math:`e_i` in the input mesh, :math:`A`, should have the following entries .. math:: A[:, 4*i] &= (i, a(i)) \\ A[:, 4*i + 1] &= (i, b(i)) \\ A[:, 4*i + 2] &= (i, c(i)) \\ A[:, 4*i + 3] &= (i, d(i)) To summarize so far, we have defined 3 things: 1. The activation of the prior (e.g. :math:`k` th) layer, :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}` 2. The edge adjacency matrix and the definition of edge adjacency. :math:`A \in \{0,...,|E|-1\}^{2 \times 4*|E|}` 3. The ways the columns of :math:`A` must be ordered. We are now finally able to define the :obj:`MeshCNNConv` class/layer. In the following definition we assume :obj:`MeshCNNConv` is at the :math:`k+1` th layer of our neural network. The :obj:`MeshCNNConv` layer is a function, .. math:: \text{MeshCNNConv}^{(k+1)}(X^{(k)}, A) = X^{(k+1)}, that, given the prior layer's output :math:`X^{(k)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k)}` and the edge adjacency matrix :math:`A` of the input mesh (graph) :math:`\mathcal{m}` , returns a new edge feature tensor :math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \text{Dim-Out}(k+1)}`, where the :math:`i` th row of :math:`X^{(k+1)}`, denoted by :math:`x^{(k+1)}_i`, represents the :math:`\text{Dim-Out}(k+1)`-dimensional feature vector of edge :math:`e_i`, **and is defined as follows**: .. math:: x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\ &+ W^{(k+1)}_1 \bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \bigr| \\ &+ W^{(k+1)}_2 \bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \bigr) \\ &+ W^{(k+1)}_3 \bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \bigr| \\ &+ W^{(k+1)}_4 \bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \bigr). :math:`W_0^{(k+1)},W_1^{(k+1)},W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)} \in \mathbb{R}^{\text{Dim-Out}(k+1) \times \text{Dim-Out}(k)}` are trainable linear functions (i.e. "the weights" of this layer). :math:`x_i` is the :math:`\text{Dim-Out}(k)`-dimensional feature of edge :math:`e_i` vector computed by the prior (e.g. :math:`k`) th layer. :math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`, and :math:`x^{(k)}_{d(i)}` are the :math:`\text{Dim-Out}(k)`-feature vectors, computed in the :math:`k` th layer, that are associated with the :math:`4` neighboring edges of :math:`e_i`. Args: in_channels (int): Corresponds to :math:`\text{Dim-Out}(k)` in the above overview. This represents the output dimension of the prior layer. For the given input mesh :math:`\mathcal{m} = (V, F)`, the prior layer is expected to output a :math:`X \in \mathbb{R}^{|E| \times \textit{in_channels}}` feature matrix. Assuming the instance of this class is situated at layer :math:`k+1`, we write that :math:`X^{(k)} \in \mathbb{R}^{|E| \times \textit{in_channels}}`. out_channels (int): Corresponds to :math:`\text{Dim-Out}(k+1)` in the above overview. This represents the output dimension of this layer. Assuming the instance of this class is situated at layer :math:`k+1`, we write that :math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \textit{out_channels}}`. kernels (torch.nn.ModuleList, optional): A list of length of 5, where each element is a :class:`torch.nn.module` (i.e a neural network), that each MUST take as input a vector of dimension :`obj:in_channels` and return a vector of dimension :obj:`out_channels`. In particular, `obj:kernels[0]` is :math:`W^{(k+1)}_0` in the above overview (see :obj:`MeshCNNConv`), `obj:kernels[1]` is :math:`W^{(k+1)}_1`, `obj:kernels[2]` is :math:`W^{(k+1)}_2`, `obj:kernels[3]` is :math:`W^{(k+1)}_3` `obj:kernels[4]` is :math:`W^{(k+1)}_4`. Note that this input is optional, in which case each of the 5 elements in the kernels will be a linear neural network :class:`torch.nn.modules.Linear` correctly configured to take as input :attr:`in_channels`-dimensional vectors and return a vector of dimensions :attr:`out_channels`. Discussion: The key difference that separates :obj:`MeshCNNConv` from a traditional message passing graph neural network is that :obj:`MeshCNNConv` requires the set of neighbors for a node :math:`\mathcal{N}(u) = (v_1, v_2, ...)` to *be an ordered set* (i.e. a tuple). In fact, :obj:`MeshCNNConv` goes further, requiring that :math:`\mathcal{N}(u)` always return a set of size :math:`4`. This is different to most message passing graph neural networks, which assume that :math:`\mathcal{N}(u) = \{v_1, v_2, ...\}` returns an ordered set. This lends :obj:`MeshCNNConv` more expressive power, at the cost of no longer being permutation invariant to :math:`\mathbb{S}_4`. Put more plainly, in tradition message passing GNNs, the network is *unable* to distinguish one neighboring node from another. In contrast, in :obj:`MeshCNNConv`, each of the 4 neighbors has a "role", either the "a", "b", "c", or "d" neighbor. We encode this fact by requiring that :math:`\mathcal{N}` return the 4-tuple, where the first component is the "a" neighbor, and so on. To summarize this comparison, it may re-define :obj:`MeshCNNConv` in terms of :math:`\text{UPDATE}` and :math:`\text{AGGREGATE}` functions, which is a general way to define a traditional GNN layer. If we let :math:`x_i^{(k+1)}` denote the output of a GNN layer for node :math:`i` at layer :math:`k+1`, and let :math:`\mathcal{N}(i)` denote the set of nodes adjacent to node :math:`i`, then we can describe the :math:`k+1` th layer as traditional GNN as .. math:: x_i^{(k+1)} = \text{UPDATE}^{(k+1)}\bigl(x^{(k)}_i, \text{AGGREGATE}^{(k+1)}\bigl(\mathcal{N}(i)\bigr)\bigr). Here, :math:`\text{UPDATE}^{(k+1)}` is a function of :math:`2` :math:`\text{Dim-Out}(k)`-dimensional vectors, and returns a :math:`\text{Dim-Out}(k+1)`-dimensional vector. :math:`\text{AGGREGATE}^{(k+1)}` function is a function of a *unordered set* of nodes that are neighbors of node :math:`i`, as defined by :math:`\mathcal{N}(i)`. Usually the size of this set varies across different nodes :math:`i`, and one of the most basic examples of such a function is the "sum aggregation", defined as :math:`\text{AGGREGATE}^{(k+1)}(\mathcal{N}(i)) = \sum_{j \in \mathcal{N}(i)} x^{(k)}_j`. See :class:`SumAggregation <torch_geometric.nn.aggr.basic.SumAggregation>` for more. In contrast, while :obj:`MeshCNNConv` 's :math:`\text{UPDATE}` function follows a tradition GNN, its :math:`\text{AGGREGATE}` is a function of a tuple (i.e. an ordered set) of neighbors rather than a unordered set of neighbors. In particular, while the :math:`\text{UPDATE}` function of :obj:`MeshCNNConv` for :math:`e_i` is .. math:: x_i^{(k+1)} = \text{UPDATE}^{(k+1)}(x_i^{(k)}, s_i^{(k+1)}) = W_0^{(k+1)}x_i^{(k)} + s_i^{(k+1)}, in contrast, :obj:`MeshCNNConv` 's :math:`\text{AGGREGATE}` function is .. math:: s_i^{(k+1)} = \text{AGGREGATE}^{(k+1)}(A, B, C, D) &= W_1^{(k+1)}\bigl|A - C \bigr| \\ &= W_2^{(k+1)}\bigl(A + C \bigr) \\ &= W_3^{(k+1)}\bigl|B - D \bigr| \\ &= W_4^{(k+1)}\bigl(B + D \bigr), where :math:`A=x_{a(i)}^{(k)}, B=x_{b(i)}^{(k)}, C=x_{c(i)}^{(k)},` and :math:`D=x_{d(i)}^{(k)}`. .. The :math:`i` th row of :math:`V \in \mathbb{R}^{|V| \times 3}` holds the cartesian :math:`xyz` coordinates for node :math:`v_i` in the mesh, and the :math:`j` th column in :math:`F \in \{1,...,|V|\}^{3 \times |V|}` holds the :math:`3` indices :math:`(k,l,m)` that correspond to the :math:`3` nodes :math:`(v_k, v_l, v_m)` that construct face :math:`j` of the mesh. """ def __init__(self, in_channels: int, out_channels: int, kernels: Optional[ModuleList] = None): super().__init__(aggr='add') self.in_channels = in_channels self.out_channels = out_channels if kernels is None: self.kernels = ModuleList( [Linear(in_channels, out_channels) for _ in range(5)]) else: # ensures kernels is properly formed, otherwise throws # the appropriate error. self._assert_kernels(kernels) self.kernels = kernels
[docs] def forward(self, x: Tensor, edge_index: Tensor): r"""Forward pass. Args: x(torch.Tensor): :math:`X^{(k)} \in \mathbb{R}^{|E| \times \textit{in_channels}}`. The edge feature tensor returned by the prior layer (e.g. :math:`k`). The tensor is of shape :math:`|E| \times \text{Dim-Out}(k)`, or equivalently, :obj:`(|E|, self.in_channels)`. edge_index(torch.Tensor): :math:`A \in \{0,...,|E|-1\}^{2 \times 4*|E|}`. The edge adjacency tensor of the networks input mesh :math:`\mathcal{m} = (V, F)`. The edge adjacency tensor **MUST** have the following form: .. math:: &A[:,0] = (0, \text{The index of the "a" edge for edge } 0) \\ &A[:,1] = (0, \text{The index of the "b" edge for edge } 0) \\ &A[:,2] = (0, \text{The index of the "c" edge for edge } 0) \\ &A[:,3] = (0, \text{The index of the "d" edge for edge } 0) \\ \vdots \\ &A[:,4*|E|-4] = \bigl(|E|-1, a\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-3] = \bigl(|E|-1, b\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-2] = \bigl(|E|-1, c\bigl(|E|-1\bigr)\bigr) \\ &A[:,4*|E|-1] = \bigl(|E|-1, d\bigl(|E|-1\bigr)\bigr) See :obj:`MeshCNNConv` for what "index of the 'a'(b,c,d) edge for edge i" means, and also for the general definition of edge adjacency in MeshCNN. These definitions are also provided in the `paper <https://arxiv.org/abs/1809.05910>`_ itself. Returns: torch.Tensor: :math:`X^{(k+1)} \in \mathbb{R}^{|E| \times \textit{out_channels}}`. The edge feature tensor for this (e.g. the :math:`k+1` th) layer. The :math:`i` th row of :math:`X^{(k+1)}` is computed according to the formula .. math:: x^{(k+1)}_i &= W^{(k+1)}_0 x^{(k)}_i \\ &+ W^{(k+1)}_1 \bigl| x^{(k)}_{a(i)} - x^{(k)}_{c(i)} \bigr| \\ &+ W^{(k+1)}_2 \bigl( x^{(k)}_{a(i)} + x^{(k)}_{c(i)} \bigr) \\ &+ W^{(k+1)}_3 \bigl| x^{(k)}_{b(i)} - x^{(k)}_{d(i)} \bigr| \\ &+ W^{(k+1)}_4 \bigl( x^{(k)}_{b(i)} + x^{(k)}_{d(i)} \bigr), where :math:`W_0^{(k+1)},W_1^{(k+1)}, W_2^{(k+1)},W_3^{(k+1)}, W_4^{(k+1)} \in \mathbb{R}^{\text{Dim-Out}(k+1) \times \text{Dim-Out}(k)}` are the trainable linear functions (i.e. the trainable "weights") of this layer, and :math:`x^{(k)}_{a(i)}, x^{(k)}_{b(i)}, x^{(k)}_{c(i)}`, :math:`x^{(k)}_{d(i)}` are the :math:`\text{Dim-Out}(k)`-dimensional edge feature vectors computed by the prior (:math:`k` th) layer, that are associated with the :math:`4` neighboring edges of :math:`e_i`. """ return self.propagate(edge_index, x=x)
def message(self, x_j: Tensor) -> Tensor: r"""The messaging passing step of :obj:`MeshCNNConv`. Args: x_j: A :obj:`[4*|E|, num_node_features]` tensor. Its ith row holds the value stored by the source node in the previous layer of edge i. Returns: A :obj:`[|E|, num_node_features]` tensor, whose ith row will be the value that the target node of edge i will receive. """ # The following variables names are taken from the paper # MeshCNN computes the features associated with edge # e by (|a - c|, a + c, |b - c|, b + c), where a, b, c, d are the # neighboring edges of e, a being the 1 edge of the upper face, # b being the second edge of the upper face, c being the first edge # of the lower face, # and d being the second edge of the lower face of the input Mesh # TODO: It is unclear if view is faster. If it is not, # then we should prefer the strided method commented out below E4, in_channels = x_j.size() # E4 = 4|E|, i.e. num edges in line graph # Option 1 n_a = x_j[0::4] # shape: |E| x in_channels n_b = x_j[1::4] # shape: |E| x in_channels n_c = x_j[2::4] # shape: |E| x in_channels n_d = x_j[3::4] # shape: |E| x in_channels m = torch.empty(E4, self.out_channels) m[0::4] = self.kernels[1].forward(torch.abs(n_a - n_c)) m[1::4] = self.kernels[2].forward(n_a + n_c) m[2::4] = self.kernels[3].forward(torch.abs(n_b - n_d)) m[3::4] = self.kernels[4].forward(n_b + n_d) return m # Option 2 # E4, in_channels = x_j.size() # E = E4 // 4 # x_j = x_j.view(E, 4, in_channels) # shape: (|E| x 4 x in_channels) # n_a, n_b, n_c, n_d = x_j.unbind( # dim=1) # shape: (4 x |E| x in_channels) # m = torch.stack( # [ # (n_a - n_c).abs(), # shape: |E| x in_channels # n_a + n_c, # (n_b - n_d).abs(), # n_b + n_d, # ], # dim=1) # shape: (|E| x 4 x in_channels) # m.view(E4, in_channels) # shape 4*|E| x in_channels # return m def update(self, inputs: Tensor, x: Tensor) -> Tensor: r"""The UPDATE step, in reference to the UPDATE and AGGREGATE formulation of message passing convolution. Args: inputs(torch.Tensor): The :attr:`in_channels`-dimensional vector returned by aggregate. x(torch.Tensor): :math:`X^{(k)}`. The original inputs to this layer. Returns: torch.Tensor: :math:`X^{(k+1)}`. The output of this layer, which has shape :obj:`(|E|, out_channels)`. """ return self.kernels[0].forward(x) + inputs def _assert_kernels(self, kernels: ModuleList): r"""Ensures that :obj:`kernels` is a list of 5 :obj:`torch.nn.Module` modules (i.e. networks). In addition, it also ensures that each network takes in input of dimension :attr:`in_channels`, and returns output of dimension :attr:`out_channels`. This method throws an error otherwise. .. warn:: This method throws an error if :obj:`kernels` is not valid. (Otherwise this method returns nothing) """ assert isinstance(kernels, ModuleList), \ f"Parameter 'kernels' must be a \ torch.nn.module.ModuleList with 5 members, but we got \ {type(kernels)}." assert len(kernels) == 5, "Parameter 'kernels' must be a \ torch.nn.module.ModuleList of with exactly 5 members" for i, network in enumerate(kernels): assert isinstance(network, Module), \ f"kernels[{i}] must be torch.nn.Module, got \ {type(network)}" if not hasattr(network, "in_channels") and \ not hasattr(network, "in_features"): warnings.warn( f"kernel[{i}] does not have attribute 'in_channels' nor " f"'out_features'. The network must take as input a " f"{self.in_channels}-dimensional tensor.", stacklevel=2) else: input_dimension = getattr(network, "in_channels", network.in_features) assert input_dimension == self.in_channels, f"The input \ dimension of the neural network in kernel[{i}] must \ be \ equal to 'in_channels', but input_dimension = \ {input_dimension}, and \ self.in_channels={self.in_channels}." if not hasattr(network, "out_channels") and \ not hasattr(network, "out_features"): warnings.warn( f"kernel[{i}] does not have attribute 'in_channels' nor " f"'out_features'. The network must take as input a " f"{self.in_channels}-dimensional tensor.", stacklevel=2) else: output_dimension = getattr(network, "out_channels", network.out_features) assert output_dimension == self.out_channels, f"The output \ dimension of the neural network in kernel[{i}] must \ be \ equal to 'out_channels', but out_dimension = \ {output_dimension}, and \ self.out_channels={self.out_channels}."