Source code for torch_geometric.datasets.medshapenet

import os
import os.path as osp
from typing import Callable, List, Optional

import numpy as np
import torch

from torch_geometric.data import Data, InMemoryDataset


[docs]class MedShapeNet(InMemoryDataset): r"""The MedShapeNet datasets from the `"MedShapeNet -- A Large-Scale Dataset of 3D Medical Shapes for Computer Vision" <https://arxiv.org/abs/2308.16139>`_ paper, containing 8 different type of structures (classes). .. note:: Data objects hold mesh faces instead of edge indices. To convert the mesh to a graph, use the :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`. To convert the mesh to a point cloud, use the :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to sample a fixed number of points on the mesh faces according to their face area. Args: root (str): Root directory where the dataset should be saved. size (int): Number of invividual 3D structures to download per type (classes). transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ def __init__( self, root: str, size: int = 100, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ) -> None: self.size = size super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) path = self.processed_paths[0] self.load(path) @property def raw_file_names(self) -> List[str]: return [ '3DTeethSeg', 'CoronaryArteries', 'FLARE', 'KITS', 'PULMONARY', 'SurgicalInstruments', 'ThoracicAorta_Saitta', 'ToothFairy' ] @property def processed_file_names(self) -> List[str]: return ['dataset.pt'] @property def raw_paths(self) -> List[str]: r"""The absolute filepaths that must be present in order to skip downloading. """ return [osp.join(self.raw_dir, f) for f in self.raw_file_names] def process(self) -> None: import urllib3 from MedShapeNet import MedShapeNet as msn msn_instance = msn(timeout=120) urllib3.HTTPConnectionPool("medshapenet.ddns.net", maxsize=50) list_of_datasets = msn_instance.datasets(False) list_of_datasets = list( filter( lambda x: x not in [ 'medshapenetcore/ASOCA', 'medshapenetcore/AVT', 'medshapenetcore/AutoImplantCraniotomy', 'medshapenetcore/FaceVR' ], list_of_datasets)) subset = [] for dataset in list_of_datasets: self.newpath = self.root + '/' + dataset.split("/")[1] if not os.path.exists(self.newpath): os.makedirs(self.newpath) stl_files = msn_instance.dataset_files(dataset, '.stl') subset.extend(stl_files[:self.size]) for stl_file in stl_files[:self.size]: msn_instance.download_stl_as_numpy(bucket_name=dataset, stl_file=stl_file, output_dir=self.newpath, print_output=False) class_mapping = { '3DTeethSeg': 0, 'CoronaryArteries': 1, 'FLARE': 2, 'KITS': 3, 'PULMONARY': 4, 'SurgicalInstruments': 5, 'ThoracicAorta_Saitta': 6, 'ToothFairy': 7 } for dataset, path in zip([subset], self.processed_paths): data_list = [] for item in dataset: class_name = item.split("/")[0] item = item.split("stl")[0] target = class_mapping[class_name] file = osp.join(self.root, item + 'npz') data = np.load(file) pre_data_list = Data( pos=torch.tensor(data["vertices"], dtype=torch.float), face=torch.tensor(data["faces"], dtype=torch.long).t().contiguous()) pre_data_list.y = torch.tensor([target], dtype=torch.long) data_list.append(pre_data_list) if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, path)