论文介绍
题目:AAU-net: An Adaptive Attention U-net for Breast Lesions Segmentation in Ultrasound Images
论文地址:https://arxiv.org/pdf/2204.12077
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
混合自适应注意模块 (HAAM):设计了一种新颖的混合自适应注意模块,通过结合不同尺度的卷积核,并集成通道自注意力和空间自注意力模块,有效地增强了对乳腺超声图像中复杂特征的捕捉能力。
自适应注意 U-net (AAU-net):在经典的 U-net 基础上,通过引入 HAAM,开发了一种自适应注意 U-net,该网络能够更稳定和自动地处理乳腺病灶的分割问题。
更强的特征表征能力:HAAM 模块能够自适应地选择不同尺度的感受野,在通道和空间维度上学习更加鲁棒的特征表示,从而应对复杂病灶形态和模糊边界的干扰。
实验验证的优越性能:
在三个公开乳腺超声数据集上的实验表明,该方法在分割精度上显著优于多种现有的深度学习方法。
通过外部验证数据的测试,该方法表现出较好的泛化能力和鲁棒性。
方法
整体架构
Adaptive Attention U-net (AAU-net) 是一种基于 U-net 改进的分割网络,通过引入混合自适应注意模块(HAAM)提升对乳腺病灶分割的性能。HAAM 集成了多尺度卷积、通道自注意力和空间自注意力,能够自适应地捕捉不同尺度和维度的特征,在编码器和解码器的每个阶段中替代传统卷积操作。结合跳跃连接和上下采样结构,AAU-net 在保留空间细节的同时增强了模型对复杂病灶特征的表达能力,有效应对边界模糊、形态多变的乳腺病灶分割挑战。
1. 基本框架
AAU-net 保持了 U-net 的核心结构,即一个经典的 U 型网络,由以下部分组成:
编码器(Encoder):包括 4 个下采样阶段,每个阶段由两层卷积模块和池化操作组成,用于提取图像的多尺度特征。
解码器(Decoder):包括 4 个上采样阶段,每个阶段由转置卷积操作(上采样)和对应的编码器特征图的跳跃连接(skip connection)组成,用于逐步恢复图像空间信息。
跳跃连接(Skip Connection):将编码器的低级特征与解码器的高级特征融合,增强分割效果。
2. 核心改进 - 引入混合自适应注意模块 (HAAM)
每个编码和解码阶段中都使用了两个 Hybrid Adaptive Attention Modules (HAAM) 来替代传统的卷积操作,改进了 U-net 的特征提取能力。
HAAM 包括以下三部分:
多尺度卷积:
使用卷积和膨胀卷积获取不同尺度的感受野,增强对不同大小病灶的适应性。
通道自注意力模块(Channel Self-Attention Block):
学习通道维度上的特征重要性,强调对特定通道特征的关注。
空间自注意力模块(Spatial Self-Attention Block):
学习空间维度上的特征权重,增强模型对目标位置的聚焦能力。
即插即用模块作用
HAAM 作为一个即插即用模块:
(1) 医学图像分割
主要应用于 乳腺超声图像 中的病灶分割任务,尤其是处理病灶形态复杂、边界模糊、背景干扰严重的图像。
适用于 CT、MRI、X光 等其他医学成像模式的分割任务,特别是在需要提取精细边界或小目标的场景中。
(2) 自然场景分割与目标检测
在需要捕捉多尺度特征的自然场景分割任务中,例如遥感图像的道路提取、植被分割等。
可用于目标检测任务,通过增强特征提取能力提升目标分类和定位的精度。
(3) 复杂背景中的小目标分割
例如监控场景中的目标分割、交通场景中车辆/行人的检测,HAAM 能有效抑制背景干扰,聚焦目标区域。
消融实验结果
展示了不同网络组件对分割性能的影响,包括使用基础 U-net、加入通道自注意力模块、空间自注意力模块,以及完整的混合自适应注意模块(HAAM)。结果表明,HAAM 模块显著提高了分割指标(如 Jaccard、Dice 等),说明通道和空间维度的注意力机制对特征提取的增强作用。
评估了不同卷积核大小和膨胀率(dilation rate)对分割性能的影响。实验对比了较小(3×3 卷积)、较大(5×5卷积)和默认设置(3×3, 5×5, 膨胀率为3)三种配置,结果表明,论文提出的默认配置在 Jaccard、Dice 等指标上均表现最佳,验证了感受野设计的合理性。
分别对良性和恶性病灶进行分割性能评估。结果显示,该方法在这两种类型中均表现出较好的鲁棒性,特别是在恶性病灶中,由于其边界模糊、形态复杂,AAU-net 的性能优势更加明显。
即插即用模块
import torch
import torch.nn as nn
#论文:AAU-net: An Adaptive Attention U-net for Breast Lesions Segmentation in Ultrasound Images
#论文地址:https://arxiv.org/pdf/2204.12077
def expend_as(tensor, rep):
return tensor.repeat(1, rep, 1, 1)
class Channelblock(nn.Module):
def __init__(self, in_channels, out_channels):
super(Channelblock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=3, dilation=3),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=5, padding=2),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(out_channels * 2, out_channels),
nn.BatchNorm1d(out_channels),
nn.ReLU(),
nn.Linear(out_channels, out_channels),
nn.Sigmoid()
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, kernel_size=1, padding=0),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(x)
combined = torch.cat([conv1, conv2], dim=1)
pooled = self.global_avg_pool(combined)
pooled = torch.flatten(pooled, 1)
sigm = self.fc(pooled)
a = sigm.view(-1, sigm.size(1), 1, 1)
a1 = 1 - sigm
a1 = a1.view(-1, a1.size(1), 1, 1)
y = conv1 * a
y1 = conv2 * a1
combined = torch.cat([y, y1], dim=1)
out = self.conv3(combined)
return out
class Spatialblock(nn.Module):
def __init__(self, in_channels, out_channels, size):
super(Spatialblock, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=5, padding=2),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.final_conv = nn.Sequential(
nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, kernel_size=size, padding=(size // 2)),
nn.BatchNorm2d(out_channels)
)
def forward(self, x, channel_data):
conv1 = self.conv1(x)
spatil_data = self.conv2(conv1)
data3 = torch.add(channel_data, spatil_data)
data3 = torch.relu(data3)
data3 = nn.Conv2d(data3.size(1), 1, kernel_size=1, padding=0).cuda()(data3)
data3 = torch.sigmoid(data3)
a = expend_as(data3, channel_data.size(1))
y = a * channel_data
a1 = 1 - data3
a1 = expend_as(a1, spatil_data.size(1))
y1 = a1 * spatil_data
combined = torch.cat([y, y1], dim=1)
out = self.final_conv(combined)
return out
class HAAM(nn.Module):
def __init__(self, in_channels, out_channels, size=3):
super(HAAM, self).__init__()
self.channel_block = Channelblock(in_channels, out_channels)
self.spatial_block = Spatialblock(out_channels, out_channels, size)
def forward(self, x):
channel_data = self.channel_block(x)
haam_data = self.spatial_block(x, channel_data)
return haam_data
if __name__ == '__main__':
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.version.cuda)
# 创建示例输入张量
batch_size = 2
in_channels = 64 # 输入通道数
height, width = 224, 224 # 输入图像的高度和宽度
input_tensor = torch.randn(batch_size, in_channels, height, width).cuda()
# 实例化 HAAM 模型
out_channels = 64 # 输出通道数
haam_model = HAAM(in_channels, out_channels).cuda()
# 前向传播
output_tensor = haam_model(input_tensor)
# 打印输入输出的形状
print("输入张量形状:", input_tensor.shape)
print("输出张量形状:", output_tensor.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文