论文介绍
题目:Efficient Attention: Attention with Linear Complexities
论文地址:https://openaccess.thecvf.com/content/CVPR2022/papers/Xia_Vision_Transformer_With_Deformable_Attention_CVPR_2022_paper.pdf
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
提出可变形自注意力模块(Deformable Self-Attention Module):
传统的视觉Transformer模型(例如ViT和Swin Transformer)要么使用全局注意力,导致计算开销过大,要么使用固定的稀疏注意力模式,限制了建模长距离关系的能力。
本文提出了基于数据的可变形注意力机制,通过动态选择关键和数值对的位置,使注意力模块能够专注于相关区域,提升了特征捕捉能力和模型的灵活性。
构建可变形注意力Transformer主干模型(Deformable Attention Transformer, DAT):
通过整合可变形注意力模块,设计了一个通用的主干模型,适用于图像分类和密集预测任务。
其特点是灵活建模关键特征,同时保持线性空间复杂度。
高效的注意力机制设计:
论文通过共享的采样偏移组而不是每个查询单独计算偏移,显著降低了可变形注意力的计算复杂度,使其适合作为主干模型。
引入了一种变形的相对位置偏置机制(Deformable Relative Position Bias),进一步增强了注意力模块的表现。
在多个基准数据集上的出色表现:
在ImageNet-1K上,与Swin Transformer相比,其模型在分类任务的Top-1准确率上提升了0.7%。
在COCO对象检测任务中,对于不同大小的目标,该模型表现出更显著的优势,尤其是对于大型目标的检测提升了多达2.1个百分点。在ADE20K语义分割数据集上,mIoU的提升幅度在多个模型规模下均超过了1%。
方法
整体架构
这篇论文提出了一种分层结构的视觉Transformer模型,称为Deformable Attention Transformer (DAT),通过引入可变形注意力模块,在后两个阶段(Stage 3 和 Stage 4)灵活捕捉全局关系,同时结合前两个阶段的本地注意力机制,逐层提取多尺度特征。模型的整体设计交替使用局部和全局注意力模块,专注于重要区域,实现高效且灵活的特征建模,适用于图像分类、目标检测和语义分割等任务。
1. 模型的分层结构
论文的模型采用了与Swin Transformer类似的金字塔结构,分为四个阶段(Stages)。输入图像首先被划分为固定大小的patch嵌入,然后逐层处理,特征图的分辨率逐步降低,通道数逐步增加,以提取多尺度特征。各阶段的关键设计如下:
Stage 1 和 Stage 2:
使用**本地注意力模块(Local Attention Module)和Shift-Window Attention(移窗注意力)**进行特征学习。
这两个阶段主要捕捉局部特征。
Stage 3 和 Stage 4:
引入可变形注意力模块(Deformable Attention Module),替换原有的Shift-Window Attention。
可变形注意力模块使模型能够捕捉全局关系,并在局部增强的特征基础上进一步提取有意义的全局特征。
2. 可变形注意力模块的设计
核心机制:通过可变形的采样点,灵活选择关键点和数值点,专注于重要区域以增强注意力效果。
具体实现:
首先,生成均匀分布的参考点网格。
然后,通过偏移网络(Offset Network)生成每个参考点的偏移值。
最后,利用偏移点计算新的键(Key)和值(Value),并通过多头注意力机制进行特征聚合。
偏移点的位置还引入了一种变形的相对位置偏置,以进一步增强注意力效果。
3. 模块交替设计
在后两个阶段(Stage 3 和 Stage 4),采用交替的注意力模块设计:
本地注意力模块负责局部特征的聚合。
可变形注意力模块则用来捕捉全局特征。 这种设计结合了局部和全局的感受野,有助于模型学习更强的多尺度特征。
4. Patch 嵌入模块
输入图像首先被分割为大小为
5. 模型的具体变体
论文设计了三个变体(DAT-T、DAT-S 和 DAT-B),主要区别在于:
每一阶段的通道数(C)和堆叠的注意力模块数量(N)。
可变形注意力模块的头数和偏移组数(Offset Groups)。
即插即用模块作用
DAT 作为一个即插即用模块:
图像分类(Image Classification):
DAT作为主干网络的一部分,能够有效捕捉多尺度特征。
在大规模图像分类任务(如ImageNet-1K)中,DAT通过灵活的注意力机制提高了分类准确率。
目标检测(Object Detection):
作为目标检测模型(如RetinaNet、Mask R-CNN和Cascade Mask R-CNN)的主干网络,DAT能够有效建模对象的长距离依赖关系。
在COCO目标检测任务中,DAT对小目标和大目标的检测性能均有显著提升,尤其对大目标的检测提升尤为明显(最高提升+2.1 mAP)。
语义分割(Semantic Segmentation):
在语义分割任务(如ADE20K数据集)中,DAT被用于细粒度的特征分割建模。
通过灵活的注意力模式,DAT能够更好地识别复杂场景中的小物体和局部细节,提高分割的mIoU和mAcc。
密集预测任务(Dense Prediction Tasks):
DAT特别适用于需要多尺度建模的密集预测任务,如实例分割和场景理解任务。
它通过捕捉全局和局部关系,提高了对小物体和场景复杂区域的建模能力。
消融实验结果
内容:评估不同几何信息(如偏移和位置嵌入)的使用方式对模型性能的影响。实验包括以下几种配置:
是否使用偏移点(Offsets)。
是否使用相对位置嵌入(Relative Position Embedding)。
相比固定的偏置或深度卷积位置嵌入,论文提出的变形相对位置嵌入的效果。
结果:引入偏移点和变形相对位置嵌入分别带来了性能提升(+0.3%)。两者结合时,表现最佳,证明了它们在可变形注意力中的兼容性和有效性。
内容:探讨在不同的网络阶段引入可变形注意力的效果。实验逐步将可变形注意力模块替换掉原有的Shift-Window Attention模块,并观察性能变化。
结果:
仅在最后一个阶段使用可变形注意力,性能小幅提升(+0.1%)。
在最后两个阶段使用可变形注意力时,性能大幅提升(+0.7%),达到最佳。
在所有阶段都使用可变形注意力时,性能略有下降,表明早期阶段引入全局注意力可能会干扰局部特征学习。
结论:可变形注意力更适合后期阶段,用于捕捉全局关系,而早期阶段更适合局部建模。
即插即用模块
# 论文题目:Efficient Attention: Attention with Linear Complexities
# 论文链接:https://openaccess.thecvf.com/content/CVPR2022/papers/Xia_Vision_Transformer_With_Deformable_Attention_CVPR_2022_paper.pdf
import torch, einops
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from timm.models.layers import trunc_normal_
class LayerNormProxy(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = einops.rearrange(x, 'b c h w -> b h w c')
x = self.norm(x)
return einops.rearrange(x, 'b h w c -> b c h w')
class DAttention(nn.Module):
# Vision Transformer with Deformable Attention CVPR2022
# fixed_pe=True need adujust 640x640
def __init__(
self, channel, q_size, n_heads=8, n_groups=4,
attn_drop=0.0, proj_drop=0.0, stride=1,
offset_range_factor=4, use_pe=True, dwc_pe=True,
no_off=False, fixed_pe=False, ksize=3, log_cpb=False, kv_size=None
):
super().__init__()
n_head_channels = channel // n_heads
self.dwc_pe = dwc_pe
self.n_head_channels = n_head_channels
self.scale = self.n_head_channels ** -0.5
self.n_heads = n_heads
self.q_h, self.q_w = q_size
# self.kv_h, self.kv_w = kv_size
self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride
self.nc = n_head_channels * n_heads
self.n_groups = n_groups
self.n_group_channels = self.nc // self.n_groups
self.n_group_heads = self.n_heads // self.n_groups
self.use_pe = use_pe
self.fixed_pe = fixed_pe
self.no_off = no_off
self.offset_range_factor = offset_range_factor
self.ksize = ksize
self.log_cpb = log_cpb
self.stride = stride
kk = self.ksize
pad_size = kk // 2 if kk != stride else 0
self.conv_offset = nn.Sequential(
nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
LayerNormProxy(self.n_group_channels),
nn.GELU(),
nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
)
if self.no_off:
for m in self.conv_offset.parameters():
m.requires_grad_(False)
self.proj_q = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_k = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_v = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_out = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_drop = nn.Dropout(proj_drop, inplace=True)
self.attn_drop = nn.Dropout(attn_drop, inplace=True)
if self.use_pe and not self.no_off:
if self.dwc_pe:
self.rpe_table = nn.Conv2d(
self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
elif self.fixed_pe:
self.rpe_table = nn.Parameter(
torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
)
trunc_normal_(self.rpe_table, std=0.01)
elif self.log_cpb:
# Borrowed from Swin-V2
self.rpe_table = nn.Sequential(
nn.Linear(2, 32, bias=True),
nn.ReLU(inplace=True),
nn.Linear(32, self.n_group_heads, bias=False)
)
else:
self.rpe_table = nn.Parameter(
torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
)
trunc_normal_(self.rpe_table, std=0.01)
else:
self.rpe_table = None
@torch.no_grad()
def _get_ref_points(self, H_key, W_key, B, dtype, device):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
indexing='ij'
)
ref = torch.stack((ref_y, ref_x), -1)
ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2
return ref
@torch.no_grad()
def _get_q_grid(self, H, W, B, dtype, device):
ref_y, ref_x = torch.meshgrid(
torch.arange(0, H, dtype=dtype, device=device),
torch.arange(0, W, dtype=dtype, device=device),
indexing='ij'
)
ref = torch.stack((ref_y, ref_x), -1)
ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2
return ref
def forward(self, x):
B, C, H, W = x.size()
dtype, device = x.dtype, x.device
q = self.proj_q(x)
q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
offset = self.conv_offset(q_off).contiguous() # B * g 2 Hg Wg
Hk, Wk = offset.size(2), offset.size(3)
n_sample = Hk * Wk
if self.offset_range_factor >= 0 and not self.no_off:
offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)
offset = einops.rearrange(offset, 'b p h w -> b h w p')
reference = self._get_ref_points(Hk, Wk, B, dtype, device)
if self.no_off:
offset = offset.fill_(0.0)
if self.offset_range_factor >= 0:
pos = offset + reference
else:
pos = (offset + reference).clamp(-1., +1.)
if self.no_off:
x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
else:
pos = pos.type(x.dtype)
x_sampled = F.grid_sample(
input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
grid=pos[..., (1, 0)], # y, x -> x, y
mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg
x_sampled = x_sampled.reshape(B, C, 1, n_sample)
q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
attn = attn.mul(self.scale)
if self.use_pe and (not self.no_off):
if self.dwc_pe:
residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
elif self.fixed_pe:
rpe_table = self.rpe_table
attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
elif self.log_cpb:
q_grid = self._get_q_grid(H, W, B, dtype, device)
displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(4.0) # d_y, d_x [-8, +8]
displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
attn_bias = self.rpe_table(displacement) # B * g, H * W, n_sample, h_g
attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
else:
rpe_table = self.rpe_table
rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
q_grid = self._get_q_grid(H, W, B, dtype, device)
displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
attn_bias = F.grid_sample(
input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
grid=displacement[..., (1, 0)],
mode='bilinear', align_corners=True) # B * g, h_g, HW, Ns
attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
attn = attn + attn_bias
attn = F.softmax(attn, dim=2)
attn = self.attn_drop(attn)
out = torch.einsum('b m n, b c n -> b c m', attn, v)
if self.use_pe and self.dwc_pe:
out = out + residual_lepe
out = out.reshape(B, C, H, W)
y = self.proj_drop(self.proj_out(out))
return y
if __name__ == '__main__':
# 设置模型超参数
channel = 64
q_size = (32, 32) # 假设查询大小为 32x32
n_heads = 8 # 8 个注意力头
n_groups = 4 # 分成 4 组
stride = 1 # 卷积步长为 1
# 创建 DAttention 模块实例
model = DAttention(
channel=channel, q_size=q_size, n_heads=n_heads, n_groups=n_groups, stride=stride
)
batch_size = 4 # 假设批次大小为 4
height, width = 64, 64 # 假设输入图像的尺寸为 64x64
input = torch.randn(batch_size, channel, height, width)
output = model(input)
print(input.shape) print(output.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文