2024多级卷积模块MCM,涨点起飞起飞了

文摘   2024-12-03 17:20   上海  

论文介绍

题目:MAGNet: Multi-scale Awareness and Global fusion Network for RGB-D salient object detection | KBS

论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0950705124007603

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 设计了一种轻量化的多模态融合网络:

    • 提出了一种名为MAGNet(Multi-scale Awareness and Global fusion Network)的网络,用于RGB-D显著性目标检测。

    • 通过16.1M的参数和9.9G的FLOPs实现了与先进方法相当的检测性能,同时大幅度减少了模型的复杂度。

  • 模块创新:

    • 多尺度感知融合模块(MAFM):充分利用低层特征图中的纹理信息和边缘信息,减少计算复杂度。

    • 全局融合模块(GFM):结合注意力机制与卷积神经网络,增强高层特征图的语义信息。

    • 多级卷积模块(MCM):用于逐步解码融合特征图,生成精细的预测结果。

  • 跨模态特征融合:

    • 在低层特征中,使用MAFM实现RGB和深度图特征的跨模态融合,以减少复杂背景和低光条件下的干扰。

    • 在高层特征中,通过GFM设计全局融合,实现对RGB与深度图语义信息的全面整合。

  • 性能优化与验证:

    • 在6个公共数据集上进行实验,结果表明MAGNet不仅在精确度上优于现有方法,而且在参数量和计算复杂度上显著减少。

    • 提供了一个轻量化版本MAGNet-S,进一步验证了其适应低计算资源环境的能力。

方法

整体架构

  • 双流编码器

    • RGB图像由 SMT(Swapped Mix Transformer) 提取多级特征。

    • 深度图像由 MobileNetV2 提取多级特征。

    • 这种设计结合了Transformer的全局感知能力和轻量化网络的高效性。

  • 特征融合模块

    • 低层特征融合:通过**多尺度感知融合模块(MAFM)**实现,融合RGB和深度图像的低层特征,充分利用纹理和边缘信息。

    • 高层特征融合:通过**全局融合模块(GFM)**实现,将RGB和深度图像的语义信息进行全局关联和融合。

  • 解码器

    • 使用多级卷积模块(MCM),逐步将融合后的特征图解码为显著性目标图。

    • MCM通过层级特征整合生成精细的显著性目标预测结果。



核心模块描述

(1) 多尺度感知融合模块(MAFM)

  • 低层特征融合,结合了深度可分离卷积(DW)、点卷积(PW)以及多头混合卷积(MHMC)。

  • 目的:降低计算复杂度的同时增强特征图的空间相关性。

(2) 全局融合模块(GFM)

  • 高层特征融合,采用注意力机制结合卷积运算。

  • 特点:

    • 融合RGB和深度特征的全局语义信息。

    • 通过注意力机制有效捕获跨模态的全局关联。

(3) 多级卷积模块(MCM)

  • 解码器部分,每级特征通过上采样、深度可分离卷积和逐点卷积逐步整合。

  • 目标:从低层到高层逐步恢复图像细节,生成高质量的显著性目标预测图。

即插即用模块作用

MCM 作为一个即插即用模块

(1) 特征融合与逐步解码

  • MCM通过逐级特征整合,结合高层语义信息和低层细节信息,从而逐步恢复特征图中的细节。

  • 该模块可以有效减少特征丢失,同时保留丰富的细节。

(2) 降低计算复杂度

  • MCM中使用了深度可分离卷积(Depth-wise Convolution)和逐点卷积(Point-wise Convolution),极大降低了计算量。

  • 对于资源受限的场景,MCM的轻量化设计显得尤为重要。

(3) 提升多层次特征的表达能力

  • 高层特征中包含的全局语义信息可以通过MCM逐级整合至低层特征,补充细节。

  • 低层特征可以帮助更精确地定位边缘和局部区域。


