即插即用时空特征融合模块TFF和SFF,涨点起飞起飞了

文摘   2025-01-17 17:20   上海  

论文介绍

题目:STNet: Spatial and Temporal feature fusion network for change detection in remote sensing images

论文地址:https://arxiv.org/pdf/2304.11422

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 跨时间特征融合模块(TFF)

    • 提出了一种基于跨时间门控机制的特征融合模块,用于双时相特征的融合。

    • 通过选择性地增强目标变化信息并抑制非目标变化,提升了变化检测的准确性。

  • 跨尺度特征融合模块(SFF)

    • 首次采用跨尺度注意力机制,利用高层次特征引导低层次特征的建模。

    • 该机制能够捕捉变化目标的细粒度空间信息,恢复变化表示的空间细节。

  • 高效的多尺度特征交互设计

    • 提出了一个轻量化的深度神经网络框架,在多个尺度上进行特征交互,综合了语义信息与空间细节。

    • 使用了ResNet-18作为特征提取的骨干网络,结合TFF和SFF模块,有效减少了参数量和计算成本。

  • 性能上的显著提升

    • 在三个遥感变化检测的基准数据集上(WHU、LEVIR-CD 和 CLCD),STNet在F1分数、IoU、整体准确率等指标上取得了领先的表现。

    • 在空间复杂度和计算开销上,STNet的参数量(14.6M)和计算量(9.61G FLOPs)显著低于许多现有方法。

  • 轻量化设计

    • 在TFF模块中使用深度可分离卷积(depth-wise separable convolution),在减少计算量的同时保持了模型的性能。

  • 创新的损失函数

    • 采用了结合Focal Loss和Dice Loss的混合损失函数,解决了变化检测中正负样本不平衡的问题。

方法

整体架构

       STNet 的整体结构包括以下几个部分:输入双时相遥感影像后,使用共享权重的 ResNet-18 提取多尺度特征;通过时间特征融合模块(TFF)采用跨时间门控机制,增强目标变化信息并抑制非目标变化;利用空间特征融合模块(SFF)通过跨尺度注意力机制,结合高层次语义信息与低层次空间细节,恢复变化表示的空间细节;最后通过轻量化解码器将多尺度特征拼接,并结合通道注意力模块(CAM)生成与输入影像尺寸一致的高精度变化检测图。

1. 输入与特征提取

  • 输入:双时相的遥感图像T1T_1 和T2T_2,它们是空间配准后的影像对。

  • 特征提取:使用共享权重的 ResNet-18 作为骨干网络,逐层提取多尺度双时相特征。

    • ResNet-18 提供了 4 个残差块的输出特征,分别为不同尺度的多层特征表示。


2. 时间特征融合模块(TFF)

  • 目标:通过跨时间门控机制融合双时相特征,强调目标变化并抑制非目标变化。

  • 工作流程

  1. 对双时相特征R1R_1 和R2R_2 进行逐元素相减,得到初步的粗粒度变化表示RcR_c

  2. RcR_c 与R1R_1R2R_2 分别进行拼接,并通过深度可分离卷积提取特征,生成权重W1W_1 和W2W_2

  3. 使用门控机制,通过权重调整融合R1R_1 和R2R_2,生成时间特征融合的结果RtR_t


3. 空间特征融合模块(SFF)

  • 目标:通过跨尺度注意力机制融合多尺度特征,恢复变化表示的空间细节。

  • 工作流程

  1. 将高层次特征(语义信息丰富但边界不精确)与低层次特征(包含更多空间细节)进行交互。

  2. 使用注意力机制计算像素间的关系,使高层次特征指导低层次特征的细化。

  3. 融合后的特征包含更高质量的语义信息和空间细节。


