Source code for torch_geometric.datasets.instruct_mol_dataset

import json
import sys
from typing import Callable, List, Optional

import torch
from tqdm import tqdm

from torch_geometric.data import Data, InMemoryDataset
from torch_geometric.io import fs
from torch_geometric.utils import one_hot


[docs]class InstructMolDataset(InMemoryDataset): r"""The dataset from the `"InstructMol: Multi-Modal Integration for Building a Versatile and Reliable Molecular Assistant in Drug Discovery" <https://arxiv.org/pdf/2311.16208>`_ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) """ raw_url = 'https://huggingface.co/datasets/OpenMol/PubChemSFT/resolve/main' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, ): super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['all_clean.json'] @property def processed_file_names(self) -> List[str]: return ['data.pt'] def download(self) -> None: print('downloading dataset...') fs.cp(f'{self.raw_url}/all_clean.json', self.raw_dir) def process(self) -> None: try: from rdkit import Chem from rdkit.Chem.rdchem import BondType as BT WITH_RDKIT = True except ImportError: WITH_RDKIT = False if not WITH_RDKIT: print(("Using a pre-processed version of the dataset. Please " "install 'rdkit' to alternatively process the raw data."), file=sys.stderr) data_list = fs.torch_load(self.raw_paths[0]) data_list = [Data(**data_dict) for data_dict in data_list] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) return # types of atom and bond types = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4, 'Unknow': 5} bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} # load data mols = json.load(open(f'{self.raw_dir}/all_clean.json')) data_list = [] for smiles, qa_pairs in tqdm(mols.items(), total=len(mols)): mol = Chem.MolFromSmiles(smiles) if mol is None: continue x: torch.Tensor = torch.tensor([ types[atom.GetSymbol()] if atom.GetSymbol() in types else 5 for atom in mol.GetAtoms() ]) x = one_hot(x, num_classes=len(types), dtype=torch.float) rows, cols, edge_types = [], [], [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_types += [bonds[bond.GetBondType()]] * 2 rows += [i, j] cols += [j, i] edge_index = torch.tensor([rows, cols], dtype=torch.long) edge_type = torch.tensor(edge_types, dtype=torch.long) edge_attr = one_hot(edge_type, num_classes=len(bonds)) for question, answer in qa_pairs: data = Data( x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=smiles, instruction=question, y=answer, ) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0])