ICLR 即插即用MobileViTAttention,结合卷积的局部特征提取能力和Transformer的全局特征建模能力

文摘   2025-01-04 17:20   中国香港  

论文介绍

题目:https://arxiv.org/pdf/2108.00154

论文地址:https://arxiv.org/pdf/2108.00154

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 轻量化设计MobileViT将CNN(卷积神经网络)的空间归纳偏差与ViT(视觉变换器)的全局信息处理能力相结合,设计了一种轻量级、通用且适用于移动设备的视觉变换器。提出了一种新的MobileViT块,它通过结合局部卷积操作和全局Transformer机制,在较少参数的情况下有效地学习局部和全局特征。

  • 简单训练配方MobileViT无需依赖复杂的数据增强或大规模正则化(如L2正则化),能够以基本的数据增强方法(如随机裁剪和水平翻转)实现优异性能。

  • 对比同样大小的CNN和其他轻量级ViT模型,MobileViT的训练更加高效和鲁棒。

  • 优异的性能表现在ImageNet-1k数据集上,MobileViT取得了78.4%的Top-1准确率,比同等参数量的MobileNetv3高出3.2%,比DeIT高出6.2%。

  • 在MS-COCO目标检测任务中,MobileViT相比MobileNetv3提升了5.7%的mAP。

  • 通用性与高效性MobileViT可以作为骨干网络,广泛应用于目标检测和语义分割等下游任务,并在性能和模型尺寸上优于其他轻量级和重型模型。MobileViT模型在移动设备(如iPhone 12)上运行具有较低延迟,能实现实时推理(>30 FPS)。

    多尺度采样策略

  • 提出了多尺度采样方法,在训练过程中动态调整输入分辨率,提高了训练效率并减少了优化步骤,同时增强了模型的多尺度特征学习能力。

方法

整体架构

     MobileViT模型通过结合MobileNetv2的高效倒置残差块(MV2块)和创新的MobileViT块,构建了一种轻量化、适用于移动设备的网络架构。其结构包括初始卷积层用于基本特征提取,MV2块负责局部特征学习和下采样,MobileViT块通过整合局部卷积操作和全局Transformer机制同时捕获局部和全局特征,最终通过全局池化和全连接层完成分类任务。这种架构兼具高效性和灵活性,能够适应多种视觉任务,同时优化了移动设备上的推理性能。

