NeurIPS 条件卷积模块CondConv,涨点起飞起飞了!

文摘   2024-12-13 17:20   上海  

论文介绍

题目:CondConv: Conditionally Parameterized Convolutions for Efficient Inference

论文地址:https://arxiv.org/pdf/1904.04971

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 条件卷积 (CondConv) 的提出:论文提出了 CondConv,这是一种条件参数化的卷积层,它根据输入样本动态生成卷积核,而不是使用静态共享的卷积核。这种设计可以针对每个输入生成特定的卷积核,从而提升模型的灵活性和容量。

  • 高效的计算设计:CondConv 将卷积核参数化为多个专家的线性组合,专家的权重由输入样本决定。相比直接增加卷积核的大小或数量,这种方式可以显著提高模型容量,同时仅略微增加推理计算成本。

  • 可插拔设计:CondConv 可以作为现有卷积层的替代品应用于任何卷积神经网络架构中,无需对基础模型进行复杂修改。

  • 性能与效率的优化

    • 在 ImageNet 数据集上,CondConv 在多种基础模型(如 MobileNetV1、EfficientNet 等)中提升了分类精度,并且只带来轻微的计算开销增加。

    • 在 COCO 目标检测任务中,CondConv 同样提高了检测性能。

  • 与现有方法的对比:CondConv 比传统的多分支架构(如 Inception 和 ResNeXt)更高效;与条件计算方法相比,它不需要复杂的离散路由策略,更容易优化。

  • 数据依赖的卷积核生成:通过结合全局上下文信息和基于输入的动态路由函数,CondConv 学习到了有语义意义的卷积核,展示了不同专家在处理特定类型输入上的专长。

方法

整体架构

     这篇论文提出了基于 CondConv(条件参数化卷积)的模型结构,它通过动态生成输入依赖的卷积核取代传统静态卷积核,核心是将卷积核参数化为多专家(Mixture of Experts)的线性组合,专家权重由输入通过路由函数生成。CondConv 可无缝集成到现有 CNN 架构中,作为普通卷积层的替代,显著提升了多种模型(如 MobileNet、ResNet 和 EfficientNet)的性能,同时保持较低的计算开销,非常适合对实时性和计算资源有限的应用场景。

1. CondConv 模块设计

  • 传统卷积的改进

    • W1,W2,,WnW_1, W_2, \ldots, W_n 是多个“专家”卷积核。

    • α1,α2,,αn\alpha_1, \alpha_2, \ldots, \alpha_n 是由输入样本通过路由函数生成的权重。

    • 传统卷积层对所有输入样本共享一个固定的卷积核。

    • CondConv 模块的核心创新在于使用条件参数化,卷积核根据输入样本动态生成,形式为:Output(x)=σ((α1W1+α2W2++αnWn)x)

  • 路由函数(Routing Function)

    • 路由函数根据输入xx 的全局平均池化结果生成专家权重α\alphar(x)=Sigmoid(GlobalAveragePooling(x)R)r(x) = \text{Sigmoid}(\text{GlobalAveragePooling}(x) R)其中RR 是可学习的路由参数。

  • 与多分支模型的对比

    • CondConv 等价于一种线性专家混合(Mixture of Experts),但只需计算一次卷积,大幅降低了计算成本。


2. CondConv 的模型整合

论文在多种基础 CNN 架构中使用 CondConv 模块,包括:

  • MobileNetV1/V2

    • 在部分卷积块中替换为 CondConv 层,并在分类头中引入 CondConv。

  • ResNet-50

    • 替换最后几个残差块中的卷积层为 CondConv,并在最终分类层中引入 CondConv。

  • EfficientNet-B0

    • 在 EfficientNet 的最后几个块组中使用 CondConv,并扩展到 CondConv-EfficientNet-B0-depth,以进一步探索 CondConv 的扩展能力。


3. 模型的完整训练与推理流程

  • 训练阶段

    • 使用专家混合(Mixture of Experts)的公式进行训练,因为这种方式可以利用现有硬件的并行能力。

    • 每个 CondConv 层通过输入样本生成卷积核,随后对小批量数据进行卷积操作。

  • 推理阶段

    • 推理时直接使用动态生成的卷积核,计算更加高效。


4. 实验验证的基础架构

论文的实验主要基于以下 CNN 模型验证:

  • ImageNet 图像分类

    • MobileNetV1、MobileNetV2、ResNet-50、MnasNet、EfficientNet-B0。

    • CondConv 替换部分卷积层后,显著提升分类精度。

  • COCO 目标检测

    • 使用 Single Shot Detector (SSD) 架构,以 MobileNetV1 为特征提取器,验证 CondConv 在检测任务中的有效性。

