点击下方卡片,关注“AI前沿速递”公众号
点击下方卡片,关注“AI前沿速递”公众号
各种重磅干货,第一时间送达
各种重磅干货,第一时间送达
标题:Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration
论文链接:https://openaccess.thecvf.com/content/CVPR2024/papers/Zhou_Adapt_or_Perish_Adaptive_Sparse_Transformer_with_Attentive_Feature_Refinement_CVPR_2024_paper.pdf
代码链接:https://github.com/joshyZhou/AST
来源:CVPR 2024
ASSA模块
定义与结构
- 双分支模式:ASSA模块采用双分支模式,包括稀疏自注意力分支(SSA)和密集自注意力分支(DSA)。SSA用于过滤掉低查询-键匹配分数的负面影响,而DSA则确保足够的信息流通过网络,以学习判别性表示。- 自适应加权:通过自适应的加权机制,将SSA和DSA的输出进行融合。这种设计使模型能够动态调整稀疏与密集注意力的权重,从而根据具体的任务和输入内容有效地平衡信息流,既能过滤掉无关特征,又保留必要的信息。
工作原理
- 稀疏自注意力(SSA):使用基于ReLU的稀疏注意力机制,过滤掉查询与键之间低匹配的无关交互,减少无效特征的参与,帮助聚焦在最有价值的信息交互上。- 密集自注意力(DSA):采用标准的softmax密集注意力机制,补充SSA,以确保在稀疏处理过程中不会丢失关键信息。
应用场景
- 图像恢复任务:ASSA最早应用于图像恢复任务,通过减少噪声交互并保留重要的特征信息,显著提升了模型的处理效率。- 目标检测:在YOLOv11模型中引入ASSA机制,可以优化特征提取过程,减少特征冗余或噪声干扰,进一步提升模型对复杂场景的适应性和检测性能。- 时间序列预测:将ASSA机制和LSTM处理后的特征输入到Transformer网络进行预测,可以进一步提高预测的准确性和效率,在顶刊ETTh开源数据集达到了不错的效果- 医疗影像处理:对于需要从复杂的医学图像中提取关键特征的任务,如癌症检测、CT或MRI图像分析,ASSA能够有效过滤无用信息,提升对病灶区域的关注和检测精度。
FRFN模块
定义与结构
特征细化前馈网络(Feature Refinement Feed-forward Network, FRFN)是一种专门设计的深度学习结构,旨在提高图像处理任务中的特征表示能力。其核心设计理念是通过逐层细化和优化特征图,从而实现更高的分类和检测精度。
- 线性层1:将输入特征维度扩展到隐藏维度的两倍,并通过激活函数进行非线性变换。- 深度可分离卷积:对扩展后的特征进行深度可分离卷积操作,进一步提取局部特征。- 线性层2:将特征维度压缩回原始维度。- 部分卷积:对部分特征通道进行卷积操作,以增强特征中的有用元素。- 门控机制:通过门控机制减少冗余信息的处理负担,提升特征的纯净度。
工作原理
FRFN模块的工作原理可以概括为以下几个步骤:- 特征扩展:通过线性层将输入特征的维度扩展到隐藏维度的两倍,增加特征的表达能力。- 深度可分离卷积:对扩展后的特征进行深度可分离卷积操作,提取局部特征,同时减少计算量。- 特征压缩:通过线性层将特征维度压缩回原始维度,减少特征冗余。- 部分卷积:对部分特征通道进行卷积操作,增强特征中的有用元素。- 门控机制:通过门控机制减少冗余信息的处理负担,提升特征的纯净度。
应用场景
- 图像恢复任务:在去噪、去雨滴、去雾、超分辨率等场景中,FRFN能够有效减少通道维度上的冗余信息,提升重要特征的表达,从而提高恢复图像的质量和细节还原能力。
- 图像分类和检测任务:在处理复杂图像时,FRFN可以通过精炼和增强有价值的特征,帮助模型更准确地分类或检测目标,特别是在多类或高维度特征的任务中表现出色。
- 高分辨率图像处理:在高分辨率图像或视频处理中,FRFN能够减少不必要的信息流,增强重要特征的表达,使模型更高效地处理大规模图像数据。
- 医学图像分析:在处理复杂的医学影像时,FRFN有助于减少噪声和干扰,聚焦于病变区域的关键特征,提升医疗影像分析的精度和效率。
集成到YOLOv11和RT-DETR
将FRFN模块集成到YOLOv11和RT-DETR模型中的步骤如下:
创建脚本文件:在 ultralytics->nn
路径下创建blocks.py
脚本,用于存放模块代码。复制代码:将上述FRFN模块的代码复制到 blocks.py
脚本中。更改 task.py
文件:在ultralytics->nn->modules->task.py
中导入FRFN模块。修改模型配置:在模型配置文件中添加FRFN模块的配置。 训练模型:创建训练脚本,使用修改后的模型配置进行训练。
代码实现
ASSA模块代码
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_
from einops import repeat
class LinearProjection(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, bias=True):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = nn.Linear(dim, inner_dim, bias=bias)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias)
self.dim = dim
self.inner_dim = inner_dim
def forward(self, x, attn_kv=None):
B_, N, C = x.shape
if attn_kv is not None:
attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1)
else:
attn_kv = x
N_kv = attn_kv.size(1)
q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
q = q[0]
k, v = kv[0], kv[1]
return q, k, v
class WindowAttention_sparse(nn.Module):
def __init__(self, dim, win_size, num_heads=8, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.win_size = win_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.win_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
if token_projection == 'linear':
self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
else:
raise Exception("Projection error!")
self.token_projection = token_projection
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.relu = nn.ReLU()
self.w = nn.Parameter(torch.ones(2)) # 自适应权重参数
def forward(self, x, attn_kv=None, mask=None):
# 调整输入维度,从 (B, C, H, W) 转为 (B, H, W, C)
x = x.permute(0, 2, 3, 1).reshape(x.shape[0], x.shape[2] * x.shape[3], x.shape[1])
B_, N, C = x.shape
q, k, v = self.qkv(x, attn_kv)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
ratio = attn.size(-1) // relative_position_bias.size(-1)
relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N * ratio)
attn0 = self.softmax(attn)
attn1 = self.relu(attn) ** 2 # b,h,w,c
w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
attn = attn0 * w1 + attn1 * w2
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
x = x.reshape(x.shape[0], int(math.sqrt(x.shape[1])), int(math.sqrt(x.shape[1])), x.shape[2]).permute(0, 3, 1, 2)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
FRFN模块
import torch
import torch.nn as nn
from einops import rearrange
class FRFN(nn.Module):
def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0., use_eca=False):
super().__init__()
self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim * 2), act_layer())
self.dwconv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1),
act_layer()
)
self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
self.dim = dim
self.hidden_dim = hidden_dim
self.dim_conv = self.dim // 4
self.dim_untouched = self.dim - self.dim_conv
self.partial_conv3 = nn.Conv2d(self.dim_conv, self.dim_conv, 3, 1, 1, bias=False)
def forward(self, x):
c, bs, hh, hw = x.size()
x1, x2 = torch.split(x, [self.dim_conv, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
x = rearrange(x, 'b c h w -> b (h w) c', h=hh, w=hw)
x = self.linear1(x)
x_1, x_2 = x.chunk(2, dim=-1)
x_1 = rearrange(x_1, 'b (h w) c -> b c h w', h=hh, w=hw)
x_1 = self.dwconv(x_1)
x_1 = rearrange(x_1, 'b c h w -> b (h w) c', h=hh, w=hw)
x = x_1 * x_2
x = self.linear2(x)
return rearrange(x, 'b (h w) c -> b c h w', h=hh, w=hw)
确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。
建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。
【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。
您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”
长按添加AI前沿速递小助理