即插即用DO-Conv模块,涨点起飞起飞了

文摘   2024-12-12 17:20   上海  

论文介绍

题目:DO-Conv: Depthwise Over-parameterized Convolutional Layer

论文地址:https://arxiv.org/abs/2006.12030

QQ深度学习交流群:719278780

扫描下方二维码,加入深度学习论文指南星球!

加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务

创新点

  • 引入深度方向过参数化卷积层(DO-Conv)

    • 该论文提出了通过加入深度卷积操作对传统卷积层进行过参数化的创新方法。

    • 通过这种过参数化,卷积层的线性操作可以用一个单一的卷积层来表达,而不会增加推理阶段的计算复杂度。

  • 性能提升

    • 实验表明,使用DO-Conv替换传统卷积层可以显著提升CNN在图像分类、检测和分割等经典视觉任务中的表现。

    • 在训练阶段,DO-Conv加速了网络的收敛并提高了最终模型的准确性。

  • 灵活性和兼容性

    • DO-Conv与现有的网络架构(如ResNet、MobileNet等)兼容,可以无缝替代传统卷积层。

    • 在推理阶段,DO-Conv会折叠成一个等价的传统卷积层,从而避免额外的计算开销。

  • 启发性和广泛适用性

    • 该方法揭示了过参数化作为神经网络设计中未被充分探索的潜力维度。提供了在主流深度学习框架(TensorFlow、PyTorch和GluonCV)中的开源实现,有助于研究者和开发者进一步探索其应用。

方法

整体架构

      模型以**深度方向过参数化卷积层(DO-Conv)**为核心,通过将传统卷积层扩展为深度卷积与标准卷积的组合,提升了模型的训练效率和最终性能。在训练阶段,DO-Conv独立优化两个卷积操作,而在推理阶段将它们“折叠”成一个等效的标准卷积,保持计算效率。DO-Conv可无缝替代现有卷积网络的标准卷积层,广泛适用于图像分类、语义分割和目标检测等任务,实现性能的显著提升而无需增加推理复杂度。

  • 输入数据处理

    • 输入是一个卷积层的标准输入特征图。

    • 特征图通过一个滑动窗口的方式处理,每次处理一个局部小块。

  • 双卷积操作

    • 使用标准卷积核对深度卷积的输出进行进一步处理,生成最终的输出特征图。

    • 每个输入通道通过一个单独的二维卷积核进行卷积操作。

    • 输出特征图的通道数取决于深度卷积的深度乘数(Depth Multiplier,Dmul)。

  • 参数化

    • 深度卷积的参数表示为一个张量DR(M×N)×Dmul×CinD \in \mathbb{R}^{(M \times N) \times D_{mul} \times C_{in}},表示卷积核大小、深度乘数和输入通道数。

    • 标准卷积的参数表示为WRCout×Dmul×CinW \in \mathbb{R}^{C_{out} \times D_{mul} \times C_{in}},表示输出通道数、深度乘数和输入通道数。

  • 训练与推理

    • 在训练阶段,两个卷积操作独立优化。

    • 在推理阶段,深度卷积和标准卷积的权重通过数学变换“折叠”成一个等价的单一卷积核,以降低推理复杂度。

即插即用模块作用

DO-Conv 作为一个即插即用模块

  • 经典计算机视觉任务

    • 图像分类:如在CIFAR和ImageNet数据集上的分类任务中,DO-Conv可以直接替换现有网络的卷积层,提升分类精度。

    • 语义分割:在细粒度分割任务中(如PASCAL VOC和Cityscapes数据集),DO-Conv能提高分割模型的性能。

    • 目标检测:适用于COCO等数据集的目标检测任务,在主干网络(Backbone)和检测阶段的卷积层中替换为DO-Conv,可以提升检测精度。

  • 现有CNN架构的增强

    • DO-Conv可无缝集成到主流架构(如ResNet、MobileNet、ResNeXt等)中,优化现有网络的性能。特别适合对卷积层的性能提升要求较高,但又不希望增加推理复杂度的场景。

  • 计算资源受限的部署需求

    • DO-Conv在推理阶段无需增加计算复杂度,这使其特别适合部署在资源有限的设备(如移动设备和嵌入式设备)上。

  • 需要快速训练的场景

    • DO-Conv加速了模型的训练收敛速度,适用于对训练效率有高要求的实验环境或需要快速迭代的开发场景。

