ICML 2024 |多模态最新进展!单模态增益多模态学习,解决多模态和单模态学习目标梯度冲突问题

文摘   2024-07-05 08:02   英国  

论文链接:

https://arxiv.org/pdf/2405.17730

代码链接:

https://github.com/GeWu-Lab/MMPareto_ICML2024

简介

具有针对性的单模态学习目标的多模态学习方法在缓解多模态学习不平衡问题方面表现出了卓越的功效。然而,之前被忽视的多模态和单模态学习目标之间的梯度冲突,这可能会误导单模态编码器优化。为了很好地减少这些冲突,作者观察了多模态损失和单模态损失之间的差异,其中更容易学习的多模态损失的梯度幅度和协方差都小于单模态损失。利用这一特性,文中分析了多模态场景下的 Pareto 积分,并提出了 MMPareto 算法,该算法可以确保最终梯度的方向对所有学习目标都是通用的,并增强幅度以提高泛化能力,从而提供单模态辅助。

研究动机

多模态学习过程中存在模态不均衡问题,即大多数多模态模型不能很好地联合利用所有模态,对每种模态的利用不平衡。此外,在多任务场景下,模型优化中存在先前被忽视的风险,这也可能会限制模型的能力。不可否认,单模态学习目标有效地增强了相应模态的学习。同时,单模态编码器参数的优化受到多模态联合学习目标和自身单模态学习目标的影响。这需要同时最小化两个学习目标,但通常不存在一组可以满足该目标的参数。因此,这些多模态和单模态学习目标在优化过程中可能会发生冲突。在图 1a 中,以广泛使用的 Kinetics Sounds 数据集上的视频编码器为例。可看出负余弦相似度表明多模态和单模态梯度在优化过程中确实存在方向冲突。特别是,早期训练阶段的这些冲突可能会严重损害模型能力,从而导致主要的多模态学习可能会受到干扰。

论文贡献

(1)提出了多模态帕累托(MMPareto)算法,该算法在梯度积分时分别考虑方向和大小。它确保了无害的单模态辅助,其中最终梯度的方向是所有学习目标的共同方向,并增强了泛化能力。

(2)对该方法的收敛性进行了分析。基于多种类型数据集的结果,该方法有效缓解了不平衡的多模态学习问题,并且可以很好地配备具有密集跨模态交互的模型,例如多模态 Transformers 。单模态性能甚至优于单独训练的单模态模型,这是以前很少实现的。

(3)验证了所提出的方法还可以扩展到任务难度存在明显差异的多任务情况,表明其可扩展性。

MMPareto方法

类似多任务的多模态框架

在多模态学习中,模型有望通过整合多种模态的信息来产生正确的预测。因此,经常存在多模态联合损失,需要融合多模态特征进行预测。然而,仅利用这种联合损失来一起优化所有模态可能会导致优化过程由一种模态主导,而导致其他模态严重优化不足。为了克服这种不平衡的多模态学习问题,引入针对每种模态优化的单模态损失被广泛使用,并被证明可以有效缓解这种不平衡的多模态学习问题。在这些场景中,损失函数为:

其中 是多模态联合损失, 是模态 k 的单模态损失。n 是模态的数量。我们主要考虑多模态判别任务,并且所有损失都是交叉熵损失函数。这种类似多任务的多模态框架如图 2 的左侧部分所示。

SGD 属性和假设

多模态框架同时具有多模态损失函数和单模态损失函数。对于,模态k的单模态编码器参数、迭代t处的的梯度满足:

其中是批次采样协方差。在多模态情况下,单模态损失仅接收基于相应模态数据的预测。相比之下,多模态损失通过来自所有模态数据的更充分信息进行优化,使其更容易训练。经验证,多模态损失比单模态损失收敛速度更快,训练误差更低.

基于之前的研究和作者在文中的验证,可提出假设1:

假设1. 在多任务多模态情况下,对于共享单模态编码器,单模态损失的梯度往往比易于学习的多模态损失具有更大的幅度和更大的批量采样协方差。

