import warnings
from typing import Any, Dict, List, Optional, Type
import torch
from torch import Tensor
from torch_geometric.typing import NodeType
[docs]def get_embeddings(
model: torch.nn.Module,
*args: Any,
**kwargs: Any,
) -> List[Tensor]:
"""Returns the output embeddings of all
:class:`~torch_geometric.nn.conv.MessagePassing` layers in
:obj:`model`.
Internally, this method registers forward hooks on all
:class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`,
and runs the forward pass of the :obj:`model` by calling
:obj:`model(*args, **kwargs)`.
Args:
model (torch.nn.Module): The message passing model.
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the model.
"""
from torch_geometric.nn import MessagePassing
embeddings: List[Tensor] = []
def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
# Clone output in case it will be later modified in-place:
outputs = outputs[0] if isinstance(outputs, tuple) else outputs
assert isinstance(outputs, Tensor)
embeddings.append(outputs.clone())
hook_handles = []
for module in model.modules(): # Register forward hooks:
if isinstance(module, MessagePassing):
hook_handles.append(module.register_forward_hook(hook))
if len(hook_handles) == 0:
warnings.warn("The 'model' does not have any 'MessagePassing' layers",
stacklevel=2)
training = model.training
model.eval()
with torch.no_grad():
model(*args, **kwargs)
model.train(training)
for handle in hook_handles: # Remove hooks:
handle.remove()
return embeddings
[docs]def get_embeddings_hetero(
model: torch.nn.Module,
supported_models: Optional[List[Type[torch.nn.Module]]] = None,
*args: Any,
**kwargs: Any,
) -> Dict[NodeType, List[Tensor]]:
"""Returns the output embeddings of all
:class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous
:obj:`model`, organized by edge type.
Internally, this method registers forward hooks on all modules that process
heterogeneous graphs in the model and runs the forward pass of the model.
For heterogeneous models, the output is a dictionary where each key is a
node type and each value is a list of embeddings from different layers.
Args:
model (torch.nn.Module): The heterogeneous GNN model.
supported_models (List[Type[torch.nn.Module]], optional): A list of
supported model classes. If not provided, defaults to
[HGTConv, HANConv, HeteroConv].
*args: Arguments passed to the model.
**kwargs (optional): Additional keyword arguments passed to the model.
Returns:
Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to
a list of embeddings from different layers.
"""
from torch_geometric.nn import HANConv, HeteroConv, HGTConv
if not supported_models:
supported_models = [HGTConv, HANConv, HeteroConv]
# Dictionary to store node embeddings by type
node_embeddings_dict: Dict[NodeType, List[Tensor]] = {}
# Hook function to capture node embeddings
def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None:
# Check if the outputs is a dictionary mapping node types to embeddings
if isinstance(outputs, dict) and outputs:
# Store embeddings for each node type
for node_type, embedding in outputs.items():
# Made sure that the outputs are a dictionary mapping node
# types to embeddings and remove the false positives.
if node_type not in node_embeddings_dict:
node_embeddings_dict[node_type] = []
node_embeddings_dict[node_type].append(embedding.clone())
# List to store hook handles
hook_handles = []
# Find ModuleDict objects in the model
for _, module in model.named_modules():
# Handle the native heterogenous models, e.g. HGTConv, HANConv
# and HeteroConv, etc.
if isinstance(module, tuple(supported_models)):
hook_handles.append(module.register_forward_hook(hook))
else:
# Handle the heterogenous models that are generated by calling
# to_hetero() on the homogeneous models.
submodules = list(module.children())
submodules_contains_module_dict = any([
isinstance(submodule, torch.nn.ModuleDict)
for submodule in submodules
])
if submodules_contains_module_dict:
hook_handles.append(module.register_forward_hook(hook))
if len(hook_handles) == 0:
warnings.warn(
"The 'model' does not have any heterogenous "
"'MessagePassing' layers", stacklevel=2)
# Run the model forward pass
training = model.training
model.eval()
with torch.no_grad():
model(*args, **kwargs)
model.train(training)
# Clean up hooks
for handle in hook_handles:
handle.remove()
return node_embeddings_dict