即插即用多分辨率特征融合模块SAM,涨点起飞起飞了

文摘   2025-01-22 17:20   中国香港  

论文介绍

题目:Attention Attention Everywhere: Monocular Depth Prediction with Skip Attention

论文地址:https://arxiv.org/pdf/2210.09071

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • Skip Attention Module (SAM)

    • 提出了基于窗口的跨注意力模块,用于将解码器特征与编码器特征进行融合。与传统卷积操作相比,该模块通过计算像素查询和特定窗口内编码器特征的相似性,能够更高效地结合局部和全局语义信息。

    • SAM模块有效克服了传统跳跃连接在结合高分辨率和全局上下文信息时的局限性。

  • Pixel Query Refinement

    • 将单目深度估计建模为像素查询精炼问题。利用编码器的最粗特征图生成初始像素查询,通过解码器特征的引导逐步优化像素查询到更高分辨率。

  • Bin Center Predictor (BCP)

    • 提出了轻量化的深度分箱预测模块,基于最粗分辨率的初始像素查询自适应地预测每张图像的深度分箱。相比现有方法,该模块更加高效,同时通过直接监督初始像素查询,增强了深度嵌入。

  • 性能提升

    • 在NYUV2(室内数据集)和KITTI(室外数据集)上分别实现了5.3%和3.9%的性能提升(绝对相对误差和平方相对误差)。在SUNRGBD(室内数据集)上实现了9.4%的泛化性能提升。

    • 通过设计独特的模块(如SAM和BCP),改善了边界对齐和深度估计的准确性,特别是在不同深度范围内的适应性更强。

  • 整合架构

    • 结合了视觉Transformer和编码器-解码器框架,通过引入跨尺度和跨模块的注意力机制,实现了对全局上下文信息和局部细节的全面捕获。

方法

整体架构

       这篇论文提出的 PixelFormer 模型基于编码器-解码器架构,利用 Swin Transformer 提取多尺度特征,并将单目深度估计建模为像素查询精炼问题。通过像素查询初始化模块(PQI)生成全局信息丰富的初始像素查询,结合跳跃注意力模块(SAM)在解码过程中逐级融合编码器特征,增强全局与局部信息的交互。同时,深度分箱预测模块(BCP)自适应地生成深度分箱以优化估计精度。最终,通过解码器输出每像素的深度预测,显著提升了室内外数据集的估计性能和泛化能力。

1. 整体架构

  • 编码器:使用基于 Swin Transformer 的编码器提取多尺度特征图。特征图的分辨率为14,18,116,132\frac{1}{4}, \frac{1}{8}, \frac{1}{16}, \frac{1}{32}

  • 像素查询初始化模块(Pixel Query Initialiser, PQI)

    • 利用编码器中最粗分辨率的特征图 (132\frac{1}{32}) 初始化全局像素查询。

    • 使用金字塔池化(Pyramid Spatial Pooling, PSP)和全局平均池化提取场景全局信息。

  • 跳跃注意力模块(Skip Attention Module, SAM)

    • 用于解码过程中逐步融合编码器的多尺度特征。

    • SAM通过跨注意力机制,将解码器特征与高分辨率的编码器特征融合,从而精炼像素查询。

  • 深度分箱预测模块(Bin Center Predictor, BCP)

    • 基于最粗分辨率的初始像素查询,自适应地预测深度分箱(bin centers),从而提供每张图像的深度离散化。

  • 解码器

    • 解码器通过逐级精炼像素查询,最终生成深度估计结果。


2. 模块细节

(1) 编码器

  • 使用 Swin Transformer 作为主干网络(backbone),提取多尺度特征图。

  • 特征图具有全局感受野,捕获长距离依赖。

(2) 像素查询初始化模块(PQI)

  • 输入:编码器生成的最粗分辨率特征图。

  • 输出:初始像素查询,包含场景的全局信息。

  • 方法:

    • 对输入特征图进行多尺度池化操作(包括 1×1、2×2、3×3 和 6×6 的全局池化)。

    • 将池化结果上采样到相同分辨率后拼接,通过卷积操作整合信息,生成初始像素查询。

