即插即用半小波注意力模块HWAB,涨点涨爆了

文摘   2024-12-01 17:20   上海  

论文介绍

题目:HALF WAVELET ATTENTION ON M-NET+ FOR LOW-LIGHT IMAGE ENHANCEMENT

论文地址:https://arxiv.org/abs/2203.01296

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 改进的分层架构 M-Net+

  • 提出了一个改良的分层模型 M-Net+,专为低光图像增强设计。该架构旨在缓解采样过程中的空间信息损失问题。通过采用像素去卷积(Pixel Unshuffle)和双线性下采样,提升了多尺度特征的多样性和丰富性。

  • 半小波注意力块(Half Wavelet Attention Block, HWAB)

  • 新引入了一种高效的特征提取模块 HWAB,利用小波域信息提取更丰富的特征。这种方法结合了小波变换和注意力机制,可以同时减少计算复杂度并增强特征语义信息。

  • 改进的特征融合方法

  • 在解码过程中,使用选择性核特征融合(Selective Kernel Feature Fusion, SKFF)方法替代传统的特征拼接方式,有效地融合了不同分辨率的特征,同时降低了网络的参数量和计算复杂度。

  • 性能表现

  • 在 LOL 和 MIT-Adobe FiveK 两个数据集上,提出的 HWMNet 模型在图像质量(PSNR、SSIM 和 LPIPS)以及计算复杂度方面均达到了竞争性甚至领先的效果。

方法

1. 模型总体架构

HWMNet 继承了 U-Net 和 M-Net 的分层结构,包含以下关键模块:

  • 编码器(Encoder):从输入低光图像中提取多层次特征。

  • 解码器(Decoder):将不同分辨率的特征融合,并逐步恢复到原始图像分辨率。

  • 跳跃连接(Skip Connections):连接编码器和解码器的对应层,用于保持高分辨率的特征信息。

2. 关键改进模块

2.1 M-Net+ 架构

M-Net+ 是基于 M-Net 的改进架构,解决了原始 M-Net 的两个主要问题:

  • 避免空间信息损失

    • 在 U-Net 路径中使用像素去卷积(Pixel Unshuffle)进行下采样。

    • 在门柱路径(Gatepost Path)中使用双线性插值下采样。

  • 高效特征融合

    • 在解码阶段,使用选择性核特征融合(SKFF)方法取代简单的特征拼接,减轻高维特征融合的计算复杂度。

2.2 半小波注意力块(HWAB)

HWAB 是模型的核心创新模块,用于增强特征提取的多样性:

  • 输入特征被分为两部分:

    • 保留部分:直接保留原始域的特征信息。

    • 变换部分:通过离散小波变换(DWT)进入小波域,从中提取更丰富的上下文信息。

  • 在小波域中,通过通道注意力(Channel Attention)和空间注意力(Spatial Attention)对特征加权,随后通过逆小波变换(IWT)回到原始域。

  • 最后,合并保留特征和加权特征,再通过卷积层生成输出特征。



