Source code for torch_geometric.llm.models.txt2kg

import os
import time
from typing import List, Optional, Tuple

import torch
import torch.multiprocessing as mp

CLIENT_INITD = False

CLIENT = None
GLOBAL_NIM_KEY = ""
SYSTEM_PROMPT = "Please convert the above text into a list of knowledge triples with the form ('entity', 'relation', 'entity'). Separate each with a new line. Do not output anything else. Try to focus on key triples that form a connected graph."  # noqa


[docs]class TXT2KG(): """A class to convert text data into a Knowledge Graph (KG) format. Uses NVIDIA NIMs + Prompt engineering by default. Default model `nvidia/llama-3.1-nemotron-70b-instruct` is on par or better than GPT4o in benchmarks. We need a high quality model to ensure high quality KG. Otherwise we have garbage in garbage out for the rest of the GNN+LLM RAG pipeline. Use local_lm flag for local debugging/dev. You still need to be able to inference a 14B param LLM, 'VAGOsolutions/SauerkrautLM-v2-14b-DPO'. Smaller LLMs did not work at all in testing. Note this 14B model requires a considerable amount of GPU memory. See examples/llm/txt2kg_rag.py for an example. Args: NVIDIA_NIM_MODEL : str, optional The name of the NVIDIA NIM model to use. (default: "nvidia/llama-3.1-nemotron-70b-instruct"). NVIDIA_API_KEY : str, optional The API key for accessing NVIDIA's NIM models (default: ""). ENDPOINT_URL : str, optional The URL hosting your model, in case you are not using the public NIM. (default: "https://integrate.api.nvidia.com/v1"). local_LM : bool, optional A flag indicating whether a local Language Model (LM) should be used. This uses HuggingFace and will be slower than deploying your own private NIM endpoint. This flag is mainly recommended for dev/debug. (default: False). chunk_size : int, optional The size of the chunks in which the text data is processed (default: 512). """ def __init__( self, NVIDIA_NIM_MODEL: Optional[ str] = "nvidia/llama-3.1-nemotron-70b-instruct", NVIDIA_API_KEY: Optional[str] = "", ENDPOINT_URL: Optional[str] = "https://integrate.api.nvidia.com/v1", local_LM: bool = False, chunk_size: int = 512, ) -> None: self.local_LM = local_LM # Initialize the local LM flag and the NIM model info accordingly if self.local_LM: # If using a local LM, set the initd_LM flag to False self.initd_LM = False else: # If not using a local LM, store the provided NIM model info self.NVIDIA_API_KEY = NVIDIA_API_KEY self.NIM_MODEL = NVIDIA_NIM_MODEL self.ENDPOINT_URL = ENDPOINT_URL # Set the chunk size for processing text data self.chunk_size = chunk_size # Initialize counters and storage for parsing results self.doc_id_counter = 0 self.relevant_triples = {} self.total_chars_parsed = 0 self.time_to_parse = 0.0
[docs] def save_kg(self, path: str) -> None: """Saves the relevant triples in the knowledge graph (KG) to a file. Args: path (str): The file path where the KG will be saved. Returns: None """ torch.save(self.relevant_triples, path)
def _chunk_to_triples_str_local(self, txt: str) -> str: # call LLM on text chunk_start_time = time.time() if not self.initd_LM: from torch_geometric.nn.nlp import LLM LM_name = "VAGOsolutions/SauerkrautLM-v2-14b-DPO" self.model = LLM(LM_name).eval() self.initd_LM = True out_str = self.model.inference(question=[txt + '\n' + SYSTEM_PROMPT], max_tokens=self.chunk_size)[0] # for debug self.total_chars_parsed += len(txt) self.time_to_parse += round(time.time() - chunk_start_time, 2) self.avg_chars_parsed_per_sec = self.total_chars_parsed / self.time_to_parse # noqa return out_str
[docs] def add_doc_2_KG( self, txt: str, QA_pair: Optional[Tuple[str, str]] = None, ) -> None: """Add a document to the Knowledge Graph (KG). Args: txt (str): The text to extract triples from. QA_pair (Tuple[str, str]], optional): A QA pair to associate with the extracted triples. Useful for downstream evaluation. Returns: - None """ if not self.local_LM: # Ensure NVIDIA_API_KEY is set before proceeding assert self.NVIDIA_API_KEY != '', \ "Please init TXT2KG w/ NVIDIA_API_KEY or set local_lm=True" if QA_pair: # QA_pairs should be unique keys, check if already exists in KG if QA_pair in self.relevant_triples.keys(): print("Warning: QA_Pair was already added to the set") print("Q=", QA_pair[0]) print("A=", QA_pair[1]) print("Previously parsed triples=", self.relevant_triples[QA_pair]) print("Skipping...") key = QA_pair else: # If no QA_pair, use the current doc_id_counter as the key key = self.doc_id_counter # Handle empty text (context-less QA pairs) if txt == "": self.relevant_triples[key] = [] else: # Chunk the text into smaller pieces for processing chunks = _chunk_text(txt, chunk_size=self.chunk_size) if self.local_LM: # For debugging purposes... # process chunks sequentially on the local LM self.relevant_triples[key] = _llm_then_python_parse( chunks, _parse_n_check_triples, self._chunk_to_triples_str_local) else: # Process chunks in parallel using multiple processes num_procs = min(len(chunks), _get_num_procs()) meta_chunk_size = int(len(chunks) / num_procs) in_chunks_per_proc = { j: chunks[j * meta_chunk_size:min((j + 1) * meta_chunk_size, len(chunks))] for j in range(num_procs) } for _retry_j in range(5): try: for _retry_i in range(200): try: # Spawn multiple processes # process chunks in parallel mp.spawn( _multiproc_helper, args=(in_chunks_per_proc, _parse_n_check_triples, _chunk_to_triples_str_cloud, self.NVIDIA_API_KEY, self.NIM_MODEL, self.ENDPOINT_URL), nprocs=num_procs) break except: # noqa # keep retrying... # txt2kg is costly -> stoppage is costly pass # Collect the results from each process self.relevant_triples[key] = [] for rank in range(num_procs): self.relevant_triples[key] += torch.load( "/tmp/outs_for_proc_" + str(rank)) os.remove("/tmp/outs_for_proc_" + str(rank)) break except: # noqa pass # Increment the doc_id_counter for the next document self.doc_id_counter += 1
known_reasoners = [ "llama-3.1-nemotron-ultra-253b-v1", "kimi-k2-instruct", "nemotron-super-49b-v1_5", "gpt-oss", ] def _chunk_to_triples_str_cloud( txt: str, GLOBAL_NIM_KEY='', NIM_MODEL="nvidia/llama-3.1-nemotron-ultra-253b-v1", ENDPOINT_URL="https://integrate.api.nvidia.com/v1", post_text=SYSTEM_PROMPT) -> str: global CLIENT_INITD if not CLIENT_INITD: # We use NIMs since most PyG users may not be able to run a 70B+ model try: from openai import OpenAI except ImportError: quit( "Failed to import `openai` package, please install it and rerun the script" # noqa ) global CLIENT CLIENT = OpenAI(base_url=ENDPOINT_URL, api_key=GLOBAL_NIM_KEY) CLIENT_INITD = True txt_input = txt if post_text != "": txt_input += '\n' + post_text messages = [] if any([model_name_str in NIM_MODEL for model_name_str in known_reasoners]): messages.append({"role": "system", "content": "detailed thinking on"}) messages.append({"role": "user", "content": txt_input}) completion = CLIENT.chat.completions.create(model=NIM_MODEL, messages=messages, temperature=0, top_p=1, max_tokens=1024, stream=True) out_str = "" for chunk in completion: if chunk.choices[0].delta.content is not None: out_str += chunk.choices[0].delta.content return out_str def _parse_n_check_triples(triples_str: str) -> List[Tuple[str, str, str]]: # use pythonic checks for triples processed = [] split_by_newline = triples_str.split("\n") # sometimes LLM fails to obey the prompt if len(split_by_newline) > 1: split_triples = split_by_newline llm_obeyed = True else: # handles form "(e, r, e) (e, r, e) ... (e, r, e)"" split_triples = triples_str[1:-1].split(") (") llm_obeyed = False for triple_str in split_triples: try: if llm_obeyed: # remove parenthesis and single quotes for parsing triple_str = triple_str.replace("(", "").replace(")", "").replace( "'", "") split_trip = triple_str.split(',') # remove blank space at beginning or end split_trip = [(i[1:] if i[0] == " " else i) for i in split_trip] split_trip = [(i[:-1].lower() if i[-1] == " " else i) for i in split_trip] potential_trip = tuple(split_trip) except: # noqa continue if 'tuple' in str(type(potential_trip)) and len( potential_trip ) == 3 and "note:" not in potential_trip[0].lower(): # additional check for empty node/edge attrs if potential_trip[0] != '' and potential_trip[ 1] != '' and potential_trip[2] != '': processed.append(potential_trip) return processed def _llm_then_python_parse(chunks, py_fn, llm_fn, **kwargs): relevant_triples = [] for chunk in chunks: relevant_triples += py_fn(llm_fn(chunk, **kwargs)) return relevant_triples def _multiproc_helper(rank, in_chunks_per_proc, py_fn, llm_fn, NIM_KEY, NIM_MODEL, ENDPOINT_URL): out = _llm_then_python_parse(in_chunks_per_proc[rank], py_fn, llm_fn, GLOBAL_NIM_KEY=NIM_KEY, NIM_MODEL=NIM_MODEL, ENDPOINT_URL=ENDPOINT_URL) torch.save(out, "/tmp/outs_for_proc_" + str(rank)) def _get_num_procs(): if hasattr(os, "sched_getaffinity"): try: num_proc = len(os.sched_getaffinity(0)) / (2) except Exception: pass if num_proc is None: num_proc = os.cpu_count() / (2) return int(num_proc) def _chunk_text(text: str, chunk_size: int = 512) -> list[str]: """Function to chunk text into sentence-based segments. Co-authored with Claude AI. """ # If the input text is empty or None, return an empty list if not text: return [] # List of punctuation marks that typically end sentences sentence_endings = '.!?' # List to store the resulting chunks chunks = [] # Continue processing the entire text while text: # If the remaining text is shorter than chunk_size, add it and break if len(text) <= chunk_size: chunks.append(text.strip()) break # Start with the maximum possible chunk chunk = text[:chunk_size] # Try to find the last sentence ending within the chunk best_split = chunk_size for ending in sentence_endings: # Find the last occurrence of the ending punctuation last_ending = chunk.rfind(ending) if last_ending != -1: # Ensure we include the punctuation and any following space best_split = min( best_split, last_ending + 1 + (1 if last_ending + 1 < len(chunk) and chunk[last_ending + 1].isspace() else 0)) # Adjust to ensure we don't break words # If the next character is a letter, find the last space if best_split < len(text) and text[best_split].isalpha(): # Find the last space before the current split point space_split = text[:best_split].rfind(' ') if space_split != -1: best_split = space_split # Append the chunk, ensuring it's stripped chunks.append(text[:best_split].strip()) # Remove the processed part from the text text = text[best_split:].lstrip() return chunks