(ACCV 2024) 局部重要性驱动注意力机制LIA,涨点起飞起飞了!

文摘   2024-12-22 17:20   上海  

论文介绍

题目:PlainUSR: Chasing Faster ConvNet for Efficient Super-Resolution

论文地址:

https://openaccess.thecvf.com/content/ACCV2024/papers/Wang_PlainUSR_Chasing_Faster_ConvNet_for_Efficient_Super-Resolution_ACCV_2024_paper.pdf

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 模块重设计

  • 提出了Reparameterized MBConv (RepMBConv),通过重新参数化技术,将复杂的训练时块转化为推理时的简单卷积块,同时在不牺牲性能的情况下实现了2.9倍的加速通过减少内存访问并优化计算与内存之间的平衡,RepMBConv在低延迟的同时保持了较高的特征提取能力。

  • 注意力机制优化

  • 引入了Local Importance-based Attention (LIA),这是一个局部重要性驱动的注意力机制,相比现有的注意力方法,LIA既能实现高阶信息交互,又保持较低的计算复杂度。与传统方法相比,LIA在多个基准测试中展现了更优的性能和效率。

  • 主干网络简化

  • 提出了PlainU-Net,通过在通道维度上进行分割与拼接,实现了高效的层次化特征编码与解码。该设计在推理时以通道索引的形式简化了计算,从而减少了总体推理延迟。

  • 整体框架

  • 将上述模块集成到一个统一的框架PlainUSR中。该框架在图像超分辨率任务中表现出低延迟、良好的扩展性和竞争性的重建质量。在实验中,PlainUSR与多个主流模型(如IMDN、RFDN、NGswin等)相比,展现了显著的速度提升和性能保持。例如,与NGswin相比,PlainUSR-L在保持相似质量的同时快了16.4倍

  • 实验结果

    在多个基准测试集上,PlainUSR在推理速度、内存占用和模型参数量等方面优于当前主流方法,同时在PSNR和SSIM等质量指标上也具有竞争力。

方法

整体架构

     论文提出的 PlainUSR 框架是一个基于 U-Net 的高效超分辨率模型,通过浅层特征提取模块、主干网络(结合重参数化的 RepMBConv 和局部重要性注意力模块 LIA)、以及重建模块协同工作,从低分辨率图像生成高分辨率图像。主干网络采用 PlainU-Net 结构,利用通道维度的分割与拼接实现高效特征编码与解码,整体框架在低延迟、低内存占用和高性能之间实现了出色的平衡

关键模块

1) 浅层特征提取

通过简单的卷积层从输入的低分辨率图像中提取浅层特征。

2) 主干网络

  • PlainU-Net

    • 设计上以通道分割和拼接的形式对特征进行层次化编码和解码。

    • 在训练阶段采用分层特征处理,在推理阶段则通过通道索引简化处理以降低计算复杂度。

  • RepMBConv

    • 在主干网络中堆叠多个 RepMBConv 块,用于有效地提取特征,同时减少推理延迟。

    • 通过重新参数化,将复杂的训练结构转化为推理时的单一卷积。

  • LIA (Local Importance-based Attention)

    • 在浅层特征的基础上,通过计算局部重要性并利用注意力机制增强特定区域的特征响应。

    • 与其他注意力机制相比,LIA 实现了高效的高阶信息交互,同时保持了较低的计算成本。

3) 重建模块

最后的卷积层将经过主干网络处理的特征映射重建为高分辨率图像。

即插即用模块作用

LIA 作为一个即插即用模块

  • 提升特征表达能力

  • LIA 通过计算局部重要性并结合注意力机制,自适应地增强有用特征,抑制无关噪声,使模型能够更好地捕捉图像的细节和结构。

  • 降低计算复杂度

  • 与传统的高阶注意力机制(如非局部注意力 NLSA)相比,LIA 使用简单操作实现高阶信息交互,极大降低了计算复杂度和延迟。

  • 适配轻量化模型

  • LIA 的设计注重效率,特别是减少了内存占用和计算需求,非常适合嵌入到轻量化的神经网络中。

  • 高效的高阶信息建模

  • 虽然 LIA 是局部注意力机制,但通过局部重要性建模和门机制,成功实现了类似于全局注意力的高阶信息交互效果。

消融实验结果

  • 结果

    • RepMBConv 的性能(PSNR 和 SSIM)接近原始的 MBConv,但推理速度显著提升。

    • 相比其他重新参数化策略(如 RepVGG 和 ECB),RepMBConv 在推理延迟、内存占用和激活量等方面表现更优。

  • 说明:证明了 RepMBConv 的有效性,不仅保留了性能,还通过减少内存访问显著降低了延迟


  • 结果

    • LIA 相比传统的 1 阶注意力(如 SE 和 ESA),性能更优且延迟更低。

    • 相比复杂的 2 阶注意力(如 NLSA),LIA 的计算量显著减少,延迟降低了 23 倍,同时性能仅有微小差距。

  • 说明:LIA 在质量与效率之间达成了良好的平衡,适合高效的超分辨率任务。


  • 结果

    • 去除局部重要性或门机制会导致性能明显下降(PSNR 和 SSIM 均下降)。

    • 每个组件在 LIA 中都发挥了重要作用。

  • 说明:验证了 LIA 中各部分设计的必要性和有效性

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F

# 论文题目:PlainUSR: Chasing Faster ConvNet for Efficient Super-Resolution
# 论文地址:https://openaccess.thecvf.com/content/ACCV2024/papers/Wang_PlainUSR_Chasing_Faster_ConvNet_for_Efficient_Super-Resolution_ACCV_2024_paper.pdf

class SoftPooling2D(torch.nn.Module):
    def __init__(self,kernel_size,stride=None,padding=0):
        super(SoftPooling2D, self).__init__()
        self.avgpool = torch.nn.AvgPool2d(kernel_size,stride,padding, count_include_pad=False)
    def forward(self, x):
        x_exp = torch.exp(x)
        x_exp_pool = self.avgpool(x_exp)
        x = self.avgpool(x_exp*x)
        return x/x_exp_pool
    
class LocalAttention(nn.Module):
    ''' attention based on local importance'''
    def __init__(self, channels, f=16):
        super().__init__()
        self.body = nn.Sequential(
            # sample importance
            nn.Conv2d(channels, f, 1),
            SoftPooling2D(7, stride=3),
            nn.Conv2d(f, f, kernel_size=3, stride=2, padding=1),
            nn.Conv2d(f, channels, 3, padding=1),
            # to heatmap
            nn.Sigmoid(),
        )
        self.gate = nn.Sequential(
            nn.Sigmoid(),
        )
    def forward(self, x):
        ''' forward '''
        # interpolate the heat map
        g = self.gate(x[:,:1].clone())
        w = F.interpolate(self.body(x), (x.size(2), x.size(3)), mode='bilinear', align_corners=False)

        return x * w * g #(w + g) #self.gate(x, w)

if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    block = LocalAttention(channels=32).to(device)
    input = torch.rand(1, 32, 256, 256).to(device)

    output = block(input)
    print(input.shape)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章