Source code for torch_geometric.datasets.molecule_gpt_dataset

import gzip
import json
import multiprocessing
import os
import sys
from collections import defaultdict
from multiprocessing import Pool
from typing import Callable, List, Optional, Tuple

import numpy as np
import requests
import torch
from tqdm import tqdm

from torch_geometric.data import Data, InMemoryDataset, download_url
from torch_geometric.io import fs
from torch_geometric.llm.models import LLM
from torch_geometric.utils import one_hot


def clean_up_description(description: str) -> str:
    description = description + " "

    # extra adj Pure
    if description.startswith("Pure "):
        description = description.replace("Pure ", "")
    # fix typo
    if description.startswith("Mercurycombines"):
        description = description.replace("Mercurycombines",
                                          "Mercury combines")

    # a special case
    description = description.replace(
        "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione. ",
        "17-Hydroxy-6-methylpregna-3,6-diene-3,20-dione is ")

    # a special case
    description = description.replace("5-Thymidylic acid. ",
                                      "5-Thymidylic acid. is ")

    # a special case
    description = description.replace(
        "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. ",
        "5'-S-(3-Amino-3-carboxypropyl)-5'-thioadenosine. is ")

    # a special case
    description = description.replace(
        ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
         " with phosphorothioic acid. "),
        ("Guanosine 5'-(trihydrogen diphosphate), monoanhydride"
         " with phosphorothioic acid is "))

    # a special case
    description = description.replace("5'-Uridylic acid. ",
                                      "5'-Uridylic acid is ")

    # a special case
    description = description.replace("5'-Adenylic acid, ",
                                      "5'-Adenylic acid is ")

    # a special case
    description = description.replace(
        "Uridine 5'-(tetrahydrogen triphosphate). ",
        "Uridine 5'-(tetrahydrogen triphosphate). is ")

    # a special case
    description = description.replace("Inosine 5'-Monophosphate. ",
                                      "Inosine 5'-Monophosphate. is ")

    # a special case
    description = description.replace("Pivaloyloxymethyl butyrate (AN-9), ",
                                      "Pivaloyloxymethyl butyrate (AN-9) is ")

    # a special case
    description = description.replace(
        "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine. ",
        "4-Amino-5-cyano-7-(D-ribofuranosyl)-7H- pyrrolo(2,3-d)pyrimidine is ")

    # a special case
    description = description.replace(
        "Cardamonin (also known as Dihydroxymethoxychalcone), ",
        "Cardamonin (also known as Dihydroxymethoxychalcone) is ")

    # a special case
    description = description.replace("Lithium has been used to treat ",
                                      "Lithium is ")

    # a special case
    description = description.replace("4,4'-Methylenebis ",
                                      "4,4'-Methylenebis is ")

    # a special case
    description = description.replace(
        "2,3,7,8-Tetrachlorodibenzo-p-dioxin",
        "2,3,7,8-Tetrachlorodibenzo-p-dioxin is ")

    # a special case
    description = description.replace("Exposure to 2,4,5-trichlorophenol ",
                                      "2,4,5-Trichlorophenol exposure ")

    index = 0
    L = len(description)
    if description.startswith('C.I. '):
        start_index = len('C.I. ')
    elif description.startswith('Nectriapyrone. D '):
        start_index = len('Nectriapyrone. D ')
    elif description.startswith(
            'Salmonella enterica sv. Minnesota LPS core oligosaccharide'):
        start_index = len(
            'Salmonella enterica sv. Minnesota LPS core oligosaccharide')
    else:
        start_index = 0
    for index in range(start_index, L - 1):
        if index < L - 2:
            if description[index] == '.' and description[
                    index + 1] == ' ' and 'A' <= description[index + 2] <= 'Z':
                break
        elif index == L - 2:
            break

    first_sentence = description[:index + 1]
    return first_sentence


