torch_geometric.data.OnDiskDataset

class OnDiskDataset(root: str, transform: ~typing.Optional[~typing.Callable] = None, pre_filter: ~typing.Optional[~typing.Callable] = None, backend: str = 'sqlite', schema: ~typing.Union[~typing.Any, ~typing.Dict[str, ~typing.Any], ~typing.Tuple[~typing.Any], ~typing.List[~typing.Any]] = <class 'object'>, log: bool = True)[source]

Bases: Dataset

Dataset base class for creating large graph datasets which do not easily fit into CPU memory at once by leveraging a Database backend for on-disk storage and access of data objects.

Parameters:
  • root (str) – Root directory where the dataset should be saved.

  • transform (callable, optional) – A function/transform that takes in a Data or HeteroData object and returns a transformed version. The data object will be transformed before every access. (default: None)

  • pre_filter (callable, optional) – A function that takes in a Data or HeteroData object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: None)

  • backend (str) – The Database backend to use (one of "sqlite" or "rocksdb"). (default: "sqlite")

  • schema (Any or Tuple[Any] or Dict[str, Any], optional) – The schema of the input data. Can take int, float, str, object, or a dictionary with dtype and size keys (for specifying tensor data) as input, and can be nested as a tuple or dictionary. Specifying the schema will improve efficiency, since by default the database will use python pickling for serializing and deserializing. If specified to anything different than object, implementations of OnDiskDataset need to override serialize() and deserialize() methods. (default: object)

  • log (bool, optional) – Whether to print any console output while downloading and processing the dataset. (default: True)

property processed_file_names: str

The name of the files in the self.processed_dir folder that must be present in order to skip processing.

Return type:

str

property db: Database

Returns the underlying Database.

Return type:

Database

close() None[source]

Closes the connection to the underlying database.

Return type:

None

serialize(data: BaseData) Any[source]

Serializes the Data or HeteroData object into the expected DB schema.

Return type:

Any

deserialize(data: Any) BaseData[source]

Deserializes the DB entry into a Data or HeteroData object.

Return type:

BaseData

append(data: BaseData) None[source]

Appends the data object to the dataset.

Return type:

None

extend(data_list: Sequence[BaseData], batch_size: Optional[int] = None) None[source]

Extends the dataset by a list of data objects.

Return type:

None

get(idx: int) BaseData[source]

Gets the data object at index idx.

Return type:

BaseData

multi_get(indices: Union[Iterable[int], Tensor, slice, range], batch_size: Optional[int] = None) List[BaseData][source]

Gets a list of data objects from the specified indices.

Return type:

List[BaseData]

len() int[source]

Returns the number of data objects stored in the dataset.

Return type:

int

download() None

Downloads the dataset to the self.raw_dir folder.

Return type:

None

get_summary() Any

Collects summary statistics for the dataset.

Return type:

Any

property has_download: bool

Checks whether the dataset defines a download() method.

Return type:

bool

property has_process: bool

Checks whether the dataset defines a process() method.

Return type:

bool

index_select(idx: Union[slice, Tensor, ndarray, Sequence]) Dataset

Creates a subset of the dataset from specified indices idx. Indices idx can be a slicing object, e.g., [2:5], a list, a tuple, or a torch.Tensor or np.ndarray of type long or bool.

Return type:

Dataset

property num_classes: int

Returns the number of classes in the dataset.

Return type:

int

property num_edge_features: int

Returns the number of features per edge in the dataset.

Return type:

int

property num_features: int

Returns the number of features per node in the dataset. Alias for num_node_features.

Return type:

int

property num_node_features: int

Returns the number of features per node in the dataset.

Return type:

int

print_summary(fmt: str = 'psql') None

Prints summary statistics of the dataset to the console.

Parameters:

fmt (str, optional) – Summary tables format. Available table formats can be found here. (default: "psql")

Return type:

None

process() None

Processes the dataset to the self.processed_dir folder.

Return type:

None

property processed_paths: List[str]

The absolute filepaths that must be present in order to skip processing.

Return type:

List[str]

property raw_file_names: Union[str, List[str], Tuple[str, ...]]

The name of the files in the self.raw_dir folder that must be present in order to skip downloading.

Return type:

Union[str, List[str], Tuple[str, ...]]

property raw_paths: List[str]

The absolute filepaths that must be present in order to skip downloading.

Return type:

List[str]

shuffle(return_perm: bool = False) Union[Dataset, Tuple[Dataset, Tensor]]

Randomly shuffles the examples in the dataset.

Parameters:

return_perm (bool, optional) – If set to True, will also return the random permutation used to shuffle the dataset. (default: False)

Return type:

Union[Dataset, Tuple[Dataset, Tensor]]

to_datapipe() Any

Converts the dataset into a torch.utils.data.DataPipe.

The returned instance can then be used with built-in DataPipes for batching graphs as follows:

from torch_geometric.datasets import QM9

dp = QM9(root='./data/QM9/').to_datapipe()
dp = dp.batch_graphs(batch_size=2, drop_last=True)

for batch in dp:
    pass

See the PyTorch tutorial for further background on DataPipes.

Return type:

Any