ECCV 2024 | 最新直方图transfomer,直接涨点起飞!

文摘   2025-01-20 11:54   安徽  

点击下方卡片,关注“AI前沿速递”公众号

各种重磅干货,第一时间送达


标题:Restoring Images in Adverse Weather Conditions via Histogram Transformer

论文链接:https://arxiv.org/pdf/2407.10172

代码链接:https://github.com/sunshangquan/Histoformer

来源:ECCV 2024

创新点

直方图自注意力机制(Histogram Self-Attention,HSA)

- 核心思想:将空间特征按强度排序并划分为不同的直方图区域(bins),对相似的降质像素进行分组处理。

- 优势:相比传统自注意力操作仅限于固定空间范围或通道维度,HSA机制能够在更广范围内处理降质问题,有效提高了对天气引起的退化图像的恢复效果。例如,在处理雾霾天气下的图像时,能够更好地识别和恢复那些因雾霾而变得模糊、色彩失真的像素区域。

动态范围卷积(Dynamic-Range Convolution)

- 核心思想:通过对像素进行水平和垂直方向的排序,使得卷积操作能够处理相似强度的像素。

- 优势:弥补了卷积神经网络(CNNs)在长距离空间特征提取上的不足,增强了对天气相关依赖性的提取能力。比如在处理雨天图像时,能够更好地捕捉雨滴分布的长距离模式,从而更准确地去除雨滴对图像的影响。

双尺度门控前馈网络(Dual-scale Gated Feed-Forward Network,DGFF)

- 核心思想:在传统的前馈网络基础上引入双尺度门控前馈网络,通过融合多尺度和多范围的深度卷积,进一步增强了对多范围信息的捕捉能力。

- 优势:能够同时提取图像中的局部细节和全局信息,对于恢复图像的整体质量具有重要意义。例如,在恢复雪景图像时,既能保留雪花的细节纹理,又能保持天空、地面等大范围区域的清晰度和色彩准确性。

皮尔逊相关系数损失(Pearson Correlation Coefficient Loss)

- 核心思想:引入皮尔逊相关系数作为额外的损失函数,用于提升恢复图像与真实图像之间的线性关联,确保恢复后的图像在整体上保持与真实图像相似的顺序和关系,而不仅仅是像素级的接近度。

- 优势:从全局角度优化图像恢复效果,使恢复图像在结构、纹理等方面与真实图像更加接近,提升了图像恢复的整体质量和视觉效果。比如在恢复雾天图像后,不仅像素值与真实图像相似,图像中的物体轮廓、层次感等也能更好地还原,更符合人眼的视觉感知。

整体结构

这篇文章介绍的Histoformer模型的整体结构如下:

- 编码器-解码器架构:模型基于编码器-解码器架构,特征在多个级别上被编码(下采样),并在解码过程中逐步重建(上采样)。编码器通过下采样提取多尺度特征,解码器通过上采样逐步重建高质量图像,同时利用跳跃连接保持图像细节。

- 主干网络:模型的主干部分由多个直方图变换块(Histogram Transformer Blocks,HTBs)组成,每个HTB包含两个核心模块:

  • 动态范围直方图自注意力模块(DHSA):这个模块通过将空间特征划分为不同的强度区域(Bins),在Bins内和Bins之间应用自注意力机制,重点处理那些在天气退化中表现出相似模式的像素。

  • 双尺度门控前馈模块(DGFF):为了增强对多尺度、多范围特征的提取能力,DGFF模块采用了两个并行的卷积路径,其中包含5×5的标准卷积和扩张卷积。

- 输入预处理:输入的低质量图像首先通过一个3×3的卷积层,进行图像块的嵌入。这一过程将图像的空间信息转换为特征向量,供后续的变换模块处理。

- 像素重排:在每个编码阶段,网络会对特征进行像素下采样(Pixel Unshuffle),而在解码阶段,应用像素上采样(Pixel Shuffle),这有助于在多层次上提取信息并生成高分辨率的输出。

- 重建模块:经过所有HTB块和下采样、上采样操作后,网络通过一个重建模块将提取的特征转化为最终的高质量图像输出。

- 损失函数:模型的训练过程中采用了两个损失函数:

  • 重建损失(L1损失):衡量恢复图像与真实图像之间的像素差异。

  • 皮尔逊相关系数损失:确保恢复图像和真实图像在整体上保持线性相关。

代码实现

Layer Norm

import numbers
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

def to_2d(x):
    return rearrange(x, 'b c h w -> b (h w c)')

def to_3d(x):
    return rearrange(x, 'b c h w -> b (h w) c')

def to_4d(x, h, w):
    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

class BiasFree_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(BiasFree_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)
        assert len(normalized_shape) == 1
        self.normalized_shape = normalized_shape

    def forward(self, x):
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return x / torch.sqrt(sigma + 1e-5)

class WithBias_LayerNorm(nn.Module):
    def __init__(self, normalized_shape):
        super(WithBias_LayerNorm, self).__init__()
        if isinstance(normalized_shape, numbers.Integral):
            normalized_shape = (normalized_shape,)
        normalized_shape = torch.Size(normalized_shape)
        assert len(normalized_shape) == 1
        self.normalized_shape = normalized_shape

    def forward(self, x):
        mu = x.mean(-1, keepdim=True)
        sigma = x.var(-1, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5)

class LayerNorm(nn.Module):
    def __init__(self, dim, LayerNorm_type="WithBias"):
        super(LayerNorm, self).__init__()
        if LayerNorm_type == 'BiasFree':
            self.body = BiasFree_LayerNorm(dim)
        else:
            self.body = WithBias_LayerNorm(dim)

    def forward(self, x):
        h, w = x.shape[-2:]
        return to_4d(self.body(to_3d(x)), h, w)

