突破计算瓶颈:GFNet在图像分类、目标检测与语义分割中的高效创新应用

文摘   2025-01-01 17:20   中国香港  

论文介绍

题目:Global Filter Networks for Image Classification

论文地址:https://arxiv.org/pdf/2107.00645

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 频域中的全局交互建模:论文提出的Global Filter Network (GFNet) 使用频域操作替代了传统视觉Transformer中的自注意力机制。具体来说,GFNet通过2D离散傅里叶变换(FFT)将输入从空间域转换到频域,然后使用可学习的全局滤波器进行频域操作,再通过逆傅里叶变换将其转换回空间域。这种方法能够以对数线性复杂度(log-linear complexity)高效地学习长程依赖关系。

  • 计算效率显著提升:相比于基于自注意力机制(复杂度为O(L²))和传统MLP(复杂度为O(H²W²)),GFNet采用FFT和点乘操作(复杂度为O(LlogL)),在保证建模能力的同时显著降低了计算成本。

  • 减少归纳偏置(Inductive Bias):与传统的卷积神经网络不同,GFNet没有人为设计的局部感受野限制,其全局滤波器能够覆盖所有频率,支持同时捕获短期和长期的特征交互。

  • 适配高分辨率输入和分层结构:GFNet能够适配更高分辨率的输入(如从224×224到384×384的图片),并且支持CNN风格的分层结构设计(例如在特征图尺寸逐层下采样的同时应用GFNet块),从而增强了密集预测任务(如目标检测和分割)的适用性。

  • 多任务性能优异:实验表明,GFNet在ImageNet分类任务中超越了许多Transformer和MLP模型,同时在下游任务(如迁移学习、语义分割等)中表现出色。此外,其在鲁棒性和泛化能力测试中(例如对抗攻击和分布外数据)也取得了领先的结果。

方法

