论文介绍
题目:MAGNet: Multi-scale Awareness and Global fusion Network for RGB-D salient object detection | KBS
论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0950705124007603
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
设计了一种轻量化的多模态融合网络:
提出了一种名为MAGNet(Multi-scale Awareness and Global fusion Network)的网络,用于RGB-D显著性目标检测。
通过16.1M的参数和9.9G的FLOPs实现了与先进方法相当的检测性能,同时大幅度减少了模型的复杂度。
模块创新:
多尺度感知融合模块(MAFM):充分利用低层特征图中的纹理信息和边缘信息,减少计算复杂度。
全局融合模块(GFM):结合注意力机制与卷积神经网络,增强高层特征图的语义信息。
多级卷积模块(MCM):用于逐步解码融合特征图,生成精细的预测结果。
跨模态特征融合:
在低层特征中,使用MAFM实现RGB和深度图特征的跨模态融合,以减少复杂背景和低光条件下的干扰。
在高层特征中,通过GFM设计全局融合,实现对RGB与深度图语义信息的全面整合。
性能优化与验证:
在6个公共数据集上进行实验,结果表明MAGNet不仅在精确度上优于现有方法,而且在参数量和计算复杂度上显著减少。
提供了一个轻量化版本MAGNet-S,进一步验证了其适应低计算资源环境的能力。
方法
整体架构
双流编码器:
RGB图像由 SMT(Swapped Mix Transformer) 提取多级特征。
深度图像由 MobileNetV2 提取多级特征。
这种设计结合了Transformer的全局感知能力和轻量化网络的高效性。
特征融合模块:
低层特征融合:通过**多尺度感知融合模块(MAFM)**实现,融合RGB和深度图像的低层特征,充分利用纹理和边缘信息。
高层特征融合:通过**全局融合模块(GFM)**实现,将RGB和深度图像的语义信息进行全局关联和融合。
解码器:
使用多级卷积模块(MCM),逐步将融合后的特征图解码为显著性目标图。
MCM通过层级特征整合生成精细的显著性目标预测结果。
核心模块描述
(1) 多尺度感知融合模块(MAFM)
低层特征融合,结合了深度可分离卷积(DW)、点卷积(PW)以及多头混合卷积(MHMC)。
目的:降低计算复杂度的同时增强特征图的空间相关性。
(2) 全局融合模块(GFM)
高层特征融合,采用注意力机制结合卷积运算。
特点:
融合RGB和深度特征的全局语义信息。
通过注意力机制有效捕获跨模态的全局关联。
(3) 多级卷积模块(MCM)
解码器部分,每级特征通过上采样、深度可分离卷积和逐点卷积逐步整合。
目标:从低层到高层逐步恢复图像细节,生成高质量的显著性目标预测图。
即插即用模块作用
MCM 作为一个即插即用模块:
(1) 特征融合与逐步解码
MCM通过逐级特征整合,结合高层语义信息和低层细节信息,从而逐步恢复特征图中的细节。
该模块可以有效减少特征丢失,同时保留丰富的细节。
(2) 降低计算复杂度
MCM中使用了深度可分离卷积(Depth-wise Convolution)和逐点卷积(Point-wise Convolution),极大降低了计算量。
对于资源受限的场景,MCM的轻量化设计显得尤为重要。
(3) 提升多层次特征的表达能力
高层特征中包含的全局语义信息可以通过MCM逐级整合至低层特征,补充细节。
低层特征可以帮助更精确地定位边缘和局部区域。
消融实验结果
对比不同主干网络组合对模型性能的影响:
采用轻量化的MobileNetV2可以减少参数和计算量,但在RGB细节提取上表现不足。
采用SMT作为RGB图像主干网络显著提升了模型性能。
结论:SMT和MobileNetV2的结合在计算效率和性能之间达成了良好平衡。
对比是否使用MAFM以及其他替代模块(PI和CMFM):
添加MAFM后模型性能显著提升,MAE指标在各数据集上降低。
与其他模块相比,MAFM以较少的参数实现了更高的检测精度。
结论:MAFM在低层特征融合中有效整合了RGB和深度特征,尤其在复杂场景下提升了模型的鲁棒性。
对比不同分辨率下的模型性能和计算量:
提高分辨率可以提升检测精度,但会显著增加计算量(FLOPs)和降低推理速度。
最终选择384×384作为平衡点,兼顾性能和效率。
结论:分辨率对模型性能和效率有直接影响,应根据应用需求选择合适的输入尺寸。
对比是否使用MHMC以及替代方法(单层卷积):
使用MHMC的模型在多个数据集上的性能均优于其他方法。
结论:MHMC在MAFM中通过捕获多尺度的相关性增强了RGB和深度特征的融合。
对比是否使用GFM以及替代方法(SCA和AF模块):
添加GFM后,模型在多数据集上均有性能提升。
GFM的性能略优于其他方法(如SCA和AF),尤其在复杂场景中表现更好。
结论:GFM能够更有效地融合RGB和深度特征的全局语义信息。
即插即用模块
import torch.nn as nn
import torch
import torch.nn.functional as F
# 论文:MAGNet: Multi-scale Awareness and Global fusion Network for RGB-D salient object detection | KBS
# 论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0950705124007603
# github地址:https://github.com/mingyu6346/MAGNet
TRAIN_SIZE = 384
class MCM(nn.Module):
def __init__(self, inc, outc):
super().__init__()
self.upsample2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.rc = nn.Sequential(
nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=3, padding=1, stride=1, groups=inc),
nn.BatchNorm2d(inc),
nn.GELU(),
nn.Conv2d(in_channels=inc, out_channels=outc, kernel_size=1, stride=1),
nn.BatchNorm2d(outc),
nn.GELU()
)
self.predtrans = nn.Sequential(
nn.Conv2d(in_channels=outc, out_channels=outc, kernel_size=3, padding=1, groups=outc),
nn.BatchNorm2d(outc),
nn.GELU(),
nn.Conv2d(in_channels=outc, out_channels=1, kernel_size=1)
)
self.rc2 = nn.Sequential(
nn.Conv2d(in_channels=outc * 2, out_channels=outc * 2, kernel_size=3, padding=1, groups=outc * 2),
nn.BatchNorm2d(outc * 2),
nn.GELU(),
nn.Conv2d(in_channels=outc * 2, out_channels=outc, kernel_size=1, stride=1),
nn.BatchNorm2d(outc),
nn.GELU()
)
def forward(self, x1, x2):
x2_upsample = self.upsample2(x2) # 上采样
x2_rc = self.rc(x2_upsample) # 减少通道数
shortcut = x2_rc
x_cat = torch.cat((x1, x2_rc), dim=1) # 拼接
x_forward = self.rc2(x_cat) # 减少通道数2
x_forward = x_forward + shortcut
pred = F.interpolate(self.predtrans(x_forward), TRAIN_SIZE, mode="bilinear", align_corners=True) # 预测图
return pred, x_forward
if __name__ == '__main__':
inc = 64 # 输入通道数
outc = 32 # 输出通道数
mcm = MCM(inc=inc, outc=outc)
x1 = torch.randn(1, outc, 96, 96) # Batch size=1, Channels=outc, Height=96, Width=96
x2 = torch.randn(1, inc, 48, 48) # Batch size=1, Channels=inc, Height=48, Width=48
pred, x_forward = mcm(x1, x2)
print(x1.size())
print(x2.size())
print(pred.size())
print(x_forward.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文