即插即用多尺度特征融合模块GAB,涨点起飞起飞了

文摘   2024-11-25 17:20   中国香港  

论文介绍

题目:EGE-UNet: an Efficient Group Enhanced UNet  for skin lesion segmentation

论文地址:https://arxiv.org/pdf/2307.08473

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 轻量化设计:提出了一个名为EGE-UNet(Efficient Group Enhanced UNet)的轻量级模型,适用于移动医疗环境,解决了现有模型参数量和计算负担过大的问题。EGE-UNet的参数量仅约50KB,是首个达到此级别的模型,同时在性能上超越了许多现有的重量级模型。

  • 提出了两个关键模块

    • GHPA模块(Group multi-axis Hadamard Product Attention):通过将输入特征分组,在不同维度上进行Hadamard积注意力操作,实现从多个视角提取信息。这种方法基于线性复杂度的Hadamard积注意力机制,避免了传统自注意力机制的二次复杂度问题,同时大幅提升了计算效率。

    • GAB模块(Group Aggregation Bridge):用于多尺度特征融合,通过将低层特征、高层特征以及解码器生成的掩膜信息相结合,利用分组和空洞卷积提取不同尺度的信息,从而有效融合多尺度特征。

  • 模型的设计

    • 采用了基于对称编码器-解码器的U形架构,并结合了GHPA和GAB模块。

    • 在跳跃连接中引入GAB模块,增强了不同阶段之间的特征融合能力。

  • 性能优越性

    • 在皮肤病变分割任务中(ISIC2017和ISIC2018数据集),EGE-UNet显著超越了当前的轻量级和重量级模型。相较TransFuse模型,其参数量和计算量分别减少了494倍和160倍,同时分割性能(mIoU和DSC)更高。

  • 资源效率

    • 是首个同时兼顾轻量化(低参数和低计算量)和高性能的模型,在移动医疗等实际应用环境中具有重要价值。

方法

整体结构

       EGE-UNet 是基于对称 U 形架构的轻量级医学图像分割模型,通过编码器中的 GHPA 模块实现多视角高效特征提取,解码器通过多阶段逐步上采样恢复分割结果,并在跳跃连接中引入 GAB 模块进行多尺度特征融合和掩膜信息整合。模型同时采用深度监督机制生成多尺度掩膜,有效提高分割精度,并显著降低参数量和计算复杂度,非常适用于资源受限的移动医疗场景。

1. 基本架构

  • EGE-UNet 基于经典的 U 形架构,由对称的编码器(Encoder)和解码器(Decoder)组成。

  • 编码器逐步提取特征,解码器逐步上采样特征以恢复分割结果,同时通过跳跃连接传递信息。

2. 编码器

  • 编码器包含六个阶段,前三阶段使用传统卷积提取低级特征。

  • 后三阶段引入 GHPA 模块,通过多轴分组和 Hadamard 积注意力机制,从多个视角高效提取高级特征。

3. 跳跃连接

  • 在传统 U-Net 的跳跃连接基础上,EGE-UNet 引入 GAB 模块

    • 融合编码器低级特征、高级特征和解码器生成的掩膜信息。

    • 使用空洞卷积和多尺度特征融合策略,有效处理目标大小和形状的多样性。

4. 解码器

  • 解码器逐步上采样经过 GAB 模块融合的特征。

  • 在每个阶段生成不同尺度的掩膜信息,结合深度监督机制优化分割性能。

5. 模型特点

  • EGE-UNet 通过 GHPA 和 GAB 模块实现了轻量化设计,仅需约 50KB 参数。

  • 在降低参数和计算复杂度的同时,显著提升分割性能,适用于资源受限的移动医疗场景。

即插即用模块作用

GAB 作为一个即插即用模块,主要适用于:

  • 医学图像分割任务

    • 特别是在皮肤病变分割任务(如ISIC2017和ISIC2018数据集)中,GAB模块通过融合多尺度特征,有效处理病变区域的多样性(大小、形状和边界不一致)。

    • 可以扩展到其他需要高精度分割的医学图像任务(如脑肿瘤分割、器官分割等)。

  • 多尺度信息融合需求

    • 在需要处理目标形状、大小变化较大的场景中,GAB模块可以通过分组和空洞卷积提取多尺度信息,从而提升分割性能。

  • 资源受限的移动医疗应用

    • GAB模块的设计考虑了轻量化需求,能够在低计算资源环境下实现高效的多尺度特征融合。

