点击下方卡片,关注“AI前沿速递”公众号
点击下方卡片,关注“AI前沿速递”公众号
各种重磅干货,第一时间送达
各种重磅干货,第一时间送达
题目:Wavelet Convolutions for Large Receptive Fields
论文地址:https://arxiv.org/pdf/2407.05848
代码地址:GitHub - BGU-CS-VIL/WTConv: Wavelet Convolutions for Large Receptive Fields. ECCV 2024.
创新点
引入小波变换:
WTConv首次将小波变换(WT)应用于卷积神经网络中,通过小波变换扩展卷积的感受野,同时避免过度参数化。
多频率响应:
WTConv层利用小波分解将输入分成不同频带,允许卷积层在低频和高频分量上分别进行处理,增强了模型对低频成分(即形状特征)的响应。
低参数增长率:
与传统方法中卷积核尺寸增大导致参数和计算量指数级增长不同,WTConv实现了参数的对数增长,使得在大感受野的情况下保持参数效率。
即插即用性:
WTConv被设计为可以直接替换现有CNN中的深度卷积操作,无需对架构进行额外修改,具有广泛适用性。
增强鲁棒性与形状偏向:
实验结果表明,WTConv能够提升CNN在图像分类、语义分割和目标检测等任务中的性能,并对图像腐蚀有更好的鲁棒性,同时提高了模型的形状偏向性。
计算成本控制:
尽管WTConv处理的通道数是原始输入的四倍,但每个小波域卷积在空间维度上减少了一个因子2,因此在浮点运算(FLOPs)方面的计算成本相对较低
整体结构
WTConv模块的整体结构主要包括以下几个部分:
小波变换(Wavelet Transform):
使用小波变换对输入特征图进行分解,得到低频和高频成分。具体来说,输入特征图经过小波变换后,被分解为低频分量(LL)和三个高频分量(LH、HL、HH)。 这一过程通过wavelet_transform函数实现,使用预定义的小波滤波器对输入进行卷积操作。
小波域卷积:
在小波域中对低频和高频成分分别进行卷积操作。对于低频分量,使用标准的卷积操作;对于高频分量,使用深度卷积(depthwise convolution)来提取特征。 这些卷积操作通过nn.Conv2d和nn.ModuleList实现,其中nn.Conv2d用于定义卷积层,nn.ModuleList用于存储多个卷积层。
逆小波变换(Inverse Wavelet Transform):
将小波域中的特征图通过逆小波变换恢复到原始空间域。这一过程通过inverse_wavelet_transform函数实现,使用逆小波滤波器对特征图进行转置卷积操作。
整合与输出:
将经过小波变换和逆小波变换后的特征图与原始特征图进行整合,最终输出经过WTConv模块处理的特征图。这一整合过程通过简单的加法操作实现。WTConv模块通过这种结构设计,实现了在保持参数效率的同时,扩大了卷积层的感受野,并增强了对不同频率信息的处理能力。
适用场景
需要大感受野的任务:WTConv通过小波变换扩展了卷积层的感受野,适用于需要捕捉更大上下文信息的视觉任务,如图像分类、语义分割和目标检测等. 增强低频信息响应的场景:在涉及形状和边缘特征识别的场景中,WTConv可以更好地响应低频成分,提升模型对图像中整体形状和轮廓的捕捉能力. 需要保持参数效率的网络架构:WTConv实现了参数的对数增长,使得在大感受野的情况下保持参数效率,适用于对参数量和计算量有严格限制的网络架构. 需要提高模型鲁棒性的应用:实验结果表明,WTConv能够提升CNN在图像分类等任务中的性能,并对图像腐蚀有更好的鲁棒性. 需要增强模型形状偏向性的任务:WTConv提高了模型的形状偏向性,使其在处理形状特征时表现更优.
实现代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class WTConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, wt_levels, wt_type='haar', stride=1, padding=0, dilation=1, bias=False):
super(WTConv2d, self).__init__()
assert in_channels == out_channels
self.in_channels = in_channels
self.wt_levels = wt_levels
self.stride = stride
self.dilation = 1
self.wt_filter, self.iwt_filter = create_wavelet_filter(wt_type, in_channels, in_channels, torch.float)
self.wt_filter = nn.Parameter(self.wt_filter, requires_grad=False)
self.iwt_filter = nn.Parameter(self.iwt_filter, requires_grad=False)
self.wt_function = wavelet_transform_init(self.wt_filter)
self.iwt_function = inverse_wavelet_transform_init(self.iwt_filter)
self.base_conv = nn.Conv2d(in_channels, in_channels, kernel_size, padding='same', stride=1, dilation=1,
groups=in_channels, bias=bias)
self.base_scale = _ScaleModule([1, in_channels, 1, 1])
self.wavelet_convs = nn.ModuleList(
[nn.Conv2d(in_channels * 4, in_channels * 4, kernel_size, padding='same', stride=1, dilation=1,
groups=in_channels * 4, bias=False) for _ in range(self.wt_levels)]
)
self.wavelet_scale = nn.ModuleList(
[_ScaleModule([1, in_channels * 4, 1, 1], init_scale=0.1) for _ in range(self.wt_levels)]
)
if self.stride > 1:
self.stride_filter = nn.Parameter(torch.ones(in_channels, 1, 1, 1), requires_grad=False)
self.do_stride = lambda x_in: F.conv2d(x_in, self.stride_filter, bias=None, stride=self.stride,
groups=in_channels)
else:
self.do_stride = None
def forward(self, x):
x_ll_in_levels = []
x_h_in_levels = []
shapes_in_levels = []
curr_x_ll = x
for i in range(self.wt_levels):
curr_shape = curr_x_ll.shape
shapes_in_levels.append(curr_shape)
if (curr_shape[2] % 2 > 0) or (curr_shape[3] % 2 > 0):
curr_pads = (0, curr_shape[3] % 2, 0, curr_shape[2] % 2)
curr_x_ll = F.pad(curr_x_ll, curr_pads)
curr_x = self.wt_function(curr_x_ll)
curr_x_ll = curr_x[:, :, 0, :, :]
shape_x = curr_x.shape
curr_x_tag = curr_x.reshape(shape_x[0], shape_x[1] * 4, shape_x[3], shape_x[4])
curr_x_tag = self.wavelet_scale[i](self.wavelet_convs[i](curr_x_tag))
curr_x_tag = curr_x_tag.reshape(shape_x)
x_ll_in_levels.append(curr_x_tag[:, :, 0, :, :])
x_h_in_levels.append(curr_x_tag[:, :, 1:4, :, :])
next_x_ll = 0
for i in range(self.wt_levels - 1, -1, -1):
curr_x_ll = x_ll_in_levels.pop()
curr_x_h = x_h_in_levels.pop()
curr_shape = shapes_in_levels.pop()
curr_x_ll = curr_x_ll + next_x_ll
curr_x = torch.cat([curr_x_ll.unsqueeze(2), curr_x_h], dim=2)
next_x_ll = self.iwt_function(curr_x)
next_x_ll = next_x_ll[:, :, :curr_shape[2], :curr_shape[3]]
x_tag = next_x_ll
assert len(x_ll_in_levels) == 0
x = self.base_scale(self.base_conv(x))
x = x + x_tag
return x
确保文章为个人原创,未在任何公开渠道发布。若文章已在其他平台发表或即将发表,请明确说明。
建议使用Markdown格式撰写稿件,并以附件形式发送清晰、无版权争议的配图。
【AI前沿速递】尊重作者的署名权,并为每篇被采纳的原创首发稿件提供具有市场竞争力的稿酬。稿酬将根据文章的阅读量和质量进行阶梯式结算。
您可以通过添加我们的小助理微信(aiqysd)进行快速投稿。请在添加时备注“投稿-姓名-学校-研究方向”
长按添加AI前沿速递小助理