CVPR 2024 | 最新即插即用注意力机制+ffn层

文摘   2025-01-14 11:13   安徽  

点击下方卡片,关注“AI前沿速递”公众号

各种重磅干货,第一时间送达


标题:Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration

论文链接:https://openaccess.thecvf.com/content/CVPR2024/papers/Zhou_Adapt_or_Perish_Adaptive_Sparse_Transformer_with_Attentive_Feature_Refinement_CVPR_2024_paper.pdf

代码链接:https://github.com/joshyZhou/AST

来源:CVPR 2024

ASSA模块





定义与结构

- 双分支模式:ASSA模块采用双分支模式,包括稀疏自注意力分支(SSA)和密集自注意力分支(DSA)。SSA用于过滤掉低查询-键匹配分数的负面影响,而DSA则确保足够的信息流通过网络,以学习判别性表示。- 自适应加权:通过自适应的加权机制,将SSA和DSA的输出进行融合。这种设计使模型能够动态调整稀疏与密集注意力的权重,从而根据具体的任务和输入内容有效地平衡信息流,既能过滤掉无关特征,又保留必要的信息。

工作原理

- 稀疏自注意力(SSA):使用基于ReLU的稀疏注意力机制,过滤掉查询与键之间低匹配的无关交互,减少无效特征的参与,帮助聚焦在最有价值的信息交互上。- 密集自注意力(DSA):采用标准的softmax密集注意力机制,补充SSA,以确保在稀疏处理过程中不会丢失关键信息。

应用场景

- 图像恢复任务:ASSA最早应用于图像恢复任务,通过减少噪声交互并保留重要的特征信息,显著提升了模型的处理效率。- 目标检测:在YOLOv11模型中引入ASSA机制,可以优化特征提取过程,减少特征冗余或噪声干扰,进一步提升模型对复杂场景的适应性和检测性能。- 时间序列预测:将ASSA机制和LSTM处理后的特征输入到Transformer网络进行预测,可以进一步提高预测的准确性和效率,在顶刊ETTh开源数据集达到了不错的效果- 医疗影像处理:对于需要从复杂的医学图像中提取关键特征的任务,如癌症检测、CT或MRI图像分析,ASSA能够有效过滤无用信息,提升对病灶区域的关注和检测精度。

FRFN模块

定义与结构

特征细化前馈网络(Feature Refinement Feed-forward Network, FRFN)是一种专门设计的深度学习结构,旨在提高图像处理任务中的特征表示能力。其核心设计理念是通过逐层细化和优化特征图,从而实现更高的分类和检测精度。

- 线性层1:将输入特征维度扩展到隐藏维度的两倍,并通过激活函数进行非线性变换。- 深度可分离卷积:对扩展后的特征进行深度可分离卷积操作,进一步提取局部特征。- 线性层2:将特征维度压缩回原始维度。- 部分卷积:对部分特征通道进行卷积操作,以增强特征中的有用元素。- 门控机制:通过门控机制减少冗余信息的处理负担,提升特征的纯净度。

工作原理

FRFN模块的工作原理可以概括为以下几个步骤:- 特征扩展:通过线性层将输入特征的维度扩展到隐藏维度的两倍,增加特征的表达能力。- 深度可分离卷积:对扩展后的特征进行深度可分离卷积操作,提取局部特征,同时减少计算量。- 特征压缩:通过线性层将特征维度压缩回原始维度,减少特征冗余。- 部分卷积:对部分特征通道进行卷积操作,增强特征中的有用元素。- 门控机制:通过门控机制减少冗余信息的处理负担,提升特征的纯净度。

应用场景

- 图像恢复任务:在去噪、去雨滴、去雾、超分辨率等场景中,FRFN能够有效减少通道维度上的冗余信息,提升重要特征的表达,从而提高恢复图像的质量和细节还原能力。

- 图像分类和检测任务:在处理复杂图像时,FRFN可以通过精炼和增强有价值的特征,帮助模型更准确地分类或检测目标,特别是在多类或高维度特征的任务中表现出色。

- 高分辨率图像处理:在高分辨率图像或视频处理中,FRFN能够减少不必要的信息流,增强重要特征的表达,使模型更高效地处理大规模图像数据。

- 医学图像分析:在处理复杂的医学影像时,FRFN有助于减少噪声和干扰,聚焦于病变区域的关键特征,提升医疗影像分析的精度和效率。

集成到YOLOv11和RT-DETR

