点击下方卡片,关注“AI前沿速递”公众号
点击下方卡片,关注“AI前沿速递”公众号
各种重磅干货,第一时间送达
各种重磅干货,第一时间送达
论文链接:https://arxiv.org/pdf/2303.03667
代码链接:https://github.com/JierunChen/FasterNet
来源:CVPR 2023
PConv(Partial Convolution)模块
PConv 是 FasterNet 的核心模块,其设计目的是通过减少冗余计算和内存访问来提高计算效率。 PConv 只对输入通道的一部分进行卷积操作,而保持其余通道不变。具体来说,PConv 的设计如下:
- 输入:输入特征图。
- 部分卷积:只对其中的个通道进行卷积操作,其余个通道保持不变。
- 计算量:PConv 的计算量(FLOPs)为,相比常规卷积大幅减少。
- 内存访问:PConv 的内存访问量为也显著减少。·实现:PConv通过 split 、conv 和 cat 操作实现。
- 优势:相比常规卷积,PConv 的计算量(FLOPs)大幅减少,仅为常规卷积的 (以部分比为例);同时,其内存访问量也显著降低,仅为常规卷积的。
- 与PWConv结合:PConv 后接一个逐点卷积 (PWConv),可以更好地利用所有通道的信息。这 种组合在输入特征图上的有效感受野类似于T形卷积,更关注中心位置,与常规卷积均匀处理一个区域的方式不同。而且,将T形卷积分解为 PConv 和 PWConv 可以进一步利用滤波器间的冗余, 节省计算量。
FasterNet 模块
FasterNet 是基于 PConv 和 PWConv 构建的神经网络,具有以下结构:
- 整体架构:FasterNet 包含四个层次阶段,每个阶段前有一个嵌入层(4×4 的常规卷积,步长为 4)或合并层(2×2 的常规卷积,步长为 2),用于空间下采样和通道数扩展。
- FasterNet 块:每个阶段包含多个 FasterNet 块,每个 FasterNet 块包含一个 PConv 层后接两个 PWConv(或 1×1 卷积)层。这些层构成一个倒置残差块,中间层通道数扩展,并且有一个快捷连接用于重用输入特征。
- 归一化和激活层:只在每个中间 PWConv 后放置归一化和激活层,以保持特征多样性并降低延迟。使用批量归一化(BN)而不是其他替代方案,因为 BN 可以合并到相邻的卷积层中以加快推理速度。对于激活层,小的 FasterNet 变体使用 GELU,大的 FasterNet 变体使用 ReLU。
- 分类层:最后三层用于特征转换和分类,包括全局平均池化、1×1 卷积和全连接层。
FasterNet 的变体
FasterNet 提供了多种变体,以满足不同计算预算的需求:- FasterNet-T0/1/2:小型变体,适用于资源受限的设备。- FasterNet-S:中等变体,适用于一般的计算任务。- FasterNet-M:较大变体,适用于需要更高精度的任务。- FasterNet-L:大型变体,适用于高性能计算任务。 这些变体在深度和宽度上有所不同,但整体架构保持一致。
代码实现
import torch
import torch.nn as nn
from pyzjr.Models.bricks import DropPath
class PartialConv(nn.Module):
def __init__(self, dim, n_div=4, kernel_size=3, forward='split_cat'):
super().__init__()
self.dim_conv3 = dim // n_div
self.dim_untouched = dim - self.dim_conv3
self.partial_conv3 = nn.Conv2d(self.dim_conv3, self.dim_conv3, kernel_size, 1, 1, bias=False)
if forward == 'slicing':
self.forward = self.forward_slicing
elif forward == 'split_cat':
self.forward = self.forward_split_cat
else:
raise NotImplementedError
def forward_slicing(self, x):
x = x.clone()
x[:, :self.dim_conv3, :, :] = self.partial_conv3(x[:, :self.dim_conv3, :, :])
return x
def forward_split_cat(self, x):
x1, x2 = torch.split(x, [self.dim_conv3, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
return x
class FasterNetBlock(nn.Module):
def __init__(self, dim, expand_ratio=2, act_layer=nn.ReLU, drop_path_rate=0.0, forward='split_cat'):
super().__init__()
self.pconv = PartialConv(dim, forward=forward)
self.conv1 = nn.Conv2d(dim, dim * expand_ratio, 1, bias=False)
self.bn = nn.BatchNorm2d(dim * expand_ratio)
self.act_layer = act_layer()
self.conv2 = nn.Conv2d(dim * expand_ratio, dim, 1, bias=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
def forward(self, x):
residual = x
x = self.pconv(x)
x = self.conv1(x)
x = self.bn(x)
x = self.act_layer(x)
x = self.conv2(x)
x = residual + self.drop_path(x)
return x
class FasterNet(nn.Module):
def __init__(self, in_channel=3, embed_dim=40, act_layer=None, num_classes=1000, depths=None, drop_rate=0.0):
super().__init__()
self.stem = nn.Sequential(
nn.Conv2d(in_channel, embed_dim, 4, stride=4, bias=False),
nn.BatchNorm2d(embed_dim),
act_layer()
)
drop_path_list = [x.item() for x in torch.linspace(0, drop_rate, sum(depths))]
self.feature = []
embed_dim = embed_dim
for idx, depth in enumerate(depths):
self.feature.append(nn.Sequential(
*[FasterNetBlock(embed_dim, act_layer=act_layer, drop_path_rate=drop_path_list[sum(depths[:idx]) + i]) for i in range(depth)]
))
if idx < len(depths) - 1:
self.feature.append(nn.Sequential(
nn.Conv2d(embed_dim, embed_dim * 2, 2, stride=2, bias=False),
nn.BatchNorm2d(embed_dim * 2),
act_layer()
))
embed_dim = embed_dim * 2
self.feature = nn.Sequential(*self.feature)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(embed_dim, 1280, 1, bias=False)
self.act_layer = act_layer()
self.fc = nn.Linear(1280, num_classes)
def forward(self, x):
x = self.stem(x)
x = self.feature(x)
x = self.avg_pool(x)
x = self.conv1(x)
x = self.act_layer(x)
x = self.fc(x.flatten(1))
return x
本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。