论文介绍
题目:FECAM: Frequency Enhanced Channel Attention Mechanism for Time Series Forecasting
论文地址:https://arxiv.org/abs/2212.01209
QQ深度学习交流群:719278780
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
频域增强的通道注意机制 (FECAM):提出了一种新的频域增强通道注意力机制,通过离散余弦变换 (DCT) 代替传统的傅里叶变换 (FT) 提取频率信息。相比傅里叶变换,DCT天然避免了因周期性问题带来的吉布斯现象 (Gibbs Phenomenon),从而减少了高频噪声的引入。
模块的通用性:FECAM不仅可以作为独立模型用于时间序列预测,还可以无缝嵌入主流的时间序列模型(如基于Transformer的模型和LSTM),提升这些模型在时间序列预测任务中的表现。这种通用性使得FECAM具有较高的实用价值和灵活性。
理论证明与验证:论文中通过理论分析和实验证明了FECAM在频域建模的有效性,尤其是利用DCT进行频率信息提取,显著提升了模型的预测性能。
实验结果表现优异:在六个真实世界的时间序列数据集上,FECAM在预测准确性方面达到了最新的最佳效果,并且相比其他方法具有更少的参数增量和计算开销。
方法
整体结构
输入处理与通道划分:首先将多变量时间序列数据按通道维度拆分成多个子序列,每个子序列包含不同的变量特征。
离散余弦变换 (DCT):对每个通道进行DCT变换,提取出对应的频率分量。这一步骤能够避免传统傅里叶变换带来的周期性问题(即吉布斯现象),从而更高效地捕捉低频信息,同时避免高频噪声的干扰。
通道注意力机制:通过频域的特征图,FECAM可以在不同通道和频率分量之间自适应地建模。使用全连接层对频率增强后的特征图进行加权学习,从而获得每个通道和频率分量的重要性。
重建与输出:最后,将学习到的频域信息和通道注意力机制的结果重新组合,生成增强后的时间序列预测。这一步通过全连接层或投影层来进行,确保频域和时间域信息能够在预测中得到充分利用。
即插即用模块作用
EFCAttention 作为一个即插即用模块,主要适用于:
时间序列预测场景:
用于电力负荷预测、气象数据预测、金融数据分析、交通流量预测等领域,帮助模型处理周期性和趋势性强的数据。
频域信息的增强:
FECAttention通过离散余弦变换(DCT)获取数据的频域特征,有效捕捉低频和高频信息,避免传统傅里叶变换带来的高频噪声问题。
增强特征重要性:
该模块在不同通道和频率分量之间自适应地建模,提升特征的表达能力,使模型能够更精准地学习到时间序列数据中的重要模式。
提升模型鲁棒性与预测精度:
通过引入频域信息,FECAttention可以显著提升各种时序模型(如LSTM、Transformer等)的预测性能,尤其在包含丰富低频信息的数据集上效果尤为显著。
消融实验结果
该表显示了在不同数据集上,将FECAM模块嵌入到主流的Transformer和RNN模型(如LSTM、Reformer、Informer、Autoformer等)后所带来的性能提升情况。实验结果表明,FECAM模块显著提升了各个模型的预测精度,尤其是在Exchange、ETTm2和Weather等包含丰富低频信息的数据集上效果尤为显著,而在Traffic数据集上提升相对较小,表明FECAM对频率信息较为敏感的数据集更具优势。
即插即用模块
import torch.nn as nn
import numpy as np
import torch
#论文:FECAM: Frequency Enhanced Channel Attention Mechanism for Time Series Forecasting
#论文地址:https://arxiv.org/abs/2212.01209
try:
from torch import irfft
from torch import rfft
except ImportError:
def rfft(x, d):
t = torch.fft.fft(x, dim=(-d))
r = torch.stack((t.real, t.imag), -1)
return r
def irfft(x, d):
t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d))
return t.real
def dct(x, norm=None):
"""
Discrete Cosine Transform, Type II (a.k.a. the DCT)
For the meaning of the parameter `norm`, see:
https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
:param x: the input signal
:param norm: the normalization, None or 'ortho'
:return: the DCT-II of the signal over the last dimension
"""
x_shape = x.shape
N = x_shape[-1]
x = x.contiguous().view(-1, N)
v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
Vc = rfft(v, 1)
k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
W_r = torch.cos(k)
W_i = torch.sin(k)
V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
if norm == 'ortho':
V[:, 0] /= np.sqrt(N) * 2
V[:, 1:] /= np.sqrt(N / 2) * 2
V = 2 * V.view(*x_shape)
return V
class dct_channel_block(nn.Module):
def __init__(self, channel):
super(dct_channel_block, self).__init__()
self.fc = nn.Sequential(
nn.Linear(channel, channel * 2, bias=False),
nn.Dropout(p=0.1),
nn.ReLU(inplace=True),
nn.Linear(channel * 2, channel, bias=False),
nn.Sigmoid()
)
self.dct_norm = nn.LayerNorm([96], eps=1e-6) # for lstm on length-wise
def forward(self, x):
b, c, l = x.size() # (B,C,L) (32,96,512)
list = []
for i in range(c):
freq = dct(x[:, i, :])
list.append(freq)
stack_dct = torch.stack(list, dim=1)
lr_weight = self.dct_norm(stack_dct)
lr_weight = self.fc(lr_weight)
lr_weight = self.dct_norm(lr_weight)
return x * lr_weight # result
if __name__ == '__main__':
input = torch.rand(8, 7, 96)
block = dct_channel_block(96)
result = block(input)
print(input.size()) print(result.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文