即插即用CSRA残差注意力机制,涨点起飞起飞了!

文摘   2025-01-06 17:20   上海  

论文介绍

题目:Residual Attention: A Simple but Effective Method for Multi-Label Recognition

论文地址:https://arxiv.org/pdf/2108.02456

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出了一个简单但有效的模块——类别特定残差注意力模块(Class-Specific Residual Attention, CSRA)

    • 通过引入类别特定的残差注意力机制,捕获不同类别对象的空间区域。

    • 将全局平均池化与类别特定的空间注意力分数相结合,从而生成类别特定的特征。

  • 大幅度简化了多标签识别的方法

    • 相较于现有复杂的多标签分类模型,CSRA的实现非常简单。

    • 使用仅需几行代码的简单修改(例如结合全局最大池化和全局平均池化),即可显著提升多种预训练模型的性能。

  • 轻量级且计算开销低

    • CSRA模块几乎没有额外的计算成本,可以直接整合到现有模型中。

    • 多头注意力扩展版本允许多个注意力分支同时运行,提高了鲁棒性和分类性能。

  • 通用性强且解释性好

    • 不仅适用于经典的卷积神经网络(如ResNet),也适用于视觉变换器(Vision Transformers, ViT)等新兴架构。

    • 提供了直观的解释和可视化,证明其注意力机制可以准确地定位多类对象。

  • 在多项基准测试数据集上达到了最新的最优性能

    包括VOC2007、VOC2012、MS-COCO和WIDER-Attribute等数据集,均取得了高于现有方法的性能。

方法

整体架构

     这篇论文的模型结构由预训练主干网络(如ResNet或ViT)进行特征提取,然后通过类别特定残差注意力模块(CSRA)生成类别特定的特征,其中全局类别无关特征与类别特定的残差特征相结合,以提高分类精度;多头注意力机制(可选)进一步增强鲁棒性。最终,每个类别的特征通过分类器输出预测得分,使用二元交叉熵损失优化模型,整体结构简单高效,适用于多标签图像分类任务

  • 输入与特征提取

    • 图像首先通过一个预训练的主干网络(如ResNet或Vision Transformer, ViT)进行特征提取。

    • 提取的特征张量表示为xRd×h×w,其中 dd 是通道数,hh 和ww 是特征图的高度和宽度。

  • 类别特定残差注意力模块(CSRA)

    • 将类别特定特征aia_i 和全局类别无关的特征gg 进行融合,得到类别特定的残差特征:fi=g+λai,f_i = g + \lambda a_i,其中gg 是通过全局平均池化计算得到的全局特征。

    • 利用注意力分数对特征进行加权组合,生成类别特定的特征向量:ai=k=1Nsikxk.

    • 通过计算每个类别在各空间位置的注意力分数sijs_{ij},公式如下:sij=exp(TxjTmi)k=1Nexp(TxkTmi),s_{ij} = \frac{\exp(T x_j^T m_i)}{\sum_{k=1}^{N} \exp(T x_k^T m_i)},其中TT 为控制分数锐度的超参数,mim_i 是第ii 类别的分类器权重,xjx_j 是第jj 个位置的特征向量。

  • 多头注意力机制(可选)

    • 在多头注意力版本中,模型对同一特征张量xx使用多个注意力分支,每个分支使用不同的温度TT 来计算注意力分数。

    • 不同分支的输出会被融合,得到最终的分类结果。

  • 分类器

    • 每个类别特定的残差特征fif_i 会输入到一个线性分类器,计算最终的分类得分y^i\hat{y}_iy^i=miTfi.

  • 损失函数

    • 模型使用经典的二元交叉熵(Binary Cross Entropy, BCE)损失函数来优化多标签分类任务。

  • 输出

    • 输出为每个类别的预测得分,表示图像中对应类别存在的概率。

即插即用模块作用

CSRA 作为一个即插即用模块

  • 提升分类精度通过结合类别无关的全局特征和类别特定的残差注意力特征,CSRA能够捕获目标类别的细粒度空间信息,显著提高分类模型的精度。


  • 增强模型的鲁棒性在处理复杂场景(如多目标、目标遮挡或目标尺寸变化)时,CSRA能够通过类别特定注意力聚焦于关键区域,减少背景干扰,提高模型在实际场景中的鲁棒性。


  • 通用性强CSRA模块可以轻松集成到多种主干网络(如ResNet、ViT)中,适配不同的任务和模型架构,展示出良好的兼容性和扩展性。

  • 节省计算资源CSRA模块结构简单,计算开销极低,可在不增加模型复杂度的情况下提供一致的性能增益,适合部署在资源受限的场景。


  • 增强模型可解释性CSRA通过可视化注意力分数,能够直观展示模型如何关注不同类别的区域,为任务结果提供了更强的解释性,尤其适合需要结果可追溯性的场景(如医疗和法律领域)。

消融实验结果

  • 内容:分析了全局类别无关特征(全局平均池化)和类别特定特征(空间池化)在多标签分类任务中的贡献。

  • 结果

    • 仅使用全局平均池化(average pooling)时,mAP为82.1%。

    • 仅使用空间池化(spatial pooling)时,mAP提升至84.2%。

    • 结合两者(即CSRA模块),mAP进一步提升至85.3%。

  • 说明:类别特定特征(空间池化)在多标签分类中比全局特征更重要,而结合两者能够获得最佳性能。


    • 内容:分析了多头注意力机制中注意力分支(head)的数量对模型性能的影响。

    • 结果

      • 对于VIT-L16主干网络,mAP从H=1的85.8%逐渐提升到H=8的86.5%。

      • 对于ResNet-cut主干网络,mAP从H=1的85.3%提升到H=6的85.6%,但H=8略微下降至85.5%。

    • 说明:增加注意力分支的数量可以提升性能,但过多的分支可能导致边际收益减少甚至性能下降。

即插即用模块

import numpy as np
import torch
from torch import nn
from torch.nn import init

# 论文地址:https://arxiv.org/pdf/2108.02456
# 论文:Residual Attention: A Simple but Effective Method for Multi-Label Recognition



class ResidualAttention(nn.Module):

    def __init__(self, channel=512 , num_class=1000,la=0.2):
        super().__init__()
        self.la=la
        self.fc=nn.Conv2d(in_channels=channel,out_channels=num_class,kernel_size=1,stride=1,bias=False)

    def forward(self, x):
        b,c,h,w=x.shape
        y_raw=self.fc(x).flatten(2) #b,num_class,hxw
        y_avg=torch.mean(y_raw,dim=2) #b,num_class
        y_max=torch.max(y_raw,dim=2)[0] #b,num_class
        score=y_avg+self.la*y_max
        return score

        


if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    resatt = ResidualAttention(channel=512,num_class=1000,la=0.2)
    output=resatt(input)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章