from math import sqrt
from typing import Dict, Optional, Tuple, Union, overload
import torch
from torch import Tensor
from torch.nn.parameter import Parameter
from torch_geometric.explain import (
ExplainerConfig,
Explanation,
HeteroExplanation,
ModelConfig,
)
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.algorithm.utils import (
clear_masks,
set_hetero_masks,
set_masks,
)
from torch_geometric.explain.config import MaskType, ModelMode, ModelTaskLevel
from torch_geometric.typing import EdgeType, NodeType
[docs]class GNNExplainer(ExplainerAlgorithm):
r"""The GNN-Explainer model from the `"GNNExplainer: Generating
Explanations for Graph Neural Networks"
<https://arxiv.org/abs/1903.03894>`_ paper for identifying compact subgraph
structures and node features that play a crucial role in the predictions
made by a GNN.
.. note::
For an example of using :class:`GNNExplainer`, see
`examples/explain/gnn_explainer.py <https://github.com/pyg-team/
pytorch_geometric/blob/master/examples/explain/gnn_explainer.py>`_,
`examples/explain/gnn_explainer_ba_shapes.py <https://github.com/
pyg-team/pytorch_geometric/blob/master/examples/
explain/gnn_explainer_ba_shapes.py>`_, and `examples/explain/
gnn_explainer_link_pred.py <https://github.com/pyg-team/
pytorch_geometric/blob/master/examples/explain/gnn_explainer_link_pred.py>`_.
.. note::
The :obj:`edge_size` coefficient is multiplied by the number of nodes
in the explanation at every iteration, and the resulting value is added
to the loss as a regularization term, with the goal of producing
compact explanations.
A higher value will push the algorithm towards explanations with less
elements.
Consider adjusting the :obj:`edge_size` coefficient according to the
average node degree in the dataset, especially if this value is bigger
than in the datasets used in the original paper.
Args:
epochs (int, optional): The number of epochs to train.
(default: :obj:`100`)
lr (float, optional): The learning rate to apply.
(default: :obj:`0.01`)
**kwargs (optional): Additional hyper-parameters to override default
settings in
:attr:`~torch_geometric.explain.algorithm.GNNExplainer.coeffs`.
"""
default_coeffs = {
'edge_size': 0.005,
'edge_reduction': 'sum',
'node_feat_size': 1.0,
'node_feat_reduction': 'mean',
'edge_ent': 1.0,
'node_feat_ent': 0.1,
'EPS': 1e-15,
}
def __init__(self, epochs: int = 100, lr: float = 0.01, **kwargs):
super().__init__()
self.epochs = epochs
self.lr = lr
self.coeffs = dict(self.default_coeffs)
self.coeffs.update(kwargs)
self.node_mask = self.hard_node_mask = None
self.edge_mask = self.hard_edge_mask = None
self.is_hetero = False
@overload
def forward(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> Explanation:
...
@overload
def forward(
self,
model: torch.nn.Module,
x: Dict[NodeType, Tensor],
edge_index: Dict[EdgeType, Tensor],
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> HeteroExplanation:
...
[docs] def forward(
self,
model: torch.nn.Module,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> Union[Explanation, HeteroExplanation]:
self.is_hetero = isinstance(x, dict)
self._train(model, x, edge_index, target=target, index=index, **kwargs)
explanation = self._create_explanation()
self._clean_model(model)
return explanation
def _create_explanation(self) -> Union[Explanation, HeteroExplanation]:
"""Create an explanation object from the current masks."""
if self.is_hetero:
# For heterogeneous graphs, process each type separately
node_mask_dict = {}
edge_mask_dict = {}
for node_type, mask in self.node_mask.items():
if mask is not None:
node_mask_dict[node_type] = self._post_process_mask(
mask,
self.hard_node_mask[node_type],
apply_sigmoid=True,
)
for edge_type, mask in self.edge_mask.items():
if mask is not None:
edge_mask_dict[edge_type] = self._post_process_mask(
mask,
self.hard_edge_mask[edge_type],
apply_sigmoid=True,
)
# Create heterogeneous explanation
explanation = HeteroExplanation()
explanation.set_value_dict('node_mask', node_mask_dict)
explanation.set_value_dict('edge_mask', edge_mask_dict)
else:
# For homogeneous graphs, process single masks
node_mask = self._post_process_mask(
self.node_mask,
self.hard_node_mask,
apply_sigmoid=True,
)
edge_mask = self._post_process_mask(
self.edge_mask,
self.hard_edge_mask,
apply_sigmoid=True,
)
# Create homogeneous explanation
explanation = Explanation(node_mask=node_mask, edge_mask=edge_mask)
return explanation
[docs] def supports(self) -> bool:
return True
@overload
def _train(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> None:
...
@overload
def _train(
self,
model: torch.nn.Module,
x: Dict[NodeType, Tensor],
edge_index: Dict[EdgeType, Tensor],
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> None:
...
def _train(
self,
model: torch.nn.Module,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
*,
target: Tensor,
index: Optional[Union[int, Tensor]] = None,
**kwargs,
) -> None:
# Initialize masks based on input type
self._initialize_masks(x, edge_index)
# Collect parameters for optimization
parameters = self._collect_parameters(model, edge_index)
# Create optimizer
optimizer = torch.optim.Adam(parameters, lr=self.lr)
# Training loop
for i in range(self.epochs):
optimizer.zero_grad()
# Forward pass with masked inputs
y_hat = self._forward_with_masks(model, x, edge_index, **kwargs)
y = target
# Handle index if provided
if index is not None:
y_hat, y = y_hat[index], y[index]
# Calculate loss
loss = self._loss(y_hat, y)
# Backward pass
loss.backward()
optimizer.step()
# In the first iteration, collect gradients to identify important
# nodes/edges
if i == 0:
self._collect_gradients()
def _collect_parameters(self, model, edge_index):
"""Collect parameters for optimization."""
parameters = []
if self.is_hetero:
# For heterogeneous graphs, collect parameters from all types
for mask in self.node_mask.values():
if mask is not None:
parameters.append(mask)
if any(v is not None for v in self.edge_mask.values()):
set_hetero_masks(model, self.edge_mask, edge_index)
for mask in self.edge_mask.values():
if mask is not None:
parameters.append(mask)
else:
# For homogeneous graphs, collect single parameters
if self.node_mask is not None:
parameters.append(self.node_mask)
if self.edge_mask is not None:
set_masks(model, self.edge_mask, edge_index,
apply_sigmoid=True)
parameters.append(self.edge_mask)
return parameters
@overload
def _forward_with_masks(
self,
model: torch.nn.Module,
x: Tensor,
edge_index: Tensor,
**kwargs,
) -> Tensor:
...
@overload
def _forward_with_masks(
self,
model: torch.nn.Module,
x: Dict[NodeType, Tensor],
edge_index: Dict[EdgeType, Tensor],
**kwargs,
) -> Tensor:
...
def _forward_with_masks(
self,
model: torch.nn.Module,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
**kwargs,
) -> Tensor:
"""Forward pass with masked inputs."""
if self.is_hetero:
# Apply masks to heterogeneous inputs
h_dict = {}
for node_type, features in x.items():
if node_type in self.node_mask and self.node_mask[
node_type] is not None:
h_dict[node_type] = features * self.node_mask[
node_type].sigmoid()
else:
h_dict[node_type] = features
# Forward pass with masked features
return model(h_dict, edge_index, **kwargs)
else:
# Apply mask to homogeneous input
h = x if self.node_mask is None else x * self.node_mask.sigmoid()
# Forward pass with masked features
return model(h, edge_index, **kwargs)
def _initialize_masks(
self,
x: Union[Tensor, Dict[NodeType, Tensor]],
edge_index: Union[Tensor, Dict[EdgeType, Tensor]],
) -> None:
node_mask_type = self.explainer_config.node_mask_type
edge_mask_type = self.explainer_config.edge_mask_type
if self.is_hetero:
# Initialize dictionaries for heterogeneous masks
self.node_mask = {}
self.hard_node_mask = {}
self.edge_mask = {}
self.hard_edge_mask = {}
# Initialize node masks for each node type
for node_type, features in x.items():
device = features.device
N, F = features.size()
self._initialize_node_mask(node_mask_type, node_type, N, F,
device)
# Initialize edge masks for each edge type
for edge_type, indices in edge_index.items():
device = indices.device
E = indices.size(1)
N = max(indices.max().item() + 1,
max(feat.size(0) for feat in x.values()))
self._initialize_edge_mask(edge_mask_type, edge_type, E, N,
device)
else:
# Initialize masks for homogeneous graph
device = x.device
(N, F), E = x.size(), edge_index.size(1)
# Initialize homogeneous node and edge masks
self._initialize_homogeneous_masks(node_mask_type, edge_mask_type,
N, F, E, device)
def _initialize_node_mask(
self,
node_mask_type,
node_type,
N,
F,
device,
) -> None:
"""Initialize node mask for a specific node type."""
std = 0.1
if node_mask_type is None:
self.node_mask[node_type] = None
self.hard_node_mask[node_type] = None
elif node_mask_type == MaskType.object:
self.node_mask[node_type] = Parameter(
torch.randn(N, 1, device=device) * std)
self.hard_node_mask[node_type] = None
elif node_mask_type == MaskType.attributes:
self.node_mask[node_type] = Parameter(
torch.randn(N, F, device=device) * std)
self.hard_node_mask[node_type] = None
elif node_mask_type == MaskType.common_attributes:
self.node_mask[node_type] = Parameter(
torch.randn(1, F, device=device) * std)
self.hard_node_mask[node_type] = None
else:
raise ValueError(f"Invalid node mask type: {node_mask_type}")
def _initialize_edge_mask(self, edge_mask_type, edge_type, E, N, device):
"""Initialize edge mask for a specific edge type."""
if edge_mask_type is None:
self.edge_mask[edge_type] = None
self.hard_edge_mask[edge_type] = None
elif edge_mask_type == MaskType.object:
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
self.edge_mask[edge_type] = Parameter(
torch.randn(E, device=device) * std)
self.hard_edge_mask[edge_type] = None
else:
raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
def _initialize_homogeneous_masks(self, node_mask_type, edge_mask_type, N,
F, E, device):
"""Initialize masks for homogeneous graph."""
# Initialize node mask
std = 0.1
if node_mask_type is None:
self.node_mask = None
elif node_mask_type == MaskType.object:
self.node_mask = Parameter(torch.randn(N, 1, device=device) * std)
elif node_mask_type == MaskType.attributes:
self.node_mask = Parameter(torch.randn(N, F, device=device) * std)
elif node_mask_type == MaskType.common_attributes:
self.node_mask = Parameter(torch.randn(1, F, device=device) * std)
else:
raise ValueError(f"Invalid node mask type: {node_mask_type}")
# Initialize edge mask
if edge_mask_type is None:
self.edge_mask = None
elif edge_mask_type == MaskType.object:
std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
self.edge_mask = Parameter(torch.randn(E, device=device) * std)
else:
raise ValueError(f"Invalid edge mask type: {edge_mask_type}")
def _collect_gradients(self) -> None:
if self.is_hetero:
self._collect_hetero_gradients()
else:
self._collect_homo_gradients()
def _collect_hetero_gradients(self):
"""Collect gradients for heterogeneous graph."""
for node_type, mask in self.node_mask.items():
if mask is not None:
if mask.grad is None:
raise ValueError(
f"Could not compute gradients for node masks of type "
f"'{node_type}'. Please make sure that node masks are "
f"used inside the model or disable it via "
f"`node_mask_type=None`.")
self.hard_node_mask[node_type] = mask.grad != 0.0
for edge_type, mask in self.edge_mask.items():
if mask is not None:
if mask.grad is None:
raise ValueError(
f"Could not compute gradients for edge masks of type "
f"'{edge_type}'. Please make sure that edge masks are "
f"used inside the model or disable it via "
f"`edge_mask_type=None`.")
self.hard_edge_mask[edge_type] = mask.grad != 0.0
def _collect_homo_gradients(self):
"""Collect gradients for homogeneous graph."""
if self.node_mask is not None:
if self.node_mask.grad is None:
raise ValueError("Could not compute gradients for node "
"features. Please make sure that node "
"features are used inside the model or "
"disable it via `node_mask_type=None`.")
self.hard_node_mask = self.node_mask.grad != 0.0
if self.edge_mask is not None:
if self.edge_mask.grad is None:
raise ValueError("Could not compute gradients for edges. "
"Please make sure that edges are used "
"via message passing inside the model or "
"disable it via `edge_mask_type=None`.")
self.hard_edge_mask = self.edge_mask.grad != 0.0
def _loss(self, y_hat: Tensor, y: Tensor) -> Tensor:
# Calculate base loss based on model configuration
loss = self._calculate_base_loss(y_hat, y)
# Apply regularization based on graph type
if self.is_hetero:
# Apply regularization for heterogeneous graph
loss = self._apply_hetero_regularization(loss)
else:
# Apply regularization for homogeneous graph
loss = self._apply_homo_regularization(loss)
return loss
def _calculate_base_loss(self, y_hat, y):
"""Calculate base loss based on model configuration."""
if self.model_config.mode == ModelMode.binary_classification:
return self._loss_binary_classification(y_hat, y)
elif self.model_config.mode == ModelMode.multiclass_classification:
return self._loss_multiclass_classification(y_hat, y)
elif self.model_config.mode == ModelMode.regression:
return self._loss_regression(y_hat, y)
else:
raise ValueError(f"Invalid model mode: {self.model_config.mode}")
def _apply_hetero_regularization(self, loss):
"""Apply regularization for heterogeneous graph."""
# Apply regularization for each edge type
for edge_type, mask in self.edge_mask.items():
if (mask is not None
and self.hard_edge_mask[edge_type] is not None):
loss = self._add_mask_regularization(
loss, mask, self.hard_edge_mask[edge_type],
self.coeffs['edge_size'], self.coeffs['edge_reduction'],
self.coeffs['edge_ent'])
# Apply regularization for each node type
for node_type, mask in self.node_mask.items():
if (mask is not None
and self.hard_node_mask[node_type] is not None):
loss = self._add_mask_regularization(
loss, mask, self.hard_node_mask[node_type],
self.coeffs['node_feat_size'],
self.coeffs['node_feat_reduction'],
self.coeffs['node_feat_ent'])
return loss
def _apply_homo_regularization(self, loss):
"""Apply regularization for homogeneous graph."""
# Apply regularization for edge mask
if self.hard_edge_mask is not None:
assert self.edge_mask is not None
loss = self._add_mask_regularization(loss, self.edge_mask,
self.hard_edge_mask,
self.coeffs['edge_size'],
self.coeffs['edge_reduction'],
self.coeffs['edge_ent'])
# Apply regularization for node mask
if self.hard_node_mask is not None:
assert self.node_mask is not None
loss = self._add_mask_regularization(
loss, self.node_mask, self.hard_node_mask,
self.coeffs['node_feat_size'],
self.coeffs['node_feat_reduction'],
self.coeffs['node_feat_ent'])
return loss
def _add_mask_regularization(self, loss, mask, hard_mask, size_coeff,
reduction_name, ent_coeff):
"""Add size and entropy regularization for a mask."""
m = mask[hard_mask].sigmoid()
reduce_fn = getattr(torch, reduction_name)
# Add size regularization
loss = loss + size_coeff * reduce_fn(m)
# Add entropy regularization
ent = -m * torch.log(m + self.coeffs['EPS']) - (
1 - m) * torch.log(1 - m + self.coeffs['EPS'])
loss = loss + ent_coeff * ent.mean()
return loss
def _clean_model(self, model):
clear_masks(model)
self.node_mask = self.hard_node_mask = None
self.edge_mask = self.hard_edge_mask = None
class GNNExplainer_:
r"""Deprecated version for :class:`GNNExplainer`."""
coeffs = GNNExplainer.default_coeffs
conversion_node_mask_type = {
'feature': 'common_attributes',
'individual_feature': 'attributes',
'scalar': 'object',
}
conversion_return_type = {
'log_prob': 'log_probs',
'prob': 'probs',
'raw': 'raw',
'regression': 'raw',
}
def __init__(
self,
model: torch.nn.Module,
epochs: int = 100,
lr: float = 0.01,
return_type: str = 'log_prob',
feat_mask_type: str = 'feature',
allow_edge_mask: bool = True,
**kwargs,
):
assert feat_mask_type in ['feature', 'individual_feature', 'scalar']
explainer_config = ExplainerConfig(
explanation_type='model',
node_mask_type=self.conversion_node_mask_type[feat_mask_type],
edge_mask_type=MaskType.object if allow_edge_mask else None,
)
model_config = ModelConfig(
mode='regression'
if return_type == 'regression' else 'multiclass_classification',
task_level=ModelTaskLevel.node,
return_type=self.conversion_return_type[return_type],
)
self.model = model
self._explainer = GNNExplainer(epochs=epochs, lr=lr, **kwargs)
self._explainer.connect(explainer_config, model_config)
@torch.no_grad()
def get_initial_prediction(self, *args, **kwargs) -> Tensor:
training = self.model.training
self.model.eval()
out = self.model(*args, **kwargs)
if (self._explainer.model_config.mode ==
ModelMode.multiclass_classification):
out = out.argmax(dim=-1)
self.model.train(training)
return out
def explain_graph(
self,
x: Tensor,
edge_index: Tensor,
**kwargs,
) -> Tuple[Tensor, Tensor]:
self._explainer.model_config.task_level = ModelTaskLevel.graph
explanation = self._explainer(
self.model,
x,
edge_index,
target=self.get_initial_prediction(x, edge_index, **kwargs),
**kwargs,
)
return self._convert_output(explanation, edge_index)
def explain_node(
self,
node_idx: int,
x: Tensor,
edge_index: Tensor,
**kwargs,
) -> Tuple[Tensor, Tensor]:
self._explainer.model_config.task_level = ModelTaskLevel.node
explanation = self._explainer(
self.model,
x,
edge_index,
target=self.get_initial_prediction(x, edge_index, **kwargs),
index=node_idx,
**kwargs,
)
return self._convert_output(explanation, edge_index, index=node_idx,
x=x)
def _convert_output(self, explanation, edge_index, index=None, x=None):
node_mask = explanation.get('node_mask')
edge_mask = explanation.get('edge_mask')
if node_mask is not None:
node_mask_type = self._explainer.explainer_config.node_mask_type
if node_mask_type in {MaskType.object, MaskType.common_attributes}:
node_mask = node_mask.view(-1)
if edge_mask is None:
if index is not None:
_, edge_mask = self._explainer._get_hard_masks(
self.model, index, edge_index, num_nodes=x.size(0))
edge_mask = edge_mask.to(x.dtype)
else:
edge_mask = torch.ones(edge_index.size(1),
device=edge_index.device)
return node_mask, edge_mask