论文介绍
题目:S2-MLPv2: Improved Spatial-Shift MLP Architecture for Vision
论文地址:https://arxiv.org/pdf/2108.01072
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
改进的空间偏移操作:相比于原始的S2-MLP方法,S2-MLPv2在通道维度上扩展特征图并分割为多个部分,对不同部分执行不同的空间偏移操作,并利用分割注意力(Split Attention)机制融合这些部分。这种设计显著提升了跨区域信息交流的能力。
金字塔结构的应用:S2-MLPv2采用了较小尺度的图像块并利用分层的金字塔结构,从而提升了对细粒度图像特征的捕获能力,提高了图像识别精度。这种结构使其与其他先进的MLP架构(如Vision Permutator和CycleMLP)保持一致。
对称与非对称偏移的结合:S2-MLPv2引入了两种互补的空间偏移操作(对称和非对称偏移),在对特征图进行偏移时提升了信息混合的多样性和表达能力。
高效的性能表现:在ImageNet-1K数据集上的实验表明,S2-MLPv2在没有自注意力和外部训练数据的情况下,能以中等规模(55M参数量)达到83.6%的Top-1准确率。这一结果在中等规模的MLP模型中处于最先进的水平,并且超越了一些大规模的MLP模型。
计算复杂度与参数量的平衡:尽管S2-MLPv2在某些情况下需要更多的FLOPs(如与GFNet和AS-MLP相比),但其通过简单的MLP结构和分割注意力机制显著减少了模型的复杂性,在实际部署中具有竞争力。
方法
整体架构
S2-MLPv2模型由图像块嵌入层、改进的S2-MLP块、金字塔结构和分类头组成。图像块嵌入层将输入图片分割为小块并映射为特征向量,改进的S2-MLP块通过通道扩展、空间偏移(对称和非对称)以及分割注意力融合特征,同时结合通道混合MLP处理特征,金字塔结构逐层提取多尺度特征,最后分类头完成分类任务。这种设计结合了高效的特征提取和融合机制,实现了性能的显著提升。
Patch Embedding Layer(图像块嵌入层):
输入一张大小为
的图片,将其分割为W × H × 3 W \times H \times 3 大小的图像块。p × p p \times p 每个图像块被映射为一个
-维向量,通过一个全连接层实现图像块到特征向量的映射。d d S2-MLPv2 Blocks(改进的S2-MLP块):
S2-MLPv2 Component(改进的空间偏移组件):
Channel-Mixing MLP(通道混合MLP):
第一部分(
)执行对称偏移。X 1 X_1 第二部分(
)执行非对称偏移。X 2 X_2 第三部分(
)保持不变。X 3 X_3 将输入特征图的通道数从
扩展到c c ,并将其分为三个部分。3 c 3c 对每个部分应用不同的空间偏移操作(对称与非对称偏移)
使用**分割注意力(Split Attention)**机制将这三个部分融合为一个输出特征图。
使用多层感知机(MLP)在通道维度上混合特征。
该部分与MLP-Mixer和ResMLP中的通道混合机制类似。
每个S2-MLPv2块由两个核心组件组成:
每个块通过残差连接(Residual Connection)实现输入与输出的融合,从而增强训练稳定性。
Pyramid Structure(金字塔结构):
模型采用两层金字塔结构,将图像分为不同尺度的图像块。
第一层使用较大的图像块(如
),后续层逐渐减小块的大小(如7 × 7 7 \times 7 )。2 × 2 2 \times 2 金字塔结构有助于捕捉多层次的视觉特征,提高模型的表达能力。
Classification Head(分类头):
最后一层使用全连接层将特征图映射为分类标签。
用于实现分类任务中的最终输出。
即插即用模块作用
S2Attention 作为一个即插即用模块:
多分支特征融合任务:在需要结合来自不同分支的特征时(如多尺度特征、多模态输入),S2-Attention模块能够高效地动态融合特征,适用于分类、检测和分割任务。
轻量化模型设计:对资源有限的设备(如嵌入式设备、移动设备),该模块可以作为自注意力的替代方案,提供高效的特征选择和融合能力,同时保持低计算开销。
跨层或跨模块特征融合:在深度学习模型中,不同层次或不同模块的特征常常需要融合,如跨层连接(skip connections)或多尺度金字塔结构中的特征对齐,S2-Attention能够动态调整权重以优化融合效果。
多模态任务:在图像与文本、语音等模态融合场景中,S2-Attention可以动态平衡不同模态的贡献,提升多模态任务(如视觉问答、跨模态检索等)的性能。
复杂特征表达任务:对于需要提取和融合复杂视觉特征的任务(如医学图像处理、遥感图像分类),S2-Attention可以更好地选择和突出关键特征。
消融实验结果
表5:金字塔结构的影响
比较了两种配置(Small/7 和 Small/14)的性能,Small/7 采用更小的图像块并结合金字塔结构,而 Small/14 使用较大的图像块且没有金字塔结构。
结果显示,Small/7 配置的 Top-1 准确率(82.0%)显著优于 Small/14(80.9%),说明金字塔结构和较小的图像块有助于捕获更精细的图像特征。
表6:分割注意力(Split Attention)的作用
比较了分割注意力与简单的加权平均(Sum-pooling)方法的性能。
使用分割注意力时的 Top-1 准确率为 82.0%,高于 Sum-pooling 的 79.8%,且参数量和计算复杂度的增加较小,说明分割注意力能够更有效地融合特征图。
表7:每个分割的影响
研究了移除某些分割对性能的影响,例如仅保留
和X 1 X_1 ,或者移除X 2 X_2 。X 3 X_3 结果显示,当仅使用两个分割时,模型的 Top-1 准确率下降到 81.6%-81.7%,而完整使用三个分割时准确率为 82.0%,说明每个分割都对性能有重要贡献。
即插即用模块
import numpy as np
import torch
from torch import nn
from torch.nn import init
# 论文地址:https://arxiv.org/pdf/2108.01072
# 论文:S2-MLPv2: Improved Spatial-Shift MLP Architecture for Vision
def spatial_shift1(x):
b,w,h,c = x.size()
x[:,1:,:,:c//4] = x[:,:w-1,:,:c//4]
x[:,:w-1,:,c//4:c//2] = x[:,1:,:,c//4:c//2]
x[:,:,1:,c//2:c*3//4] = x[:,:,:h-1,c//2:c*3//4]
x[:,:,:h-1,3*c//4:] = x[:,:,1:,3*c//4:]
return x
def spatial_shift2(x):
b,w,h,c = x.size()
x[:,:,1:,:c//4] = x[:,:,:h-1,:c//4]
x[:,:,:h-1,c//4:c//2] = x[:,:,1:,c//4:c//2]
x[:,1:,:,c//2:c*3//4] = x[:,:w-1,:,c//2:c*3//4]
x[:,:w-1,:,3*c//4:] = x[:,1:,:,3*c//4:]
return x
class SplitAttention(nn.Module):
def __init__(self,channel=512,k=3):
super().__init__()
self.channel=channel
self.k=k
self.mlp1=nn.Linear(channel,channel,bias=False)
self.gelu=nn.GELU()
self.mlp2=nn.Linear(channel,channel*k,bias=False)
self.softmax=nn.Softmax(1)
def forward(self,x_all):
b,k,h,w,c=x_all.shape
x_all=x_all.reshape(b,k,-1,c) #bs,k,n,c
a=torch.sum(torch.sum(x_all,1),1) #bs,c
hat_a=self.mlp2(self.gelu(self.mlp1(a))) #bs,kc
hat_a=hat_a.reshape(b,self.k,c) #bs,k,c
bar_a=self.softmax(hat_a) #bs,k,c
attention=bar_a.unsqueeze(-2) # #bs,k,1,c
out=attention*x_all # #bs,k,n,c
out=torch.sum(out,1).reshape(b,h,w,c)
return out
class S2Attention(nn.Module):
def __init__(self, channels=512 ):
super().__init__()
self.mlp1 = nn.Linear(channels,channels*3)
self.mlp2 = nn.Linear(channels,channels)
self.split_attention = SplitAttention()
def forward(self, x):
b,c,w,h = x.size()
x=x.permute(0,2,3,1)
x = self.mlp1(x)
x1 = spatial_shift1(x[:,:,:,:c])
x2 = spatial_shift2(x[:,:,:,c:c*2])
x3 = x[:,:,:,c*2:]
x_all=torch.stack([x1,x2,x3],1)
a = self.split_attention(x_all)
x = self.mlp2(a)
x=x.permute(0,3,1,2)
return x
if __name__ == '__main__':
input=torch.randn(50,512,7,7)
block = S2Attention(channels=512)
output=block(input) print(output.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文