即插即用模块作用

CondConv 作为一个即插即用模块

  • 高性能推理任务

    • 在需要高准确率高模型容量的同时,必须保持推理计算高效的场景。

      例如:实时视频处理、自动驾驶感知任务、AR/VR应用中的计算机视觉模块。

  • 计算资源有限的环境

    • 需要在资源受限的硬件(如移动设备、嵌入式设备)上部署高效模型。

    • CondConv 能够用更少的计算开销提供接近或优于传统卷积层的性能。

  • 数据特异性强的任务

    • 数据集中的样本间差异较大,模型需要动态适应不同输入特征。

    • 例如:目标检测、图像分类、语义分割等任务中,CondConv 可根据输入生成专用的卷积核以提升精度。

  • 动态场景或多任务处理

    • 在需要处理多样性任务(如多类别、动态环境)的场景下,CondConv 的动态路由机制使其可以针对输入样本优化计算。

消融实验结果

表3:不同路由函数架构对性能的影响

  • 内容:比较了各种路由函数设计的影响,包括单层全连接(baseline)、部分共享路由(Partially-shared)、引入隐藏层的非线性路由(Hidden 小/中/大)、分层路由(Hierarchical)和 Softmax 激活。

  • 结果

    • 单层全连接路由函数(baseline)表现最佳,说明简单且高效的路由函数已足够区分输入。

    • 隐藏层过大的路由函数(Hidden Large)容易过拟合,性能下降。

    • 使用 Sigmoid 激活比 Softmax 激活效果更好,表明多个专家同时参与计算更优。

表4:CondConv 应用于不同深度的层

  • 内容:探讨 CondConv 放置在不同层的位置(从输入层到深层,以及是否替换最终全连接层)的效果。

  • 结果

    • CondConv 在靠近网络中间和后期的层更有效果,尤其是深层的特征语义更丰富时。

    • 替换最终全连接层为 CondConv 带来了显著性能提升,但也增加了一定计算成本。

即插即用模块

#论文地址:https://arxiv.org/pdf/1904.04971
#论文:CondConv: Conditionally Parameterized Convolutions for Efficient Inference

import functools
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.utils import _pair
from torch.nn.parameter import Parameter


class _routing(nn.Module):

    def __init__(self, in_channels, num_experts, dropout_rate):
        super(_routing, self).__init__()

        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(in_channels, num_experts)

    def forward(self, x):
        x = torch.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        return torch.sigmoid(x)


class CondConv2D(_ConvNd):
    r"""Learn specialized convolutional kernels for each example.
    As described in the paper
    `CondConv: Conditionally Parameterized Convolutions for Efficient Inference`_ ,
    conditionally parameterized convolutions (CondConv),
    which challenge the paradigm of static convolutional kernels
    by computing convolutional kernels as a function of the input.
    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        padding_mode (string, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
        bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
        num_experts (int): Number of experts per layer
    Shape:
        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where
          .. math::
              H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0]
                        \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor
          .. math::
              W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1]
                        \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor
    Attributes:
        weight (Tensor): the learnable weights of the module of shape
                         :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
                         :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
                         The values of these weights are sampled from
                         :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                         :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
        bias (Tensor): the learnable bias of the module of shape (out_channels). If :attr:`bias` is ``True``,
                         then the values of these weights are
                         sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                         :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
    .. _CondConv: Conditionally Parameterized Convolutions for Efficient Inference:
       https://arxiv.org/abs/1904.04971
    """

    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1,
                 padding=0, dilation=1, groups=1,
                 bias=True, padding_mode='zeros', num_experts=3, dropout_rate=0.2):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(CondConv2D, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

        self._avg_pooling = functools.partial(F.adaptive_avg_pool2d, output_size=(1, 1))
        self._routing_fn = _routing(in_channels, num_experts, dropout_rate)

        self.weight = Parameter(torch.Tensor(
            num_experts, out_channels, in_channels // groups, *kernel_size))

        self.reset_parameters()

    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, inputs):
        b, _, _, _ = inputs.size()
        res = []
        for input in inputs:
            input = input.unsqueeze(0)
            pooled_inputs = self._avg_pooling(input)
            routing_weights = self._routing_fn(pooled_inputs)
            kernels = torch.sum(routing_weights[:, None, None, None, None] * self.weight, 0)
            out = self._conv_forward(input, kernels)
            res.append(out)
        return torch.cat(res, dim=0)


if __name__ == '__main__':

    input = torch.randn(3, 64, 64, 64)

    block = CondConv2D(64, 128)

    # 前向传播
    output = block(input)

    # 打印输入和输出的形状
    print(input.size())
    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章