图解SimCLR对比学习框架

文摘   科技   2024-07-11 07:30   江苏  
点击蓝字
 
关注我们










01


引言



近年来,人们提出了许多用于学习图像特征表示的自监督学习方法,每种方法都比以前的更好。但是,他们的表现仍然低于受监督的同行。


当Chen等人在他们的研究论文中提出一个新框架SimCLR时,这种情况发生了变化。

论文链接:https://arxiv.org/abs/2002.05709

SimCLR不仅改进了以前最先进的自监督学习方法,而且在使用更强大的主干架构时,在ImageNet分类上击败了监督学习的方法。

在本文中,我将使用图表解释研究论文中提出的该框架的关键思想。






02


举个栗子


当我还是个小孩子的时候,我记得我们必须在教科书中解决类似这样的难题。

孩子解决这个问题的方法是看左侧动物的图片,知道这是一只猫,然后在右侧搜索一只猫。过程如下:

这样的练习是为了让孩子能够识别一个物体并将其与其他物体进行对比。我们能同样地教会机器吗?

事实证明,我们可以通过一种称为对比学习的技术。它试图教机器区分相似和不同的事物。





03


 问题描述


为了对上述问题进行建模,我们需要一台机器而不是一个孩子,此时我们需要做 3 件事:

  • 相似和不同的图像
我们需要相似的图像对和不同的图像对来训练模型。


一般的监督学习需要人类手动标注这些图像对。为了实现这一点的自动化,我们可以利用自监督学习技术。但是我们如何来实现呢?


  • 获取图像特征

我们需要一些机制确保机器能够理解图像的表示。

  • 量化图像相似度
我们需要某种机制来计算两张图像的相似性。





04


  SimCLR核心思想


论文提出了一种名为SimCLR的框架,用于以自监督的方式对上述问题进行建模。它将对比学习的概念与一些新颖的想法相结合,可以在没有人类监督的情况下学习视觉特征表示。

整体框架如下:

SimCLR框架的思想非常简单。拍摄一张图像,并对其应用随机变换,以获得一对两张增广后的图像𝑥i𝑥j ,该配对中的两个图像都通过编码器以获取图像特征表示。然后应用非线性全连接层来得到最后的特征表示 z。该任务训练的目标是最大化对于相同图像的这两种特征表示zi  z之间的相似性。






05


 步骤一


现在让我们通过一个示例来探索 SimCLR 框架的各个组件。假设我们有一个包含数百万张未标记图像的训练数据集。
  • 数据增强
首先,我们从原始图像中生成大小为 N 的批次。为简单起见,我们采用一批大小为 N = 2 的批次。在论文中,他们使用了 8192 的大批量。

本文定义了一个随机变换函数 T,该函数获取图像并应用以下数据增强的不同组合。

random (crop + flip + color jitter + grayscale)

对于每个批次中的每张图像,应用随机变换函数来获取一对 2 张图像。因此,对于batch= 2 的输入设置,我们可以得到 2*N = 2*2 = 4 个图像总数。





06


 步骤二


接着,成对中的每个增强图像都通过编码器以获得对应的图像表示。使用的编码器是通用的,可以与其他架构替换。下面显示的两个编码器具有共享的权重,我们得到向量ℎi和ℎj, 如下所示:

在论文中,作者使用ResNet-50作为特征提取的主干网络,输出的特征向量h的维度为2048维。






07


步骤三

数据增强后的两张图像经过主干特征提取网络后,获得对应的特征表示ij ,接着对应的特征表示经过一系列非线性全连接层后,得到最终的特征表示 zizj , 论文中的这一步描述为g(.) , 又被成为projection head





08


步骤四


在上一步骤中,我们获取了数据增强后每张图像的特征表示:

接着我们来定义特征表示的相似度,如下:

我们定义相似度计算公式如下:

上述公式中,相关说明如下:

  • T 是可调节的控制参数。它可以缩放输入并扩大余弦相似性的范围

使用上述公式计算一个batch中每个增强图像之间的成对余弦相似度。如下图所示,在理想情况下,猫和其增强图像之间的相似性较高,而猫和大象图像之间的相似度较低。








09

步骤五


训练过程中,SimCLR使用的损失函数为NT-Xent loss我们来具体进行讲解。


首先,一个接一个地获取batch中的增广图像对。


接下来,我们应用softmax函数来获得这两个图像相似的概率。


该softmax计算等效于获得第二张增强后的猫图像与该对中的第一张猫图像最相似的概率。这里,该批次中的所有剩余图像都被采样为不同图像(负对)。因此,我们不需要以前的方法(如MoCo)所需的专门架构、或队列。

然后,通过取上述结果的对数的负数来计算一对的损耗。这个公式就是噪声对比估计(NCE)损失。

如果同一对的图像的位置发生互换,我们需要再次计算其损失:

最后,我们计算batch=2中所有配对的损耗,并计算平均值作为结果。

基于上述损失函数,编码器和投影头表示随着时间的推移而改进,并且所获得的特征表示将相似的图像放置在空间中更接近的位置。







10


用于下游任务


一旦在对比学习任务上训练了SimCLR模型,它就可以用于迁移学习。为此,使用来自编码器的特征表示,而不是从投影头获得的特征表示。这些表示可以用于下游任务,如ImageNet Classification。






11


模型评价


SimCLR模型的性能优于ImageNet上以前的自监督的方法。下图显示了在ImageNet上使用不同自监督方法学习的表示上训练的线性分类器的Top1分类精度。灰叉由ResNet50监督学习得到,SimCLR以粗体显示。

  • 在ImageNet ILSVRC-2012上,它实现了76.5%的Top1准确率,比以前的SOTA自监督方法CPC提高了7%,与监督ResNet50不相上下。

  • 当在1%的标签上训练时,它实现了85.8%的Top5准确率,比AlexNet少100倍的标签。




12



代码


论文作者在Tensorflow中对SimCLR的官方实现可以在GitHub上获得。

官方代码链接:https://github.com/google-research/simclr


他们还提供了使用Tensorflow Hub的ResNet50架构的1x、2x和3x变体的预训练模型。

下载链接:https://github.com/google-research/simclr#pre-trained-models-for-simclrv1


此外,网上还有各种非官方的SimCLR PyTorch实现,这些实现已经在CIFAR-10和STL-10等小型数据集上进行了测试。

链接1:https://github.com/leftthomas/SimCLR

链接2:https://github.com/Spijkervet/SimCLR






13



总结



总之,SimCLR为在这个方向上进行进一步的研究提供了一个强大的框架,并改善了计算机视觉的自监督学习状态。



您学废了吗?







点击上方小卡片关注我




添加个人微信,进专属粉丝群!


AI算法之道
一个专注于深度学习、计算机视觉和自动驾驶感知算法的公众号,涵盖视觉CV、神经网络、模式识别等方面,包括相应的硬件和软件配置,以及开源项目等。
 最新文章