2024即插即用时空交互注意力机制TIAM,涨点涨爆了!

文摘   2024-12-07 17:20   中国香港  

论文介绍

题目:Robust change detection for remote sensing images based on temporospatial interactive attention module

论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001213

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 提出了一种新的变化检测框架CDNeXt

    • 包含编码器(Encoder)、交互器(Interactor)、解码器(Decoder)和检测器(Detector)四部分。

    • 特别设计了一个时空交互注意力模块(Temporospatial Interactive Attention Module, TIAM),用于解决遥感图像中几何视角旋转和时间风格差异带来的问题。

  • 引入时空交互注意力模块(TIAM)

    • 能够查询并重构时空特征的依赖关系和风格相关性,缓解不同时间和空间特征之间的干扰。

    • 支持模块化嵌入,使其可以集成到其他视觉任务中。

  • 显著的性能提升

    • 在多个基准数据集(如SYSU-CD、LEVIR-CD+、S2Looking、BANDON)上取得了最新的性能(State-of-the-Art, SOTA),例如在SYSU-CD上达到F1分数82.63%,IoU 70.39%。

  • 模块设计高效且具有通用性

    • 不需要特殊的训练技术或参数调优,能够以端到端方式输出变化检测结果。TIAM模块显著降低了因视角和风格差异造成的伪变化检测,同时对主流方法的性能提升具有潜力。

方法

整体架构

     CDNeXt 是一个系统化的遥感图像变化检测框架,包含编码器(Encoder)、交互器(Interactor)、解码器(Decoder)和检测器(Detector)四个模块。编码器提取多尺度特征,交互器通过引入时空交互注意力模块(TIAM)建模空间视角相关性和时间风格差异,增强变化特征的表达;解码器逐层上采样融合特征,同时利用残差压缩模块优化特征表示;检测器整合分层特征,输出高精度的二值变化掩码。该设计有效解决了视角旋转和时间风格差异引起的伪变化检测问题,显著提升检测精度和鲁棒性。

1. 编码器(Encoder)

  • 功能:从双时相图像(例如 TAT_A 和 TBT_B)中提取多层次特征。

  • 实现

    • 使用预训练的骨干网络(如 ConvNeXt)提取层次化的金字塔特征。

    • 特征通过下采样逐层提取,形成多尺度的特征表示。

    • 编码器输出每个时间步的特征 {F1,F2,F3,F4}\{ F_1, F_2, F_3, F_4 \},并为后续模块提供基础输入。


2. 交互器(Interactor)

  • 核心模块:时空交互注意力模块(Temporospatial Interactive Attention Module, TIAM)。

  • 功能

    • 对同一层级的双时相特征(例如 FAF_A 和 FBF_B)进行查询和重建。

    • 提取和建模 时空依赖性(如空间视角相关性)和 时间风格差异(如光照、阴影、天气变化)。

  • 特点

    • TIAM 利用嵌入式高斯函数计算注意力分数,通过矩阵操作加权重构特征。

    • 减少了伪变化检测(例如视角旋转引起的误报)。


3. 解码器(Decoder)

  • 功能

    • 接受交互器输出的特征并进行逐层上采样。

    • 通过特征维度压缩(Feature Squeeze Residual,FSR)模块对特征进行降维与残差增强。

  • 特点

    • 解码器建立了交互器与解码器之间的跳跃连接,增强了梯度传播的效果。

    • 高效融合了低层次的纹理特征和高层次的语义特征。


4. 检测器(Detector)

  • 功能

    • 整合解码器生成的分层特征,生成最终的二值变化掩码(Change Mask)。

  • 特点

    • 检测器通过全连接卷积层将多层特征融合,并使用 Softmax 层输出像素级的变化检测结果。

    • 能同时保留物体的边界细节和内部完整性。

即插即用模块作用

TIAM 作为一个即插即用模块

  • 缓解伪变化干扰通过构建空间透视依赖矩阵(Spatial Perspective Dependencies),减轻因几何视角变化(如不同拍摄角度)导致的伪变化检测问题。通过构建时间风格相关性矩阵(Temporal Style Correlations),减少因光照、天气、季节等因素引起的视觉风格差异带来的误报。

  • 提升特征交互效率模块以全局交互注意力为核心,能够高效提取双时相图像的语义不变性和变化特征,强化时空特征的互信息。

  • 增强模型鲁棒性和通用性TIAM 可嵌入其他模型,如本文中实验的 IFNetFC-Conc 模型,显著提升其在变化检测任务中的准确性,证明了模块的适配性和通用性。

  • 降低计算复杂度TIAM 在四分之一尺度特征图上操作,计算复杂度较低(见论文表 6 的复杂度分析),同时保留了全局特征交互的优势。

消融实验结果

