论文介绍
题目:ABC: Attention with Bilinear Correlation for Infrared Small Target Detection
论文地址: https://arxiv.org/pdf/2303.10321
年份:2023
创新点
卷积线性融合Transformer(CLFT)模块:该模块结合了Transformer结构的全局特征提取能力和卷积神经网络(CNN)的局部特征提取能力,有效增强了目标特征并抑制了噪声。CLFT模块通过引入双线性注意力模块(BAM)计算注意力矩阵,使得模型能够感知目标的位置,从而有效地处理小目标红外检测中的噪声干扰和目标丢失问题。
U形卷积-膨胀卷积(UCDC)模块:该模块位于网络的深层,通过卷积层和膨胀卷积层的结合,进一步精细处理已经下采样的特征图,提取更精细的语义信息。通过使用大感受野和跳跃连接,UCDC模块能够有效处理小分辨率的特征图,提升检测效果。
整体架构设计:论文提出的模型采用了编码器-解码器结构,类似于UNet,通过在编码器中引入CLFT模块,在解码器中使用UCDC模块,并通过跨层跳跃连接实现特征融合,从而提高了红外小目标检测任务的性能。
实验结果:在多个公共数据集上的实验结果表明,论文提出的方法在指标上优于现有的最先进方法,特别是在噪声抑制和目标增强方面有显著提升。
方法
整体结构
编码器:编码器部分由一个卷积模块和三个CLFT模块组成。卷积模块用于提取初步的局部特征,CLFT模块结合了卷积和Transformer的特性,进行局部和全局特征的融合,增强目标特征并抑制噪声。
解码器:解码器部分由一个UCDC模块和三个卷积模块组成。UCDC模块采用U形结构,由卷积层和膨胀卷积层组成,处理更深层次的特征图,提取细致的语义信息并进一步消除噪声。卷积模块则用于恢复特征图的空间分辨率。
跳跃连接:模型中使用了跨层跳跃连接(skip connections),在编码器和解码器之间传递特征图,实现跨层特征融合,确保低层次特征和高层次语义信息能够有效结合。
最终输出:在解码器之后,模型通过一个逐点卷积层作为分割头输出最终的目标检测结果。
实验结果
展示了不同最先进方法(SOTA)在多个数据集(NUAA、IRSTD1k、SIRSTAUG、NUDT)上的IoU(交并比)、nIoU(标准化交并比)和F1分数的实验结果,表明论文提出的ABC模型在所有数据集中表现出最好的性能。
展示了不同方法在NUAA数据集和IRSTD1k数据集上的ROC曲线,显示了ABC模型在不同数据集上显著优于其他方法的性能。
提供了在NUAA和IRSTD1k数据集上不同方法的部分图像的可视化结果,展示了ABC模型在检测小目标时更具鲁棒性,能够有效抑制噪声并避免漏检。
比较了不同方法在NUAA数据集上的计算成本(FLOPs)和每秒帧数(FPS),表明ABC模型在保持较高性能的同时,计算成本相对较低,尤其是ABC-S模型表现出较好的平衡。
CLFT模块作用
将卷积线性融合Transformer(CLFT)模块作为即插即用模块提取出来,可以适用于需要结合局部和全局特征提取的多种场景,尤其是以下几类任务中可以发挥重要作用:
目标检测和语义分割:CLFT模块结合了CNN的局部感知能力和Transformer的全局建模能力,能够增强目标特征并抑制背景噪声,因此特别适合复杂场景下的小目标检测和语义分割任务,如红外图像中的小目标检测、医学图像分析中的病灶检测等。
弱特征目标的检测:在需要检测弱特征或模糊目标的任务中,CLFT模块能够通过自注意力机制有效捕捉全局上下文信息,同时利用卷积增强局部细节,因此适用于夜视、低光照、遥感图像等场景中的目标检测。
高噪声环境中的特征提取:对于背景噪声较多的图像或数据,CLFT模块能够有效过滤掉噪声并增强目标特征。因此,可以应用于工业检测、自动驾驶中的障碍物检测等需要在高噪声环境下进行的任务。
多尺度特征提取:CLFT模块能够同时提取局部和全局特征,使其在需要处理不同尺度的特征时表现出色,例如在自然场景下的人物或物体识别任务中,它可以同时处理细节和全局信息。
即插即用模块
import torch
import torch.nn as nn
from einops import rearrange
#论文:ABC: Attention with Bilinear Correlation for Infrared Small Target Detection ICME2023
#论文地址:https://arxiv.org/pdf/2303.10321
def conv_relu_bn(in_channel, out_channel, dirate):
return nn.Sequential(
nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=dirate,
dilation=dirate),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
#bilinear attention module (BAM)
class BAM(nn.Module):
def __init__(self, in_dim, in_feature, out_feature):
super(BAM, self).__init__()
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=1, kernel_size=1)
self.query_line = nn.Linear(in_features=in_feature, out_features=out_feature)
self.key_line = nn.Linear(in_features=in_feature, out_features=out_feature)
self.s_conv = nn.Conv2d(in_channels=1, out_channels=in_dim, kernel_size=1)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
q = rearrange(self.query_line(rearrange(self.query_conv(x), 'b 1 h w -> b (h w)')), 'b h -> b h 1')
k = rearrange(self.key_line(rearrange(self.key_conv(x), 'b 1 h w -> b (h w)')), 'b h -> b 1 h')
att = rearrange(torch.matmul(q, k), 'b h w -> b 1 h w')
att = self.softmax(self.s_conv(att))
return att
class Conv(nn.Module):
def __init__(self, in_dim):
super(Conv, self).__init__()
self.convs = nn.ModuleList([conv_relu_bn(in_dim, in_dim, 1) for _ in range(3)])
def forward(self, x):
for conv in self.convs:
x = conv(x)
return x
#dilated convolution layers(DConv)
class DConv(nn.Module):
def __init__(self, in_dim):
super(DConv, self).__init__()
dilation = [2, 4, 2]
self.dconvs = nn.ModuleList([conv_relu_bn(in_dim, in_dim, dirate) for dirate in dilation])
def forward(self, x):
for dconv in self.dconvs:
x = dconv(x)
return x
class ConvAttention(nn.Module):
def __init__(self, in_dim, in_feature, out_feature):
super(ConvAttention, self).__init__()
self.conv = Conv(in_dim)
self.dconv = DConv(in_dim)
self.att = BAM(in_dim, in_feature, out_feature)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
q = self.conv(x)
k = self.dconv(x)
v = q + k
att = self.att(x)
out = torch.matmul(att, v)
return self.gamma * out + v + x
class FeedForward(nn.Module):
def __init__(self, in_dim, out_dim):
super(FeedForward, self).__init__()
self.conv = conv_relu_bn(in_dim, out_dim, 1)
# self.x_conv = nn.Conv2d(in_dim, out_dim, kernel_size=1)
self.x_conv = nn.Sequential(
nn.Conv2d(in_dim, out_dim, kernel_size=1),
nn.BatchNorm2d(out_dim),
nn.ReLU(inplace=True)
)
def forward(self, x):
out = self.conv(x)
x = self.x_conv(x)
return x + out
#convolution linear fusion transformer (CLFT)
class CLFT(nn.Module):
def __init__(self, in_dim, out_dim, in_feature, out_feature):
super(CLFT, self).__init__()
self.attention = ConvAttention(in_dim, in_feature, out_feature)
self.feedforward = FeedForward(in_dim, out_dim)
def forward(self, x):
x = self.attention(x)
out = self.feedforward(x)
return out
if __name__ == '__main__':
block = CLFT(64,64,32*32,32) # 输入通道数,输出通道数 图像大小 H*W,H or W
input = torch.randn(3, 64, 32, 32) #输入tensor形状 B C H W
# Print input shape
print(input.size()) # 输入形状
# Pass the input tensor through the model
output = block(input)
# Print output shape
print(output.size()) # 输出形状
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文