分割注意力模块S2Attention,涨点起飞起飞了!

文摘   2024-12-25 17:20   中国香港  

论文介绍

题目:S2-MLPv2: Improved Spatial-Shift MLP Architecture for Vision

论文地址:https://arxiv.org/pdf/2108.01072

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 改进的空间偏移操作:相比于原始的S2-MLP方法,S2-MLPv2在通道维度上扩展特征图并分割为多个部分,对不同部分执行不同的空间偏移操作,并利用分割注意力(Split Attention)机制融合这些部分。这种设计显著提升了跨区域信息交流的能力。

  • 金字塔结构的应用:S2-MLPv2采用了较小尺度的图像块并利用分层的金字塔结构,从而提升了对细粒度图像特征的捕获能力,提高了图像识别精度。这种结构使其与其他先进的MLP架构(如Vision Permutator和CycleMLP)保持一致。

  • 对称与非对称偏移的结合:S2-MLPv2引入了两种互补的空间偏移操作(对称和非对称偏移),在对特征图进行偏移时提升了信息混合的多样性和表达能力。

  • 高效的性能表现:在ImageNet-1K数据集上的实验表明,S2-MLPv2在没有自注意力和外部训练数据的情况下,能以中等规模(55M参数量)达到83.6%的Top-1准确率。这一结果在中等规模的MLP模型中处于最先进的水平,并且超越了一些大规模的MLP模型。

  • 计算复杂度与参数量的平衡:尽管S2-MLPv2在某些情况下需要更多的FLOPs(如与GFNet和AS-MLP相比),但其通过简单的MLP结构和分割注意力机制显著减少了模型的复杂性,在实际部署中具有竞争力。

方法

整体架构

     S2-MLPv2模型由图像块嵌入层、改进的S2-MLP块、金字塔结构和分类头组成。图像块嵌入层将输入图片分割为小块并映射为特征向量,改进的S2-MLP块通过通道扩展、空间偏移(对称和非对称)以及分割注意力融合特征,同时结合通道混合MLP处理特征,金字塔结构逐层提取多尺度特征,最后分类头完成分类任务。这种设计结合了高效的特征提取和融合机制,实现了性能的显著提升。

  • Patch Embedding Layer(图像块嵌入层)

    • 输入一张大小为W×H×3W \times H \times 3 的图片,将其分割为p×pp \times p 大小的图像块。

    • 每个图像块被映射为一个dd-维向量,通过一个全连接层实现图像块到特征向量的映射。

  • S2-MLPv2 Blocks(改进的S2-MLP块)


    • S2-MLPv2 Component(改进的空间偏移组件)

    • Channel-Mixing MLP(通道混合MLP)

    • 第一部分(X1X_1)执行对称偏移。

    • 第二部分(X2X_2)执行非对称偏移。

    • 第三部分(X3X_3)保持不变。

    • 将输入特征图的通道数从cc 扩展到3c3c,并将其分为三个部分。

      对每个部分应用不同的空间偏移操作(对称与非对称偏移)

    • 使用**分割注意力(Split Attention)**机制将这三个部分融合为一个输出特征图。

    • 使用多层感知机(MLP)在通道维度上混合特征。

    • 该部分与MLP-Mixer和ResMLP中的通道混合机制类似。

    • 每个S2-MLPv2块由两个核心组件组成:

    • 每个块通过残差连接(Residual Connection)实现输入与输出的融合,从而增强训练稳定性。

  • Pyramid Structure(金字塔结构)

    • 模型采用两层金字塔结构,将图像分为不同尺度的图像块。

    • 第一层使用较大的图像块(如7×77 \times 7),后续层逐渐减小块的大小(如2×22 \times 2)。

    • 金字塔结构有助于捕捉多层次的视觉特征,提高模型的表达能力。

  • Classification Head(分类头)

    • 最后一层使用全连接层将特征图映射为分类标签。

    • 用于实现分类任务中的最终输出。

即插即用模块作用

