推荐阅读 | 联邦学习经典高被引论文《原型对比联邦学习-FedProc》

文摘   2024-06-12 08:30   北京  


本次分享的论文是西安电子科技大学NSS实验室在2023年被期刊《Future Generation Computer Systems》录用的一篇论文:「FedProc: Prototypical contrastive federated learning on non-IID data」。FGCS是一本综合性的国际学术期刊,由Elsevier出版。它在计算机领域具有广泛的影响力,被SCI和SCIE数据库收录。该期刊的影响因子为7.5。在中科院分区中,属于大类 “计算机科学”2区,小类“计算机:理论方法”1区。

论文链接:

https://www.sciencedirect.com/science/article/pii/S0167739X23000262

论文代码:

https://github.com/XidianNSS/FedProc

论文作者:穆旭彤,沈玉龙(通信作者),程珂,耿雪莉,付家瑄,张涛,张志为

实验室主页:http://xidiannss.com/     

 

    

摘要  

联邦学习(FL)允许多个客户端在保持训练数据本地化的同时,共同训练高性能的深度学习模型。然而,当所有客户端的本地数据不是独立同分布(即non-IID)时,实现这种高效的协作学习便面临挑战。尽管已有广泛努力解决此问题,但在图像分类任务中的结果仍不尽人意。在本文中,我们提出了一种新的方法——FedProc:原型对比联邦学习。这种方法的核心思想是利用原型作为全局知识,以校正每个客户端本地训练的偏移。具体而言,我们设计了一个本地网络结构和一个全局原型对比损失函数,以规范本地模型的训练。这些措施使得本地优化的方向与全局最优解保持一致,从而使全局模型在non-IID数据上取得良好表现。评估研究及其理论意义的支持表明,与现有最先进的联邦学习方法相比,FedProc在可接受的计算成本下提高了1.6%至7.9%的准确率。

引言  

联邦学习(FL)作为一种前瞻性的机器学习解决方案,成功地缓解了隐私保护的担忧。该方法允许分布式客户端通过共享其本地模型的参数进行聚合,而无需访问数据的详细内容,从而共同训练一个高性能的全局模型。作为一个有效的通信和隐私保护学习方案,联邦学习已在现实世界应用中展示出其潜力,包括医学图像分析、生物特征分析、物体检测等领域。联邦学习已证明在独立同分布(IID)数据上表现良好。然而,在实际应用中,不同客户端持有的数据通常分布极不均匀,即大多数情况下,每个客户端的本地数据集是非独立同分布的(non-IID)。不平衡的数据分布导致每个客户端的本地模型训练发生偏移,使本地目标与全局最优解相差甚远,从而显著降低了FL的性能。如何减轻non-IID数据对FL的负面影响,这一问题仍然没有得到解决。

针对非IID数据问题,已有多种努力,主要从两个互补的角度出发:一是提高模型聚合的效率,如FedNova、FedBE;另一类方法,如MOON、FedProx和SCAFFOLD,专注于通过限制本地模型相对于全局模型在参数空间的偏离,稳定本地训练阶段。然而,无论是在模型聚合阶段还是本地训练阶段,这些方法都没有充分利用每个客户端提供的潜在知识。实验表明,这些方法的准确性和计算效率仍有很大的提升空间。

鉴于non-IID数据的挑战和以往研究努力的局限性,我们提出了一种原型对比联邦学习框架,称为FedProc。受原型网络的启发,我们创新性地将原型引入联邦学习,充分利用每个客户端的知识来纠正本地训练。原型被定义为每个类别中表征的均值向量。具体来说,服务器首先通过收集客户端的类别原型获得全局类原型,并将其广播给客户端,作为全局知识来纠正本地训练。然后,客户端使用我们精心设计的本地网络结构和损失函数来约束本地模型的训练,确保本地优化目标与全局最优解保持一致。该方法强制每个样本特征表征向其类别的全局原型靠拢,并远离其他类别的全局原型,从而提高本地网络的分类性能。总的来说,FedProc是一种高效且简洁的联邦学习范式,它从原型基础的对比学习的新视角处理数据non-IID导致的问题。    