将FRFN模块集成到YOLOv11和RT-DETR模型中的步骤如下:

  1. 创建脚本文件:在ultralytics->nn路径下创建blocks.py脚本,用于存放模块代码。
  2. 复制代码:将上述FRFN模块的代码复制到blocks.py脚本中。
  3. 更改task.py文件:在ultralytics->nn->modules->task.py中导入FRFN模块。
  4. 修改模型配置:在模型配置文件中添加FRFN模块的配置。
  5. 训练模型:创建训练脚本,使用修改后的模型配置进行训练。

代码实现

ASSA模块代码

import torchimport torch.nn as nnfrom timm.models.layers import trunc_normal_from einops import repeat
class LinearProjection(nn.Module): def __init__(self, dim, heads=8, dim_head=64, bias=True): super().__init__() inner_dim = dim_head * heads self.heads = heads self.to_q = nn.Linear(dim, inner_dim, bias=bias) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias) self.dim = dim self.inner_dim = inner_dim
def forward(self, x, attn_kv=None): B_, N, C = x.shape if attn_kv is not None: attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1) else: attn_kv = x N_kv = attn_kv.size(1) q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) q = q[0] k, v = kv[0], kv[1] return q, k, v
class WindowAttention_sparse(nn.Module): def __init__(self, dim, win_size, num_heads=8, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.win_size = win_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing='ij')) # 2, Wh, Ww coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 relative_coords[:, :, 1] += self.win_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww self.register_buffer("relative_position_index", relative_position_index) trunc_normal_(self.relative_position_bias_table, std=.02) if token_projection == 'linear': self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) else: raise Exception("Projection error!") self.token_projection = token_projection self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) self.relu = nn.ReLU() self.w = nn.Parameter(torch.ones(2)) # 自适应权重参数
def forward(self, x, attn_kv=None, mask=None): # 调整输入维度,从 (B, C, H, W) 转为 (B, H, W, C) x = x.permute(0, 2, 3, 1).reshape(x.shape[0], x.shape[2] * x.shape[3], x.shape[1]) B_, N, C = x.shape q, k, v = self.qkv(x, attn_kv) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww ratio = attn.size(-1) // relative_position_bias.size(-1) relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio) attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N * ratio) attn0 = self.softmax(attn) attn1 = self.relu(attn) ** 2 # b,h,w,c w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w)) w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w)) attn = attn0 * w1 + attn1 * w2 attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) x = x.reshape(x.shape[0], int(math.sqrt(x.shape[1])), int(math.sqrt(x.shape[1])), x.shape[2]).permute(0, 3, 1, 2) return x
def extra_repr(self) -> str: return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'

FRFN模块

import torchimport torch.nn as nnfrom einops import rearrange
class FRFN(nn.Module): def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0., use_eca=False): super().__init__() self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim * 2), act_layer()) self.dwconv = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1), act_layer() ) self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) self.dim = dim self.hidden_dim = hidden_dim self.dim_conv = self.dim // 4 self.dim_untouched = self.dim - self.dim_conv self.partial_conv3 = nn.Conv2d(self.dim_conv, self.dim_conv, 3, 1, 1, bias=False)
def forward(self, x): c, bs, hh, hw = x.size() x1, x2 = torch.split(x, [self.dim_conv, self.dim_untouched], dim=1) x1 = self.partial_conv3(x1) x = torch.cat((x1, x2), 1) x = rearrange(x, 'b c h w -> b (h w) c', h=hh, w=hw) x = self.linear1(x) x_1, x_2 = x.chunk(2, dim=-1) x_1 = rearrange(x_1, 'b (h w) c -> b c h w', h=hh, w=hw) x_1 = self.dwconv(x_1) x_1 = rearrange(x_1, 'b c h w -> b (h w) c', h=hh, w=hw) x = x_1 * x_2 x = self.linear2(x) return rearrange(x, 'b (h w) c -> b c h w', h=hh, w=hw)

本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。


欢迎投稿

想要让高质量的内容更快地触达读者,降低他们寻找优质信息的成本吗?关键在于那些你尚未结识的人。他们可能掌握着你渴望了解的知识。【AI前沿速递】愿意成为这样的一座桥梁,连接不同领域、不同背景的学者,让他们的学术灵感相互碰撞,激发出无限可能。

【AI前沿速递】欢迎各高校实验室和个人在我们的平台上分享各类精彩内容,无论是最新的论文解读,还是对学术热点的深入分析,或是科研心得和竞赛经验的分享,我们的目标只有一个:让知识自由流动。

📝 投稿指南

  • 确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。

  • 建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。

  • 【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。

📬 投稿方式

  • 您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”


    长按添加AI前沿速递小助理


AI前沿速递
持续分享最新AI前沿论文成果
 最新文章