(IEEE 2024)即插即用局部金字塔注意力模块LPA,涨点涨爆了

文摘   2024-12-02 17:20   上海  

论文介绍

题目:SwinPA-Net: Swin Transformer-Based Multiscale Feature Pyramid Aggregation Network for Medical Image Segmentation

论文地址:https://ieeexplore.ieee.org/document/9895210

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出了基于Swin Transformer的U型网络结构:该网络利用Swin Transformer的全局和动态感受野,通过自注意力机制和滑动窗口机制增强了特征提取能力。

  • 引入了两种新型模块

    • 密集乘法连接模块 (DMC):通过多尺度语义特征的乘法融合,有效减少浅层特征中的噪声干扰,增强特征表达能力,并更好地识别不同大小的病变边界。

    • 局部金字塔注意力模块 (LPA):结合局部和全局注意力机制,通过分层处理和特征融合,引导网络聚焦目标区域,提高语义特征的辨别能力。

  • 构建了新的临床数据集 (LIVis):该数据集专注于腹腔镜图像中的疏松结缔组织分割,包含1180张高分辨率图像,标注由外科专家交叉验证。这为模型在真实场景下的评估提供了重要基准。

  • 性能显著提升:该方法在三个不同的医学图像分割任务中表现出色(肠息肉分割、皮肤病变分割、腹腔镜图像分割),与最先进方法相比,平均Dice系数提高了1.68%、0.8%、1.2%。


方法

模型总体架构

SwinPA-Net 是一个基于 U 型结构的医学图像分割网络,采用 Swin Transformer 作为编码器,通过其滑动窗口自注意力机制实现全局与局部特征提取,同时结合密集乘法连接模块(DMC)进行多尺度特征融合,降低噪声干扰,增强边界表达,并利用局部金字塔注意力模块(LPA)聚焦目标区域,提高语义特征区分能力。解码器逐步上采样特征图至原始分辨率,生成高精度分割结果,适用于多种医学场景。

(1) 编码器(Encoder)

  • 基于 Swin Transformer,采用层级自注意力机制提取输入图像的多尺度特征。

  • 编码器包括 4 个阶段,每个阶段的特征图尺寸逐步减小(如 H/s、H/2s、H/4s、H/8s),通道数逐步增加。

  • 通过 滑动窗口多头自注意力机制 (SW-MSA),减小计算复杂度并保持局部和全局特征的关联。


(2) 密集乘法连接模块 (Dense Multiplicative Connection, DMC)

  • 位于特征融合阶段,利用浅层和深层特征之间的乘法方式进行多尺度特征融合。

  • 目标:

    • 降低浅层特征的背景噪声干扰。

    • 强化目标边界特征的表达。

    • 通过多尺度特征融合,缓解病变区域大小差异带来的分割困难。


(3) 局部金字塔注意力模块 (Local Pyramid Attention, LPA)

  • 基于金字塔结构,将注意力机制应用于局部和全局特征融合。

  • 处理流程:

    • 将特征图划分为不同维度的局部子区域。

    • 分别计算每个区域的空间注意力和通道注意力。

    • 融合不同层次的注意力特征以聚焦目标区域,同时压制无关信息。

  • 目标:

    • 提升模型在低对比度和复杂背景下的目标区分能力。


(4) 解码器(Decoder)

  • 解码器采用卷积层和上采样模块,将融合后的多尺度特征逐步还原至输入图像的原始分辨率。

  • 特点:

    • 解码过程中结合了编码器输出的跳跃连接特征。

    • 在每一阶段对特征进行进一步融合,以增强上下文信息表达。

即插即用模块作用

LPA 作为一个即插即用模块

  • 医学图像分割

    • 任务:肠息肉分割、皮肤病变分割、腹腔镜图像组织分割等。

    • 特点:低对比度、复杂背景,目标边界与周围组织相似。

  • 多尺度特征处理任务

    • 任务:需要同时处理全局背景和局部细节的任务,如语义分割、物体检测。

    • 特点:目标大小变化显著,形状复杂多样。

  • 复杂背景下的目标检测

    • 任务:显著性目标检测、场景分割。

    • 特点:背景噪声较强,目标区域不明显。

消融实验结果

  • 验证 Swin Transformer 编码器的优势表 V 中对比了使用 ResNet-152 和 Swin Transformer 作为编码器的效果,结果显示 Swin Transformer 能显著提升 mIoU 和 mDice 分数,说明其在特征提取方面的强大能力。

  • DMC 模块的有效性在没有 DMC 模块的模型中(Model 2),mIoU 和 mDice 较低。引入 DMC 模块(Model 3)后,性能显著提升,表明其在多尺度特征融合中有效减少噪声并增强边界特征表达的能力。

  • LPA 模块的贡献将 LPA 模块加入模型(Model 4)后,mIoU 和 mDice 进一步提高,证明了局部金字塔注意力机制能够有效聚焦目标区域,提升分割性能。

  • DMC 模块的融合方法对比表中比较了不同特征融合方法(加法、拼接和乘法)。结果显示,乘法融合的 DMC 模块(Model 4)取得最佳性能,表明这种方法在特征融合中更具优势。

  • LPA 模块的层数优化表中对比了单层、多层 LPA 模块的效果,发现双层 LPA 结构(Model 4)性能最佳,而层数过多会引入复杂性,导致性能下降。

即插即用模块

import torch
import torch.nn as nn
#论文:SwinPA-Net: Swin Transformer-Based Multiscale Feature Pyramid Aggregation Network for Medical Image Segmentation
#论文地址:https://ieeexplore.ieee.org/document/9895210

class ChannelAttention(nn.Module):
    def __init__(self, in_planes):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1 = nn.Conv2d(in_planes, in_planes // 8, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // 8, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=3):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)


class LPA(nn.Module):
    def __init__(self, in_channel):
        super(LPA, self).__init__()
        self.ca = ChannelAttention(in_channel)
        self.sa = SpatialAttention()

    def forward(self, x):
        x0, x1 = x.chunk(2, dim=2)
        x0 = x0.chunk(2, dim=3)
        x1 = x1.chunk(2, dim=3)
        x0 = [self.ca(x0[-2]) * x0[-2], self.ca(x0[-1]) * x0[-1]]
        x0 = [self.sa(x0[-2]) * x0[-2], self.sa(x0[-1]) * x0[-1]]

        x1 = [self.ca(x1[-2]) * x1[-2], self.ca(x1[-1]) * x1[-1]]
        x1 = [self.sa(x1[-2]) * x1[-2], self.sa(x1[-1]) * x1[-1]]

        x0 = torch.cat(x0, dim=3)
        x1 = torch.cat(x1, dim=3)
        x3 = torch.cat((x0, x1), dim=2)

        x4 = self.ca(x) * x
        x4 = self.sa(x4) * x4
        x = x3 + x4
        return x


if __name__ == '__main__':

    input = torch.rand(1, 28, 64, 64)
    block = LPA(in_channel=28)
    output = block(input)

    print(input.size())    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章