本文的贡献如下:

1.提出了一种新的FL框架(FedProc)来解决non-IID数据问题。该框架受原型网络的启发,引入了全局类原型,并利用全局类原型来校正本地训练,使本地优化目标与全局优化目标保持一致。这一策略显著提高了联邦学习的性能,尤其是在客户端之间的数据分布为non-IID时。

2.设计了一种通用的混合本地网络结构,充分利用由全局类原型提供的基础知识。这种网络结构量身定制了一种结合交叉熵损失和全局原型对比损失的混合损失函数。该设计允许联邦学习在连续迭代过程中逐渐从特征学习过渡到分类学习。特征学习使得本地网络在全局类原型的约束下能够学习到更好的特征表征,从而使分类器更容易实施分类任务。

3.理论上,所提出的FedProc在non-IID数据集上的训练具有收敛保证。通过在不同数据集上实施FedProc并进行广泛的实验,实证研究与理论阐述相呼应,显示FedProc在准确性和计算效率方面显著优于现有的最先进技术。


图 1. CIFAR-10 上隐藏向量的 T-SNE 可视化。图 (a) 和 (b) 分别显示了客户端 C1 和 C2 处的 SOLO 表示。图(c)显示了全局分布。图 (d) 和 (e) 分别显示了客户端 C1 和 C2 处的 FedProc 表示。SOLO:一种基线方法,每个客户端仅通过输入其本地数据来训练模型,而不进行联合学习。

动机  

本文讨论了纠正本地训练的动机观察。首先研究了训练过程中本地网络架构隐藏层的特征分布。首先,命名SOLO为基线方法,其中每个客户端仅通过输入其本地数据而不进行联邦学习来训练模型。具体而言,使用SOLO根据不同客户端的本地数据训练模型,这些本地数据都是CIFAR-10的偏斜子集。然后,使用t-SNE来可视化这些来自两个不同客户端C1和C2的本地数据的隐藏层特征,如图1(a)和1(b)所示。观察发现,两个客户端的图像特征分布在簇中心和聚类程度方面有很大差异,且与图1(c)中显示的全局分布高度不同。结果,每个客户端的本地目标与全局最优解不一致,这会严重影响联邦学习的准确性。FedProc基于一个直观的想法解决上述问题:原型可以作为全局知识来纠正联邦学习中的本地训练。这一想法使客户端能够将同一类别的样本拉向其类别的全局原型,并远离其他类别的全局原型,从而使每个客户端的本地目标与全局最优解保持一致。为了展示这一想法的有效性,在上述客户端C1和C2的本地数据上运行FedProc,并展示图像的特征分布,如图1(d)和1(e)所示。发现两个客户端同一类别的点被约束在以全局类原型为中心的同一域中。此外,客户端C1和C2中的点的分布都与图1(c)中显示的全局分布相匹配。    

FedProc的总体框架  

图2. FedProc中本地网络架构。特征提取网络(包括基础编码器和投影头)提取表征,用于计算全局原型对比损失。通过输入表征,输出层预测各类的概率,用于计算交叉熵损失。引入系数在本地训练期间调整两种损失函数的权重。

