(ECCV 2022)即插即用门控通道注意力机制NAF,涨点起飞起飞了!

文摘   2024-12-10 17:20   上海  

论文介绍

题目:Simple Baselines for Image Restoration

论文地址:https://arxiv.org/abs/2204.04676

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出了一个简单的基线模型:论文通过对当前先进(SOTA)方法的分解提取其核心组件,设计了一个结构简单但性能超越SOTA的基线模型,具有较低的系统复杂性和计算成本。

  • 非线性激活函数的移除:揭示了非线性激活函数(例如Sigmoid、ReLU、GELU)可能并非必要。通过用简单的线性操作(如元素乘法)代替非线性激活,进一步简化了基线模型,提出了非线性激活函数自由网络(NAFNet)。

  • 高效性和性能兼顾:提出的NAFNet在多种图像恢复任务(如去噪和去模糊)中达到了SOTA性能,同时计算成本显著降低。例如,在SIDD去噪数据集上,NAFNet以不到一半的计算成本实现了比现有最优方法更高的PSNR。

  • 组件设计的新视角:通过分析和简化网络模块,提出了“简化通道注意力”和“简单门控”两种模块,用更简单的设计取代了传统的复杂模块,而性能不降反升。

  • 理论意义与实践价值:这是首次证明在计算机视觉任务中,非线性激活函数并非实现SOTA性能的必需元素,可能扩展未来计算机视觉模型的设计空间,同时提供了一种更简洁的实验基准。

方法

整体架构

     该模型采用经典的单阶段U型网络(UNet)结构,通过编码器和解码器的对称设计以及跳跃连接实现多尺度特征处理。模块化设计中引入了深度卷积、简单门控(Simple Gate)和简化通道注意力(Simplified Channel Attention),完全移除了非线性激活函数,并使用层归一化稳定训练。模型在保持高效计算的同时,显著简化了块间和块内的复杂性,适用于图像去噪、去模糊等多种任务,达到了SOTA性能。

1. 整体架构采用单阶段U型结构(UNet Architecture)

  • 模型整体结构基于经典的单阶段U型网络(UNet),这一结构广泛应用于图像恢复任务。

  • 具体特性:

    • 跳跃连接(Skip Connections):通过跳跃连接保留高分辨率的特征。

    • 对称结构:包括编码器和解码器部分,编码器提取多尺度特征,解码器逐步恢复分辨率。

    • 简单的特征融合:编码器和解码器之间采用简单的逐元素加法进行特征融合。

2. 模块化设计

  • 网络由多个块(Blocks)堆叠而成,每个块的设计通过去复杂化后得到更简单高效的结构。

  • 主要组件包括:

    • 深度卷积(Depthwise Convolution):用于高效提取局部特征。

    • 简单门控(Simple Gate):用两个特征图的逐元素乘法取代了传统的非线性激活函数(如GELU)。

    • 简化通道注意力(Simplified Channel Attention):通过全局池化生成注意力权重,去除了多层感知机和非线性激活函数。

    • 层归一化(Layer Normalization):用于稳定训练并提高性能。

3. 简化设计

  • 移除非线性激活函数:模型完全去除了传统的激活函数(如ReLU、Sigmoid等),依靠门控机制(Simple Gate)来提供非线性能力。

  • 简化的计算操作:减少了复杂的特征融合方式和冗余模块,使得计算效率更高。

  • 模块内部结构优化

    • 编码和解码器的特征通过简单的卷积和点操作进行变换。

    • 每个模块都注重减少计算复杂度,保持网络的整体轻量化。

4. 多尺度特征处理

  • 尽管采用了单阶段结构,模型通过U型网络的特性很好地处理了多尺度特征,避免了复杂的多阶段设计。

  • 网络整体设计保持了低块间复杂度(Inter-block Complexity)和低块内复杂度(Intra-block Complexity)。

即插即用模块作用

