2024即插即用Efficient Non-Local Transformer Block,涨点起飞起飞了!

文摘   2024-12-14 17:20   上海  

论文介绍

题目:Perspective+ Unet: Enhancing Segmentation with Bi-Path Fusion and Efficient Non-Local Attention for Superior Receptive Fields

论文地址:https://arxiv.org/abs/2406.14052

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 双路径编码策略

    • 在编码器阶段引入了双路径策略,结合了传统卷积和空洞卷积的结果。

    • 传统卷积捕捉高分辨率局部细节,空洞卷积扩展了感受野以捕捉更广的上下文信息。

    • 这种组合平衡了局部细节保留和全局结构理解的需求。

  • 高效非局部变换器模块(ENLTB)

    • 提出了基于核函数近似的高效非局部注意力机制,显著提升了长距离依赖的捕捉能力,同时具有线性计算和空间复杂度。

    • ENLTB 模块将传统非局部注意力的高计算成本大幅降低,同时保留全局视角的优势。

  • 跨尺度空间集成器(SCSI)

    • 该模块整合了来自不同阶段的全局和局部特征,确保不同层级的特征能够相互补充。

    • 通过跨尺度的信息融合,提高了模型对图像复杂结构的精准分割能力。

方法

整体架构

     Perspective+ Unet 是一种创新的医学图像分割模型,基于经典 U-Net 框架,采用 编码器-瓶颈-解码器 结构。编码器引入双路径残差模块(BPRB),结合传统卷积和空洞卷积,实现局部细节和全局上下文的融合;瓶颈部分通过高效非局部变换器块(ENLTB)捕获长距离依赖,提升全局信息建模能力;解码器利用跨尺度空间集成器(SCSI)整合多层次特征,通过跳跃连接保留细节信息,最终生成精准的分割结果。这种设计实现了全局感知与局部分辨率的平衡,显著提升医学图像分割性能。

1. 编码器 (Encoder)

  • 双路径残差模块 (Bi-Path Residual Block, BPRB)

    • 传统卷积:捕获局部细节,保持高分辨率的纹理信息。

    • 空洞卷积:扩展感受野,捕捉全局上下文。

    • 在编码器中每个阶段使用双路径设计,分别通过:

    • 通过连接这些特征,实现局部与全局信息的有效结合。

    • 编码器的输出传递给后续的瓶颈模块。


2. 瓶颈 (Bottleneck)

  • 高效非局部变换器块 (Efficient Non-Local Transformer Block, ENLTB)

  • 使用核函数近似减少注意力计算复杂度,使其从传统的二次复杂度 O(N2)O(N^2)降至线性 O(N)。

    • 通过全局上下文的建模,提取长距离依赖的特征。

    • 结合高效的非局部注意力机制(ENLSA):

    • 在瓶颈部分堆叠多个 ENLTB 层,每层通过特征合并与多尺度上下文的融合,逐步提升对图像全局结构的理解。


3. 解码器 (Decoder)

  • 跨尺度空间集成器 (Spatial Cross-Scale Integrator, SCSI)

    • 将来自编码器和瓶颈模块的特征融合,整合局部细节和全局信息。

    • 跨尺度模块通过跨层级特征的对齐和重建,确保图像分割的精准性。

    • 使用经典的跳跃连接 (Skip Connection) 将编码器中各层的细节特征直接传递到解码器,减少信息丢失。

    • 解码器通过逐层上采样 (Upsampling) 生成最终的分割结果。

即插即用模块作用

ENLTB 作为一个即插即用模块

  • 需要捕获全局上下文信息的任务

    • 特别是在 3D 医学图像分割、目标检测、语义分割等领域,模型需要理解全局结构,同时保留局部细节。

    • ENLTB 能够在长距离依赖建模中展现优势,有助于复杂结构的精准分割或识别。

  • 对计算资源敏感的应用

    • ENLTB 的计算复杂度从传统全局注意力的 O(N2)O(N^2) 降低到线性复杂度 O(N)O(N),适合资源受限的场景,例如移动设备上的图像处理或实时应用。

  • 需要提高现有模型的全局感知能力

    • 可用于补强现有的 CNN 模型或混合架构(如 CNN+Transformer),弥补 CNN 对长距离依赖捕捉能力的不足。

  • 多尺度特征处理任务

  • 在需要同时关注细粒度和宏观特征的任务中(如多器官分割、超分辨率等),ENLTB 能够通过跨尺度集成进一步提升效果。

