在大模型被广泛应用的今天,研究半监督学习还有意义吗?
我们的答案是:有的。即使大模型大行其道,在下游任务微调时,仍然不可避免要遇到标签数据稀少的问题,而半监督学习旨在只有少量标注数据的时候利用大量无标注数据提升模型的泛化性。我们的NeurIPS 2022的工作 USB中已验证了预训练模型在半监督场景下的局限性。现在,我们将从算法创新层面再对半监督进行提升。
近年来,基于阈值的伪标签方法的半监督方法取得了巨大的成功。然而,我们认为现有的方法可能无法有效地利用未标记的数据,因为它们要么 「使用预定义 / 固定阈值」,要么 「使用专门的启发式阈值调整方案」。这将导致模型性能低下和收敛速度慢。在本文中,我们首先 「理论」 分析一个简单的二分类模型,以获得关于理想阈值和模型学习状态之间关系的直觉。基于分析,我们因此提出 「FreeMatch」 来根据模型的学习状态以 「自适应方式」 调整置信度阈值。我们进一步引入自适应类公平正则化惩罚,以鼓励模型在早期训练阶段进行多样化预测。广泛的实验表明FreeMatch的优越性,尤其是当标记数据极其稀少时。
文章已被机器学习顶级会议 「ICLR」 2023录用,其在录用之前就收到了多方关注、有多人索要代码。文章第一作者为微软亚洲研究院实习生、东京工业大学硕士生王一栋,共同第一作者为卡耐基梅隆大学的陈皓。通讯作者为微软亚洲研究院的王晋东。
论文标题:FreeMatch: Self-adaptive Thresholding for Semi-supervised Learning 论文链接: https://arxiv.org/abs/2205.07246 代码链接: https://github.com/microsoft/Semi-supervised-learning
对半监督学习而言,什么是好的阈值?
我们从一个简单的二分类问题出发来分析:好的半监督学习的阈值是怎么样的。
假设真实的数据分布来自两个高斯分布的混合:
再考虑输出概率如下的分类器:
如果我们采用一个固定的阈值, 那么不难证明伪标签 有如下的概率分布;
详细证明过程可以参考论文。
观察上面的公式,我们可以获得一些明显的推论:
首先,不难看出未标注数据的采样率是直接由决定的:越大,伪标签的数量越少。更有趣的是,当时,。这可能导致伪标签分布不均匀从而损害模型表现。 同时,伪标签采用率 随着 变小而下降。换言之,两个类越接近,模型的置信度越低,因此也应相应降低以保证伪标签的分布均匀。
这些推论为我们设计一个自适应阈值提供了如下的启发:
在训练的早期,应该相对较小,以促使伪标签多元化,提升未标注数据的利用率,提升模型收敛速度。 随着训练的进行(变大),较低的阈值会导致确认误差。在理想的情况下,应该随着变大以维持一个稳定的伪标签采用比例。 同时由于类内多样性()以及类邻接 (相对较小),某些类的分类难度要大于其余类,我们应该对每个类设置一个局部阈值。
FreeMatch:自适应阈值方法
我们提出的FreeMatch包含两部分:「自适应阈值」 和 「自适应公平正则化惩罚」。下面分别进行介绍。
自适应阈值 (SAT)
如下图所示,自适应阈值具体可以分为自适应全局阈值、自适应局部阈值。局部阈值旨在以类特定的方式调整全局阈值,以考虑类内多样性和可能的类邻接。
自适应全局阈值
我们根据以下两个原则设计全局阈值。首先,全局阈值应该与模型对未标记数据的置信度相关,反映整体学习状态。此外,全局阈值应在训练期间稳定增加,以确保在训练后期丢弃噪声伪标签。我们将全局阈值 设置为模型对未标记数据的 「平均置信度」,其中 表示第 个时间步(迭代)。
然而,由于未标注数据数量庞大,在每个时间步甚至每个训练时期计算所有未标记数据的置信度将非常耗时。因此,我们将全局置信度估计为每个训练时间步长置信度的指数移动平均值 (EMA)。具体来说,我们将 初始化为 ,其中 表示类数。
具体而言,全局阈值 定义和调整为:
其中 是 EMA 的动量衰减。
自适应局部阈值
我们计算模型对每个类别 的预测的期望,以估计特定于类别的学习状态:
其中 是包含所有 的列表。
最终的阈值自适应调整
整合全局和局部阈值,我们得到最终的自适应阈值 为:
其中 是最大归一化(即 )。
最后,第 次迭代的无监督训练目标 是:
自适应公平正则化惩罚 (SAF)
我们没有使用之前常被使用的类平均先验来惩罚模型(因为真实场景往往不满足类平衡条件),而是使用来自模型预测的滑动平均EMA 作为期望的估计未标记数据的预测分布。
我们优化 和 的交叉熵批处理作为 的估计。
考虑到潜在的伪标签分布可能不均匀,我们建议以自适应的方式调节公平性目标,即通过伪标签的直方图分布对概率的期望进行归一化,以抵消不平衡的负面影响:
与相似, 我们这样计算:
第步的自适应公平正则化惩罚(SAF) 表示如下:
最终模型的训练目标由对标注数据的交叉熵,无监督训练目标和自适应公平正则化惩罚组成。
具体细节可以参考文章内容。
实验
我们进行了详尽的实验,包括在经典benchmark与之前的算法进行对比(Table 1)和ImageNet结果对比(Table2)。为了证明FreeMatch不需要预定义阀值,我们在表十中提供了FixMatch和FlexMatch不同阀值的实验。
从表一,表二和表十可以看出,「FreeMatch有助于减少超参数调整计算或整体训练时间(在别的算法使用最佳选择的阈值情况下,FreeMatch无需预定义阈值即可获得更优异的性能)并且FreeMatch的性能优于任何固定阈值的方法」。
为了更好的理解FreeMatch,我们在图3中分析了FreeMatch在STL-10 40标签的实验中阈值,无标签数据的利用率,和准确率随训练的变化。可以看出,FreeMatch在训练初始阶段自适应的采取了较低的阈值,所以更多的无标签数据参与到学习中。随着模型从无标签数据中学习,阈值快速上升(与dash手动定义相比),使得不准确的无标签数据被筛出,从而达到更准确的利用无标签数据的目的。
我们还在论文中提供了详细的消融实验,感兴趣的读者可以自行查看。
总结
我们提出了FreeMatch方法,该方法提出了自适应阈值和自适应公平性正则化。FreeMatch在各种SSL基准测试中优于其他SOTA算法,尤其是在标注数据极其稀少的情况下。我们认为置信度阈值在SSL中具有很大的潜力。我们希望我们的方法能够激发更多关于最优阈值的研究。