(3) 跳跃注意力模块(SAM)

  • 输入:当前分辨率的像素查询和对应分辨率的编码器特征。

  • 输出:精炼后的像素查询。

  • 方法:

    • SAM通过跨注意力机制计算像素查询与编码器特征的相关性,并在窗口内融合信息。

    • 采用多头注意力机制(window-based cross-attention),同时利用残差连接确保梯度稳定。

(4) 深度分箱预测模块(BCP)

  • 输入:初始像素查询。

  • 输出:图像自适应的深度分箱(bin centers)。

  • 方法:

    • 对像素查询进行全局平均池化。

    • 使用多层感知机(MLP)预测深度分箱(bins)。

(5) 解码器

  • 解码器逐级上采样像素查询(如使用 Pixel Shuffle 方法),并通过 SAM 模块与编码器特征融合。

  • 最终,通过卷积和 Softmax 操作,生成每像素的深度概率分布。


3. 深度估计输出

  • 深度预测

    • 对每个像素,计算深度概率分布与对应分箱中心的加权和,得到最终深度值。

    • 深度did_i 计算公式为:di=k=1nbinsc(bk)pikd_i = \sum_{k=1}^{n_{\text{bins}}} c(b_k) \cdot p_{ik}其中,c(bk)c(b_k) 是第kk 个深度分箱的中心,pikp_{ik} 是第ii个像素在该分箱的概率。


4. 训练与优化

  • 损失函数:使用 Scale-Invariant Logarithmic Loss (SILog),同时直接对初始像素查询和最终深度预测进行监督。

即插即用模块作用

SAM 作为一个即插即用模块

  • 高效融合局部细节与全局信息

    • SAM 使用跨注意力机制,通过比较解码器中的像素查询与编码器特征图的局部区域,动态融合全局和局部特征。

    • 解决了传统卷积或简单跳跃连接方法中,高分辨率细节和全局语义信息结合不足的问题。

  • 提升长距离依赖的建模能力

    • SAM 能够在特定窗口内捕获跨像素的相似性,增强模型对边界信息和复杂场景结构的建模能力。

  • 增强特征交互,提高预测精度

    • 通过自适应地融合编码器和解码器特征,改善像素级预测的准确性,尤其在边界区域和深度梯度变化大的区域。

消融实验结果

  • 内容

    • 比较了使用不同特征融合方法(Add-Conv、Cat-Conv 和 SAM)的性能。

    • 评价指标包括 Abs Rel(绝对相对误差)Sq Rel(平方相对误差)

  • 结果

    • SAM 模块相比 Add-Conv 和 Cat-Conv 显著降低了误差(Abs Rel 降低 4.0%,Sq Rel 降低 4.2%)。

    • 说明 SAM 模块在结合编码器和解码器特征时更高效,能更好地捕获长距离依赖。


  • 内容

    • 比较了不同方法对深度分箱预测的影响,包括 mViT-Last(基于高分辨率特征图)、mViT-First(基于初始像素查询)和 BCP(本文提出的方法)。

    • 评价指标为 Abs Rel、Sq Rel 和 δ1(阈值精度)

  • 结果

    • BCP 模块相比 mViT-Last 和 mViT-First 有更低的误差(Abs Rel 降低 3%,Sq Rel 降低 3.7%),并且提升了 δ1 精度。

    • 说明直接在初始像素查询上嵌入深度信息能显著提升预测准确性。

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_




class Mlp(nn.Module):
    """ Multilayer perceptron."""

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition(x, window_size):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """

    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): Window size
        H (int): Height of image
        W (int): Width of image

    Returns:
        x: (B, H, W, C)
    """

    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x


