(ESWA 2024) 高低频注意力机制FCHilo,即插即用涨点启动!

文摘   2024-12-19 17:20   中国香港  

论文介绍

题目:A dual encoder crack segmentation network with Haar wavelet-based high-low frequency attention

论文地址:https://doi.org/10.1016/j.eswa.2024.124950

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 双编码器结构(DECS-Net)

    • 提出了一种结合卷积神经网络(CNN)和变换器(Transformer)的双编码器裂缝分割网络。CNN用于提取局部信息,而Transformer用于捕获全局语义信息,两者互补。

  • 高低频注意机制(HLA)

    • 基于Haar小波分解的高低频注意机制,用于分别提取高频(边缘信息)和低频(全局语义信息)特征,有助于提高对裂缝边缘的敏感性。

  • 局部增强前馈网络(LEFN)

    • 在传统Transformer的基础上,设计了一种局部增强的前馈网络,通过增强图像补丁之间的交互,改善了网络对局部信息的感知能力。

  • 特征融合模块(FFM)

  • 提出了一个特征融合模块,用于融合CNN和Transformer提取的中间特征。通过通道注意(CA)、跨域融合块(CFB)和相关性增强操作,优化了不同特征域间的交互,显著提升了特征融合效果。

方法

整体架构

     DECS-Net 是一种结合 CNN 和 Transformer 的双编码器裂缝分割网络,利用 CNN 提取局部信息、Transformer 捕获全局语义,通过高低频注意机制和特征融合模块(FFM)深度整合两者特性。其创新点包括基于 Haar 小波的高低频特征提取、局部增强前馈网络(LEFN)提升局部感知能力,以及跨域融合和相关性增强模块优化特征表达。实验表明,该模型在裂缝分割任务中显著优于现有方法,具备更高的召回率和整体性能。

(1) 双编码器结构

  • CNN 编码器

    • 基于 ResNet-50 构建,包含初始化层、最大池化层和四个卷积层。

    • 主要用于提取图像的局部空间特征。

    • 输出多尺度特征图,用于与 Transformer 编码器的特征进行融合。

  • Transformer 编码器


    • 将输入图像分割成多个小块(patch),逐层下采样,得到不同尺度的特征图(4倍、8倍、16倍、32倍下采样)。

    • 采用高低频注意机制(HLA)和局部增强前馈网络(LEFN),分别提取高频(边缘)和低频(语义)特征。

    • 特征提取经过多次 Transformer 块(每层重复3次)增强全局语义信息。

(2) 特征融合模块(FFM)

  • CNN 和 Transformer 编码器提取的特征在每一层通过 FFM 进行融合。

  • FFM 主要包括以下几个部分:

  1. 通道注意机制(CA):调整特征图中不同通道的权重,减少冗余信息。

  2. 跨域融合块(CFB):在 CNN 和 Transformer 提取的特征间进行深层交互。


  3. 相关性增强操作(CE):通过矩阵运算,强化两种特征图间的相关性。

  4. 特征融合块(FFB):将融合后的特征进行进一步精简和聚合。

(3) 解码器

  • 解码器用于将融合后的特征图恢复到输入图像的大小,生成分割掩膜。

  • 主要特点:

    • 使用子像素卷积(PixelShuffle)进行上采样,有效保留特征信息。

    • 每层接收对应尺度的融合特征,进行拼接后通过 IDSC(逆深度分离卷积)降维处理。

    • 最终通过卷积生成分割结果。

即插即用模块作用

FCHilo 作为一个即插即用模块

  • 高频特征捕获

    • 作用:捕获目标的边缘和细节特征。

    • 适用场景:在裂缝检测、医学影像等任务中,能够增强对微小目标的敏感性,提取复杂形状的细节边界。

  • 低频特征建模

    • 作用:捕获图像的全局语义信息。

    • 适用场景:在自然图像分割中,有助于理解目标的整体形状和空间位置,降低背景噪声干扰。

  • 局部与全局特征的平衡

    • 作用:结合局部高频特征和全局低频特征,提高特征的表达能力。

    • 适用场景:在复杂背景或多目标分割场景下,帮助模型更加精准地聚焦目标区域。

  • 抗背景干扰

  • 作用:通过区分高频和低频信息,减少背景复杂性对模型的干扰。

    适用场景:如裂缝检测任务中,将裂缝从背景中分离出来,显著提高检测准确率。

消融实验结果