S2Attention 作为一个即插即用模块

  • 多分支特征融合任务在需要结合来自不同分支的特征时(如多尺度特征、多模态输入),S2-Attention模块能够高效地动态融合特征,适用于分类、检测和分割任务。

  • 轻量化模型设计对资源有限的设备(如嵌入式设备、移动设备),该模块可以作为自注意力的替代方案,提供高效的特征选择和融合能力,同时保持低计算开销。

  • 跨层或跨模块特征融合在深度学习模型中,不同层次或不同模块的特征常常需要融合,如跨层连接(skip connections)或多尺度金字塔结构中的特征对齐,S2-Attention能够动态调整权重以优化融合效果。

  • 多模态任务在图像与文本、语音等模态融合场景中,S2-Attention可以动态平衡不同模态的贡献,提升多模态任务(如视觉问答、跨模态检索等)的性能。

  • 复杂特征表达任务对于需要提取和融合复杂视觉特征的任务(如医学图像处理、遥感图像分类),S2-Attention可以更好地选择和突出关键特征。

消融实验结果

  • 表5:金字塔结构的影响

  • 比较了两种配置(Small/7 和 Small/14)的性能,Small/7 采用更小的图像块并结合金字塔结构,而 Small/14 使用较大的图像块且没有金字塔结构。

  • 结果显示,Small/7 配置的 Top-1 准确率(82.0%)显著优于 Small/14(80.9%),说明金字塔结构和较小的图像块有助于捕获更精细的图像特征。


  • 表6:分割注意力(Split Attention)的作用

  • 比较了分割注意力与简单的加权平均(Sum-pooling)方法的性能。

  • 使用分割注意力时的 Top-1 准确率为 82.0%,高于 Sum-pooling 的 79.8%,且参数量和计算复杂度的增加较小,说明分割注意力能够更有效地融合特征图。


  • 表7:每个分割的影响

  • 研究了移除某些分割对性能的影响,例如仅保留X1X_1X2X_2,或者移除X3X_3

  • 结果显示,当仅使用两个分割时,模型的 Top-1 准确率下降到 81.6%-81.7%,而完整使用三个分割时准确率为 82.0%,说明每个分割都对性能有重要贡献。

即插即用模块

import numpy as np
import torch
from torch import nn
from torch.nn import init

# 论文地址:https://arxiv.org/pdf/2108.01072
# 论文:S2-MLPv2: Improved Spatial-Shift MLP Architecture for Vision


def spatial_shift1(x):
    b,w,h,c = x.size()
    x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]
    x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]
    x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]
    x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]
    return x


def spatial_shift2(x):
    b,w,h,c = x.size()
    x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]
    x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]
    x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]
    x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]
    return x


class SplitAttention(nn.Module):
    def __init__(self,channel=512,k=3):
        super().__init__()
        self.channel=channel
        self.k=k
        self.mlp1=nn.Linear(channel,channel,bias=False)
        self.gelu=nn.GELU()
        self.mlp2=nn.Linear(channel,channel*k,bias=False)
        self.softmax=nn.Softmax(1)
    
    def forward(self,x_all):
        b,k,h,w,c=x_all.shape
        x_all=x_all.reshape(b,k,-1,c) #bs,k,n,c
        a=torch.sum(torch.sum(x_all,1),1) #bs,c
        hat_a=self.mlp2(self.gelu(self.mlp1(a))) #bs,kc
        hat_a=hat_a.reshape(b,self.k,c) #bs,k,c
        bar_a=self.softmax(hat_a) #bs,k,c
        attention=bar_a.unsqueeze(-2) # #bs,k,1,c
        out=attention*x_all # #bs,k,n,c
        out=torch.sum(out,1).reshape(b,h,w,c)
        return out


class S2Attention(nn.Module):

    def __init__(self, channels=512 ):
        super().__init__()
        self.mlp1 = nn.Linear(channels,channels*3)
        self.mlp2 = nn.Linear(channels,channels)
        self.split_attention = SplitAttention()

    def forward(self, x):
        b,c,w,h = x.size()
        x=x.permute(0,2,3,1)
        x = self.mlp1(x)
        x1 = spatial_shift1(x[:,:,:,:c])
        x2 = spatial_shift2(x[:,:,:,c:c*2])
        x3 = x[:,:,:,c*2:]
        x_all=torch.stack([x1,x2,x3],1)
        a = self.split_attention(x_all)
        x = self.mlp2(a)
        x=x.permute(0,3,1,2)
        return x

        


if __name__ == '__main__':
    input=torch.randn(50,512,7,7)
    block = S2Attention(channels=512)
    output=block(input)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章