即插即用稀疏注意力机制TKSA,涨点起飞起飞了

文摘   2025-01-16 17:20   上海  

论文介绍

题目:Learning A Sparse Transformer Network for Effective Image Deraining

论文地址:https://arxiv.org/pdf/2303.11950

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 稀疏Transformer设计(Sparse Transformer)
    提出了一种新的稀疏Transformer架构,用于高效的图像去雨任务。与传统Transformer不同,该模型通过引入一种可学习的Top-K选择操作,仅保留最相关的自注意力值,从而减少无用特征的干扰,提升图像恢复的质量。

  • Top-K稀疏注意力机制(Top-K Sparse Attention, TKSA)
    传统的Transformer通常计算所有Query-Key对的相似性,但这可能引入无关或噪声信息。该方法通过Top-K选择,仅保留对特定Query最有用的K个注意力值,大幅优化了特征聚合的过程。

  • 混合尺度前馈网络(Mixed-Scale Feed-Forward Network, MSFN)
    在传统前馈网络中引入多尺度卷积操作,以提取不同尺度的图像局部信息,从而更好地建模和处理复杂的雨条纹。

  • 专家特征补偿器(Mixture of Experts Feature Compensator, MEFC)
    在模型的早期和末期阶段引入一个专家特征补偿器,通过多个稀疏CNN操作来实现数据和内容的稀疏性联合探索。这种设计提高了图像恢复的细节保留能力。

  • 动态学习的稀疏性控制
    在Top-K选择操作中引入了动态可学习的稀疏性范围(例如[1/2, 4/5]),允许模型根据输入内容自适应调整稀疏程度,提升了灵活性和泛化能力。

  • 综合性能提升
    在多个合成和真实数据集上的实验表明,该方法在PSNR、SSIM等指标上优于现有的最新方法,同时在去除雨条纹和保留图像细节方面表现更加出色。

方法

整体架构

       论文提出了一种基于编码器-解码器框架的稀疏Transformer模型(DRSformer),通过引入Top-K稀疏注意力机制(TKSA)动态选择最相关的特征,减少无用信息干扰,同时结合混合尺度前馈网络(MSFN)提取多尺度局部信息。在网络的早期和末期阶段加入专家特征补偿模块(MEFC),进一步优化特征表示。该模型通过残差学习方法恢复清晰图像,并通过跳跃连接保留细节,在多种去雨场景中展现了卓越性能和高效性。

1. 输入与特征嵌入

  • 输入的雨图像IrainRH×W×3 被分割成重叠的图像块,并通过 3×33 \times 3 卷积操作嵌入为特征表示。

  • 每层的特征空间有不同的空间分辨率和通道数,以提取多尺度信息。