消融实验结果

  • 不包含任何模块(Baseline)

    • Dice 相似系数(DSC):84.04%

    • Hausdorff 距离(HD):16.63 mm

    • 仅使用基础 U-Net 框架,分割性能较低。

  • 仅加入 BPRB

    • DSC 降至 83.36%,但 HD 显著降低到 14.70 mm。

    • 表明双路径残差模块在增强感受野和保持全局上下文上有效,但需要与其他模块协同才能充分发挥作用。

  • 加入 BPRB 和 SCSI

    • DSC 提升至 83.92%,HD 进一步降低到 13.94 mm。

    • 跨尺度空间集成器增强了多层次特征的融合,提高了模型的重建精度。

  • 完整模型(BPRB + ENLTB + SCSI)

    • DSC 提升至 84.63%,HD 降低到 11.74 mm,达到最优性能。

      说明三大模块的协同作用显著增强了模型的分割精度和全局感知能力。

即插即用模块

import math
import torch
from torch import nn
from functools import partial
from einops import repeat
import torch.nn.functional as F
from timm.models.layers import DropPath
#论文:Perspective+ Unet: Enhancing Segmentation with Bi-Path Fusion and Efficient Non-Local Attention for Superior Receptive Fields [MICCAI2024]
#论文地址:https://arxiv.org/abs/2406.14052

def default_conv(in_channels, out_channels, kernel_size, stride=1, bias=True):
    return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size // 2), stride=stride, bias=bias)

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def orthogonal_matrix_chunk(cols, device=None):
    unstructured_block = torch.randn((cols, cols), device=device)
    some = True
    q, r = torch.linalg.qr(unstructured_block.cpu(), 'reduced' if some else 'complete')
    q, r = map(lambda t: t.to(device), (q, r))
    return q.t()

def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, device=None):
    nb_full_blocks = int(nb_rows / nb_columns)

    block_list = []

    for _ in range(nb_full_blocks):
        q = orthogonal_matrix_chunk(nb_columns, device=device)
        block_list.append(q)

    remaining_rows = nb_rows - nb_full_blocks * nb_columns
    if remaining_rows > 0:
        q = orthogonal_matrix_chunk(nb_columns, device=device)
        block_list.append(q[:remaining_rows])

    final_matrix = torch.cat(block_list)

    if scaling == 0:
        multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
    elif scaling == 1:
        multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
    else:
        raise ValueError(f'Invalid scaling {scaling}')

    return torch.diag(multiplier) @ final_matrix


def generalized_kernel(data, *, projection_matrix, kernel_fn=nn.ReLU(), kernel_epsilon=0.001, normalize_data=True):
    b, h, *_ = data.shape

    data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.

    if projection_matrix is None:
        return kernel_fn(data_normalizer * data) + kernel_epsilon

    projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)
    projection = projection.type_as(data)

    data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data), projection)

    data_prime = kernel_fn(data_dash) + kernel_epsilon
    return data_prime.type_as(data)


def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=False, eps=1e-4, device=None):
    b, h, *_ = data.shape

    ratio = (projection_matrix.shape[0] ** -0.5)

    projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)
    projection = projection.type_as(data)

    data_dash = torch.einsum('...id,...jd->...ij', data, projection)
    diag_data = data ** 2
    diag_data = torch.sum(diag_data, dim=-1)
    diag_data = (diag_data / 2.0)
    diag_data = diag_data.unsqueeze(dim=-1)

    data_dash = ratio * (torch.exp(data_dash - diag_data) + eps)

    return data_dash.type_as(data)


# non-causal linear attention
def linear_attention(q, k, v):
    k_cumsum = k.sum(dim=-2)
    D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
    context = torch.einsum('...nd,...ne->...de', k, v)
    out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
    return out

#Efficient Non-Local Attention Mechanism (ENLA)
class ENLA(nn.Module):
    def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, generalized_attention=False, kernel_fn=nn.ReLU(),
                 no_projection=False, attn_drop=0.)
