Source code for torch_geometric.llm.models.glem

from typing import List, Optional, Union

import torch
import torch.nn as nn
from tqdm import tqdm

from torch_geometric.loader import DataLoader, NeighborLoader
from torch_geometric.nn.models import GraphSAGE, basic_gnn


def deal_nan(x):
    if isinstance(x, torch.Tensor):
        x = x.clone()
        x[torch.isnan(x)] = 0.0
    return x


[docs]class GLEM(torch.nn.Module): r"""This GNN+LM co-training model is based on GLEM from the `"Learning on Large-scale Text-attributed Graphs via Variational Inference" <https://arxiv.org/abs/2210.14709>`_ paper. Args: lm_to_use (str): A TextEncoder from huggingface model repo with a classifier(default: TinyBERT) gnn_to_use (torch_geometric.nn.models): (default: GraphSAGE) out_channels (int): output channels for LM and GNN, should be same num_gnn_heads Optional[int]: Number of heads for attention, if needed num_gnn_layers (int): number of gnn layers gnn_loss: loss function for gnn, (default: CrossEntropyLoss) lm_loss: loss function for Language Model, (default: CrossEntropyLoss) alpha (float): pseudo label weight of E-step, LM optimization, (default: 0.5) beta (float): pseudo label weight of M-step, GNN optimization, (default: 0.5) lm_dtype (torch.dtype): the data type once you load LM into memory, (default: torch.bfloat16) lm_use_lora (bool): choose if LM use Lora peft for fine tune, (default: True) lora_target_modules: The names of the target modules to apply the lora adapter to, e.g. ['q_proj', 'v_proj'] for LLM , (default: None) .. note:: See `examples/llm_plus_gnn/glem.py` for example usage. """ def __init__( self, lm_to_use: str = 'prajjwal1/bert-tiny', gnn_to_use: basic_gnn = GraphSAGE, out_channels: int = 47, gnn_loss: Optional[nn.Module] = None, lm_loss: Optional[nn.Module] = None, alpha: float = 0.5, beta: float = 0.5, lm_dtype: torch.dtype = torch.bfloat16, lm_use_lora: bool = True, lora_target_modules: Optional[Union[List[str], str]] = None, device: Optional[Union[str, torch.device]] = None, ): super().__init__() if gnn_loss is None: gnn_loss = nn.CrossEntropyLoss(reduction='mean') if lm_loss is None: lm_loss = nn.CrossEntropyLoss(reduction='mean') if device is None: device = torch.device('cpu') self.device = device self.lm_loss = lm_loss self.gnn = gnn_to_use self.gnn_loss = gnn_loss self.alpha = alpha self.beta = beta self.gnn_loss = gnn_loss self.lm = lm_to_use from transformers import AutoModelForSequenceClassification self.lm = AutoModelForSequenceClassification.from_pretrained( lm_to_use, num_labels=out_channels, torch_dtype=lm_dtype, offload_folder="offload", trust_remote_code=True) if lm_use_lora: from peft import ( LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training, ) print("Training LM with LORA!") self.lm = prepare_model_for_kbit_training(self.lm) config = LoraConfig(task_type=TaskType.SEQ_CLS, r=16, lora_alpha=16, lora_dropout=0.05, bias="none", target_modules=lora_target_modules) self.lm = get_peft_model(self.lm, config) self.lm.print_trainable_parameters() self.lm.config.pad_token_id = self.lm.config.eos_token_id self.lm_device = self.lm.device if self.lm.num_labels != self.gnn.out_channels: raise ValueError('''The output channel of language model \ and gnn should be the same''') def pre_train_gnn(self, train_loader: NeighborLoader, optimizer: torch.optim.Optimizer, num_epochs: int, patience: int, ext_pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): # Pretrain GNN, optional steps if you do not have pseudo labels. best_acc = 0 early_stopping = 0 # training only based on gold data for epoch in range(0, num_epochs): acc, loss = self.train_gnn(train_loader, optimizer, epoch, ext_pseudo_labels, is_augmented, verbose) if acc < best_acc: early_stopping += 1 if early_stopping > patience: print(f'Early stopped by Epoch: {epoch}, ' f'Best acc: {best_acc}') break best_acc = max(best_acc, acc) def pre_train_lm(self, train_loader: DataLoader, optimizer: torch.optim.Optimizer, num_epochs: int, patience: int, ext_pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): # Pretrain language model best_acc = 0 early_stopping = 0 for epoch in range(1, num_epochs + 1): acc, loss = self.train_lm(train_loader, optimizer, epoch, ext_pseudo_labels, is_augmented, verbose) if acc < best_acc: early_stopping += 1 if early_stopping > patience: print(f'Early stopped by Epoch: {epoch}, ' f'Best acc: {best_acc}') break best_acc = max(best_acc, acc)
[docs] def train(self, em_phase: str, train_loader: Union[DataLoader, NeighborLoader], optimizer: torch.optim.Optimizer, pseudo_labels: torch.Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False): r"""GLEM training step, EM steps. Args: em_phase(str): 'gnn' or 'lm' choose which phase you are training on train_loader(Union[DataLoader, NeighborLoader]): use DataLoader for lm training, include tokenized data, labels is_gold mask. use NeighborLoader for gnn training, include x, edge_index. optimizer (torch.optim.Optimizer): optimizer for training pseudo_labels(torch.Tensor): the predicted labels used as pseudo labels epoch (int): current epoch is_augmented (bool): will use pseudo_labels or not verbose (bool): print training progress bar or not Returns: acc (float): training accuracy loss (float): loss value """ if pseudo_labels is not None: pseudo_labels = pseudo_labels.to(self.device) if em_phase == 'gnn': acc, loss = self.train_gnn(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose) if em_phase == 'lm': acc, loss = self.train_lm(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose) return acc, loss
[docs] def train_lm(self, train_loader: DataLoader, optimizer: torch.optim.Optimizer, epoch: int, pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): r"""Language model Training in every epoch. Args: train_loader (loader.dataloader.DataLoader): text token dataloader optimizer (torch.optim.Optimizer): model optimizer epoch (int): current train epoch pseudo_labels (torch.Tensor): 1-D tensor, predictions from gnn is_augmented (bool): train with pseudo labels or not verbose (bool): print training progress bar or not Returns: approx_acc (torch.tensor): training accuracy loss (torch.float): loss value """ all_out = [] total_loss = total_correct = 0 num_nodes = train_loader.dataset.indices.size(0) self.lm.train() if verbose: pbar = tqdm(total=num_nodes) pbar.set_description(f'Epoch {epoch:02d}') for batch in train_loader: inputs = {k: v.to(self.device) for k, v in batch['input'].items()} out = self.lm(**inputs).logits labels = batch['labels'].to(self.device).squeeze() # training with pseudo labels or not if is_augmented: pl_batch = pseudo_labels[batch['n_id']].to(self.device) else: pl_batch = None loss = self.loss(out, labels, self.lm_loss, batch['is_gold'].to(self.device), pl_batch, self.alpha, is_augmented) loss.backward() optimizer.step() optimizer.zero_grad() all_out.append(out) total_correct += int(out.argmax(dim=-1).eq(labels).sum()) total_loss += float(loss.detach()) if verbose: pbar.update(batch['n_id'].size(0)) all_out = torch.cat(all_out, dim=0) approx_acc = total_correct / num_nodes loss = total_loss / len(train_loader) if verbose: pbar.close() print(f'Epoch {epoch:02d} Loss: {loss:.4f} ' f'Approx. Train: {approx_acc:.4f}') return approx_acc, loss
[docs] def train_gnn(self, train_loader: NeighborLoader, optimizer: torch.optim.Optimizer, epoch: int, pseudo_labels: torch.Tensor = None, is_augmented: bool = False, verbose: bool = True): r"""GNN training step in every epoch. Args: train_loader (loader.NeighborLoader): gnn Neighbor node loader optimizer (torch.optim.Optimizer): model optimizer epoch (int): current train epoch pseudo_labels(torch.tensor): 1-D tensor, predictions from lm is_augmented(bool): use pseudo labeled node or not verbose (bool): print training progress or not Returns: approx_acc (torch.tensor): training accuracy loss (torch.float): loss value """ self.gnn.train() num_nodes = train_loader.input_nodes.size(0) if verbose: pbar = tqdm(total=num_nodes) pbar.set_description(f'Epoch {epoch:02d}') total_loss = total_correct = 0 all_out = [] for batch in train_loader: batch = batch.to(self.device) out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] all_out.append(out) labels = batch.y[:batch.batch_size].squeeze() is_gold_batch = batch.is_gold[:batch.batch_size].squeeze() # training with pseudo labels or not if is_augmented and pseudo_labels is not None: pl_batch = pseudo_labels[batch.n_id[:batch.batch_size]] else: pl_batch = None loss = self.loss(out, labels, self.gnn_loss, is_gold_batch, pl_batch, self.beta, is_augmented) loss.backward() optimizer.step() optimizer.zero_grad() total_loss += float(loss.detach()) total_correct += int(out.argmax(dim=-1).eq(labels).sum()) if verbose: pbar.update(batch.batch_size) all_out = torch.cat(all_out, dim=0) loss = total_loss / len(train_loader) approx_acc = total_correct / num_nodes if verbose: pbar.close() print(f'Epoch: {epoch:02d} Loss: {loss:.4f} ' f'Approx. Train: {approx_acc:.4f}') return approx_acc, loss
[docs] @torch.no_grad() def inference(self, em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False): r"""GLEM inference step. Args: em_phase(str): 'gnn' or 'lm' data_loader(dataloader or Neighborloader): dataloader: for lm training, include tokenized data nodeloader: for gnn training, include x, edge_index verbose(bool): print inference progress or not Returns: out (torch.Tensor): n * m tensor, m is number of classes, n is number of nodes """ out = None if em_phase == 'gnn': self.gnn.eval() out = self.inference_gnn(data_loader, verbose) elif em_phase == 'lm': self.lm.eval() out = self.inference_lm(data_loader, verbose) return out
[docs] @torch.no_grad() def inference_lm(self, data_loader: DataLoader, verbose: bool = True): r"""LM inference step. Args: data_loader (Dataloader): include token, labels, and gold mask verbose (bool): print progress bar or not Returns: preds (tensor): prediction from GNN, convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) """ if verbose: pbar = tqdm(total=data_loader.dataset._data.num_nodes) pbar.set_description('LM inference stage') self.lm.eval() preds = [] for batch in data_loader: inputs = {k: v.to(self.device) for k, v in batch['input'].items()} logits = self.lm(**inputs).logits preds.append(logits) if verbose: pbar.update(batch['n_id'].size(0)) if verbose: pbar.close() preds = torch.cat(preds) return preds
[docs] @torch.no_grad() def inference_gnn(self, data_loader: NeighborLoader, verbose: bool = True): r"""GNN inference step. Args: data_loader(NeighborLoader): include x, edge_index, verbose (bool): print progress bar or not Returns: preds (tensor): prediction from GNN, convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1) """ if verbose: pbar = tqdm(total=data_loader.data.num_nodes) pbar.set_description('GNN inference stage') preds = [] self.gnn.eval() for batch in data_loader: batch = batch.to(self.device) out = self.gnn(batch.x, batch.edge_index)[:batch.batch_size] preds.append(out) if verbose: pbar.update(batch.batch_size) if verbose: pbar.close() preds = torch.cat(preds, dim=0) return preds
[docs] def loss(self, logits: torch.Tensor, labels: torch.Tensor, loss_func: torch.nn.functional, is_gold: torch.Tensor, pseudo_labels: torch.Tensor = None, pl_weight: float = 0.5, is_augmented: bool = True): r"""Core function of variational EM inference, this function is aming on combining loss value on gold(original train) and loss value on pseudo labels. Reference: <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa Args: logits(torch.tensor): predict results from LM or GNN labels(torch.tensor): combined node labels from ground truth and pseudo labels(if provided) loss_func(torch.nn.modules.loss): loss function for classification is_gold(tensor): a tensor with bool value that mask ground truth label and during training, thus ~is_gold mask pseudo labels pseudo_labels(torch.tensor): predictions from other model pl_weight: the pseudo labels used in E-step and M-step optimization alpha in E-step, beta in M-step respectively is_augmented: use EM or just train GNN and LM with gold data """ if is_augmented and (sum(~is_gold) > 0): mle_loss = deal_nan(loss_func(logits[is_gold], labels[is_gold])) # all other labels beside from ground truth(gold labels) pseudo_label_loss = deal_nan( loss_func(logits[~is_gold], pseudo_labels[~is_gold])) loss = pl_weight * pseudo_label_loss + (1 - pl_weight) * mle_loss else: loss = loss_func(logits, labels) return loss