即插即用特征注意力融合模块FFA,涨点起飞起飞了

文摘   2024-11-21 17:20   上海  

论文介绍

题目:FFA-Net: Feature Fusion Attention Network for Single Image Dehazing

论文地址:arxiv.org/pdf/1911.07559

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 特征融合注意力网络(FFA-Net):提出了一种端到端的特征融合注意力网络(FFA-Net)用于单张图像去雾。FFA-Net在厚雾区域和富含纹理的细节恢复方面具有显著优势,其细节和色彩保真度优于之前的去雾方法。

  • 特征注意力模块(FA):设计了一个创新的特征注意力模块,结合了通道注意力和像素注意力机制。该模块使得网络在厚雾区域和关键通道上更加聚焦,提供了处理不同类型信息的灵活性,提升了卷积神经网络的表达能力。

  • 局部残差学习和特征注意力的基础模块:提出了一种结合局部残差学习和特征注意力的基础模块。局部残差学习允许薄雾和低频信息通过多层局部残差连接被跳过,从而让网络关注更有效的信息。

  • 基于注意力的特征融合结构(FFA):提出了基于注意力的特征融合结构,可以将浅层特征保留并传递到深层,并根据不同级别特征的重要性自适应地学习权重。这种方式比直接指定权重的融合方法效果更好。

方法

整体结构

       FFA-Net模型通过浅层特征提取、多个包含局部残差学习和特征注意力的组架构、全局残差学习和特征融合注意力模块来实现单图像去雾。模型利用通道和像素级的注意力机制自适应地调整不同特征的权重,融合浅层和深层信息,从而更有效地去除雾霾,保留细节和色彩的准确性。

  • 浅层特征提取模块:输入的有雾图像首先通过浅层特征提取模块,用以提取初步的低层特征。

  • 组架构(Group Architecture):模型的核心部分由多个组架构组成,每个组架构包含多个基础模块,通过多个跳跃连接增加网络深度和表达能力。这种组架构主要通过局部残差学习(Local Residual Learning)和特征注意力模块(Feature Attention, FA)来实现。

    • 基础模块(Basic Block):基础模块由局部残差学习和特征注意力模块构成,局部残差学习可以跳过薄雾和低频信息,使网络更专注于有效信息。特征注意力模块通过通道注意力和像素注意力结合的方式,增强网络在不同特征和像素上的关注力度。

  • 全局残差学习模块(Global Residual Learning):在组架构的输出之后加入一个全局残差学习模块,通过两层卷积操作和长跳跃连接来帮助网络更好地恢复清晰的图像。这一模块能够进一步提高去雾效果并稳定网络训练。

  • 特征融合注意力模块(Feature Fusion Attention, FFA):所有组架构的输出在通道方向上进行特征图拼接,并通过特征注意力机制来调整不同级别特征的权重,实现浅层和深层信息的融合。这种自适应权重的学习让网络能够更好地关注厚雾区域和高频纹理信息。

  • 重构模块:融合后的特征最终输入到重构模块中,通过反卷积或卷积操作来生成无雾的清晰图像输出。

即插即用模块作用

FFA 作为一个即插即用模块,主要适用于:

  • 应用场景

    • 图像去雾:适用于处理厚雾或不均匀雾霾分布的场景,可以增强去雾效果,恢复图像的细节和真实色彩。

    • 图像超分辨率:适用于需要提升图像分辨率的场景,帮助保留高频纹理和细节信息。

    • 图像去噪:在去除噪声的同时保留图像的细节,使图像在噪声去除后更清晰。

    • 图像复原:如去雨、去划痕等低层视觉任务,适合需要融合多层特征信息以增强细节表现的场景。

  • 主要作用

    • 自适应特征权重分配:通过通道和像素注意力机制,自适应地为不同特征分配权重,增强关键区域(如厚雾区域和高频细节)的关注度。

    • 浅层与深层信息融合:保留浅层的关键信息并传递至深层,提高网络在处理复杂纹理和细节上的能力。

    • 提升图像复原效果:增强网络在细节复原、色彩还原和视觉效果上的表现,使图像复原效果更加真实和自然

消融实验结果

  • 仅使用特征注意力模块(FA)时,网络在去雾性能上已经具有较强的竞争力。

  • 将局部残差学习(LRL)与FA模块结合使用后,网络性能和训练稳定性均有所提升。

  • 当结合FFA结构后,网络性能进一步大幅提升,表明FFA结构在融合不同层级特征信息、提高去雾效果上起到了关键作用。

即插即用模块

import torch.nn as nn
import torch

def default_conv(in_channels, out_channels, kernel_size, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias)



class PALayer(nn.Module):
    def __init__(self, channel):
        super(PALayer, self).__init__()
        self.pa = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.pa(x)
        return x * y



class CALayer(nn.Module):
    def __init__(self, channel):
        super(CALayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.ca = nn.Sequential(
            nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.ca(y)
        return x * y


# 基础块(Block)
class Block(nn.Module):
    def __init__(self, conv, dim, kernel_size):
        super(Block, self).__init__()
        self.conv1 = conv(dim, dim, kernel_size, bias=True)
        self.act1 = nn.ReLU(inplace=True)
        self.conv2 = conv(dim, dim, kernel_size, bias=True)
        self.calayer = CALayer(dim)
        self.palayer = PALayer(dim)

    def forward(self, x):
        res = self.act1(self.conv1(x))
        res = res + x
        res = self.conv2(res)
        res = self.calayer(res)
        res = self.palayer(res)
        res = res + x
        return res


# 分组(Group)
class Group(nn.Module):
    def __init__(self, conv, dim, kernel_size, blocks):
        super(Group, self).__init__()

        modules = [Block(conv, dim, kernel_size) for _ in range(blocks)]
        modules.append(conv(dim, dim, kernel_size))
        self.gp = nn.Sequential(*modules)

    def forward(self, x):
        res = self.gp(x)
        res = res + x
        return res



class FFA(nn.Module):
    def __init__(self, gps, blocks, conv=default_conv):
        super(FFA, self).__init__()
        self.gps = gps
        self.dim = 64
        kernel_size = 3
        pre_process = [conv(3, self.dim, kernel_size)]


        assert self.gps == 3


        self.g1 = Group(conv, self.dim, kernel_size, blocks=blocks)
        self.g2 = Group(conv, self.dim, kernel_size, blocks=blocks)
        self.g3 = Group(conv, self.dim, kernel_size, blocks=blocks)


        self.ca = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(self.dim * self.gps, self.dim // 16, 1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.dim // 16, self.dim * self.gps, 1, padding=0, bias=True),
            nn.Sigmoid()
        )


        self.palayer = PALayer(self.dim)

        post_process = [
            conv(self.dim, self.dim, kernel_size),
            conv(self.dim, 3, kernel_size)
        ]

        self.pre = nn.Sequential(*pre_process)
        self.post = nn.Sequential(*post_process)

    def forward(self, x1):
        x = self.pre(x1)
        res1 = self.g1(x)
        res2 = self.g2(res1)
        res3 = self.g3(res2)
        w = self.ca(torch.cat([res1, res2, res3], dim=1))
        w = w.view(-1, self.gps, self.dim)[:, :, :, None, None]
        out = w[:, 0, ::] * res1 + w[:, 1, ::] * res2 + w[:, 2, ::] * res3
        out = self.palayer(out)
        x = self.post(out)
        return x + x1


if __name__ == "__main__":

    input = torch.randn(1, 3, 32, 32) # B C H W

    model = FFA(gps=3, blocks=20)

    output = model(input)

    print(input.size())    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章