from typing import Any, Dict, Optional, Tuple, Union
import torch
from torch import Tensor
from torch_geometric.data import FeatureStore
from torch_geometric.distributed import LocalGraphStore
from torch_geometric.sampler import (
BidirectionalNeighborSampler,
NodeSamplerInput,
SamplerOutput,
)
from torch_geometric.utils import index_sort
# A representation of an edge index, following the possible formats:
# * default: Tensor, size = [2, num_edges]
# * Tensor[0, :] == row, Tensor[1, :] == col
# * COO: (row, col)
# * CSC: (row, colptr)
# * CSR: (rowptr, col)
_EdgeTensorType = Union[Tensor, Tuple[Tensor, Tensor]]
[docs]class NeighborSamplingRAGGraphStore(LocalGraphStore):
"""Neighbor sampling based graph-store to store & retrieve graph data."""
def __init__( # type: ignore[no-untyped-def]
self,
feature_store: Optional[FeatureStore] = None,
**kwargs,
):
"""Initializes the graph store.
Optional feature store and neighbor sampling settings.
Args:
feature_store (optional): The feature store to use.
None if not yet registered.
**kwargs (optional):
Additional keyword arguments for neighbor sampling.
"""
self.feature_store = feature_store
self.sample_kwargs = kwargs
self._sampler_is_initialized = False
self._config: Dict[str, Any] = {}
# to be set by the config
self.num_neighbors = None
super().__init__()
@property
def config(self) -> Dict[str, Any]:
"""Get the config for the feature store."""
return self._config
def _set_from_config(self, config: Dict[str, Any], attr_name: str) -> None:
"""Set an attribute from the config.
Args:
config (Dict[str, Any]): Config dictionary
attr_name (str): Name of attribute to set
Raises:
ValueError: If required attribute not found in config
"""
if attr_name not in config:
raise ValueError(
f"Required config parameter '{attr_name}' not found")
setattr(self, attr_name, config[attr_name])
@config.setter # type: ignore
def config(self, config: Dict[str, Any]) -> None:
"""Set the config for the feature store.
Args:
config (Dict[str, Any]):
Config dictionary containing required parameters
Raises:
ValueError: If required parameters missing from config
"""
self._set_from_config(config, "num_neighbors")
if hasattr(self, 'sampler'):
self.sampler.num_neighbors = ( # type: ignore[has-type]
self.num_neighbors)
self._config = config
def _init_sampler(self) -> None:
"""Initializes neighbor sampler with the registered feature store."""
if self.feature_store is None:
raise AttributeError("Feature store not registered yet.")
assert self.num_neighbors is not None, \
"Please set num_neighbors through config"
self.sampler = BidirectionalNeighborSampler(
data=(self.feature_store, self), num_neighbors=self.num_neighbors,
**self.sample_kwargs)
self._sampler_is_initialized = True
[docs] def register_feature_store(self, feature_store: FeatureStore) -> None:
"""Registers a feature store with the graph store.
:param feature_store: The feature store to register.
"""
self.feature_store = feature_store
self._sampler_is_initialized = False
[docs] def put_edge_id( # type: ignore[no-untyped-def]
self, edge_id: Tensor, *args, **kwargs) -> bool:
"""Stores an edge ID in the graph store.
:param edge_id: The edge ID to store.
:return: Whether the operation was successful.
"""
ret = super().put_edge_id(edge_id.contiguous(), *args, **kwargs)
self._sampler_is_initialized = False
return ret
@property
def edge_index(self) -> _EdgeTensorType:
"""Gets the edge index of the graph.
:return: The edge index as a tensor.
"""
return self.get_edge_index(*self.edge_idx_args, **self.edge_idx_kwargs)
[docs] def put_edge_index( # type: ignore[no-untyped-def]
self, edge_index: _EdgeTensorType, *args, **kwargs) -> bool:
"""Stores an edge index in the graph store.
:param edge_index: The edge index to store.
:return: Whether the operation was successful.
"""
ret = super().put_edge_index(edge_index, *args, **kwargs)
# HACK
self.edge_idx_args = args
self.edge_idx_kwargs = kwargs
self._sampler_is_initialized = False
return ret
# HACKY
@edge_index.setter # type: ignore
def edge_index(self, edge_index: _EdgeTensorType) -> None:
"""Sets the edge index of the graph.
:param edge_index: The edge index to set.
"""
# correct since we make node list from triples
if isinstance(edge_index, Tensor):
num_nodes = int(edge_index.max()) + 1
else:
assert isinstance(edge_index, tuple) \
and isinstance(edge_index[0], Tensor) \
and isinstance(edge_index[1], Tensor), \
"edge_index must be a Tensor of [2, num_edges] \
or a tuple of Tensors, (row, col)."
num_nodes = int(edge_index[0].max()) + 1
attr = dict(
edge_type=None,
layout='coo',
size=(num_nodes, num_nodes),
is_sorted=False,
)
# edge index needs to be sorted here and the perm saved for later
col_sorted, self.perm = index_sort(edge_index[1], num_nodes,
stable=True)
row_sorted = edge_index[0][self.perm]
edge_index_sorted = torch.stack([row_sorted, col_sorted], dim=0)
self.put_edge_index(edge_index_sorted, **attr)
[docs] def sample_subgraph(
self,
seed_nodes: Tensor,
) -> SamplerOutput:
"""Sample the graph starting from the given nodes using the
in-built NeighborSampler.
Args:
seed_nodes (InputNodes): Seed nodes to start sampling from.
num_neighbors (Optional[NumNeighborsType], optional): Parameters
to determine how many hops and number of neighbors per hop.
Defaults to None.
Returns:
Union[SamplerOutput, HeteroSamplerOutput]: NeighborSamplerOutput
for the input.
"""
# TODO add support for Hetero
if not self._sampler_is_initialized:
self._init_sampler()
seed_nodes = seed_nodes.unique().contiguous()
node_sample_input = NodeSamplerInput(input_id=None, node=seed_nodes)
out = self.sampler.sample_from_nodes( # type: ignore[has-type]
node_sample_input)
# edge ids need to be remapped to the original indices
out.edge = self.perm[out.edge]
return out