2024即插即用分层特征融合模块HFF,涨点起飞起飞了

文摘   2024-11-28 17:20   北京  

论文介绍

题目:HiFuse: Hierarchical multi-scale feature fusion network for medical image classification

论文地址:https://www.sciencedirect.com/science/article/abs/pii/S1746809423009679

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出HiFuse模型:开发了一种全新的三分支分层多尺度特征融合网络结构,用于医学图像分类。HiFuse模型能够有效融合全局和局部特征,避免破坏各自的建模,提升分类准确性。

  • 分层结构与多尺度特征融合

  • HiFuse设计了全局特征分支和局部特征分支,以平行方式分别提取全局语义信息和局部空间特征。通过引入分层特征融合模块(HFF模块),在不增加额外噪声的情况下实现了全局和局部特征的有效融合。

  • 创新模块设计

  • HFF模块结合了空间注意力机制、通道注意力机制、反向残差多层感知器(IRMLP)以及快捷连接,在不同层次实现了特征的自适应融合。使用窗口多头自注意力机制(W-MSA)优化全局特征提取,同时利用深度可分离卷积减少计算量。

  • 模型性能突出

  • 在多个医学图像数据集(如ISIC2018、Kvasir、Covid-19-CT和食管癌病理图像数据集)上的实验表明,该模型在分类精度和F1分数等指标上优于当前先进模型。HiFuse模型在复杂特征和数据噪声较高的医学图像任务中表现尤为出色。

  • 模块化设计与可扩展性

  • HiFuse的模块化设计(如全局和局部分支、分层融合模块)具备良好的扩展性,适用于各种医学图像分析任务。

方法

整体结构

       HiFuse模型是一个三分支分层多尺度特征融合网络,由全局分支、局部分支和分层特征融合模块(HFF模块)组成。全局分支通过窗口多头自注意力机制提取全局语义信息,局部分支通过深度可分离卷积捕获局部空间特征,而HFF模块通过通道注意力、空间注意力和反向残差多层感知器(IRMLP)实现全局与局部特征的自适应融合。该模型采用分层结构,在多个阶段融合不同层次的特征,最终通过全局平均池化和线性分类器完成分类,适用于医学图像分类任务。

模型的主要组成部分

(1) 局部特征分支

  • 通过 深度可分离卷积(Depthwise Convolution) 提取局部空间特征。

  • 使用 线性层 在通道间交互信息。

  • 特征提取完成后,将特征送入HFF模块进行融合。

(2) 全局特征分支

  • 使用 窗口多头自注意力机制(W-MSA) 提取全局语义信息。

    • W-MSA相比传统自注意力机制(MSA),通过将特征划分为窗口进行自注意力计算,显著减少了计算复杂度。

    • 结合层归一化(LayerNorm)和激活函数(GELU)增强特征表达能力。

  • 采用残差连接(Residual Connection)以及Shift-W-MSA增强全局建模能力。

(3) 分层特征融合模块(HFF Block)

  • 作用:将全局和局部特征进行自适应融合。

  • 组成

  1. 通道注意力机制(Channel Attention):选择性增强对重要通道特征的关注。

  2. 空间注意力机制(Spatial Attention):增强对关键空间区域的关注。

  3. 反向残差多层感知器(IRMLP):通过深度卷积和非线性变换对融合后的特征进行学习,增强表示能力。

  4. 快捷连接(Shortcut):促进梯度传递,缓解过拟合。

即插即用模块作用

HFF 作为一个即插即用模块

(1) 自适应融合多尺度特征

  • 将来自全局分支(全局语义信息)和局部分支(细粒度局部特征)的特征自适应融合,避免特征丢失。

  • 通过不同层次的融合,强化模型对多尺度语义的理解,适配复杂医学图像中的多样特征分布。

(2) 增强模型的特征表达能力

  • 通道注意力(Channel Attention)选择性增强重要通道特征,关注特定语义。

  • 空间注意力(Spatial Attention)聚焦关键空间区域,提升对病灶或异常区域的敏感性。

  • IRMLP模块:进一步学习融合特征的复杂非线性关系,提升特征表达能力。