不同注意力机制的消融实验通过在 SYSU-CD 和 S2Looking 数据集上的实验,比较了 TIAM 与其他主流注意力机制(如 SAM、DAM、CBAM 等)的性能。实验结果显示,TIAM 在 F1 分数和 IoU 方面均优于其他方法,并验证了时空特征交互的重要性,尤其是在应对复杂变化场景时。


框架组件的消融实验:该表验证了 CDNeXt 框架中不同组件的贡献,包括骨干网络(ResNet18 与 ConvNeXt 的对比)、特征融合模块(FS)和 TIAM 的作用。结果表明,ConvNeXt 提高了特征提取能力,而 FS 和 TIAM 的加入显著提升了 F1 和 IoU,特别是 TIAM 在捕获变化特征上至关重要。

复杂度分析该表对比了模型的参数规模(Params)、计算复杂度(FLOPs)和检测性能(F1)。结果表明,CDNeXt 在保持较低计算复杂度的同时实现了最佳性能。同时,TIAM 模块嵌入其他方法(如 FC-Conc 和 IFNet)后,显著提升了准确性,验证了其通用性和高效性。

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F
#论文:Robust change detection for remote sensing images based on temporospatial interactive attention module
#论文地址:https://www.sciencedirect.com/science/article/pii/S1569843224001213

class SpatiotemporalAttentionFull(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
        super(SpatiotemporalAttentionFull, self).__init__()
        assert dimension in [2, ]
        self.dimension = dimension
        self.sub_sample = sub_sample
        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        self.g = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0)
        )

        self.W = nn.Sequential(
            nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                      kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.in_channels)
        )
        self.theta = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0),
        )
        self.phi = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0),
        )
        self.energy_time_1_sf = nn.Softmax(dim=-1)
        self.energy_time_2_sf = nn.Softmax(dim=-1)
        self.energy_space_2s_sf = nn.Softmax(dim=-2)
        self.energy_space_1s_sf = nn.Softmax(dim=-2)

    def forward(self, x1, x2):

        batch_size = x1.size(0)
        g_x11 = self.g(x1).reshape(batch_size, self.inter_channels, -1)
        g_x12 = g_x11.permute(0, 2, 1)
        g_x21 = self.g(x2).reshape(batch_size, self.inter_channels, -1)
        g_x22 = g_x21.permute(0, 2, 1)

        theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
        theta_x2 = theta_x1.permute(0, 2, 1)

        phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
        phi_x2 = phi_x1.permute(0, 2, 1)

        energy_time_1 = torch.matmul(theta_x1, phi_x2)
        energy_time_2 = energy_time_1.permute(0, 2, 1)
        energy_space_1 = torch.matmul(theta_x2, phi_x1)
        energy_space_2 = energy_space_1.permute(0, 2, 1)

        energy_time_1s = self.energy_time_1_sf(energy_time_1)
        energy_time_2s = self.energy_time_2_sf(energy_time_2)
        energy_space_2s = self.energy_space_2s_sf(energy_space_1)
        energy_space_1s = self.energy_space_1s_sf(energy_space_2)

        # energy_time_2s*g_x11*energy_space_2s = C2*S(C1) × C1*H1W1 × S(H1W1)*H2W2 = (C2*H2W2)' is rebuild C1*H1W1
        y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
        # energy_time_1s*g_x12*energy_space_1s = C1*S(C2) × C2*H2W2 × S(H2W2)*H1W1 = (C1*H1W1)' is rebuild C2*H2W2
        y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous()
        y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
        y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
        return x1 + self.W(y1), x2 + self.W(y2)


class SpatiotemporalAttentionBase(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
        super(SpatiotemporalAttentionBase, self).__init__()
        assert dimension in [2, ]
        self.dimension = dimension
        self.sub_sample = sub_sample
        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        self.g = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0)
        )

        self.W = nn.Sequential(
            nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                      kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.in_channels)
        )

        self.theta = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0),
        )
        self.phi = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0),
        )
        self.energy_space_2s_sf = nn.Softmax(dim=-2)
        self.energy_space_1s_sf = nn.Softmax(dim=-2)

    def forward(self, x1, x2):
        """
        :param x: (b, c, h, w)
        :param return_nl_map: if True return z, nl_map, else only return z.
        :return:
        """

        batch_size = x1.size(0)
        g_x11 = self.g(x1).reshape(batch_size, self.inter_channels, -1)
        g_x21 = self.g(x2).reshape(batch_size, self.inter_channels, -1)

        theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
        theta_x2 = theta_x1.permute(0, 2, 1)

        phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)

        energy_space_1 = torch.matmul(theta_x2, phi_x1)
        energy_space_2 = energy_space_1.permute(0, 2, 1)
        energy_space_2s = self.energy_space_2s_sf(energy_space_1) # S(H1W1)*H2W2
        energy_space_1s = self.energy_space_1s_sf(energy_space_2) # S(H2W2)*H1W1

        # g_x11*energy_space_2s = C1*H1W1 × S(H1W1)*H2W2 = (C1*H2W2)' is rebuild C1*H1W1
        y1 = torch.matmul(g_x11, energy_space_2s).contiguous() # C2*H2W2
        # g_x21*energy_space_1s = C2*H2W2 × S(H2W2)*H1W1 = (C2*H1W1)' is rebuild C2*H2W2
        y2 = torch.matmul(g_x21, energy_space_1s).contiguous()
        y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
        y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
        return x1 + self.W(y1), x2 + self.W(y2)


