论文介绍
题目:HALF WAVELET ATTENTION ON M-NET+ FOR LOW-LIGHT IMAGE ENHANCEMENT
论文地址:https://arxiv.org/abs/2203.01296
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
改进的分层架构 M-Net+:
半小波注意力块(Half Wavelet Attention Block, HWAB):
改进的特征融合方法:
性能表现:
提出了一个改良的分层模型 M-Net+,专为低光图像增强设计。该架构旨在缓解采样过程中的空间信息损失问题。通过采用像素去卷积(Pixel Unshuffle)和双线性下采样,提升了多尺度特征的多样性和丰富性。
新引入了一种高效的特征提取模块 HWAB,利用小波域信息提取更丰富的特征。这种方法结合了小波变换和注意力机制,可以同时减少计算复杂度并增强特征语义信息。
在解码过程中,使用选择性核特征融合(Selective Kernel Feature Fusion, SKFF)方法替代传统的特征拼接方式,有效地融合了不同分辨率的特征,同时降低了网络的参数量和计算复杂度。
在 LOL 和 MIT-Adobe FiveK 两个数据集上,提出的 HWMNet 模型在图像质量(PSNR、SSIM 和 LPIPS)以及计算复杂度方面均达到了竞争性甚至领先的效果。
方法
1. 模型总体架构
HWMNet 继承了 U-Net 和 M-Net 的分层结构,包含以下关键模块:
编码器(Encoder):从输入低光图像中提取多层次特征。
解码器(Decoder):将不同分辨率的特征融合,并逐步恢复到原始图像分辨率。
跳跃连接(Skip Connections):连接编码器和解码器的对应层,用于保持高分辨率的特征信息。
2. 关键改进模块
2.1 M-Net+ 架构
M-Net+ 是基于 M-Net 的改进架构,解决了原始 M-Net 的两个主要问题:
避免空间信息损失:
在 U-Net 路径中使用像素去卷积(Pixel Unshuffle)进行下采样。
在门柱路径(Gatepost Path)中使用双线性插值下采样。
高效特征融合:
在解码阶段,使用选择性核特征融合(SKFF)方法取代简单的特征拼接,减轻高维特征融合的计算复杂度。
2.2 半小波注意力块(HWAB)
HWAB 是模型的核心创新模块,用于增强特征提取的多样性:
输入特征被分为两部分:
保留部分:直接保留原始域的特征信息。
变换部分:通过离散小波变换(DWT)进入小波域,从中提取更丰富的上下文信息。
在小波域中,通过通道注意力(Channel Attention)和空间注意力(Spatial Attention)对特征加权,随后通过逆小波变换(IWT)回到原始域。
最后,合并保留特征和加权特征,再通过卷积层生成输出特征。
3. 特征处理流程
输入处理:
输入图像经过一个初始 3×3 卷积层,提取初始特征。
每一层都通过 HWAB 处理,分为多分辨率特征。
多层次特征提取:
U-Net 路径通过像素去卷积进行下采样,逐步降低特征图分辨率。
门柱路径使用双线性下采样,并保持特征与 U-Net 路径的连接。
特征融合:
在解码阶段,通过 SKFF 将多分辨率特征高效融合,减轻计算负担并提升重建质量。
输出生成:
经过多层次特征融合后,模型最终通过卷积层生成增强后的图像。
4. 模型的主要优势
分层结构提升了模型对多尺度信息的处理能力。
HWAB 模块显著提高了特征提取的多样性和语义丰富度。
通过高效特征融合和轻量化设计,实现了更低的计算复杂度。
即插即用模块作用
HWAB 作为一个即插即用模块:
图像增强任务:
图像修复任务:
需要低计算复杂度的场景:
多尺度特征处理的场景:
特别适用于低光图像增强任务,如论文中提到的 LOL 和 MIT-Adobe FiveK 数据集。在需要同时提升图像亮度、对比度和细节的场景中效果显著。
可用于其他图像修复任务,如图像去噪、去模糊等,因为其设计本质上有助于提取和恢复细节特征。
HWAB 通过小波变换对特征分解并仅处理一半的特征,显著降低了计算复杂度,非常适合嵌入式设备或实时处理的应用场景。
在需要多分辨率特征提取和整合的视觉任务中,HWAB 可高效提取不同尺度下的丰富特征信息。
消融实验结果
表 1 是在 LOL 数据集上的结果对比,表明 HWAB 和 M-Net+ 架构结合后在 PSNR、SSIM 和 LPIPS 三个指标上表现优异。
表 2 是在 MIT-Adobe FiveK 数据集上的结果对比,展示了 HWMNet 在多个任务下的稳健性和高效性。
HWAB 的引入使模型在保持较低计算复杂度的情况下,实现了比大多数方法更好的性能(如 PSNR 和 LPIPS 指标)。
即插即用模块
import torch
import torch.nn as nn
#论文:HALF WAVELET ATTENTION ON M-NET+ FOR LOW-LIGHT IMAGE ENHANCEMENT
#论文地址:https://arxiv.org/abs/2203.01296
def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), bias=bias, stride=stride)
def dwt_init(x):
x01 = x[:, :, 0::2, :] / 2
x02 = x[:, :, 1::2, :] / 2
x1 = x01[:, :, :, 0::2]
x2 = x02[:, :, :, 0::2]
x3 = x01[:, :, :, 1::2]
x4 = x02[:, :, :, 1::2]
x_LL = x1 + x2 + x3 + x4
x_HL = -x1 - x2 + x3 + x4
x_LH = -x1 + x2 - x3 + x4
x_HH = x1 - x2 - x3 + x4
# print(x_HH[:, 0, :, :])
return torch.cat((x_LL, x_HL, x_LH, x_HH), 1)
def iwt_init(x):
r = 2
in_batch, in_channel, in_height, in_width = x.size()
out_batch, out_channel, out_height, out_width = in_batch, int(in_channel / (r ** 2)), r * in_height, r * in_width
x1 = x[:, 0:out_channel, :, :] / 2
x2 = x[:, out_channel:out_channel * 2, :, :] / 2
x3 = x[:, out_channel * 2:out_channel * 3, :, :] / 2
x4 = x[:, out_channel * 3:out_channel * 4, :, :] / 2
h = torch.zeros([out_batch, out_channel, out_height, out_width])
h[:, :, 0::2, 0::2] = x1 - x2 - x3 + x4
h[:, :, 1::2, 0::2] = x1 - x2 + x3 - x4
h[:, :, 0::2, 1::2] = x1 + x2 - x3 - x4
h[:, :, 1::2, 1::2] = x1 + x2 + x3 + x4
return h
class DWT(nn.Module):
def __init__(self):
super(DWT, self).__init__()
self.requires_grad = True
def forward(self, x):
return dwt_init(x)
class IWT(nn.Module):
def __init__(self):
super(IWT, self).__init__()
self.requires_grad = True
def forward(self, x):
return iwt_init(x)
# Spatial Attention Layer
class SALayer(nn.Module):
def __init__(self, kernel_size=5, bias=False):
super(SALayer, self).__init__()
self.conv_du = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2, bias=bias),
nn.Sigmoid()
)
def forward(self, x):
# torch.max will output 2 things, and we want the 1st one
max_pool, _ = torch.max(x, dim=1, keepdim=True)
avg_pool = torch.mean(x, 1, keepdim=True)
channel_pool = torch.cat([max_pool, avg_pool], dim=1) # [N,2,H,W] could add 1x1 conv -> [N,3,H,W]
y = self.conv_du(channel_pool)
return x * y
# Channel Attention Layer
class CALayer(nn.Module):
def __init__(self, channel, reduction=16, bias=False):
super(CALayer, self).__init__()
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
# Half Wavelet Attention Block (HWAB)
class HWAB(nn.Module):
def __init__(self, n_feat, o_feat, kernel_size=3, reduction=16, bias=False, act=nn.PReLU()):
super(HWAB, self).__init__()
self.dwt = DWT()
self.iwt = IWT()
modules_body = \
[
conv(n_feat*2, n_feat, kernel_size, bias=bias),
act,
conv(n_feat, n_feat*2, kernel_size, bias=bias)
]
self.body = nn.Sequential(*modules_body)
self.WSA = SALayer()
self.WCA = CALayer(n_feat*2, reduction, bias=bias)
self.conv1x1 = nn.Conv2d(n_feat*4, n_feat*2, kernel_size=1, bias=bias)
self.conv3x3 = nn.Conv2d(n_feat, o_feat, kernel_size=3, padding=1, bias=bias)
self.activate = act
self.conv1x1_final = nn.Conv2d(n_feat, o_feat, kernel_size=1, bias=bias)
def forward(self, x):
residual = x
# Split 2 part
wavelet_path_in, identity_path = torch.chunk(x, 2, dim=1)
# Wavelet domain (Dual attention)
x_dwt = self.dwt(wavelet_path_in)
res = self.body(x_dwt)
branch_sa = self.WSA(res)
branch_ca = self.WCA(res)
res = torch.cat([branch_sa, branch_ca], dim=1)
res = self.conv1x1(res) + x_dwt
wavelet_path = self.iwt(res)
out = torch.cat([wavelet_path, identity_path], dim=1)
out = self.activate(self.conv3x3(out))
out += self.conv1x1_final(residual)
return out
if __name__ == '__main__':
block = HWAB(n_feat=64, o_feat=64)
input = torch.randn(1, 64, 128, 128) # B C H W
output = block(input)
print(input.size()) print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文