标题:FFT-based Dynamic Token Mixer for Vision
论文链接:https://arxiv.org/pdf/2303.03932
代码链接:https://github.com/okojoalg/dfformer
来源:AAAI 2024
1.定义与结构
Dynamic Filter模块主要由两个部分组成:
Filter-Generating Network (滤波器生成网络):该网络根据输入动态生成滤波器。输入可以是图像或其他特征,输出是针对特定样本的滤波器。这些滤波器是样本特定的,而不是固定的模型参 数。 Dynamic Filtering Layer (动态过滤层):该层将生成的滤波器应用到输入上,生成最终的输 出。这个过程是可微分的,允许端到端的训练。
2.工作原理
- 滤波器生成:滤波器生成网络接收输入,生成滤波器 ,该滤波器可以应用于另一个输入 以生成输出。滤波器的大小决定了感受野的大小,具体选择依赖于应用需求。"
- 动态卷积:动态卷积层使用动态生成的滤波器进行卷积操作,而不是预训练的固定滤波器。这种动 态生成的滤波器可以根据输入内容进行调整,从而提高模型的性能。"
3.应用场景
- 图像生成:根据一个视角图预测其他视角(如旋转人脸)、预测视频下一帧、2D变30等任务中常 用到动态过滤模块。
- 特征提取:在需要自适应特征提取的任务中,动态过滤模块可以灵活地生成适合特定样本的滤波 器,从而提高特征提取的准确性。
- 频域操作:一些动态过滤模块在频域进行操作,利用快速傅里叶变换(FFT)将特征图转换为频 域,应用滤波器后再通过逆FFT转换回空间域。这种方法可以降低计算复杂度,同时捕捉全局信 息。
4. 具体实现
以一个具体的实现为例,Dynamic Filter模块可以包括以下步骤:
定义滤波器基:预定义一组滤波器基,如低通滤波器、高通滤波器或带通滤波器。 使用MLP学习权重:通过一个多层感知机(MLP)层学习每个特征通道对的滤波器基权重。 生成滤波器:根据MLP的输出,将滤波器基与权重相乘并累加,得到最终的滤波器。 频域操作:使用FFT将特征图转换为频域,应用生成的滤波器,最后通过逆FFT将结果转换回空间域。
5. 优势
灵活性:滤波器可以根据输入动态生成,适用于多种不同的任务和数据。 性能提升:通过自适应生成滤波器,可以更好地捕捉输入数据的特征,从而提高模型的性能。 计算效率:在频域进行操作可以降低计算复杂度,同时保持全局信息的捕捉能力。
代码实现
import torch
import torch.nn as nn
from timm.models.layers.helpers import to_2tuple
class StarReLU(nn.Module):
def __init__(self, scale_value=1.0, bias_value=0.0, scale_learnable=True, bias_learnable=True, mode=None, inplace=False):
super().__init__()
inplace =
nn.ReLU(inplace=inplace) =
nn.Parameter(scale_value * torch.ones(1), requires_grad=scale_learnable) =
nn.Parameter(bias_value * torch.ones(1), requires_grad=bias_learnable) =
def forward(self, x):
return self.scale * self.relu(x) ** 2 + self.bias
class Mlp(nn.Module):
def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=StarReLU, drop=0., bias=False, **kwargs):
super().__init__()
in_features = dim
out_features = out_features or in_features
hidden_features = int(mlp_ratio * in_features)
drop_probs = to_2tuple(drop)
nn.Linear(in_features, hidden_features, bias=bias) =
act_layer() =
nn.Dropout(drop_probs[0]) =
nn.Linear(hidden_features, out_features, bias=bias) =
nn.Dropout(drop_probs[1]) =
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class DynamicFilter(nn.Module):
def __init__(self, dim, expansion_ratio=2, reweight_expansion_ratio=.25, act1_layer=StarReLU, act2_layer=nn.Identity, bias=False, num_filters=4, size=14, weight_resize=False, **kwargs):
super().__init__()
size = to_2tuple(size)
size[0] =
size[1] // 2 + 1 =
num_filters =
dim =
int(expansion_ratio * dim) =
weight_resize =
nn.Linear(dim, self.med_channels, bias=bias) =
act1_layer() =
Mlp(dim, reweight_expansion_ratio, num_filters * self.med_channels) =
nn.Parameter(torch.randn(self.size, self.filter_size, num_filters, 2, dtype=torch.float32) * 0.02) =
act2_layer() =
nn.Linear(self.med_channels, dim, bias=bias) =
def forward(self, x):
H, W, _ = x.shape
routeing = self.reweight(x.mean(dim=(1, 2))).view(B, self.num_filters, -1).softmax(dim=1)
x = self.pwconv1(x)
x = self.act1(x)
x = x.to(torch.float32)
x = torch.fft.rfft2(x, dim=(1, 2), norm='ortho')
if self.weight_resize:
complex_weights = resize_complex_weight(self.complex_weights, x.shape[1], x.shape[2])
complex_weights = torch.view_as_complex(complex_weights.contiguous())
else:
complex_weights = torch.view_as_complex(self.complex_weights)
routeing = routeing.to(torch.complex64)
weight = torch.einsum('bfc,hwf->bhwc', routeing, complex_weights)
if self.weight_resize:
weight = weight.view(-1, x.shape[1], x.shape[2], self.med_channels)
else:
weight = weight.view(-1, self.size, self.filter_size, self.med_channels)
x = x * weight
x = torch.fft.irfft2(x, s=(H, W), dim=(1, 2), norm='ortho')
x = self.act2(x)
x = self.pwconv2(x)
return x
if __name__ == '__main__':
x = torch.randn(4, 512, 7, 7).cuda()
x = x.permute(0, 2, 3, 1)
model = DynamicFilter(512, size=7).cuda()
out = model(x)
out = out.permute(0, 3, 1, 2)
print(out.shape)
本文内容为论文学习收获分享,受限于知识能力,本文对原文的理解可能存在偏差,最终内容以原论文为准。本文信息旨在传播和学术交流,其内容由作者负责,不代表本号观点。文中作品文字、图片等如涉及内容、版权和其他问题,请及时与我们联系,我们将在第一时间回复并处理。