ICCV 2023 | 多尺度线性注意力模块,有效涨点,即插即用

文摘   2025-01-16 10:12   安徽  

点击下方卡片,关注“AI前沿速递”公众号

各种重磅干货,第一时间送达


标题:EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction

论文链接:https://arxiv.org/pdf/2205.14756

代码链接:https://github.com/mit-han-lab/efficientvit**

创新点

**

  1. 三明治布局设计:在FFN层间插入单个受内存限制的MHSA层,减少其时间成本,同时增加FFN层以增强通道通信,提升内存效率。
  2. 级联组注意力模块:将特征分割成不同部分提供给各注意力头,避免注意力头输入相同特征,减少计算冗余,还通过级联输出特征增加网络深度与注意力多样性。
  3. 参数重分配策略:扩大关键组件如值投影的通道宽度,缩小次要组件如FFN隐藏维度,优化参数使用,提高模型参数效率。
  4. 轻量级多尺度注意力机制:用基于ReLU的全局注意力替代传统自注意力,降低计算复杂度,避免硬件低效操作;通过小核卷积生成多尺度tokens并执行全局注意力,实现全局感受野与多尺度学习。
  5. 高效部署与实时推理能力:整体设计使EfficientViT在多个部署场景下能够实现实时推理,让ViT模型更高效地应用于实际任务,如目标检测、语义分割等。

整体结构

在这里插入图片描述

EfficientViT模型架构核心是一个多阶段主干网络,其关键在于利用ReLU线性注意力机制来获取全局上下文信息,并借助深度可分离卷积强化局部信息的处理。该模型注重多尺度学习,通过聚合Q/K/V的多尺度信息来提升特征提取效果。同时,借助特征金字塔融合不同阶段的特征图,最终经上采样和简单的MBConv块输出高分辨率预测结果。

具体到EfficientViT模块:

  • ReLU线性注意力:采用改进的线性注意力机制,即ReLU线性注意力,专注于捕捉全局上下文信息,以增强模型对整体图像内容的理解。
  • 深度可分离卷积:在每个前馈网络(FFN)层中嵌入深度卷积操作,目的是捕捉局部信息,从而提升模型处理高分辨率输入的能力,使模型能够更好地关注图像细节。
  • 多尺度学习:通过对Q/K/V进行多尺度信息聚合,增强模型的多尺度特征提取能力,同时利用分离卷积来避免降低硬件效率,确保模型在不同尺度下都能有效运行。

从消融研究结果来看,在Cityscapes数据集上,以mIoU和MAC为指标进行测量,输入分辨率为1024x2048。通过调整模型宽度使其MAC相同,结果显示多尺度学习和全局感受野对于语义分割性能至关重要。EfficientViT-L2-r384在ImageNet数据集上取得了86.0的top-1精度,相比EfficientNetV2-L提升了0.3的精度,在A100 GPU上更是实现了2.6倍的加速效果。

代码实现

Conv2d_BN 类

class Conv2d_BN(torch.nn.Sequential):    def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,                 groups=1, bn_weight_init=1, resolution=-10000):        super().__init__()        self.add_module('c', torch.nn.Conv2d(            a, b, ks, stride, pad, dilation, groups, bias=False))        self.add_module('bn', torch.nn.BatchNorm2d(b))        torch.nn.init.constant_(self.bn.weight, bn_weight_init)        torch.nn.init.constant_(self.bn.bias, 0)    @torch.no_grad()    def fuse(self):        c, bn = self._modules.values()        w = bn.weight / (bn.running_var + bn.eps)**0.5        w = c.weight * w[:, None, None, None]        b = bn.bias - bn.running_mean * bn.weight / \            (bn.running_var + bn.eps)**0.5        m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(            0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)        m.weight.data.copy_(w)        m.bias.data.copy_(b)        return m


EfficientViTBlock 类

