GNNExplainer代码解读及其PyG实现
- 使用GNNExplainer
- GNNExplainer源码速读
- 前向传播
- 损失函数
- 基于GNNExplainer图分类解释的PyG代码示例
- 参考资料
接上一篇博客图神经网络的可解释性方法及GNNexplainer代码示例,我们这里简单分析GNNExplainer源码,并用PyTorch Geometric手动实现。
GNNExplainer的源码地址:https://github.com/RexYing/gnn-model-explainer
使用GNNExplainer
(1)安装:
git clone https://github.com/RexYing/gnn-model-explainer
推荐使用python3.7以及创建虚拟环境:
Virtualenv venv -p /usr/local/bin/python3 source venv/bin/activate
(2)训练一个GCN模型
python train.py --dataset=EXPERIMENT_NAME
其中EXPERIMENT_NAME表示想要复现的实验名称。
训练GCN模型的完整选项列表:
python train.py --help
(3)解释一个GCN模型
要运行解释器,请运行以下内容:
python explainer_main.py --dataset=EXPERIMENT_NAME
(4)可视化解释
使用Tensorboard:优化的结果可以通过Tensorboard可视化。
tensorboard --logdir log
GNNExplainer源码速读
GNNExplainer会从2个角度解释图:
- 边(edge):会生成一个edge mask,表示每条边在图中出现的概率,值为0-1之间的浮点数。edge mask也可以当作一个权重,可以取topk的edge连成的子图来解释。
- 结点特征(node feature):node feature(NF)即结点向量,比如一个结点128维表示128个特征,那么它同时会生成一个NF mask来表示每个特征的权重,这个可以不要。
-
explainer目录下的ExplainModel类定义了GNNExplainer网络的模块结构,继承torch.nn.Module:
- 在初始化init的时候,用construct_edge_mask和construct_feat_mask函数初始化要学习的两个mask(分别对应于两个nn.Parameter类型的变量: n × n n×n n×n维的mask,d维全0的feat_mask);diag_mask即主对角线上是0,其余元素均为1的矩阵,用于_masked_adj函数。
- _masked_adj函数将mask用sigmod或ReLU激活后,加上自身转置再除以2,以转为对称矩阵,然后乘上diag_mask,最终将原邻接矩阵adj变换为masked_adj。
-
Explainer类实现了解释的逻辑,主函数是其中的explain,用于解释原模型在单节点的预测结果,主要步骤:
- 取子图的adj, x, label。图解释:取graph_idx对应的整个计算图;节点解释:调用extract_neighborhood函数取该节点num_gc_layers阶数的邻居。
- 将传入的模型预测输出pred转为pred_label。
- 构建ExplainModule,进行num_epochs轮训练(前向+反向传播)
adj = torch.tensor(sub_adj, dtype=torch.float) x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float) label = torch.tensor(sub_label, dtype=torch.long) if self.graph_mode: pred_label = np.argmax(self.pred[0][graph_idx], axis=0) print("Graph predicted label: ", pred_label) else: pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1) print("Node predicted label: ", pred_label[node_idx_new]) explainer = ExplainModule( adj=adj, x=x, model=self.model, label=label, args=self.args, writer=self.writer, graph_idx=self.graph_idx, graph_mode=self.graph_mode, ) if self.args.gpu: explainer = explainer.cuda() ... # NODE EXPLAINER def explain_nodes(self, node_indices, args, graph_idx=0): ... def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp"): ... # GRAPH EXPLAINER def explain_graphs(self, graph_indices): ...
explain_nodes、explain_nodes_gnn_stats、explain_graphs这三个函数都是在它的基础上实现的。
下面分析其中的forward和loss函数。
前向传播
首先把待学习的参数mask和feat_mask分别乘上原邻接矩阵和特征向量,得到变换后的masked_adj和x。前者通过调用_masked_adj函数完成,后者的实现如下:
feat_mask = ( torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask ) if marginalize: std_tensor = torch.ones_like(x, dtype=torch.float) / 2 mean_tensor = torch.zeros_like(x, dtype=torch.float) - x z = torch.normal(mean=mean_tensor, std=std_tensor) x = x + z * (1 - feat_mask) else: x = x * feat_mask
完整代码如下:
这里需要说明的是marginalize为True的情况,参考论文中的Learning binary feature selector F:
- 如果同mask一样学习feature_mask,在某些情况下回导致重要特征也被忽略(学到的特征遮罩也是接近于0的值),因此,依据 X S X_S XS的经验边缘分布使用Monte Carlo方法来抽样得到 X = X S F X=X_S^F X=XSF.
- 为了解决随机变量
X
X
X的反向传播的问题,引入了"重参数化"的技巧,即将其表示为一个无参的随机变量
Z
Z
Z的确定性变换:
X
=
Z
+
(
X
S
−
Z
)
⊙
F
X=Z+(X_S-Z)\odot F
X=Z+(XS−Z)⊙F
s
.
t
.
∑
j
F
j
≤
K
F
s.t. \sum_{j}F_j\le K_F
s.t.j∑Fj≤KF
其中, Z Z Z是依据经验分布采样得到的 d d d维随机变量, K F K_F KF是表示保留的最大特征数的参数(utils/io_utils.py中的denoise_graph函数)。
接着将masked_adj和x输入原始模型得到ExplainModule结果pred。
损失函数
loss = pred_loss + size_loss + lap_loss + mask_ent_loss + feat_size_loss
可知,总的loss包含五项,除了对应于论文中损失函数公式的pred_loss,其余各项损失的作用参考论文Integrating additional constraints into explanations,它们的权重定义在coeffs中:
self.coeffs = { "size": 0.005, "feat_size": 1.0, "ent": 1.0, "feat_ent": 0.1, "grad": 0, "lap": 1.0, }
- pred_loss
mi_obj = False if mi_obj: pred_loss = -torch.sum(pred * torch.log(pred)) else: pred_label_node = pred_label if self.graph_mode else pred_label[node_idx] gt_label_node = self.label if self.graph_mode else self.label[0][node_idx] logit = pred[gt_label_node] pred_loss = -torch.log(logit)
其中pred是当前的预测结果,pred_label是原始特征上的预测结果。
- mask_ent_loss
# entropy mask_ent = -mask * torch.log(mask) - (1 - mask) * torch.log(1 - mask) mask_ent_loss = self.coeffs["ent"] * torch.mean(mask_ent)
- size_loss
# size mask = self.mask if self.mask_act == "sigmoid": mask = torch.sigmoid(self.mask) elif self.mask_act == "ReLU": mask = nn.ReLU()(self.mask) size_loss = self.coeffs["size"] * torch.sum(mask)
- feat_size_loss
# pre_mask_sum = torch.sum(self.feat_mask) feat_mask = ( torch.sigmoid(self.feat_mask) if self.use_sigmoid else self.feat_mask ) feat_size_loss = self.coeffs["feat_size"] * torch.mean(feat_mask)
- lap_loss
# laplacian D = torch.diag(torch.sum(self.masked_adj[0], 0)) m_adj = self.masked_adj if self.graph_mode else self.masked_adj[self.graph_idx] L = D - m_adj pred_label_t = torch.tensor(pred_label, dtype=torch.float) if self.args.gpu: pred_label_t = pred_label_t.cuda() L = L.cuda() if self.graph_mode: lap_loss = 0 else: lap_loss = (self.coeffs["lap"] * (pred_label_t @ L @ pred_label_t) / self.adj.numel())
基于GNNExplainer图分类解释的PyG代码示例
对于图分类问题的解释,关键点有两个:
- 要学习的Mask作用在整个图上,不用取子图
- 标签预测和损失函数的对象是单个graph
实现代码如下:
#!/usr/bin/env python # encoding: utf-8 # Created by BIT09 at 2023/4/28 import torch import networkx as nx import numpy as np import matplotlib.pyplot as plt from math import sqrt from tqdm import tqdm from torch_geometric.nn import MessagePassing from torch_geometric.data import Data from torch_geometric.utils import k_hop_subgraph, to_networkx EPS = 1e-15 class GNNExplainer(torch.nn.Module): r""" Args: model (torch.nn.Module): The GNN module to explain. epochs (int, optional): The number of epochs to train. (default: :obj:`100`) lr (float, optional): The learning rate to apply. (default: :obj:`0.01`) log (bool, optional): If set to :obj:`False`, will not log any learning progress. (default: :obj:`True`) """ coeffs = { 'edge_size': 0.001, 'node_feat_size': 1.0, 'edge_ent': 1.0, 'node_feat_ent': 0.1, } def __init__(self, model, epochs=100, lr=0.01, log=True, node=False): # disable node_feat_mask by default super(GNNExplainer, self).__init__() self.model = model self.epochs = epochs self.lr = lr self.log = log self.node = node def __set_masks__(self, x, edge_index, init="normal"): (N, F), E = x.size(), edge_index.size(1) std = 0.1 if self.node: self.node_feat_mask = torch.nn.Parameter(torch.randn(F) * 0.1) std = torch.nn.init.calculate_gain('relu') * sqrt(2.0 / (2 * N)) self.edge_mask = torch.nn.Parameter(torch.randn(E) * std) self.edge_mask = torch.nn.Parameter(torch.zeros(E) * 50) for module in self.model.modules(): if isinstance(module, MessagePassing): module.__explain__ = True module.__edge_mask__ = self.edge_mask def __clear_masks__(self): for module in self.model.modules(): if isinstance(module, MessagePassing): module.__explain__ = False module.__edge_mask__ = None if self.node: self.node_feat_masks = None self.edge_mask = None def __num_hops__(self): num_hops = 0 for module in self.model.modules(): if isinstance(module, MessagePassing): num_hops += 1 return num_hops def __flow__(self): for module in self.model.modules(): if isinstance(module, MessagePassing): return module.flow return 'source_to_target' def __subgraph__(self, node_idx, x, edge_index, **kwargs): num_nodes, num_edges = x.size(0), edge_index.size(1) if node_idx is not None: subset, edge_index, mapping, edge_mask = k_hop_subgraph( node_idx, self.__num_hops__(), edge_index, relabel_nodes=True, num_nodes=num_nodes, flow=self.__flow__()) x = x[subset] else: x = x edge_index = edge_index row, col = edge_index edge_mask = row.new_empty(row.size(0), dtype=torch.bool) edge_mask[:] = True mapping = None for key, item in kwargs: if torch.is_tensor(item) and item.size(0) == num_nodes: item = item[subset] elif torch.is_tensor(item) and item.size(0) == num_edges: item = item[edge_mask] kwargs[key] = item return x, edge_index, mapping, edge_mask, kwargs def __graph_loss__(self, log_logits, pred_label): loss = -torch.log(log_logits[0, pred_label]) m = self.edge_mask.sigmoid() loss = loss + self.coeffs['edge_size'] * m.sum() ent = -m * torch.log(m + EPS) - (1 - m) * torch.log(1 - m + EPS) loss = loss + self.coeffs['edge_ent'] * ent.mean() return loss def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None, threshold=None, **kwargs): r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask :attr:`edge_mask`. Args: node_idx (int): The node id to explain. edge_index (LongTensor): The edge indices. edge_mask (Tensor): The edge mask. y (Tensor, optional): The ground-truth node-prediction labels used as node colorings. (default: :obj:`None`) threshold (float, optional): Sets a threshold for visualizing important edges. If set to :obj:`None`, will visualize all edges with transparancy indicating the importance of edges. (default: :obj:`None`) **kwargs (optional): Additional arguments passed to :func:`nx.draw`. :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph` """ assert edge_mask.size(0) == edge_index.size(1) if node_idx is not None: # Only operate on a k-hop subgraph around `node_idx`. subset, edge_index, _, hard_edge_mask = k_hop_subgraph( node_idx, self.__num_hops__(), edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) edge_mask = edge_mask[hard_edge_mask] subset = subset.tolist() if y is None: y = torch.zeros(edge_index.max().item() + 1, device=edge_index.device) else: y = y[subset].to(torch.float) / y.max().item() y = y.tolist() else: subset = [] for index, mask in enumerate(edge_mask): node_a = edge_index[0, index] node_b = edge_index[1, index] if node_a not in subset: subset.append(node_a.item()) if node_b not in subset: subset.append(node_b.item()) y = [y for i in range(len(subset))] if threshold is not None: edge_mask = (edge_mask >= threshold).to(torch.float) data = Data(edge_index=edge_index, att=edge_mask, y=y, num_nodes=len(y)).to('cpu') G = to_networkx(data, edge_attrs=['att']) # , node_attrs=['y'] mapping = {k: i for k, i in enumerate(subset)} G = nx.relabel_nodes(G, mapping) kwargs['with_labels'] = kwargs.get('with_labels') or True kwargs['font_size'] = kwargs.get('font_size') or 10 kwargs['node_size'] = kwargs.get('node_size') or 800 kwargs['cmap'] = kwargs.get('cmap') or 'cool' pos = nx.spring_layout(G) ax = plt.gca() for source, target, data in G.edges(data=True): ax.annotate( '', xy=pos[target], xycoords='data', xytext=pos[source], textcoords='data', arrowprops=dict( arrowstyle="->", alpha=max(data['att'], 0.1), shrinkA=sqrt(kwargs['node_size']) / 2.0, shrinkB=sqrt(kwargs['node_size']) / 2.0, connectionstyle="arc3,rad=0.1", )) nx.draw_networkx_nodes(G, pos, node_color=y, **kwargs) nx.draw_networkx_labels(G, pos, **kwargs) return ax, G def explain_graph(self, data, **kwargs): self.model.eval() self.__clear_masks__() x, edge_index, batch = data.x, data.edge_index, data.batch num_edges = edge_index.size(1) # Only operate on a k-hop subgraph around `node_idx`. x, edge_index, _, hard_edge_mask, kwargs = self.__subgraph__(node_idx=None, x=x, edge_index=edge_index, **kwargs) # Get the initial prediction. with torch.no_grad(): log_logits = self.model(data, **kwargs) probs_Y = torch.softmax(log_logits, 1) pred_label = probs_Y.argmax(dim=-1) self.__set_masks__(x, edge_index) self.to(x.device) if self.node: optimizer = torch.optim.Adam([self.node_feat_mask, self.edge_mask], lr=self.lr) else: optimizer = torch.optim.Adam([self.edge_mask], lr=self.lr) epoch_losses = [] for epoch in range(1, self.epochs + 1): epoch_loss = 0 optimizer.zero_grad() if self.node: h = x * self.node_feat_mask.view(1, -1).sigmoid() log_logits = self.model(data, **kwargs) pred = torch.softmax(log_logits, 1) loss = self.__graph_loss__(pred, pred_label) loss.backward() optimizer.step() epoch_loss += loss.detach().item() epoch_losses.append(epoch_loss) edge_mask = self.edge_mask.detach().sigmoid() print(edge_mask) self.__clear_masks__() return edge_mask, epoch_losses def __repr__(self): return f'{self.__class__.__name__}()'
参考资料
- gnn-explainer
- 图神经网络的可解释性方法及GNNexplainer代码示例
- Pytorch实现GNNExplainer
- How to Explain Graph Neural Network — GNNExplainer
- https://gist.github.com/hongxuenong/9f7d4ce96352d4313358bc8368801707
-