Source code for torch_geometric.llm.utils.vectorrag

# mypy: ignore-errors
import os
from abc import abstractmethod
from typing import Any, Callable, Dict, List, Optional, Protocol, Union

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.llm.models import SentenceTransformer
from torch_geometric.llm.utils.backend_utils import batch_knn


class VectorRetriever(Protocol):
    """Protocol for VectorRAG."""
    @abstractmethod
    def query(self, query: Any, **kwargs: Optional[Dict[str, Any]]) -> Data:
        """Retrieve a context for a given query."""
        ...


[docs]class DocumentRetriever(VectorRetriever): """Retrieve documents from a vector database.""" def __init__(self, raw_docs: List[str], embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2, model: Optional[Union[SentenceTransformer, torch.nn.Module, Callable]] = None, model_kwargs: Optional[Dict[str, Any]] = None): """Retrieve documents from a vector database. Args: raw_docs: List[str]: List of raw documents. embedded_docs: Optional[Tensor]: Embedded documents. k_for_docs: int: Number of documents to retrieve. model: Optional[Union[SentenceTransformer, torch.nn.Module]]: Model to use for encoding. model_kwargs: Optional[Dict[str, Any]]: Keyword arguments to pass to the model. """ self.raw_docs = raw_docs self.embedded_docs = embedded_docs self.k_for_docs = k_for_docs self.model = model if self.model is not None: self.encoder = self.model self.model_kwargs = model_kwargs if self.embedded_docs is None: assert self.model is not None, \ "Model must be provided if embedded_docs is not provided" self.model_kwargs = model_kwargs or {} self.embedded_docs = self.encoder(self.raw_docs, **self.model_kwargs) # we don't want to print the verbose output in `query` self.model_kwargs.pop("verbose", None)
[docs] def query(self, query: Union[str, Tensor]) -> List[str]: """Retrieve documents from the vector database. Args: query: Union[str, Tensor]: Query to retrieve documents for. Returns: List[str]: Documents retrieved from the vector database. """ if isinstance(query, str): query_enc = self.encoder(query, **self.model_kwargs) else: query_enc = query selected_doc_idxs, _ = next( batch_knn(query_enc, self.embedded_docs, self.k_for_docs)) return [self.raw_docs[i] for i in selected_doc_idxs]
[docs] def save(self, path: str) -> None: """Save the DocumentRetriever instance to disk. Args: path: str: Path where to save the retriever. """ os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True) # Prepare data to save save_dict = { 'raw_docs': self.raw_docs, 'embedded_docs': self.embedded_docs, 'k_for_docs': self.k_for_docs, } # We do not serialize the model torch.save(save_dict, path)
[docs] @classmethod def load(cls, path: str, model: Union[SentenceTransformer, torch.nn.Module, Callable], model_kwargs: Optional[Dict[str, Any]] = None) -> VectorRetriever: """Load a DocumentRetriever instance from disk. Args: path: str: Path to the saved retriever. model: Union[SentenceTransformer, torch.nn.Module, Callable]: Model to use for encoding. If None, the saved model will be used if available. model_kwargs: Optional[Dict[str, Any]] Key word args to be passed to model Returns: DocumentRetriever: The loaded retriever. """ if not os.path.exists(path): raise FileNotFoundError( f"No saved document retriever found at {path}") save_dict = torch.load(path, weights_only=False) if save_dict['embedded_docs'] is not None \ and isinstance(save_dict['embedded_docs'], Tensor)\ and model_kwargs is not None: model_kwargs.pop("verbose", None) # Create a new DocumentRetriever with the loaded data return cls(raw_docs=save_dict['raw_docs'], embedded_docs=save_dict['embedded_docs'], k_for_docs=save_dict['k_for_docs'], model=model, model_kwargs=model_kwargs)