Source code for torch_geometric.explain.algorithm.attention_explainer

import logging
from typing import Dict, List, Optional, Union, overload

import torch
from torch import Tensor

from torch_geometric.explain import Explanation, HeteroExplanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import ExplanationType, ModelTaskLevel
from torch_geometric.nn.conv.message_passing import MessagePassing
from torch_geometric.typing import EdgeType, NodeType


[docs]class AttentionExplainer(ExplainerAlgorithm): r"""An explainer that uses the attention coefficients produced by an attention-based GNN (*e.g.*, :class:`~torch_geometric.nn.conv.GATConv`, :class:`~torch_geometric.nn.conv.GATv2Conv`, or :class:`~torch_geometric.nn.conv.TransformerConv`) as edge explanation. Attention scores across layers and heads will be aggregated according to the :obj:`reduce` argument. Args: reduce (str, optional): The method to reduce the attention scores across layers and heads. (default: :obj:`"max"`) """ def __init__(self, reduce: str = 'max'): super().__init__() self.reduce = reduce self.is_hetero = False @overload def forward( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Explanation: ... @overload def forward( self, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> HeteroExplanation: ...
[docs] def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: """Generate explanations based on attention coefficients.""" self.is_hetero = isinstance(x, dict) # Collect attention coefficients alphas_dict = self._collect_attention_coefficients( model, x, edge_index, **kwargs) # Process attention coefficients if self.is_hetero: return self._create_hetero_explanation(model, alphas_dict, edge_index, index, x) else: return self._create_homo_explanation(model, alphas_dict, edge_index, index, x)
@overload def _collect_attention_coefficients( self, model: torch.nn.Module, x: Tensor, edge_index: Tensor, **kwargs, ) -> List[Tensor]: ... @overload def _collect_attention_coefficients( self, model: torch.nn.Module, x: Dict[NodeType, Tensor], edge_index: Dict[EdgeType, Tensor], **kwargs, ) -> Dict[EdgeType, List[Tensor]]: ... def _collect_attention_coefficients( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], **kwargs, ) -> Union[List[Tensor], Dict[EdgeType, List[Tensor]]]: """Collect attention coefficients from model layers.""" if self.is_hetero: # For heterogeneous graphs, store alphas by edge type alphas_dict: Dict[EdgeType, List[Tensor]] = {} # Get list of edge types edge_types = list(edge_index.keys()) # Hook function to capture attention coefficients by edge type def hook(module, msg_kwargs, out): # Find edge type from the module's full name module_name = getattr(module, '_name', None) if module_name is None: return edge_type = None for edge_tuple in edge_types: src_type, edge_name, dst_type = edge_tuple # Check if all components appear in the module name in # order try: src_idx = module_name.index(src_type) edge_idx = module_name.index(edge_name, src_idx) dst_idx = module_name.index(dst_type, edge_idx) if src_idx < edge_idx < dst_idx: edge_type = edge_tuple break except ValueError: # Component not found continue if edge_type is None: return if edge_type not in alphas_dict: alphas_dict[edge_type] = [] # Extract alpha from message kwargs or module if 'alpha' in msg_kwargs[0]: alphas_dict[edge_type].append( msg_kwargs[0]['alpha'].detach()) elif getattr(module, '_alpha', None) is not None: alphas_dict[edge_type].append(module._alpha.detach()) else: # For homogeneous graphs, store all alphas in a list alphas: List[Tensor] = [] def hook(module, msg_kwargs, out): if 'alpha' in msg_kwargs[0]: alphas.append(msg_kwargs[0]['alpha'].detach()) elif getattr(module, '_alpha', None) is not None: alphas.append(module._alpha.detach()) # Register hooks for all message passing modules hook_handles = [] for name, module in model.named_modules(): if isinstance(module, MessagePassing) and module.explain is not False: # Store name for hetero graph lookup in the hook if self.is_hetero: module._name = name hook_handles.append(module.register_message_forward_hook(hook)) # Forward pass to collect attention coefficients. model(x, edge_index, **kwargs) # Remove hooks for handle in hook_handles: handle.remove() # Check if we collected any attention coefficients. if self.is_hetero: if not alphas_dict: raise ValueError( "Could not collect any attention coefficients. " "Please ensure that your model is using " "attention-based GNN layers.") return alphas_dict else: if not alphas: raise ValueError( "Could not collect any attention coefficients. " "Please ensure that your model is using " "attention-based GNN layers.") return alphas def _process_attention_coefficients( self, alphas: List[Tensor], edge_index_size: int, ) -> Tensor: """Process collected attention coefficients into a single mask.""" for i, alpha in enumerate(alphas): # Ensure alpha doesn't exceed edge_index size alpha = alpha[:edge_index_size] # Reduce multi-head attention if alpha.dim() == 2: alpha = getattr(torch, self.reduce)(alpha, dim=-1) if isinstance(alpha, tuple): # Handle torch.max output alpha = alpha[0] elif alpha.dim() > 2: raise ValueError(f"Cannot reduce attention coefficients of " f"shape {list(alpha.size())}") alphas[i] = alpha # Combine attention coefficients across layers if len(alphas) > 1: alpha = torch.stack(alphas, dim=-1) alpha = getattr(torch, self.reduce)(alpha, dim=-1) if isinstance(alpha, tuple): # Handle torch.max output alpha = alpha[0] else: alpha = alphas[0] return alpha def _create_homo_explanation( self, model: torch.nn.Module, alphas: List[Tensor], edge_index: Tensor, index: Optional[Union[int, Tensor]], x: Tensor, ) -> Explanation: """Create explanation for homogeneous graph.""" # Get hard edge mask for node-level tasks hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: _, hard_edge_mask = self._get_hard_masks(model, index, edge_index, num_nodes=x.size(0)) # Process attention coefficients alpha = self._process_attention_coefficients(alphas, edge_index.size(1)) # Post-process mask with hard edge mask if needed alpha = self._post_process_mask(alpha, hard_edge_mask, apply_sigmoid=False) return Explanation(edge_mask=alpha) def _create_hetero_explanation( self, model: torch.nn.Module, alphas_dict: Dict[EdgeType, List[Tensor]], edge_index: Dict[EdgeType, Tensor], index: Optional[Union[int, Tensor]], x: Dict[NodeType, Tensor], ) -> HeteroExplanation: """Create explanation for heterogeneous graph.""" edge_masks_dict = {} # Process each edge type separately for edge_type, alphas in alphas_dict.items(): if not alphas: continue # Get hard edge mask for node-level tasks hard_edge_mask = None if self.model_config.task_level == ModelTaskLevel.node: src_type, _, dst_type = edge_type _, hard_edge_mask = self._get_hard_masks( model, index, edge_index[edge_type], num_nodes=max(x[src_type].size(0), x[dst_type].size(0))) # Process attention coefficients for this edge type alpha = self._process_attention_coefficients( alphas, edge_index[edge_type].size(1)) # Apply hard mask if available edge_masks_dict[edge_type] = self._post_process_mask( alpha, hard_edge_mask, apply_sigmoid=False) # Create heterogeneous explanation explanation = HeteroExplanation() explanation.set_value_dict('edge_mask', edge_masks_dict) return explanation
[docs] def supports(self) -> bool: explanation_type = self.explainer_config.explanation_type if explanation_type != ExplanationType.model: logging.error(f"'{self.__class__.__name__}' only supports " f"model explanations " f"got (`explanation_type={explanation_type.value}`)") return False node_mask_type = self.explainer_config.node_mask_type if node_mask_type is not None: logging.error(f"'{self.__class__.__name__}' does not support " f"explaining input node features " f"got (`node_mask_type={node_mask_type.value}`)") return False return True