Source code for torch_geometric.llm.rag_loader

from abc import abstractmethod
from typing import Any, Callable, Dict, Optional, Protocol, Tuple, Union

from torch_geometric.data import Data, FeatureStore, HeteroData
from torch_geometric.llm.utils.vectorrag import VectorRetriever
from torch_geometric.sampler import HeteroSamplerOutput, SamplerOutput
from torch_geometric.typing import InputEdges, InputNodes


class RAGFeatureStore(Protocol):
    """Feature store template for remote GNN RAG backend."""
    @abstractmethod
    def retrieve_seed_nodes(self, query: Any, **kwargs) -> InputNodes:
        """Makes a comparison between the query and all the nodes to get all
        the closest nodes. Return the indices of the nodes that are to be seeds
        for the RAG Sampler.
        """
        ...

    @property
    @abstractmethod
    def config(self) -> Dict[str, Any]:
        """Get the config for the RAGFeatureStore."""
        ...

    @config.setter
    @abstractmethod
    def config(self, config: Dict[str, Any]):
        """Set the config for the RAGFeatureStore."""
        ...

    @abstractmethod
    def retrieve_seed_edges(self, query: Any, **kwargs) -> InputEdges:
        """Makes a comparison between the query and all the edges to get all
        the closest nodes. Returns the edge indices that are to be the seeds
        for the RAG Sampler.
        """
        ...

    @abstractmethod
    def load_subgraph(
        self, sample: Union[SamplerOutput, HeteroSamplerOutput]
    ) -> Union[Data, HeteroData]:
        """Combines sampled subgraph output with features in a Data object."""
        ...


class RAGGraphStore(Protocol):
    """Graph store template for remote GNN RAG backend."""
    @abstractmethod
    def sample_subgraph(self, seed_nodes: InputNodes, seed_edges: InputEdges,
                        **kwargs) -> Union[SamplerOutput, HeteroSamplerOutput]:
        """Sample a subgraph using the seeded nodes and edges."""
        ...

    @property
    @abstractmethod
    def config(self) -> Dict[str, Any]:
        """Get the config for the RAGGraphStore."""
        ...

    @config.setter
    @abstractmethod
    def config(self, config: Dict[str, Any]):
        """Set the config for the RAGGraphStore."""
        ...

    @abstractmethod
    def register_feature_store(self, feature_store: FeatureStore):
        """Register a feature store to be used with the sampler. Samplers need
        info from the feature store in order to work properly on HeteroGraphs.
        """
        ...


# TODO: Make compatible with Heterographs


[docs]class RAGQueryLoader: """Loader meant for making RAG queries from a remote backend.""" def __init__(self, graph_data: Tuple[RAGFeatureStore, RAGGraphStore], subgraph_filter: Optional[Callable[[Data, Any], Data]] = None, augment_query: bool = False, vector_retriever: Optional[VectorRetriever] = None, config: Optional[Dict[str, Any]] = None): """Loader meant for making queries from a remote backend. Args: graph_data (Tuple[RAGFeatureStore, RAGGraphStore]): Remote FeatureStore and GraphStore to load from. Assumed to conform to the protocols listed above. subgraph_filter (Optional[Callable[[Data, Any], Data]], optional): Optional local transform to apply to data after retrieval. Defaults to None. augment_query (bool, optional): Whether to augment the query with retrieved documents. Defaults to False. vector_retriever (Optional[VectorRetriever], optional): VectorRetriever to use for retrieving documents. Defaults to None. config (Optional[Dict[str, Any]], optional): Config to pass into the RAGQueryLoader. Defaults to None. """ fstore, gstore = graph_data self.vector_retriever = vector_retriever self.augment_query = augment_query self.feature_store = fstore self.graph_store = gstore self.graph_store.edge_index = self.graph_store.edge_index.contiguous() self.graph_store.register_feature_store(self.feature_store) self.subgraph_filter = subgraph_filter self.config = config def _propagate_config(self, config: Dict[str, Any]): """Propagate the config the relevant components.""" self.feature_store.config = config self.graph_store.config = config @property def config(self): """Get the config for the RAGQueryLoader.""" return self._config @config.setter def config(self, config: Dict[str, Any]): """Set the config for the RAGQueryLoader. Args: config (Dict[str, Any]): The config to set. """ self._propagate_config(config) self._config = config
[docs] def query(self, query: Any) -> Data: """Retrieve a subgraph associated with the query with all its feature attributes. """ if self.vector_retriever: retrieved_docs = self.vector_retriever.query(query) if self.augment_query: query = [query] + retrieved_docs seed_nodes, query_enc = self.feature_store.retrieve_seed_nodes(query) subgraph_sample = self.graph_store.sample_subgraph(seed_nodes) data = self.feature_store.load_subgraph(sample=subgraph_sample) # apply local filter if self.subgraph_filter: data = self.subgraph_filter(data, query) if self.vector_retriever: data.text_context = retrieved_docs return data