表 3(Compare the different combinations of main operations in FFM on the DeepCrack dataset):

  • 内容:分析了特征融合模块(FFM)中各关键组件(CA、CFB 和 CE)的作用。通过对比不同组件的组合,证明同时采用这三个操作能获得最佳的分割性能(F1 达到 87.51%)。

  • 说明:FFM 的每个部分(CA、CFB 和 CE)都对性能有积极影响,其中 CFB 能深度交互跨域特征,但略微影响推理速度(FPS)。


表 4(Compare the effectiveness of CNN encoder and transformer encoder on the DeepCrack dataset):

  • 内容:分析单独使用 CNN 编码器或 Transformer 编码器的性能,结果显示双编码器结构(同时使用 CNN 和 Transformer)在综合指标(F1 和 mIoU)上表现最佳。

  • 说明:CNN 擅长局部特征提取,Transformer 提升全局语义建模能力,两者结合通过特征融合进一步增强了网络对裂缝的检测能力。


表 5(Compare the Pr, Re, F1, and mIoU of different number of transformer blocks on the DeepCrack dataset):

  • 内容:测试了 Transformer 编码器中不同数量 Transformer 块的性能(每层执行 1~4 次),发现每层执行 3 次时综合效果最佳(F1 达到 87.51%)。

  • 说明:适当的 Transformer 块数量能在建模能力和效率之间取得平衡。


表 6(Compare the Pr, Re, F1, and mIoU of different CNN encoder on the DeepCrack dataset):

  • 内容:测试了 CNN 编码器使用不同 ResNet 深度(ResNet-18、34、50、101、152)的效果,结果显示 ResNet-50 性能最优(F1 达到 87.51%),并且计算复杂度适中。

  • 说明:更深的 CNN 网络能更好提取局部特征,但过深会导致小目标位置信息丢失。

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F
# 论文:A dual encoder crack segmentation network with Haar wavelet-based high-low frequency attention
# 论文地址:https://doi.org/10.1016/j.eswa.2024.124950


