【图神经网络】GNNExplainer代码解读及其PyG实现

慈云数据 2024-03-15 技术支持 51 0

GNNExplainer代码解读及其PyG实现

  • 使用GNNExplainer
  • GNNExplainer源码速读
    • 前向传播
    • 损失函数
    • 基于GNNExplainer图分类解释的PyG代码示例
    • 参考资料

      接上一篇博客图神经网络的可解释性方法及GNNexplainer代码示例,我们这里简单分析GNNExplainer源码,并用PyTorch Geometric手动实现。

      GNNExplainer的源码地址:https://github.com/RexYing/gnn-model-explainer

      使用GNNExplainer

      (1)安装:

      git clone https://github.com/RexYing/gnn-model-explainer
      

      推荐使用python3.7以及创建虚拟环境:

      Virtualenv venv -p /usr/local/bin/python3
      source venv/bin/activate
      

      (2)训练一个GCN模型

      python train.py --dataset=EXPERIMENT_NAME
      

      其中EXPERIMENT_NAME表示想要复现的实验名称。

      训练GCN模型的完整选项列表:

      python train.py --help
      

      (3)解释一个GCN模型

      要运行解释器,请运行以下内容:

      python explainer_main.py --dataset=EXPERIMENT_NAME
      

      (4)可视化解释

      使用Tensorboard:优化的结果可以通过Tensorboard可视化。

      tensorboard --logdir log
      

      GNNExplainer源码速读

      GNNExplainer会从2个角度解释图:

      • 边(edge):会生成一个edge mask,表示每条边在图中出现的概率,值为0-1之间的浮点数。edge mask也可以当作一个权重,可以取topk的edge连成的子图来解释。
      • 结点特征(node feature):node feature(NF)即结点向量,比如一个结点128维表示128个特征,那么它同时会生成一个NF mask来表示每个特征的权重,这个可以不要。

        代码目录

        • explainer目录下的ExplainModel类定义了GNNExplainer网络的模块结构,继承torch.nn.Module:

          • 在初始化init的时候,用construct_edge_mask和construct_feat_mask函数初始化要学习的两个mask(分别对应于两个nn.Parameter类型的变量: n × n n×n n×n维的mask,d维全0的feat_mask);diag_mask即主对角线上是0,其余元素均为1的矩阵,用于_masked_adj函数。
          • _masked_adj函数将mask用sigmod或ReLU激活后,加上自身转置再除以2,以转为对称矩阵,然后乘上diag_mask,最终将原邻接矩阵adj变换为masked_adj。
          • Explainer类实现了解释的逻辑,主函数是其中的explain,用于解释原模型在单节点的预测结果,主要步骤:

            1. 取子图的adj, x, label。图解释:取graph_idx对应的整个计算图;节点解释:调用extract_neighborhood函数取该节点num_gc_layers阶数的邻居。
            2. 将传入的模型预测输出pred转为pred_label。
            3. 构建ExplainModule,进行num_epochs轮训练(前向+反向传播)
            adj   = torch.tensor(sub_adj, dtype=torch.float)
            x     = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float)
            label = torch.tensor(sub_label, dtype=torch.long)
            if self.graph_mode:
            	pred_label = np.argmax(self.pred[0][graph_idx], axis=0)
            	print("Graph predicted label: ", pred_label)
            else:
            	pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1)
            	print("Node predicted label: ", pred_label[node_idx_new])
            explainer = ExplainModule(
            	adj=adj,
            	x=x,
            	model=self.model,
            	label=label,
            	args=self.args,
            	writer=self.writer,
            	graph_idx=self.graph_idx,
            	graph_mode=self.graph_mode,
            )
            if self.args.gpu:
            	explainer = explainer.cuda()
            ...
            # NODE EXPLAINER
            def explain_nodes(self, node_indices, args, graph_idx=0):
            ...
            def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp"):
            ...
            # GRAPH EXPLAINER
            def explain_graphs(self, graph_indices):
            ...
            

            explain_nodes、explain_nodes_gnn_stats、explain_graphs这三个函数都是在它的基础上实现的。

            下面分析其中的forward和loss函数。

            前向传播

            首先把待学习的参数mask和feat_mask分别乘上原邻接矩阵和特征向量,得到变换后的masked_adj和x。前者通过调用_masked_adj函数完成,后者的实现如下:

            feat_mask = (
            	torch.sigmoid(self.feat_mask)
            	if self.use_sigmoid
            	else self.feat_mask
            )
            if marginalize:
            	std_tensor = torch.ones_like(x, dtype=torch.float) / 2
            	mean_tensor = torch.zeros_like(x, dtype=torch.float) - x
            	z = torch.normal(mean=mean_tensor, std=std_tensor)
            	x = x + z * (1 - feat_mask)
            else:
            	x = x * feat_mask
            

            完整代码如下:

            forward

            这里需要说明的是marginalize为True的情况,参考论文中的Learning binary feature selector F:

            Learning binary feature selector F

            • 如果同mask一样学习feature_mask,在某些情况下回导致重要特征也被忽略(学到的特征遮罩也是接近于0的值),因此,依据 X S X_S XS​的经验边缘分布使用Monte Carlo方法来抽样得到 X = X S F X=X_S^F X=XSF​.
            • 为了解决随机变量 X X X的反向传播的问题,引入了"重参数化"的技巧,即将其表示为一个无参的随机变量 Z Z Z的确定性变换: X = Z + ( X S − Z ) ⊙ F X=Z+(X_S-Z)\odot F X=Z+(XS​−Z)⊙F s . t . ∑ j F j ≤ K F s.t. \sum_{j}F_j\le K_F s.t.j∑​Fj​≤KF​

              其中, Z Z Z是依据经验分布采样得到的 d d d维随机变量, K F K_F KF​是表示保留的最大特征数的参数(utils/io_utils.py中的denoise_graph函数)。

              接着将masked_adj和x输入原始模型得到ExplainModule结果pred。

              损失函数

              loss = pred_loss + size_loss + lap_loss + mask_ent_loss + feat_size_loss
              

              可知,总的loss包含五项,除了对应于论文中损失函数公式的pred_loss,其余各项损失的作用参考论文Integrating additional constraints into explanations,它们的权重定义在coeffs中:

              self.coeffs = {
              	"size": 0.005,
              	"feat_size": 1.0,
              	"ent": 1.0,
              	"feat_ent": 0.1,
              	"grad": 0,
              	"lap": 1.0,
              }
              

              Integrating additional constraints into explanations

              1. pred_loss
              mi_obj = False
              if mi_obj:
              	pred_loss = -torch.sum(pred * torch.log(pred))
              else:
              	pred_label_node = pred_label if self.graph_mode else pred_label[node_idx]
              	gt_label_node = self.label if self.graph_mode else self.label[0][node_idx]
              	logit = pred[gt_label_node]
              	pred_loss = -torch.log(logit)
              

              其中pred是当前的预测结果,pred_label是原始特征上的预测结果。

              1. mask_ent_loss
              # entropy
              mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask)
              mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent)
              
              1. size_loss
              # size
              mask = self.mask
              if self.mask_act == "sigmoid":
              	mask = torch.sigmoid(self.mask)
              elif self.mask_act == "ReLU":
              	mask = nn.ReLU()(self.mask)
              size_loss = self.coeffs["size"] * torch.sum(mask)
              
              1. feat_size_loss
              # pre_mask_sum = torch.sum(self.feat_mask)
              feat_mask = (
              	torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask
              )
              feat_size_loss = self.coeffs["feat_size"] * torch.mean(feat_mask)
              
              1. lap_loss
              # laplacian
              D = torch.diag(torch.sum(self.masked_adj[0], 0))
              m_adj = self.masked_adj if self.graph_mode else self.masked_adj[self.graph_idx]
              L = D - m_adj
              pred_label_t = torch.tensor(pred_label, dtype=torch.float)
              if self.args.gpu:
              	pred_label_t = pred_label_t.cuda()
              	L = L.cuda()
              if self.graph_mode:
              	lap_loss = 0
              else:
              	lap_loss = (self.coeffs["lap"] * (pred_label_t @ L @ pred_label_t) / self.adj.numel())
              

              补充

              基于GNNExplainer图分类解释的PyG代码示例

              对于图分类问题的解释,关键点有两个:

              • 要学习的Mask作用在整个图上,不用取子图
              • 标签预测和损失函数的对象是单个graph

                实现代码如下:

                #!/usr/bin/env python
                # encoding: utf-8
                # Created by BIT09 at 2023/4/28
                import torch
                import networkx as nx
                import numpy as np
                import matplotlib.pyplot as plt
                from math import sqrt
                from tqdm import tqdm
                from torch_geometric.nn import MessagePassing
                from torch_geometric.data import Data
                from torch_geometric.utils import k_hop_subgraph, to_networkx
                EPS = 1e-15
                class GNNExplainer(torch.nn.Module):
                    r"""
                    Args:
                        model (torch.nn.Module): The GNN module to explain.
                        epochs (int, optional): The number of epochs to train.
                            (default: :obj:`100`)
                        lr (float, optional): The learning rate to apply.
                            (default: :obj:`0.01`)
                        log (bool, optional): If set to :obj:`False`, will not log any learning
                            progress. (default: :obj:`True`)
                    """
                    coeffs = {
                        'edge_size': 0.001,
                        'node_feat_size': 1.0,
                        'edge_ent': 1.0,
                        'node_feat_ent': 0.1,
                    }
                    def __init__(self, model, epochs=100, lr=0.01, log=True, node=False):  # disable node_feat_mask by default
                        super(GNNExplainer, self).__init__()
                        self.model = model
                        self.epochs = epochs
                        self.lr = lr
                        self.log = log
                        self.node = node
                    def __set_masks__(self, x, edge_index, init="normal"):
                        (N, F), E = x.size(), edge_index.size(1)
                        std = 0.1
                        if self.node:
                            self.node_feat_mask = torch.nn.Parameter(torch.randn(F) * 0.1)
                        std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N))
                        self.edge_mask = torch.nn.Parameter(torch.randn(E) * std)
                        self.edge_mask = torch.nn.Parameter(torch.zeros(E) * 50)
                        for module in self.model.modules():
                            if isinstance(module, MessagePassing):
                                module.__explain__ = True
                                module.__edge_mask__ = self.edge_mask
                    def __clear_masks__(self):
                        for module in self.model.modules():
                            if isinstance(module, MessagePassing):
                                module.__explain__ = False
                                module.__edge_mask__ = None
                        if self.node:
                            self.node_feat_masks = None
                        self.edge_mask = None
                    def __num_hops__(self):
                        num_hops = 0
                        for module in self.model.modules():
                            if isinstance(module, MessagePassing):
                                num_hops += 1
                        return num_hops
                    def __flow__(self):
                        for module in self.model.modules():
                            if isinstance(module, MessagePassing):
                                return module.flow
                        return 'source_to_target'
                    def __subgraph__(self, node_idx, x, edge_index, **kwargs):
                        num_nodes, num_edges = x.size(0), edge_index.size(1)
                        if node_idx is not None:
                            subset, edge_index, mapping, edge_mask = k_hop_subgraph(
                                node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,
                                num_nodes=num_nodes, flow=self.__flow__())
                            x = x[subset]
                        else:
                            x = x
                            edge_index = edge_index
                            row, col = edge_index
                            edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
                            edge_mask[:] = True
                            mapping = None
                        for key, item in kwargs:
                            if torch.is_tensor(item) and item.size(0) == num_nodes:
                                item = item[subset]
                            elif torch.is_tensor(item) and item.size(0) == num_edges:
                                item = item[edge_mask]
                            kwargs[key] = item
                        return x, edge_index, mapping, edge_mask, kwargs
                    def __graph_loss__(self, log_logits, pred_label):
                        loss = -torch.log(log_logits[0, pred_label])
                        m = self.edge_mask.sigmoid()
                        loss = loss + self.coeffs['edge_size'] * m.sum()
                        ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS)
                        loss = loss + self.coeffs['edge_ent'] * ent.mean()
                        return loss
                    def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,
                                           threshold=None, **kwargs):
                        r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask
                        :attr:`edge_mask`.
                        Args:
                            node_idx (int): The node id to explain.
                            edge_index (LongTensor): The edge indices.
                            edge_mask (Tensor): The edge mask.
                            y (Tensor, optional): The ground-truth node-prediction labels used
                                as node colorings. (default: :obj:`None`)
                            threshold (float, optional): Sets a threshold for visualizing
                                important edges. If set to :obj:`None`, will visualize all
                                edges with transparancy indicating the importance of edges.
                                (default: :obj:`None`)
                            **kwargs (optional): Additional arguments passed to
                                :func:`nx.draw`.
                        :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`
                        """
                        assert edge_mask.size(0) == edge_index.size(1)
                        if node_idx is not None:
                            # Only operate on a k-hop subgraph around `node_idx`.
                            subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
                                node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,
                                num_nodes=None, flow=self.__flow__())
                            edge_mask = edge_mask[hard_edge_mask]
                            subset = subset.tolist()
                            if y is None:
                                y = torch.zeros(edge_index.max().item() + 1,
                                                device=edge_index.device)
                            else:
                                y = y[subset].to(torch.float) / y.max().item()
                                y = y.tolist()
                        else:
                            subset = []
                            for index, mask in enumerate(edge_mask):
                                node_a = edge_index[0, index]
                                node_b = edge_index[1, index]
                                if node_a not in subset:
                                    subset.append(node_a.item())
                                if node_b not in subset:
                                    subset.append(node_b.item())
                            y = [y for i in range(len(subset))]
                        if threshold is not None:
                            edge_mask = (edge_mask >= threshold).to(torch.float)
                        data = Data(edge_index=edge_index, att=edge_mask, y=y,
                                    num_nodes=len(y)).to('cpu')
                        G = to_networkx(data, edge_attrs=['att'])  # , node_attrs=['y']
                        mapping = {k: i for k, i in enumerate(subset)}
                        G = nx.relabel_nodes(G, mapping)
                        kwargs['with_labels'] = kwargs.get('with_labels') or True
                        kwargs['font_size'] = kwargs.get('font_size') or 10
                        kwargs['node_size'] = kwargs.get('node_size') or 800
                        kwargs['cmap'] = kwargs.get('cmap') or 'cool'
                        pos = nx.spring_layout(G)
                        ax = plt.gca()
                        for source, target, data in G.edges(data=True):
                            ax.annotate(
                                '', xy=pos[target], xycoords='data', xytext=pos[source],
                                textcoords='data', arrowprops=dict(
                                    arrowstyle="->",
                                    alpha=max(data['att'], 0.1),
                                    shrinkA=sqrt(kwargs['node_size']) / 2.0,
                                    shrinkB=sqrt(kwargs['node_size']) / 2.0,
                                    connectionstyle="arc3,rad=0.1",
                                ))
                        nx.draw_networkx_nodes(G, pos, node_color=y, **kwargs)
                        nx.draw_networkx_labels(G, pos, **kwargs)
                        return ax, G
                    def explain_graph(self, data, **kwargs):
                        self.model.eval()
                        self.__clear_masks__()
                        x, edge_index, batch = data.x, data.edge_index, data.batch
                        num_edges = edge_index.size(1)
                        # Only operate on a k-hop subgraph around `node_idx`.
                        x, edge_index, _, hard_edge_mask, kwargs = self.__subgraph__(node_idx=None, x=x, edge_index=edge_index,
                                      **kwargs)
                        # Get the initial prediction.
                        with torch.no_grad():
                            log_logits = self.model(data, **kwargs)
                            probs_Y = torch.softmax(log_logits, 1)
                            pred_label = probs_Y.argmax(dim=-1)
                        self.__set_masks__(x, edge_index)
                        self.to(x.device)
                        if self.node:
                            optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask],
                                                         lr=self.lr)
                        else:
                            optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr)
                        epoch_losses = []
                        for epoch in range(1, self.epochs + 1):
                            epoch_loss = 0
                            optimizer.zero_grad()
                            if self.node:
                                h = x * self.node_feat_mask.view(1, -1).sigmoid()
                            log_logits = self.model(data, **kwargs)
                            pred = torch.softmax(log_logits, 1)
                            loss = self.__graph_loss__(pred, pred_label)
                            loss.backward()
                            optimizer.step()
                            epoch_loss += loss.detach().item()
                            epoch_losses.append(epoch_loss)
                        edge_mask = self.edge_mask.detach().sigmoid()
                        print(edge_mask)
                        self.__clear_masks__()
                        return edge_mask, epoch_losses
                    def __repr__(self):
                        return f'{self.__class__.__name__}()'
                

                参考资料

                1. gnn-explainer
                2. 图神经网络的可解释性方法及GNNexplainer代码示例
                3. Pytorch实现GNNExplainer
                4. How to Explain Graph Neural Network — GNNExplainer
                5. https://gist.github.com/hongxuenong/9f7d4ce96352d4313358bc8368801707
微信扫一扫加客服

微信扫一扫加客服

点击启动AI问答
Draggable Icon