NAF 作为一个即插即用模块

  • 提升模型的性能:NAF模块通过有效的通道注意力和简单的门控机制,减少了传统模块中非必要的复杂操作,使得在图像恢复任务中的PSNR和SSIM表现均超过现有方法。

  • 降低计算复杂度:由于移除了非线性激活函数,并采用更简单的设计,NAF模块能够在保证性能的同时,显著减少计算成本(如MACs和延迟)。

  • 模块化灵活性:作为即插即用的组件嵌入到现有的深度学习架构中,无需对整体网络进行大幅度修改。

  • 训练稳定性:NAF模块因其结构简单,训练时对学习率的要求降低,易于优化,减少了训练过程中的不稳定性。

  • 高效性与扩展性:在多任务设置中表现出色,模块本身易于扩展,适配不同模型规模和任务需求。

消融实验结果

      展示了从基础的PlainNet逐步构建到性能更强的基线模型的过程。通过引入Layer Normalization(LN)、使用GELU替代ReLU,以及添加通道注意力模块(Channel Attention, CA),分别验证了这些设计选择对性能提升的贡献。其中,LN显著提高了训练稳定性和性能,GELU对去模糊任务的提升较大,而CA则在去噪和去模糊任务上均有小幅提升,最终构建的基线模型在SIDD和GoPro数据集上实现了较优的PSNR表现。

       展示了将基线模型进一步简化为NAFNet的实验过程。通过将GELU替换为Simple Gate和将Channel Attention替换为Simplified Channel Attention,网络的复杂性降低,同时性能在SIDD和GoPro数据集上仍有小幅提升。这表明,非线性激活函数和复杂的注意力机制并非必要,可以用更简单的设计实现甚至超过SOTA性能。

        探讨了模型中块数量对性能和延迟的影响。实验表明,随着块数量从9增加到36,模型的PSNR显著提升,但延迟仅小幅增加;当块数量进一步增加到72时,性能提升趋于饱和,而延迟大幅增加。因此,36块的设置在性能与效率之间实现了良好的平衡,成为默认选择。

即插即用模块

import torch
import torch.nn as nn

'''
Simple Baselines for Image Restoration
https://arxiv.org/abs/2204.04676
'''


class LayerNormFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, weight, bias, eps):
        ctx.eps = eps
        N, C, H, W = x.size()
        mu = x.mean(1, keepdim=True)
        var = (x - mu).pow(2).mean(1, keepdim=True)
        y = (x - mu) / (var + eps).sqrt()
        ctx.save_for_backward(y, var, weight)
        y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
        return y

    @staticmethod
    def backward(ctx, grad_output):
        eps = ctx.eps

        N, C, H, W = grad_output.size()
        y, var, weight = ctx.saved_variables
        g = grad_output * weight.view(1, C, 1, 1)
        mean_g = g.mean(dim=1, keepdim=True)

        mean_gy = (g * y).mean(dim=1, keepdim=True)
        gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
        return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
            dim=0), None

class LayerNorm2d(nn.Module):

    def __init__(self, channels, eps=1e-6):
        super(LayerNorm2d, self).__init__()
        self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
        self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
        self.eps = eps

    def forward(self, x):
        return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)

class SimpleGate(nn.Module):
    def forward(self, x):
        x1, x2 = x.chunk(2, dim=1)
        return x1 * x2

class NAFBlock(nn.Module):
    def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
        super().__init__()
        dw_channel = c * DW_Expand
        self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1,
                               bias=True)
        self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1,
                               groups=dw_channel,
                               bias=True)
        self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
                               groups=1, bias=True)

        # Simplified Channel Attention
        self.sca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
                      groups=1, bias=True),
        )

        # SimpleGate
        self.sg = SimpleGate()

        ffn_channel = FFN_Expand * c
        self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1,
                               bias=True)
        self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1,
                               groups=1, bias=True)

        self.norm1 = LayerNorm2d(c)
        self.norm2 = LayerNorm2d(c)

        self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
        self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()

        self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
        self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)

    def forward(self, inp):
        x = inp

        x = self.norm1(x)

        x = self.conv1(x)
        x = self.conv2(x)
        x = self.sg(x)
        x = x * self.sca(x)
        x = self.conv3(x)

        x = self.dropout1(x)

        y = inp + x * self.beta

        x = self.conv4(self.norm2(y))
        x = self.sg(x)
        x = self.conv5(x)

        x = self.dropout2(x)

        return y + x * self.gamma


if __name__ == '__main__':


    block = NAFBlock(c=64)
    input = torch.rand(1, 64, 128, 128)
    output = block(input)

    print(input.size())    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章