(3) 降低计算复杂度,提升效率

  • 利用深度卷积和轻量级设计(如窗口自注意力机制),减少特征融合的计算开销。

  • 与传统方法相比,在计算成本可控的情况下实现更优性能。

(4) 抑制噪声,提升鲁棒性

  • 通过多层次特征融合和注意力机制,抑制无关信息干扰,提升模型对细粒度特征的捕捉能力。

消融实验结果

  • 添加全局路径后,ACC和F1值分别提升了 2.47%10.2%,说明全局特征的引入显著提升了模型的语义表示能力。

  • 加入HFF模块的组件(注意力机制和IRMLP)后,ACC和F1值进一步提升了 7.4%8.67%,表明特征融合模块在增强全局与局部特征互补性上的重要性。

  • 最终,完整的HiFuse-Tiny模型达到了 82.99%的ACC72.99%的F1值,验证了组件设计的合理性。

  • 随着融合阶段数量的增加,模型的性能逐步提升。

  • 当所有四个阶段都参与特征融合时,模型在ISIC2018数据集上的ACC达到了 85.85%,F1值为 74.57%,显著优于仅融合单阶段或少数阶段的结果。

  • 这表明全层次的特征融合对于全面提取全局和局部信息至关重要。

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F
#论文:HiFuse: Hierarchical multi-scale feature fusion network for medical image classification
#论文地址:https://www.sciencedirect.com/science/article/abs/pii/S1746809423009679

class LayerNorm(nn.Module):
    """
    channels_last corresponds to inputs with shape (batch_size, height, width, channels)
    channels_first corresponds to inputs with shape (batch_size, channels, height, width)
    """

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape), requires_grad=True)
        self.bias = nn.Parameter(torch.zeros(normalized_shape), requires_grad=True)
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise ValueError(f"not support data format '{self.data_format}'")
        self.normalized_shape = (normalized_shape,)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            # [batch_size, channels, height, width]
            mean = x.mean(1, keepdim=True)
            var = (x - mean).pow(2).mean(1, keepdim=True)
            x = (x - mean) / torch.sqrt(var + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

def drop_path_f(x, drop_prob: float = 0., training: bool = False):
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
    random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
    random_tensor.floor_() # binarize
    output = x.div(keep_prob) * random_tensor
    return output

class DropPath(nn.Module):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path_f(x, self.drop_prob, self.training)

class Conv(nn.Module):
    def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, bias=True, group=1):
        super(Conv, self).__init__()
        self.inp_dim = inp_dim
        self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=bias)
        self.relu = None
        self.bn = None
        if relu:
            self.relu = nn.ReLU(inplace=True)
        if bn:
            self.bn = nn.BatchNorm2d(out_dim)

    def forward(self, x):
        assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
        x = self.conv(x)
        if self.bn is not None:
            x = self.bn(x)
        if self.relu is not None:
            x = self.relu(x)
        return x

#### Inverted Residual MLP
class IRMLP(nn.Module):
    def __init__(self, inp_dim, out_dim):
        super(IRMLP, self).__init__()
        self.conv1 = Conv(inp_dim, inp_dim, 3, relu=False, bias=False, group=inp_dim)
        self.conv2 = Conv(inp_dim, inp_dim * 4, 1, relu=False, bias=False)
        self.conv3 = Conv(inp_dim * 4, out_dim, 1, relu=False, bias=False, bn=True)
        self.gelu = nn.GELU()
        self.bn1 = nn.BatchNorm2d(inp_dim)

    def forward(self, x):

        residual = x
        out = self.conv1(x)
        out = self.gelu(out)
        out += residual

        out = self.bn1(out)
        out = self.conv2(out)
        out = self.gelu(out)
        out = self.conv3(out)

        return out

