论文介绍
题目:Parametric Contrastive Learning
论文地址:https://arxiv.org/abs/2107.12028
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
提出了参数化对比学习(PaCo):针对长尾分布数据中的不平衡问题,论文提出了一种新的参数化对比学习方法。PaCo 引入了一组可学习的类中心,旨在通过优化的方式实现类别重平衡。
理论分析与验证:通过理论分析指出,传统的监督对比损失在不平衡数据集上倾向于高频类,导致模型性能的偏差。论文证明了 PaCo 损失可以在训练过程中自适应地增强同类样本的聚集强度,特别是在学习困难样本时。
实验验证了 PaCo 在多个数据集上的有效性:在长尾版本的 CIFAR、ImageNet、Places 和 iNaturalist 数据集上,PaCo 在长尾识别任务中实现了新的最先进性能。即使在平衡数据集上(如完整的 ImageNet 和 CIFAR),PaCo 也表现出了优于传统监督对比学习的效果。
引入中心学习重平衡策略:将平衡 Softmax(Balanced Softmax)融入到中心学习中,从而进一步改进了对长尾问题的处理。
与现有方法的对比:在推理时间、模型复杂度以及性能(如 ImageNet-LT 和 iNaturalist 数据集的分类准确率)上,PaCo 优于现有的监督对比学习方法以及其他基于模型重采样和权重重平衡的方法。
方法
整体架构
PaCo 模型基于 MoCo 对比学习框架,结合参数化类中心设计,将输入样本通过查询网络和动量网络生成特征表示,与动量队列中的负样本及可学习的类中心进行对比,通过 PaCo 损失实现样本间和样本-类中心间的优化。模型通过类中心的动态调整平衡长尾数据中的类别分布,并结合 Balanced Softmax 改善类中心学习,最终在长尾和平衡数据集上均实现了优异性能。
1. 特征提取模块
骨干网络(Backbone Network):使用了标准的卷积神经网络(如 ResNet 和 ResNeXt)作为特征提取模块,用于生成输入样本的特征表示。
2. 对比学习框架
对比学习的基础:基于 MoCo(Momentum Contrastive Learning)的架构,包含以下两个核心部分:
查询网络(Query Network):用于生成查询样本的特征表示。
动量网络(Momentum Network):通过动量更新方式生成负样本特征池。
双视图输入:对输入图像进行两种不同的随机数据增强(RandAugment 和 SimAugment),并将增强后的图像分别输入查询网络和动量网络,以生成查询特征和关键特征。
3. 参数化类中心模块
引入类中心(Class-wise Learnable Centers):为每个类别引入一个可学习的类中心,这些类中心用于捕捉类别的全局分布特征。
类中心的参数化通过对比学习优化,类中心的分布动态调整以适应长尾分布的样本特性。
学习目标:类中心帮助调整类别间的对比损失,从而更好地处理低频类别的学习问题。
4. 损失函数设计
PaCo损失:在监督对比损失的基础上加入类中心,构成 PaCo 的核心损失函数:
样本-样本对比损失:增强同类样本的聚集。
样本-类中心对比损失:引导样本靠近其对应类别的类中心,平衡长尾类别的影响。
中心学习的重平衡策略:结合 Balanced Softmax 对类中心的学习进行权重调整,进一步缓解类别不平衡问题。
5. 动量队列(Momentum Queue)
用于存储大量的负样本特征,以提高对比学习效率。
动量队列中的负样本特征通过动量网络持续更新,确保负样本的多样性和新鲜度。
6. 训练流程
前向传播:输入样本经过骨干网络提取特征,并通过 PaCo 模块计算对比损失。
优化过程:通过梯度下降优化类中心、查询网络和动量网络的参数,动态调整样本与类中心的分布。
即插即用模块作用
DPTAM 作为一个即插即用模块:
缓解类别不平衡问题:通过参数化类中心,提升低频类别样本的学习效果,从而增强模型在长尾数据上的性能。
增强难例学习能力:PaCoLoss 动态调整对比学习强度,特别对“难例”样本表现优异,使其在特征空间内更接近对应的类中心。
提高泛化能力:PaCoLoss 的设计同时适用于长尾分布数据和平衡数据,提高模型在多个任务和数据分布下的泛化性能。
对比学习的优化扩展:将类中心引入到对比学习框架中,改善样本-样本对比的同时,补充样本-类中心的对比,提高表示学习质量。
消融实验结果
在 ImageNet-LT 数据集上比较了交叉熵损失(Cross-Entropy)、监督对比损失(SupCon)和引入 PaCo 损失后的性能。
消融内容:验证 PaCo 损失是否优于传统损失函数。
结果说明:PaCo 在低频类别(Few)、中频类别(Medium)以及整体性能(All)上均显著优于 SupCon 和交叉熵,说明其对低频类别的处理更有效。
表2(Table 2):
比较了不同数据增强策略对 PaCo 在 ImageNet-LT 数据集上性能的影响。
消融内容:SimAugment、RandAugment 及其组合策略的效果对比。
结果说明:结合 SimAugment 和 RandAugment 的增强策略(策略3)性能最佳,说明类中心学习需要更强的数据增强以提升模型泛化性。
即插即用模块
import torch
import torch.nn as nn
class PaCoLoss(nn.Module):
def __init__(self, alpha=1.0, beta=1.0, gamma=0.0, supt=1.0, temperature=1.0, base_temperature=None, K=128,
num_classes=1000):
super(PaCoLoss, self).__init__()
self.temperature = temperature
self.base_temperature = temperature if base_temperature is None else base_temperature
self.K = K
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.supt = supt
self.num_classes = num_classes
def forward(self, features, labels=None, sup_logits=None):
device = torch.device('cuda' if features.is_cuda else 'cpu')
batch_size = features.shape[0]
labels = labels.contiguous().view(-1, 1)
mask = torch.eq(labels[:batch_size], labels.T).float().to(device)
# compute logits using complete features tensor
anchor_dot_contrast = torch.div(
torch.matmul(features, features.T),
self.temperature)
# add supervised logits
anchor_dot_contrast = torch.cat(((sup_logits) / self.supt, anchor_dot_contrast), dim=1)
# for numerical stability
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
logits = anchor_dot_contrast - logits_max.detach()
# mask-out self-contrast cases
logits_mask = torch.scatter(
torch.ones_like(mask),
1,
torch.arange(batch_size).view(-1, 1).to(device),
0
)
mask = mask * logits_mask
# add ground truth
one_hot_label = torch.nn.functional.one_hot(labels[:batch_size, ].view(-1, ), num_classes=self.num_classes).to(
torch.float32)
mask = torch.cat((one_hot_label * self.beta, mask * self.alpha), dim=1)
# compute log_prob
logits_mask = torch.cat((torch.ones(batch_size, self.num_classes).to(device), self.gamma * logits_mask), dim=1)
exp_logits = torch.exp(logits) * logits_mask
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
# compute mean of log-likelihood over positive
mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
# loss
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
loss = loss.mean()
return loss
if __name__ == '__main__':
block = PaCoLoss()
input_features = torch.rand(64, 64)
labels = torch.randint(0, 10, (64,))
sup_logits = torch.rand(64, 1000)
loss = block(input_features, labels=labels, sup_logits=sup_logits)
print(input_features.size()) print(loss.item())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文