点击下方卡片,关注“AI前沿速递”公众号
点击下方卡片,关注“AI前沿速递”公众号
各种重磅干货,第一时间送达
各种重磅干货,第一时间送达
标题:EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
论文链接:https://arxiv.org/pdf/2205.14756
代码链接:https://github.com/mit-han-lab/efficientvit**
创新点
**
三明治布局设计:在FFN层间插入单个受内存限制的MHSA层,减少其时间成本,同时增加FFN层以增强通道通信,提升内存效率。 级联组注意力模块:将特征分割成不同部分提供给各注意力头,避免注意力头输入相同特征,减少计算冗余,还通过级联输出特征增加网络深度与注意力多样性。 参数重分配策略:扩大关键组件如值投影的通道宽度,缩小次要组件如FFN隐藏维度,优化参数使用,提高模型参数效率。 轻量级多尺度注意力机制:用基于ReLU的全局注意力替代传统自注意力,降低计算复杂度,避免硬件低效操作;通过小核卷积生成多尺度tokens并执行全局注意力,实现全局感受野与多尺度学习。 高效部署与实时推理能力:整体设计使EfficientViT在多个部署场景下能够实现实时推理,让ViT模型更高效地应用于实际任务,如目标检测、语义分割等。
整体结构
EfficientViT模型架构核心是一个多阶段主干网络,其关键在于利用ReLU线性注意力机制来获取全局上下文信息,并借助深度可分离卷积强化局部信息的处理。该模型注重多尺度学习,通过聚合Q/K/V的多尺度信息来提升特征提取效果。同时,借助特征金字塔融合不同阶段的特征图,最终经上采样和简单的MBConv块输出高分辨率预测结果。
具体到EfficientViT模块:
ReLU线性注意力:采用改进的线性注意力机制,即ReLU线性注意力,专注于捕捉全局上下文信息,以增强模型对整体图像内容的理解。 深度可分离卷积:在每个前馈网络(FFN)层中嵌入深度卷积操作,目的是捕捉局部信息,从而提升模型处理高分辨率输入的能力,使模型能够更好地关注图像细节。 多尺度学习:通过对Q/K/V进行多尺度信息聚合,增强模型的多尺度特征提取能力,同时利用分离卷积来避免降低硬件效率,确保模型在不同尺度下都能有效运行。
从消融研究结果来看,在Cityscapes数据集上,以mIoU和MAC为指标进行测量,输入分辨率为1024x2048。通过调整模型宽度使其MAC相同,结果显示多尺度学习和全局感受野对于语义分割性能至关重要。EfficientViT-L2-r384在ImageNet数据集上取得了86.0的top-1精度,相比EfficientNetV2-L提升了0.3的精度,在A100 GPU上更是实现了2.6倍的加速效果。
代码实现
Conv2d_BN 类
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
EfficientViTBlock 类
class EfficientViTBlock(torch.nn.Module):
""" A basic EfficientViT building block.
Args:
type (str): Type for token mixer. Default: 's' for self-attention.
ed (int): Number of input channels.
kd (int): Dimension for query and key in the token mixer.
nh (int): Number of attention heads.
ar (int): Multiplier for the query dim for value dimension.
resolution (int): Input resolution.
window_resolution (int): Local window resolution.
kernels (List[int]): The kernel size of the dw conv on query.
"""
def __init__(self, type,
ed, kd, nh=8,
ar=4,
resolution=14,
window_resolution=7,
kernels=[5, 5, 5, 5],):
super().__init__()
self.dw0 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))
self.ffn0 = Residual(FFN(ed, int(ed * 2), resolution))
if type == 's':
self.mixer = Residual(LocalWindowAttention(ed, kd, nh, attn_ratio=ar, \
resolution=resolution, window_resolution=window_resolution, kernels=kernels))
self.dw1 = Residual(Conv2d_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0., resolution=resolution))
self.ffn1 = Residual(FFN(ed, int(ed * 2), resolution))
def forward(self, x):
return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
EfficientViT 类
class EfficientViT(torch.nn.Module):
def __init__(self, img_size=400,
patch_size=16,
frozen_stages=0,
in_chans=3,
stages=['s', 's', 's'],
embed_dim=[64, 128, 192],
key_dim=[16, 16, 16],
depth=[1, 2, 3],
num_heads=[4, 4, 4],
window_size=[7, 7, 7],
kernels=[5, 5, 5, 5],
down_ops=[['subsample', 2], ['subsample', 2], ['']],
pretrained=None,
distillation=False,):
super().__init__()
resolution = img_size
self.patch_embed = torch.nn.Sequential(Conv2d_BN(in_chans, embed_dim[0] // 8, 3, 2, 1, resolution=resolution), torch.nn.ReLU(),
Conv2d_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1, resolution=resolution // 2), torch.nn.ReLU(),
Conv2d_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1, resolution=resolution // 4), torch.nn.ReLU(),
Conv2d_BN(embed_dim[0] // 2, embed_dim[0], 3, 1, 1, resolution=resolution // 8))
resolution = img_size // patch_size
attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
self.blocks1 = []
self.blocks2 = []
self.blocks3 = []
for i, (stg, ed, kd, dpth, nh, ar, wd, do) in enumerate(
zip(stages, embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
for d in range(dpth):
eval('self.blocks' + str(i+1)).append(EfficientViTBlock(stg, ed, kd, nh, ar, resolution, wd, kernels))
if do[0] == 'subsample':
#('Subsample' stride)
blk = eval('self.blocks' + str(i+2))
resolution_ = (resolution - 1) // do[1] + 1
blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i], resolution=resolution)),
Residual(FFN(embed_dim[i], int(embed_dim[i] * 2), resolution)),))
blk.append(PatchMerging(*embed_dim[i:i + 2], resolution))
resolution = resolution_
blk.append(torch.nn.Sequential(Residual(Conv2d_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1], resolution=resolution)),
Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2), resolution)),))
self.blocks1 = torch.nn.Sequential(*self.blocks1)
self.blocks2 = torch.nn.Sequential(*self.blocks2)
self.blocks3 = torch.nn.Sequential(*self.blocks3)
self.channel = [i.size(1) for i in self.forward(torch.randn(1, 3, 640, 640))]
def forward(self, x):
outs = []
x = self.patch_embed(x)
x = self.blocks1(x)
outs.append(x)
x = self.blocks2(x)
outs.append(x)
x = self.blocks3(x)
outs.append(x)
return outs
确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。
建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。
【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。
您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”
长按添加AI前沿速递小助理