Source code for torch_geometric.datasets.git_mol_dataset

import sys
from typing import Any, Callable, Dict, List, Optional

import numpy as np
import torch
from tqdm import tqdm

from torch_geometric.data import (
    Data,
    InMemoryDataset,
    download_google_url,
    extract_zip,
)
from torch_geometric.io import fs


def safe_index(lst: List[Any], e: int) -> int:
    return lst.index(e) if e in lst else len(lst) - 1


[docs]class GitMolDataset(InMemoryDataset): r"""The dataset from the `"GIT-Mol: A Multi-modal Large Language Model for Molecular Science with Graph, Image, and Text" <https://arxiv.org/pdf/2308.06911>`_ paper. Args: root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) pre_filter (callable, optional): A function that takes in an :obj:`torch_geometric.data.Data` object and returns a boolean value, indicating whether the data object should be included in the final dataset. (default: :obj:`None`) force_reload (bool, optional): Whether to re-process the dataset. (default: :obj:`False`) split (int, optional): Datasets split, train/valid/test=0/1/2. (default: :obj:`0`) """ raw_url_id = '1loBXabD6ncAFY-vanRsVtRUSFkEtBweg' def __init__( self, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None, pre_filter: Optional[Callable] = None, force_reload: bool = False, split: int = 0, ): from torchvision import transforms self.split = split if self.split == 0: self.img_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: self.img_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload) self.load(self.processed_paths[0]) @property def raw_file_names(self) -> List[str]: return ['train_3500.pkl', 'valid_450.pkl', 'test_450.pkl'] @property def processed_file_names(self) -> str: return ['train.pt', 'valid.pt', 'test.pt'][self.split] def download(self) -> None: file_path = download_google_url( self.raw_url_id, self.raw_dir, 'gitmol.zip', ) extract_zip(file_path, self.raw_dir) def process(self) -> None: import pandas as pd from PIL import Image try: from rdkit import Chem, RDLogger RDLogger.DisableLog('rdApp.*') # type: ignore[attr-defined] WITH_RDKIT = True except ImportError: WITH_RDKIT = False if not WITH_RDKIT: print(("Using a pre-processed version of the dataset. Please " "install 'rdkit' to alternatively process the raw data."), file=sys.stderr) data_list = fs.torch_load(self.raw_paths[0]) data_list = [Data(**data_dict) for data_dict in data_list] if self.pre_filter is not None: data_list = [d for d in data_list if self.pre_filter(d)] if self.pre_transform is not None: data_list = [self.pre_transform(d) for d in data_list] self.save(data_list, self.processed_paths[0]) return allowable_features: Dict[str, List[Any]] = { 'possible_atomic_num_list': list(range(1, 119)) + ['misc'], 'possible_formal_charge_list': [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 'misc'], 'possible_chirality_list': [ Chem.rdchem.ChiralType.CHI_UNSPECIFIED, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.rdchem.ChiralType.CHI_OTHER ], 'possible_hybridization_list': [ Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2, Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.SP3D, Chem.rdchem.HybridizationType.SP3D2, Chem.rdchem.HybridizationType.UNSPECIFIED, 'misc' ], 'possible_numH_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 'misc'], 'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6], 'possible_degree_list': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 'misc'], 'possible_number_radical_e_list': [0, 1, 2, 3, 4, 'misc'], 'possible_is_aromatic_list': [False, True], 'possible_is_in_ring_list': [False, True], 'possible_bond_type_list': [ Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE, Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC, Chem.rdchem.BondType.ZERO ], 'possible_bond_dirs': [ # only for double bond stereo information Chem.rdchem.BondDir.NONE, Chem.rdchem.BondDir.ENDUPRIGHT, Chem.rdchem.BondDir.ENDDOWNRIGHT ], 'possible_bond_stereo_list': [ Chem.rdchem.BondStereo.STEREONONE, Chem.rdchem.BondStereo.STEREOZ, Chem.rdchem.BondStereo.STEREOE, Chem.rdchem.BondStereo.STEREOCIS, Chem.rdchem.BondStereo.STEREOTRANS, Chem.rdchem.BondStereo.STEREOANY, ], 'possible_is_conjugated_list': [False, True] } data = pd.read_pickle( f'{self.raw_dir}/igcdata_toy/{self.raw_file_names[self.split]}') data_list = [] for _, r in tqdm(data.iterrows(), total=data.shape[0]): smiles = r['isosmiles'] mol = Chem.MolFromSmiles(smiles.strip('\n')) if mol is not None: # text summary = r['summary'] # image cid = r['cid'] img_file = f'{self.raw_dir}/igcdata_toy/imgs/CID_{cid}.png' img = Image.open(img_file).convert('RGB') img = self.img_transform(img).unsqueeze(0) # graph atom_features_list = [] for atom in mol.GetAtoms(): atom_feature = [ safe_index( allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()), allowable_features['possible_chirality_list'].index( atom.GetChiralTag()), safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()), safe_index( allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()), safe_index(allowable_features['possible_numH_list'], atom.GetTotalNumHs()), safe_index( allowable_features[ 'possible_number_radical_e_list'], atom.GetNumRadicalElectrons()), safe_index( allowable_features['possible_hybridization_list'], atom.GetHybridization()), allowable_features['possible_is_aromatic_list'].index( atom.GetIsAromatic()), allowable_features['possible_is_in_ring_list'].index( atom.IsInRing()), ] atom_features_list.append(atom_feature) x = torch.tensor(np.array(atom_features_list), dtype=torch.long) edges_list = [] edge_features_list = [] for bond in mol.GetBonds(): i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() edge_feature = [ safe_index( allowable_features['possible_bond_type_list'], bond.GetBondType()), allowable_features['possible_bond_stereo_list'].index( bond.GetStereo()), allowable_features['possible_is_conjugated_list']. index(bond.GetIsConjugated()), ] edges_list.append((i, j)) edge_features_list.append(edge_feature) edges_list.append((j, i)) edge_features_list.append(edge_feature) edge_index = torch.tensor( np.array(edges_list).T, dtype=torch.long, ) edge_attr = torch.tensor( np.array(edge_features_list), dtype=torch.long, ) data = Data( x=x, edge_index=edge_index, smiles=smiles, edge_attr=edge_attr, image=img, caption=summary, ) if self.pre_filter is not None and not self.pre_filter(data): continue if self.pre_transform is not None: data = self.pre_transform(data) data_list.append(data) self.save(data_list, self.processed_paths[0])