2024即插即用移位窗口注意力机制SWA,涨点起飞!

文摘   2024-12-06 18:09   中国香港  

论文介绍

题目:DAU-Net: Dual attention-aided U-Net for segmenting tumor in breast ultrasound images

论文地址:https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0303670

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 双注意力机制的集成:提出了一种基于深度学习的新型乳腺超声图像分割方法,将两种强大的注意力机制(位置卷积块注意力模块 (PCBAM) 和移位窗口注意力 (SWA))集成到残差 U-Net 模型中。这些机制分别用于增强局部特征的上下文信息和捕获全局依赖关系。

  • 改进的PCBAM模块:该模块结合了通道和空间注意力的卷积块注意力模块 (CBAM),以及位置注意力模块 (PAM),提升了模型捕获局部特征之间空间关系的能力。

  • 在瓶颈层引入SWA:SWA在模型的瓶颈层中用于捕获全局上下文信息,从而进一步提高分割性能。

  • 出色的实验结果:在两组乳腺超声图像数据集(BUSI 和 UDIAT)上的实验中,该方法分别实现了 74.23% 和 78.58% 的 Dice 得分,超越了其他同类先进方法,证明了模型在分割乳腺肿瘤区域上的准确性和鲁棒性。

  • 综合的损失函数:模型训练中结合了 Dice 损失、二元交叉熵 (BCE) 损失和焦点损失的组合,提高了训练过程中的准确性和对难分类像素的关注。

  • 全面的消融研究与比较实验:通过系统的消融研究,验证了 PCBAM 和 SWA 组件对模型性能的贡献,并在多个基准模型中展示了其优越性。

方法

整体架构

     基于 Residual U-Net 的双注意力增强模型(DAU-Net),其整体结构包括编码器、解码器和瓶颈层。编码器通过卷积层和残差连接提取多层次特征,同时在跳跃连接中加入 PCBAM(Positional Convolutional Block Attention Module),增强局部特征的空间和上下文信息捕获能力;瓶颈层集成了 SWA(Shifted Window Attention) 模块,用于捕获全局上下文信息;解码器通过反卷积结合跳跃连接的方式融合多层次特征,逐步恢复空间信息并生成分割掩码。PCBAM和SWA的结合显著提升了模型对局部和全局特征的建模能力,使其在乳腺肿瘤的分割任务中表现优越。

1. 模型架构

该模型由以下主要组件组成:

  • 编码器(Encoder):用于提取输入图像的特征。

    • 使用 3x3 卷积层进行特征提取,并配合批量归一化(Batch Normalization)和 ReLU 激活函数。

    • 使用残差连接(Residual Connections)确保梯度的有效传播,同时保留重要信息。

    • PCBAM模块在编码器层中增强特征提取能力,特别是捕获局部的空间和上下文信息。

  • 瓶颈层(Bottleneck Layer)

    • 集成了 Shifted Window Attention (SWA),捕获全局上下文信息,进一步提升模型的特征表示能力。

    • SWA机制通过滑动窗口的方式获取长距离依赖关系,改善空间一致性。

  • 解码器(Decoder):用于恢复图像的空间信息并生成分割结果。

    • 使用反卷积层(Upsampling)对编码器生成的特征进行上采样。

    • 融合了来自编码器的低级和高级特征,通过跳跃连接(Skip Connections)保留空间信息。

    • PCBAM模块也应用于解码器,增强解码过程中特征的表示能力。

  • 输出层:通过一层卷积将解码器输出转换为二值分割掩码。


2. 关键模块

PCBAM(Positional Convolutional Block Attention Module)

  • PCBAM结合了:

    • CBAM(Convolutional Block Attention Module):通过通道和空间注意力,选择性地关注有用的特征。

    • PAM(Positional Attention Module):捕获空间上下文信息和像素之间的关系。

  • PCBAM在编码器和解码器中均应用,增强了特征提取和融合能力。

SWA(Shifted Window Attention)

  • 集成在瓶颈层,用于建模长距离依赖。

  • 滑动窗口机制能够在局部窗口内计算注意力,同时避免信息丢失。


