轻量级、高效、动态化的时序卷积模块TAdaConv,即插即用即涨点

文摘   2025-01-18 19:12   上海  

论文介绍

题目:Temporally-Adaptive Models for Efficient Video Understanding

论文地址:https://arxiv.org/pdf/2308.05787

QQ深度学习交流群:994264161

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 引入时序自适应卷积 (Temporally-Adaptive Convolutions, TAdaConv)

  • TAdaConv通过根据每一帧的局部和全局时序上下文动态调整卷积核权重,从而增强了空间卷积的时序建模能力。与传统的1D时序卷积相比,TAdaConv直接对卷积核进行操作,而不是特征图,极大地提高了效率。

  • 因子化卷积核权重

    • TAdaConv将每一帧的卷积核分解为基础权重和校准权重的乘积,其中校准权重根据输入数据动态生成。这种设计可以保留预训练模型的权重,从而减少训练视频模型的资源消耗并提升泛化性能。

  • 多头自注意力优化

    • 在改进版本TAdaConvV2中,引入了多头自注意力机制以增强全局时序信息的建模能力。

  • 模块化设计

    • 提出了TAdaBlock模块,能够直接嵌入现有的卷积网络(如ConvNeXt)和视觉Transformer中,从而赋予这些模型强大的时序建模能力。

方法

整体架构

       这篇论文提出了一种用于视频理解的时序自适应卷积(TAdaConv)及其改进版本TAdaConvV2,通过动态校准卷积核权重,结合局部和全局时序上下文,显著增强模型的时序建模能力。基于此,构建了TAdaBlock模块,可灵活嵌入现代卷积网络(如ConvNeXt)和Transformer(如Vision Transformer)中,提升在动作识别和定位任务中的性能。创新设计包括多头自注意力优化、时序特征聚合以及与预训练模型的无缝集成,确保效率与准确率的显著提升。

1. 核心模块:TAdaConv

  • 基本原理

    • 动态调整每一帧的卷积核权重,通过结合局部和全局时序上下文生成校准权重。

    • 卷积核权重被分解为基础权重WbW_b 和校准权重αt\alpha_t,公式为:Wt=αtWb

  • 效率优势

    • 直接操作卷积核,而非特征图,显著减少计算开销。

    • 可以加载预训练模型的权重,无需从零训练。


2. 改进模块:TAdaConvV2

  • 增强特性

    • 引入多头自注意力机制(MHSA),提升全局时序建模能力。

    • 使用GELU激活函数和LayerNorm代替传统ReLU和BatchNorm,更适配现代网络设计。

  • 时序聚合

    • 提供了高效的时序特征聚合方案(如T-Pool),结合时序降采样进一步压缩计算量。


3. 模块化设计:TAdaBlock

  • TAdaBlock的主要类型

    • 适用于Transformer架构。

    • 在Transformer的多头自注意力层之前插入TAdaConv模块,增强时序动态建模。

    • 适用于现代卷积网络(如ConvNeXt)。

    • 集成时序池化(T-Pool)和降采样功能。

    • 适用于传统卷积网络(如ResNet)。

    • 将TAdaConv作为替代卷积核,增强时序建模能力。

即插即用模块作用

TAdaConv 作为一个即插即用模块

  • 增强时序建模能力

    动态调整卷积核权重,结合局部和全局时序上下文,精确捕捉视频中的时序动态信息。

  • 提升模型效率

    • 通过直接操作卷积核而非特征图,大幅减少计算开销,显著提高模型的效率。

  • 兼容预训练模型

    • 可直接加载预训练权重,避免从零训练的高昂成本,同时加速模型在小规模视频数据集上的收敛。

  • 模块化设计,灵活嵌入

    • 可轻松插入到不同架构中(如2D卷积网络、3D卷积网络、Transformer),提升基础模型的时序理解能力。

  • 精度和速度的平衡

    • 在保持较低计算成本的同时,提供与传统复杂时序模型(如3D卷积网络、复杂Transformer)相当甚至更优的性能。