class PositionEmbedding(nn.Module):
    def __init__(self, t=10000):
        super().__init__()
        self.t = t

    def forward(self, x):
        B, N, C = x.shape
        assert C % 2 == 0, 'dim must be divided 2'

        pos_embed = torch.zeros(N, C, dtype=torch.float32)

        N_num = torch.arange(N, dtype=torch.float32)

        o = torch.arange(C//2, dtype=torch.float32)
        o /= C/2.
        o = 1. / (self.t**o)

        out = N_num[:, None] @ o[None, :]

        sin_embed = torch.sin(out)
        cos_embed = torch.cos(out)

        pos_embed[:, 0::2] = sin_embed
        pos_embed[:, 1::2] = cos_embed

        pos_embed = pos_embed.unsqueeze(0).repeat(B, 1, 1)
        return pos_embed

class DSC(nn.Module):
    def __init__(self, c_in, c_out, k_size=3, stride=1, padding=1):
        super(DSC, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.dw = nn.Conv2d(c_in, c_in, k_size, stride, padding, groups=c_in)
        self.pw = nn.Conv2d(c_in, c_out, 1, 1)

    def forward(self, x):
        out = self.dw(x)
        out = self.pw(out)
        return out

class IDSC(nn.Module):
    def __init__(self, c_in, c_out, k_size=3, stride=1, padding=1):
        super(IDSC, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.dw = nn.Conv2d(c_out, c_out, k_size, stride, padding, groups=c_out)
        self.pw = nn.Conv2d(c_in, c_out, 1, 1)

    def forward(self, x):
        out = self.pw(x)
        out = self.dw(out)
        return out

class FCHiLo1(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, window_size=2, alpha=0.5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        head_dim = int(dim / num_heads)
        self.dim = dim
        self.pos = PositionEmbedding()

        self.l_heads = int(num_heads * alpha)
        self.l_dim = self.l_heads * head_dim

        self.h_heads = num_heads - self.l_heads
        self.h_dim = self.h_heads * head_dim

        self.ws = window_size

        if self.ws == 1:
            self.h_heads = 0
            self.h_dim = 0
            self.l_heads = num_heads
            self.l_dim = dim

        self.scale = qk_scale or head_dim ** -0.5

        if self.ws != 1:
            # self.wt = DWTForward(J=1, mode='zero', wave='haar')
            self.wt = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
        else:
            self.sr = nn.Sequential()

        if self.l_heads > 0:
            self.l_q = DSC(self.dim, self.l_dim)
            self.l_kv = DSC(self.dim, self.l_dim*2)
            self.l_proj = DSC(self.l_dim, self.l_dim)

        if self.h_heads > 0:
            self.h_qkv = DSC(self.dim, self.h_dim*3)
            self.h_proj = DSC(self.h_dim, self.h_dim)

    def hi_lofi(self, x):
        B, N, C = x.shape
        H = W = int(N ** 0.5)
        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)

        if self.ws != 1:
            # low_feats, yH = self.wt(x)
            low_feats = self.wt(x)
        else:
            low_feats = self.sr(x)

        high_feats = F.interpolate(low_feats, size=H, mode='nearest')
        high_feats = high_feats - x

        if self.l_heads!=0:
            l_q = self.l_q(x).permute(0, 2, 3, 1).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)
            if self.ws > 1:
                l_kv = self.l_kv(low_feats).permute(0, 2, 3, 1).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
            else:
                l_kv = self.l_kv(x).permute(0, 2, 3, 1).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
            l_k, l_v = l_kv[0], l_kv[1]

            l_attn = (l_q @ l_k.transpose(-2, -1)) * self.scale
            l_attn = l_attn.softmax(dim=-1)

            l_x = (l_attn @ l_v).transpose(1, 2).reshape(B, H, W, self.l_dim).permute(0, 3, 1, 2)
            l_x = self.l_proj(l_x).permute(0, 2, 3, 1)


        if self.h_heads!=0:
            h_group, w_group = H // self.ws, W // self.ws
            total_groups = h_group * w_group
            h_qkv = self.h_qkv(high_feats).permute(0, 2, 3, 1).\
                reshape(B, h_group, self.ws, w_group, self.ws, 3*self.h_dim).\
                transpose(2, 3).reshape(B, total_groups, -1, 3, self.h_heads,
                                        self.h_dim // self.h_heads).permute(3, 0, 1, 4, 2, 5)
            h_q, h_k, h_v = h_qkv[0], h_qkv[1], h_qkv[2]

            h_attn = (h_q @ h_k.transpose(-2, -1)) * self.scale
            h_attn = h_attn.softmax(dim=-1)
            h_attn = (h_attn @ h_v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
            h_x = h_attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim).permute(0, 3, 1, 2)

            h_x = self.h_proj(h_x).permute(0, 2, 3, 1)


        if self.h_heads!=0 and self.l_heads!=0:
            out = torch.cat([l_x, h_x], dim=-1)
            out = out.reshape(B, N, C)

        if self.l_heads==0:
            out = h_x.reshape(B, N, C)

        if self.h_heads==0:
            out = l_x.reshape(B, N, C)

        return out

    def forward(self, x):
        return self.hi_lofi(x)

class FFN1(nn.Module):
    def __init__(self, dim, h_dim=None, out_dim=None):
        super().__init__()
        self.h_dim = dim*2 if h_dim==None else h_dim
        self.out_dim = dim if out_dim==None else out_dim

        self.act = nn.GELU()
        self.fc1 = DSC(dim, self.h_dim)
        self.norm = nn.BatchNorm2d(self.out_dim)
        self.fc2 = DSC(self.h_dim, self.h_dim)
        self.fc3 = IDSC(self.h_dim, self.out_dim)

    def forward(self, x):
        B, N, C = x.shape
        H = W = int(N**0.5)
        x = x.reshape(B, H, W, C).permute(0, 3, 1, 2)
        x = self.act(self.fc3(self.act(self.fc2(self.act(self.fc1(x))))))
        x = self.norm(x).reshape(B, C, -1).permute(0, 2, 1)

        return x

class Block1(nn.Module):
    def __init__(self, dim, num_heads=8, window_size=2, alpha=0.5, qkv_bias=False, qk_scale=None, h_dim=None, out_dim=None):
        super().__init__()
        self.hilo = FCHiLo1(dim, num_heads, qkv_bias, qk_scale, window_size, alpha)
        self.ffn = FFN1(dim, h_dim, out_dim)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)
    def forward(self, x):
        x = x + self.norm1(self.hilo(x))
        x = x + self.norm2(self.ffn(x))
        return x


if __name__ == '__main__':

    input = torch.randn(1, 1024, 64) # B N C

    block1 = Block1(64)
    print(input.size())
    output_block1 = block1(input)
    print(output_block1.size())

    ffn1 = FFN1(64)
    print(input.size())
    output_ffn1 = ffn1(input)
    print(output_ffn1.size())

    # Instantiate FCHiLo1
    fchilo1 = FCHiLo1(64)
    print(input.size())
    output_fchilo1 = fchilo1(input)    print(output_fchilo1.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章