class SpatiotemporalAttentionFullNotWeightShared(nn.Module):
    def __init__(self, in_channels, inter_channels=None, dimension=2, sub_sample=False):
        super(SpatiotemporalAttentionFullNotWeightShared, self).__init__()
        assert dimension in [2, ]
        self.dimension = dimension
        self.sub_sample = sub_sample
        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        self.g1 = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0)
        )
        self.g2 = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0),
        )

        self.W1 = nn.Sequential(
            nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                      kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.in_channels)
        )
        self.W2 = nn.Sequential(
            nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                      kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(self.in_channels)
        )
        self.theta = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0),
        )
        self.phi = nn.Sequential(
            nn.BatchNorm2d(self.in_channels),
            nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                      kernel_size=1, stride=1, padding=0),
        )

    def forward(self, x1, x2):
        """
        :param x: (b, c, h, w)
        :param return_nl_map: if True return z, nl_map, else only return z.
        :return:
        """

        batch_size = x1.size(0)
        g_x11 = self.g1(x1).reshape(batch_size, self.inter_channels, -1)
        g_x12 = g_x11.permute(0, 2, 1)
        g_x21 = self.g2(x2).reshape(batch_size, self.inter_channels, -1)
        g_x22 = g_x21.permute(0, 2, 1)

        theta_x1 = self.theta(x1).reshape(batch_size, self.inter_channels, -1)
        theta_x2 = theta_x1.permute(0, 2, 1)

        phi_x1 = self.phi(x2).reshape(batch_size, self.inter_channels, -1)
        phi_x2 = phi_x1.permute(0, 2, 1)

        energy_time_1 = torch.matmul(theta_x1, phi_x2)
        energy_time_2 = energy_time_1.permute(0, 2, 1)
        energy_space_1 = torch.matmul(theta_x2, phi_x1)
        energy_space_2 = energy_space_1.permute(0, 2, 1)

        energy_time_1s = F.softmax(energy_time_1, dim=-1)
        energy_time_2s = F.softmax(energy_time_2, dim=-1)
        energy_space_2s = F.softmax(energy_space_1, dim=-2)
        energy_space_1s = F.softmax(energy_space_2, dim=-2)
        # C1*S(C2) energy_time_1s * C1*H1W1 g_x12 * energy_space_1s S(H2W2)*H1W1 -> C1*H1W1
        y1 = torch.matmul(torch.matmul(energy_time_2s, g_x11), energy_space_2s).contiguous() # C2*H2W2
        # C2*S(C1) energy_time_2s * C2*H2W2 g_x21 * energy_space_2s S(H1W1)*H2W2 -> C2*H2W2
        y2 = torch.matmul(torch.matmul(energy_time_1s, g_x21), energy_space_1s).contiguous() # C1*H1W1
        y1 = y1.reshape(batch_size, self.inter_channels, *x2.size()[2:])
        y2 = y2.reshape(batch_size, self.inter_channels, *x1.size()[2:])
        return x1 + self.W1(y1), x2 + self.W2(y2)


if __name__ == '__main__':



    input1 = torch.randn(1, 64, 32, 32) #B C H W
    input2 = torch.randn(1, 64, 32, 32) #B C H W

    sp_full = SpatiotemporalAttentionFull(in_channels=64)
    output_full_x1, output_full_x2 = sp_full(input1, input2)

    print(input1.shape, input2.shape)
    print(output_full_x1.shape, output_full_x2.shape)


    sp_base = SpatiotemporalAttentionBase(in_channels=64)
    output_base_x1, output_base_x2 = sp_base(input1, input2)

    print(input1.shape, input2.shape)
    print(output_base_x1.shape, output_base_x2.shape)


    sp_full_not_shared = SpatiotemporalAttentionFullNotWeightShared(in_channels=64)
    output_full_not_shared_x1, output_full_not_shared_x2 = sp_full_not_shared(input1, input2)

    print(input1.shape, input2.shape)
    print(output_full_not_shared_x1.shape,          output_full_not_shared_x2.shape)

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章