论文介绍
题目:SUnet: A multi-organ segmentation network based on multiple attention
论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0010482523010612
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
SUnet模型设计:提出了一种基于纯Transformer的U型医疗图像分割网络SUnet,采用了高效的空间压缩注意力机制(ESRA)和多重注意力特征融合模块,实现了更高的分割精度,且减少了模型参数,有助于减轻过拟合问题。
高效空间压缩注意力(ESRA):ESRA模块通过压缩多头自注意力的关键和值来减少模型参数,从而提升了特征提取能力并降低了计算复杂度。
增强的注意力门(EAG)模块:EAG模块结合了分组卷积和残差连接,用于改进语义特征的传递,从而实现了更加丰富的特征表达。
多重注意力特征融合(EFF)模块:基于多重注意力的特征融合模块实现了跨尺度的特征整合,使得低级语义特征与解码器的高级语义特征更好地结合,提升了分割效果。
方法
整体结构
编码器(Encoder):编码器部分采用了改进的Transformer结构,即高效空间压缩注意力(ESRA)模块,用于特征提取。该模块通过压缩注意力计算的关键和值,显著降低了模型参数量和计算复杂度,增强了模型对图像全局特征的捕捉能力。
跳跃连接(Skip Connections):SUnet延续了U-Net的跳跃连接设计,将编码器中提取的低级语义特征直接传递到解码器对应层。这一结构使得模型在解码过程中能够保留更丰富的空间细节信息,有助于精确分割。
解码器(Decoder):解码器部分使用多重注意力特征融合(EFF)模块,将从编码器传递的低级特征与解码器逐层上采样得到的高级特征融合。EFF模块由多个子模块组成,包括增强注意力门(EAG)模块、通道注意力(ECA)和空间注意力(SA)。这些模块能够选择性地突出与任务相关的特征区域,提升分割效果。
增强注意力门(EAG)模块:在EFF模块中,EAG模块基于分组卷积和残差连接设计,专注于融合低级和高级特征,通过减小不相关区域的影响,提升了分割精度。
输出层:最终,解码器输出的是分割结果图,模型通过Dice和交叉熵损失函数的加权组合来优化分割精度。
即插即用模块作用
EFF 作为一个即插即用模块,主要适用于:
复杂的多器官医学图像分割:在多器官分割任务中,不同器官的形状、位置和大小差异显著。EFF模块通过多重注意力机制(包含EAG、ECA和SA),能够有效融合不同尺度的特征,增强模型对多种解剖结构的分割能力。
需要细致特征表达的任务:EFF模块能够在编码器与解码器之间有效融合低级和高级特征,确保细节信息不丢失,适合对细节要求较高的医学影像处理任务,例如心脏、血管或肿瘤的分割。
计算资源有限的环境:EFF模块采用了高效的注意力机制,如分组卷积和残差连接,能够在不显著增加模型参数的前提下提升分割性能,因此适用于计算资源受限但要求高精度分割的场景。
消融实验结果
ESRA模块的贡献:仅使用高效空间压缩注意力(ESRA)模块的SUnet-0模型,其Dice系数已超过大多数传统2D医学图像分割模型的性能,证明了ESRA模块在特征提取中的有效性。
EAG模块的提升作用:在SUnet-0的基础上增加增强注意力门(EAG)模块,形成SUnet-1模型,进一步提高了Dice系数,表明EAG模块在融合低级和高级特征时有助于提高分割精度。
ECA与SA模块的协同效果:单独添加通道注意力(ECA)或空间注意力(SA)模块时,模型性能有所下降,但将两者组合后则提升了性能。这表明同时关注通道和空间信息能避免局部最优,提升分割效果。
完整模型的最优表现:包含ESRA、EAG、ECA和SA的完整SUnet模型在Dice系数上达到最高,证明了各模块组合在提升分割精度和减少不相关区域干扰方面的协同作用。
即插即用模块
import torch
import torch.nn as nn
import math
#论文:SUnet: A multi-organ segmentation network based on multiple attention
#论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0010482523010612
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class Efficient_Attention_Gate(nn.Module):
def __init__(self, F_g, F_l, F_int, num_groups=32):
super(Efficient_Attention_Gate, self).__init__()
self.num_groups = num_groups
self.grouped_conv_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True, groups=num_groups),
nn.BatchNorm2d(F_int),
nn.ReLU(inplace=True)
)
self.grouped_conv_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True, groups=num_groups),
nn.BatchNorm2d(F_int),
nn.ReLU(inplace=True)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.grouped_conv_g(g)
x1 = self.grouped_conv_x(x)
psi = self.psi(self.relu(x1 + g1))
out = x * psi
out += x
return out
class EfficientChannelAttention(nn.Module):
def __init__(self, channels, gamma=2, b=1):
super(EfficientChannelAttention, self).__init__()
# 设计自适应卷积核,便于后续做1*1卷积
kernel_size = int(abs((math.log(channels, 2) + b) / gamma))
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
# 全局平局池化
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# 基于1*1卷积学习通道之间的信息
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
# 激活函数
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# 首先,空间维度做全局平局池化,[b,c,h,w]==>[b,c,1,1]
v = self.avg_pool(x)
# 然后,基于1*1卷积学习通道之间的信息;其中,使用前面设计的自适应卷积核
v = self.conv(v.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# 最终,经过sigmoid 激活函数处理
v = self.sigmoid(v)
return v
class EFF(nn.Module):
def __init__(self, in_dim, is_bottom=False):
super().__init__()
self.is_bottom = is_bottom
if not is_bottom:
self.EAG = Efficient_Attention_Gate(in_dim, in_dim, in_dim)
else:
self.EAG = nn.Identity()
self.ECA = EfficientChannelAttention(in_dim*2)
self.SA = SpatialAttention()
def forward(self, x, skip):
if not self.is_bottom:
EAG_skip = self.EAG(x, skip)
x = torch.cat((EAG_skip, x), dim=1)
# x = EAG_skip + x
else:
x = self.EAG(x)
x = self.ECA(x) * x
x = self.SA(x) * x
return x
if __name__ == '__main__':
block = EFF(in_dim=512, is_bottom=False)
x1 = torch.randn(1, 512, 71, 71)
x2 = torch.randn(1, 512, 71, 71)
# 将张量通过 EFF 模块
output = block(x1, x2)
print(x1.size())
print(x2.size()) print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文