消融实验结果

  • 表 3:关于 TAdaConv 的消融实验

  • 表 3a:验证了是否放松时序不变性对模型性能的影响。

    • 结果表明,动态校准权重能够显著提高时序建模能力,而放松时间维度的不变性进一步增强了分类准确率。

  • 表 3b:探讨了不同校准维度的影响。

    • 校准权重在输入通道CinC_{in} 上表现最佳,表明对输入特征的动态调整更有利于模型性能。

  • 表 3c:分析了在不同视频模型(如SlowFast和R(2+1)D)中引入TAdaConv的效果。

    • TAdaConv能够在多种基线模型上提升性能,且计算开销微乎其微。

  • 表 3d:比较了不同校准权重生成方式(如线性与非线性)的性能。

    • 包含时序上下文的非线性生成方式表现最佳,强调了局部和全局时序上下文的重要性。

  • 表 3e:验证了不同的特征聚合方案对性能的影响。

    • 平均池化 (Avg Pooling) 和独立的BatchNorm分支显著提高了模型的性能。


  • 内容:比较不同版本的TAdaBlock对Kinetics-400和SSV2的影响。

  • 结果

    • TAdaConvV2 和 T-Pool 的引入进一步提升了模型在时序动态数据集上的表现。

即插即用模块

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _triple

class TAdaConv2d(nn.Module):
    """
    Performs temporally adaptive 2D convolution.
    Currently, only application on 5D tensors is supported, which makes TAdaConv2d
        essentially a 3D convolution with temporal kernel size of 1.
    """


    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True,
                 cal_dim="cin")
:

        super(TAdaConv2d, self).__init__()
        """
        Args:
            in_channels (int): number of input channels.
            out_channels (int): number of output channels.
            kernel_size (list): kernel size of TAdaConv2d.
            stride (list): stride for the convolution in TAdaConv2d.
            padding (list): padding for the convolution in TAdaConv2d.
            dilation (list): dilation of the convolution in TAdaConv2d.
            groups (int): number of groups for TAdaConv2d.
            bias (bool): whether to use bias in TAdaConv2d.
        """


        kernel_size = _triple(kernel_size)
        stride = _triple(stride)
        padding = _triple(padding)
        dilation = _triple(dilation)

        assert kernel_size[0] == 1
        assert stride[0] == 1
        assert padding[0] == 0
        assert dilation[0] == 1
        assert cal_dim in ["cin", "cout"]

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.cal_dim = cal_dim

        # base weights (W_b)
        self.weight = nn.Parameter(
            torch.Tensor(1, 1, out_channels, in_channels // groups, kernel_size[1], kernel_size[2])
        )
        if bias:
            self.bias = nn.Parameter(torch.Tensor(1, 1, out_channels))
        else:
            self.register_parameter('bias', None)

        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, x, alpha):
        """
        Args:
            x (tensor): feature to perform convolution on.
            alpha (tensor): calibration weight for the base weights.
                W_t = alpha_t * W_b
        """

        _, _, c_out, c_in, kh, kw = self.weight.size()
        b, c_in, t, h, w = x.size()
        x = x.permute(0,2,1,3,4).reshape(1,-1,h,w)

        if self.cal_dim == "cin":
            # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, 1, C, H(1), W(1)
            # corresponding to calibrating the input channel
            weight = (alpha.permute(0,2,1,3,4).unsqueeze(2) * self.weight).reshape(-1, c_in//self.groups, kh, kw)
        elif self.cal_dim == "cout":
            # w_alpha: B, C, T, H(1), W(1) -> B, T, C, H(1), W(1) -> B, T, C, 1, H(1), W(1)
            # corresponding to calibrating the input channel
            weight = (alpha.permute(0,2,1,3,4).unsqueeze(3) * self.weight).reshape(-1, c_in//self.groups, kh, kw)

        bias = None
        if self.bias is not None:
            # in the official implementation of TAda2D,
            # there is no bias term in the convs
            # hence the performance with bias is not validated
            bias = self.bias.repeat(b, t, 1).reshape(-1)
        output = F.conv2d(
            x, weight=weight, bias=bias, stride=self.stride[1:], padding=self.padding[1:],
            dilation=self.dilation[1:], groups=self.groups * b * t)

        output = output.view(b, t, c_out, output.size(-2), output.size(-1)).permute(0,2,1,3,4)

        return output
        
    def __repr__(self):
        return f"TAdaConv2d({self.in_channels}, {self.out_channels}, kernel_size={self.kernel_size}, " +\
            f"stride={self.stride}, padding={self.padding}, bias={self.bias is not None}, cal_dim=\"{self.cal_dim}\")"


if __name__ == '__main__':

    tada_conv2d = TAdaConv2d(in_channels=64, out_channels=64, kernel_size=[1, 3, 3], stride=[1, 1, 1], padding=[0, 1, 1])
    input = torch.rand(2, 64, 10, 32, 32)
    alpha_tensor = torch.rand(2, 64, 10, 1, 1)
    output = tada_conv2d(input, alpha_tensor)
    print(input.size())    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章