def extract_name(
    name_raw: str,
    description: str,
) -> Tuple[Optional[str], str, str]:
    first_sentence = clean_up_description(description)

    splitter = '  --  --  '
    if ' are ' in first_sentence or ' were ' in first_sentence:
        replaced_words = 'These molecules'
    else:
        replaced_words = 'This molecule'

    first_sentence = first_sentence.replace(' is ', splitter)
    first_sentence = first_sentence.replace(' are ', splitter)
    first_sentence = first_sentence.replace(' was ', splitter)
    first_sentence = first_sentence.replace(' were ', splitter)
    first_sentence = first_sentence.replace(' appears ', splitter)
    first_sentence = first_sentence.replace(' occurs ', splitter)
    first_sentence = first_sentence.replace(' stands for ', splitter)
    first_sentence = first_sentence.replace(' belongs to ', splitter)
    first_sentence = first_sentence.replace(' exists ',
                                            splitter)  # only for CID=11443
    first_sentence = first_sentence.replace(' has been used in trials ',
                                            splitter)
    first_sentence = first_sentence.replace(' has been investigated ',
                                            splitter)
    first_sentence = first_sentence.replace(' has many uses ', splitter)

    if splitter in first_sentence:
        extracted_name = first_sentence.split(splitter, 1)[0]
    elif first_sentence.startswith(name_raw):
        extracted_name = name_raw
    elif name_raw in first_sentence:
        extracted_name = name_raw
        extracted_name = None
        print("=====", name_raw)
        print("first sentence: ", first_sentence)
    else:
        extracted_name = None

    if extracted_name is not None:
        extracted_description = description.replace(extracted_name,
                                                    replaced_words)
    else:
        extracted_description = description

    return extracted_name, extracted_description, first_sentence


