论文介绍
题目:Learning A Sparse Transformer Network for Effective Image Deraining
论文地址:https://arxiv.org/pdf/2303.11950
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
稀疏Transformer设计(Sparse Transformer)
提出了一种新的稀疏Transformer架构,用于高效的图像去雨任务。与传统Transformer不同,该模型通过引入一种可学习的Top-K选择操作,仅保留最相关的自注意力值,从而减少无用特征的干扰,提升图像恢复的质量。Top-K稀疏注意力机制(Top-K Sparse Attention, TKSA)
传统的Transformer通常计算所有Query-Key对的相似性,但这可能引入无关或噪声信息。该方法通过Top-K选择,仅保留对特定Query最有用的K个注意力值,大幅优化了特征聚合的过程。混合尺度前馈网络(Mixed-Scale Feed-Forward Network, MSFN)
在传统前馈网络中引入多尺度卷积操作,以提取不同尺度的图像局部信息,从而更好地建模和处理复杂的雨条纹。专家特征补偿器(Mixture of Experts Feature Compensator, MEFC)
在模型的早期和末期阶段引入一个专家特征补偿器,通过多个稀疏CNN操作来实现数据和内容的稀疏性联合探索。这种设计提高了图像恢复的细节保留能力。动态学习的稀疏性控制
在Top-K选择操作中引入了动态可学习的稀疏性范围(例如[1/2, 4/5]),允许模型根据输入内容自适应调整稀疏程度,提升了灵活性和泛化能力。综合性能提升
在多个合成和真实数据集上的实验表明,该方法在PSNR、SSIM等指标上优于现有的最新方法,同时在去除雨条纹和保留图像细节方面表现更加出色。
方法
整体架构
论文提出了一种基于编码器-解码器框架的稀疏Transformer模型(DRSformer),通过引入Top-K稀疏注意力机制(TKSA)动态选择最相关的特征,减少无用信息干扰,同时结合混合尺度前馈网络(MSFN)提取多尺度局部信息。在网络的早期和末期阶段加入专家特征补偿模块(MEFC),进一步优化特征表示。该模型通过残差学习方法恢复清晰图像,并通过跳跃连接保留细节,在多种去雨场景中展现了卓越性能和高效性。
1. 输入与特征嵌入
输入的雨图像
I r a i n ∈ R H × W × 3 被分割成重叠的图像块,并通过 卷积操作嵌入为特征表示。3 × 3 3 \times 3 每层的特征空间有不同的空间分辨率和通道数,以提取多尺度信息。
2. 编码器-解码器架构
编码器:
多个稀疏Transformer块(Sparse Transformer Block, STB)堆叠而成,每个块包括:
特征的降采样通过Pixel-Unshuffle操作完成,减少空间维度,增强计算效率。
Top-K稀疏注意力(TKSA):通过动态选择机制保留最相关的K个自注意力值。
混合尺度前馈网络(MSFN):结合多尺度卷积(如
和3 × 3 3 \times 3 ),捕获不同尺度的局部信息。5 × 5 5 \times 5
解码器:
对应编码器的反向结构,通过Pixel-Shuffle操作进行上采样,逐步恢复图像的分辨率。
使用跳跃连接(Skip Connections)跨层传递特征,以保留细节并稳定训练。
3. 专家特征补偿器(MEFC)
在模型的早期和末期阶段,加入混合专家模块来进一步优化特征表示。
该模块结合多种稀疏CNN操作(如平均池化、分离卷积和膨胀卷积),在不同感受野上对特征进行补偿和增强。
4. 输出结果
最终的去雨结果通过:
其中I d e r a i n = F ( I r a i n ) + I r a i n I_{derain} = F(I_{rain}) + I_{rain} 是整个网络的输出,基于残差学习方法增强去雨效果。F ( ⋅ ) F(\cdot)
5. 网络优化与损失函数
使用
范数作为损失函数:L 1 L_1 其中L = ∥ I d e r a i n − I g t ∥ 1 L = \| I_{derain} - I_{gt} \|_1 是去雨图像的真实值。I g t I_{gt}
即插即用模块作用
TKSA 作为一个即插即用模块:
图像复原任务
图像去雨、去雾、去噪等低级视觉任务中,需处理复杂的空间依赖和局部信息。
图像超分辨率,需要在不同尺度和特征间选择最有效的注意力。
高效的Transformer架构
在Transformer应用于图像或视频处理任务时,TKSA可减少不相关特征的干扰,适用于视频去雨、视频去噪等时空特征相关性强的任务。
在对大分辨率输入的任务(如高分辨率图像生成、遥感图像处理)中,TKSA通过稀疏选择减少计算成本,提升效率。
跨领域任务
自然语言处理(NLP):处理长文本的任务中(如机器翻译、摘要生成),可以使用TKSA聚焦于最相关的上下文内容。
多模态学习:在图像与文本交互场景下(如视觉问答、图文生成),通过稀疏注意力选择最重要的跨模态特征。
消融实验结果
对比了传统前馈网络(FN)、深度卷积前馈网络(DFN)、门控深度卷积前馈网络(GDFN)以及混合尺度前馈网络(MSFN)的性能。结果表明,MSFN通过结合多尺度信息显著提升了PSNR和SSIM,相比GDFN获得了0.21 dB的PSNR提升。
比较了有无MEFC模块以及不同专家数量对性能的影响。结果显示,添加MEFC模块后性能显著提升,尤其是采用多专家结构(如8个专家)能够进一步增强图像细节恢复能力。
即插即用模块
import torch
import torch.nn as nn
from einops import rearrange
class Attention(nn.Module):
def __init__(self, dim, num_heads, bias):
super(Attention, self).__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(dim * 3, dim * 3, kernel_size=3, stride=1, padding=1, groups=dim * 3, bias=bias)
self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.attn_drop = nn.Dropout(0.)
self.attn1 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.attn2 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.attn3 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
self.attn4 = torch.nn.Parameter(torch.tensor([0.2]), requires_grad=True)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv_dwconv(self.qkv(x))
q, k, v = qkv.chunk(3, dim=1)
q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
_, _, C, _ = q.shape
mask1 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
mask2 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
mask3 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
mask4 = torch.zeros(b, self.num_heads, C, C, device=x.device, requires_grad=False)
attn = (q @ k.transpose(-2, -1)) * self.temperature
index = torch.topk(attn, k=int(C/2), dim=-1, largest=True)[1]
mask1.scatter_(-1, index, 1.)
attn1 = torch.where(mask1 > 0, attn, torch.full_like(attn, float('-inf')))
index = torch.topk(attn, k=int(C*2/3), dim=-1, largest=True)[1]
mask2.scatter_(-1, index, 1.)
attn2 = torch.where(mask2 > 0, attn, torch.full_like(attn, float('-inf')))
index = torch.topk(attn, k=int(C*3/4), dim=-1, largest=True)[1]
mask3.scatter_(-1, index, 1.)
attn3 = torch.where(mask3 > 0, attn, torch.full_like(attn, float('-inf')))
index = torch.topk(attn, k=int(C*4/5), dim=-1, largest=True)[1]
mask4.scatter_(-1, index, 1.)
attn4 = torch.where(mask4 > 0, attn, torch.full_like(attn, float('-inf')))
attn1 = attn1.softmax(dim=-1)
attn2 = attn2.softmax(dim=-1)
attn3 = attn3.softmax(dim=-1)
attn4 = attn4.softmax(dim=-1)
out1 = (attn1 @ v)
out2 = (attn2 @ v)
out3 = (attn3 @ v)
out4 = (attn4 @ v)
out = out1 * self.attn1 + out2 * self.attn2 + out3 * self.attn3 + out4 * self.attn4
out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
out = self.project_out(out)
return out
if __name__ == '__main__':
block = Attention(dim=3, num_heads=3, bias=False)
input = torch.rand(32, 3, 224, 224)
output = block(input)
print(input.size()) print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文