论文介绍
题目:DAU-Net: Dual attention-aided U-Net for segmenting tumor in breast ultrasound images
论文地址:https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0303670
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
双注意力机制的集成:提出了一种基于深度学习的新型乳腺超声图像分割方法,将两种强大的注意力机制(位置卷积块注意力模块 (PCBAM) 和移位窗口注意力 (SWA))集成到残差 U-Net 模型中。这些机制分别用于增强局部特征的上下文信息和捕获全局依赖关系。
改进的PCBAM模块:该模块结合了通道和空间注意力的卷积块注意力模块 (CBAM),以及位置注意力模块 (PAM),提升了模型捕获局部特征之间空间关系的能力。
在瓶颈层引入SWA:SWA在模型的瓶颈层中用于捕获全局上下文信息,从而进一步提高分割性能。
出色的实验结果:在两组乳腺超声图像数据集(BUSI 和 UDIAT)上的实验中,该方法分别实现了 74.23% 和 78.58% 的 Dice 得分,超越了其他同类先进方法,证明了模型在分割乳腺肿瘤区域上的准确性和鲁棒性。
综合的损失函数:模型训练中结合了 Dice 损失、二元交叉熵 (BCE) 损失和焦点损失的组合,提高了训练过程中的准确性和对难分类像素的关注。
全面的消融研究与比较实验:通过系统的消融研究,验证了 PCBAM 和 SWA 组件对模型性能的贡献,并在多个基准模型中展示了其优越性。
方法
整体架构
基于 Residual U-Net 的双注意力增强模型(DAU-Net),其整体结构包括编码器、解码器和瓶颈层。编码器通过卷积层和残差连接提取多层次特征,同时在跳跃连接中加入 PCBAM(Positional Convolutional Block Attention Module),增强局部特征的空间和上下文信息捕获能力;瓶颈层集成了 SWA(Shifted Window Attention) 模块,用于捕获全局上下文信息;解码器通过反卷积结合跳跃连接的方式融合多层次特征,逐步恢复空间信息并生成分割掩码。PCBAM和SWA的结合显著提升了模型对局部和全局特征的建模能力,使其在乳腺肿瘤的分割任务中表现优越。
1. 模型架构
该模型由以下主要组件组成:
编码器(Encoder):用于提取输入图像的特征。
使用 3x3 卷积层进行特征提取,并配合批量归一化(Batch Normalization)和 ReLU 激活函数。
使用残差连接(Residual Connections)确保梯度的有效传播,同时保留重要信息。
PCBAM模块在编码器层中增强特征提取能力,特别是捕获局部的空间和上下文信息。
瓶颈层(Bottleneck Layer):
集成了 Shifted Window Attention (SWA),捕获全局上下文信息,进一步提升模型的特征表示能力。
SWA机制通过滑动窗口的方式获取长距离依赖关系,改善空间一致性。
解码器(Decoder):用于恢复图像的空间信息并生成分割结果。
使用反卷积层(Upsampling)对编码器生成的特征进行上采样。
融合了来自编码器的低级和高级特征,通过跳跃连接(Skip Connections)保留空间信息。
PCBAM模块也应用于解码器,增强解码过程中特征的表示能力。
输出层:通过一层卷积将解码器输出转换为二值分割掩码。
2. 关键模块
PCBAM(Positional Convolutional Block Attention Module):
PCBAM结合了:
CBAM(Convolutional Block Attention Module):通过通道和空间注意力,选择性地关注有用的特征。
PAM(Positional Attention Module):捕获空间上下文信息和像素之间的关系。
PCBAM在编码器和解码器中均应用,增强了特征提取和融合能力。
SWA(Shifted Window Attention):
集成在瓶颈层,用于建模长距离依赖。
滑动窗口机制能够在局部窗口内计算注意力,同时避免信息丢失。
3. 数据流动与连接方式
跳跃连接(Skip Connections):
编码器和解码器之间的特征在不同层级上进行融合。
跳跃连接中加入 PCBAM 模块,进一步提升特征表达能力。
残差连接(Residual Connections):
确保信息流在深层网络中的有效传播,避免梯度消失问题。
4. 输出
模型最终生成与输入图像大小一致的二值分割掩码,用于标注乳腺肿瘤的区域。
即插即用模块作用
SWA 作为一个即插即用模块:
捕获全局依赖关系:SWA通过滑动窗口注意力机制,建立输入图像不同区域之间的长距离依赖关系,从而改善模型对全局上下文的感知能力。
提升空间一致性:在医学图像分割任务中,肿瘤等目标通常表现为不规则形状,SWA能够帮助模型捕捉这些目标的全局特征,增强分割结果的空间一致性。
高效注意力计算:相较于传统的全局注意力机制,SWA仅在滑动窗口范围内计算注意力,降低了计算复杂度,同时保持了注意力机制的全局建模能力。
改善小样本分割的鲁棒性:
在医学影像分割中,训练数据量通常有限。SWA通过全局特征捕获能力,提升了模型在小样本环境下的表现。
消融实验结果
展示了消融实验中不同模型的性能比较,包括Dice得分、IoU、准确率、精确率和召回率等指标。随着注意力机制的逐步引入,模型性能不断提升:基础Residual U-Net模型的Dice得分为68.27%,IoU为55.82%;加入PAM和CBAM模块后分别有所改进;结合PAM和CBAM为PCBAM后性能进一步提升;最终完整模型(PCBAM + SWA)取得最佳效果,Dice得分为74.23%,IoU为65.32%,表明PCBAM和SWA的引入显著增强了模型捕获上下文和空间特征的能力。
通过分割结果对比和特征热力图展示了不同模型的性能差异。随着PCBAM和SWA的加入,模型能够更准确地关注乳腺肿瘤的感兴趣区域,提升了分割结果的空间一致性和特征表达能力。最终完整模型(PCBAM + SWA)生成的分割结果与真实标签(Ground Truth)高度吻合,展现了模型在肿瘤分割任务中的优越性能。
即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
#论文:DAU-Net: Dual attention-aided U-Net for segmenting tumor in breast ultrasound images
#论文地址:https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0303670
class SWA(nn.Module):
def __init__(self, in_channels, n_heads=8, window_size=7):
super(SWA, self).__init__()
self.in_channels = in_channels
self.n_heads = n_heads
self.window_size = window_size
self.query_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, C, height, width = x.size()
padded_x = F.pad(x, [self.window_size // 2, self.window_size // 2, self.window_size // 2, self.window_size // 2], mode='reflect')
proj_query = self.query_conv(x).view(batch_size, self.n_heads, C // self.n_heads, height * width)
proj_key = self.key_conv(padded_x).unfold(2, self.window_size, 1).unfold(3, self.window_size, 1)
proj_key = proj_key.permute(0, 1, 4, 5, 2, 3).contiguous().view(batch_size, self.n_heads, C // self.n_heads, -1)
proj_value = self.value_conv(padded_x).unfold(2, self.window_size, 1).unfold(3, self.window_size, 1)
proj_value = proj_value.permute(0, 1, 4, 5, 2, 3).contiguous().view(batch_size, self.n_heads, C // self.n_heads, -1)
energy = torch.matmul(proj_query.permute(0, 1, 3, 2), proj_key)
attention = self.softmax(energy)
out_window = torch.matmul(attention, proj_value.permute(0, 1, 3, 2))
out_window = out_window.permute(0, 1, 3, 2).contiguous().view(batch_size, C, height, width)
out = self.gamma * out_window + x
return out
if __name__ == '__main__':
input = torch.randn(1, 64, 32, 32)
block = SWA(in_channels=64)
print(input.size())
output = block(input)
print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文