论文介绍
题目:STNet: Spatial and Temporal feature fusion network for change detection in remote sensing images
论文地址:https://arxiv.org/pdf/2304.11422
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
跨时间特征融合模块(TFF):
提出了一种基于跨时间门控机制的特征融合模块,用于双时相特征的融合。
通过选择性地增强目标变化信息并抑制非目标变化,提升了变化检测的准确性。
跨尺度特征融合模块(SFF):
首次采用跨尺度注意力机制,利用高层次特征引导低层次特征的建模。
该机制能够捕捉变化目标的细粒度空间信息,恢复变化表示的空间细节。
高效的多尺度特征交互设计:
提出了一个轻量化的深度神经网络框架,在多个尺度上进行特征交互,综合了语义信息与空间细节。
使用了ResNet-18作为特征提取的骨干网络,结合TFF和SFF模块,有效减少了参数量和计算成本。
性能上的显著提升:
在三个遥感变化检测的基准数据集上(WHU、LEVIR-CD 和 CLCD),STNet在F1分数、IoU、整体准确率等指标上取得了领先的表现。
在空间复杂度和计算开销上,STNet的参数量(14.6M)和计算量(9.61G FLOPs)显著低于许多现有方法。
轻量化设计:
在TFF模块中使用深度可分离卷积(depth-wise separable convolution),在减少计算量的同时保持了模型的性能。
创新的损失函数:
采用了结合Focal Loss和Dice Loss的混合损失函数,解决了变化检测中正负样本不平衡的问题。
方法
整体架构
STNet 的整体结构包括以下几个部分:输入双时相遥感影像后,使用共享权重的 ResNet-18 提取多尺度特征;通过时间特征融合模块(TFF)采用跨时间门控机制,增强目标变化信息并抑制非目标变化;利用空间特征融合模块(SFF)通过跨尺度注意力机制,结合高层次语义信息与低层次空间细节,恢复变化表示的空间细节;最后通过轻量化解码器将多尺度特征拼接,并结合通道注意力模块(CAM)生成与输入影像尺寸一致的高精度变化检测图。
1. 输入与特征提取
输入:双时相的遥感图像
和T 1 T_1 ,它们是空间配准后的影像对。T 2 T_2 特征提取:使用共享权重的 ResNet-18 作为骨干网络,逐层提取多尺度双时相特征。
ResNet-18 提供了 4 个残差块的输出特征,分别为不同尺度的多层特征表示。
2. 时间特征融合模块(TFF)
目标:通过跨时间门控机制融合双时相特征,强调目标变化并抑制非目标变化。
工作流程:
对双时相特征
和R 1 R_1 进行逐元素相减,得到初步的粗粒度变化表示R 2 R_2 。R c R_c 将
与R c R_c 、R 1 R_1 分别进行拼接,并通过深度可分离卷积提取特征,生成权重R 2 R_2 和W 1 W_1 。W 2 W_2 使用门控机制,通过权重调整融合
和R 1 R_1 ,生成时间特征融合的结果R 2 R_2 。R t R_t
3. 空间特征融合模块(SFF)
目标:通过跨尺度注意力机制融合多尺度特征,恢复变化表示的空间细节。
工作流程:
将高层次特征(语义信息丰富但边界不精确)与低层次特征(包含更多空间细节)进行交互。
使用注意力机制计算像素间的关系,使高层次特征指导低层次特征的细化。
融合后的特征包含更高质量的语义信息和空间细节。
4. 解码器与变化检测图生成
轻量化解码器:
将各尺度的变化表示上采样到统一尺寸,并沿通道方向拼接。
使用通道注意力模块(Channel Attention Module, CAM)进一步增强特征。
输出:通过最终的上采样操作,生成与输入影像尺寸一致的变化检测图。
5. 损失函数
混合损失函数:结合 Focal Loss 和 Dice Loss,解决正负样本不平衡问题,增强对变化区域的敏感性。
即插即用模块作用
TFF模块的作用:时间特征增强
突出变化区域:通过跨时间门控机制,选择性增强变化的目标区域,同时抑制非目标区域的干扰。
轻量化设计:使用深度可分离卷积,减少参数量和计算成本,使其适合嵌入到多种深度学习框架中。
时间信息提炼:适用于需要结合时间维度进行变化分析的任务。
SFF模块的作用:空间特征细化
细粒度空间细节提取:通过跨尺度注意力机制,融合高层次语义信息和低层次空间细节,增强目标区域的表示能力。
边界恢复与语义细化:适合于对目标边界和结构信息要求较高的场景。
跨尺度交互:提高模型对多尺度目标的检测能力。
消融实验结果
内容:
该表通过对比基础模型(Base)与添加 TFF 模块、SFF 模块以及完整模型(STNet)的性能差异,验证了 TFF 和 SFF 模块的有效性。
性能指标包括 F1 分数(F1)、精确率(Pre.)、召回率(Rec.)、交并比(IoU)和总体准确率(OA)。
结论:
单独加入 TFF 模块(Base + TFF)或 SFF 模块(Base + SFF)都显著提升了模型性能。
同时引入 TFF 和 SFF 模块(完整模型 STNet)取得了最佳效果,表明这两个模块的联合使用具有协同增益。
内容:
该表比较了 STNet 和其他方法在参数数量(Params, M)及计算量(FLOPs, G)方面的差异。
结论:
STNet 以相对较少的参数量(14.6M)和最低的计算成本(9.61G FLOPs),实现了在性能上超越现有方法的效果,证明其轻量化设计的有效性。
即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
def dsconv_3x3(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
def conv_1x1(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
class TFF(nn.Module):
def __init__(self, in_channel, out_channel):
super(TFF, self).__init__()
self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
self.catconv = dsconv_3x3(in_channel * 2, out_channel)
self.convA = nn.Conv2d(in_channel, 1, 1)
self.convB = nn.Conv2d(in_channel, 1, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, xA, xB):
x_diff = xA - xB
x_diffA = self.catconvA(torch.cat([x_diff, xA], dim=1))
x_diffB = self.catconvB(torch.cat([x_diff, xB], dim=1))
A_weight = self.sigmoid(self.convA(x_diffA))
B_weight = self.sigmoid(self.convB(x_diffB))
xA = A_weight * xA
xB = B_weight * xB
x = self.catconv(torch.cat([xA, xB], dim=1))
return x
if __name__ == '__main__':
in_channel = 3
out_channel = 3
block = TFF(in_channel, out_channel)
# Create dummy inputs
xA = torch.rand(1, in_channel, 32, 32)
xB = torch.rand(1, in_channel, 32, 32)
# Forward pass
output = block(xA, xB)
print(f"Input A size: {xA.size()}")
print(f"Input B size: {xB.size()}")
print(f"Output size: {output.size()}")
class SelfAttentionBlock(nn.Module):
"""
query_feats: (B, C, h, w)
key_feats: (B, C, h, w)
value_feats: (B, C, h, w)
output: (B, C, h, w)
"""
def __init__(self, key_in_channels, query_in_channels, transform_channels, out_channels,
key_query_num_convs, value_out_num_convs):
super(SelfAttentionBlock, self).__init__()
self.key_project = self.buildproject(
in_channels=key_in_channels,
out_channels=transform_channels,
num_convs=key_query_num_convs,
)
self.query_project = self.buildproject(
in_channels=query_in_channels,
out_channels=transform_channels,
num_convs=key_query_num_convs
)
self.value_project = self.buildproject(
in_channels=key_in_channels,
out_channels=transform_channels,
num_convs=value_out_num_convs
)
self.out_project = self.buildproject(
in_channels=transform_channels,
out_channels=out_channels,
num_convs=value_out_num_convs
)
self.transform_channels = transform_channels
def forward(self, query_feats, key_feats, value_feats):
batch_size = query_feats.size(0)
query = self.query_project(query_feats)
query = query.reshape(*query.shape[:2], -1)
query = query.permute(0, 2, 1).contiguous() # (B, h*w, C)
key = self.key_project(key_feats)
key = key.reshape(*key.shape[:2], -1) # (B, C, h*w)
value = self.value_project(value_feats)
value = value.reshape(*value.shape[:2], -1)
value = value.permute(0, 2, 1).contiguous() # (B, h*w, C)
sim_map = torch.matmul(query, key)
sim_map = (self.transform_channels ** -0.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1) # (B, h*w, K)
context = torch.matmul(sim_map, value) # (B, h*w, C)
context = context.permute(0, 2, 1).contiguous()
context = context.reshape(batch_size, -1, *query_feats.shape[2:]) # (B, C, h, w)
context = self.out_project(context) # (B, C, h, w)
return context
def buildproject(self, in_channels, out_channels, num_convs):
convs = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
for _ in range(num_convs - 1):
convs.append(
nn.Sequential(
nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
)
if len(convs) > 1:
return nn.Sequential(*convs)
return convs[0]
def conv_3x3(in_channel, out_channel):
return nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True)
)
class SFF(nn.Module):
def __init__(self, in_channel):
super(SFF, self).__init__()
self.conv_small = conv_1x1(in_channel, in_channel)
self.conv_big = conv_1x1(in_channel, in_channel)
self.catconv = conv_3x3(in_channel *2, in_channel)
self.attention = SelfAttentionBlock(
key_in_channels=in_channel,
query_in_channels = in_channel,
transform_channels = in_channel // 2,
out_channels = in_channel,
key_query_num_convs=2,
value_out_num_convs=1
)
def forward(self, x_small, x_big):
img_size =x_big.size(2), x_big.size(3)
x_small = F.interpolate(x_small, img_size, mode="bilinear", align_corners=False)
x = self.conv_small(x_small) + self.conv_big(x_big)
new_x = self.attention(x, x, x_big)
out = self.catconv(torch.cat([new_x, x_big], dim=1))
return out
if __name__ == '__main__':
block = SFF(3)
x_small = torch.rand(1, 3, 32, 32)
x_big = torch.rand(1, 3, 32, 32)
output = block(x_small, x_big)
print(x_small.size())
print(x_big.size()) print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文