torch_geometric.llm.models.GLEM

class GLEM(lm_to_use: str = 'prajjwal1/bert-tiny', gnn_to_use: <module 'torch_geometric.nn.models.basic_gnn' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/10089/lib/python3.9/site-packages/torch_geometric/nn/models/basic_gnn.py'> = <class 'torch_geometric.nn.models.basic_gnn.GraphSAGE'>, out_channels: int = 47, gnn_loss: ~typing.Optional[~torch.nn.modules.module.Module] = None, lm_loss: ~typing.Optional[~torch.nn.modules.module.Module] = None, alpha: float = 0.5, beta: float = 0.5, lm_dtype: ~torch.dtype = torch.bfloat16, lm_use_lora: bool = True, lora_target_modules: ~typing.Optional[~typing.Union[str, ~typing.List[str]]] = None, device: ~typing.Optional[~typing.Union[~torch.device, str]] = None)[source]

Bases: Module

This GNN+LM co-training model is based on GLEM from the “Learning on Large-scale Text-attributed Graphs via Variational Inference” paper.

Parameters:
  • lm_to_use (str) – A TextEncoder from huggingface model repo with a classifier(default: TinyBERT)

  • gnn_to_use (torch_geometric.nn.models) – (default: GraphSAGE)

  • out_channels (int) – output channels for LM and GNN, should be same

  • Optional[int] (num_gnn_heads) – Number of heads for attention, if needed

  • num_gnn_layers (int) – number of gnn layers

  • gnn_loss (Optional[Module], default: None) – loss function for gnn, (default: CrossEntropyLoss)

  • lm_loss (Optional[Module], default: None) – loss function for Language Model, (default: CrossEntropyLoss)

  • alpha (float) – pseudo label weight of E-step, LM optimization, (default: 0.5)

  • beta (float) – pseudo label weight of M-step, GNN optimization, (default: 0.5)

  • lm_dtype (torch.dtype) – the data type once you load LM into memory, (default: torch.bfloat16)

  • lm_use_lora (bool) – choose if LM use Lora peft for fine tune, (default: True)

  • lora_target_modules (Union[str, List[str], None], default: None) – The names of the target modules to apply the lora adapter to, e.g. [‘q_proj’, ‘v_proj’] for LLM , (default: None)

Note

See examples/llm_plus_gnn/glem.py for example usage.

train(em_phase: str, train_loader: Union[DataLoader, NeighborLoader], optimizer: Optimizer, pseudo_labels: Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False)[source]

GLEM training step, EM steps.

Parameters:
  • em_phase (str) – ‘gnn’ or ‘lm’ choose which phase you are training on

  • train_loader (Union[DataLoader, NeighborLoader]) – use DataLoader for lm training, include tokenized data, labels is_gold mask. use NeighborLoader for gnn training, include x, edge_index.

  • optimizer (torch.optim.Optimizer) – optimizer for training

  • pseudo_labels (torch.Tensor) – the predicted labels used as pseudo labels

  • epoch (int) – current epoch

  • is_augmented (bool) – will use pseudo_labels or not

  • verbose (bool) – print training progress bar or not

Returns:

training accuracy loss (float): loss value

Return type:

acc (float)

train_lm(train_loader: DataLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]

Language model Training in every epoch.

Parameters:
Returns:

training accuracy loss (torch.float): loss value

Return type:

approx_acc (torch.tensor)

train_gnn(train_loader: NeighborLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]

GNN training step in every epoch.

Parameters:
  • train_loader (loader.NeighborLoader) – gnn Neighbor node loader

  • optimizer (torch.optim.Optimizer) – model optimizer

  • epoch (int) – current train epoch

  • pseudo_labels (torch.tensor) – 1-D tensor, predictions from lm

  • is_augmented (bool) – use pseudo labeled node or not

  • verbose (bool) – print training progress or not

Returns:

training accuracy loss (torch.float): loss value

Return type:

approx_acc (torch.tensor)

inference(em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False)[source]

GLEM inference step.

Parameters:
  • em_phase (str) – ‘gnn’ or ‘lm’

  • data_loader (dataloader or Neighborloader) – dataloader: for lm training, include tokenized data nodeloader: for gnn training, include x, edge_index

  • verbose (bool) – print inference progress or not

Returns:

n * m tensor, m is number of classes,

n is number of nodes

Return type:

out (torch.Tensor)

inference_lm(data_loader: DataLoader, verbose: bool = True)[source]

LM inference step.

Parameters:
  • data_loader (Dataloader) – include token, labels, and gold mask

  • verbose (bool) – print progress bar or not

Returns:

prediction from GNN, convert to pseudo labels

by preds.argmax(dim=-1).unsqueeze(1)

Return type:

preds (tensor)

inference_gnn(data_loader: NeighborLoader, verbose: bool = True)[source]

GNN inference step.

Parameters:
  • data_loader (NeighborLoader) – include x, edge_index,

  • verbose (bool) – print progress bar or not

Returns:

prediction from GNN,

convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)

Return type:

preds (tensor)

loss(logits: ~torch.Tensor, labels: ~torch.Tensor, loss_func: <module 'torch.nn.functional' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/10089/lib/python3.9/site-packages/torch/nn/functional.py'>, is_gold: ~torch.Tensor, pseudo_labels: ~typing.Optional[~torch.Tensor] = None, pl_weight: float = 0.5, is_augmented: bool = True)[source]

Core function of variational EM inference, this function is aming on combining loss value on gold(original train) and loss value on pseudo labels.

Reference: <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa

Parameters:
  • logits (torch.tensor) – predict results from LM or GNN

  • labels (torch.tensor) – combined node labels from ground truth and pseudo labels(if provided)

  • loss_func (torch.nn.modules.loss) – loss function for classification

  • is_gold (tensor) – a tensor with bool value that mask ground truth label and during training, thus ~is_gold mask pseudo labels

  • pseudo_labels (torch.tensor) – predictions from other model

  • pl_weight (float, default: 0.5) – the pseudo labels used in E-step and M-step optimization alpha in E-step, beta in M-step respectively

  • is_augmented (bool, default: True) – use EM or just train GNN and LM with gold data