2. 编码器-解码器架构

  • 编码器

    • 多个稀疏Transformer块(Sparse Transformer Block, STB)堆叠而成,每个块包括:

    • 特征的降采样通过Pixel-Unshuffle操作完成,减少空间维度,增强计算效率。

  1. Top-K稀疏注意力(TKSA):通过动态选择机制保留最相关的K个自注意力值。

  2. 混合尺度前馈网络(MSFN):结合多尺度卷积(如3×33 \times 3 和5×55 \times 5),捕获不同尺度的局部信息。

  • 解码器

    • 对应编码器的反向结构,通过Pixel-Shuffle操作进行上采样,逐步恢复图像的分辨率。

    • 使用跳跃连接(Skip Connections)跨层传递特征,以保留细节并稳定训练。


    3. 专家特征补偿器(MEFC)

    • 在模型的早期和末期阶段,加入混合专家模块来进一步优化特征表示。

    • 该模块结合多种稀疏CNN操作(如平均池化、分离卷积和膨胀卷积),在不同感受野上对特征进行补偿和增强。


    4. 输出结果

    • 最终的去雨结果通过:Iderain=F(Irain)+IrainI_{derain} = F(I_{rain}) + I_{rain}其中F()F(\cdot) 是整个网络的输出,基于残差学习方法增强去雨效果。


    5. 网络优化与损失函数

    • 使用L1L_1范数作为损失函数:L=IderainIgt1L = \| I_{derain} - I_{gt} \|_1其中IgtI_{gt} 是去雨图像的真实值。

    即插即用模块作用

    TKSA 作为一个即插即用模块

    • 图像复原任务

      • 图像去雨、去雾、去噪等低级视觉任务中,需处理复杂的空间依赖和局部信息。

      • 图像超分辨率,需要在不同尺度和特征间选择最有效的注意力。

    • 高效的Transformer架构

      • 在Transformer应用于图像或视频处理任务时,TKSA可减少不相关特征的干扰,适用于视频去雨、视频去噪等时空特征相关性强的任务。

      • 在对大分辨率输入的任务(如高分辨率图像生成、遥感图像处理)中,TKSA通过稀疏选择减少计算成本,提升效率。

    • 跨领域任务

      • 自然语言处理(NLP):处理长文本的任务中(如机器翻译、摘要生成),可以使用TKSA聚焦于最相关的上下文内容。

      • 多模态学习:在图像与文本交互场景下(如视觉问答、图文生成),通过稀疏注意力选择最重要的跨模态特征。

    消融实验结果

    对比了传统前馈网络(FN)、深度卷积前馈网络(DFN)、门控深度卷积前馈网络(GDFN)以及混合尺度前馈网络(MSFN)的性能。结果表明,MSFN通过结合多尺度信息显著提升了PSNR和SSIM,相比GDFN获得了0.21 dB的PSNR提升。

    比较了有无MEFC模块以及不同专家数量对性能的影响。结果显示,添加MEFC模块后性能显著提升,尤其是采用多专家结构(如8个专家)能够进一步增强图像细节恢复能力。

    即插即用模块

    import torch
    import torch.nn as nn
    from einops import rearrange

    class Attention(nn.Module):
        def __init__(self, dim, num_heads, bias):
            super(Attention, self).__init__()
            self.num_heads = num_heads

            self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

            self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
            self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
            self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
            self.attn_drop = nn.Dropout(0.)

            self.attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
            self.attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
            self.attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
            self.attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)

        def forward(self, x):
            b, c, h, w = x.shape

            qkv = self.qkv_dwconv(self.qkv(x))
            q, k, v = qkv.chunk(3, dim=1)

            q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
            k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
            v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

            q = torch.nn.functional.normalize(q, dim=-1)
            k = torch.nn.functional.normalize(k, dim=-1)

            _, _, C, _ = q.shape

            mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
            mask2 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
            mask3 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
            mask4 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)

            attn = (q @ k.transpose(-2, -1)) * self.temperature

            index = torch.topk(attn, k=int(C/2), dim=-1, largest=True)[1]
            mask1.scatter_(-1, index, 1.)
            attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf')))

            index = torch.topk(attn, k=int(C*2/3), dim=-1, largest=True)[1]
            mask2.scatter_(-1, index, 1.)
            attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('-inf')))

            index = torch.topk(attn, k=int(C*3/4), dim=-1, largest=True)[1]
            mask3.scatter_(-1, index, 1.)
            attn3 = torch.where(mask3 > 0, attn, torch.full_like(attn, float('-inf')))

            index = torch.topk(attn, k=int(C*4/5), dim=-1, largest=True)[1]
            mask4.scatter_(-1, index, 1.)
            attn4 = torch.where(mask4 > 0, attn, torch.full_like(attn, float('-inf')))

            attn1 = attn1.softmax(dim=-1)
            attn2 = attn2.softmax(dim=-1)
            attn3 = attn3.softmax(dim=-1)
            attn4 = attn4.softmax(dim=-1)

            out1 = (attn1 @ v)
            out2 = (attn2 @ v)
            out3 = (attn3 @ v)
            out4 = (attn4 @ v)

            out = out1 * self.attn1 + out2 * self.attn2 + out3 * self.attn3 + out4 * self.attn4

            out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)

            out = self.project_out(out)
            return out


    if __name__ == '__main__':
        block = Attention(dim=3, num_heads=3, bias=False)
        input = torch.rand(32, 3, 224, 224)
        output = block(input)
        print(input.size())    print(output.size())

    便捷下载方式

    浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

    更多分析可见原文


    ai缝合大王
    聚焦AI前沿,分享相关技术、论文,研究生自救指南
     最新文章