1. 总体架构

  • 初始卷积层

    • 模型首先使用一个带步长的标准3×33 \times 3 卷积层,对输入图像进行下采样和初步特征提取。

  • MobileNetv2模块 (MV2 Blocks)

    • 使用了MobileNetv2的倒置残差块(Inverted Residual Block),作为主要的特征提取模块之一,负责进一步的局部特征提取和下采样。

  • MobileViT块 (MobileViT Blocks)

    • MobileViT块是模型的核心创新模块,它结合了局部卷积操作和全局Transformer机制,用于同时学习局部和全局特征。

    • 在每个MobileViT块中:

  1. 使用标准卷积提取局部特征。

  2. 通过展开(unfolding)操作将特征映射划分为多个非重叠的patch。

  3. 使用Transformer机制在patch之间建立全局关系。

  4. 通过折叠(folding)操作将处理后的特征重新组合成完整的特征映射。

  5. 最终用卷积层对局部和全局特征进行融合。

  • 输出层

    • 在最后阶段,模型通过全局池化层和全连接层,生成用于分类任务的最终预测。


    2. 分层设计与参数配置

    • 模型变体

      • MobileViT提供了多个变体,包括MobileViT-XXS(极小型)、MobileViT-XS(小型)、MobileViT-S(标准型)。这些变体在网络深度(层数)和宽度(通道数)上有所不同,以适应不同的任务需求和设备资源限制。

    • 空间分辨率与多尺度特性

      • 不同阶段的特征映射分辨率逐步下降,如128×128128 \times 1281×11 \times 1,以实现有效的空间压缩。

      • MobileViT块中的Transformer在不同分辨率下处理特征,以保证多尺度特征的提取和融合。

    即插即用模块作用

    MobileViTAttention 作为一个即插即用模块

    • 增强全局信息感知能力

      • 在视觉任务中,局部特征和全局依赖的结合有助于提升对复杂场景和长距离依赖关系的建模能力。

    • 提高任务性能

      • 在分类、检测和分割任务中,通过引入MobileViTAttention,可以显著提升模型的准确性和鲁棒性。

    • 低计算开销

      • 模块化的设计使其能够以较低的额外计算代价提升性能,尤其适合实时性要求高的任务。

    • 模型通用性增强

      • 作为即插即用模块,它能够帮助现有模型在新数据集和新任务上更好地泛化。

    消融实验结果

    • 比较了MobileViT与MobileNetv2、DeiT和PiT在推理时间、参数量和FLOPs上的表现。

    • 分析

      • MobileViT在推理速度和性能上优于其他ViT变体,虽然在移动设备上略慢于MobileNetv2,但在参数量和准确率之间取得了更好的平衡。

    即插即用模块

    from torch import nn
    import torch
    from einops import rearrange

    # 论文题目:MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer
    # 论文链接:https://arxiv.org/pdf/2110.02178


    # 预定义一个带有层归一化的预处理模块
    class PreNorm(nn.Module):
        def __init__(self, dim, fn):
            super().__init__()
            self.ln = nn.LayerNorm(dim) # 层归一化,标准化输入
            self.fn = fn # 用于传入的函数(例如 Attention 或 FeedForward)

        def forward(self, x, **kwargs):
            return self.fn(self.ln(x), **kwargs) # 对归一化后的输入应用函数


    # 定义一个前馈神经网络模块,用于 MLP 层
    class FeedForward(nn.Module):
        def __init__(self, dim, mlp_dim, dropout):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(dim, mlp_dim), # 线性层,输入维度到 MLP 维度
                nn.SiLU(), # SiLU 激活函数
                nn.Dropout(dropout), # Dropout,防止过拟合
                nn.Linear(mlp_dim, dim), # 线性层,将 MLP 维度还原为输入维度
                nn.Dropout(dropout) # Dropout
            )

        def forward(self, x):
            return self.net(x) # 输出前馈网络的结果


    # 定义注意力模块,用于计算多头自注意力
    class Attention(nn.Module):
        def __init__(self, dim, heads, head_dim, dropout):
            super().__init__()
            inner_dim = heads * head_dim # 内部维度为头数乘以每头的维度
            project_out = not (heads == 1 and head_dim == dim) # 判断是否需要输出投影

            self.heads = heads # 注意力头的数量
            self.scale = head_dim ** -0.5  # 缩放因子,用于稳定训练

            self.attend = nn.Softmax(dim=-1) # 使用 Softmax 计算注意力权重
            self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) # 线性变换生成查询、键、值

            # 输出层,如果没有单独的投影层则直接使用 Identity
            self.to_out = nn.Sequential(
                nn.Linear(inner_dim, dim),
                nn.Dropout(dropout)
            ) if project_out else nn.Identity()

        def forward(self, x):
            qkv = self.to_qkv(x).chunk(3, dim=-1) # 将查询、键和值分成三个部分
            q, k, v = map(lambda t: rearrange(t, 'b p n (h d) -> b p h n d', h=self.heads), qkv) # 重排维度
            dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale # 计算注意力分数
            attn = self.attend(dots) # 对注意力分数应用 Softmax
            out = torch.matmul(attn, v) # 根据注意力权重加权值向量
            out = rearrange(out, 'b p h n d -> b p n (h d)') # 重排回原始维度
            return self.to_out(out) # 返回投影输出


    # Transformer 模块,由多层注意力和前馈网络组成
    class Transformer(nn.Module):
        def __init__(self, dim, depth, heads, head_dim, mlp_dim, dropout=0.):
            super().__init__()
            self.layers = nn.ModuleList([]) # 初始化层列表
            for _ in range(depth): # 根据深度循环添加层
                self.layers.append(nn.ModuleList([
                    PreNorm(dim, Attention(dim, heads, head_dim, dropout)), # 预归一化注意力模块
                    PreNorm(dim, FeedForward(dim, mlp_dim, dropout)) # 预归一化前馈模块
                ]))

        def forward(self, x):
            out = x
            for att, ffn in self.layers: # 遍历注意力和前馈网络层
                out = out + att(out) # 残差连接,应用注意力
                out = out + ffn(out) # 残差连接,应用前馈网络
            return out


    # MobileViT 的注意力模块,结合了局部和全局表示
    class MobileViTAttention(nn.Module):
        def __init__(self, in_channel=3, dim=512, kernel_size=3, patch_size=7):
            super().__init__()
            self.ph, self.pw = patch_size, patch_size # 设置 patch 的高度和宽度
            self.conv1 = nn.Conv2d(in_channel, in_channel, kernel_size=kernel_size, padding=kernel_size // 2) # 局部卷积
            self.conv2 = nn.Conv2d(in_channel, dim, kernel_size=1) # 用于通道变换的 1x1 卷积

            self.trans = Transformer(dim=dim, depth=3, heads=8, head_dim=64, mlp_dim=1024) # Transformer 模块用于全局表示

            self.conv3 = nn.Conv2d(dim, in_channel, kernel_size=1) # 将维度变换回原通道
            self.conv4 = nn.Conv2d(2 * in_channel, in_channel, kernel_size=kernel_size, padding=kernel_size // 2) # 用于融合的卷积层

        def forward(self, x):
            y = x.clone() # 复制输入张量 y = x 以保留局部特征

            ## 局部表示
            y = self.conv2(self.conv1(x)) # 使用卷积层获得局部特征

            ## 全局表示
            _, _, h, w = y.shape # 获取 y 的高度和宽度
            y = rearrange(y, 'bs dim (nh ph) (nw pw) -> bs (ph pw) (nh nw) dim', ph=self.ph, pw=self.pw) # 重排为 patch 格式
            y = self.trans(y) # 应用 Transformer 进行全局特征提取
            y = rearrange(y, 'bs (ph pw) (nh nw) dim -> bs dim (nh ph) (nw pw)', ph=self.ph, pw=self.pw, nh=h // self.ph,
                          nw=w // self.pw) # 恢复为原始形状

            ## 融合
            y = self.conv3(y) # 维度变换回原通道
            y = torch.cat([x, y], 1) # 拼接局部和全局特征
            y = self.conv4(y) # 融合后的卷积操作

            return y # 返回融合结果

    if __name__ == '__main__':
        m = MobileViTAttention(in_channel=512)
        input = torch.randn(1, 512, 49, 49) # 生成输入张量,大小为 (1, 512, 49, 49)
        output = m(input) # 应用 MobileViTAttention 模块
        print(input.shape) # 打印输入张量的形状
        print(output.shape) # 打印输出张量的形状

    便捷下载方式

    浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

    更多分析可见原文


    ai缝合大王
    聚焦AI前沿,分享相关技术、论文,研究生自救指南
     最新文章