AAAI 2024 | 动态频域fft模块,几行代码,有效提点,即插即用

文摘   2025-01-19 18:05   安徽  

标题:FFT-based Dynamic Token Mixer for Vision

论文链接:https://arxiv.org/pdf/2303.03932

代码链接:https://github.com/okojoalg/dfformer

来源:AAAI 2024

1.定义与结构

Dynamic Filter模块主要由两个部分组成:

  • Filter-Generating Network (滤波器生成网络):该网络根据输入动态生成滤波器。输入可以是图像或其他特征,输出是针对特定样本的滤波器。这些滤波器是样本特定的,而不是固定的模型参 数。
  • Dynamic Filtering Layer (动态过滤层):该层将生成的滤波器应用到输入上,生成最终的输 出。这个过程是可微分的,允许端到端的训练。

2.工作原理

- 滤波器生成:滤波器生成网络接收输入,生成滤波器 ,该滤波器可以应用于另一个输入 以生成输出。滤波器的大小决定了感受野的大小,具体选择依赖于应用需求。"

- 动态卷积:动态卷积层使用动态生成的滤波器进行卷积操作,而不是预训练的固定滤波器。这种动 态生成的滤波器可以根据输入内容进行调整,从而提高模型的性能。"

3.应用场景

- 图像生成:根据一个视角图预测其他视角(如旋转人脸)、预测视频下一帧、2D变30等任务中常 用到动态过滤模块。

- 特征提取:在需要自适应特征提取的任务中,动态过滤模块可以灵活地生成适合特定样本的滤波 器,从而提高特征提取的准确性。

- 频域操作:一些动态过滤模块在频域进行操作,利用快速傅里叶变换(FFT)将特征图转换为频 域,应用滤波器后再通过逆FFT转换回空间域。这种方法可以降低计算复杂度,同时捕捉全局信 息。

4. 具体实现

以一个具体的实现为例,Dynamic Filter模块可以包括以下步骤:

  1. 定义滤波器基:预定义一组滤波器基,如低通滤波器、高通滤波器或带通滤波器。
  2. 使用MLP学习权重:通过一个多层感知机(MLP)层学习每个特征通道对的滤波器基权重。
  3. 生成滤波器:根据MLP的输出,将滤波器基与权重相乘并累加,得到最终的滤波器。
  4. 频域操作:使用FFT将特征图转换为频域,应用生成的滤波器,最后通过逆FFT将结果转换回空间域。

5. 优势

  • 灵活性:滤波器可以根据输入动态生成,适用于多种不同的任务和数据。
  • 性能提升:通过自适应生成滤波器,可以更好地捕捉输入数据的特征,从而提高模型的性能。
  • 计算效率:在频域进行操作可以降低计算复杂度,同时保持全局信息的捕捉能力。

代码实现

import torchimport torch.nn as nnfrom timm.models.layers.helpers import to_2tuple
class StarReLU(nn.Module): def __init__(self, scale_value=1.0, bias_value=0.0, scale_learnable=True, bias_learnable=True, mode=None, inplace=False): super().__init__() self.inplace = inplace self.relu = nn.ReLU(inplace=inplace) self.scale = nn.Parameter(scale_value * torch.ones(1), requires_grad=scale_learnable) self.bias = nn.Parameter(bias_value * torch.ones(1), requires_grad=bias_learnable)
def forward(self, x): return self.scale * self.relu(x) ** 2 + self.bias
class Mlp(nn.Module): def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs): super().__init__() in_features = dim out_features = out_features or in_features hidden_features = int(mlp_ratio * in_features) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x
class DynamicFilter(nn.Module): def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25, act1_layer=StarReLU, act2_layer=nn.Identity, bias=False, num_filters=4, size=14, weight_resize=False, **kwargs): super().__init__() size = to_2tuple(size) self.size = size[0] self.filter_size = size[1] // 2 + 1 self.num_filters = num_filters self.dim = dim self.med_channels = int(expansion_ratio * dim) self.weight_resize = weight_resize self.pwconv1 = nn.Linear(dim, self.med_channels, bias=bias) self.act1 = act1_layer() self.reweight = Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels) self.complex_weights = nn.Parameter(torch.randn(self.size, self.filter_size, num_filters, 2, dtype=torch.float32) * 0.02) self.act2 = act2_layer() self.pwconv2 = nn.Linear(self.med_channels, dim, bias=bias)
def forward(self, x): B, H, W, _ = x.shape routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters, -1).softmax(dim=1) x = self.pwconv1(x) x = self.act1(x) x = x.to(torch.float32) x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho') if self.weight_resize: complex_weights = resize_complex_weight(self.complex_weights, x.shape[1], x.shape[2]) complex_weights = torch.view_as_complex(complex_weights.contiguous()) else: complex_weights = torch.view_as_complex(self.complex_weights) routeing = routeing.to(torch.complex64) weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights) if self.weight_resize: weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels) else: weight = weight.view(-1, self.size, self.filter_size, self.med_channels) x = x * weight x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho') x = self.act2(x) x = self.pwconv2(x) return x
if __name__ == '__main__': x = torch.randn(4, 512, 7, 7).cuda() x = x.permute(0, 2, 3, 1) model = DynamicFilter(512, size=7).cuda() out = model(x) out = out.permute(0, 3, 1, 2) print(out.shape)


本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。

AI前沿速递
持续分享最新AI前沿论文成果
 最新文章