消融实验结果

  • 论文证明了GHPA模块和GAB模块的设计是提升EGE-UNet性能的核心,同时验证了各模块中的关键设计(如多轴分组、掩膜信息和空洞卷积)的有效性。这些实验表明,EGE-UNet能够在显著降低参数和计算量的情况下,依靠精细设计的模块实现卓越的分割性能。

即插即用模块

import torch
from torch import nn
import torch.nn.functional as F

#可以缝合在跳跃连接部分
class LayerNorm(nn.Module):
    """ From ConvNeXt (https://arxiv.org/pdf/2201.03545.pdf)"""

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x


class group_aggregation_bridge(nn.Module):
    def __init__(self, dim_xh, dim_xl, k_size=3, d_list=[1, 2, 5, 7]):
        super().__init__()
        self.pre_project = nn.Conv2d(dim_xh, dim_xl, 1)
        group_size = dim_xl // 2
        self.g0 = nn.Sequential(
            LayerNorm(normalized_shape=group_size + 1, data_format='channels_first'),
            nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
                      padding=(k_size + (k_size - 1) * (d_list[0] - 1)) // 2,
                      dilation=d_list[0], groups=group_size + 1)
        )
        self.g1 = nn.Sequential(
            LayerNorm(normalized_shape=group_size + 1, data_format='channels_first'),
            nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
                      padding=(k_size + (k_size - 1) * (d_list[1] - 1)) // 2,
                      dilation=d_list[1], groups=group_size + 1)
        )
        self.g2 = nn.Sequential(
            LayerNorm(normalized_shape=group_size + 1, data_format='channels_first'),
            nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
                      padding=(k_size + (k_size - 1) * (d_list[2] - 1)) // 2,
                      dilation=d_list[2], groups=group_size + 1)
        )
        self.g3 = nn.Sequential(
            LayerNorm(normalized_shape=group_size + 1, data_format='channels_first'),
            nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
                      padding=(k_size + (k_size - 1) * (d_list[3] - 1)) // 2,
                      dilation=d_list[3], groups=group_size + 1)
        )
        self.tail_conv = nn.Sequential(
            LayerNorm(normalized_shape=dim_xl * 2 + 4, data_format='channels_first'),
            nn.Conv2d(dim_xl * 2 + 4, dim_xl, 1)
        )

    def forward(self, xh, xl, mask):
        xh = self.pre_project(xh)
        xh = F.interpolate(xh, size=[xl.size(2), xl.size(3)], mode='bilinear', align_corners=True)
        xh = torch.chunk(xh, 4, dim=1)
        xl = torch.chunk(xl, 4, dim=1)
        x0 = self.g0(torch.cat((xh[0], xl[0], mask), dim=1))
        x1 = self.g1(torch.cat((xh[1], xl[1], mask), dim=1))
        x2 = self.g2(torch.cat((xh[2], xl[2], mask), dim=1))
        x3 = self.g3(torch.cat((xh[3], xl[3], mask), dim=1))
        x = torch.cat((x0, x1, x2, x3), dim=1)
        x = self.tail_conv(x)
        return x
if __name__ == '__main__':
    # 创建模拟输入数据
    xh = torch.randn(1, 64, 32, 32) # 输入 xh 的形状为 [B C H W]
    xl = torch.randn(1, 64, 16, 16) # 输入 xl 的形状为 [B C H/2 W2]
    mask = torch.randn(1, 1, 16, 16) # 蒙版张量的形状为 [B 1 H/2 W/2]
    # 实例化模块
    block = group_aggregation_bridge(dim_xh=64, dim_xl=64)

    # 打印输入的形状
    print("输入 xh 的形状:", xh.size())
    print("输入 xl 的形状:", xl.size())
    print("蒙版张量的形状:", mask.size())

    # 进行前向传播
    output = block(xh, xl, mask)

    # 打印输出的形状
    print("输出的形状:", output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章