消融实验结果

不同ResNet阶段使用DO-Conv的效果

  • 测试了在ResNet的不同阶段(0到4阶段)替换为DO-Conv的效果。

  • 结果显示,在CIFAR-100数据集上,使用更多DO-Conv的阶段可以带来持续的性能提升。而在ImageNet数据集上,不同阶段的效果更加复杂,有些阶段替换可能会导致性能下降,但在第一阶段的一致提升表明其关键作用。

DO-Conv的不同初始化方法

  • 比较了DO-Conv中的深度卷积参数DD 的两种初始化方式:

    • 随机初始化(Random-init)。

    • 单位矩阵初始化(Identity-init)。

  • 结果表明,单位矩阵初始化可以显著提升性能(比随机初始化的增益更高),因为它能更好地保留原始网络的特性

即插即用模块

# 论文地址:https://arxiv.org/abs/2006.12030
# 论文:DO-Conv: Depthwise Over-parameterized Convolutional Layer
import math
import torch
import numpy as np
from torch.nn import init
from itertools import repeat
from torch.nn import functional as F
from torch._jit_internal import Optional
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import collections


class DOConv2d(Module):
    """
       DOConv2d can be used as an alternative for torch.nn.Conv2d.
       The interface is similar to that of Conv2d, with one exception:
            1. D_mul: the depth multiplier for the over-parameterization.
       Note that the groups parameter switchs between DO-Conv (groups=1),
       DO-DConv (groups=in_channels), DO-GConv (otherwise).
    """

    __constants__ = ['stride', 'padding', 'dilation', 'groups',
                     'padding_mode', 'output_padding', 'in_channels',
                     'out_channels', 'kernel_size', 'D_mul']
    __annotations__ = {'bias': Optional[torch.Tensor]}

    def __init__(self, in_channels, out_channels, kernel_size, D_mul=None, stride=1,
                 padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros')
:

        super(DOConv2d, self).__init__()

        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
        if padding_mode not in valid_padding_modes:
            raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
                valid_padding_modes, padding_mode))
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode
        self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))

        #################################### Initailization of D & W ###################################
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        self.D_mul = M * N if D_mul is None or M * N <= 1 else D_mul
        self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, self.D_mul))
        init.kaiming_uniform_(self.W, a=math.sqrt(5))

        if M * N > 1:
            self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
            init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
            self.D.data = torch.from_numpy(init_zero)

            eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
            d_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
            if self.D_mul % (M * N) != 0: # the cases when D_mul > M * N
                zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
                self.d_diag = Parameter(torch.cat([d_diag, zeros], dim=2), requires_grad=False)
            else: # the case when D_mul = M * N
                self.d_diag = Parameter(d_diag, requires_grad=False)
        ##################################################################################################

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
        else:
            self.register_parameter('bias', None)

    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'
        return s.format(**self.__dict__)

    def __setstate__(self, state):
        super(DOConv2d, self).__setstate__(state)
        if not hasattr(self, 'padding_mode'):
            self.padding_mode = 'zeros'

    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input):
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
        if M * N > 1:
            ######################### Compute DoW #################
            # (input_channels, D_mul, M * N)
            D = self.D + self.d_diag
            W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))

            # einsum outputs (out_channels // groups, in_channels, M * N),
            # which is reshaped to
            # (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
            #######################################################
        else:
            # in this case D_mul == M * N
            # reshape from
            # (out_channels, in_channels // groups, D_mul)
            # to
            # (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(self.W, DoW_shape)
        return self._conv_forward(input, DoW)


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


_pair = _ntuple(2)

if __name__ == '__main__':
    input = torch.randn(1, 64, 64, 64)
    print(input.size())
    block = DOConv2d(in_channels=64, out_channels=32, kernel_size=1)
    output = block(input)    print(output.size())

便捷下载方式

浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules

更多分析可见原文


ai缝合大王
聚焦AI前沿,分享相关技术、论文,研究生自救指南
 最新文章