论文介绍
题目:FITS: MODELING TIME SERIES WITH 10k PARAME TERS
论文地址:https://arxiv.org/pdf/2307.03756
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
轻量级模型设计:提出了一个名为 FITS(Frequency Interpolation Time Series Analysis)的轻量级模型,其参数量在 5k 到 10k 之间。相比主流时间序列模型(如 TimesNet、DLinear 等),FITS 显著减少了模型大小和计算复杂度,非常适合资源受限的边缘设备部署。
复数值神经网络的创新应用:FITS 是一种基于复数值神经网络的模型,能够同时捕获时间序列的幅值和相位信息。这种方法充分利用了频域的紧凑性和信息丰富性,为时间序列分析提供了一种更高效的表示方式。
频域内插方法的提出:通过快速傅里叶变换(FFT)将时间序列数据转换到复数频域,进行频率内插,从而完成时间序列的预测和重建任务。这种方法避免了对时间域数据的直接操作,提升了计算效率。
低通滤波的有效集成:在模型中加入了低通滤波器(Low Pass Filter, LPF),通过去除高频噪声成分,进一步压缩模型规模,同时保留了关键的时间序列信息。
兼顾性能与效率:FITS 在多个时间序列预测和异常检测任务上实现了与主流模型相当甚至更优的性能(如表 1 和表 2 中所示),但其参数量和计算量显著减少。例如,在 Electricity 数据集上的实验中,FITS 的参数量仅为 TimesNet 的 0.003% 左右,仍然能够获得竞争性甚至更优的预测精度。
方法
整体架构
FITS 模型通过将时间序列转换到复数频域(rFFT),利用复数值线性层进行幅值和相位的频域内插操作,同时结合低通滤波器去除高频噪声以压缩模型规模,随后通过逆变换(irFFT)还原到时间域,完成预测或重建任务。整个模型参数量仅约 10k,具有极高的计算效率,非常适合资源受限的边缘设备部署。
模型的核心目标:
模型的整体流程:
预处理:首先对输入时间序列进行归一化处理,使其均值为零(使用 RIN,Reversible Instance Normalization)。
快速傅里叶变换(rFFT):将时间序列数据从时间域投影到复数频域。
低通滤波(LPF):在频域中应用低通滤波器,去除高频噪声成分,以压缩模型规模并保留关键信息。
复数值线性层:通过复数值线性层进行频率插值操作。该层主要学习幅值缩放和相位平移,以实现频域插值。
零填充:对插值后的频域数据进行零填充,以满足逆变换的长度要求。
逆快速傅里叶变换(irFFT):将插值后的复数频域数据投影回时间域。
归一化逆变换:将时间序列恢复到原始的尺度和分布。
核心组件:
预测任务:监督模型在频域中生成预测的时间序列。
重建任务:通过降采样的时间序列监督模型恢复原始序列。
复数值线性层:这是模型的核心部分,能够同时处理复数的幅值和相位信息,通过复数乘法实现频域的插值操作。
低通滤波器(LPF):通过指定截止频率,去除高频噪声,同时减小数据维度。
监督机制:
特殊设计:
轻量化设计:模型使用极少的参数(仅约 10k),因此计算和内存开销非常低,适合边缘设备运行。
频域处理:通过在频域中操作复杂的时间序列,降低了时间域直接预测的复杂度。
FITS 是一个通过复数频域内插来进行时间序列预测和重建的模型。它通过将时间序列转换到复数频域,进行插值,然后再转回时间域来完成任务。
即插即用模块作用
FITS 作为一个即插即用模块:
频域特征提取与内插:利用频域表示的紧凑性和信息丰富性,捕捉时间序列的幅值和相位信息,提升对复杂动态变化的捕捉能力。
高效计算与内存节约:大幅降低参数量和计算需求(相比主流模型,FITS 参数量减少达数百倍),实现实时性能。
去噪与压缩:通过低通滤波器(LPF)去除高频噪声成分,保留关键信息,同时减少数据维度。
提升预测与重建精度:在频域插值的基础上有效整合多频段信息,为长短期预测、重建任务提供支持。
消融实验结果
显示了不同 Look-back Window 长度和 Low-pass Filter (LPF) 的 Cutoff Frequency(COF)设置对模型在 ETTh2 数据集上的影响。实验表明,更长的 Look-back Window 通常会带来更好的性能,而增加 COF 的效果相对较小。
展示了不同设置下模型参数量的变化,说明了模型的轻量化特性:通过调整 Look-back Window 和 COF,可以进一步减少模型参数量,同时保持性能。
对 ETTh1 数据集的分析结果,表明在某些情况下(如 Look-back Window = 720 时),更长的 Look-back Window 会引入分布偏移(distribution shift),导致性能下降。这一现象突出显示了数据分布对模型表现的影响。
即插即用模块
import torch
import torch.nn as nn
class Configs:
def __init__(self):
self.seq_len = 100
self.individual = False
self.enc_in = 2
self.cut_freq = 50
class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()
self.seq_len = configs.seq_len
self.channels = configs.enc_in
self.dominance_freq = configs.cut_freq
self.n_fft = self.seq_len // 2 + 1
if configs.individual:
self.freq_upsampler = nn.ModuleList([nn.Linear(self.n_fft * 2, self.n_fft * 2, bias=False) for _ in range(self.channels)])
else:
self.freq_upsampler = nn.Linear(self.n_fft * 2 * self.channels, self.n_fft * 2 * self.channels, bias=False)
def forward(self, x):
x_mean = torch.mean(x, dim=1, keepdim=True)
x_normalized = (x - x_mean) / torch.sqrt(torch.var(x, dim=1, keepdim=True) + 1e-5)
# 执行FFT变换
low_specx = torch.fft.rfft(x_normalized, dim=1)
low_specx[:, self.dominance_freq:, :] = 0 # 应用LPF
# 拆分实部和虚部
real_part = low_specx.real
imag_part = low_specx.imag
low_specx_combined = torch.cat([real_part, imag_part], dim=-1)
if isinstance(self.freq_upsampler, nn.ModuleList):
low_specxy_combined = torch.stack([
self.freq_upsampler[i](low_specx_combined[:, :, i].view(-1, 2 * self.n_fft))
for i in range(self.channels)
], dim=-1).view(-1, self.n_fft, 2)
else:
low_specxy_combined = self.freq_upsampler(low_specx_combined.view(-1, self.n_fft * 2 * self.channels))
low_specxy_combined = low_specxy_combined.view(-1, self.n_fft, 2 * self.channels)
real_part, imag_part = torch.split(low_specxy_combined, self.channels, dim=-1)
real_part = real_part.view(-1, self.seq_len // 2 + 1, self.channels)
imag_part = imag_part.view(-1, self.seq_len // 2 + 1, self.channels)
low_specxy_ = torch.complex(real_part, imag_part)
low_xy = torch.fft.irfft(low_specxy_, n=self.seq_len, dim=1)
xy = (low_xy * torch.sqrt(torch.var(x, dim=1, keepdim=True) + 1e-5)) + x_mean
return xy
if __name__ == '__main__':
configs = Configs()
block = Model(configs)
input = torch.rand(32, configs.seq_len, configs.enc_in)
output = block(input)
print(input.size()) print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文