:

        super().__init__()
        nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))

        self.dim_heads = dim_heads
        self.nb_features = nb_features
        self.ortho_scaling = ortho_scaling

        self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features,
                                         nb_columns=dim_heads, scaling=ortho_scaling)
        projection_matrix = self.create_projection()
        self.register_buffer('projection_matrix', projection_matrix)

        self.generalized_attention = generalized_attention
        self.kernel_fn = kernel_fn

        # if this is turned on, no projection will be used
        # queries and keys will be softmax-ed as in the original efficient attention paper
        self.no_projection = no_projection
        self.attn_drop = nn.Dropout(attn_drop)

    @torch.no_grad()
    def redraw_projection_matrix(self, device):
        projections = self.create_projection(device=device)
        self.projection_matrix.copy_(projections)
        del projections

    def forward(self, q, k, v):
        # q[b,h,n,d],b is batch ,h is multi head, n is number of batch, d is feature
        device = q.device

        if self.no_projection:
            q = q.softmax(dim=-1)
            k = k.softmax(dim=-2)

        elif self.generalized_attention:
            create_kernel = partial(generalized_kernel, kernel_fn=self.kernel_fn,
                                    projection_matrix=self.projection_matrix, device=device)
            q, k = map(create_kernel, (q, k))

        else:
            create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=device)
            q = create_kernel(q, is_query=True)
            k = create_kernel(k, is_query=False)

        attn_fn = linear_attention
        out = attn_fn(q, k, v)
        out = self.attn_drop(out)
        return out

class BasicBlock(nn.Sequential):
    def __init__(self, conv, in_channels, out_channels, kernel_size, bias=True, bn=False, act=None):
        m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
        if bn:
            m.append(nn.BatchNorm2d(out_channels))
        if act is not None:
            m.append(act)
        super(BasicBlock, self).__init__(*m)

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

#efficient non-local transformer block (ENLTB)
class ENLTB(nn.Module):

    def __init__(self, dim, input_resolution, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop=0.1, attn_drop=0.1, drop_path=0.1, act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=1)
:

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        # self.mlp_ratio = mlp_ratio
        self.qk_scale = qk_scale
        self.conv_match1 = BasicBlock(default_conv, dim, dim, kernel_size, bias=qkv_bias, bn=False, act=None)
        self.conv_match2 = BasicBlock(default_conv, dim, dim, kernel_size, bias=qkv_bias, bn=False, act=None)
        self.conv_assembly = BasicBlock(default_conv, dim, dim, kernel_size, bias=qkv_bias, bn=False, act=None)

        self.norm1 = norm_layer(dim)
        self.attn = ENLA(dim_heads=dim, nb_features=dim, attn_drop=attn_drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)

        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x):
        H, W = self.input_resolution
        assert H == x.shape[-2] and W == x.shape[-1], "input feature has wrong size"
        B, C, H, W = x.shape
        x = x.view(B, C, H * W).permute(0, 2, 1).contiguous()
        shortcut = x # skip connection

        # Layer Norm
        x = self.norm1(x)
        x = x.view(B, H, W, C)
        x = x.permute(0, 3, 1, 2).contiguous()

        # ENLA
        x_embed_1 = self.conv_match1(x)
        x_embed_2 = self.conv_match2(x)
        x_assembly = self.conv_assembly(x) # [B,C,H,W]
        if self.qk_scale is not None:
            x_embed_1 = F.normalize(x_embed_1, p=2, dim=1, eps=5e-5) * self.qk_scale
            x_embed_2 = F.normalize(x_embed_2, p=2, dim=1, eps=5e-5) * self.qk_scale
        else:
            x_embed_1 = F.normalize(x_embed_1, p=2, dim=1, eps=5e-5)
            x_embed_2 = F.normalize(x_embed_2, p=2, dim=1, eps=5e-5)
        B, C, H, W = x_embed_1.shape
        x_embed_1 = x_embed_1.permute(0, 2, 3, 1).view(B, 1, H * W, C)
        x_embed_2 = x_embed_2.permute(0, 2, 3, 1).view(B, 1, H * W, C)
        x_assembly = x_assembly.permute(0, 2, 3, 1).view(B, 1, H * W, -1)

        x = self.attn(x_embed_1, x_embed_2, x_assembly).squeeze(1) # (B, H*W, C)
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        x = x.view(B, H, W, C).permute(0, 3, 1, 2).contiguous()
        return x


if __name__ == '__main__':

    input = torch.randn(1, 64, 32, 32)
    input_resolution = (32, 32)
    block = ENLTB(dim=64,input_resolution=input_resolution)

    print(input.size())
    output = block(input)    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章