整体架构

     GFNet 的整体架构以视觉 Transformer 为基础,使用全局滤波层(Global Filter Layer)替代了传统的自注意力机制,通过 2D 傅里叶变换(FFT)将空间特征转换到频域,与可学习的全局滤波器逐元素相乘后,再通过逆傅里叶变换(IFFT)返回空间域,实现高效的全局特征交互;同时结合前馈网络(FFN)和层归一化(Layer Norm)处理特征,最后通过全局平均池化和分类头输出结果。这种架构计算复杂度低(O(LlogL)),适配高分辨率输入,并在分类和密集预测任务中展现优异性。

  1. 输入处理

  • 将输入图像分割成固定大小的非重叠图像块(patch),并将每个块展平形成一系列的Token。

  • 这些Token通过线性投影被嵌入到一个固定维度的表示空间。

  • 全局滤波层 (Global Filter Layer)

    • 频域转换:对输入Token特征进行2D离散傅里叶变换(2D FFT),将空间特征转换到频域表示。

    • 全局滤波:在频域中,通过与可学习的全局滤波器进行逐元素乘法(Hadamard乘积)实现特征的全局交互。

    • 逆变换:通过2D逆傅里叶变换(2D IFFT),将频域特征映射回空间域。

    • 此全局滤波层代替了传统视觉Transformer中的自注意力机制,是GFNet的核心创新点。

  • 前馈网络 (Feed Forward Network, FFN)

    • 在全局滤波层之后,应用标准的前馈网络,主要包括多层感知机(MLP)和激活函数,用于非线性特征映射。

  • 层归一化 (Layer Norm)

    • 在全局滤波层和前馈网络的输入前,应用层归一化以稳定训练过程。

  • 分类头

    • 最后一层特征通过全局平均池化(Global Average Pooling)聚合,然后通过全连接层实现最终分类。

    即插即用模块作用

    GFNet 作为一个即插即用模块

    • 图像分类任务

      • GFNet在大规模数据集(如ImageNet)上的分类任务中表现优异,尤其在高分辨率输入(如384×384图像)的情况下,能够保持较高的计算效率和准确率。

      • 它适用于需要高效处理高分辨率输入、需要减少计算成本但仍追求高精度的场景。

    • 密集预测任务(语义分割、目标检测等)

      • GFNet可以处理更大的特征图(如56×56或更高分辨率),支持密集预测任务中的特征提取和特征融合。

      • 在目标检测或语义分割场景中,GFNet通过捕获全局特征依赖关系,提高模型对复杂背景和长程依赖的理解能力。

    • 高效模型部署

      • GFNet由于其低复杂度(O(LlogL)O(L \log L))和低内存占用,非常适合部署在计算资源有限(如移动设备、嵌入式系统)或需要高吞吐量的场景。

    • 鲁棒性要求高的场景

    • GFNet在分布外数据(如ImageNet-A、ImageNet-C)和对抗攻击(如FGSM、PGD)下展现了强鲁棒性,适用于要求高稳定性和可靠性的场景,如自动驾驶、医疗影像分析等。

    消融实验结果

    内容:比较了GFNet-XS与其他基线模型(包括带不同卷积核的局部卷积、ResMLP、FNet等)在ImageNet上的性能。实验表明:

    • GFNet使用全局滤波层在性能上优于基线方法,如ResMLP(76.6% vs. 78.6%)和FNet(71.2% vs. 78.6%)。

    • GFNet的全局滤波设计在保持较低计算复杂度的同时,提供了显著的性能提升。

    • 使用局部卷积的模型(如3×3、5×5卷积)效果不如GFNet的全局滤波器,说明频域全局交互的优势。


    • 内容:在多个基准数据集(如ImageNet-A、ImageNet-C、ImageNet-V2、ImageNet-Real)上评估GFNet的鲁棒性和泛化能力。结果表明:

      • GFNet在对抗攻击(如FGSM、PGD)和分布外数据上的表现均优于传统方法(如ResNet-50和ResMLP)。

      • 泛化能力上,GFNet在ImageNet-V2和ImageNet-Real上的表现与SOTA模型相当甚至更优。

    即插即用模块

    import torch
    from torch import nn
    import math
    from timm.models.layers import DropPath, to_2tuple

    # 论文地址:https://arxiv.org/pdf/2107.00645
    # 论文:Global Filter Networks for Image Classification


    class PatchEmbed(nn.Module):
        """ Image to Patch Embedding
        """

        def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
            super().__init__()
            img_size = to_2tuple(img_size)
            patch_size = to_2tuple(patch_size)
            num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
            self.img_size = img_size
            self.patch_size = patch_size
            self.num_patches = num_patches
            self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

        def forward(self, x):
            B, C, H, W = x.shape
            # FIXME look at relaxing size constraints
            assert H == self.img_size[0] and W == self.img_size[1], \
                f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
            x = self.proj(x).flatten(2).transpose(1, 2)
            return x

    class GlobalFilter(nn.Module):
        def __init__(self, dim, h=14, w=8):
            super().__init__()
            self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2, dtype=torch.float32) * 0.02)
            self.w = w
            self.h = h

        def forward(self, x, spatial_size=None):
            B, N, C = x.shape
            if spatial_size is None:
                a = b = int(math.sqrt(N))
            else:
                a, b = spatial_size

            x = x.view(B, a, b, C)

            x = x.to(torch.float32)

            x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
            weight = torch.view_as_complex(self.complex_weight)
            x = x * weight
            x = torch.fft.irfft2(x, s=(a, b), dim=(1, 2), norm='ortho')

            x = x.reshape(B, N, C)
            return x

    class Mlp(nn.Module):
        def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
            super().__init__()
            out_features = out_features or in_features
            hidden_features = hidden_features or in_features
            self.fc1 = nn.Linear(in_features, hidden_features)
            self.act = act_layer()
            self.fc2 = nn.Linear(hidden_features, out_features)
            self.drop = nn.Dropout(drop)

        def forward(self, x):
            x = self.fc1(x)
            x = self.act(x)
            x = self.drop(x)
            x = self.fc2(x)
            x = self.drop(x)
            return x

    class Block(nn.Module):
        def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, h=14, w=8):
            super().__init__()
            self.norm1 = norm_layer(dim)
            self.filter = GlobalFilter(dim, h=h, w=w)
            self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
            self.norm2 = norm_layer(dim)
            mlp_hidden_dim = int(dim * mlp_ratio)
            self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        def forward(self, x):
            x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
            return x


    class GFNet(nn.Module):
        def __init__(self, embed_dim=384, img_size=224, patch_size=16, mlp_ratio=4, depth=4, num_classes=1000):
            super().__init__()
            self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, embed_dim=embed_dim)
            self.embedding = nn.Linear((patch_size ** 2) * 3, embed_dim)

            h = img_size // patch_size
            w = h // 2 + 1


            self.blocks = nn.ModuleList([
                Block(dim=embed_dim, mlp_ratio=mlp_ratio, h=h, w=w)
                for i in range(depth)
            ])

            self.head = nn.Linear(embed_dim, num_classes)
            self.softmax = nn.Softmax(1)

        def forward(self, x):
            x = self.patch_embed(x)
            for blk in self.blocks:
                x = blk(x)
            x = x.mean(dim=1)
            x = self.softmax(self.head(x))
            return x

    if __name__ == '__main__':
        input = torch.randn(1, 3, 224, 224)
        block = GFNet(embed_dim=384, img_size=224, patch_size=16, num_classes=1000)
        out = block(input)    print(out.shape)

    便捷下载方式

    浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

    更多分析可见原文


    ai缝合大王
    聚焦AI前沿,分享相关技术、论文,研究生自救指南
     最新文章