多模态学习中的帕累托积分

在多模态情况下,多模态损失和单模态损失紧密相关,但它们的梯度仍然可能存在冲突,如图1a。因此,如何很好地整合是需要解决的问题。这符合多任务学习中帕累托方法的思想。在帕累托方法中,在每次迭代时,梯度被分配不同的权重,加权组合是最终的梯度,它可以提供有利于所有学习目标的下降方向。最后,参数可以收敛到权衡状态,即帕累托最优,其中任何目标都不能在不损害任何其他目标的情况下推进。将帕累托积分引入多模态框架是很自然的,避免了多模态和单模态梯度之间的冲突。对于模态 k,帕累托算法被公式化来求解:

其中表示L2范数。为了简洁起见,在某些部分将模态 k 表示为。这个问题等价于寻找梯度向量族的凸包中的最小范数。帕累托最优的必要条件是这个优化问题的最小范数为 0,并且相应的参数是帕累托平稳,或者它可以提供所有学习共同的下降方向目标。

多模态帕累托算法

基于以上分析,传统的帕累托方法在多模态学习中可能会导致极小值,进而削弱模型泛化能力。文中提出了多模态帕累托(MMPareto)算法,分别考虑冲突情况和非冲突情况。整体算法如图2所示。文中以模态k的编码器为例,所有模态的编码器都遵循相同的积分。为了简洁起见,还省略了

非冲突情况 首先考虑cos β ≥ 0 的情况。在这种情况下, 之间的余弦相似度为正。对于方向,梯度向量族 的任意凸组合对于所有学习目标都是通用的。因此,在这种情况下,在积分过程中指定 2 = 2 = 1 而不是 Pareto 解析解,以增强 SGD 噪声项。通过此设置,最终梯度为 ,噪声项为与传统 Pareto 噪声项相比,强度有所增强。

冲突情况 对于 cos β < 0 的情况,必须找到所有损失的共同方向,并在梯度积分过程中增强 SGD 噪声强度。因此,首先解决Pareto优化问题,得到,这可以提供一个不冲突的方向。此外,为了增强噪声项的强度,增加了最终梯度的大小。以统一基线的大小为基准,在适当的范围内调整:

总体而言,MMPareto 提供了无冲突方向和增强的 SGD 噪声强度,帮助模型收敛到更平坦的最小值并更好地泛化。除此之外,我们还分析了所提出的 MMPareto 方法的收敛性.

实验结果

根据表1,可以得出统一基线可以获得相当可观的性能,甚至可以优于或与这些不平衡的多模态学习方法相媲美。原因可能是单模态损失的引入有效地增强了每种模态的学习,这符合这些比较方法的核心思想。此外,与现有的多模态预测方法相比,MMpareto 方法具有无冲突优化过程,取得了相当大的改进。更重要的是,MMPareto方法同时表现出出色的单模态性能,甚至可以超越单独训练的单模态模型。例如,在 CREMA-D 和 Kinetics Sounds 数据集上,MMPareto 的音频准确性优于纯音频方法。这在之前的研究中是很少实现的。

更详细的内容和实施过程请访问点击👉原文链接

喜欢的话,请别忘记点赞👍➕关注哦~


推荐阅读

NeurIPS 2023|浙大&上海AI Lab&华为联合发表--跨模态泛化的多模态统一表示

CVPR2024—重磅来袭!西工大团队提出通用多模态医学数据表示学习方法!持续自监督学习!

CVPR2023-动态多模态特征融合!模态级分类!融合级语义分割!

浙江大学最新发布!从ChatGPT到WorldGPT-基于多模态LLM的通用世界模型

AAAI2024-南京大学、腾讯联合发表--MmAP:跨域多任务学习的多模态对齐提示

多模态机器学习与大模型
多模态机器学习与大模型 致力于推荐、分享、解读多模态机器学习相关的前沿论文成果,讨论大语言模型先进技术,助力AI研究者进步。 合作交流请+V:Multimodal2024,谢谢❤️
 最新文章