消融实验结果

  • 对比不同主干网络组合对模型性能的影响

    • 采用轻量化的MobileNetV2可以减少参数和计算量,但在RGB细节提取上表现不足。

    • 采用SMT作为RGB图像主干网络显著提升了模型性能。

  • 结论:SMT和MobileNetV2的结合在计算效率和性能之间达成了良好平衡。


  • 对比是否使用MAFM以及其他替代模块(PI和CMFM)

    • 添加MAFM后模型性能显著提升,MAE指标在各数据集上降低。

    • 与其他模块相比,MAFM以较少的参数实现了更高的检测精度。

  • 结论:MAFM在低层特征融合中有效整合了RGB和深度特征,尤其在复杂场景下提升了模型的鲁棒性。


  • 对比不同分辨率下的模型性能和计算量

    • 提高分辨率可以提升检测精度,但会显著增加计算量(FLOPs)和降低推理速度。

    • 最终选择384×384作为平衡点,兼顾性能和效率。

  • 结论:分辨率对模型性能和效率有直接影响,应根据应用需求选择合适的输入尺寸。


  • 对比是否使用MHMC以及替代方法(单层卷积)

    • 使用MHMC的模型在多个数据集上的性能均优于其他方法。

  • 结论:MHMC在MAFM中通过捕获多尺度的相关性增强了RGB和深度特征的融合。


  • 对比是否使用GFM以及替代方法(SCA和AF模块)

    • 添加GFM后,模型在多数据集上均有性能提升。

    • GFM的性能略优于其他方法(如SCA和AF),尤其在复杂场景中表现更好。

  • 结论:GFM能够更有效地融合RGB和深度特征的全局语义信息。

即插即用模块

import torch.nn as nn
import torch
import torch.nn.functional as F
# 论文:MAGNet: Multi-scale Awareness and Global fusion Network for RGB-D salient object detection | KBS
# 论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0950705124007603
# github地址:https://github.com/mingyu6346/MAGNet

TRAIN_SIZE = 384

class MCM(nn.Module):
    def __init__(self, inc, outc):
        super().__init__()
        self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.rc = nn.Sequential(
            nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=3, padding=1, stride=1, groups=inc),
            nn.BatchNorm2d(inc),
            nn.GELU(),
            nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1),
            nn.BatchNorm2d(outc),
            nn.GELU()
        )
        self.predtrans = nn.Sequential(
            nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, padding=1, groups=outc),
            nn.BatchNorm2d(outc),
            nn.GELU(),
            nn.Conv2d(in_channels=outc, out_channels=1, kernel_size=1)
        )

        self.rc2 = nn.Sequential(
            nn.Conv2d(in_channels=outc * 2, out_channels=outc * 2, kernel_size=3, padding=1, groups=outc * 2),
            nn.BatchNorm2d(outc * 2),
            nn.GELU(),
            nn.Conv2d(in_channels=outc * 2, out_channels=outc, kernel_size=1, stride=1),
            nn.BatchNorm2d(outc),
            nn.GELU()
        )

    def forward(self, x1, x2):
        x2_upsample = self.upsample2(x2) # 上采样
        x2_rc = self.rc(x2_upsample) # 减少通道数
        shortcut = x2_rc

        x_cat = torch.cat((x1, x2_rc), dim=1) # 拼接
        x_forward = self.rc2(x_cat) # 减少通道数2
        x_forward = x_forward + shortcut
        pred = F.interpolate(self.predtrans(x_forward), TRAIN_SIZE, mode="bilinear", align_corners=True) # 预测图

        return pred, x_forward


if __name__ == '__main__':

    inc = 64  # 输入通道数
    outc = 32  # 输出通道数
    mcm = MCM(inc=inc, outc=outc)

    x1 = torch.randn(1, outc, 96, 96) # Batch size=1, Channels=outc, Height=96, Width=96
    x2 = torch.randn(1, inc, 48, 48) # Batch size=1, Channels=inc, Height=48, Width=48

    pred, x_forward = mcm(x1, x2)

    print(x1.size())
    print(x2.size())
    print(pred.size())
    print(x_forward.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章