即插即用实时语义分割模块PIDNet,涨点起飞起飞了

文摘   2025-01-24 17:20   中国香港  

论文介绍

题目:PIDNet: A Real-time Semantic Segmentation Network Inspired by PID

论文地址:https://arxiv.org/pdf/2206.02066

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 引入PID控制器的思想

    • 作者首次将比例-积分-微分(PID)控制器的概念引入到深度学习的语义分割任务中,指出传统双分支网络(TBN)相当于一个比例-积分(PI)控制器,而PI控制器存在“过冲”问题。

    • 提出通过增加一个额外的导数(D)分支,模拟PID控制器以缓解过冲问题,从而提高分割性能。

  • 提出PIDNet三分支网络

    • P分支(Proportional):专注于解析和保留高分辨率的细节信息。

    • I分支(Integral):聚合全局上下文信息,捕获长距离依赖关系。

    • D分支(Derivative):提取高频特征,用于精确预测边界。

  • 设计高效模块

    • Pag模块(Pixel-attention-guided fusion module):实现从I分支到P分支的信息传递,同时避免细节特征被上下文信息淹没。

    • Bag模块(Boundary-attention-guided fusion module):通过边界注意力机制,优化P分支与I分支特征的融合,平衡细节与上下文信息。

    • PAPPM模块(Parallel Aggregation PPM):改进上下文聚合模块,通过并行结构加快计算,提升实时性能。

  • 推理速度与精度的权衡

    • 在Cityscapes和CamVid数据集上,PIDNet在保持高推理速度的同时实现了更高的分割精度,优于现有实时语义分割模型。

方法

整体架构

      PIDNet 是一种三分支架构的实时语义分割模型,分别通过 P 分支解析细节、I 分支聚合上下文、D 分支提取边界信息,同时利用 Pag 和 Bag 模块优化细节与上下文的融合,PAPPM 模块加速上下文聚合。该模型在保持高推理速度的同时实现了优异的分割精度,适用于自动驾驶等对实时性要求高的场景。

1. 模型整体架构

  • 三分支设计

    • 预测边界信息,提取高频特征。

    • 通过检测边界信息辅助细节和上下文特征的融合,提升分割精度。

    • 聚合局部和全局的上下文信息。

    • 用于捕获长距离依赖关系,解析全局语义。

    • 处理高分辨率的细节特征。

    • 聚焦于像素级信息的解析,用于保留目标的几何形状和边界。

    • P分支(Proportional Branch)

    • I分支(Integral Branch)

    • D分支(Derivative Branch)

  • 模块设计

    • 用于快速聚合多尺度上下文信息。

    • 通过并行化结构减少计算时间,提升实时性。

    • 在融合细节和上下文特征时,引入边界注意力。

    • 增强小目标和边界区域的细节特征。

    • 实现从I分支到P分支的信息传递,避免细节被上下文信息淹没。

    • 基于像素级注意力机制,有选择性地融合有用的语义信息。

    • Pag模块(Pixel-attention-guided fusion module)

    • Bag模块(Boundary-attention-guided fusion module)

    • PAPPM模块(Parallel Aggregation PPM)


2. 网络运行流程

  • 输入图像经过主干网络(基于残差块的骨干网络),提取初步特征。

  • 特征分别流向三个分支:

    • P分支解析局部细节特征。

    • I分支聚合全局上下文信息。

    • D分支提取边界信息,用于辅助优化。

  • 在输出阶段,利用Bag模块结合D分支的边界信息,将P分支和I分支的特征进行加权融合,生成最终的分割结果。


3. 损失函数设计

  • 多任务损失

    • 语义损失(Cross-Entropy Loss,l2l_2):优化全局分割任务。

    • 边界损失(Weighted Binary Cross-Entropy Loss,l1l_1):强调边界区域的重要性。

    • 边界感知损失(Boundary-Awareness CE Loss,l3l_3):协调语义分割与边界检测的功能。

总损失函数为:

