Source code for torch_geometric.transforms.face_to_edge

import torch

from torch_geometric.data import Data
from torch_geometric.data.datapipes import functional_transform
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import to_undirected


[docs]@functional_transform('face_to_edge') class FaceToEdge(BaseTransform): r"""Converts mesh faces of shape :obj:`[3, num_faces]` or :obj:`[4, num_faces]` to edge indices of shape :obj:`[2, num_edges]` (functional name: :obj:`face_to_edge`). This transform supports both 2D triangular faces, represented by a tensor of shape :obj:`[3, num_faces]`, and 3D tetrahedral mesh faces, represented by a tensor of shape :obj:`[4, num_faces]`. It will convert these faces into edge indices, where each edge is defined by the indices of its two endpoints. Args: remove_faces (bool, optional): If set to :obj:`False`, the face tensor will not be removed. """ def __init__(self, remove_faces: bool = True) -> None: self.remove_faces = remove_faces def forward(self, data: Data) -> Data: if hasattr(data, 'face'): assert data.face is not None face = data.face if face.size(0) not in [3, 4]: raise RuntimeError(f"Expected 'face' tensor with shape " f"[3, num_faces] or [4, num_faces] " f"(got {list(face.size())})") if face.size()[0] == 3: edge_index = torch.cat([ face[:2], face[1:], face[::2], ], dim=1) else: assert face.size()[0] == 4 edge_index = torch.cat([ face[:2], face[1:3], face[2:4], face[::2], face[1::2], face[::3], ], dim=1) edge_index = to_undirected(edge_index, num_nodes=data.num_nodes) data.edge_index = edge_index if self.remove_faces: data.face = None return data