标题:SHViT: Single-Head Vision Transformer with Memory Efficient Macro Design
论文链接:https://arxiv.org/pdf/2401.16456
代码链接:https://github.com/ysj9909/SHViT
来源:CVPR 2024
单头自注意力模块(SHSA)
基本结构
SHSA模块主要包含以下几个部分:
-输入分割:将输入通道分为两部分,一部分是参与注意力计算的通道,另一部分是保持不变的通道。默认设置中,参与注意力计算的通道数占总通道数的比例为 1/4.67。
-自注意力计算:对参与注意力计算的通道进行自注意力操作。具体来说,先对进行线性变换得到查询、键和值,然后计算和的点积并进行 softmax 归一化得到注意力权重,最后将注意力权重与 相乘得到注意力特征
-特征拼接与投影:将注意力特征与保持不变的通道进行拼接,再通过一个投影层输出最终结果。
计算公式:
其中,是投影权重,是查询和键的维度,默认为 16, Concat( ) 是拼接操作。
优势特点
- 减少计算冗余:相比于多头自注意力机制,SHSA仅对部分通道进行注意力计算,避免了多头机制中的计算冗余,降低了计算量和内存访问成本。
- 并行结合全局和局部信息:SHSA通过将注意力特征与保持不变的通道进行拼接,能够在并行计算中同时结合全局和局部信息,提高了特征的丰富性和模型的性能。
- 内存访问高效:SHSA减少了对内存绑定操作(如 reshape 和 normalization)的使用,或者将这些操作应用于较少的输入通道,从而提高了计算效率,充分发挥了 GPU/CPUs 的计算能力。
应用场景
SHSA模块在SHViT模型中得到了广泛应用,用于图像分类、目标检测和实例分割等任务。例如,在ImageNet-1k图像分类任务中,SHViT-S4模型在Nvidia A100 GPU上达到了14283 images/s的吞吐量,同时取得了79.4%的Top-1准确率。在MS COCO目标检测和实例分割任务中,SHViT模型使用Mask R-CNN检测器,显著优于EfficientViT-M4等模型,同时在各种设备上展现出更低的骨干网络延迟。
代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupNorm(nn.GroupNorm):
"""Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
class Conv2d_BN(nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1):
super().__init__()
self.add_module('c', nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', nn.BatchNorm2d(b))
nn.init.constant_(self.bn.weight, bn_weight_init)
nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps)**0.5
m = nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
device=c.weight.device)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class BN_Linear(nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', nn.BatchNorm1d(a))
self.add_module('l', nn.Linear(a, b, bias=bias))
trunc_normal_(self.l.weight, std=std)
if bias:
nn.init.constant_(self.l.bias, 0)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps)**0.5
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class SHSA(nn.Module):
"""Single-Head Self-Attention"""
def __init__(self, dim, qk_dim, pdim):
super().__init__()
self.scale = qk_dim ** -0.5
self.qk_dim = qk_dim
self.dim = dim
self.pdim = pdim
self.pre_norm = GroupNorm(pdim)
self.qkv = Conv2d_BN(pdim, qk_dim * 2 + pdim)
self.proj = nn.Sequential(nn.ReLU(), Conv2d_BN(
dim, dim, bn_weight_init=0))
def forward(self, x):
B, C, H, W = x.shape
x1, x2 = torch.split(x, [self.pdim, self.dim - self.pdim], dim=1)
x1 = self.pre_norm(x1)
qkv = self.qkv(x1)
q, k, v = qkv.split([self.qk_dim, self.qk_dim, self.pdim], dim=1)
q, k, v = q.flatten(2), k.flatten(2), v.flatten(2)
attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim=-1)
x1 = (v @ attn.transpose(-2, -1)).reshape(B, self.pdim, H, W)
x = self.proj(torch.cat([x1, x2], dim=1))
return x
if __name__ == '__main__':
x = torch.randn(1, 64, 32, 32)
shsa = SHSA(dim=64, qk_dim=64, pdim=64)
print(shsa)
output = shsa(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。
建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。
【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。
您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”
长按添加AI前沿速递小助理