Loss=λ0l0+λ1l1+λ2l2+λ3l3


4. PIDNet变体

根据需求设计了三种模型:

  • PIDNet-S:轻量化版本,适合移动设备,速度最快。

  • PIDNet-M:中等复杂度,兼顾速度与精度。

  • PIDNet-L:深度版本,适合高精度需求。

即插即用模块作用

PIDNet 作为一个即插即用模块

  • 细节与上下文信息的平衡

    P 分支提取高分辨率的细节特征,确保小目标和边界的准确性。

    I 分支提供全局上下文信息,解析复杂场景和大目标的语义。

    D 分支通过提取边界特征,强化小目标和边界的清晰度。


  • 实时处理能力

    通过轻量化设计(如 PAPPM 和 Bag 模块),确保分割模型在嵌入式设备或资源有限的场景中也能实时运行。


  • 边界优化与目标增强

    边界感知模块(D 分支和 Bag 模块)在边界细节和小目标分割上表现尤为突出,适合对边界精确性要求高的任务。

消融实验结果

  • 内容:对两种双分支网络(BiSeNet 和 DDRNet)进行实验,分别测试添加和不添加 ADB 和 Bag 模块的性能变化。

  • 说明:引入 ADB 和 Bag 后,模型的分割精度(mIOU)明显提升,但推理速度(FPS)略有下降。这表明 ADB 和 Bag 在提升精度方面效果显著,但需权衡计算代价。


  • 内容:对 PIDNet-L 的不同配置进行实验,包括是否使用 Pag 模块、Bag 模块以及预训练(ImageNet)的影响。

  • 说明

    • 结合 Pag 和 Bag 模块后,模型性能提升到最高(mIOU 达到 80.9%)。

    • 进一步验证了 Pag 模块在细节信息提取中的重要性,Bag 模块在特征融合中的关键作用。


    • 内容:对 PAPPM(并行上下文聚合模块)与传统的 DAPPM 进行对比。

    • 说明:PAPPM 在不降低精度的情况下(mIOU 78.8%),推理速度比 DAPPM 快 9.5 FPS,说明 PAPPM 更适合实时语义分割。

即插即用模块

import torch
import torch.nn as nn
import torch.nn.functional as F


class PagFM(nn.Module):
    def __init__(self, in_channels, mid_channels, after_relu=False, with_channel=False, BatchNorm=nn.BatchNorm2d):
        super(PagFM, self).__init__()
        self.with_channel = with_channel
        self.after_relu = after_relu
        self.f_x = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels,
                      kernel_size=1, bias=False),
            BatchNorm(mid_channels)
        )
        self.f_y = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels,
                      kernel_size=1, bias=False),
            BatchNorm(mid_channels)
        )
        if with_channel:
            self.up = nn.Sequential(
                nn.Conv2d(mid_channels, in_channels,
                          kernel_size=1, bias=False),
                BatchNorm(in_channels)
            )
        if after_relu:
            self.relu = nn.ReLU(inplace=True)

    def forward(self, x, y):
        input_size = x.size()
        if self.after_relu:
            y = self.relu(y)
            x = self.relu(x)

        y_q = self.f_y(y)
        y_q = F.interpolate(y_q, size=[input_size[2], input_size[3]],
                            mode='bilinear', align_corners=False)
        x_k = self.f_x(x)

        if self.with_channel:
            sim_map = torch.sigmoid(self.up(x_k * y_q))
        else:
            sim_map = torch.sigmoid(torch.sum(x_k * y_q, dim=1).unsqueeze(1))

        y = F.interpolate(y, size=[input_size[2], input_size[3]],
                          mode='bilinear', align_corners=False)
        x = (1 - sim_map) * x + sim_map * y

        return x


if __name__ == '__main__':
    block = PagFM(in_channels=32, mid_channels=16)
    input1 = torch.rand(16, 32, 16, 16)
    input2 = torch.rand(16, 32, 16, 16)
    output = block(input1, input2)    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章