论文介绍
题目:A dual encoder crack segmentation network with Haar wavelet-based high-low frequency attention
论文地址:https://doi.org/10.1016/j.eswa.2024.124950
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
双编码器结构(DECS-Net):
提出了一种结合卷积神经网络(CNN)和变换器(Transformer)的双编码器裂缝分割网络。CNN用于提取局部信息,而Transformer用于捕获全局语义信息,两者互补。
高低频注意机制(HLA):
基于Haar小波分解的高低频注意机制,用于分别提取高频(边缘信息)和低频(全局语义信息)特征,有助于提高对裂缝边缘的敏感性。
局部增强前馈网络(LEFN):
在传统Transformer的基础上,设计了一种局部增强的前馈网络,通过增强图像补丁之间的交互,改善了网络对局部信息的感知能力。
特征融合模块(FFM):
提出了一个特征融合模块,用于融合CNN和Transformer提取的中间特征。通过通道注意(CA)、跨域融合块(CFB)和相关性增强操作,优化了不同特征域间的交互,显著提升了特征融合效果。
方法
整体架构
DECS-Net 是一种结合 CNN 和 Transformer 的双编码器裂缝分割网络,利用 CNN 提取局部信息、Transformer 捕获全局语义,通过高低频注意机制和特征融合模块(FFM)深度整合两者特性。其创新点包括基于 Haar 小波的高低频特征提取、局部增强前馈网络(LEFN)提升局部感知能力,以及跨域融合和相关性增强模块优化特征表达。实验表明,该模型在裂缝分割任务中显著优于现有方法,具备更高的召回率和整体性能。
(1) 双编码器结构
CNN 编码器:
基于 ResNet-50 构建,包含初始化层、最大池化层和四个卷积层。
主要用于提取图像的局部空间特征。
输出多尺度特征图,用于与 Transformer 编码器的特征进行融合。
Transformer 编码器:
将输入图像分割成多个小块(patch),逐层下采样,得到不同尺度的特征图(4倍、8倍、16倍、32倍下采样)。
采用高低频注意机制(HLA)和局部增强前馈网络(LEFN),分别提取高频(边缘)和低频(语义)特征。
特征提取经过多次 Transformer 块(每层重复3次)增强全局语义信息。
(2) 特征融合模块(FFM)
CNN 和 Transformer 编码器提取的特征在每一层通过 FFM 进行融合。
FFM 主要包括以下几个部分:
通道注意机制(CA):调整特征图中不同通道的权重,减少冗余信息。
跨域融合块(CFB):在 CNN 和 Transformer 提取的特征间进行深层交互。
相关性增强操作(CE):通过矩阵运算,强化两种特征图间的相关性。
特征融合块(FFB):将融合后的特征进行进一步精简和聚合。
(3) 解码器
解码器用于将融合后的特征图恢复到输入图像的大小,生成分割掩膜。
主要特点:
使用子像素卷积(PixelShuffle)进行上采样,有效保留特征信息。
每层接收对应尺度的融合特征,进行拼接后通过 IDSC(逆深度分离卷积)降维处理。
最终通过卷积生成分割结果。
即插即用模块作用
FCHilo 作为一个即插即用模块:
高频特征捕获:
作用:捕获目标的边缘和细节特征。
适用场景:在裂缝检测、医学影像等任务中,能够增强对微小目标的敏感性,提取复杂形状的细节边界。
低频特征建模:
作用:捕获图像的全局语义信息。
适用场景:在自然图像分割中,有助于理解目标的整体形状和空间位置,降低背景噪声干扰。
局部与全局特征的平衡:
作用:结合局部高频特征和全局低频特征,提高特征的表达能力。
适用场景:在复杂背景或多目标分割场景下,帮助模型更加精准地聚焦目标区域。
抗背景干扰:
作用:通过区分高频和低频信息,减少背景复杂性对模型的干扰。
适用场景:如裂缝检测任务中,将裂缝从背景中分离出来,显著提高检测准确率。
消融实验结果
表 3(Compare the different combinations of main operations in FFM on the DeepCrack dataset):
内容:分析了特征融合模块(FFM)中各关键组件(CA、CFB 和 CE)的作用。通过对比不同组件的组合,证明同时采用这三个操作能获得最佳的分割性能(F1 达到 87.51%)。
说明:FFM 的每个部分(CA、CFB 和 CE)都对性能有积极影响,其中 CFB 能深度交互跨域特征,但略微影响推理速度(FPS)。
表 4(Compare the effectiveness of CNN encoder and transformer encoder on the DeepCrack dataset):
内容:分析单独使用 CNN 编码器或 Transformer 编码器的性能,结果显示双编码器结构(同时使用 CNN 和 Transformer)在综合指标(F1 和 mIoU)上表现最佳。
说明:CNN 擅长局部特征提取,Transformer 提升全局语义建模能力,两者结合通过特征融合进一步增强了网络对裂缝的检测能力。
表 5(Compare the Pr, Re, F1, and mIoU of different number of transformer blocks on the DeepCrack dataset):
内容:测试了 Transformer 编码器中不同数量 Transformer 块的性能(每层执行 1~4 次),发现每层执行 3 次时综合效果最佳(F1 达到 87.51%)。
说明:适当的 Transformer 块数量能在建模能力和效率之间取得平衡。
表 6(Compare the Pr, Re, F1, and mIoU of different CNN encoder on the DeepCrack dataset):
内容:测试了 CNN 编码器使用不同 ResNet 深度(ResNet-18、34、50、101、152)的效果,结果显示 ResNet-50 性能最优(F1 达到 87.51%),并且计算复杂度适中。
说明:更深的 CNN 网络能更好提取局部特征,但过深会导致小目标位置信息丢失。
即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
# 论文:A dual encoder crack segmentation network with Haar wavelet-based high-low frequency attention
# 论文地址:https://doi.org/10.1016/j.eswa.2024.124950
class PositionEmbedding(nn.Module):
def __init__(self, t=10000):
super().__init__()
self.t = t
def forward(self, x):
B, N, C = x.shape
assert C % 2 == 0, 'dim must be divided 2'
pos_embed = torch.zeros(N, C, dtype=torch.float32)
N_num = torch.arange(N, dtype=torch.float32)
o = torch.arange(C//2, dtype=torch.float32)
o /= C/2.
o = 1. / (self.t**o)
out = N_num[:, None] @ o[None, :]
sin_embed = torch.sin(out)
cos_embed = torch.cos(out)
pos_embed[:, 0::2] = sin_embed
pos_embed[:, 1::2] = cos_embed
pos_embed = pos_embed.unsqueeze(0).repeat(B, 1, 1)
return pos_embed
class DSC(nn.Module):
def __init__(self, c_in, c_out, k_size=3, stride=1, padding=1):
super(DSC, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.dw = nn.Conv2d(c_in, c_in, k_size, stride, padding, groups=c_in)
self.pw = nn.Conv2d(c_in, c_out, 1, 1)
def forward(self, x):
out = self.dw(x)
out = self.pw(out)
return out
class IDSC(nn.Module):
def __init__(self, c_in, c_out, k_size=3, stride=1, padding=1):
super(IDSC, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.dw = nn.Conv2d(c_out, c_out, k_size, stride, padding, groups=c_out)
self.pw = nn.Conv2d(c_in, c_out, 1, 1)
def forward(self, x):
out = self.pw(x)
out = self.dw(out)
return out
class FCHiLo1(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, window_size=2, alpha=0.5):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
head_dim = int(dim / num_heads)
self.dim = dim
self.pos = PositionEmbedding()
self.l_heads = int(num_heads * alpha)
self.l_dim = self.l_heads * head_dim
self.h_heads = num_heads - self.l_heads
self.h_dim = self.h_heads * head_dim
self.ws = window_size
if self.ws == 1:
self.h_heads = 0
self.h_dim = 0
self.l_heads = num_heads
self.l_dim = dim
self.scale = qk_scale or head_dim ** -0.5
if self.ws != 1:
# self.wt = DWTForward(J=1, mode='zero', wave='haar')
self.wt = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
else:
self.sr = nn.Sequential()
if self.l_heads > 0:
self.l_q = DSC(self.dim, self.l_dim)
self.l_kv = DSC(self.dim, self.l_dim*2)
self.l_proj = DSC(self.l_dim, self.l_dim)
if self.h_heads > 0:
self.h_qkv = DSC(self.dim, self.h_dim*3)
self.h_proj = DSC(self.h_dim, self.h_dim)
def hi_lofi(self, x):
B, N, C = x.shape
H = W = int(N ** 0.5)
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
if self.ws != 1:
# low_feats, yH = self.wt(x)
low_feats = self.wt(x)
else:
low_feats = self.sr(x)
high_feats = F.interpolate(low_feats, size=H, mode='nearest')
high_feats = high_feats - x
if self.l_heads!=0:
l_q = self.l_q(x).permute(0, 2, 3, 1).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)
if self.ws > 1:
l_kv = self.l_kv(low_feats).permute(0, 2, 3, 1).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
else:
l_kv = self.l_kv(x).permute(0, 2, 3, 1).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
l_k, l_v = l_kv[0], l_kv[1]
l_attn = (l_q @ l_k.transpose(-2, -1)) * self.scale
l_attn = l_attn.softmax(dim=-1)
l_x = (l_attn @ l_v).transpose(1, 2).reshape(B, H, W, self.l_dim).permute(0, 3, 1, 2)
l_x = self.l_proj(l_x).permute(0, 2, 3, 1)
if self.h_heads!=0:
h_group, w_group = H // self.ws, W // self.ws
total_groups = h_group * w_group
h_qkv = self.h_qkv(high_feats).permute(0, 2, 3, 1).\
reshape(B, h_group, self.ws, w_group, self.ws, 3*self.h_dim).\
transpose(2, 3).reshape(B, total_groups, -1, 3, self.h_heads,
self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
h_q, h_k, h_v = h_qkv[0], h_qkv[1], h_qkv[2]
h_attn = (h_q @ h_k.transpose(-2, -1)) * self.scale
h_attn = h_attn.softmax(dim=-1)
h_attn = (h_attn @ h_v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
h_x = h_attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim).permute(0, 3, 1, 2)
h_x = self.h_proj(h_x).permute(0, 2, 3, 1)
if self.h_heads!=0 and self.l_heads!=0:
out = torch.cat([l_x, h_x], dim=-1)
out = out.reshape(B, N, C)
if self.l_heads==0:
out = h_x.reshape(B, N, C)
if self.h_heads==0:
out = l_x.reshape(B, N, C)
return out
def forward(self, x):
return self.hi_lofi(x)
class FFN1(nn.Module):
def __init__(self, dim, h_dim=None, out_dim=None):
super().__init__()
self.h_dim = dim*2 if h_dim==None else h_dim
self.out_dim = dim if out_dim==None else out_dim
self.act = nn.GELU()
self.fc1 = DSC(dim, self.h_dim)
self.norm = nn.BatchNorm2d(self.out_dim)
self.fc2 = DSC(self.h_dim, self.h_dim)
self.fc3 = IDSC(self.h_dim, self.out_dim)
def forward(self, x):
B, N, C = x.shape
H = W = int(N**0.5)
x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
x = self.act(self.fc3(self.act(self.fc2(self.act(self.fc1(x))))))
x = self.norm(x).reshape(B, C, -1).permute(0, 2, 1)
return x
class Block1(nn.Module):
def __init__(self, dim, num_heads=8, window_size=2, alpha=0.5, qkv_bias=False, qk_scale=None, h_dim=None, out_dim=None):
super().__init__()
self.hilo = FCHiLo1(dim, num_heads, qkv_bias, qk_scale, window_size, alpha)
self.ffn = FFN1(dim, h_dim, out_dim)
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
def forward(self, x):
x = x + self.norm1(self.hilo(x))
x = x + self.norm2(self.ffn(x))
return x
if __name__ == '__main__':
input = torch.randn(1, 1024, 64) # B N C
block1 = Block1(64)
print(input.size())
output_block1 = block1(input)
print(output_block1.size())
ffn1 = FFN1(64)
print(input.size())
output_ffn1 = ffn1(input)
print(output_ffn1.size())
# Instantiate FCHiLo1
fchilo1 = FCHiLo1(64)
print(input.size())
output_fchilo1 = fchilo1(input) print(output_fchilo1.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文