Dual-scale Gated Feed-Forward Network (DGFF)

class FeedForward(nn.Module):
    def __init__(self, dim, ffn_expansion_factor, bias):
        super(FeedForward, self).__init__()
        hidden_features = int(dim * ffn_expansion_factor)
        self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
        self.dwconv_5 = nn.Conv2d(hidden_features // 4, hidden_features // 4, kernel_size=5, stride=1, padding=2, groups=hidden_features // 4, bias=bias)
        self.dwconv_dilated2_1 = nn.Conv2d(hidden_features // 4, hidden_features // 4, kernel_size=3, stride=1, padding=2, groups=hidden_features // 4, bias=bias, dilation=2)
        self.p_unshuffle = nn.PixelUnshuffle(2)
        self.p_shuffle = nn.PixelShuffle(2)
        self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)

    def forward(self, x):
        x = self.project_in(x)
        x = self.p_shuffle(x)
        x1, x2 = x.chunk(2, dim=1)
        x1 = self.dwconv_5(x1)
        x2 = self.dwconv_dilated2_1(x2)
        x = F.mish(x2) * x1
        x = self.p_unshuffle(x)
        x = self.project_out(x)
        return x

Dynamic-range Histogram Self-Attention (DHSA)

class Attention_histogram(nn.Module):
    def __init__(self, dim, num_heads=4, bias=False, ifBox=True):
        super(Attention_histogram, self).__init__()
        self.factor = num_heads
        self.ifBox = ifBox
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.qkv = nn.Conv2d(dim, dim * 5, kernel_size=1, bias=bias)
        self.qkv_dwconv = nn.Conv2d(dim * 5, dim * 5, kernel_size=3, stride=1, padding=1, groups=dim * 5, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def pad(self, x, factor):
        hw = x.shape[-1]
        t_pad = [0, 0] if hw % factor == 0 else [0, (hw // factor + 1) * factor - hw]
        x = F.pad(x, t_pad, 'constant', 0)
        return x, t_pad

    def unpad(self, x, t_pad):
        _, _, hw = x.shape
        return x[:, :, t_pad[0]:hw - t_pad[1]]

    def softmax_1(self, x, dim=-1):
        logit = x.exp()
        logit = logit / (logit.sum(dim, keepdim=True) + 1)
        return logit

    def normalize(self, x):
        mu = x.mean(-2, keepdim=True)
        sigma = x.var(-2, keepdim=True, unbiased=False)
        return (x - mu) / torch.sqrt(sigma + 1e-5)

    def reshape_attn(self, q, k, v, ifBox):
        b, c = q.shape[:2]
        q, t_pad = self.pad(q, self.factor)
        k, t_pad = self.pad(k, self.factor)
        v, t_pad = self.pad(v, self.factor)
        hw = q.shape[-1] // self.factor
        shape_ori = "b (head c) (factor hw)" if ifBox else "b (head c) (hw factor)"
        shape_tar = "b head (c factor) hw"
        q = rearrange(q, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
        k = rearrange(k, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
        v = rearrange(v, '{} -> {}'.format(shape_ori, shape_tar), factor=self.factor, hw=hw, head=self.num_heads)
        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)
        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = self.softmax_1(attn, dim=-1)
        out = (attn @ v)
        out = rearrange(out, '{} -> {}'.format(shape_tar, shape_ori), factor=self.factor, hw=hw, b=b, head=self.num_heads)
        out = self.unpad(out, t_pad)
        return out

    def forward(self, x):
        b, c, h, w = x.shape
        x_sort, idx_h = x[:, :c // 2].sort(-2)
        x_sort, idx_w = x_sort.sort(-1)
        x[:, :c // 2] = x_sort
        qkv = self.qkv_dwconv(self.qkv(x))
        q1, k1, q2, k2, v = qkv.chunk(5, dim=1)  # b,c,x,x
        v, idx = v.view(b, c, -1).sort(dim=-1)
        q1 = torch.gather(q1.view(b, c, -1), dim=2, index=idx)
        k1 = torch.gather(k1.view(b, c, -1), dim=2, index=idx)
        q2 = torch.gather(q2.view(b, c, -1), dim=2, index=idx)
        k2 = torch.gather(k2.view(b, c, -1), dim=2, index=idx)
        out1 = self.reshape_attn(q1, k1, v, True)
        out2 = self.reshape_attn(q2, k2, v, False)
        out1 = torch.scatter(out1, 2, idx, out1).view(b, c, h, w)
        out2 = torch.scatter(out2, 2, idx, out2).view(b, c, h, w)
        out = out1 * out2
        out = self.project_out(out)
        out_replace = out[:, :c // 2]
        out_replace = torch.scatter(out_replace, -1, idx_w, out_replace)
        out_replace = torch.scatter(out_replace, -2, idx_h, out_replace)
        out[:, :c // 2] = out_replace
        return out

Histogram Transformer Block (HTB)

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4, ffn_expansion_factor=2.5, bias=False, LayerNorm_type='WithBias'):
        super(TransformerBlock, self).__init__()
        self.attn_g = Attention_histogram(dim, num_heads, bias, True)
        self.norm_g = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
        self.norm_ff1 = LayerNorm(dim, LayerNorm_type)

    def forward(self, x):
        x = x + self.attn_g(self.norm_g(x))
        x_out = x + self.ffn(self.norm_ff1(x))
        return x_out

示例用法

if __name__ == '__main__':
    input = torch.randn(1, 64, 128, 128)  # 输入 B C H W
    transformer_block = TransformerBlock(64)  # 输入C
    # 前向传播
    output = transformer_block(input)
    # 打印输入和输出的形状
    print(input.size())
    print(output.size())



本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。

AI前沿速递
持续分享最新AI前沿论文成果
 最新文章