论文介绍
题目:SparseTSF: Modeling Long-term Time Series Forecasting with 1k Parameters
论文地址:https://arxiv.org/pdf/2405.00946
QQ深度学习交流群:994264161
扫描下方二维码,加入深度学习论文指南星球!
加入即可获得,模块缝合、制作、写作技巧,学会“结构”创新、“创新点”创新,从三区到顶会,小论文大论文,毕业一站式服务
创新点
提出Cross-Period Sparse Forecasting技术:
通过将时间序列数据的周期性与趋势分离,创新性地提出了跨周期稀疏预测技术。
原始序列被下采样为跨周期子序列,再对这些子序列进行趋势预测,从而有效提取周期性特征并简化预测任务。
极度轻量化模型SparseTSF:
基于上述技术,构建了SparseTSF模型,其参数数量少于1k。
相较于现有方法,该模型在保持预测性能的同时显著减少了参数规模和计算资源需求。
强大的泛化能力:
SparseTSF在计算资源有限、小样本或低质量数据场景中表现出色,显示了其优越的泛化能力。
能够以极少的参数在多个数据集上实现接近甚至超越最先进模型的性能。
性能和参数效率的平衡:
SparseTSF通过有效的下采样、聚合和稀疏预测技术,将预测任务从直接建模原始序列转化为更简单的子任务,大幅减少了计算复杂度。
应对长时间序列预测的挑战:
针对长时间序列预测中复杂的时间依赖性问题,SparseTSF利用数据的内在周期性简化了建模难度。
方法
整体架构
SparseTSF 模型通过跨周期稀疏预测技术,将长时间序列分解为周期性和趋势性两部分,整体架构包括滑动聚合预处理、基于周期的下采样和共享参数线性层进行稀疏预测,并通过上采样恢复完整预测序列。该模型以少于 1k 的参数捕捉关键周期特征,同时通过实例归一化和简单的均方误差损失函数实现高效、鲁棒的长时间序列预测。
1. 输入数据的预处理
实例归一化 (Instance Normalization):
输入时间序列
首先通过归一化处理,减去其均值x t − L + 1 : t x_{t-L+1:t} ,以减轻分布偏移的影响。e t e_t 归一化公式:
x t − L + 1 : t = x t − L + 1 : t − e t 滑动聚合 (Sliding Aggregation):
使用 1D 卷积对序列进行滑动聚合,捕捉每个周期内的局部特征。
聚合后的序列
包含了周围时间点的上下文信息。x t − L + 1 : t ′ x'_{t-L+1:t}
2. 跨周期稀疏预测 (Cross-Period Sparse Forecasting)
下采样 (Downsampling):
输入序列
根据已知周期性x t − L + 1 : t ′ x'_{t-L+1:t} 被分割成w w 个子序列,每个子序列的长度为w w 。n = ⌊ L / w ⌋ n = \lfloor L / w \rfloor 子序列表示为矩阵
。X ∈ R w × n X \in \mathbb{R}^{w \times n} 稀疏滑动预测 (Sparse Sliding Prediction):
通过共享参数的线性层
对每个子序列的趋势进行预测,得到预测矩阵Linear \text{Linear} ,其中Y ∈ R w × m Y \in \mathbb{R}^{w \times m} 。m = ⌊ H / w ⌋ m = \lfloor H / w \rfloor 上采样 (Upsampling):
对预测矩阵
进行转置和重塑,恢复到完整的预测序列Y Y x ^ t + 1 : t + H \hat{x}_{t+1:t+H}
3. 输出数据的后处理
恢复预测序列的均值
,得到最终的预测值:e t e_t x ^ t + 1 : t + H = x ^ t + 1 : t + H + e t
4. 模型的损失函数
使用经典的均方误差 (MSE) 作为损失函数:
L = 1 C ∑ i = 1 C ∥ y t + 1 : t + H ( i ) − x ^ t + 1 : t + H ( i ) ∥ 2 2
即插即用模块作用
SparseTSF 作为一个即插即用模块:
(1)周期性时间序列数据
典型场景:能源消耗(如电力和水资源使用)、交通流量、零售销售量、天气预测等具有固定周期性的数据。
作用:
有效分离数据中的周期性和趋势性特征,简化预测任务,提升模型的效率和准确性。
(2)长时间序列预测(LTSF)
典型场景:超长时间的流量预测、物流需求分析、生产计划、金融市场中的中长期趋势分析。
作用:
在长时间窗口中提取关键的周期性特征,同时通过稀疏预测降低模型复杂度,显著减少计算开销。
(3)资源受限的环境
典型场景:嵌入式设备、IoT(物联网)设备、边缘计算环境等低计算能力场景。
作用:
利用稀疏化结构和极小的参数规模(<1k),在低功耗设备中实现高效时间序列预测。
(4)数据样本少或质量低的场景
典型场景:缺失数据较多、不规则采样或噪声较大的时间序列数据。
作用:
SparseTSF 模块的轻量化设计和周期性提取能力有助于在小样本和低质量数据中实现稳健的预测性能。
消融实验结果
内容:对比了不同基础模型(Linear、Transformer 和 GRU)在是否使用 Sparse 技术时的性能差异。
说明:
Sparse 技术显著提升了所有模型的预测性能,尤其是对复杂模型(如 Transformer 和 GRU),其平均性能提升分别为 21.4% 和 12.4%。
结果表明,Sparse 技术能有效提取数据的周期性特征,从而提升模型性能。
内容:探讨了主周期超参数
的选择对模型性能的影响。w w 说明:
当
(与数据的主周期一致)时,模型性能最佳。w = 24 w = 24 如果
偏离主周期,模型性能略有下降,表明主周期的合理选择对模型效果至关重要。w w
即插即用模块
import torch
import torch.nn as nn
from thop import profile
class Configs:
def __init__(self, seq_len=100, pred_len=50, enc_in=1, period_len=10):
self.seq_len = seq_len
self.pred_len = pred_len
self.enc_in = enc_in
self.period_len = period_len
class SparseTSF(nn.Module):
def __init__(self, configs):
super(SparseTSF, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.enc_in = configs.enc_in
self.period_len = configs.period_len
self.seg_num_x = self.seq_len // self.period_len
self.seg_num_y = self.pred_len // self.period_len
self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1 + 2 * self.period_len // 2,
stride=1, padding=self.period_len // 2, padding_mode="zeros", bias=False)
self.linear = nn.Linear(self.seg_num_x, self.seg_num_y, bias=False)
def forward(self, x):
batch_size = x.shape[0]
seq_mean = torch.mean(x, dim=1).unsqueeze(1)
x = (x - seq_mean).permute(0, 2, 1)
x = self.conv1d(x.reshape(-1, 1, self.seq_len)).reshape(-1, self.enc_in, self.seq_len) + x
x = x.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1)
y = self.linear(x)
y = y.permute(0, 2, 1).reshape(batch_size, self.enc_in, self.pred_len)
y = y.permute(0, 2, 1) + seq_mean
return y
if __name__ == '__main__':
configs = Configs(seq_len=100, pred_len=100, enc_in=3, period_len=10)
block = SparseTSF(configs)
input_tensor = torch.rand(1, configs.seq_len, configs.enc_in)
x = torch.randn(1, 100, 3)
flops, params = profile(block, (x,))
print('Params = ' + str(params / 1000 ** 2) + 'M')
output = block(input_tensor)
print(input_tensor.size()) print(output.size())
便捷下载方式
浏览打开网址:https://github.com/ai-dawang/PlugNPlay-Modules
更多分析可见原文