torch_geometric.llm.utils.DocumentRetriever

class DocumentRetriever(raw_docs: List[str], embedded_docs: Optional[Tensor] = None, k_for_docs: int = 2, model: Optional[Union[SentenceTransformer, Module, Callable]] = None, model_kwargs: Optional[Dict[str, Any]] = None)[source]

Bases: VectorRetriever

Retrieve documents from a vector database.

query(query: Union[str, Tensor]) List[str][source]

Retrieve documents from the vector database.

Parameters:

query (Union[str, Tensor]) – Union[str, Tensor]: Query to retrieve documents for.

Returns:

Documents retrieved from the vector database.

Return type:

List[str]

save(path: str) None[source]

Save the DocumentRetriever instance to disk.

Parameters:

path (str) – str: Path where to save the retriever.

Return type:

None

classmethod load(path: str, model: Union[SentenceTransformer, Module, Callable], model_kwargs: Optional[Dict[str, Any]] = None) VectorRetriever[source]

Load a DocumentRetriever instance from disk.

Parameters:
  • path (str) – str: Path to the saved retriever.

  • model (Union[SentenceTransformer, Module, Callable]) – Union[SentenceTransformer, torch.nn.Module, Callable]: Model to use for encoding. If None, the saved model will be used if available.

  • model_kwargs (Optional[Dict[str, Any]], default: None) – Optional[Dict[str, Any]] Key word args to be passed to model

Returns:

The loaded retriever.

Return type:

DocumentRetriever