点击下方卡片,关注计算机视觉Daily
添加微信号:CVer2233,小助手会拉你进群!
扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!
添加微信号:CVer2233,小助手会拉你进群!
扫描下方二维码,加入CVer学术星球!可以获得最新顶会/顶刊上的论文idea和CV从入门到精通资料,及最前沿应用!发论文/搞科研/涨薪,强烈推荐!
Decouple Graph Neural Networks: Train Multiple Simple GNNs Simultaneously Instead of One论文:https://arxiv.org/abs/2304.10126
解耦图神经网络:同时训练多个简单的GNN而不是一个
Hongyuan Zhang; Yanan Zhu; Xuelong Li
摘要
图神经网络(GNN)由于节点依赖随着层数增加呈指数增长,导致严重的效率问题。这极大地限制了随机优化算法的应用,使得GNN的训练通常耗时较长。为了解决这个问题,我们提出了将多层GNN解耦为多个简单模块以实现更高效的训练,该框架包括经典的前向训练(FT)和设计的反向训练(BT)。在所提出的框架下,每个模块都可以在FT中通过随机算法高效地训练,而不会因简单性而扭曲图信息。为了避免FT中单向信息传递以及充分训练浅层模块与深层模块,我们开发了一种反向训练机制,使前层模块能够感知后层模块,灵感来自经典的反向传播算法。反向训练引入了解耦模块中的反向信息传递以及前向信息传递。为了研究解耦和贪婪训练如何影响表征能力,我们理论上证明了在线性模块中产生的误差在大多数情况下不会在无监督任务中累积。理论和实验结果表明,所提出的框架具有高效性并且性能合理,值得进一步研究。
关键词
反向训练
高效训练
图神经网络
I. 引言
近年来,由于其令人印象深刻的性能,神经网络[1]、[2]已扩展到图数据,被称为图神经网络(GNN)[3]。随着GNN显著提高了图任务的结果,它已从不同方面得到了广泛研究,例如图卷积网络(GCN)[4]、[5]、图注意力网络(GATs)[6]、[7]、时空GNN(STGNN)[8]、图自编码器[9]、[10]、图对比学习[11]等。
除了来自不同视角的变体之外,一个重要的议题是众所周知的GNN效率问题。在经典神经网络[2]中,优化通常基于随机算法进行,批量大小有限[12]、[13],因为样本彼此独立。然而,在[14]中定义的聚合类操作导致每个节点依赖于其邻居,并且一个节点的依赖节点数量随着层数的增长呈指数增长,导致批量大小意外增加。一些工作基于邻居采样[14]、[15]、[16]、[17]和图近似[18]提出以限制批量大小,而一些方法[19]、[20]尝试直接应用高阶图操作并牺牲大部分非线性。尽管VRGCN[16]试图通过改进采样来控制方差,但采样方法的训练稳定性仍然是一个问题。注意到所需的节点可能仍然会随着深度的增加而增长(缓慢)。Cluster-GCN[18]找到了一个具有许多连通分量的近似图,以便批量大小被严格上界限制。这些方法的主要挑战是采样期间的信息缺失。简化的方法[19]、[20]是高效的,但有限的非线性可能是这些方法的瓶颈。这些方法可能结合GIN[21]的思想来提高容量[20]。
为了在保留完整图结构的同时应用随机优化,我们提出了一个框架,即堆叠图神经网络(SGNN),它将多层GNN解耦为多个简单的GNN模块,然后同时训练它们,而不是随着深度的增加而连接它们。受反向传播算法的启发,我们发现堆叠网络[22]和经典网络之间的主要区别是没有从后层模块传递到前层的训练信息。缺乏反向信息传递可能是堆叠模型性能限制的主要原因。贡献总结如下:(1)我们相应地提出了一种反向训练策略,让前层模块接收来自最终损失和后层模块的信息,从而形成了一个循环训练框架,以控制偏差并正确训练浅层模块。(2)在这个框架下,多层GNN可以被解耦为多个简单的GNN,本文中称为可分离GNN,以便每个训练步骤都可以使用随机优化而不进行任何采样或对图的更改。因此,SGNN可以同时考虑非线性和高效率。(3)我们研究了解耦和贪婪训练如何影响线性SGNN的表征能力。证明了在大多数情况下,当最终目标是图重建时,线性模块中产生的误差不会累积。
II. 背景
图神经网络:在过去的几年中,图神经网络[4]、[5]、[6]、[21]、[23]、[24]越来越受到关注。GNN不仅应用于图任务[25](例如,推荐系统[26]),还应用于其他应用[27]、[28](例如,计算机视觉[29])。特别是,图卷积网络(GCN)[4]已成为重要的基线。通过引入自注意力技术[30],图注意力网络(GAT)[6]、[7]被提出并应用于其他应用[29]、[31]。正如[32]所声称的,GNN遭受过平滑问题的困扰,GALA[10]发展了图锐化,ResGCN[33]试图设计一个更深的架构。理论工作[32]、[34]、[35]对GNN的深度有不同的看法。一些工作[32]、[34]声称,随着层数的增加,GNN的表达能力会降低,而其他工作则认为[34]中的假设可能不成立,更深的GNN具有更强的表达能力[35]。此外,一些工作[21]、[36]通过展示Weisfeiler-Lehman测试[37]和GNN之间的联系来研究表达能力。然而,大多数人忽略了GNN的效率问题。
高效图神经网络:为了加速通过批量梯度下降对GNN的优化而不产生太大偏差,几种模型[14]、[15]、[16]、[17]提出根据图拓扑采样数据点。这些模型提出了不同的采样策略以获得稳定的结果。GraphSAGE[14]为每个节点生成了一个有限邻居的子图,而FastGCN[15]通过重要性采样为每层采样固定节点。在[16]中进一步控制了采样的方差。Cluster-GCN[18]旨在生成一个具有许多连通分量的近似图,以便每个分量可以作为每一步的批量使用。AnchorGAE[38]提出通过引入锚点将原始图转换为二分图来加速图操作,从而将复杂度降低到与现有模型[39]相比的O(n)。SGC[19]通过将中间层的所有激活函数设置为线性函数来简化GCN,SSGC[20]进一步改进了它。总之,本文提出的SGNN保留了非线性,不需要节点采样或子图采样。L2-GCN[40]试图将经典的堆叠自编码器的思想扩展到流行的GCN,而DGL-GNN[41]进一步开发了一个并行版本。它们都没有训练所有GNN模块,但SGNN首次提供了一个新颖的框架来像训练常规神经网络中的层一样训练它们。
与现有模型的联系:堆叠自编码器(SAE)[22]是应用于神经网络预训练的模型。它训练当前的两层自编码器[42],然后将中间层输出的潜在特征传递给下一个自编码器。该模型通常用作预训练模型而不是正式模型。MGAE[43]是SAE的扩展,其基本模块是图自编码器[9]。与所提出的模型相比,主要区别在于每个模块是否可以被来自正向和反向方向的模块感知。堆叠范式类似于经典的提升模型[44]、[45]、[46],而一些工作[47]、[48]也研究了神经网络的提升算法。近年来,一些提升GNN模型[49]、[50]也得到了发展。大多数提升算法(例如,[47]、[49])的目标是逐步学习预测函数,而所提出的SGNN的目标是逐步学习理想的嵌入。请注意,AdaGCN[50]也是逐步训练的,并且使用AdaBoost[44]结合特征。更重要的是,所有这些GNN的提升方法只进行前向训练,缺少反向训练。深度神经接口[51]提出了解耦神经网络以异步加速梯度计算。解耦是计算L层网络梯度的加速技巧,而本文提出的SGNN明确地将L层GNN分离为L个简单模块。换句话说,SGNN的最终目标不是优化L层GNN。
III. 提出的方法
受SAE[22]的启发,以及简化模型[19]、[20]对GNN的高效率,我们因此重新思考堆叠网络和多层GNN之间的实质区别。总之,我们试图回答本文中的以下两个问题:Q1: 如何将复杂的GNN解耦为多个简单的GNN并联合训练它们?Q2: 解耦如何影响表征能力和最终性能?我们将在本节讨论第一个问题,然后在第四部分详细阐述另一个问题。
预备知识
堆叠图神经网络
可分离性:效率的关键概念
前向训练(FT):第一个挑战是如何为每个模块 设置训练目标。这对于将SGNN应用于监督和无监督场景至关重要。假设我们有一个可分离的GNN模块 并且让
反向训练(BT):第二个挑战是如何同时训练多个可分离的GNN以确保性能。简而言之,由于前向传播(FP)和反向传播(BP)的信息传递,经典多层神经网络的所有层的梯度被精确计算。BP通过反馈让浅层感知深层。在SGNN中,尾部模块在FT中对头部模块是不可见的。我们因此设计了反向训练(BT)来实现信息的反向传递。具体来说,对于一个可分离的GNN层作为SGNN中的一个模块:
复杂度分析
IV. 理论分析
V. 实验
节点聚类
实验设置:我们首先验证SGNN在节点聚类上的有效性。我们将其与10种方法进行比较,包括一个基线聚类模型Kmeans,三种不考虑训练效率的GCN模型(GAE[9]、ARGA[54]、MGAE[43]),以及六种具有GAE损失的快速GCN模型(GraphSAGE[14]、FastGAE[15]、ClusterGAE[18]、AGC[55](SGC[19]的无监督扩展)、S2GC[20]和GAE-S2GC)。
性能:从表III中,我们发现SGNN在大多数数据集上表现优异。如果发布的代码因内存不足(OOM)而无法在Reddit上运行,我们用“N/A”代替结果。特别是,SGNN-BT在Reddit上取得了高达8%的提升,超过了众所周知的GraphSAGE。SGNN-FT在某些数据集上表现优于平均水平。它通常优于GraphSAGE,但未能超过SGC。由于由多个模块引起的更深层结构,SGNN的性能优于简单的GAE。由于多个模块引入的更多非线性,它也优于SGC。注意S2GC和SGC是强大的竞争者,而SGNN可以轻松地将它们作为基础模块使用,因为它们是可分离的,这由SGNN-S2GC所示。可以很容易地发现SGNN-S2GC通常与S2GC取得了相似的结果。由于它比SGNN-BT慢,并且性能提升不稳定,我们建议在实践中使用简单的SGNN-FT而不是SGNN-S2GC。从表中,我们发现将S2GC修改为GAE-S2GC是不必要的,因为GAE架构根本没有提高S2GC的性能。从消融实验中,SGNN-BT比SGNN-FT表现更好,这表明了反向训练的必要性。
效率:图2显示了在Pubmed和Reddit上几种GNN的消耗时间,这些GNN具有更高的效率。我们不仅忽略了预处理操作,还测量了效率,即从将数据加载到RAM后开始,直到完成所有参数更新的总时间除以GNNs的更新参数总数。这种度量可以反映旨在将基于批量的算法应用于GNN的不同训练技术之间的真实差异。需要强调的是,SGC在消耗时间上比SGNN差的原因。关键是它们的预处理操作的成本不同。对于一个L阶SGC,计算的计算成本至少为,而具有L个一阶模块的SGNN总共需要进行相同的预处理操作。该度量还为SGC和其他模型之间的公平比较提供了依据,因为停止标准总是不同的。
节点分类
实验设置:我们还在四个数据集上进行了半监督分类的实验。数据集的拆分遵循[19],如表II所示。我们将SGNN与GCN[4]、GAT[6]、DGI[56]、APPNP[57]、L2GCN[40]、FastGCN[15]、GraphSAGE[14]、Cluster-GCN[18]、SGC[19]、GCNII[58]和S2GC[20]进行比较。同样,我们用两种不同的基础模型测试SGNN,即SGNN-BT和SGNN-S2GC。实验设置与节点聚类实验相同。对于GraphSAGE,我们默认使用均值操作符,如果使用额外的操作符,我们会添加一些注释。在引用网络上,学习率设为0.01,而在Reddit上设为。由于引用网络上的训练节点少于200,我们在每次迭代中对所有方法使用所有训练点,而在Reddit上,所有基于批量的模型的批量大小设为512。我们不使用[4]中使用早期停止标准,最大迭代次数遵循SGC的设置。每个模块的嵌入维度与节点聚类设置相同。为了公平起见,我们报告了使用两个模块的SGNN的结果,这些模块使用一阶操作。前向训练损失在(6)中定义。此外,所有比较的模型共享相同的批量迭代器、损失函数和邻域采样器(如适用)的实现。默认情况下,LFT和LBT之间的平衡系数设为1。我们在Cora上优化超参数,并在表V和表VI上报告了在引用数据集上平均10次运行的结果,以及在Reddit上平均5次运行的结果。Reddit上的结果记录在表VI中。超参数在所有数据集上共享,这些超参数在Cora上进行了优化。 性能:表V中比较方法的结果来自相应的论文。当实验结果缺失时,我们运行公开发布的代码,并在结果上加注†。从表V和表VI中,我们得出结论,SGNN在引用网络上超过了使用邻居采样的模型,例如GraphSAGE、FastGCN和Cluster-GCN,并且在Reddit上超过了大多数模型。在简单的引用网络上,SGNN比其他基于批量的模型损失的准确性最少,这接近GCN。由于每个模块的可分离性,批量采样不需要邻居采样,也不会造成图信息的丢失。
可视化解耦的影响:在图3中,我们可视化了由3个模块组成的SGNN和3层GCN的输出,直接展示了解耦不会导致平凡特征,这对应于第IV节中的理论结论。为了展示SGNN和反向训练引入的非线性和灵活性的好处,图7显示了SGC、SGNN-FT和SGNN-BT的收敛曲线。注意,该图显示了最终损失的变化。在SGNN中,最终损失是ML的损失,而在SGC中,它唯一的训练损失。SGC使用了L阶图操作。从这个图中,我们可以得出结论:(1)非线性确实会导致更好的损失值;(2)反向训练显著降低了损失。总之,解耦在实证上不会造成负面影响。
在OGB数据集上的实验
VI. 结论
绘图神器下载
后台回复:绘图神器,即可下载绘制神经网络结构的神器!
何恺明在MIT授课的课件PPT下载
在CVer公众号后台回复:何恺明,即可下载本课程的所有566页课件PPT!赶紧学起来!
CVPR 2024 论文和代码下载
在CVer公众号后台回复:CVPR2024,即可下载CVPR 2024论文和代码开源的论文合集
Mamba、多模态和扩散模型交流群成立
扫描下方二维码,或者添加微信号:CVer2233,即可添加CVer小助手微信,便可申请加入CVer-Mamba、多模态学习或者扩散模型微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer、NeRF、3DGS、Mamba等。
一定要备注:研究方向+地点+学校/公司+昵称(如Mamba、多模态学习或者扩散模型+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群
▲扫码或加微信号: CVer2233,进交流群
CVer计算机视觉(知识星球)来了!想要了解最新最快最好的CV/DL/AI论文速递、优质实战项目、AI行业前沿、从入门到精通学习教程等资料,欢迎扫描下方二维码,加入CVer计算机视觉(知识星球),已汇集近万人!
▲扫码加入星球学习
整理不易,请赞和在看