class WindowAttention(nn.Module):
    """ Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """


    def __init__(self, dim, window_size, num_heads, v_dim, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
        self.register_buffer("relative_position_index", relative_position_index)

        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(v_dim, v_dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, v, mask=None):
        """ Forward function.

        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """

        B_, N, C = x.shape
        q = self.q(x).view(B_, N, self.num_heads, -1).transpose(1, 2)
        kv = self.kv(v).reshape(B_, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1] # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class SAMBLOCK(nn.Module):
    """
    Args:
        dim (int): Number of feature channels
        num_heads (int): Number of attention head.
        window_size (int): Local window size. Default: 7.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
    """


    def __init__(self,
                 dim,
                 num_heads,
                 v_dim,
                 window_size=7,
                 mlp_ratio=4.,
                 qkv_bias=True,
                 qk_scale=None,
                 drop=0.,
                 attn_drop=0.,
                 drop_path=0.,
                 norm_layer=nn.LayerNorm,
                 )
:

        super().__init__()
        self.window_size = window_size
        self.dim = dim
        self.num_heads = num_heads
        self.v_dim = v_dim
        self.window_size = window_size
        self.mlp_ratio = mlp_ratio
        act_layer = nn.GELU
        norm_layer = nn.LayerNorm

        self.norm1 = norm_layer(dim)
        self.normv = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, v_dim=v_dim,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(v_dim)
        mlp_hidden_dim = int(v_dim * mlp_ratio)
        self.mlp = Mlp(in_features=v_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, v, H, W):
        """ Forward function.

        Args:
            x: Input feature, tensor size (B, H*W, C).
            H, W: Spatial resolution of the input feature.
        """


        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        shortcut_v = v
        v = self.normv(v)
        v = v.view(B, H, W, C)

        # pad feature maps to multiples of window size
        pad_l = pad_t = 0
        pad_r = (self.window_size - W % self.window_size) % self.window_size
        pad_b = (self.window_size - H % self.window_size) % self.window_size
        x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
        v = F.pad(v, (0, 0, pad_l, pad_r, pad_t, pad_b))
        _, Hp, Wp, _ = x.shape

        # partition windows
        x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
        v_windows = window_partition(v, self.window_size) # nW*B, window_size, window_size, C
        v_windows = v_windows.view(-1, self.window_size * self.window_size,
                                   v_windows.shape[-1]) # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, v_windows, mask=None) # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.v_dim)
        x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C

        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

        x = x.view(B, H * W, self.v_dim)

        # FFN
        x = self.drop_path(x) + shortcut
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x, H, W


class SAM(nn.Module):
    def __init__(self,
                 input_dim=96,
                 embed_dim=96,
                 v_dim=64,
                 window_size=7,
                 num_heads=4,
                 patch_size=4,
                 in_chans=3,
                 norm_layer=nn.LayerNorm,
                 patch_norm=True)
:

        super().__init__()

        self.embed_dim = embed_dim

        if input_dim != embed_dim:
            self.proj_e = nn.Conv2d(input_dim, embed_dim, 3, padding=1)
        else:
            self.proj_e = None

        if v_dim != embed_dim:
            self.proj_q = nn.Conv2d(v_dim, embed_dim, 3, padding=1)
        elif embed_dim % v_dim == 0:
            self.proj_q = None
        self.proj = nn.Conv2d(embed_dim, embed_dim, 3, padding=1)

        v_dim = embed_dim
        self.sam_block = SAMBLOCK(
            dim=embed_dim,
            num_heads=num_heads,
            v_dim=v_dim,
            window_size=window_size,
            mlp_ratio=4.,
            qkv_bias=True,
            qk_scale=None,
            drop=0.,
            attn_drop=0.,
            drop_path=0.,
            norm_layer=norm_layer)

        layer = norm_layer(embed_dim)
        layer_name = 'norm_sam'
        self.add_module(layer_name, layer)

    def forward(self, e, q):
        if self.proj_q is not None:
            q = self.proj_q(q)
        if self.proj_e is not None:
            e = self.proj_e(e)
        e_proj = e
        q_proj = q

        Wh, Ww = q.size(2), q.size(3)
        q = q.flatten(2).transpose(1, 2)
        e = e.flatten(2).transpose(1, 2)

        q_out, H, W = self.sam_block(q, e, Wh, Ww)
        norm_layer = getattr(self, f'norm_sam')
        q_out = norm_layer(q_out)
        q_out = q_out.view(-1, H, W, self.embed_dim).permute(0, 3, 1, 2).contiguous()

        return q_out + e_proj + q_proj



if __name__ == '__main__':

    model = SAM(input_dim=96, embed_dim=96, v_dim=96, window_size=7, num_heads=4, patch_size=4, in_chans=3)

    B, H, W = 2, 128, 128
    e = torch.rand(B, 96, H, W)
    q = torch.rand(B, 96, H, W)

    output = model(e, q)

    print(e.size())
    print(q.size())    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章