AAAI 2024 | 即插即用,Conv-Former注意力模块,用卷积实现Transformer效果!

科技   2024-12-09 22:36   广东  


前言 文介绍一种新颖的单分支实时分割网络,称为SCTNet。通过学习利用从transformer到CNN的语义信息对齐来提取丰富的语义信息,SCTNet在保持轻量级单分支CNN的快速推理速度的同时,享受transformer的高准确性。为了缓解CNN特征和Transformer特征之间的语义差距,设计了CFBlock(ConvFormer块),它仅使用卷积操作就能像Transformer块一样捕获长距离上下文。此外,提出了SIAM(语义信息对齐模块),以更有效的方式对齐特征。

来源:AI缝合术
仅用于学术分享,若侵权请联系删除



1
论文题目SCTNet: Single-Branch CNN with Transformer Semantic Information for Real-Time Segmentation
中文题目: 单分支CNN结合Transformer语义信息的实时分割网络
论文链接:https://arxiv.org/pdf/2312.17071
官方github:https://github.com/xzz777/SCTNet
所属机构:华中科技大学人工智能与自动化学院国家多媒体信息智能处理技术重点实验室,美团
关键词:实时语义分割,Transformer,单分支CNN,语义信息对齐,深度学习

一、论文概要




Highlight

图5:在Cityscapes验证集上的可视化结果。与DDRNet-23(Pan等人,2022年)和RTFormer-B(Wang等人,2022年)相比,SCTNet-B生成的掩码具有更精细的细节,如浅蓝色框中突出显示的那样,以及更准确的大面积预测,如黄色框中突出显示的那样。

研究背景:

  • 实时语义分割方法:近期的实时语义分割方法通常采用额外的语义分支来追求丰富的长距离上下文信息,但额外的分支会带来不希望的计算开销并减慢推理速度。
     二、方法



1
图3:SCTNet的架构。CFBlock(Conv-Transformer,详见图4)通过SIAM(语义信息对齐模块)利用训练仅限的Transformer分支(在虚线框中以灰色显示),该模块由BFA(主干特征对齐)和SDHA(共享解码器头对齐)组成。
图4:Conv-Former块的设计(左)和卷积注意力的细节(右)。GDN表示分组双重归一化。⊗表示卷积操作,⊕代表加法,k表示核大小。

Conv-Former Block旨在模拟 Transformer 的结构,以更好地学习 Transformer 分支的语义信息,同时仅使用高效的卷积操作来实现注意力功能:

