CVPR | DAT:高效灵活的可变形注意力Transformer,为视觉任务带来新突破

文摘   2025-01-07 17:20   上海  

论文介绍

题目:Efficient Attention: Attention with Linear Complexities

论文地址:https://openaccess.thecvf.com/content/CVPR2022/papers/Xia_Vision_Transformer_With_Deformable_Attention_CVPR_2022_paper.pdf

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出可变形自注意力模块(Deformable Self-Attention Module)

    • 传统的视觉Transformer模型(例如ViT和Swin Transformer)要么使用全局注意力,导致计算开销过大,要么使用固定的稀疏注意力模式,限制了建模长距离关系的能力。

    • 本文提出了基于数据的可变形注意力机制,通过动态选择关键和数值对的位置,使注意力模块能够专注于相关区域,提升了特征捕捉能力和模型的灵活性。

  • 构建可变形注意力Transformer主干模型(Deformable Attention Transformer, DAT)

    • 通过整合可变形注意力模块,设计了一个通用的主干模型,适用于图像分类和密集预测任务。

    • 其特点是灵活建模关键特征,同时保持线性空间复杂度。

  • 高效的注意力机制设计

    • 论文通过共享的采样偏移组而不是每个查询单独计算偏移,显著降低了可变形注意力的计算复杂度,使其适合作为主干模型。

    • 引入了一种变形的相对位置偏置机制(Deformable Relative Position Bias),进一步增强了注意力模块的表现。

  • 在多个基准数据集上的出色表现

    • 在ImageNet-1K上,与Swin Transformer相比,其模型在分类任务的Top-1准确率上提升了0.7%。

    • 在COCO对象检测任务中,对于不同大小的目标,该模型表现出更显著的优势,尤其是对于大型目标的检测提升了多达2.1个百分点。在ADE20K语义分割数据集上,mIoU的提升幅度在多个模型规模下均超过了1%。

方法

整体架构

     这篇论文提出了一种分层结构的视觉Transformer模型,称为Deformable Attention Transformer (DAT),通过引入可变形注意力模块,在后两个阶段(Stage 3 和 Stage 4)灵活捕捉全局关系,同时结合前两个阶段的本地注意力机制,逐层提取多尺度特征。模型的整体设计交替使用局部和全局注意力模块,专注于重要区域,实现高效且灵活的特征建模,适用于图像分类、目标检测和语义分割等任务。

1. 模型的分层结构

论文的模型采用了与Swin Transformer类似的金字塔结构,分为四个阶段(Stages)。输入图像首先被划分为固定大小的patch嵌入,然后逐层处理,特征图的分辨率逐步降低,通道数逐步增加,以提取多尺度特征。各阶段的关键设计如下:

  • Stage 1 和 Stage 2

    • 使用**本地注意力模块(Local Attention Module)Shift-Window Attention(移窗注意力)**进行特征学习。

    • 这两个阶段主要捕捉局部特征。

  • Stage 3 和 Stage 4

    • 引入可变形注意力模块(Deformable Attention Module),替换原有的Shift-Window Attention。

    • 可变形注意力模块使模型能够捕捉全局关系,并在局部增强的特征基础上进一步提取有意义的全局特征。


2. 可变形注意力模块的设计

  • 核心机制:通过可变形的采样点,灵活选择关键点和数值点,专注于重要区域以增强注意力效果。

  • 具体实现

    • 首先,生成均匀分布的参考点网格。

    • 然后,通过偏移网络(Offset Network)生成每个参考点的偏移值。

    • 最后,利用偏移点计算新的键(Key)和值(Value),并通过多头注意力机制进行特征聚合。

    • 偏移点的位置还引入了一种变形的相对位置偏置,以进一步增强注意力效果。


3. 模块交替设计

在后两个阶段(Stage 3 和 Stage 4),采用交替的注意力模块设计

  • 本地注意力模块负责局部特征的聚合。

  • 可变形注意力模块则用来捕捉全局特征。 这种设计结合了局部和全局的感受野,有助于模型学习更强的多尺度特征。


4. Patch 嵌入模块

输入图像首先被分割为大小为4×44\times4 的patch,经过一个非重叠的卷积操作进行嵌入。卷积核的大小为4×44\times4,步幅为4,从而将输入的H×W×3H \times W \times 3 的图像转换为H/4×W/4×CH/4 \times W/4 \times C 的特征图。


5. 模型的具体变体

论文设计了三个变体(DAT-T、DAT-S 和 DAT-B),主要区别在于:

  • 每一阶段的通道数(C)和堆叠的注意力模块数量(N)。

  • 可变形注意力模块的头数和偏移组数(Offset Groups)。