class EfficientViTBlock(torch.nn.Module):    """ A basic EfficientViT building block.    Args:        type (str): Type for token mixer. Default: 's' for self-attention.        ed (int): Number of input channels.        kd (int): Dimension for query and key in the token mixer.        nh (int): Number of attention heads.        ar (int): Multiplier for the query dim for value dimension.        resolution (int): Input resolution.        window_resolution (int): Local window resolution.        kernels (List[int]): The kernel size of the dw conv on query.    """    def __init__(self, type,                 ed, kd, nh=8,                 ar=4,                 resolution=14,                 window_resolution=7,                 kernels=[5, 5, 5, 5],):        super().__init__()        self.dw0 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))        self.ffn0 = Residual(FFN(ed, int(ed * 2), resolution))        if type == 's':            self.mixer = Residual(LocalWindowAttention(ed, kd, nh, attn_ratio=ar, \                    resolution=resolution, window_resolution=window_resolution, kernels=kernels))        self.dw1 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))        self.ffn1 = Residual(FFN(ed, int(ed * 2), resolution))        def forward(self, x):        return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))


EfficientViT 类

class EfficientViT(torch.nn.Module):    def __init__(self, img_size=400,                 patch_size=16,                 frozen_stages=0,                 in_chans=3,                 stages=['s', 's', 's'],                 embed_dim=[64, 128, 192],                 key_dim=[16, 16, 16],                 depth=[1, 2, 3],                 num_heads=[4, 4, 4],                 window_size=[7, 7, 7],                 kernels=[5, 5, 5, 5],                 down_ops=[['subsample', 2], ['subsample', 2], ['']],                 pretrained=None,                 distillation=False,):        super().__init__()        resolution = img_size        self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1, resolution=resolution), torch.nn.ReLU(),                           Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1, resolution=resolution // 2), torch.nn.ReLU(),                           Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1, resolution=resolution // 4), torch.nn.ReLU(),                           Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 1, 1, resolution=resolution // 8))        resolution = img_size // patch_size        attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]        self.blocks1 = []        self.blocks2 = []        self.blocks3 = []        for i, (stg, ed, kd, dpth, nh, ar, wd, do) in enumerate(                zip(stages, embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):            for d in range(dpth):                eval('self.blocks' + str(i+1)).append(EfficientViTBlock(stg, ed, kd, nh, ar, resolution, wd, kernels))            if do[0] == 'subsample':                #('Subsample' stride)                blk = eval('self.blocks' + str(i+2))                resolution_ = (resolution - 1) // do[1] + 1                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i], resolution=resolution)),                                    Residual(FFN(embed_dim[i], int(embed_dim[i] * 2), resolution)),))                blk.append(PatchMerging(*embed_dim[i:i + 2], resolution))                resolution = resolution_                blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1], resolution=resolution)),                                    Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2), resolution)),))        self.blocks1 = torch.nn.Sequential(*self.blocks1)        self.blocks2 = torch.nn.Sequential(*self.blocks2)        self.blocks3 = torch.nn.Sequential(*self.blocks3)        self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]        def forward(self, x):        outs = []        x = self.patch_embed(x)        x = self.blocks1(x)        outs.append(x)        x = self.blocks2(x)        outs.append(x)        x = self.blocks3(x)        outs.append(x)        return outs


本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。


欢迎投稿

想要让高质量的内容更快地触达读者,降低他们寻找优质信息的成本吗?关键在于那些你尚未结识的人。他们可能掌握着你渴望了解的知识。【AI前沿速递】愿意成为这样的一座桥梁,连接不同领域、不同背景的学者,让他们的学术灵感相互碰撞,激发出无限可能。

【AI前沿速递】欢迎各高校实验室和个人在我们的平台上分享各类精彩内容,无论是最新的论文解读,还是对学术热点的深入分析,或是科研心得和竞赛经验的分享,我们的目标只有一个:让知识自由流动。

📝 投稿指南

  • 确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。

  • 建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。

  • 【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。

📬 投稿方式

  • 您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”


    长按添加AI前沿速递小助理


AI前沿速递
持续分享最新AI前沿论文成果
 最新文章