论文介绍
题目:Robust change detection for remote sensing images based on temporospatial interactive attention module
论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001213
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
提出了一种新的变化检测框架CDNeXt:
包含编码器(Encoder)、交互器(Interactor)、解码器(Decoder)和检测器(Detector)四部分。
特别设计了一个时空交互注意力模块(Temporospatial Interactive Attention Module, TIAM),用于解决遥感图像中几何视角旋转和时间风格差异带来的问题。
引入时空交互注意力模块(TIAM):
能够查询并重构时空特征的依赖关系和风格相关性,缓解不同时间和空间特征之间的干扰。
支持模块化嵌入,使其可以集成到其他视觉任务中。
显著的性能提升:
在多个基准数据集(如SYSU-CD、LEVIR-CD+、S2Looking、BANDON)上取得了最新的性能(State-of-the-Art, SOTA),例如在SYSU-CD上达到F1分数82.63%,IoU 70.39%。
模块设计高效且具有通用性:
不需要特殊的训练技术或参数调优,能够以端到端方式输出变化检测结果。TIAM模块显著降低了因视角和风格差异造成的伪变化检测,同时对主流方法的性能提升具有潜力。
方法
整体架构
CDNeXt 是一个系统化的遥感图像变化检测框架,包含编码器(Encoder)、交互器(Interactor)、解码器(Decoder)和检测器(Detector)四个模块。编码器提取多尺度特征,交互器通过引入时空交互注意力模块(TIAM)建模空间视角相关性和时间风格差异,增强变化特征的表达;解码器逐层上采样融合特征,同时利用残差压缩模块优化特征表示;检测器整合分层特征,输出高精度的二值变化掩码。该设计有效解决了视角旋转和时间风格差异引起的伪变化检测问题,显著提升检测精度和鲁棒性。
1. 编码器(Encoder)
功能:从双时相图像(例如
和T A T_A )中提取多层次特征。T B T_B 实现:
使用预训练的骨干网络(如 ConvNeXt)提取层次化的金字塔特征。
特征通过下采样逐层提取,形成多尺度的特征表示。
编码器输出每个时间步的特征
,并为后续模块提供基础输入。{ F 1 , F 2 , F 3 , F 4 } \{ F_1, F_2, F_3, F_4 \}
2. 交互器(Interactor)
核心模块:时空交互注意力模块(Temporospatial Interactive Attention Module, TIAM)。
功能:
对同一层级的双时相特征(例如
和F A F_A )进行查询和重建。F B F_B 提取和建模 时空依赖性(如空间视角相关性)和 时间风格差异(如光照、阴影、天气变化)。
特点:
TIAM 利用嵌入式高斯函数计算注意力分数,通过矩阵操作加权重构特征。
减少了伪变化检测(例如视角旋转引起的误报)。
3. 解码器(Decoder)
功能:
接受交互器输出的特征并进行逐层上采样。
通过特征维度压缩(Feature Squeeze Residual,FSR)模块对特征进行降维与残差增强。
特点:
解码器建立了交互器与解码器之间的跳跃连接,增强了梯度传播的效果。
高效融合了低层次的纹理特征和高层次的语义特征。
4. 检测器(Detector)
功能:
整合解码器生成的分层特征,生成最终的二值变化掩码(Change Mask)。
特点:
检测器通过全连接卷积层将多层特征融合,并使用 Softmax 层输出像素级的变化检测结果。
能同时保留物体的边界细节和内部完整性。
即插即用模块作用
TIAM 作为一个即插即用模块:
缓解伪变化干扰:通过构建空间透视依赖矩阵(Spatial Perspective Dependencies),减轻因几何视角变化(如不同拍摄角度)导致的伪变化检测问题。通过构建时间风格相关性矩阵(Temporal Style Correlations),减少因光照、天气、季节等因素引起的视觉风格差异带来的误报。
提升特征交互效率:模块以全局交互注意力为核心,能够高效提取双时相图像的语义不变性和变化特征,强化时空特征的互信息。
增强模型鲁棒性和通用性:TIAM 可嵌入其他模型,如本文中实验的 IFNet 和 FC-Conc 模型,显著提升其在变化检测任务中的准确性,证明了模块的适配性和通用性。
降低计算复杂度:TIAM 在四分之一尺度特征图上操作,计算复杂度较低(见论文表 6 的复杂度分析),同时保留了全局特征交互的优势。
消融实验结果
不同注意力机制的消融实验:通过在 SYSU-CD 和 S2Looking 数据集上的实验,比较了 TIAM 与其他主流注意力机制(如 SAM、DAM、CBAM 等)的性能。实验结果显示,TIAM 在 F1 分数和 IoU 方面均优于其他方法,并验证了时空特征交互的重要性,尤其是在应对复杂变化场景时。
框架组件的消融实验:该表验证了 CDNeXt 框架中不同组件的贡献,包括骨干网络(ResNet18 与 ConvNeXt 的对比)、特征融合模块(FS)和 TIAM 的作用。结果表明,ConvNeXt 提高了特征提取能力,而 FS 和 TIAM 的加入显著提升了 F1 和 IoU,特别是 TIAM 在捕获变化特征上至关重要。
复杂度分析:该表对比了模型的参数规模(Params)、计算复杂度(FLOPs)和检测性能(F1)。结果表明,CDNeXt 在保持较低计算复杂度的同时实现了最佳性能。同时,TIAM 模块嵌入其他方法(如 FC-Conc 和 IFNet)后,显著提升了准确性,验证了其通用性和高效性。
即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
#论文:Robust change detection for remote sensing images based on temporospatial interactive attention module
#论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001213
class SpatiotemporalAttentionFull(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
super(SpatiotemporalAttentionFull, self).__init__()
assert dimension in [2, ]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
)
self.W = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.theta = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.phi = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.energy_time_1_sf = nn.Softmax(dim=-1)
self.energy_time_2_sf = nn.Softmax(dim=-1)
self.energy_space_2s_sf = nn.Softmax(dim=-2)
self.energy_space_1s_sf = nn.Softmax(dim=-2)
def forward(self, x1, x2):
batch_size = x1.size(0)
g_x11 = self.g(x1).reshape(batch_size, self.inter_channels, -1)
g_x12 = g_x11.permute(0, 2, 1)
g_x21 = self.g(x2).reshape(batch_size, self.inter_channels, -1)
g_x22 = g_x21.permute(0, 2, 1)
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
theta_x2 = theta_x1.permute(0, 2, 1)
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
phi_x2 = phi_x1.permute(0, 2, 1)
energy_time_1 = torch.matmul(theta_x1, phi_x2)
energy_time_2 = energy_time_1.permute(0, 2, 1)
energy_space_1 = torch.matmul(theta_x2, phi_x1)
energy_space_2 = energy_space_1.permute(0, 2, 1)
energy_time_1s = self.energy_time_1_sf(energy_time_1)
energy_time_2s = self.energy_time_2_sf(energy_time_2)
energy_space_2s = self.energy_space_2s_sf(energy_space_1)
energy_space_1s = self.energy_space_1s_sf(energy_space_2)
# energy_time_2s*g_x11*energy_space_2s = C2*S(C1) × C1*H1W1 × S(H1W1)*H2W2 = (C2*H2W2)' is rebuild C1*H1W1
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
# energy_time_1s*g_x12*energy_space_1s = C1*S(C2) × C2*H2W2 × S(H2W2)*H1W1 = (C1*H1W1)' is rebuild C2*H2W2
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous()
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
return x1 + self.W(y1), x2 + self.W(y2)
class SpatiotemporalAttentionBase(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
super(SpatiotemporalAttentionBase, self).__init__()
assert dimension in [2, ]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
)
self.W = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.theta = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.phi = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.energy_space_2s_sf = nn.Softmax(dim=-2)
self.energy_space_1s_sf = nn.Softmax(dim=-2)
def forward(self, x1, x2):
"""
:param x: (b, c, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x1.size(0)
g_x11 = self.g(x1).reshape(batch_size, self.inter_channels, -1)
g_x21 = self.g(x2).reshape(batch_size, self.inter_channels, -1)
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
theta_x2 = theta_x1.permute(0, 2, 1)
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
energy_space_1 = torch.matmul(theta_x2, phi_x1)
energy_space_2 = energy_space_1.permute(0, 2, 1)
energy_space_2s = self.energy_space_2s_sf(energy_space_1) # S(H1W1)*H2W2
energy_space_1s = self.energy_space_1s_sf(energy_space_2) # S(H2W2)*H1W1
# g_x11*energy_space_2s = C1*H1W1 × S(H1W1)*H2W2 = (C1*H2W2)' is rebuild C1*H1W1
y1 = torch.matmul(g_x11, energy_space_2s).contiguous() # C2*H2W2
# g_x21*energy_space_1s = C2*H2W2 × S(H2W2)*H1W1 = (C2*H1W1)' is rebuild C2*H2W2
y2 = torch.matmul(g_x21, energy_space_1s).contiguous()
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
return x1 + self.W(y1), x2 + self.W(y2)
class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
assert dimension in [2, ]
self.dimension = dimension
self.sub_sample = sub_sample
self.in_channels = in_channels
self.inter_channels = inter_channels
if self.inter_channels is None:
self.inter_channels = in_channels // 2
if self.inter_channels == 0:
self.inter_channels = 1
self.g1 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0)
)
self.g2 = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.W1 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.W2 = nn.Sequential(
nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(self.in_channels)
)
self.theta = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
self.phi = nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
kernel_size=1, stride=1, padding=0),
)
def forward(self, x1, x2):
"""
:param x: (b, c, h, w)
:param return_nl_map: if True return z, nl_map, else only return z.
:return:
"""
batch_size = x1.size(0)
g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
g_x12 = g_x11.permute(0, 2, 1)
g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
g_x22 = g_x21.permute(0, 2, 1)
theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
theta_x2 = theta_x1.permute(0, 2, 1)
phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
phi_x2 = phi_x1.permute(0, 2, 1)
energy_time_1 = torch.matmul(theta_x1, phi_x2)
energy_time_2 = energy_time_1.permute(0, 2, 1)
energy_space_1 = torch.matmul(theta_x2, phi_x1)
energy_space_2 = energy_space_1.permute(0, 2, 1)
energy_time_1s = F.softmax(energy_time_1, dim=-1)
energy_time_2s = F.softmax(energy_time_2, dim=-1)
energy_space_2s = F.softmax(energy_space_1, dim=-2)
energy_space_1s = F.softmax(energy_space_2, dim=-2)
# C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
# C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
return x1 + self.W1(y1), x2 + self.W2(y2)
if __name__ == '__main__':
input1 = torch.randn(1, 64, 32, 32) #B C H W
input2 = torch.randn(1, 64, 32, 32) #B C H W
sp_full = SpatiotemporalAttentionFull(in_channels=64)
output_full_x1, output_full_x2 = sp_full(input1, input2)
print(input1.shape, input2.shape)
print(output_full_x1.shape, output_full_x2.shape)
sp_base = SpatiotemporalAttentionBase(in_channels=64)
output_base_x1, output_base_x2 = sp_base(input1, input2)
print(input1.shape, input2.shape)
print(output_base_x1.shape, output_base_x2.shape)
sp_full_not_shared = SpatiotemporalAttentionFullNotWeightShared(in_channels=64)
output_full_not_shared_x1, output_full_not_shared_x2 = sp_full_not_shared(input1, input2)
print(input1.shape, input2.shape)
print(output_full_not_shared_x1.shape, output_full_not_shared_x2.shape)
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文