论文介绍
题目:Global Filter Networks for Image Classification
论文地址:https://arxiv.org/pdf/2107.00645
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
频域中的全局交互建模:论文提出的Global Filter Network (GFNet) 使用频域操作替代了传统视觉Transformer中的自注意力机制。具体来说,GFNet通过2D离散傅里叶变换(FFT)将输入从空间域转换到频域,然后使用可学习的全局滤波器进行频域操作,再通过逆傅里叶变换将其转换回空间域。这种方法能够以对数线性复杂度(log-linear complexity)高效地学习长程依赖关系。
计算效率显著提升:相比于基于自注意力机制(复杂度为O(L²))和传统MLP(复杂度为O(H²W²)),GFNet采用FFT和点乘操作(复杂度为O(LlogL)),在保证建模能力的同时显著降低了计算成本。
减少归纳偏置(Inductive Bias):与传统的卷积神经网络不同,GFNet没有人为设计的局部感受野限制,其全局滤波器能够覆盖所有频率,支持同时捕获短期和长期的特征交互。
适配高分辨率输入和分层结构:GFNet能够适配更高分辨率的输入(如从224×224到384×384的图片),并且支持CNN风格的分层结构设计(例如在特征图尺寸逐层下采样的同时应用GFNet块),从而增强了密集预测任务(如目标检测和分割)的适用性。
多任务性能优异:实验表明,GFNet在ImageNet分类任务中超越了许多Transformer和MLP模型,同时在下游任务(如迁移学习、语义分割等)中表现出色。此外,其在鲁棒性和泛化能力测试中(例如对抗攻击和分布外数据)也取得了领先的结果。
方法
整体架构
GFNet 的整体架构以视觉 Transformer 为基础,使用全局滤波层(Global Filter Layer)替代了传统的自注意力机制,通过 2D 傅里叶变换(FFT)将空间特征转换到频域,与可学习的全局滤波器逐元素相乘后,再通过逆傅里叶变换(IFFT)返回空间域,实现高效的全局特征交互;同时结合前馈网络(FFN)和层归一化(Layer Norm)处理特征,最后通过全局平均池化和分类头输出结果。这种架构计算复杂度低(O(LlogL)),适配高分辨率输入,并在分类和密集预测任务中展现优异性。
输入处理:
将输入图像分割成固定大小的非重叠图像块(patch),并将每个块展平形成一系列的Token。
这些Token通过线性投影被嵌入到一个固定维度的表示空间。
全局滤波层 (Global Filter Layer):
频域转换:对输入Token特征进行2D离散傅里叶变换(2D FFT),将空间特征转换到频域表示。
全局滤波:在频域中,通过与可学习的全局滤波器进行逐元素乘法(Hadamard乘积)实现特征的全局交互。
逆变换:通过2D逆傅里叶变换(2D IFFT),将频域特征映射回空间域。
此全局滤波层代替了传统视觉Transformer中的自注意力机制,是GFNet的核心创新点。
前馈网络 (Feed Forward Network, FFN):
在全局滤波层之后,应用标准的前馈网络,主要包括多层感知机(MLP)和激活函数,用于非线性特征映射。
层归一化 (Layer Norm):
在全局滤波层和前馈网络的输入前,应用层归一化以稳定训练过程。
分类头:
最后一层特征通过全局平均池化(Global Average Pooling)聚合,然后通过全连接层实现最终分类。
即插即用模块作用
GFNet 作为一个即插即用模块:
图像分类任务:
GFNet在大规模数据集(如ImageNet)上的分类任务中表现优异,尤其在高分辨率输入(如384×384图像)的情况下,能够保持较高的计算效率和准确率。
它适用于需要高效处理高分辨率输入、需要减少计算成本但仍追求高精度的场景。
密集预测任务(语义分割、目标检测等):
GFNet可以处理更大的特征图(如56×56或更高分辨率),支持密集预测任务中的特征提取和特征融合。
在目标检测或语义分割场景中,GFNet通过捕获全局特征依赖关系,提高模型对复杂背景和长程依赖的理解能力。
高效模型部署:
GFNet由于其低复杂度(
)和低内存占用,非常适合部署在计算资源有限(如移动设备、嵌入式系统)或需要高吞吐量的场景。O ( L log L ) O(L \log L) 鲁棒性要求高的场景:
GFNet在分布外数据(如ImageNet-A、ImageNet-C)和对抗攻击(如FGSM、PGD)下展现了强鲁棒性,适用于要求高稳定性和可靠性的场景,如自动驾驶、医疗影像分析等。
消融实验结果
内容:比较了GFNet-XS与其他基线模型(包括带不同卷积核的局部卷积、ResMLP、FNet等)在ImageNet上的性能。实验表明:
GFNet使用全局滤波层在性能上优于基线方法,如ResMLP(76.6% vs. 78.6%)和FNet(71.2% vs. 78.6%)。
GFNet的全局滤波设计在保持较低计算复杂度的同时,提供了显著的性能提升。
使用局部卷积的模型(如3×3、5×5卷积)效果不如GFNet的全局滤波器,说明频域全局交互的优势。
GFNet在对抗攻击(如FGSM、PGD)和分布外数据上的表现均优于传统方法(如ResNet-50和ResMLP)。
泛化能力上,GFNet在ImageNet-V2和ImageNet-Real上的表现与SOTA模型相当甚至更优。
内容:在多个基准数据集(如ImageNet-A、ImageNet-C、ImageNet-V2、ImageNet-Real)上评估GFNet的鲁棒性和泛化能力。结果表明:
即插即用模块
import torch
from torch import nn
import math
from timm.models.layers import DropPath, to_2tuple
# 论文地址:https://arxiv.org/pdf/2107.00645
# 论文:Global Filter Networks for Image Classification
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
# FIXME look at relaxing size constraints
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
return x
class GlobalFilter(nn.Module):
def __init__(self, dim, h=14, w=8):
super().__init__()
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
self.w = w
self.h = h
def forward(self, x, spatial_size=None):
B, N, C = x.shape
if spatial_size is None:
a = b = int(math.sqrt(N))
else:
a, b = spatial_size
x = x.view(B, a, b, C)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
weight = torch.view_as_complex(self.complex_weight)
x = x * weight
x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')
x = x.reshape(B, N, C)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
super().__init__()
self.norm1 = norm_layer(dim)
self.filter = GlobalFilter(dim, h=h, w=w)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
return x
class GFNet(nn.Module):
def __init__(self, embed_dim=384, img_size=224, patch_size=16, mlp_ratio=4, depth=4, num_classes=1000):
super().__init__()
self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
self.embedding = nn.Linear((patch_size ** 2) * 3, embed_dim)
h = img_size // patch_size
w = h // 2 + 1
self.blocks = nn.ModuleList([
Block(dim=embed_dim, mlp_ratio=mlp_ratio, h=h, w=w)
for i in range(depth)
])
self.head = nn.Linear(embed_dim, num_classes)
self.softmax = nn.Softmax(1)
def forward(self, x):
x = self.patch_embed(x)
for blk in self.blocks:
x = blk(x)
x = x.mean(dim=1)
x = self.softmax(self.head(x))
return x
if __name__ == '__main__':
input = torch.randn(1, 3, 224, 224)
block = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)
out = block(input) print(out.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文