ECCV 2024 | 小波变换卷积的即插即用,引入了频域信息,助你涨点起飞

文摘   2025-01-13 11:36   安徽  

点击下方卡片,关注“AI前沿速递”公众号

各种重磅干货,第一时间送达



题目:Wavelet Convolutions for Large Receptive Fields

论文地址:https://arxiv.org/pdf/2407.05848

代码地址:GitHub - BGU-CS-VIL/WTConv: Wavelet Convolutions for Large Receptive Fields. ECCV 2024.

创新点

  1. 引入小波变换:
  • WTConv首次将小波变换(WT)应用于卷积神经网络中,通过小波变换扩展卷积的感受野,同时避免过度参数化。
  1. 多频率响应:
  • WTConv层利用小波分解将输入分成不同频带,允许卷积层在低频和高频分量上分别进行处理,增强了模型对低频成分(即形状特征)的响应。
  1. 低参数增长率:
  • 与传统方法中卷积核尺寸增大导致参数和计算量指数级增长不同,WTConv实现了参数的对数增长,使得在大感受野的情况下保持参数效率。
  1. 即插即用性:
  • WTConv被设计为可以直接替换现有CNN中的深度卷积操作,无需对架构进行额外修改,具有广泛适用性。
  1. 增强鲁棒性与形状偏向:
  • 实验结果表明,WTConv能够提升CNN在图像分类、语义分割和目标检测等任务中的性能,并对图像腐蚀有更好的鲁棒性,同时提高了模型的形状偏向性。
  1. 计算成本控制:
  • 尽管WTConv处理的通道数是原始输入的四倍,但每个小波域卷积在空间维度上减少了一个因子2,因此在浮点运算(FLOPs)方面的计算成本相对较低

整体结构

WTConv模块的整体结构主要包括以下几个部分:

  1. 小波变换(Wavelet Transform):
  • 使用小波变换对输入特征图进行分解,得到低频和高频成分。具体来说,输入特征图经过小波变换后,被分解为低频分量(LL)和三个高频分量(LH、HL、HH)。
  • 这一过程通过wavelet_transform函数实现,使用预定义的小波滤波器对输入进行卷积操作。
  1. 小波域卷积:
  • 在小波域中对低频和高频成分分别进行卷积操作。对于低频分量,使用标准的卷积操作;对于高频分量,使用深度卷积(depthwise convolution)来提取特征。
  • 这些卷积操作通过nn.Conv2d和nn.ModuleList实现,其中nn.Conv2d用于定义卷积层,nn.ModuleList用于存储多个卷积层。
  1. 逆小波变换(Inverse Wavelet Transform):
  • 将小波域中的特征图通过逆小波变换恢复到原始空间域。这一过程通过inverse_wavelet_transform函数实现,使用逆小波滤波器对特征图进行转置卷积操作。
  1. 整合与输出:
  • 将经过小波变换和逆小波变换后的特征图与原始特征图进行整合,最终输出经过WTConv模块处理的特征图。这一整合过程通过简单的加法操作实现。WTConv模块通过这种结构设计,实现了在保持参数效率的同时,扩大了卷积层的感受野,并增强了对不同频率信息的处理能力。

适用场景

  1. 需要大感受野的任务:WTConv通过小波变换扩展了卷积层的感受野,适用于需要捕捉更大上下文信息的视觉任务,如图像分类、语义分割和目标检测等.
  2. 增强低频信息响应的场景:在涉及形状和边缘特征识别的场景中,WTConv可以更好地响应低频成分,提升模型对图像中整体形状和轮廓的捕捉能力.
  3. 需要保持参数效率的网络架构:WTConv实现了参数的对数增长,使得在大感受野的情况下保持参数效率,适用于对参数量和计算量有严格限制的网络架构.
  4. 需要提高模型鲁棒性的应用:实验结果表明,WTConv能够提升CNN在图像分类等任务中的性能,并对图像腐蚀有更好的鲁棒性.
  5. 需要增强模型形状偏向性的任务:WTConv提高了模型的形状偏向性,使其在处理形状特征时表现更优.

实现代码

import torch
import torch.nn as nn
import torch.nn.functional as F

class WTConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, wt_levels, wt_type='haar', stride=1, padding=0, dilation=1, bias=False):
        super(WTConv2d, self).__init__()
        assert in_channels == out_channels
        self.in_channels = in_channels
        self.wt_levels = wt_levels
        self.stride = stride
        self.dilation = 1
        self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
        self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
        self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
        self.wt_function = wavelet_transform_init(self.wt_filter)
        self.iwt_function = inverse_wavelet_transform_init(self.iwt_filter)
        self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,
                                   groups=in_channels, bias=bias)
        self.base_scale = _ScaleModule([1, in_channels, 1, 1])
        self.wavelet_convs = nn.ModuleList(
            [nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,
                       groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
        )
        self.wavelet_scale = nn.ModuleList(
            [_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)]
        )
        if self.stride > 1:
            self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)
            self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,
                                                   groups=in_channels)
        else:
            self.do_stride = None

    def forward(self, x):
        x_ll_in_levels = []
        x_h_in_levels = []
        shapes_in_levels = []
        curr_x_ll = x
        for i in range(self.wt_levels):
            curr_shape = curr_x_ll.shape
            shapes_in_levels.append(curr_shape)
            if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
                curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
                curr_x_ll = F.pad(curr_x_ll, curr_pads)
            curr_x = self.wt_function(curr_x_ll)
            curr_x_ll = curr_x[:, :, 0, :, :]
            shape_x = curr_x.shape
            curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
            curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
            curr_x_tag = curr_x_tag.reshape(shape_x)
            x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
            x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
        next_x_ll = 0
        for i in range(self.wt_levels - 1, -1, -1):
            curr_x_ll = x_ll_in_levels.pop()
            curr_x_h = x_h_in_levels.pop()
            curr_shape = shapes_in_levels.pop()
            curr_x_ll = curr_x_ll + next_x_ll
            curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
            next_x_ll = self.iwt_function(curr_x)
            next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
        x_tag = next_x_ll
        assert len(x_ll_in_levels) == 0
        x = self.base_scale(self.base_conv(x))
        x = x + x_tag
        return x



本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。


欢迎投稿

想要让高质量的内容更快地触达读者,降低他们寻找优质信息的成本吗?关键在于那些你尚未结识的人。他们可能掌握着你渴望了解的知识。【AI前沿速递】愿意成为这样的一座桥梁,连接不同领域、不同背景的学者,让他们的学术灵感相互碰撞,激发出无限可能。

【AI前沿速递】欢迎各高校实验室和个人在我们的平台上分享各类精彩内容,无论是最新的论文解读,还是对学术热点的深入分析,或是科研心得和竞赛经验的分享,我们的目标只有一个:让知识自由流动。

📝 投稿指南

  • 确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。

  • 建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。

  • 【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。

📬 投稿方式

  • 您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”


    长按添加AI前沿速递小助理


AI前沿速递
持续分享最新AI前沿论文成果
 最新文章