torch_geometric.explain
Warning
This module is in active development and may not be stable. Access requires installing PyTorch Geometric from master.
Philoshopy
This module provides a set of tools to explain the predictions of a PyG model or to explain the underlying phenomenon of a dataset (see the “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” paper for more details).
We represent explanations using the torch_geometric.explain.Explanation
class, which is a Data
object containing masks for the nodes, edges, features and any attributes of the data.
The torch_geometric.explain.Explainer
class is designed to handle all explainability parameters (see the torch_geometric.explain.config.ExplainerConfig
class for more details):
which algorithm from the
torch_geometric.explain.algorithm
module to use (e.g.,GNNExplainer
)the type of explanation to compute (e.g.,
explanation_type="phenomenon"
orexplanation_type="model"
)the different type of masks for node and edges (e.g.,
mask="object"
ormask="attributes"
)any postprocessing of the masks (e.g.,
threshold_type="topk"
orthreshold_type="hard"
)
This class allows the user to easily compare different explainability methods and to easily switch between different types of masks, while making sure the high-level framework stays the same.
Explainer
- class Explainer(model: Module, algorithm: ExplainerAlgorithm, explanation_type: Union[ExplanationType, str], model_config: Union[ModelConfig, Dict[str, Any]], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None, threshold_config: Optional[ThresholdConfig] = None)[source]
An explainer class for instance-level explanations of Graph Neural Networks.
- Parameters
model (torch.nn.Module) – The model to explain.
algorithm (ExplainerAlgorithm) – The explanation algorithm.
explanation_type (ExplanationType or str) –
The type of explanation to compute. The possible values are:
"model"
: Explains the model prediction."phenomenon"
: Explains the phenomenon that the model is trying to predict.
In practice, this means that the explanation algorithm will either compute their losses with respect to the model output (
"model"
) or the target output ("phenomenon"
).model_config (ModelConfig) – The model configuration. See
ModelConfig
for available options. (default:None
)node_mask_type (MaskType or str, optional) –
The type of mask to apply on nodes. The possible values are (default:
None
):None
: Will not apply any mask on nodes."object"
: Will mask each node."common_attributes"
: Will mask each feature."attributes"
: Will mask each feature across all nodes.
edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Has the sample possible values as
node_mask_type
. (default:None
)threshold_config (ThresholdConfig, optional) – The threshold configuration. See
ThresholdConfig
for available options. (default:None
)
- __call__(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Optional[Tensor] = None, index: Optional[Union[int, Tensor]] = None, **kwargs) Union[Explanation, HeteroExplanation] [source]
Computes the explanation of the GNN for the given inputs and target.
Note
If you get an error message like “Trying to backward through the graph a second time”, make sure that the target you provided was computed with
torch.no_grad()
.- Parameters
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – The target of the model. If the explanation type is
"phenomenon"
, the target has to be provided. If the explanation type is"model"
, the target should be set toNone
and will get automatically inferred. (default:None
)index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default:
None
)**kwargs – additional arguments to pass to the GNN.
- get_prediction(*args, **kwargs) Tensor [source]
Returns the prediction of the model on the input graph.
If the model mode is
"regression"
, the prediction is returned as a scalar value. If the model mode is"multiclass_classification"
or"binary_classification"
, the prediction is returned as the predicted class label.- Parameters
*args – Arguments passed to the model.
**kwargs (optional) – Additional keyword arguments passed to the model.
- get_masked_prediction(x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], node_mask: Optional[Union[Tensor, Dict[str, Tensor]]] = None, edge_mask: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, **kwargs) Tensor [source]
Returns the prediction of the model on the input graph with node and edge masks applied.
- get_target(prediction: Tensor) Tensor [source]
Returns the target of the model from a given prediction.
If the model mode is of type
"regression"
, the prediction is returned as it is. If the model mode is of type"multiclass_classification"
or"binary_classification"
, the prediction is returned as the predicted class label.
- class ExplainerConfig(explanation_type: Union[ExplanationType, str], node_mask_type: Optional[Union[MaskType, str]] = None, edge_mask_type: Optional[Union[MaskType, str]] = None)[source]
Configuration class to store and validate high level explanation parameters.
- Parameters
explanation_type (ExplanationType or str) –
The type of explanation to compute. The possible values are:
"model"
: Explains the model prediction."phenomenon"
: Explains the phenomenon that the model is trying to predict.
In practice, this means that the explanation algorithm will either compute their losses with respect to the model output (
"model"
) or the target output ("phenomenon"
).node_mask_type (MaskType or str, optional) –
The type of mask to apply on nodes. The possible values are (default:
None
):None
: Will not apply any mask on nodes."object"
: Will mask each node."common_attributes"
: Will mask each feature."attributes"
: Will mask each feature across all nodes.
edge_mask_type (MaskType or str, optional) – The type of mask to apply on edges. Has the sample possible values as
node_mask_type
. (default:None
)
- class ModelConfig(mode: Union[ModelMode, str], task_level: Union[ModelTaskLevel, str], return_type: Optional[Union[ModelReturnType, str]] = None)[source]
Configuration class to store model parameters.
- Parameters
mode (ModelMode or str) –
The mode of the model. The possible values are:
"binary_classification"
: A binary classification model."multiclass_classification"
: A multiclass classification model."regression"
: A regression model.
task_level (ModelTaskLevel or str) –
The task-level of the model. The possible values are:
"node"
: A node-level prediction model."edge"
: An edge-level prediction model."graph"
: A graph-level prediction model.
return_type (ModelReturnType or str, optional) –
The return type of the model. The possible values are (default:
None
):"raw"
: The model returns raw values."probs"
: The model returns probabilities."log_probs"
: The model returns log-probabilities.
- class ThresholdConfig(threshold_type: Union[ThresholdType, str], value: Union[float, int])[source]
Configuration class to store and validate threshold parameters.
- Parameters
threshold_type (ThresholdType or str) –
The type of threshold to apply. The possible values are:
None
: No threshold is applied."hard"
: A hard threshold is applied to each mask. The elements of the mask with a value below thevalue
are set to0
, the others are set to1
."topk"
: A soft threshold is applied to each mask. The top obj:value elements of each mask are kept, the others are set to0
."topk_hard"
: Same as"topk"
but values are set to1
for all elements which are kept.
value (int or float, optional) – The value to use when thresholding. (default:
None
)
Explanations
- class Explanation(x: Optional[Tensor] = None, edge_index: Optional[Tensor] = None, edge_attr: Optional[Tensor] = None, y: Optional[Tensor] = None, pos: Optional[Tensor] = None, **kwargs)[source]
Holds all the obtained explanations of a homogenous graph.
The explanation object is a
Data
object and can hold node attributions and edge attributions. It can also hold the original graph if needed.- Parameters
- validate(raise_on_error: bool = True) bool [source]
Validates the correctness of the
Explanation
object.
- get_explanation_subgraph() Explanation [source]
Returns the induced subgraph, in which all nodes and edges with zero attribution are masked out.
- get_complement_subgraph() Explanation [source]
Returns the induced subgraph, in which all nodes and edges with any attribution are masked out.
- visualize_feature_importance(path: Optional[str] = None, feat_labels: Optional[List[str]] = None, top_k: Optional[int] = None)[source]
Creates a bar plot of the node features importance by summing up
self.node_mask
across all nodes.- Parameters
path (str, optional) – The path to where the plot is saved. If set to
None
, will visualize the plot on-the-fly. (default:None
)feat_labels (List[str], optional) – Optional labels for features. (default
None
)top_k (int, optional) – Top k features to plot. If
None
plots all features. (default:None
)
- visualize_graph(path: Optional[str] = None, backend: Optional[str] = None)[source]
Visualizes the explanation graph with edge opacity corresponding to edge importance.
- Parameters
path (str, optional) – The path to where the plot is saved. If set to
None
, will visualize the plot on-the-fly. (default:None
)backend (str, optional) – The graph drawing backend to use for visualization (
"graphviz"
,"networkx"
). If set toNone
, will use the most appropriate visualization backend based on available system packages. (default:None
)
Explainer Algorithms
A dummy explainer for testing purposes. |
|
The GNN-Explainer model from the "GNNExplainer: Generating Explanations for Graph Neural Networks" paper for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN. |
|
The PGExplainer model from the "Parameterized Explainer for Graph Neural Network" paper. |
|
An explainer that uses the attention coefficients produced by an attention-based GNN (e.g., |
- class ExplainerAlgorithm[source]
- abstract forward(model: Module, x: Union[Tensor, Dict[str, Tensor]], edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs) Explanation [source]
Computes the explanation.
- Parameters
model (torch.nn.Module) – The model to explain.
x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input node features of a homogeneous or heterogeneous graph.
edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]) – The input edge indices of a homogeneous or heterogeneous graph.
target (torch.Tensor) – The target of the model.
index (Union[int, Tensor], optional) – The index of the model output to explain. Can be a single index or a tensor of indices. (default:
None
)**kwargs (optional) – Additional keyword arguments passed to
model
.
- abstract supports() bool [source]
Checks if the explainer supports the user-defined settings provided in
self.explainer_config
,self.model_config
.
- property explainer_config: ExplainerConfig
Returns the connected explainer configuration.
- property model_config: ModelConfig
Returns the connected model configuration.
- connect(explainer_config: ExplainerConfig, model_config: ModelConfig)[source]
Connects an explainer and model configuration to the explainer algorithm.
- class GNNExplainer(epochs: int = 100, lr: float = 0.01, **kwargs)[source]
The GNN-Explainer model from the “GNNExplainer: Generating Explanations for Graph Neural Networks” paper for identifying compact subgraph structures and node features that play a crucial role in the predictions made by a GNN.
Note
For an example of using
GNNExplainer
, see examples/gnn_explainer.py and examples/gnn_explainer_ba_shapes.py.
- class PGExplainer(epochs: int, lr: float = 0.003, **kwargs)[source]
The PGExplainer model from the “Parameterized Explainer for Graph Neural Network” paper. Internally, it utilizes a neural network to identify subgraph structures that play a crucial role in the predictions made by a GNN. Importantly, the
PGExplainer
needs to be trained viatrain()
before being able to generate explanations:explainer = Explainer( model=model, algorithm=PGExplainer(epochs=30, lr=0.003), explanation_type='phenomenon', edge_mask_type='object', model_config=ModelConfig(...), ) # Train against a variety of node-level or graph-level predictions: for epoch in range(30): for index in [...]: # Indices to train against. loss = explainer.algorithm.train(epoch, model, x, edge_index, target=target, index=index) # Get the final explanations: explanation = explainer(x, edge_index, target=target, index=0)
- Parameters
- train(epoch: int, model: Module, x: Tensor, edge_index: Tensor, *, target: Tensor, index: Optional[Union[int, Tensor]] = None, **kwargs)[source]
Trains the underlying explainer model. Needs to be called before being able to make predictions.
- Parameters
epoch (int) – The current epoch of the training phase.
model (torch.nn.Module) – The model to explain.
x (torch.Tensor) – The input node features of a homogeneous graph.
edge_index (torch.Tensor) – The input edge indices of a homogeneous graph.
target (torch.Tensor) – The target of the model.
index (int or torch.Tensor, optional) – The index of the model output to explain. Needs to be a single index. (default:
None
)**kwargs (optional) – Additional keyword arguments passed to
model
.
- class AttentionExplainer(reduce: str = 'max')[source]
An explainer that uses the attention coefficients produced by an attention-based GNN (e.g.,
GATConv
,GATv2Conv
, orTransformerConv
) as edge explanation. Attention scores across layers and heads will be aggregated according to thereduce
argument.- Parameters
reduce (str, optional) – The method to reduce the attention scores across layers and heads. (default:
"max"
)
Explanation Metrics
The quality of an explanation can be judged by a variety of different methods. PyG supports the following metrics out-of-the-box:
Compares and evaluates an explanation mask with the ground-truth explanation mask. |
|
Evaluates the fidelity of an |
|
Returns the componentwise characterization score as described in the "GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" paper: |
|
Returns the AUC for the fidelity curve as described in the "GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks" paper. |
|
Evaluates how faithful an |
- groundtruth_metrics(pred_mask: Tensor, target_mask: Tensor, metrics: Optional[Union[str, List[str]]] = None, threshold: float = 0.5) Union[float, Tuple[float, ...]] [source]
Compares and evaluates an explanation mask with the ground-truth explanation mask.
- Parameters
pred_mask (torch.Tensor) – The prediction mask to evaluate.
target_mask (torch.Tensor) – The ground-truth target mask.
metrics (str or List[str], optional) – (
"accuracy"
,"recall"
,"precision"
,"f1_score"
,"auroc"
). (default:["accuracy", "recall", "precision", "f1_score", "auroc"]
)threshold (float, optional) – The threshold value to perform hard thresholding of
mask
andgroundtruth
. (default:0.5
)
- fidelity(explainer: Explainer, explanation: Explanation) Tuple[float, float] [source]
Evaluates the fidelity of an
Explainer
given anExplanation
, as described in the “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” paper.Fidelity evaluates the contribution of the produced explanatory subgraph to the initial prediction, either by giving only the subgraph to the model (fidelity-) or by removing it from the entire graph (fidelity+). The fidelity scores capture how good an explanable model reproduces the natural phenomenon or the GNN model logic.
For phenomenon explanations, the fidelity scores are given by:
\[ \begin{align}\begin{aligned}\textrm{fid}_{+} &= \frac{1}{N} \sum_{i = 1}^N \| \mathbb{1}(\hat{y}_i = y_i) - \mathbb{1}( \hat{y}_i^{G_{C \setminus S}} = y_i) \|\\\textrm{fid}_{-} &= \frac{1}{N} \sum_{i = 1}^N \| \mathbb{1}(\hat{y}_i = y_i) - \mathbb{1}( \hat{y}_i^{G_S} = y_i) \|\end{aligned}\end{align} \]For model explanations, the fidelity scores are given by:
\[ \begin{align}\begin{aligned}\textrm{fid}_{+} &= 1 - \frac{1}{N} \sum_{i = 1}^N \mathbb{1}( \hat{y}_i^{G_{C \setminus S}} = \hat{y}_i)\\\textrm{fid}_{-} &= 1 - \frac{1}{N} \sum_{i = 1}^N \mathbb{1}( \hat{y}_i^{G_S} = \hat{y}_i)\end{aligned}\end{align} \]- Parameters
explainer (Explainer) – The explainer to evaluate.
explanation (Explanation) – The explanation to evaluate.
- characterization_score(pos_fidelity: Tensor, neg_fidelity: Tensor, pos_weight: float = 0.5, neg_weight: float = 0.5) Tensor [source]
Returns the componentwise characterization score as described in the “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” paper:
\[\textrm{charact} = \frac{w_{+} + w_{-}}{\frac{w_{+}}{\textrm{fid}_{+}} + \frac{w_{-}}{1 - \textrm{fid}_{-}}}\]- Parameters
pos_fidelity (torch.Tensor) – The positive fidelity \(\textrm{fid}_{+}\).
neg_fidelity (torch.Tensor) – The negative fidelity \(\textrm{fid}_{-}\).
pos_weight (float, optional) – The weight \(w_{+}\) for \(\textrm{fid}_{+}\). (default:
0.5
)neg_weight (float, optional) – The weight \(w_{-}\) for \(\textrm{fid}_{-}\). (default:
0.5
)
- fidelity_curve_auc(pos_fidelity: Tensor, neg_fidelity: Tensor, x: Tensor) Tensor [source]
Returns the AUC for the fidelity curve as described in the “GraphFramEx: Towards Systematic Evaluation of Explainability Methods for Graph Neural Networks” paper.
More precisely, returns the AUC of
\[f(x) = \frac{\textrm{fid}_{+}}{1 - \textrm{fid}_{-}}\]- Parameters
pos_fidelity (torch.Tensor) – The positive fidelity \(\textrm{fid}_{+}\).
neg_fidelity (torch.Tensor) – The negative fidelity \(\textrm{fid}_{-}\).
x (torch.Tensor) – Tensor containing the points on the \(x\)-axis. Needs to be sorted in ascending order.
- unfaithfulness(explainer: Explainer, explanation: Explanation, top_k: Optional[int] = None) float [source]
Evaluates how faithful an
Explanation
is to an underyling GNN predictor, as described in the “Evaluating Explainability for Graph Neural Networks” paper.In particular, the graph explanation unfaithfulness metric is defined as
\[\textrm{GEF}(y, \hat{y}) = 1 - \exp(- \textrm{KL}(y || \hat{y}))\]where \(y\) refers to the prediction probability vector obtained from the original graph, and \(\hat{y}\) refers to the prediction probability vector obtained from the masked subgraph. Finally, the Kullback-Leibler (KL) divergence score quantifies the distance between the two probability distributions.
- Parameters
explainer (Explainer) – The explainer to evaluate.
explanation (Explanation) – The explanation to evaluate.
top_k (int, optional) – If set, will only keep the original values of the top-\(k\) node features identified by an explanation. If set to
None
, will useexplanation.node_mask
as it is for masking node features. (default:None
)