3. 数据流动与连接方式

  • 跳跃连接(Skip Connections)

    • 编码器和解码器之间的特征在不同层级上进行融合。

    • 跳跃连接中加入 PCBAM 模块,进一步提升特征表达能力。

  • 残差连接(Residual Connections)

    • 确保信息流在深层网络中的有效传播,避免梯度消失问题。


4. 输出

模型最终生成与输入图像大小一致的二值分割掩码,用于标注乳腺肿瘤的区域。

即插即用模块作用

SWA 作为一个即插即用模块

  • 捕获全局依赖关系SWA通过滑动窗口注意力机制,建立输入图像不同区域之间的长距离依赖关系,从而改善模型对全局上下文的感知能力。


  • 提升空间一致性在医学图像分割任务中,肿瘤等目标通常表现为不规则形状,SWA能够帮助模型捕捉这些目标的全局特征,增强分割结果的空间一致性。

  • 高效注意力计算相较于传统的全局注意力机制,SWA仅在滑动窗口范围内计算注意力,降低了计算复杂度,同时保持了注意力机制的全局建模能力。

  • 改善小样本分割的鲁棒性

    • 在医学影像分割中,训练数据量通常有限。SWA通过全局特征捕获能力,提升了模型在小样本环境下的表现。

消融实验结果

      展示了消融实验中不同模型的性能比较,包括Dice得分、IoU、准确率、精确率和召回率等指标。随着注意力机制的逐步引入,模型性能不断提升:基础Residual U-Net模型的Dice得分为68.27%,IoU为55.82%;加入PAM和CBAM模块后分别有所改进;结合PAM和CBAM为PCBAM后性能进一步提升;最终完整模型(PCBAM + SWA)取得最佳效果,Dice得分为74.23%,IoU为65.32%,表明PCBAM和SWA的引入显著增强了模型捕获上下文和空间特征的能力。

     通过分割结果对比和特征热力图展示了不同模型的性能差异。随着PCBAM和SWA的加入,模型能够更准确地关注乳腺肿瘤的感兴趣区域,提升了分割结果的空间一致性和特征表达能力。最终完整模型(PCBAM + SWA)生成的分割结果与真实标签(Ground Truth)高度吻合,展现了模型在肿瘤分割任务中的优越性能。

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F
#论文:DAU-Net: Dual attention-aided U-Net for segmenting tumor in breast ultrasound images
#论文地址:https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0303670
class SWA(nn.Module):
    def __init__(self, in_channels, n_heads=8, window_size=7):
        super(SWA, self).__init__()
        self.in_channels = in_channels
        self.n_heads = n_heads
        self.window_size = window_size

        self.query_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        batch_size, C, height, width = x.size()
        padded_x = F.pad(x, [self.window_size // 2, self.window_size // 2, self.window_size // 2, self.window_size // 2], mode='reflect')

        proj_query = self.query_conv(x).view(batch_size, self.n_heads, C // self.n_heads, height * width)
        proj_key = self.key_conv(padded_x).unfold(2, self.window_size, 1).unfold(3, self.window_size, 1)
        proj_key = proj_key.permute(0, 1, 4, 5, 2, 3).contiguous().view(batch_size, self.n_heads, C // self.n_heads, -1)
        proj_value = self.value_conv(padded_x).unfold(2, self.window_size, 1).unfold(3, self.window_size, 1)
        proj_value = proj_value.permute(0, 1, 4, 5, 2, 3).contiguous().view(batch_size, self.n_heads, C // self.n_heads, -1)

        energy = torch.matmul(proj_query.permute(0, 1, 3, 2), proj_key)
        attention = self.softmax(energy)

        out_window = torch.matmul(attention, proj_value.permute(0, 1, 3, 2))
        out_window = out_window.permute(0, 1, 3, 2).contiguous().view(batch_size, C, height, width)

        out = self.gamma * out_window + x
        return out

if __name__ == '__main__':

    input = torch.randn(1, 64, 32, 32)
    block = SWA(in_channels=64)
    print(input.size())
    output = block(input)
    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章