NeurIPS 2024 | 用高斯邻域最小化提升视觉提示词微调在长尾视觉识别上的性能

科技   2024-12-26 12:42   北京  

导读

本文是VCC刘烨同学对论文 Improving Visual Prompt Tuning by Gaussian Neighborhood Minimization for Long-Tailed Visual Recognition 的解读,该工作来自深圳大学可视计算研究中心及光明实验室黄惠教授课题组,和厦门大学、广东工业大学及香港浸会大学联合研究,已被机器学习顶级会议 NeurIPS 2024 录用,同时获得中国发明专利授权和软件著作权登记。

项目主页: 
https://vcc.tech/research/2024/GNM-PT

该工作提出了一种针对长尾问题的训练优化策略,旨在平衡地提升视觉提示词微调对各个类别的泛化能力。此训练优化策略新提出的基于高斯邻域最小化的损失,能够帮助模型在长尾数据上训练时收敛到更平坦的损失极小值点,平衡地提升模型对头类和尾类的泛化能力,并且几乎不引入额外的计算代价。大量实验证明,提出的高斯邻域最小化方法能够使得模型在长尾分布数据上的损失平面更加平坦,且几乎不增加额外的计算开销。该方法有效平衡了模型对头类和尾类的泛化能力,并在多个长尾任务中展现出卓越的性能和效率优势



I


 引言 

从真实世界中采集的数据通常呈现长尾分布,其中少数类别(头部类)拥有丰富的样本,而大量类别(尾类)则仅占据极少的样本。这种不平衡的分布对深度学习模型的训练构成了严重障碍。因此近年来,长尾视觉识别问题引起了广泛关注,并促使研究者提出了许多有效的解决方案。大多数现有方法集中于从头开始训练模型,主要从数据处理、表征能力提升和模型输出修正等角度着手,试图缓解长尾问题。近期,一些研究开始探索在微调预训练模型的基础上进行长尾视觉识别的改进[1]这些方法借助参数有效微调 (PEFT) 技术和更具鲁棒性的预训练模型,取得了良好的性能。然而,即使引入了大规模预训练知识,使用视觉提示词微调 (VPT) [2]等PEFT技术时,模型在尾类上的泛化能力依然远逊于头类。Sharpness-Aware Minimization (SAM) [3]优化器能够使模型在训练过程中收敛到平坦的损失极小值点,从而提高其泛化能力。然而,在长尾数据上应用SAM时,模型优化通常由头类主导,忽略了尾类的贡献。此外,SAM需要计算两次梯度,带来了额外的计算代价。因此,迫切需要一种能够提升模型在长尾数据上泛化能力且计算高效的方法。


本论文提出了一种针对长尾数据分布提升VPT泛化能力的新方法 — Gaussian Neighborhood Minimization Prompt Tuning (GNM-PT)。该方法的核心原理基于Sharpness-Aware Minimization (SAM),通过使损失平面更加平坦来增强模型的泛化能力。SAM优化器在训练模型时,通过最小化当前参数邻域内的最大损失值,使得模型极小值点附近的损失平面更平坦。然而,由于长尾数据中大量头类样本的主导,SAM优化策略使得修正后的梯度方向更偏向于优化头类。为了解决这一问题,GNM-PT提出了一种新的优化策略  Gaussian neighborhood minimization (GNM)。与SAM不同,GNM在优化过程中仅需要计算一次梯度,避免了额外的计算开销。通过最小化损失平面中高斯邻域内采样点的损失,GNM能够使模型收敛到一个平坦且不受头类主导的损失极小值点,从而平衡地提升模型对所有类别的泛化能力。此外,GNM-PT还进一步利用提示词中的信息,增强了分类器的鲁棒性。图1展示了损失平面[4]的可视化结果。图1 (a) 表明,GNM与SAM在效果上相似,能够使损失极小值点附近的损失平面更平坦。图1 (b) 表明,在长尾分布数据上,GNM凸性更好,进一步提高了模型的泛化能力。

图1 CoLOD与渐进式模型简化方法的对比

II


 技术贡献 

本工作主要贡献如下:

  • 研究了预训练模型在长尾问题上的潜力,并且提出了对预训练模型的迫切需求:增强对所有类泛化能力的同时减少计算代价;

  • 提出了一个高效的基于VPT的长尾视觉识别算法:GNM-PT,可以在提升模型泛化能力的同时节省计算开


III


 方法介绍 
GNM-PT方法使用VPT微调预训练的ViT模型[5],利用GNM优化器更新VPT的参数,使模型参数收敛到平坦且不受头类主导的损失极小值点,提升泛化能力;另外,将高水平提示词中的信息融合进ViT最后输出的特征中,增强分类性能。

GNM在长尾分布数据上优化模型时,通过最小化损失平面中高斯邻域内采样点的损失来更新模型参数。首先,从正态分布中采样出一个随机向量  并利用  和高斯邻域半径  生成高斯扰动  如公式所示: 
之后,利用  在长尾分布上计算出当前参数高斯邻域内采样点的损失,并最小化该损失,更新模型参数,其过程如公式所示:
按照上述GNM的优化步骤进行训练,最终参数便可收敛到一个平坦且不受头类主导的损失极小值点。整个GNM优化器的参数更新示意图如图2所示:
图2 参数更新示意图
(  和  分别表示在  轮参数更新时未使用和使用GNM情况下的梯度更新)


使用VPT微调预训练ViT模型时,提示词中也编码了大量与当前任务相关的信息。为了进一步提升分类性能,GNM-PT按照下面公式所示的方式进行提示词信息融合:
将最后一层Transformer block的提示词信息  融合进最后一层输出的  中得到  作为ViT最终输出的特征,再将  送入分类器中进行分类。

IV


 部分结果展示 

