混合自适应注意力模块HAAM,涨点起飞起飞了!

文摘   2024-12-18 17:20   中国香港  

论文介绍

题目:AAU-net: An Adaptive Attention U-net for Breast Lesions Segmentation in Ultrasound Images

论文地址:https://arxiv.org/pdf/2204.12077

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 混合自适应注意模块 (HAAM):设计了一种新颖的混合自适应注意模块,通过结合不同尺度的卷积核,并集成通道自注意力和空间自注意力模块,有效地增强了对乳腺超声图像中复杂特征的捕捉能力。

  • 自适应注意 U-net (AAU-net):在经典的 U-net 基础上,通过引入 HAAM,开发了一种自适应注意 U-net,该网络能够更稳定和自动地处理乳腺病灶的分割问题。

  • 更强的特征表征能力:HAAM 模块能够自适应地选择不同尺度的感受野,在通道和空间维度上学习更加鲁棒的特征表示,从而应对复杂病灶形态和模糊边界的干扰。

  • 实验验证的优越性能

    • 在三个公开乳腺超声数据集上的实验表明,该方法在分割精度上显著优于多种现有的深度学习方法。

    • 通过外部验证数据的测试,该方法表现出较好的泛化能力和鲁棒性。

方法

整体架构

     Adaptive Attention U-net (AAU-net) 是一种基于 U-net 改进的分割网络,通过引入混合自适应注意模块(HAAM)提升对乳腺病灶分割的性能。HAAM 集成了多尺度卷积、通道自注意力和空间自注意力,能够自适应地捕捉不同尺度和维度的特征,在编码器和解码器的每个阶段中替代传统卷积操作。结合跳跃连接和上下采样结构,AAU-net 在保留空间细节的同时增强了模型对复杂病灶特征的表达能力,有效应对边界模糊、形态多变的乳腺病灶分割挑战。

1. 基本框架

AAU-net 保持了 U-net 的核心结构,即一个经典的 U 型网络,由以下部分组成:

  • 编码器(Encoder):包括 4 个下采样阶段,每个阶段由两层卷积模块和池化操作组成,用于提取图像的多尺度特征。

  • 解码器(Decoder):包括 4 个上采样阶段,每个阶段由转置卷积操作(上采样)和对应的编码器特征图的跳跃连接(skip connection)组成,用于逐步恢复图像空间信息。

  • 跳跃连接(Skip Connection):将编码器的低级特征与解码器的高级特征融合,增强分割效果。

2. 核心改进 - 引入混合自适应注意模块 (HAAM)

每个编码和解码阶段中都使用了两个 Hybrid Adaptive Attention Modules (HAAM) 来替代传统的卷积操作,改进了 U-net 的特征提取能力。
HAAM 包括以下三部分:

  • 多尺度卷积

    • 使用卷积和膨胀卷积获取不同尺度的感受野,增强对不同大小病灶的适应性。

  • 通道自注意力模块(Channel Self-Attention Block)

    • 学习通道维度上的特征重要性,强调对特定通道特征的关注。

  • 空间自注意力模块(Spatial Self-Attention Block)

    • 学习空间维度上的特征权重,增强模型对目标位置的聚焦能力。

即插即用模块作用

HAAM 作为一个即插即用模块

(1) 医学图像分割

  • 主要应用于 乳腺超声图像 中的病灶分割任务,尤其是处理病灶形态复杂、边界模糊、背景干扰严重的图像。

  • 适用于 CT、MRI、X光 等其他医学成像模式的分割任务,特别是在需要提取精细边界或小目标的场景中。

(2) 自然场景分割与目标检测

  • 在需要捕捉多尺度特征的自然场景分割任务中,例如遥感图像的道路提取、植被分割等。

  • 可用于目标检测任务,通过增强特征提取能力提升目标分类和定位的精度。

(3) 复杂背景中的小目标分割

  • 例如监控场景中的目标分割、交通场景中车辆/行人的检测,HAAM 能有效抑制背景干扰,聚焦目标区域。

