CVPR 2024 | 单头注意力机制(SHSA),即插即用,涨点起飞!

文摘   2025-01-15 11:47   安徽  

标题:SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design

论文链接:https://arxiv.org/pdf/2401.16456

代码链接:https://github.com/ysj9909/SHViT

来源:CVPR 2024

单头自注意力模块(SHSA)

基本结构

在这里插入图片描述

SHSA模块主要包含以下几个部分:

-输入分割:将输入通道分为两部分,一部分是参与注意力计算的通道,另一部分是保持不变的通道。默认设置中,参与注意力计算的通道数占总通道数的比例为 1/4.67。

-自注意力计算:对参与注意力计算的通道进行自注意力操作。具体来说,先对进行线性变换得到查询、键和值,然后计算的点积并进行 softmax 归一化得到注意力权重,最后将注意力权重与 相乘得到注意力特征 

-特征拼接与投影:将注意力特征与保持不变的通道进行拼接,再通过一个投影层输出最终结果。

计算公式:

其中,是投影权重,是查询和键的维度,默认为 16, Concat( ) 是拼接操作。

优势特点

在这里插入图片描述

- 减少计算冗余:相比于多头自注意力机制,SHSA仅对部分通道进行注意力计算,避免了多头机制中的计算冗余,降低了计算量和内存访问成本。

- 并行结合全局和局部信息:SHSA通过将注意力特征与保持不变的通道进行拼接,能够在并行计算中同时结合全局和局部信息,提高了特征的丰富性和模型的性能。

- 内存访问高效:SHSA减少了对内存绑定操作(如 reshape 和 normalization)的使用,或者将这些操作应用于较少的输入通道,从而提高了计算效率,充分发挥了 GPU/CPUs 的计算能力。

应用场景

在这里插入图片描述

SHSA模块在SHViT模型中得到了广泛应用,用于图像分类、目标检测和实例分割等任务。例如,在ImageNet-1k图像分类任务中,SHViT-S4模型在Nvidia A100 GPU上达到了14283 images/s的吞吐量,同时取得了79.4%的Top-1准确率。在MS COCO目标检测和实例分割任务中,SHViT模型使用Mask R-CNN检测器,显著优于EfficientViT-M4等模型,同时在各种设备上展现出更低的骨干网络延迟。

代码实现

import torchimport torch.nn as nnimport torch.nn.functional as F
class GroupNorm(nn.GroupNorm): """Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs)
class Conv2d_BN(nn.Sequential): def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1): super().__init__() self.add_module('c', nn.Conv2d( a, b, ks, stride, pad, dilation, groups, bias=False)) self.add_module('bn', nn.BatchNorm2d(b)) nn.init.constant_(self.bn.weight, bn_weight_init) nn.init.constant_(self.bn.bias, 0)
@torch.no_grad() def fuse(self): c, bn = self._modules.values() w = bn.weight / (bn.running_var + bn.eps)**0.5 b = bn.bias - self.bn.running_mean * \ self.bn.weight / (bn.running_var + bn.eps)**0.5 m = nn.Conv2d(w.size(1) * self.c.groups, w.size( 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups, device=c.weight.device) m.weight.data.copy_(w) m.bias.data.copy_(b) return m
class BN_Linear(nn.Sequential): def __init__(self, a, b, bias=True, std=0.02): super().__init__() self.add_module('bn', nn.BatchNorm1d(a)) self.add_module('l', nn.Linear(a, b, bias=bias)) trunc_normal_(self.l.weight, std=std) if bias: nn.init.constant_(self.l.bias, 0)
@torch.no_grad() def fuse(self): bn, l = self._modules.values() w = bn.weight / (bn.running_var + bn.eps)**0.5 b = bn.bias - self.bn.running_mean * \ self.bn.weight / (bn.running_var + bn.eps)**0.5 b = (l.weight @ b[:, None]).view(-1) + self.l.bias m = nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m
class SHSA(nn.Module): """Single-Head Self-Attention""" def __init__(self, dim, qk_dim, pdim): super().__init__() self.scale = qk_dim ** -0.5 self.qk_dim = qk_dim self.dim = dim self.pdim = pdim self.pre_norm = GroupNorm(pdim) self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim) self.proj = nn.Sequential(nn.ReLU(), Conv2d_BN( dim, dim, bn_weight_init=0))
def forward(self, x): B, C, H, W = x.shape x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim=1) x1 = self.pre_norm(x1) qkv = self.qkv(x1) q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim=1) q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) attn = (q.transpose(-2, -1) @ k) * self.scale attn = attn.softmax(dim=-1) x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W) x = self.proj(torch.cat([x1, x2], dim=1)) return x
if __name__ == '__main__': x = torch.randn(1, 64, 32, 32) shsa = SHSA(dim=64, qk_dim=64, pdim=64) print(shsa) output = shsa(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}")


本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。


欢迎投稿

想要让高质量的内容更快地触达读者,降低他们寻找优质信息的成本吗?关键在于那些你尚未结识的人。他们可能掌握着你渴望了解的知识。【AI前沿速递】愿意成为这样的一座桥梁,连接不同领域、不同背景的学者,让他们的学术灵感相互碰撞,激发出无限可能。

【AI前沿速递】欢迎各高校实验室和个人在我们的平台上分享各类精彩内容,无论是最新的论文解读,还是对学术热点的深入分析,或是科研心得和竞赛经验的分享,我们的目标只有一个:让知识自由流动。

📝 投稿指南

  • 确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。

  • 建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。

  • 【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。

📬 投稿方式

  • 您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”


    长按添加AI前沿速递小助理


AI前沿速递
持续分享最新AI前沿论文成果
 最新文章