基于图的消息传递会加剧类别不平衡对分类器的影响,这种现象称为拓扑偏差。少数类节点在邻域中噪声比例上升更快,有效信号减少更多,导致其更容易被误分类。为解决这一问题,本文提出了轻量化的拓扑增强模块BAT,包括节点风险估计、后验概率估计和拓扑增强三步。BAT通过虚拟边和超级节点缓解偏差,易于集成,适用于多种GNN架构和类别不平衡技术,进一步提升分类性能。
基于图的消息传递会加剧类别不平衡对分类器的影响,这种现象称为拓扑偏差。少数类节点在邻域中噪声比例上升更快,有效信号减少更多,导致其更容易被误分类。为解决这一问题,本文提出了轻量化的拓扑增强模块BAT,包括节点风险估计、后验概率估计和拓扑增强三步。BAT通过虚拟边和超级节点缓解偏差,易于集成,适用于多种GNN架构和类别不平衡技术,进一步提升分类性能。
论文标题:
Class-Imbalanced Graph Learning without Class Rebalancing
论文地址:
https://arxiv.org/abs/2308.14181
代码地址:
https://github.com/ZhiningLiu1998/BAT
Theoretical Insights:基于图拓扑结构的消息传递(message-passing)会加剧类别不平衡对分类器学习的影响,我们称之为 topology-sourced bias。具体地,随着类别不平衡加剧,相比于多数类,少数类节点的邻域中噪声(即异类标记节点,见 Theorem 2.1)比例会更快增加,并且有效训练信号(即己类标记节点,见 Theorem 2.2)比例会更快减少。换言之,在 GNN 的消息传递过程中,少数类节点具有更低的信噪比,因此更容易被误分类。
Practical Algorithm:为了 handle topology-sourced bias,我们提出了一个通用、轻量化的拓扑增强模块 BAT。BAT 的主要流程有 3 步:1)节点风险估计: 估计节点受拓扑 bias 影响而误分类的概率;2)后验概率估计:对已知高误分类风险的节点估计其实际类别标签;3)拓扑增强:为每个类创建 super node 并根据前两步采样虚拟边。
Practical Implementation:我们的 GitHub 实现提供了简洁的 API、文档和用例 ,用户仅需添加10 行左右的代码便可将 BAT 嵌入自己的节点分类训练流程中。BAT 旨在通过动态数据增强来缓解拓扑结构带来的 bias,因此可以与各种现有的类别不平衡节点分类技术和 GNN 架构一同使用,并进一步提升其性能。
有兴趣进一步了解的读者可以继续阅读,下面我将简短地介绍以下内容:背景,研究动机,主要观点,算法设计,实验结果,以及本工作的局限性和未来可能的扩展。
类别不平衡(标签中不同类别的标注数据数量的不平衡)是一个已经被广泛研究的问题 [1],但大部分都是针对 i.i.d.(独立同分布)的数据。而在图节点分类任务中,不同的节点之间具有复杂的 dependence(因此 non i.i.d.),由节点之间的边(edge)来刻画,而这些边组成了图的拓扑结构。
GNN 可以通过整合邻居节点的信息来利用图的拓扑结构帮助表示学习,随着 GNN 在多个领域的成功,图上的类别不平衡学习(CIGL,class-imbalanced graph learning)问题在最近也开始得到关注。
三、Motivation/研究动机
现有的 CIGL 研究多关注于如何设计类别重平衡(CR,class-rebalancing) 机制来缓解图上的类别不平衡,如调整节点的权重 [2]、通过数据增强增加少数类节点数量 [3][4]。
尽管这些工作设计中都考虑了图的拓扑结构,但其核心思想仍是如何利用 topology 来帮助 class-rebalancing,我们称这类工作为 CR-oriented CIGL。但我们认为 CR 并不能完全解决 CIGL 问题:CIGL 的本质挑战在于其复杂的拓扑结构与类别不平衡对表示学习的共同影响。
受到 topology-imbalance 相关研究 [2] 的启发,我们想要探究如下问题:在类别不平衡时,拓扑结构和 message-passing 在 CIGL 中扮演何种角色?
四、Insights/观点
通过理论分析和实验观察,我们发现两种 local topological phenomena 可以放大/加剧类别不平衡的影响:
Ambivalent message-passing (AMP): 节点邻域中存在高比例的异类;
labeled nodes Distant message-passing (DMP): 节点与己类 labeled nodes的连接性弱。
图1:AMP(左)与DMP(右)的概念与其对多数类(蓝色)和少数类(橙色)分类性能的影响。
五、Algorithm/算法
无论是多数类/少数类,都只有一小部分节点受AMP/DMP影响(左),且高AMP/DMP的节点具有显著降低的预测准确率(右)。
受此启发,我们希望设计一种算法来直接针对性地缓解 AMP/DMP 对这一部分关键节点的影响。这指导了我们算法的设计:首先定位受 AMP/DMP 影响的节点,然后拓扑增强其在图中的 context(即邻域)以缓解 AMP/DMP。
随着这一思路,我们设计了一个简单的 pipeline:
Node Risk Estimation/节点风险估计:受 AMP/DMP 影响的节点在 message-passing 中具有更低的信噪比,通常也具有更高的预测不确定性。这一步通过量化模型对每个节点的预测的不确定性来定位高风险节点。
Posterior Likelihood Estimation/后验概率估计:高风险节点具有更高的误分类概率,因此其当前的预测标签并不可信。为了选择合适的信息来增强其上下文/邻域,这一步使用 0/1 阶信息重新估计高风险节点的实际标签。
Virtual Topology Augmentation/虚拟拓扑增强:根据当前预测,为每个类别创建一个 prototype 虚拟节点(virtual super node)来表示该类的 general pattern。根据前两步的结果,在高风险节点和其对应的实际标签之间构建虚拟边。注意 BAT 具有线性复杂度,因此计算十分高效,我们在每一步训练中都会重新计算 BAT 来实现动态数据增强。
算法流程图,从左至右
六、Experiment/实验
6.1 Performance/性能
主要结果1:灰色部分为BAT在与其他CIGL方法结合时为其带来的平均/最佳性能增益。
主要结果2:在更多数据集以及类别不平衡设置下的结果,与上面一致,灰色部分为BAT带来的性能增益。
6.2 Efficiency/效率
Runtime Results of BAT
七、Limitation and Future Directions/局限性及可能的扩展
BAT 的理论分析是在设计算法之后做的,因此算法本身并非是由 theoretical findings motivated。但我们认为从拓扑视角出发的对类别不平衡的分析是很有趣的,在这一方向进一步挖掘应该可以得到更加优雅的(且 unified)结论和算法。
BAT 的设计理念以简单实用为第一准则,在每个步骤/模块我们都使用了最简单直接的解决方案。有许多更高级的技术可以用于节点 risk/后验概率估计等步骤。
图数据的拓扑结构带来了许多独特的挑战。除了本文考虑的节点层面的不平衡之外,在拓扑结构中还可能存在节点出入度的不平衡/子图层面的不平衡等等,如何同时 handle 这些层面的 bias/skewness(以及他们之间的 interaction)也将涉及许多有趣的问题。关于这个问题有一篇 survey [5] 可供感兴趣的各位参考。
参考文献
作者:刘芷宁 来源:公众号【PaperWeekly】
扫码观看!
“AI技术流”原创投稿计划
TechBeat是由将门创投建立的AI学习社区(www.techbeat.net)。社区上线500+期talk视频,3000+篇技术干货文章,方向覆盖CV/NLP/ML/Robotis等;每月定期举办顶会及其他线上交流活动,不定期举办技术人线下聚会交流活动。我们正在努力成为AI人才喜爱的高质量、知识型交流平台,希望为AI人才打造更专业的服务和体验,加速并陪伴其成长。
投稿内容
// 最新技术解读/系统性知识分享 //
// 前沿资讯解说/心得经历讲述 //
投稿须知
稿件需要为原创文章,并标明作者信息。
我们会选择部分在深度技术解析及科研心得方向,对用户启发更大的文章,做原创性内容奖励
投稿方式
发送邮件到
melodybai@thejiangmen.com
或添加工作人员微信(yellowsubbj)投稿,沟通投稿详情;还可以关注“将门创投”公众号,后台回复“投稿”二字,获得投稿说明。