[docs]class MoleculeGPTDataset(InMemoryDataset): r"""The dataset from the `"MoleculeGPT: Instruction Following Large Language Models for Molecular Property Prediction" <https://ai4d3.github.io/papers/34.pdf>`_ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) total_page_num (int, optional): The number of pages from PubChem. (default: :obj:`10`) total_block_num (int, optional): The blocks of SDF files from PubChem. (default: :obj:`1`) num_units (int, optional): Number of units of the sample. (default: :obj:`-1`, which means all units will be used) """ description_url = ( 'https://pubchem.ncbi.nlm.nih.gov/rest/pug_view/annotations/' 'heading/json?heading_type=Compound&heading=Record+Description&page={}' ) compound_url = ('https://ftp.ncbi.nlm.nih.gov/pubchem/Compound/' 'CURRENT-Full/SDF') def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, total_page_num: int = 10, total_block_num: int = 1, num_units: int = -1, ): self.total_page_num = total_page_num self.total_block_num = total_block_num self.num_units = num_units super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['pubchem.csv'] @property def processed_file_names(self) -> List[str]: return ['data.pt'] def download(self) -> None: # Step 01. Extract description step1_folder = f"{self.raw_dir}/step_01_PubChemSTM_description" if not os.path.exists(step1_folder): os.makedirs(step1_folder) valid_CID_set = set() CID2name_raw, CID2name_extracted = defaultdict(list), defaultdict( list) CID2text_raw, CID2text_extracted = defaultdict(list), defaultdict( list) for page_index in tqdm(range(self.total_page_num)): page_num = page_index + 1 f_out = open( f"{step1_folder}/Compound_description_{page_num}.txt", "w") description_data = requests.get( self.description_url.format(page_num)).json() description_data = description_data["Annotations"] assert description_data["Page"] == page_num record_list = description_data["Annotation"] for record in record_list: try: CID = record["LinkedRecords"]["CID"][0] if "Name" in record: name_raw = record["Name"] CID2name_raw[CID].append(name_raw) else: name_raw = None data_list = record["Data"] for data in data_list: description = data["Value"]["StringWithMarkup"][0][ "String"].strip() extracted_name, extracted_description, _ = extract_name( # noqa: E501 name_raw, description) if extracted_name is not None: CID2name_extracted[CID].append(extracted_name) CID2text_raw[CID].append(description) CID2text_extracted[CID].append( extracted_description) valid_CID_set.add(CID) f_out.write(f"{CID}\n") f_out.write(f"{extracted_description}\n\n") except Exception: continue valid_CID_list = sorted(list(valid_CID_set)) print(f"Total CID (with raw name) {len(CID2name_raw)}") print(f"Total CID (with extracted name) {len(CID2name_extracted)}") print(f"Total CID {len(valid_CID_list)}") with open(f"{self.raw_dir}/CID2name_raw.json", "w") as f: json.dump(CID2name_raw, f) with open(f"{self.raw_dir}/CID2name.json", "w") as f: json.dump(CID2name_extracted, f) with open(f"{self.raw_dir}/CID2text_raw.json", "w") as f: json.dump(CID2text_raw, f) with open(f"{self.raw_dir}/CID2text.json", "w") as f: json.dump(CID2text_extracted, f) # Step 02. Download SDF Files step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" if not os.path.exists(step2_folder): for block_id in tqdm(range(self.total_block_num)): block_size = 500000 l_id = block_id * block_size + 1 r_id = (block_id + 1) * block_size compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" download_url(f"{self.compound_url}/{compound_file_name}", step2_folder) def process(self, use_mp: bool = False) -> None: try: from rdkit import Chem from rdkit.Chem.rdchem import BondType as BT WITH_RDKIT = True except ImportError: WITH_RDKIT = False if not WITH_RDKIT: print(("Using a pre-processed version of the dataset. Please " "install 'rdkit' to alternatively process the raw data."), file=sys.stderr) data_list = fs.torch_load(self.raw_paths[0]) data_list = [Data(**data_dict) for data_dict in data_list] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) return # Step 03. Filter out SDF step2_folder = f"{self.raw_dir}/step_02_PubChemSTM_SDF" step3_folder = f"{self.raw_dir}/step_03_PubChemSTM_filtered" if not os.path.exists(step3_folder): os.makedirs(step3_folder) with open(f"{self.raw_dir}/CID2text.json") as f: CID2text = json.load(f) target_CID_list = set(CID2text.keys()) block_size = 500000 def extract_one_SDF_file(block_id: int) -> None: valid_mol_count = 0 writer = Chem.SDWriter( f'{step3_folder}/filtered_{block_id}.sdf') l_id = block_id * block_size + 1 r_id = (block_id + 1) * block_size compound_file_name = f"Compound_{l_id:09d}_{r_id:09d}.sdf.gz" gzip_loader = gzip.open(f"{step2_folder}/{compound_file_name}") suppl = Chem.ForwardSDMolSupplier(gzip_loader) for mol in tqdm(suppl): if mol is None: continue cid = mol.GetProp("PUBCHEM_COMPOUND_CID") if cid not in target_CID_list: continue writer.write(mol) valid_mol_count += 1 writer.close() print(f"block id: {block_id}\nfound {valid_mol_count}\n\n") sys.stdout.flush() return if use_mp: num_process = multiprocessing.cpu_count() print(f"{num_process} CPUs") num_process = 8 p = Pool(num_process) block_id_list = np.arange(self.total_block_num) with p: p.map(extract_one_SDF_file, block_id_list) else: for block_id in range(self.total_block_num): extract_one_SDF_file(block_id) # Step 04. Merge SDF with open(f"{self.raw_dir}/CID2text.json") as f: CID2text = json.load(f) target_CID_list = set(CID2text.keys()) print(f'The length of target_CID_list: {len(target_CID_list)}') writer = Chem.SDWriter(f'{self.raw_dir}/molecules.sdf') found_CID_set = set() for block_id in range(self.total_block_num + 1): compound_file_path = f"{step3_folder}/filtered_{block_id}.sdf" try: suppl = Chem.SDMolSupplier(compound_file_path) for mol in tqdm(suppl): writer.write(mol) cid = mol.GetProp("PUBCHEM_COMPOUND_CID") found_CID_set.add(cid) except Exception: print(f"block id: {block_id} with 0 valid SDF file") continue writer.close() print(f"In total: {len(found_CID_set)} molecules") # Step 05. Convert to PyG data format types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} data_list = [] # Real data CID2text_file = f'{self.raw_dir}/CID2text.json' with open(CID2text_file) as f: CID2text_data = json.load(f) suppl = Chem.SDMolSupplier(f'{self.raw_dir}/molecules.sdf') llm = LLM( # model_name='lmsys/vicuna-7b-v1.5', model_name='TinyLlama/TinyLlama-1.1B-Chat-v0.1', num_params=1, dtype=torch.bfloat16, ) prompt = ("Propose a question regarding the molecule '∼' " "whose answer is: {}:") for mol in tqdm(suppl): if mol.HasProp('PUBCHEM_COMPOUND_CID'): CID = mol.GetProp("PUBCHEM_COMPOUND_CID") CAN_SMILES = mol.GetProp("PUBCHEM_SMILES") m: Chem.Mol = Chem.MolFromSmiles(CAN_SMILES) if m is None: continue RDKit_CAN_SMILES = Chem.MolToSmiles(m) ground_truth = CID2text_data[CID][0] instruction = llm.inference([prompt.format(ground_truth)])[0] x: torch.Tensor = torch.tensor([ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 for atom in m.GetAtoms() ]) x = one_hot(x, num_classes=len(types), dtype=torch.float) rows, cols, edge_types = [], [], [] for bond in m.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_types += [bonds[bond.GetBondType()]] * 2 rows += [i, j] cols += [j, i] edge_index = torch.tensor([rows, cols], dtype=torch.long) edge_type = torch.tensor(edge_types, dtype=torch.long) edge_attr = one_hot(edge_type, num_classes=len(bonds)) data = Data( x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=RDKit_CAN_SMILES, instruction=instruction, y=ground_truth, ) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) if self.num_units > 0 and len(data_list) >= self.num_units: break self.save(data_list, self.processed_paths[0])