即插即用极化自注意力模块PSAN,涨点起飞起飞了!

文摘   2024-12-27 18:51   中国香港  

论文介绍

题目:Polarized Self-Attention: Towards High-quality Pixel-wise Regression

论文地址:https://arxiv.org/pdf/2107.00782

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 极化滤波(Polarized Filtering)

    • 在通道注意力和空间注意力的计算中保持高分辨率,但完全压缩输入张量在相应的正交方向。这种设计能够在低计算和存储开销的情况下最大化保留高分辨率的信息。

  • 非线性增强(Non-linear Enhancement)

    • 通过结合 Softmax 和 Sigmoid 的组合,使得输出非线性能够直接匹配像素级回归的典型分布(如关键点热图中的二维高斯分布或语义分割中的二项分布)。这一增强减少了学习的负担,提升了模型性能。

  • 通道与空间分支的平行与序列化设计

    • PSA 模块分为通道注意力和空间注意力两部分,且可以采用平行或序列的方式组合。实验表明,两种组合方式的性能差异非常小,这表明 PSA 已经充分利用了其在通道和空间维度上的表示能力。

  • 性能提升

  • PSA 在多个基准任务(如2D人体姿态估计和语义分割)中显著提升了性能。例如,在 MS-COCO 骨骼点检测和 Pascal VOC2012 分割任务中,与基准模型相比,PSA 的加入带来了 2-4 个点的性能提升。

方法

整体架构

     论文提出了一种基于 Polarized Self-Attention (PSA) 的模型结构,将 PSA 模块作为一种可插拔组件集成到主干网络(如 ResNet 或 HRNet)的残差块中,通过通道注意力和空间注意力分支保持高分辨率特征表示,并采用 Softmax-Sigmoid 组合建模逐像素回归的分布。模型整体通过 PSA 增强特征表达能力,并结合解码器生成高精度的关键点热图或语义分割掩码,适用于各种高分辨率的逐像素任务,同时保持较低的计算开销和内存需求。

1. 输入与特征提取

  • 输入数据经过一个深度卷积神经网络(如 ResNet 或 HRNet)作为 主干网络(backbone),用于提取多尺度的特征。

  • 主干网络生成一个高分辨率的特征图,作为 PSA 模块的输入。

2. Polarized Self-Attention 模块(PSA)

PSA 模块是论文的核心设计,用于增强主干网络提取的特征,其结构由两个独立的分支构成:

  • 通道注意力分支(Channel-only Attention)

    • 通过压缩空间维度,仅在通道维度上操作,保持高分辨率的通道信息。

    • 使用 Softmax 和 Sigmoid 的组合进行非线性建模,增强特征的表达能力。

  • 空间注意力分支(Spatial-only Attention)

    • 通过全局池化压缩通道维度,仅在空间维度上操作,保留高分辨率的空间信息。

    • 同样采用 Softmax 和 Sigmoid 的非线性组合,适配逐像素回归的输出分布(如高斯分布或二项分布)。

  • 组合方式

    • PSA 模块的两个分支可以通过**并行(parallel)序列化(sequential)**的方式组合。

    • 并行方式:通道和空间分支的输出直接相加。

    • 序列化方式:空间分支的输出作为通道分支的输入,进一步细化特征。

3. 集成 PSA 到主干网络:

  • PSA 模块以**可插拔模块(plug-and-play)**的形式集成到主干网络中,具体如下:

    • 在主干网络的每个残差块(residual block)中的第一层 3×3 卷积后插入一个 PSA 模块。

    • 这一设计确保 PSA 能够对多尺度特征进行全局增强,同时保留高分辨率信息。

4. 解码器与输出层

  • 解码器部分根据具体任务的需求进行调整:

    • 关键点回归任务:解码器生成关键点热图(通常是二维高斯分布)。

    • 语义分割任务:解码器生成逐像素的分类掩码(如多通道的二项分布)。

  • PSA 的输出直接融入到解码器,提升逐像素预测的精度。

即插即用模块作用

PSAN 作为一个即插即用模块

  • 增强逐像素回归的特征表达能力

    • 通过通道注意力和空间注意力分支分别建模高分辨率的通道和空间信息,PSA 模块能更好地捕获细粒度的上下文依赖,提升逐像素预测的准确性。

  • 保留高分辨率信息

    • PSA 通过极化滤波(Polarized Filtering),避免传统注意力机制中分辨率的丢失,保留特征图的空间细节和边界信息。

  • 适配输出分布

    • PSA 结合 Softmax 和 Sigmoid 的非线性设计,能够适配关键点热图的高斯分布和语义分割掩码的二项分布,从而减轻学习复杂度。

