(ICME 2023)即插即用卷积线性融合Transformer,助力小目标检测起飞

文摘   2024-09-29 17:20   天津  

论文介绍

题目:ABC: Attention with Bilinear Correlation for Infrared Small Target Detection

论文地址: https://arxiv.org/pdf/2303.10321

年份:2023

创新点

  • 卷积线性融合Transformer(CLFT)模块:该模块结合了Transformer结构的全局特征提取能力和卷积神经网络(CNN)的局部特征提取能力,有效增强了目标特征并抑制了噪声。CLFT模块通过引入双线性注意力模块(BAM)计算注意力矩阵,使得模型能够感知目标的位置,从而有效地处理小目标红外检测中的噪声干扰和目标丢失问题。

  • U形卷积-膨胀卷积(UCDC)模块:该模块位于网络的深层,通过卷积层和膨胀卷积层的结合,进一步精细处理已经下采样的特征图,提取更精细的语义信息。通过使用大感受野和跳跃连接,UCDC模块能够有效处理小分辨率的特征图,提升检测效果。

  • 整体架构设计:论文提出的模型采用了编码器-解码器结构,类似于UNet,通过在编码器中引入CLFT模块,在解码器中使用UCDC模块,并通过跨层跳跃连接实现特征融合,从而提高了红外小目标检测任务的性能。


  • 实验结果:在多个公共数据集上的实验结果表明,论文提出的方法在指标上优于现有的最先进方法,特别是在噪声抑制和目标增强方面有显著提升。


方法

整体结构

       论文中的模型采用了编码器-解码器结构,结合了卷积神经网络(CNN)和Transformer的特性。编码器部分通过卷积模块和卷积线性融合Transformer(CLFT)模块提取并融合局部和全局特征,增强目标特征并抑制噪声;解码器部分利用U形卷积-膨胀卷积(UCDC)模块和卷积模块进一步精细处理特征图,并通过跳跃连接实现跨层特征融合,最终输出目标检测结果。


  • 编码器:编码器部分由一个卷积模块和三个CLFT模块组成。卷积模块用于提取初步的局部特征,CLFT模块结合了卷积和Transformer的特性,进行局部和全局特征的融合,增强目标特征并抑制噪声。

  • 解码器:解码器部分由一个UCDC模块和三个卷积模块组成。UCDC模块采用U形结构,由卷积层和膨胀卷积层组成,处理更深层次的特征图,提取细致的语义信息并进一步消除噪声。卷积模块则用于恢复特征图的空间分辨率。

  • 跳跃连接:模型中使用了跨层跳跃连接(skip connections),在编码器和解码器之间传递特征图,实现跨层特征融合,确保低层次特征和高层次语义信息能够有效结合。

  • 最终输出:在解码器之后,模型通过一个逐点卷积层作为分割头输出最终的目标检测结果。

实验结果

  • 展示了不同最先进方法(SOTA)在多个数据集(NUAA、IRSTD1k、SIRSTAUG、NUDT)上的IoU(交并比)、nIoU(标准化交并比)和F1分数的实验结果,表明论文提出的ABC模型在所有数据集中表现出最好的性能。

  • 展示了不同方法在NUAA数据集和IRSTD1k数据集上的ROC曲线,显示了ABC模型在不同数据集上显著优于其他方法的性能。

  • 提供了在NUAA和IRSTD1k数据集上不同方法的部分图像的可视化结果,展示了ABC模型在检测小目标时更具鲁棒性,能够有效抑制噪声并避免漏检。

  • 比较了不同方法在NUAA数据集上的计算成本(FLOPs)和每秒帧数(FPS),表明ABC模型在保持较高性能的同时,计算成本相对较低,尤其是ABC-S模型表现出较好的平衡。

