Source code for torch_geometric.utils.embedding

import warnings
from typing import Any, Dict, List, Optional, Type

import torch
from torch import Tensor

from torch_geometric.typing import NodeType


[docs]def get_embeddings( model: torch.nn.Module, *args: Any, **kwargs: Any, ) -> List[Tensor]: """Returns the output embeddings of all :class:`~torch_geometric.nn.conv.MessagePassing` layers in :obj:`model`. Internally, this method registers forward hooks on all :class:`~torch_geometric.nn.conv.MessagePassing` layers of a :obj:`model`, and runs the forward pass of the :obj:`model` by calling :obj:`model(*args, **kwargs)`. Args: model (torch.nn.Module): The message passing model. *args: Arguments passed to the model. **kwargs (optional): Additional keyword arguments passed to the model. """ from torch_geometric.nn import MessagePassing embeddings: List[Tensor] = [] def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None: # Clone output in case it will be later modified in-place: outputs = outputs[0] if isinstance(outputs, tuple) else outputs assert isinstance(outputs, Tensor) embeddings.append(outputs.clone()) hook_handles = [] for module in model.modules(): # Register forward hooks: if isinstance(module, MessagePassing): hook_handles.append(module.register_forward_hook(hook)) if len(hook_handles) == 0: warnings.warn("The 'model' does not have any 'MessagePassing' layers", stacklevel=2) training = model.training model.eval() with torch.no_grad(): model(*args, **kwargs) model.train(training) for handle in hook_handles: # Remove hooks: handle.remove() return embeddings
[docs]def get_embeddings_hetero( model: torch.nn.Module, supported_models: Optional[List[Type[torch.nn.Module]]] = None, *args: Any, **kwargs: Any, ) -> Dict[NodeType, List[Tensor]]: """Returns the output embeddings of all :class:`~torch_geometric.nn.conv.MessagePassing` layers in a heterogeneous :obj:`model`, organized by edge type. Internally, this method registers forward hooks on all modules that process heterogeneous graphs in the model and runs the forward pass of the model. For heterogeneous models, the output is a dictionary where each key is a node type and each value is a list of embeddings from different layers. Args: model (torch.nn.Module): The heterogeneous GNN model. supported_models (List[Type[torch.nn.Module]], optional): A list of supported model classes. If not provided, defaults to [HGTConv, HANConv, HeteroConv]. *args: Arguments passed to the model. **kwargs (optional): Additional keyword arguments passed to the model. Returns: Dict[NodeType, List[Tensor]]: A dictionary mapping each node type to a list of embeddings from different layers. """ from torch_geometric.nn import HANConv, HeteroConv, HGTConv if not supported_models: supported_models = [HGTConv, HANConv, HeteroConv] # Dictionary to store node embeddings by type node_embeddings_dict: Dict[NodeType, List[Tensor]] = {} # Hook function to capture node embeddings def hook(model: torch.nn.Module, inputs: Any, outputs: Any) -> None: # Check if the outputs is a dictionary mapping node types to embeddings if isinstance(outputs, dict) and outputs: # Store embeddings for each node type for node_type, embedding in outputs.items(): # Made sure that the outputs are a dictionary mapping node # types to embeddings and remove the false positives. if node_type not in node_embeddings_dict: node_embeddings_dict[node_type] = [] node_embeddings_dict[node_type].append(embedding.clone()) # List to store hook handles hook_handles = [] # Find ModuleDict objects in the model for _, module in model.named_modules(): # Handle the native heterogenous models, e.g. HGTConv, HANConv # and HeteroConv, etc. if isinstance(module, tuple(supported_models)): hook_handles.append(module.register_forward_hook(hook)) else: # Handle the heterogenous models that are generated by calling # to_hetero() on the homogeneous models. submodules = list(module.children()) submodules_contains_module_dict = any([ isinstance(submodule, torch.nn.ModuleDict) for submodule in submodules ]) if submodules_contains_module_dict: hook_handles.append(module.register_forward_hook(hook)) if len(hook_handles) == 0: warnings.warn( "The 'model' does not have any heterogenous " "'MessagePassing' layers", stacklevel=2) # Run the model forward pass training = model.training model.eval() with torch.no_grad(): model(*args, **kwargs) model.train(training) # Clean up hooks for handle in hook_handles: handle.remove() return node_embeddings_dict