多尺度特征融合模块TIF,涨点涨爆了!

文摘   2024-12-09 17:23   上海  

论文介绍

题目:DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation

论文地址:https://arxiv.org/abs/2106.06716

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 双尺度编码器:论文提出了一种基于双分支的编码器架构,使用不同尺度的图像块(patch)进行特征提取。这种双尺度方法可以同时捕捉粗粒度和细粒度的特征,从而提升了语义分割的效果。

  • Transformer交互融合模块(TIF):提出了一个新颖的TIF模块,通过Transformer的自注意力机制,有效地融合了来自双尺度编码器的多尺度特征表示。这种融合方式建立了特征间的全局依赖关系,从而保证了多尺度特征的语义一致性。

  • 在解码器中引入Swin Transformer:创新性地在U-Net解码器中使用了Swin Transformer模块,不仅在下采样阶段建模了长程依赖,还在上采样阶段进一步提升了上下文信息的利用效率。

  • 全面的实验验证:通过四个典型的医学图像分割任务(如息肉分割、皮肤病变分割等)的实验,展示了DS-TransUNet在分割质量上优于现有的最先进方法,尤其是在息肉分割任务中表现突出。

方法

整体架构

     DS-TransUNet是一种基于双分支编码器的U型网络结构,融合了Swin Transformer的长程依赖建模能力。它通过双尺度编码器提取粗粒度和细粒度特征,利用Transformer交互融合模块(TIF)实现多尺度特征的全局交互,在解码器中进一步引入Swin Transformer块建模全局上下文,从而实现高效的医学图像分割。这种架构能够捕捉丰富的多尺度信息,并在多个分割任务中表现出色。

1. 双分支编码器(Dual-Branch Encoder)

  • 双尺度特征提取:输入的医学图像被分割为两种不同尺度的非重叠图像块(patch),分别通过两个独立的分支处理:

    • 主分支:处理细粒度图像块(较小尺寸的patch),提取细粒度特征。

    • 辅分支:处理粗粒度图像块(较大尺寸的patch),提取粗粒度特征。

  • 特征提取器:每个分支使用分层的 Swin Transformer 作为编码器,对图像块进行特征表示学习,并通过多个阶段逐步提取高层次特征。


2. Transformer交互融合模块(Transformer Interactive Fusion, TIF)

  • 特征融合:通过标准Transformer块的自注意力机制,融合双分支(粗粒度和细粒度)的特征表示。

  • 全局依赖建模:TIF模块能够捕捉不同尺度特征之间的全局依赖关系,并在特征间实现高效交互。

3. 解码器(Decoder)

  • 上采样与跳跃连接:解码器采用逐层上采样的方式,并利用编码器对应层的特征通过跳跃连接(Skip Connections)来恢复原始分辨率。

  • 引入Swin Transformer块:在每个解码阶段加入Swin Transformer块,以建模长程依赖和全局上下文信息,从而提升解码器的表现。

  • 最终输出:融合后的特征被逐步恢复为与输入图像相同的分辨率,生成像素级的分割结果。

即插即用模块作用

TIF 作为一个即插即用模块

  • 多尺度特征融合:TIF模块利用自注意力机制,在不同尺度的特征之间建立全局交互,提升多尺度特征的融合效果,保证语义一致性。


  • 增强全局上下文信息:通过全局依赖建模,TIF模块能够在特征中注入丰富的上下文信息,提高目标分割的准确性和鲁棒性。


  • 提升分割细节表现:对于边界复杂或细粒度分割任务,TIF模块能有效提升目标边界的分割质量,减少边界模糊现象。


  • 即插即用的灵活性:TIF模块可以作为现有深度学习模型(如U-Net、FPN)的插件模块,无需对整体结构进行大幅修改,即可显著提升模型性能。

消融实验结果

表 VIII 展示了不同模型配置(Base Model、Swin U-Net、Swin Decoder、Multi-Scale SD和DS-TransUNet)在息肉分割任务上的性能对比。实验验证了Swin Transformer作为编码器的有效性、Swin Decoder的长程依赖建模能力,以及TIF模块在多尺度特征融合中的关键作用。最终模型DS-TransUNet在所有数据集上的分割性能均优于其他配置。

图 4 展示了DS-TransUNet在息肉分割任务中(包括Kvasir、CVC-ClinicDB及多个数据集)的定性分割结果。与其他模型相比,DS-TransUNet表现出更强的边界捕捉能力,特别是在处理模糊、颜色与背景相近或边缘复杂的息肉时,其分割结果更接近真实边界。

即插即用模块

import torch
from torch import nn, einsum
from einops import rearrange
#论文:DS-TransUNet: Dual Swin Transformer U-Net for Medical Image Segmentation
#论文地址:https://arxiv.org/abs/2106.06716

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()

        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()


    def forward(self, x):
        b, n, _ = x.shape
        h = self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class CrossAttention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_k = nn.Linear(dim, inner_dim , bias=False)
        self.to_v = nn.Linear(dim, inner_dim , bias = False)
        self.to_q = nn.Linear(dim, inner_dim, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x_qkv):
        b, n, _ = x_qkv.shape
        h = self.heads

        k = self.to_k(x_qkv)
        k = rearrange(k, 'b n (h d) -> b h n d', h = h)

        v = self.to_v(x_qkv)
        v = rearrange(v, 'b n (h d) -> b h n d', h = h)

        q = self.to_q(x_qkv[:, 0].unsqueeze(1))
        q = rearrange(q, 'b n (h d) -> b h n d', h = h)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out = self.to_out(out)
        return out

class TIF(nn.Module):
    def __init__(self, dim_s, dim_l):
        super().__init__()
        self.transformer_s = Transformer(dim=dim_s, depth=1, heads=3, dim_head=32, mlp_dim=128)
        self.transformer_l = Transformer(dim=dim_l, depth=1, heads=1, dim_head=64, mlp_dim=256)
        self.norm_s = nn.LayerNorm(dim_s)
        self.norm_l = nn.LayerNorm(dim_l)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.linear_s = nn.Linear(dim_s, dim_l)
        self.linear_l = nn.Linear(dim_l, dim_s)

    def forward(self, e, r):
        b_e, c_e, h_e, w_e = e.shape
        e = e.reshape(b_e, c_e, -1).permute(0, 2, 1)
        b_r, c_r, h_r, w_r = r.shape
        r = r.reshape(b_r, c_r, -1).permute(0, 2, 1)
        e_t = torch.flatten(self.avgpool(self.norm_l(e).transpose(1, 2)), 1)
        r_t = torch.flatten(self.avgpool(self.norm_s(r).transpose(1, 2)), 1)
        e_t = self.linear_l(e_t).unsqueeze(1)
        r_t = self.linear_s(r_t).unsqueeze(1)
        r = self.transformer_s(torch.cat([e_t, r], dim=1))[:, 1:, :]
        e = self.transformer_l(torch.cat([r_t, e], dim=1))[:, 1:, :]
        e = e.permute(0, 2, 1).reshape(b_e, c_e, h_e, w_e)
        r = r.permute(0, 2, 1).reshape(b_r, c_r, h_r, w_r)
        return e + r


if __name__ == '__main__':

    model = TIF(dim_s=64, dim_l=64)
    input1 = torch.randn(1, 64, 64, 64) # 例如来自小尺度特征的图像
    input2 = torch.randn(1, 64, 64, 64) # 例如来自大尺度特征的图像
    # 前向传播获取输出
    output = model(input1, input2)

    # 打印输入和输出的形状
    print(input1.size())
    print(input2.size())    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章