即插即用模块作用

DAT 作为一个即插即用模块

  • 图像分类(Image Classification)

    • DAT作为主干网络的一部分,能够有效捕捉多尺度特征。

    • 在大规模图像分类任务(如ImageNet-1K)中,DAT通过灵活的注意力机制提高了分类准确率。

  • 目标检测(Object Detection)

    • 作为目标检测模型(如RetinaNet、Mask R-CNN和Cascade Mask R-CNN)的主干网络,DAT能够有效建模对象的长距离依赖关系。

    • 在COCO目标检测任务中,DAT对小目标和大目标的检测性能均有显著提升,尤其对大目标的检测提升尤为明显(最高提升+2.1 mAP)。

  • 语义分割(Semantic Segmentation)

    • 在语义分割任务(如ADE20K数据集)中,DAT被用于细粒度的特征分割建模。

    • 通过灵活的注意力模式,DAT能够更好地识别复杂场景中的小物体和局部细节,提高分割的mIoU和mAcc。

  • 密集预测任务(Dense Prediction Tasks)

    • DAT特别适用于需要多尺度建模的密集预测任务,如实例分割和场景理解任务。

      它通过捕捉全局和局部关系,提高了对小物体和场景复杂区域的建模能力。

消融实验结果

  • 内容:评估不同几何信息(如偏移和位置嵌入)的使用方式对模型性能的影响。实验包括以下几种配置:

    • 是否使用偏移点(Offsets)。

    • 是否使用相对位置嵌入(Relative Position Embedding)。

    • 相比固定的偏置或深度卷积位置嵌入,论文提出的变形相对位置嵌入的效果。

  • 结果:引入偏移点和变形相对位置嵌入分别带来了性能提升(+0.3%)。两者结合时,表现最佳,证明了它们在可变形注意力中的兼容性和有效性。


  • 内容:探讨在不同的网络阶段引入可变形注意力的效果。实验逐步将可变形注意力模块替换掉原有的Shift-Window Attention模块,并观察性能变化。

  • 结果

    • 仅在最后一个阶段使用可变形注意力,性能小幅提升(+0.1%)。

    • 在最后两个阶段使用可变形注意力时,性能大幅提升(+0.7%),达到最佳。

    • 在所有阶段都使用可变形注意力时,性能略有下降,表明早期阶段引入全局注意力可能会干扰局部特征学习。

  • 结论:可变形注意力更适合后期阶段,用于捕捉全局关系,而早期阶段更适合局部建模。

即插即用模块

# 论文题目:Efficient Attention: Attention with Linear Complexities
# 论文链接:https://openaccess.thecvf.com/content/CVPR2022/papers/Xia_Vision_Transformer_With_Deformable_Attention_CVPR_2022_paper.pdf

import torch, einops
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from timm.models.layers import trunc_normal_

class LayerNormProxy(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)

    def forward(self, x):
        x = einops.rearrange(x, 'b c h w -> b h w c')
        x = self.norm(x)
        return einops.rearrange(x, 'b h w c -> b c h w')

class DAttention(nn.Module):
    # Vision Transformer with Deformable Attention CVPR2022
    # fixed_pe=True need adujust 640x640
    def __init__(
        self, channel, q_size, n_heads=8, n_groups=4,
        attn_drop=0.0, proj_drop=0.0, stride=1,
        offset_range_factor=4, use_pe=True, dwc_pe=True,
        no_off=False, fixed_pe=False, ksize=3, log_cpb=False, kv_size=None
    )
