Source code for torch_geometric.explain.algorithm.dummy_explainer

from collections import defaultdict
from typing import Dict, Optional, Union

import torch
from torch import Tensor

from torch_geometric.explain import Explanation, HeteroExplanation
from torch_geometric.explain.algorithm import ExplainerAlgorithm
from torch_geometric.explain.config import MaskType
from torch_geometric.typing import EdgeType, NodeType


[docs]class DummyExplainer(ExplainerAlgorithm): r"""A dummy explainer for testing purposes.""" def forward( self, model: torch.nn.Module, x: Union[Tensor, Dict[NodeType, Tensor]], edge_index: Union[Tensor, Dict[EdgeType, Tensor]], edge_attr: Optional[Union[Tensor, Dict[EdgeType, Tensor]]] = None, **kwargs, ) -> Union[Explanation, HeteroExplanation]: r"""Returns random explanations based on the shape of the inputs. Args: model (torch.nn.Module): The model to explain. x (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input node features of a homogeneous or heterogeneous graph. edge_index (Union[torch.Tensor, Dict[NodeType, torch.Tensor]]): The input edge indices of a homogeneous or heterogeneous graph. edge_attr (Union[torch.Tensor, Dict[EdgeType, torch.Tensor]], optional): The inputs edge attributes of a homogeneous or heterogeneous graph. (default: :obj:`None`) Returns: Union[Explanation, HeteroExplanation]: A random explanation based on the shape of the inputs. """ assert isinstance(x, (Tensor, dict)) node_mask_type = self.explainer_config.node_mask_type edge_mask_type = self.explainer_config.edge_mask_type if isinstance(x, Tensor): assert isinstance(edge_index, Tensor) node_mask = None if node_mask_type == MaskType.object: node_mask = torch.rand(x.size(0), 1, device=x.device) elif node_mask_type == MaskType.common_attributes: node_mask = torch.rand(1, x.size(1), device=x.device) elif node_mask_type == MaskType.attributes: node_mask = torch.rand_like(x) edge_mask = None if edge_mask_type == MaskType.object: edge_mask = torch.rand(edge_index.size(1), device=x.device) return Explanation(node_mask=node_mask, edge_mask=edge_mask) else: assert isinstance(edge_index, dict) node_dict = defaultdict(dict) for k, v in x.items(): node_mask = None if node_mask_type == MaskType.object: node_mask = torch.rand(v.size(0), 1, device=v.device) elif node_mask_type == MaskType.common_attributes: node_mask = torch.rand(1, v.size(1), device=v.device) elif node_mask_type == MaskType.attributes: node_mask = torch.rand_like(v) if node_mask is not None: node_dict[k]['node_mask'] = node_mask edge_dict = defaultdict(dict) for k, v in edge_index.items(): edge_mask = None if edge_mask_type == MaskType.object: edge_mask = torch.rand(v.size(1), device=v.device) if edge_mask is not None: edge_dict[k]['edge_mask'] = edge_mask return HeteroExplanation({**node_dict, **edge_dict}) def supports(self) -> bool: return True