torch_geometric.llm.models.GLEM
- class GLEM(lm_to_use: str = 'prajjwal1/bert-tiny', gnn_to_use: <module 'torch_geometric.nn.models.basic_gnn' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/10089/lib/python3.9/site-packages/torch_geometric/nn/models/basic_gnn.py'> = <class 'torch_geometric.nn.models.basic_gnn.GraphSAGE'>, out_channels: int = 47, gnn_loss: ~typing.Optional[~torch.nn.modules.module.Module] = None, lm_loss: ~typing.Optional[~torch.nn.modules.module.Module] = None, alpha: float = 0.5, beta: float = 0.5, lm_dtype: ~torch.dtype = torch.bfloat16, lm_use_lora: bool = True, lora_target_modules: ~typing.Optional[~typing.Union[str, ~typing.List[str]]] = None, device: ~typing.Optional[~typing.Union[~torch.device, str]] = None)[source]
Bases:
Module
This GNN+LM co-training model is based on GLEM from the “Learning on Large-scale Text-attributed Graphs via Variational Inference” paper.
- Parameters:
lm_to_use (str) – A TextEncoder from huggingface model repo with a classifier(default: TinyBERT)
gnn_to_use (torch_geometric.nn.models) – (default: GraphSAGE)
out_channels (int) – output channels for LM and GNN, should be same
Optional[int] (num_gnn_heads) – Number of heads for attention, if needed
num_gnn_layers (int) – number of gnn layers
gnn_loss (
Optional
[Module
], default:None
) – loss function for gnn, (default: CrossEntropyLoss)lm_loss (
Optional
[Module
], default:None
) – loss function for Language Model, (default: CrossEntropyLoss)alpha (float) – pseudo label weight of E-step, LM optimization, (default: 0.5)
beta (float) – pseudo label weight of M-step, GNN optimization, (default: 0.5)
lm_dtype (torch.dtype) – the data type once you load LM into memory, (default: torch.bfloat16)
lm_use_lora (bool) – choose if LM use Lora peft for fine tune, (default: True)
lora_target_modules (
Union
[str
,List
[str
],None
], default:None
) – The names of the target modules to apply the lora adapter to, e.g. [‘q_proj’, ‘v_proj’] for LLM , (default: None)
Note
See examples/llm_plus_gnn/glem.py for example usage.
- train(em_phase: str, train_loader: Union[DataLoader, NeighborLoader], optimizer: Optimizer, pseudo_labels: Tensor, epoch: int, is_augmented: bool = False, verbose: bool = False)[source]
GLEM training step, EM steps.
- Parameters:
em_phase (str) – ‘gnn’ or ‘lm’ choose which phase you are training on
train_loader (Union[DataLoader, NeighborLoader]) – use DataLoader for lm training, include tokenized data, labels is_gold mask. use NeighborLoader for gnn training, include x, edge_index.
optimizer (torch.optim.Optimizer) – optimizer for training
pseudo_labels (torch.Tensor) – the predicted labels used as pseudo labels
epoch (int) – current epoch
is_augmented (bool) – will use pseudo_labels or not
verbose (bool) – print training progress bar or not
- Returns:
training accuracy loss (float): loss value
- Return type:
acc (float)
- train_lm(train_loader: DataLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]
Language model Training in every epoch.
- Parameters:
train_loader (loader.dataloader.DataLoader) – text token dataloader
optimizer (torch.optim.Optimizer) – model optimizer
epoch (int) – current train epoch
pseudo_labels (torch.Tensor) – 1-D tensor, predictions from gnn
is_augmented (bool) – train with pseudo labels or not
verbose (bool) – print training progress bar or not
- Returns:
training accuracy loss (torch.float): loss value
- Return type:
approx_acc (torch.tensor)
- train_gnn(train_loader: NeighborLoader, optimizer: Optimizer, epoch: int, pseudo_labels: Optional[Tensor] = None, is_augmented: bool = False, verbose: bool = True)[source]
GNN training step in every epoch.
- Parameters:
train_loader (loader.NeighborLoader) – gnn Neighbor node loader
optimizer (torch.optim.Optimizer) – model optimizer
epoch (int) – current train epoch
pseudo_labels (torch.tensor) – 1-D tensor, predictions from lm
is_augmented (bool) – use pseudo labeled node or not
verbose (bool) – print training progress or not
- Returns:
training accuracy loss (torch.float): loss value
- Return type:
approx_acc (torch.tensor)
- inference(em_phase: str, data_loader: Union[NeighborLoader, DataLoader], verbose: bool = False)[source]
GLEM inference step.
- Parameters:
- Returns:
- n * m tensor, m is number of classes,
n is number of nodes
- Return type:
out (torch.Tensor)
- inference_lm(data_loader: DataLoader, verbose: bool = True)[source]
LM inference step.
- Parameters:
data_loader (Dataloader) – include token, labels, and gold mask
verbose (bool) – print progress bar or not
- Returns:
- prediction from GNN, convert to pseudo labels
by preds.argmax(dim=-1).unsqueeze(1)
- Return type:
preds (tensor)
- inference_gnn(data_loader: NeighborLoader, verbose: bool = True)[source]
GNN inference step.
- Parameters:
data_loader (NeighborLoader) – include x, edge_index,
verbose (bool) – print progress bar or not
- Returns:
- prediction from GNN,
convert to pseudo labels by preds.argmax(dim=-1).unsqueeze(1)
- Return type:
preds (tensor)
- loss(logits: ~torch.Tensor, labels: ~torch.Tensor, loss_func: <module 'torch.nn.functional' from '/home/docs/checkouts/readthedocs.org/user_builds/pytorch-geometric/envs/10089/lib/python3.9/site-packages/torch/nn/functional.py'>, is_gold: ~torch.Tensor, pseudo_labels: ~typing.Optional[~torch.Tensor] = None, pl_weight: float = 0.5, is_augmented: bool = True)[source]
Core function of variational EM inference, this function is aming on combining loss value on gold(original train) and loss value on pseudo labels.
Reference: <https://github.com/AndyJZhao/GLEM/blob/main/src/models/GLEM/GLEM_utils.py> # noqa
- Parameters:
logits (torch.tensor) – predict results from LM or GNN
labels (torch.tensor) – combined node labels from ground truth and pseudo labels(if provided)
loss_func (torch.nn.modules.loss) – loss function for classification
is_gold (tensor) – a tensor with bool value that mask ground truth label and during training, thus ~is_gold mask pseudo labels
pseudo_labels (torch.tensor) – predictions from other model
pl_weight (
float
, default:0.5
) – the pseudo labels used in E-step and M-step optimization alpha in E-step, beta in M-step respectivelyis_augmented (
bool
, default:True
) – use EM or just train GNN and LM with gold data