消融实验结果

  • 单独分支的效果

    • 通道注意力分支(Ach)优于空间注意力分支(Asp),说明通道维度的注意力对逐像素回归的影响更显著。

    • 单独的 Ach 和 Asp 相较于基准模型均有显著性能提升,分别提高了 4.1 和 2.8 AP,表明 PSA 的两种注意力分支都有益于特征增强。

  • 分支组合的效果

    • 并行组合([Ach | Asp])和序列组合(Asp(Ach))的性能相当,两者的 AP 提升分别为 4.3 和 4.4,表明 PSA 模块的通道和空间分支已经充分挖掘了其在各自维度上的表示能力,组合方式对性能的影响较小。

  • 与其他注意力模块的比较

    • PSA 模块相比其他注意力机制(如 Non-Local、GC、SE、CBAM)表现更优,特别是 PSA 的通道注意力分支 Ach 在轻量化的情况下优于 GC 和 SE。

    • PSA 的并行和序列组合方式在相似的计算开销下,显著优于 CBAM 等传统通道+空间注意力模块。

  • 模型效率

  • PSA 模块在保留高分辨率信息的同时,保持了较低的计算开销(Flops)和参数数量(mPara),并且推理时间和内存使用也较为高效,进一步验证了其轻量化设计的优势。

即插即用模块

import numpy as np
import torch
from torch import nn
from torch.nn import init

# 论文地址:https://arxiv.org/pdf/2107.00782
# 论文:Polarized Self-Attention: Towards High-quality Pixel-wise Regression
class ParallelPolarizedSelfAttention(nn.Module):

    def __init__(self, channel=512):
        super().__init__()
        self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))
        self.softmax_channel=nn.Softmax(1)
        self.softmax_spatial=nn.Softmax(-1)
        self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))
        self.ln=nn.LayerNorm(channel)
        self.sigmoid=nn.Sigmoid()
        self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.agp=nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        b, c, h, w = x.size()

        #Channel-only Self-Attention
        channel_wv=self.ch_wv(x) #bs,c//2,h,w
        channel_wq=self.ch_wq(x) #bs,1,h,w
        channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
        channel_wq=self.softmax_channel(channel_wq)
        channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
        channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1
        channel_out=channel_weight*x

        #Spatial-only Self-Attention
        spatial_wv=self.sp_wv(x) #bs,c//2,h,w
        spatial_wq=self.sp_wq(x) #bs,c//2,h,w
        spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
        spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
        spatial_wq=self.softmax_spatial(spatial_wq)
        spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
        spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
        spatial_out=spatial_weight*x
        out=spatial_out+channel_out
        return out


class SequentialPolarizedSelfAttention(nn.Module):

    def __init__(self, channel=512):
        super().__init__()
        self.ch_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.ch_wq=nn.Conv2d(channel,1,kernel_size=(1,1))
        self.softmax_channel=nn.Softmax(1)
        self.softmax_spatial=nn.Softmax(-1)
        self.ch_wz=nn.Conv2d(channel//2,channel,kernel_size=(1,1))
        self.ln=nn.LayerNorm(channel)
        self.sigmoid=nn.Sigmoid()
        self.sp_wv=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.sp_wq=nn.Conv2d(channel,channel//2,kernel_size=(1,1))
        self.agp=nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        b, c, h, w = x.size()

        #Channel-only Self-Attention
        channel_wv=self.ch_wv(x) #bs,c//2,h,w
        channel_wq=self.ch_wq(x) #bs,1,h,w
        channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
        channel_wq=self.softmax_channel(channel_wq)
        channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
        channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1
        channel_out=channel_weight*x

        #Spatial-only Self-Attention
        spatial_wv=self.sp_wv(channel_out) #bs,c//2,h,w
        spatial_wq=self.sp_wq(channel_out) #bs,c//2,h,w
        spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
        spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
        spatial_wq=self.softmax_spatial(spatial_wq)
        spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
        spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
        spatial_out=spatial_weight*channel_out
        return spatial_out
if __name__ == '__main__':
    input=torch.randn(1,512,7,7)
    block = SequentialPolarizedSelfAttention(channel=512)
    output=block(input)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章