论文介绍
题目:HiFuse: Hierarchical multi-scale feature fusion network for medical image classification
论文地址:https://www.sciencedirect.com/science/article/abs/pii/S1746809423009679
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
提出HiFuse模型:开发了一种全新的三分支分层多尺度特征融合网络结构,用于医学图像分类。HiFuse模型能够有效融合全局和局部特征,避免破坏各自的建模,提升分类准确性。
分层结构与多尺度特征融合:
创新模块设计:
模型性能突出:
模块化设计与可扩展性:
HiFuse设计了全局特征分支和局部特征分支,以平行方式分别提取全局语义信息和局部空间特征。通过引入分层特征融合模块(HFF模块),在不增加额外噪声的情况下实现了全局和局部特征的有效融合。
HFF模块结合了空间注意力机制、通道注意力机制、反向残差多层感知器(IRMLP)以及快捷连接,在不同层次实现了特征的自适应融合。使用窗口多头自注意力机制(W-MSA)优化全局特征提取,同时利用深度可分离卷积减少计算量。
在多个医学图像数据集(如ISIC2018、Kvasir、Covid-19-CT和食管癌病理图像数据集)上的实验表明,该模型在分类精度和F1分数等指标上优于当前先进模型。HiFuse模型在复杂特征和数据噪声较高的医学图像任务中表现尤为出色。
HiFuse的模块化设计(如全局和局部分支、分层融合模块)具备良好的扩展性,适用于各种医学图像分析任务。
方法
整体结构
模型的主要组成部分
(1) 局部特征分支
通过 深度可分离卷积(Depthwise Convolution) 提取局部空间特征。
使用 线性层 在通道间交互信息。
特征提取完成后,将特征送入HFF模块进行融合。
(2) 全局特征分支
使用 窗口多头自注意力机制(W-MSA) 提取全局语义信息。
W-MSA相比传统自注意力机制(MSA),通过将特征划分为窗口进行自注意力计算,显著减少了计算复杂度。
结合层归一化(LayerNorm)和激活函数(GELU)增强特征表达能力。
采用残差连接(Residual Connection)以及Shift-W-MSA增强全局建模能力。
(3) 分层特征融合模块(HFF Block)
作用:将全局和局部特征进行自适应融合。
组成:
通道注意力机制(Channel Attention):选择性增强对重要通道特征的关注。
空间注意力机制(Spatial Attention):增强对关键空间区域的关注。
反向残差多层感知器(IRMLP):通过深度卷积和非线性变换对融合后的特征进行学习,增强表示能力。
快捷连接(Shortcut):促进梯度传递,缓解过拟合。
即插即用模块作用
HFF 作为一个即插即用模块:
(1) 自适应融合多尺度特征
将来自全局分支(全局语义信息)和局部分支(细粒度局部特征)的特征自适应融合,避免特征丢失。
通过不同层次的融合,强化模型对多尺度语义的理解,适配复杂医学图像中的多样特征分布。
(2) 增强模型的特征表达能力
通道注意力(Channel Attention)选择性增强重要通道特征,关注特定语义。
空间注意力(Spatial Attention)聚焦关键空间区域,提升对病灶或异常区域的敏感性。
IRMLP模块:进一步学习融合特征的复杂非线性关系,提升特征表达能力。
(3) 降低计算复杂度,提升效率
利用深度卷积和轻量级设计(如窗口自注意力机制),减少特征融合的计算开销。
与传统方法相比,在计算成本可控的情况下实现更优性能。
(4) 抑制噪声,提升鲁棒性
通过多层次特征融合和注意力机制,抑制无关信息干扰,提升模型对细粒度特征的捕捉能力。
消融实验结果
添加全局路径后,ACC和F1值分别提升了 2.47% 和 10.2%,说明全局特征的引入显著提升了模型的语义表示能力。
加入HFF模块的组件(注意力机制和IRMLP)后,ACC和F1值进一步提升了 7.4% 和 8.67%,表明特征融合模块在增强全局与局部特征互补性上的重要性。
最终,完整的HiFuse-Tiny模型达到了 82.99%的ACC 和 72.99%的F1值,验证了组件设计的合理性。
随着融合阶段数量的增加,模型的性能逐步提升。
当所有四个阶段都参与特征融合时,模型在ISIC2018数据集上的ACC达到了 85.85%,F1值为 74.57%,显著优于仅融合单阶段或少数阶段的结果。
这表明全层次的特征融合对于全面提取全局和局部信息至关重要。
即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
#论文:HiFuse: Hierarchical multi-scale feature fusion network for medical image classification
#论文地址:https://www.sciencedirect.com/science/article/abs/pii/S1746809423009679
class LayerNorm(nn.Module):
"""
channels_last corresponds to inputs with shape (batch_size, height, width, channels)
channels_first corresponds to inputs with shape (batch_size, channels, height, width)
"""
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(normalized_shape), requires_grad=True)
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise ValueError(f"not support data format '{self.data_format}'")
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
# [batch_size, channels, height, width]
mean = x.mean(1, keepdim=True)
var = (x - mean).pow(2).mean(1, keepdim=True)
x = (x - mean) / torch.sqrt(var + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
def drop_path_f(x, drop_prob: float = 0., training: bool = False):
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path_f(x, self.drop_prob, self.training)
class Conv(nn.Module):
def __init__(self, inp_dim, out_dim, kernel_size=3, stride=1, bn=False, relu=True, bias=True, group=1):
super(Conv, self).__init__()
self.inp_dim = inp_dim
self.conv = nn.Conv2d(inp_dim, out_dim, kernel_size, stride, padding=(kernel_size-1)//2, bias=bias)
self.relu = None
self.bn = None
if relu:
self.relu = nn.ReLU(inplace=True)
if bn:
self.bn = nn.BatchNorm2d(out_dim)
def forward(self, x):
assert x.size()[1] == self.inp_dim, "{} {}".format(x.size()[1], self.inp_dim)
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
#### Inverted Residual MLP
class IRMLP(nn.Module):
def __init__(self, inp_dim, out_dim):
super(IRMLP, self).__init__()
self.conv1 = Conv(inp_dim, inp_dim, 3, relu=False, bias=False, group=inp_dim)
self.conv2 = Conv(inp_dim, inp_dim * 4, 1, relu=False, bias=False)
self.conv3 = Conv(inp_dim * 4, out_dim, 1, relu=False, bias=False, bn=True)
self.gelu = nn.GELU()
self.bn1 = nn.BatchNorm2d(inp_dim)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.gelu(out)
out += residual
out = self.bn1(out)
out = self.conv2(out)
out = self.gelu(out)
out = self.conv3(out)
return out
# Hierachical Feature Fusion Block
class HFF_block(nn.Module):
def __init__(self, ch_1, ch_2, r_2, ch_int, ch_out, drop_rate=0.):
super(HFF_block, self).__init__()
self.maxpool=nn.AdaptiveMaxPool2d(1)
self.avgpool=nn.AdaptiveAvgPool2d(1)
self.se=nn.Sequential(
nn.Conv2d(ch_2, ch_2 // r_2, 1,bias=False),
nn.ReLU(),
nn.Conv2d(ch_2 // r_2, ch_2, 1,bias=False)
)
self.sigmoid = nn.Sigmoid()
self.spatial = Conv(2, 1, 7, bn=True, relu=False, bias=False)
self.W_l = Conv(ch_1, ch_int, 1, bn=True, relu=False)
self.W_g = Conv(ch_2, ch_int, 1, bn=True, relu=False)
self.Avg = nn.AvgPool2d(2, stride=2)
self.Updim = Conv(ch_int//2, ch_int, 1, bn=True, relu=True)
self.norm1 = LayerNorm(ch_int * 3, eps=1e-6, data_format="channels_first")
self.norm2 = LayerNorm(ch_int * 2, eps=1e-6, data_format="channels_first")
self.norm3 = LayerNorm(ch_1 + ch_2 + ch_int, eps=1e-6, data_format="channels_first")
self.W3 = Conv(ch_int * 3, ch_int, 1, bn=True, relu=False)
self.W = Conv(ch_int * 2, ch_int, 1, bn=True, relu=False)
self.gelu = nn.GELU()
self.residual = IRMLP(ch_1 + ch_2 + ch_int, ch_out)
self.drop_path = DropPath(drop_rate) if drop_rate > 0. else nn.Identity()
def forward(self, l, g, f):
W_local = self.W_l(l) # local feature from Local Feature Block
W_global = self.W_g(g) # global feature from Global Feature Block
if f is not None:
W_f = self.Updim(f)
W_f = self.Avg(W_f)
shortcut = W_f
X_f = torch.cat([W_f, W_local, W_global], 1)
X_f = self.norm1(X_f)
X_f = self.W3(X_f)
X_f = self.gelu(X_f)
else:
shortcut = 0
X_f = torch.cat([W_local, W_global], 1)
X_f = self.norm2(X_f)
X_f = self.W(X_f)
X_f = self.gelu(X_f)
# spatial attention for ConvNeXt branch
l_jump = l
max_result, _ = torch.max(l, dim=1, keepdim=True)
avg_result = torch.mean(l, dim=1, keepdim=True)
result = torch.cat([max_result, avg_result], 1)
l = self.spatial(result)
l = self.sigmoid(l) * l_jump
# channel attetion for transformer branch
g_jump = g
max_result=self.maxpool(g)
avg_result=self.avgpool(g)
max_out=self.se(max_result)
avg_out=self.se(avg_result)
g = self.sigmoid(max_out+avg_out) * g_jump
fuse = torch.cat([g, l, X_f], 1)
fuse = self.norm3(fuse)
fuse = self.residual(fuse)
fuse = shortcut + self.drop_path(fuse)
return fuse
if __name__ == '__main__':
block1 = HFF_block(ch_1=192, ch_2=192, r_2=16, ch_int=192, ch_out=192, drop_rate=0)
block2 = HFF_block(ch_1=128, ch_2=128, r_2=16, ch_int=128, ch_out=128, drop_rate=0)
# 生成模拟输入
l1 = torch.rand(1, 192, 28, 28) # 局部特征
g1 = torch.rand(1, 192, 28, 28) # 全局特征
f1 = torch.rand(1, 96 , 56, 56) # 中间特征
l2 = torch.rand(1, 128, 64, 64) # 局部特征
g2 = torch.rand(1, 128, 64, 64) # 全局特征
f2 = torch.rand(1, 64, 128, 128) # 中间特征
# 传递输入并获取输出
output = block1(l1, g1, f1)
# output = hff_block2(l2, g2, None)
# output = hff_block2(l2, g2, f2)
print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文