(IEEE TIP)即插即用多尺度特征提取模块MSB,涨点起飞!

文摘   2024-12-04 17:20   中国香港  

论文介绍

题目:Mix Structure Block contains multi-scale parallel large convolution kernel module and enhanced parallel attention module

论文地址:https://arxiv.org/abs/2305.17654

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出了“Mix Structure Block”:论文提出了一种新的网络结构——Mix Structure Block (MSB),通过将不同类型的卷积操作(如深度可分离卷积和普通卷积)结合起来,从而提升了图像去雾的效果。MSB通过混合不同的特征处理方式,增强了网络对图像特征的提取能力。

  • 改进的图像去雾网络架构:传统的去雾网络往往依赖于复杂的结构和大量的计算,MixDehazeNet通过引入Mix Structure Block,不仅提高了去雾效果,而且减少了计算复杂度,具有更好的效率和性能。

  • 高效的多尺度特征提取:论文采用多尺度特征融合的策略来增强网络对雾霾的处理能力。这一策略有助于网络捕捉到不同尺度的图像特征,从而更准确地恢复被雾霾遮蔽的图像细节。

  • 实验验证与性能提升:通过大量实验验证,MixDehazeNet在多个标准去雾数据集上都展现了较传统方法更好的去雾效果,尤其在视觉质量和PSNR(Peak Signal-to-Noise Ratio)等评估指标上,取得了显著的提升。

  • 网络结构的可扩展性:MixDehazeNet的网络架构具有良好的可扩展性,可以方便地调整和扩展到不同的去雾任务和不同的输入图像类型,使得该方法具备较强的适应性。

方法

整体架构

  • 输入层

    • 输入是经过雾霾污染的图像,网络的目标是去除图像中的雾霾,并恢复清晰图像。

  • Mix Structure Block (MSB)

    • 标准卷积(Standard Convolution):用于捕捉较大范围的图像特征。

    • 深度可分离卷积(Depthwise Separable Convolution):通过分离空间卷积和逐通道卷积,减少了计算量,并能更有效地捕捉细节特征。

    • 论文的核心创新是提出了Mix Structure Block (MSB),这种模块通过融合不同类型的卷积操作(如深度可分离卷积和标准卷积),有效地提升了特征提取能力。

    • 每个MSB由多个卷积层组成,包括:

  • 多尺度特征提取模块

    • 在网络中,使用了多尺度特征融合的策略来增强图像去雾的效果。通过结合不同尺度的特征图,网络能够更好地恢复不同层次的图像信息,特别是处理雾霾遮挡的部分。

    • 这一模块可以通过不同尺寸的卷积核来处理图像,增强网络的表现力。

  • 残差连接

    • 为了帮助信息流的传递和梯度的反向传播,论文中的网络架构采用了残差连接(Residual Connections)。这种连接方式使得网络能够更有效地训练,并且避免了深层网络中的梯度消失问题。

  • 特征融合与恢复模块

    • 在去雾过程中,网络会逐步将不同层次的特征图进行融合,并通过上采样或者解码操作恢复原始图像的细节。最终,网络会输出去除雾霾后的清晰图像。

  • 损失函数

    • 论文中使用了标准的像素级损失函数(如L1损失),以及对图像结构的损失函数(如SSIM损失),来优化模型的训练过程。这些损失函数的结合有助于提高图像去雾的质量和结构保持。

即插即用模块作用

MSB 作为一个即插即用模块

  • 特征提取能力的增强:MSB通过结合不同类型的卷积操作(深度可分离卷积与标准卷积),能够更有效地提取多尺度的图像特征,尤其是在细节恢复方面表现优越。这样可以帮助网络更好地处理雾霾遮挡的图像细节。

  • 提高去雾效果:在图像去雾任务中,MSB能够通过多层卷积操作增强对图像中各个层次特征的捕捉,进而提高去雾效果,恢复更清晰、更自然的图像。

  • 降低计算复杂度:通过采用深度可分离卷积,MSB减少了计算量和参数量,使得模型更加高效,适合部署在计算资源有限的设备上。

消融实验结果

  • 提升效果:通过对比不同方法的指标,论文可以验证MixDehazeNet在多个数据集上是否均表现出了更高的PSNR和SSIM,证明其在去雾质量和图像细节恢复方面的优势。

  • 性能分析:在多种环境和雾霾条件下,MixDehazeNet相对于SOTA方法可能在细节恢复、噪声处理、视觉质量等方面有所提升。

即插即用模块

import torch
import torch.nn as nn
#论文地址:https://arxiv.org/abs/2305.17654
#论文:Mix Structure Block contains multi-scale parallel large convolution kernel module and enhanced parallel attention module

class MixStructureBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()

        self.norm1 = nn.BatchNorm2d(dim)
        self.norm2 = nn.BatchNorm2d(dim)

        self.conv1 = nn.Conv2d(dim, dim, kernel_size=1)
        self.conv2 = nn.Conv2d(dim, dim, kernel_size=5, padding=2, padding_mode='reflect')
        self.conv3_19 = nn.Conv2d(dim, dim, kernel_size=7, padding=9, groups=dim, dilation=3, padding_mode='reflect')
        self.conv3_13 = nn.Conv2d(dim, dim, kernel_size=5, padding=6, groups=dim, dilation=3, padding_mode='reflect')
        self.conv3_7 = nn.Conv2d(dim, dim, kernel_size=3, padding=3, groups=dim, dilation=3, padding_mode='reflect')

        # Simple Pixel Attention
        self.Wv = nn.Sequential(
            nn.Conv2d(dim, dim, 1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=3 // 2, groups=dim, padding_mode='reflect')
        )
        self.Wg = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, dim, 1),
            nn.Sigmoid()
        )

        # Channel Attention
        self.ca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(dim, dim, 1, padding=0, bias=True),
            nn.GELU(),
            # nn.ReLU(True),
            nn.Conv2d(dim, dim, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

        # Pixel Attention
        self.pa = nn.Sequential(
            nn.Conv2d(dim, dim // 8, 1, padding=0, bias=True),
            nn.GELU(),
            # nn.ReLU(True),
            nn.Conv2d(dim // 8, 1, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

        self.mlp = nn.Sequential(
            nn.Conv2d(dim * 3, dim * 4, 1),
            nn.GELU(),
            # nn.ReLU(True),
            nn.Conv2d(dim * 4, dim, 1)
        )
        self.mlp2 = nn.Sequential(
            nn.Conv2d(dim * 3, dim * 4, 1),
            nn.GELU(),
            # nn.ReLU(True),
            nn.Conv2d(dim * 4, dim, 1)
        )

    def forward(self, x):
        identity = x
        x = self.norm1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.cat([self.conv3_19(x), self.conv3_13(x), self.conv3_7(x)], dim=1)
        x = self.mlp(x)
        x = identity + x

        identity = x
        x = self.norm2(x)
        x = torch.cat([self.Wv(x) * self.Wg(x), self.ca(x) * x, self.pa(x) * x], dim=1)
        x = self.mlp2(x)
        x = identity + x
        return x


if __name__ == '__main__':


    block = MixStructureBlock(dim=64)


    input = torch.rand(1, 64, 128, 128) # B C H W


    output = block(input)

    print(input.size())
    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章