# Hierachical Feature Fusion Block
class HFF_block(nn.Module):
    def __init__(self, ch_1, ch_2, r_2, ch_int, ch_out, drop_rate=0.):
        super(HFF_block, self).__init__()
        self.maxpool=nn.AdaptiveMaxPool2d(1)
        self.avgpool=nn.AdaptiveAvgPool2d(1)
        self.se=nn.Sequential(
            nn.Conv2d(ch_2, ch_2 // r_2, 1,bias=False),
            nn.ReLU(),
            nn.Conv2d(ch_2 // r_2, ch_2, 1,bias=False)
        )
        self.sigmoid = nn.Sigmoid()
        self.spatial = Conv(2, 1, 7, bn=True, relu=False, bias=False)
        self.W_l = Conv(ch_1, ch_int, 1, bn=True, relu=False)
        self.W_g = Conv(ch_2, ch_int, 1, bn=True, relu=False)
        self.Avg = nn.AvgPool2d(2, stride=2)
        self.Updim = Conv(ch_int//2, ch_int, 1, bn=True, relu=True)
        self.norm1 = LayerNorm(ch_int * 3, eps=1e-6, data_format="channels_first")
        self.norm2 = LayerNorm(ch_int * 2, eps=1e-6, data_format="channels_first")
        self.norm3 = LayerNorm(ch_1 + ch_2 + ch_int, eps=1e-6, data_format="channels_first")
        self.W3 = Conv(ch_int * 3, ch_int, 1, bn=True, relu=False)
        self.W = Conv(ch_int * 2, ch_int, 1, bn=True, relu=False)

        self.gelu = nn.GELU()

        self.residual = IRMLP(ch_1 + ch_2 + ch_int, ch_out)
        self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()

    def forward(self, l, g, f):

        W_local = self.W_l(l) # local feature from Local Feature Block
        W_global = self.W_g(g) # global feature from Global Feature Block
        if f is not None:
            W_f = self.Updim(f)
            W_f = self.Avg(W_f)
            shortcut = W_f
            X_f = torch.cat([W_f, W_local, W_global], 1)
            X_f = self.norm1(X_f)
            X_f = self.W3(X_f)
            X_f = self.gelu(X_f)
        else:
            shortcut = 0
            X_f = torch.cat([W_local, W_global], 1)
            X_f = self.norm2(X_f)
            X_f = self.W(X_f)
            X_f = self.gelu(X_f)

        # spatial attention for ConvNeXt branch
        l_jump = l
        max_result, _ = torch.max(l, dim=1, keepdim=True)
        avg_result = torch.mean(l, dim=1, keepdim=True)
        result = torch.cat([max_result, avg_result], 1)
        l = self.spatial(result)
        l = self.sigmoid(l) * l_jump

        # channel attetion for transformer branch
        g_jump = g
        max_result=self.maxpool(g)
        avg_result=self.avgpool(g)
        max_out=self.se(max_result)
        avg_out=self.se(avg_result)
        g = self.sigmoid(max_out+avg_out) * g_jump

        fuse = torch.cat([g, l, X_f], 1)
        fuse = self.norm3(fuse)
        fuse = self.residual(fuse)
        fuse = shortcut + self.drop_path(fuse)
        return fuse


if __name__ == '__main__':


    block1 = HFF_block(ch_1=192, ch_2=192, r_2=16, ch_int=192, ch_out=192, drop_rate=0)
    block2 = HFF_block(ch_1=128, ch_2=128, r_2=16, ch_int=128, ch_out=128, drop_rate=0)

    # 生成模拟输入
    l1 = torch.rand(1, 192, 28, 28) # 局部特征
    g1 = torch.rand(1, 192, 28, 28) # 全局特征
    f1 = torch.rand(1, 96 , 56, 56) # 中间特征

    l2 = torch.rand(1, 128, 64, 64) # 局部特征
    g2 = torch.rand(1, 128, 64, 64) # 全局特征
    f2 = torch.rand(1, 64, 128, 128) # 中间特征

    # 传递输入并获取输出
    output = block1(l1, g1, f1)
    # output = hff_block2(l2, g2, None)
    # output = hff_block2(l2, g2, f2)

    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章