4. 解码器与变化检测图生成

  • 轻量化解码器

  1. 将各尺度的变化表示上采样到统一尺寸,并沿通道方向拼接。

  2. 使用通道注意力模块(Channel Attention Module, CAM)进一步增强特征。

  • 输出:通过最终的上采样操作,生成与输入影像尺寸一致的变化检测图。


  • 5. 损失函数

    • 混合损失函数:结合 Focal Loss 和 Dice Loss,解决正负样本不平衡问题,增强对变化区域的敏感性。

    即插即用模块作用

    TFF模块的作用:时间特征增强

    • 突出变化区域:通过跨时间门控机制,选择性增强变化的目标区域,同时抑制非目标区域的干扰。

    • 轻量化设计:使用深度可分离卷积,减少参数量和计算成本,使其适合嵌入到多种深度学习框架中。

    • 时间信息提炼:适用于需要结合时间维度进行变化分析的任务。


    SFF模块的作用:空间特征细化

    • 细粒度空间细节提取:通过跨尺度注意力机制,融合高层次语义信息和低层次空间细节,增强目标区域的表示能力。

    • 边界恢复与语义细化:适合于对目标边界和结构信息要求较高的场景。

    • 跨尺度交互:提高模型对多尺度目标的检测能力。

    消融实验结果

    • 内容

      • 该表通过对比基础模型(Base)与添加 TFF 模块、SFF 模块以及完整模型(STNet)的性能差异,验证了 TFF 和 SFF 模块的有效性。

      • 性能指标包括 F1 分数(F1)、精确率(Pre.)、召回率(Rec.)、交并比(IoU)和总体准确率(OA)。

    • 结论

      • 单独加入 TFF 模块(Base + TFF)或 SFF 模块(Base + SFF)都显著提升了模型性能。

      • 同时引入 TFF 和 SFF 模块(完整模型 STNet)取得了最佳效果,表明这两个模块的联合使用具有协同增益。


    • 内容

      • 该表比较了 STNet 和其他方法在参数数量(Params, M)及计算量(FLOPs, G)方面的差异。

    • 结论

      • STNet 以相对较少的参数量(14.6M)和最低的计算成本(9.61G FLOPs),实现了在性能上超越现有方法的效果,证明其轻量化设计的有效性。

    即插即用模块


    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    def dsconv_3x3(in_channel, out_channel):
        return nn.Sequential(
            nn.Conv2d(in_channel, in_channel, kernel_size=3, stride=1, padding=1, groups=in_channel),
            nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, groups=1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )

    def conv_1x1(in_channel, out_channel):
        return nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )


    class TFF(nn.Module):
        def __init__(self, in_channel, out_channel):
            super(TFF, self).__init__()
            self.catconvA = dsconv_3x3(in_channel * 2, in_channel)
            self.catconvB = dsconv_3x3(in_channel * 2, in_channel)
            self.catconv = dsconv_3x3(in_channel * 2, out_channel)
            self.convA = nn.Conv2d(in_channel, 1, 1)
            self.convB = nn.Conv2d(in_channel, 1, 1)
            self.sigmoid = nn.Sigmoid()

        def forward(self, xA, xB):
            x_diff = xA - xB

            x_diffA = self.catconvA(torch.cat([x_diff, xA], dim=1))
            x_diffB = self.catconvB(torch.cat([x_diff, xB], dim=1))

            A_weight = self.sigmoid(self.convA(x_diffA))
            B_weight = self.sigmoid(self.convB(x_diffB))

            xA = A_weight * xA
            xB = B_weight * xB

            x = self.catconv(torch.cat([xA, xB], dim=1))

            return x


    if __name__ == '__main__':
        in_channel = 3
        out_channel = 3
        block = TFF(in_channel, out_channel)

        # Create dummy inputs
        xA = torch.rand(1, in_channel, 32, 32)
        xB = torch.rand(1, in_channel, 32, 32)

        # Forward pass
        output = block(xA, xB)
        print(f"Input A size: {xA.size()}")
        print(f"Input B size: {xB.size()}")
        print(f"Output size: {output.size()}")


    class SelfAttentionBlock(nn.Module):
        """
        query_feats: (B, C, h, w)
        key_feats: (B, C, h, w)
        value_feats: (B, C, h, w)

        output: (B, C, h, w)
        """


        def __init__(self, key_in_channels, query_in_channels, transform_channels, out_channels,
                     key_query_num_convs, value_out_num_convs)
    :

            super(SelfAttentionBlock, self).__init__()
            self.key_project = self.buildproject(
                in_channels=key_in_channels,
                out_channels=transform_channels,
                num_convs=key_query_num_convs,
            )
            self.query_project = self.buildproject(
                in_channels=query_in_channels,
                out_channels=transform_channels,
                num_convs=key_query_num_convs
            )
            self.value_project = self.buildproject(
                in_channels=key_in_channels,
                out_channels=transform_channels,
                num_convs=value_out_num_convs
            )
            self.out_project = self.buildproject(
                in_channels=transform_channels,
                out_channels=out_channels,
                num_convs=value_out_num_convs
            )
            self.transform_channels = transform_channels

        def forward(self, query_feats, key_feats, value_feats):
            batch_size = query_feats.size(0)

            query = self.query_project(query_feats)
            query = query.reshape(*query.shape[:2], -1)
            query = query.permute(0, 2, 1).contiguous() # (B, h*w, C)

            key = self.key_project(key_feats)
            key = key.reshape(*key.shape[:2], -1) # (B, C, h*w)

            value = self.value_project(value_feats)
            value = value.reshape(*value.shape[:2], -1)
            value = value.permute(0, 2, 1).contiguous() # (B, h*w, C)

            sim_map = torch.matmul(query, key)

            sim_map = (self.transform_channels ** -0.5) * sim_map
            sim_map = F.softmax(sim_map, dim=-1) # (B, h*w, K)

            context = torch.matmul(sim_map, value) # (B, h*w, C)
            context = context.permute(0, 2, 1).contiguous()
            context = context.reshape(batch_size, -1, *query_feats.shape[2:]) # (B, C, h, w)

            context = self.out_project(context) # (B, C, h, w)
            return context

        def buildproject(self, in_channels, out_channels, num_convs):
            convs = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            )
            for _ in range(num_convs - 1):
                convs.append(
                    nn.Sequential(
                        nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                        nn.BatchNorm2d(out_channels),
                        nn.ReLU(inplace=True)
                    )
                )
            if len(convs) > 1:
                return nn.Sequential(*convs)
            return convs[0]

    def conv_3x3(in_channel, out_channel):
        return nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )



    class SFF(nn.Module):
        def __init__(self, in_channel):
            super(SFF, self).__init__()
            self.conv_small = conv_1x1(in_channel, in_channel)
            self.conv_big = conv_1x1(in_channel, in_channel)
            self.catconv = conv_3x3(in_channel *2, in_channel)
            self.attention = SelfAttentionBlock(
                key_in_channels=in_channel,
                query_in_channels = in_channel,
                transform_channels = in_channel // 2,
                out_channels = in_channel,
                key_query_num_convs=2,
                value_out_num_convs=1
            )

        def forward(self, x_small, x_big):
            img_size =x_big.size(2), x_big.size(3)
            x_small = F.interpolate(x_small, img_size, mode="bilinear", align_corners=False)
            x = self.conv_small(x_small) + self.conv_big(x_big)
            new_x = self.attention(x, x, x_big)

            out = self.catconv(torch.cat([new_x, x_big], dim=1))
            return out


    if __name__ == '__main__':
        block = SFF(3)
        x_small = torch.rand(1, 3, 32, 32)
        x_big = torch.rand(1, 3, 32, 32)
        output = block(x_small, x_big)
        print(x_small.size())
        print(x_big.size())    print(output.size())

    便捷下载方式

    浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

    更多分析可见原文


    ai缝合大王
    聚焦AI前沿,分享相关技术、论文,研究生自救指南
     最新文章