论文介绍
题目:EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network
论文地址:https://arxiv.org/pdf/2105.14447
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
提出了新的注意力模块(PSA模块):设计了一种名为“Pyramid Squeeze Attention (PSA)”的轻量级模块,通过多尺度金字塔卷积结构对输入特征图的信息进行整合,同时提取空间信息和跨维度的特征交互,解决了传统注意力模块中多尺度特征提取和长距离通道依赖不足的问题。
构建了EPSA模块:通过将PSA模块替代ResNet的3x3卷积,开发了一种名为“Efficient Pyramid Squeeze Attention (EPSA)”的新型结构模块。该模块不仅提升了多尺度特征提取能力,还增强了通道间的长距离依赖关系,且灵活易于集成。
开发了EPSANet架构:通过堆叠EPSA模块,设计了一个新型的主干网络EPSANet,与传统注意力模型相比,提供了更强的多尺度特征表示能力,显著提高了多个视觉任务(如图像分类、目标检测、实例分割)的性能。
性能提升和高效性:
在ImageNet数据集上,相比于SENet-50,Top-1准确率提高了1.93%。
在目标检测任务中,使用Mask R-CNN在MS-COCO数据集上的Box AP提升了2.7个百分点,Mask AP提升了1.7个百分点。
EPSANet相比其他方法具有更低的参数量和计算复杂度,同时性能更优,表现出高效的计算资源利用率。
方法
整体架构
EPSANet 是一种基于 ResNet 的高效多尺度注意力网络,通过引入 Pyramid Squeeze Attention (PSA) 模块替代传统 3x3 卷积,形成 Efficient Pyramid Squeeze Attention (EPSA) 块,从而增强多尺度特征提取和通道间长距离依赖能力。模型通过堆叠多个 EPSA 块,构建出灵活高效的网络架构,既提升了图像分类、目标检测和实例分割等任务的性能,又显著降低了参数量和计算复杂度,适用于多种视觉任务。
1. 基础架构
论文以ResNet为基础模型框架,在其瓶颈模块中替换传统的3x3卷积,加入了新提出的**Pyramid Squeeze Attention (PSA)**模块,从而形成新的块结构——**Efficient Pyramid Squeeze Attention (EPSA)**块。通过这种方式,模型能够更高效地提取多尺度特征,并建立通道间的长距离依赖。
2. 核心模块:Pyramid Squeeze Attention (PSA)
PSA模块分为以下几个步骤:
多尺度特征提取:通过**Squeeze and Concat (SPC)**模块从输入特征图中提取不同尺度的特征。SPC模块采用多分支结构,对输入的特征通道进行划分,并使用不同尺寸的卷积核(如3x3, 5x5, 7x7)在各分支上提取特征。
通道注意力权重计算:利用SE模块对不同尺度的特征图分别计算通道注意力权重。
权重重标定:通过Softmax对每个尺度的通道注意力权重进行重新校准,融合不同尺度的上下文信息。
特征重新组合:将重新校准后的通道注意力权重与原始多尺度特征图进行逐元素乘积操作,并最终通过拼接生成新的特征图。
3. EPSA块
EPSA块将上述PSA模块嵌入到ResNet的瓶颈结构中,替代原有的3x3卷积。这种设计使得:
可以以更细粒度提取多尺度特征。
在不显著增加模型复杂度的情况下,增强了通道间的长距离依赖。
提升了模型的特征表达能力。
4. EPSANet架构
通过堆叠多个EPSA块,论文设计了两种版本的主干网络架构:
EPSANet(Small):适用于低计算成本的场景,使用较小的组卷积和卷积核。
EPSANet(Large):适用于更高精度要求的场景,使用更大的组卷积和卷积核。
即插即用模块作用
EPSA 作为一个即插即用模块:
多尺度特征提取
EPSA 利用 Pyramid Squeeze Attention (PSA) 提取输入特征的多尺度空间信息,能够在细粒度上增强特征表示,尤其适合需要处理多尺度特征的任务(如目标检测和实例分割)。
跨通道注意力增强
EPSA 模块通过计算通道之间的长距离依赖,能够有效捕获全局上下文信息,从而优化通道注意力权重分配,提升模型感知全局信息的能力。
灵活的模块化设计
EPSA 模块设计为即插即用组件,可以轻松集成到现有的深度学习模型中(如 ResNet),适应不同的网络架构和应用场景,而无需显著增加计算复杂度。
性能和效率的平衡
EPSA 提供了更高的精度,同时减少参数量和计算成本,使其适用于资源有限的场景(如移动设备上的视觉应用)。
消融实验结果
内容:表中对比了不同组卷积(Group Size)设置下的核大小对模型性能(Top-1 和 Top-5 准确率)的影响。
意义:
调整核大小和组大小可以在性能和计算成本之间取得平衡。
最优设置为核大小 (3, 5, 7, 9) 和组大小 (1, 4, 8, 16),该配置下 EPSANet 在 ImageNet 数据集上实现了最高的 Top-1 和 Top-5 准确率(77.49% 和 93.54%)。
即插即用模块
import numpy as np
import torch
from torch import nn
from torch.nn import init
# 论文地址:https://arxiv.org/pdf/2105.14447
# 论文:EPSANet: An Efficient Pyramid Squeeze Attention Block on Convolutional Neural Network
class PSA(nn.Module):
def __init__(self, channel=512,reduction=4,S=4):
super().__init__()
self.S=S
self.convs=[]
for i in range(S):
self.convs.append(nn.Conv2d(channel//S,channel//S,kernel_size=2*(i+1)+1,padding=i+1))
self.se_blocks=[]
for i in range(S):
self.se_blocks.append(nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channel//S, channel // (S*reduction),kernel_size=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channel // (S*reduction), channel//S,kernel_size=1, bias=False),
nn.Sigmoid()
))
self.softmax=nn.Softmax(dim=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, h, w = x.size()
#Step1:SPC module
SPC_out=x.view(b,self.S,c//self.S,h,w) #bs,s,ci,h,w
for idx,conv in enumerate(self.convs):
SPC_out[:,idx,:,:,:]=conv(SPC_out[:,idx,:,:,:])
#Step2:SE weight
se_out=[]
for idx,se in enumerate(self.se_blocks):
se_out.append(se(SPC_out[:,idx,:,:,:]))
SE_out=torch.stack(se_out,dim=1)
SE_out=SE_out.expand_as(SPC_out)
#Step3:Softmax
softmax_out=self.softmax(SE_out)
#Step4:SPA
PSA_out=SPC_out*softmax_out
PSA_out=PSA_out.view(b,-1,h,w)
return PSA_out
if __name__ == '__main__':
input=torch.randn(50,512,7,7)
block = PSA(channel=512,reduction=8)
output=block(input)
a=output.view(-1).sum()
a.backward() print(output.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文