论文介绍
题目:DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation
论文地址:https://arxiv.org/abs/2106.06716
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
双尺度编码器:论文提出了一种基于双分支的编码器架构,使用不同尺度的图像块(patch)进行特征提取。这种双尺度方法可以同时捕捉粗粒度和细粒度的特征,从而提升了语义分割的效果。
Transformer交互融合模块(TIF):提出了一个新颖的TIF模块,通过Transformer的自注意力机制,有效地融合了来自双尺度编码器的多尺度特征表示。这种融合方式建立了特征间的全局依赖关系,从而保证了多尺度特征的语义一致性。
在解码器中引入Swin Transformer:创新性地在U-Net解码器中使用了Swin Transformer模块,不仅在下采样阶段建模了长程依赖,还在上采样阶段进一步提升了上下文信息的利用效率。
全面的实验验证:通过四个典型的医学图像分割任务(如息肉分割、皮肤病变分割等)的实验,展示了DS-TransUNet在分割质量上优于现有的最先进方法,尤其是在息肉分割任务中表现突出。
方法
整体架构
DS-TransUNet是一种基于双分支编码器的U型网络结构,融合了Swin Transformer的长程依赖建模能力。它通过双尺度编码器提取粗粒度和细粒度特征,利用Transformer交互融合模块(TIF)实现多尺度特征的全局交互,在解码器中进一步引入Swin Transformer块建模全局上下文,从而实现高效的医学图像分割。这种架构能够捕捉丰富的多尺度信息,并在多个分割任务中表现出色。
1. 双分支编码器(Dual-Branch Encoder)
双尺度特征提取:输入的医学图像被分割为两种不同尺度的非重叠图像块(patch),分别通过两个独立的分支处理:
主分支:处理细粒度图像块(较小尺寸的patch),提取细粒度特征。
辅分支:处理粗粒度图像块(较大尺寸的patch),提取粗粒度特征。
特征提取器:每个分支使用分层的 Swin Transformer 作为编码器,对图像块进行特征表示学习,并通过多个阶段逐步提取高层次特征。
2. Transformer交互融合模块(Transformer Interactive Fusion, TIF)
特征融合:通过标准Transformer块的自注意力机制,融合双分支(粗粒度和细粒度)的特征表示。
全局依赖建模:TIF模块能够捕捉不同尺度特征之间的全局依赖关系,并在特征间实现高效交互。
3. 解码器(Decoder)
上采样与跳跃连接:解码器采用逐层上采样的方式,并利用编码器对应层的特征通过跳跃连接(Skip Connections)来恢复原始分辨率。
引入Swin Transformer块:在每个解码阶段加入Swin Transformer块,以建模长程依赖和全局上下文信息,从而提升解码器的表现。
最终输出:融合后的特征被逐步恢复为与输入图像相同的分辨率,生成像素级的分割结果。
即插即用模块作用
TIF 作为一个即插即用模块:
多尺度特征融合:TIF模块利用自注意力机制,在不同尺度的特征之间建立全局交互,提升多尺度特征的融合效果,保证语义一致性。
增强全局上下文信息:通过全局依赖建模,TIF模块能够在特征中注入丰富的上下文信息,提高目标分割的准确性和鲁棒性。
提升分割细节表现:对于边界复杂或细粒度分割任务,TIF模块能有效提升目标边界的分割质量,减少边界模糊现象。
即插即用的灵活性:TIF模块可以作为现有深度学习模型(如U-Net、FPN)的插件模块,无需对整体结构进行大幅修改,即可显著提升模型性能。
消融实验结果
表 VIII 展示了不同模型配置(Base Model、Swin U-Net、Swin Decoder、Multi-Scale SD和DS-TransUNet)在息肉分割任务上的性能对比。实验验证了Swin Transformer作为编码器的有效性、Swin Decoder的长程依赖建模能力,以及TIF模块在多尺度特征融合中的关键作用。最终模型DS-TransUNet在所有数据集上的分割性能均优于其他配置。
图 4 展示了DS-TransUNet在息肉分割任务中(包括Kvasir、CVC-ClinicDB及多个数据集)的定性分割结果。与其他模型相比,DS-TransUNet表现出更强的边界捕捉能力,特别是在处理模糊、颜色与背景相近或边缘复杂的息肉时,其分割结果更接近真实边界。
即插即用模块
import torch
from torch import nn, einsum
from einops import rearrange
#论文:DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation
#论文地址:https://arxiv.org/abs/2106.06716
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
b, n, _ = x.shape
h = self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
super().__init__()
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
]))
def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x
return x
class CrossAttention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.to_k = nn.Linear(dim, inner_dim , bias=False)
self.to_v = nn.Linear(dim, inner_dim , bias = False)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x_qkv):
b, n, _ = x_qkv.shape
h = self.heads
k = self.to_k(x_qkv)
k = rearrange(k, 'b n (h d) -> b h n d', h = h)
v = self.to_v(x_qkv)
v = rearrange(v, 'b n (h d) -> b h n d', h = h)
q = self.to_q(x_qkv[:, 0].unsqueeze(1))
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
attn = dots.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
class TIF(nn.Module):
def __init__(self, dim_s, dim_l):
super().__init__()
self.transformer_s = Transformer(dim=dim_s, depth=1, heads=3, dim_head=32, mlp_dim=128)
self.transformer_l = Transformer(dim=dim_l, depth=1, heads=1, dim_head=64, mlp_dim=256)
self.norm_s = nn.LayerNorm(dim_s)
self.norm_l = nn.LayerNorm(dim_l)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.linear_s = nn.Linear(dim_s, dim_l)
self.linear_l = nn.Linear(dim_l, dim_s)
def forward(self, e, r):
b_e, c_e, h_e, w_e = e.shape
e = e.reshape(b_e, c_e, -1).permute(0, 2, 1)
b_r, c_r, h_r, w_r = r.shape
r = r.reshape(b_r, c_r, -1).permute(0, 2, 1)
e_t = torch.flatten(self.avgpool(self.norm_l(e).transpose(1, 2)), 1)
r_t = torch.flatten(self.avgpool(self.norm_s(r).transpose(1, 2)), 1)
e_t = self.linear_l(e_t).unsqueeze(1)
r_t = self.linear_s(r_t).unsqueeze(1)
r = self.transformer_s(torch.cat([e_t, r], dim=1))[:, 1:, :]
e = self.transformer_l(torch.cat([r_t, e], dim=1))[:, 1:, :]
e = e.permute(0, 2, 1).reshape(b_e, c_e, h_e, w_e)
r = r.permute(0, 2, 1).reshape(b_r, c_r, h_r, w_r)
return e + r
if __name__ == '__main__':
model = TIF(dim_s=64, dim_l=64)
input1 = torch.randn(1, 64, 64, 64) # 例如来自小尺度特征的图像
input2 = torch.randn(1, 64, 64, 64) # 例如来自大尺度特征的图像
# 前向传播获取输出
output = model(input1, input2)
# 打印输入和输出的形状
print(input1.size())
print(input2.size()) print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文