为了证明GNM在长尾分布上的优势,我们分别使用SAM和GNM两种优化器,在长尾分布数据上利用GCL[6]损失函数训练模型,并可视化两者的损失平面,结果如图3所示。可以看出,GNM能使模型得到更小的损失值,且损失平面几乎没有波动,有助于提高模型泛化能力。

图3 GCL损失下SAM与GNM损失平面对比
为了验证GNM方式在分类精度和计算效率上的优势,我们在保证其他设置相同的情况下分别使用SAM和GNM两种优化器进行训练,对比优化器的执行时间和得到的模型的分类精度。表1中的结果表明,SAM的计算时间比基线方法超出了1.8倍。相比之下,GNM只增加了不到两秒的计算时间,几乎可忽略不计,同时还能够提升分类精度。
表1 SAM和GNM的执行时间与精度对比

图4中统计了SAM和GNM两种优化器分别对不同类的精度影响。可以看出,在使用GCL损失时,SAM降低了模型对于尾类的性能,而论文提出GNM平衡地提升了模型对所有类的性能,更适合解决长尾问题。
图4 不同类的精度对比

为了展示GNM-PT方法与其他先进的长尾视觉识别算法的性能对比,我们在常见的长尾数据集上进行实验,其结果如表2-4所示。GNM-PT在各个数据集上均展现出了较好的分类性能。
表2 在CIFAR100-LT上的top-1分类精度 (%)

表3 在iNaturalist2018上的top-1分类精度 (%)

表4 在Places-LT上的top-1分类精度 (%)

V


 总结与展望 
在长尾学习中,预训练模型对各个类的泛化能力仍然存在偏差,虽然SAM优化器可以让模型收敛到更加平坦的损失极小值点,提升模型的泛化能力,但是需要额外的前向反向传播,加倍了计算开销,且容易受到头类的主导。基于此问题,本论文提出了GNM-PT方法,该方法中提出的GNM 优化器只需要计算一次梯度,几乎没有额外的计算开销,而且GNM最小化损失平面中高斯邻域内的采样点损失,使得模型最终收敛到一个平坦且不受头类主导的损失极小值点。另外,GNM-PT充分利用了高水平提示词中的信息,进一步增强了分类器的鲁棒性。论文中大量的对比实验和消融实验都证明了GNM-PT的卓越效果。

尽管GNM-PT在处理长尾问题上展现了其有效性,然而也有一些局限。表3和表4中的结果显示,进一步重平衡分类器,才能达到整体更好的分类效果,这牺牲了一点点头类性能。未来将会进一步聚焦于同时提升模型的特征表示能力和分类器的性能,以期增强模型对所有类别的泛化能力。

VI


 思考与讨论 
Q: SAM优化器如何使模型收敛到平坦的损失极小值点?
A: SAM优化器的执行过程分为两步:第一步,在当前参数的邻域内计算出一个扰动;第二步,利用扰动后的参数计算损失,并计算梯度,即扰动后的梯度,利用此梯度更新模型。第一步的扰动是由梯度和扰动半径计算得出,如以下公式所示:
梯度为损失变化最快的方向,该扰动对应到损失平面中也是最陡峭的参数点,参数沿着扰动后的梯度下降能使得此参数周围最陡峭的方向趋于平坦,因此SAM优化器能使模型收敛到平坦的损失极小值点。

Q: GNM优化器相比于SAM优化器,为什么能够节省计算代价,并且更适用于长尾分布数据上的模型训练? 
A: GNM最小化损失平面中高斯邻域的采样点损失,计算该损失时使用的扰动由正态分布中随机采样得来,代替了上述公式中利用梯度计算扰动的方式,因此节省了一次前向反向传播。从上述公式中可以得知,SAM使用的梯度计算扰动的方式会受到大量头类样本的主导,而GNM则是借助随机采样得到,免受头类干扰,对所有类更加公平,因此更适用长尾分布上训练模型的情况。 

以下是开放性问题,欢迎读者朋友留言讨论: 
Q: 如何进一步设计GNM扰动获取方式和梯度更新方式,使其能受到类先验的引导,在不引入额外计算代价的同时更加平衡地优化模型参数,进一步提升GNM鲁棒性?

-- End --


导 读 | 刘烨
审 核 | 李梦柯
编 辑 | 申金、余鑫泉

参考文献

[1] Bowen Dong, Pan Zhou, Shuicheng Yan, and Wangmeng Zuo. LPT: long-tailed prompt tuning for image classification. International Conference on Learning Representations (ICLR). 2023. 

[2] Menglin Jia, Luming Tang, Bor-Chun Chen, Claire Cardie, Serge Belongie, Bharath Hariharan, and Ser-Nam Lim. Visual prompt tuning. European Conference on Computer Vision (ECCV). 709-727, 2022. 

[3] Pierre Foret, Ariel Kleiner, Hossein Mobahi, and Behnam Neyshabur. Sharpness-aware minimization for efficiently improving generalization. International Conference on Learning Representations (ICLR). 2021. 

[4] Hao Li, Zheng Xu, Gavin Taylor, Christoph Studer, and Tom Goldstein. Visualizing the loss landscape of neural nets. Conference and Workshop on Neural Information Processing Systems (NeurIPS). 6391-6401, 2018. 

[5] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly. An image is worth 16x16 words: Transformers for image recognition at scale. International Conference on Learning Representations (ICLR). 2021. 

[6] Mengke Li, Yiu-ming Cheung, and Yang Lu. Long-tailed visual recognition via gaussian clouded logit adjustment. Conference on Computer Vision and Pattern Recognition (CVPR). 6929-6938, 2022.


arXiv每日学术速递
工作日更新学术速递!官网www.arxivdaily.com。
 最新文章