近年来,机器学习在各个领域展现出了惊人的性能表现,然而,获取高质量的大规模标注数据在实际应用中往往困难重重。本文介绍了一个应对这一挑战的通用框架——从弱监督中学习的通用框架(GLWS)。本文由来自卡耐基梅隆大学、微软研究院、新加坡科技设计大学等机构的研究人员共同完成,展示了一种通过期望最大化(EM)算法学习来自各种弱监督源的通用方法,在十几个弱监督问题中显著提升了模型的可扩展性和性能。
论文标题:A General Framework for Learning from Weak Supervision 论文链接:https://arxiv.org/abs/2402.01922 论文代码:https://github.com/Hhhhhhao/General-Framework-Weak-Supervision
背景介绍:弱监督学习的挑战
弱监督标签在机器学习应用时广泛存在,比如噪音标签(noisy label), 单个数据对应多个标签(partial label/crowdsourcing), 多个数据对应单个标签(multiple-instance learning/label proportion)。在每个不同标签的场景下都有很多方法被提出。然而弱监督学习仍然面临着两个主要挑战:
处理多种弱监督配置的普适性: 过去的传统方法通常需要针对特定形式的弱监督设计特定的解决方案,难以在多种弱监督形式下普遍适用。然而实际应用中非常可能多种弱监督标签共同存在。 现有算法的复杂性导致的可扩展性问题: 过去的方法通常通过过于简单的假设或者过于复杂的模块设计来解决多种弱监督的问题,导致这些方法没办法很好的被大规模的实际应用。
本文提出GLWS,一种基于最大期望算法(Expectation-Maximization, EM)的弱监督学习框架,通过将各种弱监督形式建模为非确定性有限自动机(Non-determinstic Finite Automata, NFA),并结合前向后向(Forward-Backward Algorithm)算法,高效的解决所提出的EM框架。GLWS使得EM计算的时间复杂度从传统方法的二次或阶乘级别降低到了线性级别,并且可以广泛的应用于不同的弱监督场景(14+)。
弱监督分类学习的通用EM框架
我们用 表示一对有准确标签的训练数据; 表示可学习的分类器, 用来预测 .
全监督学习
对于所有标签完整且准确的全监督学习,我们有学习目标:
以及对应的损失函数:
弱监督学习
在实际应用中,我们往往接触不到完整且准确的标签( unkown),能接触到的只有弱监督标签。这里我们把弱监督标签抽象的表示为 ,用来代表不同形式的弱监督信息,比如:
Partial label learning中的多个标签 Multiple instance learning中的标签统计 Label proportion learning中的标签数量统计 对于不同的弱监督标签/信息,我们的优化目标为:
因为 未知以及对 的marginalization需要已知 ,以上优化目标通常只能通过迭代 -- EM算法 -- 来解决:
为了进一步推到基于EM的通用弱监督学习的损失函数,我们把训练数据重新表示为 和 , 。不同种类的弱监督标签可以理解为在 上的已知信息。基于条件概率独立假设 ,我们可以推导基于EM的通用弱监督学习的损失函数为:
注:以上假设对于non-sequential network来说是完全准确的。
GLWS: 高效解决EM弱监督学习
尽管有了通用的弱监督学习的损失函数,可以发现这个损失函数仍然是难以解决的,计算 需要找到 所有当前弱监督信息 满足的可能的标签组合 。对于一些弱监督场景,计算 的复杂度可以高达 或 .
为了解决计算复杂度的问题,我们提出了一个新颖且有趣的角度 -- 非确定性有限自动机(NFA).
非确定性有限自动机(NFA)
基于我们的建模,我们可以把“找到 所有当前弱监督信息 满足的可能的标签组合 ”这个问题表示为一个NFA (详情可见维基百科)。
对于不同的弱监督标签,我们可以用不同的NFA来表示
动态规划算法
有了不同弱监督场景的NFA之后,我们可以进一步基于模型预测的output的线性图和弱监督的NFA来把所有满足弱监督信息的标签也表示为一个线性图:
其中每条trelli就表示一组满足的可能的标签。在所得图上,我们可以采用动态规划算法 -- forward-backward algorithm, 来以线性复杂度计算 , 以高效的解决EM损失函数。
对于其中每个节点,我们可以结合前向和后向来计算:
以上算法可以通过把不同类别都表示为一个二分类问题从而简单的扩展到多分类问题上。
实验结果
我们在CIFAR-10、CIFAR-100、STL-10和ImageNet-100等多个数据集上进行了实验,GLWS在14个弱监督学习任务中都表现出色。例如,在ImageNet-100数据集上,GLWS在部分标签学习任务中的准确率相比之前最好的方法提高了1.28%。这里我们只展示部分结果,更多结果可以查看论文。
算法分析
我们同时对GLWS进行了一些算法层面的分析。
相比于之前的方法,GLWS展现出来稳定的快速收敛。
对于不同, GLWS展现符合预期的线性复杂度。
实践意义
GLWS不仅提高了机器学习模型在弱监督条件下的扩展性和性能,还为实际应用中的大规模部署铺平了道路。代码已开源,可供研究人员和开发者进一步研究和应用。
通过GLWS框架,弱监督学习不再局限于特定的场景,变得更加普遍适用和高效。GLWS的计算复杂度可以进一步被优化,融入NFA minimization和determinization来简化图。GLWS也可以被扩展到其他的sequential的任务中。期待未来更多的研究能够基于此框架以及GLWS和foundation model的交叉。
引用论文:
Wei, Z., Feng, L., Han, B., Liu, T., Niu, G., Zhu, X. and Shen, H.T., 2023, July. A universal unbiased method for classification from aggregate observations. In International Conference on Machine Learning (pp. 36804-36820). PMLR. Shukla, V., Zeng, Z., Ahmed, K. and Van den Broeck, G., 2024. A Unified Approach to Count-Based Weakly Supervised Learning. Advances in Neural Information Processing Systems, 36.