论文介绍
题目:EGE-UNet: an Efficient Group Enhanced UNet for skin lesion segmentation
论文地址:https://arxiv.org/pdf/2307.08473
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
轻量化设计:提出了一个名为EGE-UNet(Efficient Group Enhanced UNet)的轻量级模型,适用于移动医疗环境,解决了现有模型参数量和计算负担过大的问题。EGE-UNet的参数量仅约50KB,是首个达到此级别的模型,同时在性能上超越了许多现有的重量级模型。
提出了两个关键模块:
GHPA模块(Group multi-axis Hadamard Product Attention):通过将输入特征分组,在不同维度上进行Hadamard积注意力操作,实现从多个视角提取信息。这种方法基于线性复杂度的Hadamard积注意力机制,避免了传统自注意力机制的二次复杂度问题,同时大幅提升了计算效率。
GAB模块(Group Aggregation Bridge):用于多尺度特征融合,通过将低层特征、高层特征以及解码器生成的掩膜信息相结合,利用分组和空洞卷积提取不同尺度的信息,从而有效融合多尺度特征。
模型的设计:
采用了基于对称编码器-解码器的U形架构,并结合了GHPA和GAB模块。
在跳跃连接中引入GAB模块,增强了不同阶段之间的特征融合能力。
性能优越性:
在皮肤病变分割任务中(ISIC2017和ISIC2018数据集),EGE-UNet显著超越了当前的轻量级和重量级模型。相较TransFuse模型,其参数量和计算量分别减少了494倍和160倍,同时分割性能(mIoU和DSC)更高。
资源效率:
是首个同时兼顾轻量化(低参数和低计算量)和高性能的模型,在移动医疗等实际应用环境中具有重要价值。
方法
整体结构
1. 基本架构
EGE-UNet 基于经典的 U 形架构,由对称的编码器(Encoder)和解码器(Decoder)组成。
编码器逐步提取特征,解码器逐步上采样特征以恢复分割结果,同时通过跳跃连接传递信息。
2. 编码器
编码器包含六个阶段,前三阶段使用传统卷积提取低级特征。
后三阶段引入 GHPA 模块,通过多轴分组和 Hadamard 积注意力机制,从多个视角高效提取高级特征。
3. 跳跃连接
在传统 U-Net 的跳跃连接基础上,EGE-UNet 引入 GAB 模块。
融合编码器低级特征、高级特征和解码器生成的掩膜信息。
使用空洞卷积和多尺度特征融合策略,有效处理目标大小和形状的多样性。
4. 解码器
解码器逐步上采样经过 GAB 模块融合的特征。
在每个阶段生成不同尺度的掩膜信息,结合深度监督机制优化分割性能。
5. 模型特点
EGE-UNet 通过 GHPA 和 GAB 模块实现了轻量化设计,仅需约 50KB 参数。
在降低参数和计算复杂度的同时,显著提升分割性能,适用于资源受限的移动医疗场景。
即插即用模块作用
GAB 作为一个即插即用模块,主要适用于:
医学图像分割任务:
特别是在皮肤病变分割任务(如ISIC2017和ISIC2018数据集)中,GAB模块通过融合多尺度特征,有效处理病变区域的多样性(大小、形状和边界不一致)。
可以扩展到其他需要高精度分割的医学图像任务(如脑肿瘤分割、器官分割等)。
多尺度信息融合需求:
在需要处理目标形状、大小变化较大的场景中,GAB模块可以通过分组和空洞卷积提取多尺度信息,从而提升分割性能。
资源受限的移动医疗应用:
GAB模块的设计考虑了轻量化需求,能够在低计算资源环境下实现高效的多尺度特征融合。
消融实验结果
论文证明了GHPA模块和GAB模块的设计是提升EGE-UNet性能的核心,同时验证了各模块中的关键设计(如多轴分组、掩膜信息和空洞卷积)的有效性。这些实验表明,EGE-UNet能够在显著降低参数和计算量的情况下,依靠精细设计的模块实现卓越的分割性能。
即插即用模块
import torch
from torch import nn
import torch.nn.functional as F
#可以缝合在跳跃连接部分
class LayerNorm(nn.Module):
""" From ConvNeXt (https://arxiv.org/pdf/2201.03545.pdf)"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class group_aggregation_bridge(nn.Module):
def __init__(self, dim_xh, dim_xl, k_size=3, d_list=[1, 2, 5, 7]):
super().__init__()
self.pre_project = nn.Conv2d(dim_xh, dim_xl, 1)
group_size = dim_xl // 2
self.g0 = nn.Sequential(
LayerNorm(normalized_shape=group_size + 1, data_format='channels_first'),
nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
padding=(k_size + (k_size - 1) * (d_list[0] - 1)) // 2,
dilation=d_list[0], groups=group_size + 1)
)
self.g1 = nn.Sequential(
LayerNorm(normalized_shape=group_size + 1, data_format='channels_first'),
nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
padding=(k_size + (k_size - 1) * (d_list[1] - 1)) // 2,
dilation=d_list[1], groups=group_size + 1)
)
self.g2 = nn.Sequential(
LayerNorm(normalized_shape=group_size + 1, data_format='channels_first'),
nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
padding=(k_size + (k_size - 1) * (d_list[2] - 1)) // 2,
dilation=d_list[2], groups=group_size + 1)
)
self.g3 = nn.Sequential(
LayerNorm(normalized_shape=group_size + 1, data_format='channels_first'),
nn.Conv2d(group_size + 1, group_size + 1, kernel_size=3, stride=1,
padding=(k_size + (k_size - 1) * (d_list[3] - 1)) // 2,
dilation=d_list[3], groups=group_size + 1)
)
self.tail_conv = nn.Sequential(
LayerNorm(normalized_shape=dim_xl * 2 + 4, data_format='channels_first'),
nn.Conv2d(dim_xl * 2 + 4, dim_xl, 1)
)
def forward(self, xh, xl, mask):
xh = self.pre_project(xh)
xh = F.interpolate(xh, size=[xl.size(2), xl.size(3)], mode='bilinear', align_corners=True)
xh = torch.chunk(xh, 4, dim=1)
xl = torch.chunk(xl, 4, dim=1)
x0 = self.g0(torch.cat((xh[0], xl[0], mask), dim=1))
x1 = self.g1(torch.cat((xh[1], xl[1], mask), dim=1))
x2 = self.g2(torch.cat((xh[2], xl[2], mask), dim=1))
x3 = self.g3(torch.cat((xh[3], xl[3], mask), dim=1))
x = torch.cat((x0, x1, x2, x3), dim=1)
x = self.tail_conv(x)
return x
if __name__ == '__main__':
# 创建模拟输入数据
xh = torch.randn(1, 64, 32, 32) # 输入 xh 的形状为 [B C H W]
xl = torch.randn(1, 64, 16, 16) # 输入 xl 的形状为 [B C H/2 W2]
mask = torch.randn(1, 1, 16, 16) # 蒙版张量的形状为 [B 1 H/2 W/2]
# 实例化模块
block = group_aggregation_bridge(dim_xh=64, dim_xl=64)
# 打印输入的形状
print("输入 xh 的形状:", xh.size())
print("输入 xl 的形状:", xl.size())
print("蒙版张量的形状:", mask.size())
# 进行前向传播
output = block(xh, xl, mask)
# 打印输出的形状
print("输出的形状:", output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文