FedProc框架的核心设计体现在本地训练阶段,包括本地网络架构和损失函数,以促进更好的表征学习。图2描述了所提出的本地网络架构概览。本地网络由三个模块组成:基础编码器、投影头和输出层。首先,基础编码器从输入中提取表征。其次,投影头将表征映射为向量表征,该向量用于计算全局原型对比损失。注意,使用带有一个隐藏层的多层感知器(MLP)来实现投影头,这有助于提高其前一层的表征能力。最后,通过输入图像表征,输出层(即单个线性层)预测类别的对数概率,用于计算交叉熵损失。模型权重表示为表示特征提取网络的权重,由基础编码器和投影头组成,表示输出层的权重。用表示整个网络。相应地,(带有可学习参数表示特征提取网络,(带有可学习参数)表示输出层网络。即是输入的映射表征,是表征的预测向量。    

本地网络的损失函数由两部分组成。第一部分是提出的全局原型对比损失项。该项使本地网络学习具有类内紧凑性和类间可分性的嵌入空间。第二部分是典型的交叉熵损失,用于分类器学习,可从上述嵌入空间中获益。受课程学习的启发,引入一个系数在本地训练阶段调整两个项的权重。具体来说,总通信轮次为,当前轮次为通过计算。网络的最终损失函数为。这种方法使得本地学习逐渐从特征学习过渡到分类器学习,随着轮次的增加。在本地训练中,每个客户端使用随机梯度下降(SGD)根据其本地训练数据更新模型。本地目标是最小化

为了使全局类原型作为知识纠正每个客户端的本地训练,提出了全局原型对比损失迫使客户端的每个样本靠近其类的全局原型,并远离其他类的全局原型。定义全局原型对比损失为

其中是余弦相似度,是输入时特征提取网络提取的表征。表示第轮属于类的全局表征。

实验与验证  

实验覆盖了三个标准数据集:CIFAR-10(包含60,000张图像,10个类别)、CIFAR-100(包含60,000张图像,100个类别)和Tiny-ImageNet(包含100,000张图像,200个类别)。为确保公平比较,所有方法均使用相同的本地网络模块。在具体实现上,对于CIFAR-10数据集采用简单的CNN模型作为基础编码器,而CIFAR-100和TinyImageNet数据集则采用ResNet-50作为基础编码器。non-IID数据分布通过Dirichlet分布生成,根据集中参数,将类别的示例按的百分比分配给客户端。在实验中,网络模型的参数、学习率、批次大小、集中参数及本地训练周期数保持不变,唯一变化的是联邦学习的框架设置。

实验结果1. 总体性能对比

表1. FedProc 和其他方法在测试数据集上的 top-1 准确性

表1列出了所有方法的Top-1测试准确率。SOLO显示了所有方法中最差的结果,这突显了联邦学习的优势。FedAvg是第一个使用交叉熵损失来训练本地网络的FL框架,可以视为FL的基线。其他FL框架,包括SCAFFOLD、FedProx和MOON,都旨在解决非IID数据问题。由于FedAvg没有针对非IID设置进行任何优化,其准确率在所有FL算法中相对较低。此外,尽管SCAFFOLD声称能在CIFAR-10上提高准确率,但在CIFAR-100和Tiny-ImageNet上的表现却远不及FedAvg。FedProx的准确率与FedAvg非常接近,这是因为FedProx只对FedAvg进行了微小的改动,使用了重参数化技术。MOON提出了模型对比联邦学习,比较不同模型学习到的表征。这种方法在不同数据集上的准确率比FedAvg高出1.3%到3%。至于FedProc方法,可以观察到其在所有数据集上的准确率始终优于其他方法。具体来说,该方法比MOON在不同数据集上高出1.6%到7.9%。这表明该方法(原型对比联邦学习)能有效纠正本地训练。

实验结果2. 不同参数对准确率的影响

图 4 描述了每轮训练的准确率。从训练结束的结果可以看出,FedProc 取得了最好的性能。此外,图4中的曲线表明FedProc以缓慢的收敛速度为代价提高了精度。这是因为训练开始时特征学习很重要,然后分类器学习逐渐取代特征学习。换句话说,FedProc 在训练过程的早期学习更强的表示,这可以帮助以后的分类器学习。

图 5 描述了在整个训练过程中,随着局部 epoch 数量的增加,准确率也随之提高。我们得出结论,当局部历元的数量为 E = 10 时,大多数方法的准确率是最大的。这是因为当 E 很小时,局部网络无法得到充分的训练。然而,当 E > 10 时,对倾斜数据的局部训练可能会过度拟合,从而导致全局模型的准确性下降。    

表2.  时的 top-1 测试精度。

在CIFAR-100数据集上,通过调整Dirichlet分布的集中参数β来研究数据异质性对准确率的影响。较小的β值表示更偏斜的数据分布。表2的结果显示,FedProc在所有不平衡水平上始终实现最佳准确率。具体来说,当时,FedProc的准确率比MOON高出7.6%。当数据分布高度异质()时,FedProc的准确率分别比MOON高出7.1%和4.9%。这一结果验证了FedProc的鲁棒性,其在不同分布水平下均表现良好。FedProc的高性能得益于引入全局类原型,这些原型作为全局信息,允许本地训练全面学习全局知识。相比之下,其他方法没有充分利用潜在知识。  

表3. 不同类型局部目标损失的 top-1 精度。

在FedProc框架中,通过调整系数来改变本地训练中特征学习和分类器学习的权重。为展示该方法的优越性,设计了一种灵感来自双阶段工作的两阶段联邦学习方法。在第一阶段,该方法使用损失训练特征,然后在第二阶段固定特征训练分类器。如表6所示,该方法的准确率明显高于两阶段训练方法。这是因为两阶段训练破坏了特征学习与分类器学习之间的兼容性。为验证设置方法的有效性,固定重新运行FedProc。显然,这种设置的结果较之变化设置的结果更差。当时,FedProc在训练初期能学习到更好的表征,从而在后期具有更强的分类能力。

表4. 在 CIFAR-100 上不同客户端数量 (m) 和不同通信轮数 (T ) 下的 top-1 测试精度。

为了演示FedProc的可扩展性,在CIFAR-100上进行了大量客户端的实验。客户端数量分别设置为{50, 100, 150},相应的采样率分别设置为{1, 0.2, 0.1}。请注意,意味着每轮从100个客户中随机选择20个客户参加训练。表4和图6的结果显示,FedProc具有出色的可扩展性,其准确率远远高于其他方法。特别是在轮数且客户端数量时,该方法的准确率比MOON高出10%。FedProc的出色可扩展性得益于原型对比学习的引入。这一改进确保了每个客户端的局部目标与全局最优保持一致,从而使得FedProc的性能不随客户端数量增加而受到影响。    

总结  

本文提出了原型对比联邦学习(FedProc),这是一种简单有效的联邦学习框架,用于解决非独立同分布数据问题。FedProc引入类原型作为全局知识来纠正联邦学习中的局部训练。从技术上讲,本文设计了局部网络架构和全局原型对比损失,使局部目标与全局最优值一致,从而产生全局模型良好的分类性能。对多个数据集的大量实验证明了 FedProc 在非 IID 数据上的优势。    

推荐阅读:

基于因果路径的层次图卷积注意力网络在复杂机电系统故障检测中的应用
考虑大规模电池储能热致事故的可再生能源系统可靠性评估
用于预测和健康管理的类ChatGPT大型基础模型:综述和路线图
基于动态贝叶斯网络和数字孪生的水下控制模块可靠性分析
用于锂电池参数识别的分类器辅助贝叶斯优化方法
考虑非线性能耗模型的多路电动公交线路调度优化
基于数据驱动与迁移堆叠的锂离子电池SOH估计方法
在混合操作条件下使用物理驱动的机器学习进行原位电池寿命预测
虚拟断层扫描技术:基于机器学习支持测量的加工过程阶段分割新方法
一种用于动态工作条件下锂离子电池多状态估计的CNN-SAM-LSTM混合神经网络

学术人人
传播科学与学术研究动态,发布学术领域重要研究成果。 重点推广可靠性系统工程(包括可靠性、维修性、保障性、测试性、安全性和环境适应性)理论研究成果,传播相关知识。
 最新文章