:

        super().__init__()
        n_head_channels = channel // n_heads
        self.dwc_pe = dwc_pe
        self.n_head_channels = n_head_channels
        self.scale = self.n_head_channels ** -0.5
        self.n_heads = n_heads
        self.q_h, self.q_w = q_size
        # self.kv_h, self.kv_w = kv_size
        self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride
        self.nc = n_head_channels * n_heads
        self.n_groups = n_groups
        self.n_group_channels = self.nc // self.n_groups
        self.n_group_heads = self.n_heads // self.n_groups
        self.use_pe = use_pe
        self.fixed_pe = fixed_pe
        self.no_off = no_off
        self.offset_range_factor = offset_range_factor
        self.ksize = ksize
        self.log_cpb = log_cpb
        self.stride = stride
        kk = self.ksize
        pad_size = kk // 2 if kk != stride else 0

        self.conv_offset = nn.Sequential(
            nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
            LayerNormProxy(self.n_group_channels),
            nn.GELU(),
            nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)
        )
        if self.no_off:
            for m in self.conv_offset.parameters():
                m.requires_grad_(False)

        self.proj_q = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_k = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_v = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_out = nn.Conv2d(
            self.nc, self.nc,
            kernel_size=1, stride=1, padding=0
        )

        self.proj_drop = nn.Dropout(proj_drop, inplace=True)
        self.attn_drop = nn.Dropout(attn_drop, inplace=True)

        if self.use_pe and not self.no_off:
            if self.dwc_pe:
                self.rpe_table = nn.Conv2d(
                    self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)
            elif self.fixed_pe:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
                )
                trunc_normal_(self.rpe_table, std=0.01)
            elif self.log_cpb:
                # Borrowed from Swin-V2
                self.rpe_table = nn.Sequential(
                    nn.Linear(2, 32, bias=True),
                    nn.ReLU(inplace=True),
                    nn.Linear(32, self.n_group_heads, bias=False)
                )
            else:
                self.rpe_table = nn.Parameter(
                    torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
                )
                trunc_normal_(self.rpe_table, std=0.01)
        else:
            self.rpe_table = None

    @torch.no_grad()
    def _get_ref_points(self, H_key, W_key, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
            torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2

        return ref
    
    @torch.no_grad()
    def _get_q_grid(self, H, W, B, dtype, device):

        ref_y, ref_x = torch.meshgrid(
            torch.arange(0, H, dtype=dtype, device=device),
            torch.arange(0, W, dtype=dtype, device=device),
            indexing='ij'
        )
        ref = torch.stack((ref_y, ref_x), -1)
        ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
        ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
        ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1) # B * g H W 2

        return ref

    def forward(self, x):

        B, C, H, W = x.size()
        dtype, device = x.dtype, x.device

        q = self.proj_q(x)
        q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)
        offset = self.conv_offset(q_off).contiguous() # B * g 2 Hg Wg
        Hk, Wk = offset.size(2), offset.size(3)
        n_sample = Hk * Wk

        if self.offset_range_factor >= 0 and not self.no_off:
            offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
            offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

        offset = einops.rearrange(offset, 'b p h w -> b h w p')
        reference = self._get_ref_points(Hk, Wk, B, dtype, device)

        if self.no_off:
            offset = offset.fill_(0.0)

        if self.offset_range_factor >= 0:
            pos = offset + reference
        else:
            pos = (offset + reference).clamp(-1., +1.)

        if self.no_off:
            x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
            assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
        else:
            pos = pos.type(x.dtype)
            x_sampled = F.grid_sample(
                input=x.reshape(B * self.n_groups, self.n_group_channels, H, W),
                grid=pos[..., (1, 0)], # y, x -> x, y
                mode='bilinear', align_corners=True) # B * g, Cg, Hg, Wg
                

        x_sampled = x_sampled.reshape(B, C, 1, n_sample)

        q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
        k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
        v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

        attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
        attn = attn.mul(self.scale)

        if self.use_pe and (not self.no_off):

            if self.dwc_pe:
                residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
            elif self.fixed_pe:
                rpe_table = self.rpe_table
                attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
            elif self.log_cpb:
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(4.0) # d_y, d_x [-8, +8]
                displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
                attn_bias = self.rpe_table(displacement) # B * g, H * W, n_sample, h_g
                attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
            else:
                rpe_table = self.rpe_table
                rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
                q_grid = self._get_q_grid(H, W, B, dtype, device)
                displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
                attn_bias = F.grid_sample(
                    input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
                    grid=displacement[..., (1, 0)],
                    mode='bilinear', align_corners=True) # B * g, h_g, HW, Ns

                attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
                attn = attn + attn_bias

        attn = F.softmax(attn, dim=2)
        attn = self.attn_drop(attn)

        out = torch.einsum('b m n, b c n -> b c m', attn, v)

        if self.use_pe and self.dwc_pe:
            out = out + residual_lepe
        out = out.reshape(B, C, H, W)

        y = self.proj_drop(self.proj_out(out))

        return y

if __name__ == '__main__':
    # 设置模型超参数
    channel = 64
    q_size = (32, 32) # 假设查询大小为 32x32
    n_heads = 8  # 8 个注意力头
    n_groups = 4  # 分成 4 组
    stride = 1  # 卷积步长为 1

    # 创建 DAttention 模块实例
    model = DAttention(
        channel=channel, q_size=q_size, n_heads=n_heads, n_groups=n_groups, stride=stride
    )

    batch_size = 4  # 假设批次大小为 4
    height, width = 64, 64  # 假设输入图像的尺寸为 64x64
    input = torch.randn(batch_size, channel, height, width)

    output = model(input)

    print(input.shape)    print(output.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章