3. 特征处理流程

  1. 输入处理

  • 输入图像经过一个初始 3×3 卷积层,提取初始特征。

  • 每一层都通过 HWAB 处理,分为多分辨率特征。

  • 多层次特征提取

    • U-Net 路径通过像素去卷积进行下采样,逐步降低特征图分辨率。

    • 门柱路径使用双线性下采样,并保持特征与 U-Net 路径的连接。

  • 特征融合

    • 在解码阶段,通过 SKFF 将多分辨率特征高效融合,减轻计算负担并提升重建质量。

  • 输出生成

    • 经过多层次特征融合后,模型最终通过卷积层生成增强后的图像。



    4. 模型的主要优势

    • 分层结构提升了模型对多尺度信息的处理能力。

    • HWAB 模块显著提高了特征提取的多样性和语义丰富度。

    • 通过高效特征融合和轻量化设计,实现了更低的计算复杂度。

    即插即用模块作用

    HWAB 作为一个即插即用模块

    • 图像增强任务

    • 特别适用于低光图像增强任务,如论文中提到的 LOL 和 MIT-Adobe FiveK 数据集。在需要同时提升图像亮度、对比度和细节的场景中效果显著。

    • 图像修复任务

    • 可用于其他图像修复任务,如图像去噪、去模糊等,因为其设计本质上有助于提取和恢复细节特征。

    • 需要低计算复杂度的场景

    • HWAB 通过小波变换对特征分解并仅处理一半的特征,显著降低了计算复杂度,非常适合嵌入式设备或实时处理的应用场景。

    • 多尺度特征处理的场景

    • 在需要多分辨率特征提取和整合的视觉任务中,HWAB 可高效提取不同尺度下的丰富特征信息。

    消融实验结果

    • 表 1 是在 LOL 数据集上的结果对比,表明 HWAB 和 M-Net+ 架构结合后在 PSNR、SSIM 和 LPIPS 三个指标上表现优异。

    • 表 2 是在 MIT-Adobe FiveK 数据集上的结果对比,展示了 HWMNet 在多个任务下的稳健性和高效性。

    • HWAB 的引入使模型在保持较低计算复杂度的情况下,实现了比大多数方法更好的性能(如 PSNR 和 LPIPS 指标)。

    即插即用模块

    import torch
    import torch.nn as nn
    #论文:HALF WAVELET ATTENTION ON M-NET+ FOR LOW-LIGHT IMAGE ENHANCEMENT
    #论文地址:https://arxiv.org/abs/2203.01296

    def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
        return nn.Conv2d(
            in_channels, out_channels, kernel_size,
            padding=(kernel_size // 2), bias=bias, stride=stride)

    def dwt_init(x):
        x01 = x[:, :, 0::2, :] / 2
        x02 = x[:, :, 1::2, :] / 2
        x1 = x01[:, :, :, 0::2]
        x2 = x02[:, :, :, 0::2]
        x3 = x01[:, :, :, 1::2]
        x4 = x02[:, :, :, 1::2]
        x_LL = x1 + x2 + x3 + x4
        x_HL = -x1 - x2 + x3 + x4
        x_LH = -x1 + x2 - x3 + x4
        x_HH = x1 - x2 - x3 + x4
        # print(x_HH[:, 0, :, :])
        return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)

    def iwt_init(x):
        r = 2
        in_batch, in_channel, in_height, in_width = x.size()
        out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r ** 2)), r * in_height, r * in_width
        x1 = x[:, 0:out_channel, :, :] / 2
        x2 = x[:, out_channel:out_channel * 2, :, :] / 2
        x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
        x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
        h = torch.zeros([out_batch, out_channel, out_height, out_width])

        h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
        h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
        h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
        h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4

        return h


    class DWT(nn.Module):
        def __init__(self):
            super(DWT, self).__init__()
            self.requires_grad = True

        def forward(self, x):
            return dwt_init(x)


    class IWT(nn.Module):
        def __init__(self):
            super(IWT, self).__init__()
            self.requires_grad = True

        def forward(self, x):
            return iwt_init(x)


    # Spatial Attention Layer
    class SALayer(nn.Module):
        def __init__(self, kernel_size=5, bias=False):
            super(SALayer, self).__init__()
            self.conv_du = nn.Sequential(
                nn.Conv2d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias),
                nn.Sigmoid()
            )

        def forward(self, x):
            # torch.max will output 2 things, and we want the 1st one
            max_pool, _ = torch.max(x, dim=1, keepdim=True)
            avg_pool = torch.mean(x, 1, keepdim=True)
            channel_pool = torch.cat([max_pool, avg_pool], dim=1) # [N,2,H,W] could add 1x1 conv -> [N,3,H,W]
            y = self.conv_du(channel_pool)

            return x * y

    # Channel Attention Layer
    class CALayer(nn.Module):
        def __init__(self, channel, reduction=16, bias=False):
            super(CALayer, self).__init__()
            # global average pooling: feature --> point
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
            # feature channel downscale and upscale --> channel weight
            self.conv_du = nn.Sequential(
                nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
                nn.Sigmoid()
            )

        def forward(self, x):
            y = self.avg_pool(x)
            y = self.conv_du(y)
            return x * y

    # Half Wavelet Attention Block (HWAB)
    class HWAB(nn.Module):
        def __init__(self, n_feat, o_feat, kernel_size=3, reduction=16, bias=False, act=nn.PReLU()):
            super(HWAB, self).__init__()
            self.dwt = DWT()
            self.iwt = IWT()

            modules_body = \
                [
                    conv(n_feat*2, n_feat, kernel_size, bias=bias),
                    act,
                    conv(n_feat, n_feat*2, kernel_size, bias=bias)
                ]
            self.body = nn.Sequential(*modules_body)

            self.WSA = SALayer()
            self.WCA = CALayer(n_feat*2, reduction, bias=bias)

            self.conv1x1 = nn.Conv2d(n_feat*4, n_feat*2, kernel_size=1, bias=bias)
            self.conv3x3 = nn.Conv2d(n_feat, o_feat, kernel_size=3, padding=1, bias=bias)
            self.activate = act
            self.conv1x1_final = nn.Conv2d(n_feat, o_feat, kernel_size=1, bias=bias)

        def forward(self, x):
            residual = x

            # Split 2 part
            wavelet_path_in, identity_path = torch.chunk(x, 2, dim=1)

            # Wavelet domain (Dual attention)
            x_dwt = self.dwt(wavelet_path_in)
            res = self.body(x_dwt)
            branch_sa = self.WSA(res)
            branch_ca = self.WCA(res)
            res = torch.cat([branch_sa, branch_ca], dim=1)
            res = self.conv1x1(res) + x_dwt
            wavelet_path = self.iwt(res)

            out = torch.cat([wavelet_path, identity_path], dim=1)
            out = self.activate(self.conv3x3(out))
            out += self.conv1x1_final(residual)

            return out


    if __name__ == '__main__':


        block = HWAB(n_feat=64, o_feat=64)

        input = torch.randn(1, 64, 128, 128) # B C H W

        output = block(input)

        print(input.size())    print(output.size())

    便捷下载方式

    浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

    更多分析可见原文


    ai缝合大王
    聚焦AI前沿,分享相关技术、论文,研究生自救指南
     最新文章