1. 结构设计:Conv-Former Block 的结构类似于典型的 Transformer 编码器。

    2. 卷积注意力:为了实现低延迟和强大的语义提取能力,Conv-Former Block 的卷积注意力基于 GFA(GPU-Friendly Attention)改进而来。主要区别在于:

    • 使用像素级卷积操作替代 GFA 中的矩阵乘法,避免了特征展平和重塑操作,以保持固有的空间结构并减少推理延迟。

    • 通过将可学习向量扩展为可学习的核,以更好地对齐 Transformer 的语义信息。这种转换将像素与可学习向量之间的相似度计算转换为像素块与可学习核之间的相似度计算,并通过带有可学习核的卷积操作保留更多的局部空间信息。

    3. 实现细节:为了提高效率,Conv-Former Block 使用条带卷积来近似标准卷积层。具体来说,使用 1×k 和 k×1 的卷积来近似 k×k 的卷积层。

    4. 前馈网络(FFN):与典型的 FFN 相比,Conv-Former Block 的 FFN 由两个标准的 3×3 卷积层组成,这比典型的 FFN 更高效,并提供了更大的感受野。

    CFBlock 结合卷积和 Transformer 的特性,通过 Conv-Former 高效建模局部和全局依赖关系,能够在多种视觉任务中发挥作用,尤其是在需要平衡性能与效率的场景下(如实时检测或分割任务)。可以调整模块中卷积核的尺度、注意力头的数量以及中间通道数,以适配不同任务的需求。

    注:Convolutional Attention模块也可单独拿出来使用!

      三、实验分析



      • Cityscapes数据集上:SCTNet-B-Seg 100实现了80.5%的mIoU和62.8 FPS,这是实时分割领域的新状态最佳性能。SCTNet-B-Seg 75达到了79.8%的mIoU,比RTFormer-B和DDRNet-23等方法在准确率上更高,同时速度是它们的两倍。SCTNet-S在保持最高FPS的同时,也实现了与STDC 2、RTFormer-S、SeaFormer-B和TopFormer-B等方法相比更好的性能。

      • ADE 20K数据集上:SCTNet-B实现了43.0%的mIoU和145.1 FPS,比RTFormer-B快约1.6倍,同时mIoU性能高出0.9%。SCTNet-S达到了37.7%的mIoU,保持了在ADE 20K上所有方法中最高的FPS。

          四、代码




          1
          温馨提示:对于所有推文中出现的代码,如果您在微信中复制的代码排版错乱,请复制该篇推文的链接,在任意浏览器中打开,再复制相应代码,即可成功在开发环境中运行!或者进入官方github仓库找到对应代码进行复制!
          import torchfrom torch import nnimport torch.nn.functional as Ffrom mmengine.model import constant_init, kaiming_init,trunc_normal_init,normal_initfrom timm.models.layers import DropPath# 论文题目:SCTNet: Single-Branch CNN with Transformer Semantic Information for Real-Time Segmentation# 中文题目: 单分支CNN结合Transformer语义信息的实时分割网络# 论文链接:https://arxiv.org/pdf/2312.17071# 官方github:https://github.com/xzz777/SCTNet# 所属机构:华中科技大学人工智能与自动化学院国家多媒体信息智能处理技术重点实验室,美团# 关键词:实时语义分割,Transformer,单分支CNN,语义信息对齐,深度学习#BN->Conv->GELU->drop->Conv2->dropclass MLP(nn.Module):    def __init__(self,                 in_channels,                 hidden_channels=None,                 out_channels=None,                 drop_rate=0.):        super(MLP,self).__init__()        hidden_channels = hidden_channels or in_channels        out_channels = out_channels or in_channels        self.norm = nn.BatchNorm2d(in_channels, eps=1e-06)        self.conv1 = nn.Conv2d(in_channels, hidden_channels, 3, 1, 1)        self.act = nn.GELU()        self.conv2 = nn.Conv2d(hidden_channels, out_channels, 3, 1, 1)        self.drop = nn.Dropout(drop_rate)        self.apply(self._init_weights)    def _init_weights(self, m):        if isinstance(m, nn.Linear):            trunc_normal_init(m.weight, std=.02)            if m.bias is not None:                constant_init(m.bias, val=0)        elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):            constant_init(m.weight, val=1.0)            constant_init(m.bias, val=0)        elif isinstance(m, nn.Conv2d):            kaiming_init(m.weight)            if m.bias is not None:                constant_init(m.bias, val=0)    def forward(self, x):        x = self.norm(x)        x = self.conv1(x)        x = self.act(x)        x = self.drop(x)        x = self.conv2(x)        x = self.drop(x)        return x    class ConvolutionalAttention(nn.Module):    """    The ConvolutionalAttention implementation    Args:        in_channels (int, optional): The input channels.        inter_channels (int, optional): The channels of intermediate feature.        out_channels (int, optional): The output channels.        num_heads (int, optional): The num of heads in attention. Default: 8    """    def __init__(self,                 in_channels,                 out_channels,                 inter_channels,                 num_heads=8):        super(ConvolutionalAttention,self).__init__()        assert out_channels % num_heads == 0, \            "out_channels ({}) should be be a multiple of num_heads ({})".format(out_channels, num_heads)        self.in_channels = in_channels        self.out_channels = out_channels        self.inter_channels = inter_channels        self.num_heads = num_heads        self.norm = nn.BatchNorm2d(in_channels, eps=1e-06)        self.kv =nn.Parameter(torch.zeros(inter_channels, in_channels, 7, 1))        self.kv3 =nn.Parameter(torch.zeros(inter_channels, in_channels, 1, 7))        trunc_normal_init(self.kv, std=0.001)        trunc_normal_init(self.kv3, std=0.001)        self.apply(self._init_weights)    def _init_weights(self, m):        if isinstance(m, nn.Linear):            trunc_normal_init(m.weight, std=.001)            if m.bias is not None:                constant_init(m.bias, val=0.)        elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):            constant_init(m.weight, val=1.)            constant_init(m.bias, val=.0)        elif isinstance(m, nn.Conv2d):            trunc_normal_init(m.weight, std=.001)            if m.bias is not None:                constant_init(m.bias, val=0.)    def _act_dn(self, x):        x_shape = x.shape  # n,c_inter,h,w        h, w = x_shape[2], x_shape[3]        x = x.reshape(            [x_shape[0], self.num_heads, self.inter_channels // self.num_heads, -1])   #n,c_inter,h,w -> n,heads,c_inner//heads,hw        x = F.softmax(x, dim=3)           x = x / (torch.sum(x, dim =2, keepdim=True) + 1e-06)          x = x.reshape([x_shape[0], self.inter_channels, h, w])         return x    def forward(self, x):        """        Args:            x (Tensor): The input tensor. (n,c,h,w)            cross_k (Tensor, optional): The dims is (n*144, c_in, 1, 1)            cross_v (Tensor, optional): The dims is (n*c_in, 144, 1, 1)        """        x = self.norm(x)        x1 = F.conv2d(                x,                self.kv,                bias=None,                stride=1,                padding=(3,0))          x1 = self._act_dn(x1)          x1 = F.conv2d(                x1, self.kv.transpose(1, 0), bias=None, stride=1,                padding=(3,0))          x3 = F.conv2d(                x,                self.kv3,                bias=None,                stride=1,                padding=(0,3))         x3 = self._act_dn(x3)        x3 = F.conv2d(                x3, self.kv3.transpose(1, 0), bias=None, stride=1,padding=(0,3))         x=x1+x3        return x    class CFBlock(nn.Module):    """    The CFBlock implementation based on PaddlePaddle.    Args:        in_channels (int, optional): The input channels.        out_channels (int, optional): The output channels.        num_heads (int, optional): The num of heads in attention. Default: 8        drop_rate (float, optional): The drop rate in MLP. Default:0.        drop_path_rate (float, optional): The drop path rate in CFBlock. Default: 0.2    """    def __init__(self,                 in_channels,                 out_channels,                 num_heads=8,                 drop_rate=0.,                 drop_path_rate=0.):        super(CFBlock,self).__init__()        in_channels_l = in_channels        out_channels_l = out_channels        self.attn_l = ConvolutionalAttention(            in_channels_l,            out_channels_l,            inter_channels=64,            num_heads=num_heads)        self.mlp_l = MLP(out_channels_l, drop_rate=drop_rate)        self.drop_path = DropPath(            drop_path_rate) if drop_path_rate > 0. else nn.Identity()    def _init_weights_kaiming(self, m):        if isinstance(m, nn.Linear):            trunc_normal_init(m.weight, std=.02)            if m.bias is not None:                constant_init(m.bias, val=0)        elif isinstance(m, (nn.SyncBatchNorm, nn.BatchNorm2d)):            constant_init(m.weight, val=1.0)            constant_init(m.bias, val=0)        elif isinstance(m, nn.Conv2d):            kaiming_init(m.weight)            if m.bias is not None:                constant_init(m.bias, val=0)    def forward(self, x):        x_res = x        x = x_res + self.drop_path(self.attn_l(x))        x = x + self.drop_path(self.mlp_l(x))         return xif __name__ == '__main__':    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')    input=torch.randn(1,32,256,256).to(device)    print(input.shape)    cfb = CFBlock(32,32).to(device)    output=cfb(input)    print(output.shape)
          运行结果

          便捷下载

          https://github.com/AIFengheshu/Plug-play-modules/blob/main/(AAAI%202024)%20CFBlock.py

          AI算法与图像处理
          考研逆袭985,非科班跨行AI,目前从事计算机视觉的工业和商业相关应用的工作。分享最新最前沿的科技,共同分享宝贵的资源资料,这里有机器学习,计算机视觉,Python等技术实战分享,也有考研,转行IT经验交流心得
           最新文章