CLFT模块作用

      将卷积线性融合Transformer(CLFT)模块作为即插即用模块提取出来,可以适用于需要结合局部和全局特征提取的多种场景,尤其是以下几类任务中可以发挥重要作用:

  1. 目标检测和语义分割:CLFT模块结合了CNN的局部感知能力和Transformer的全局建模能力,能够增强目标特征并抑制背景噪声,因此特别适合复杂场景下的小目标检测和语义分割任务,如红外图像中的小目标检测、医学图像分析中的病灶检测等。


  2. 弱特征目标的检测:在需要检测弱特征或模糊目标的任务中,CLFT模块能够通过自注意力机制有效捕捉全局上下文信息,同时利用卷积增强局部细节,因此适用于夜视、低光照、遥感图像等场景中的目标检测。


  3. 高噪声环境中的特征提取:对于背景噪声较多的图像或数据,CLFT模块能够有效过滤掉噪声并增强目标特征。因此,可以应用于工业检测、自动驾驶中的障碍物检测等需要在高噪声环境下进行的任务。


  4. 多尺度特征提取:CLFT模块能够同时提取局部和全局特征,使其在需要处理不同尺度的特征时表现出色,例如在自然场景下的人物或物体识别任务中,它可以同时处理细节和全局信息。


即插即用模块

import torch
import torch.nn as nn
from einops import rearrange
#论文:ABC: Attention with Bilinear Correlation for Infrared Small Target Detection ICME2023
#论文地址:https://arxiv.org/pdf/2303.10321

def conv_relu_bn(in_channel, out_channel, dirate):
    return nn.Sequential(
        nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=dirate,
                  dilation=dirate),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(inplace=True)
    )

#bilinear attention module (BAM)
class BAM(nn.Module):
    def __init__(self, in_dim, in_feature, out_feature):
        super(BAM, self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1)
        self.query_line = nn.Linear(in_features=in_feature, out_features=out_feature)
        self.key_line = nn.Linear(in_features=in_feature, out_features=out_feature)
        self.s_conv = nn.Conv2d(in_channels=1, out_channels=in_dim, kernel_size=1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        q = rearrange(self.query_line(rearrange(self.query_conv(x), 'b 1 h w -> b (h w)')), 'b h -> b h 1')
        k = rearrange(self.key_line(rearrange(self.key_conv(x), 'b 1 h w -> b (h w)')), 'b h -> b 1 h')
        att = rearrange(torch.matmul(q, k), 'b h w -> b 1 h w')
        att = self.softmax(self.s_conv(att))
        return att

class Conv(nn.Module):
    def __init__(self, in_dim):
        super(Conv, self).__init__()
        self.convs = nn.ModuleList([conv_relu_bn(in_dim, in_dim, 1) for _ in range(3)])

    def forward(self, x):
        for conv in self.convs:
            x = conv(x)
        return x

#dilated convolution layers(DConv)
class DConv(nn.Module):
    def __init__(self, in_dim):
        super(DConv, self).__init__()
        dilation = [2, 4, 2]
        self.dconvs = nn.ModuleList([conv_relu_bn(in_dim, in_dim, dirate) for dirate in dilation])

    def forward(self, x):
        for dconv in self.dconvs:
            x = dconv(x)
        return x

class ConvAttention(nn.Module):
    def __init__(self, in_dim, in_feature, out_feature):
        super(ConvAttention, self).__init__()
        self.conv = Conv(in_dim)
        self.dconv = DConv(in_dim)
        self.att = BAM(in_dim, in_feature, out_feature)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        q = self.conv(x)
        k = self.dconv(x)
        v = q + k
        att = self.att(x)
        out = torch.matmul(att, v)
        return self.gamma * out + v + x


class FeedForward(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(FeedForward, self).__init__()
        self.conv = conv_relu_bn(in_dim, out_dim, 1)
        # self.x_conv = nn.Conv2d(in_dim, out_dim, kernel_size=1)
        self.x_conv = nn.Sequential(
            nn.Conv2d(in_dim, out_dim, kernel_size=1),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        out = self.conv(x)
        x = self.x_conv(x)
        return x + out

#convolution linear fusion transformer (CLFT)
class CLFT(nn.Module):
    def __init__(self, in_dim, out_dim, in_feature, out_feature):
        super(CLFT, self).__init__()
        self.attention = ConvAttention(in_dim, in_feature, out_feature)
        self.feedforward = FeedForward(in_dim, out_dim)

    def forward(self, x):
        x = self.attention(x)
        out = self.feedforward(x)
        return out


if __name__ == '__main__':


    block = CLFT(64,64,32*32,32) # 输入通道数,输出通道数 图像大小 H*W,H or W


    input = torch.randn(3, 64, 32, 32) #输入tensor形状 B C H W

    # Print input shape
    print(input.size()) # 输入形状

    # Pass the input tensor through the model
    output = block(input)

    # Print output shape
    print(output.size()) # 输出形状

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章