消融实验结果

         展示了不同网络组件对分割性能的影响,包括使用基础 U-net、加入通道自注意力模块、空间自注意力模块,以及完整的混合自适应注意模块(HAAM)。结果表明,HAAM 模块显著提高了分割指标(如 Jaccard、Dice 等),说明通道和空间维度的注意力机制对特征提取的增强作用。


       评估了不同卷积核大小和膨胀率(dilation rate)对分割性能的影响。实验对比了较小(3×3 卷积)、较大(5×5卷积)和默认设置(3×3, 5×5, 膨胀率为3)三种配置,结果表明,论文提出的默认配置在 Jaccard、Dice 等指标上均表现最佳,验证了感受野设计的合理性。


         分别对良性和恶性病灶进行分割性能评估。结果显示,该方法在这两种类型中均表现出较好的鲁棒性,特别是在恶性病灶中,由于其边界模糊、形态复杂,AAU-net 的性能优势更加明显。

即插即用模块

import torch
import torch.nn as nn
#论文:AAU-net: An Adaptive Attention U-net for Breast Lesions Segmentation in Ultrasound Images
#论文地址:https://arxiv.org/pdf/2204.12077

def expend_as(tensor, rep):
    return tensor.repeat(1, rep, 1, 1)


class Channelblock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Channelblock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=3, dilation=3),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=5, padding=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(out_channels * 2, out_channels),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(),
            nn.Linear(out_channels, out_channels),
            nn.Sigmoid()
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, x):
        conv1 = self.conv1(x)
        conv2 = self.conv2(x)

        combined = torch.cat([conv1, conv2], dim=1)
        pooled = self.global_avg_pool(combined)
        pooled = torch.flatten(pooled, 1)
        sigm = self.fc(pooled)

        a = sigm.view(-1, sigm.size(1), 1, 1)
        a1 = 1 - sigm
        a1 = a1.view(-1, a1.size(1), 1, 1)

        y = conv1 * a
        y1 = conv2 * a1

        combined = torch.cat([y, y1], dim=1)
        out = self.conv3(combined)

        return out


class Spatialblock(nn.Module):
    def __init__(self, in_channels, out_channels, size):
        super(Spatialblock, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=5, padding=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

        self.final_conv = nn.Sequential(
            nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, kernel_size=size, padding=(size // 2)),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x, channel_data):
        conv1 = self.conv1(x)
        spatil_data = self.conv2(conv1)

        data3 = torch.add(channel_data, spatil_data)
        data3 = torch.relu(data3)
        data3 = nn.Conv2d(data3.size(1), 1, kernel_size=1, padding=0).cuda()(data3)
        data3 = torch.sigmoid(data3)

        a = expend_as(data3, channel_data.size(1))
        y = a * channel_data

        a1 = 1 - data3
        a1 = expend_as(a1, spatil_data.size(1))
        y1 = a1 * spatil_data

        combined = torch.cat([y, y1], dim=1)
        out = self.final_conv(combined)

        return out


class HAAM(nn.Module):
    def __init__(self, in_channels, out_channels, size=3):
        super(HAAM, self).__init__()
        self.channel_block = Channelblock(in_channels, out_channels)
        self.spatial_block = Spatialblock(out_channels, out_channels, size)

    def forward(self, x):
        channel_data = self.channel_block(x)
        haam_data = self.spatial_block(x, channel_data)
        return haam_data


if __name__ == '__main__':
    print(torch.__version__)
    print(torch.cuda.is_available())
    print(torch.version.cuda)

    # 创建示例输入张量
    batch_size = 2
    in_channels = 64  # 输入通道数
    height, width = 224, 224  # 输入图像的高度和宽度
    input_tensor = torch.randn(batch_size, in_channels, height, width).cuda()

    # 实例化 HAAM 模型
    out_channels = 64  # 输出通道数
    haam_model = HAAM(in_channels, out_channels).cuda()

    # 前向传播
    output_tensor = haam_model(input_tensor)

    # 打印输入输出的形状
    print("输入张量形状:", input_tensor.shape)
    print("输出张量形状:", output_tensor.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章