上交最新时空预测模型PredFormer,纯Transformer架构,多个数据集取得SOTA效果

科技   2024-10-12 08:03   北京  

点关注,不迷路,用心整理每一篇算法干货~

后台留言”交流“,加入圆圆算法交流群~
👇🏻扫码👇🏻加入圆圆算法知识星球~
已有870+同学加入学习,700+干货笔记)

今天给大家介绍一篇时空预测最新模型PredFormer,由上海交大等多所高校发表,采用纯Transformer模型结构,在多个数据集中取得SOTA效果。

1

背景

时空预测学习是一个拥有广泛应用场景的领域,比如天气预测,交通流预测,降水预测,自动驾驶,人体运动预测等。

提起时空预测,不得不提到经典模型ConvLSTM和最经典的benchmark moving mnist,在ConvLSTM时代,对于Moving MNIST的预测存在肉眼可见的伪影和预测误差。而在最新模型PredFormer中,对Moving MNIST的误差达到肉眼难以分辨的近乎完美的预测结果。

在以前的时空预测工作中,主要分为两个流派,基于循环(自回归)的模型,以ConvLSTM/PredRNN//E3DLSTM/SwinLSTM/VMRNN等工作为代表;更近年来,研究者提出无需循环的SimVP框架,由CNN Encoder-Decoder结构和一个时间转换器组成,以SimVP/TAU/OpenSTL等工作为代表。

RNN。系列模型的缺陷在于,无法并行化,自回归速度慢,显存占用高,效率低;CNN系列模型无需循环提高了效率,得益于归纳偏置,但往往以牺牲泛化性和可扩展性为代价,模型上限低。

于是作者提出了问题,时空预测,真的需要RNN吗?真的需要CNN吗?是否能够设计一个模型,可以自动地学习数据中的时空依赖,而不需要依赖于归纳偏置呢?

一个直觉的想法是利用Transformer,因为它在各种视觉任务中的广泛成功,并且是RNN和CNN的有力替代者。在此前的时空预测工作中,已有研究者把Transformer嵌入到上述两种框架中,比如SwinLSTM(ICCV23)融合了Swin Transformer和LSTM,比如OpenSTL(NeurIPS23)把各种MetaFormer结构(比如ViT,Swin Transformer等)作为SimVP框架中的时间转换器。但是,纯Transformer结构的网络鲜有探索。

但纯Transformer模型的挑战在于,如何在一个框架中同时处理时间和空间信息。一个简单的想法是合并空间序列和时间序列,计算时空全注意力,由于Transformer的计算复杂度是序列长度的二次复杂度,这样的做法会导致计算复杂度较大。

在这篇文章中,作者提出了用于时空预测学习的新框架PredFormer,这是一个纯ViT模型,既没有自回归也没有任何卷积。作者利用精心设计的基于门控Transfomer模块,对3D Attention进行了全面的分析,包括时空全注意力,时空分解的注意力,和时空交错的注意力。PredFormer 采用非循环、基于Transformer的设计,既简单又高效,更少参数量,Flops,更快推理速度,性能显著优于以前的方法。在合成和真实数据集上进行的大量实验表明,PredFormer 实现了最先进的性能。在 Moving MNIST 上,PredFormer 相对于 SimVP 实现了 51.3% 的 MSE 降低,突破性地达到11.6。对于 TaxiBJ,该模型将 MSE 降低了 33.1%,并将 FPS 从 533 提高到 2364。此外,在 WeatherBench 上,它将 MSE 降低了 11.1%,同时将 FPS 从 196 提高到 404。这些准确度和效率方面的性能提升证明了 PredFormer 在实际应用中的潜力。

2

实现方法

PredFormer模型遵循标准ViT的设计,先对输入进行Patch Embedding,把输入为[B, T, C, H, W]的时空序列转换为[B, T, N, D]的张量。在位置编码环节,作者采用了不同于一般ViT设计的可学习的位置编码,而是采用了基于sin函数的绝对位置编码,作者在消融实验中进一步阐述了绝对位置编码在时空任务中的优越性。

PredFormer的编码器部分,由门控Transfomer模块以不同的方式堆叠而成。由于编码器部分是纯Transformer结构,没有任何卷积,也没有分辨率的下降,每一个门控Transformer模块都建模了全局信息,这允许模型只需使用一个简单的解码器就可以构成一个性能强大的预测模型。作者采用了一个线性层作为解码器来进行Patch Recovery,这让模型的输出从[B, T, N, D]恢复到[B, T, C, H, W]。

不同于标准Transformer模型采用MLP作为FFN,PredFormer采用了Gated Linear Unit(GLU)作为FFN,这是受GLU在NLP任务中优于MLP启发的改进。作者在消融实验中进一步阐述了GLU相比于MLP在时空任务上的优越性。

作者对3D Attention进行了全面的分析,并提出了9种PredFormer变体。在以前用于视频分类的Video ViT设计中,TimesFormer(ICML21), ViviT(ICCV21), TSViT(CVPR23)等工作也对时空分解进行了分析,但是TimesFormer是在self-attention层面进行分解,也就是spatial attention和temporal attention共用一个MLP。ViviT则是提出了在Encoder层面(先空间后时间),self-attention层面和head层面进行时空分解。而TSViT发现先时间后空间的Encoder对卫星序列图像分类更有效。

不同于以上工作,PredFormer是在Gated Transformer Block(GTB)层面(多了基于Gated Linear Unit)进行时空分解。对时间和空间的self-attention都加GLU是至关重要的,因为它可以让学习到的特征互相作用并且增强非线性。

PredFormer提出了时空全注意力Encoder,时间在前和空间在前的2种分解Encoder和6种新颖的时空交错的Encoder,一共9种模型。PredFormer提出了PredFormer Layer的概念,即一个既能建模空间信息,又能建模时间信息的最小单元。基于这种想法,作者提出了三种基本范式,二元组(由一个Temporal GTB和一个Spatial GTB组成,有T-S和S-T两种方式),三元组(T-S-T和S-T-S),四元组(两个二元组以相反的方向重组)。

这一设计源于不同的时空预测任务往往有着不同的空间分辨率和时间分辨率(时间间隔以及变化程度),这意味着不同的数据集上对时间信息和空间信息的依赖程度不同,作者设计了这些模型以提高PredFormer模型在不同任务上的适应性。

3

实验效果

在实验部分,作者控制了提出的每种变体使用相同的GTB数目,这可以保证模型的参数量基本一致,从而对比不同模型的性能。

实验发现了一些规律,(1)时间在前的分解Encoder模型优于时空全注意力模型,由于空间在前的分解Encoder模型 (2)时空交错的6种模型在大多数任务上表现都很好,都能达到sota,但最优模型因为数据集本身的不同时空依赖特性而不同,这体现了PredFormer这种框架和时空交错设计的优势 (3)作者在讨论环节提出了建议,在其他的时空预测任务上,从四元组-TSST开始尝试,因为这个模型在三个数据集上都表现sota,先调整M个TSST(即4M个门控Transformer)的M参数,然后尝试M个TST和M个STS以确定数据集是时间依赖更强或空间依赖更强的模型。得益于Transformer构架的可扩展性,不同于SimVP框架的CNN Encoder-Decoder模型,对spatial和temporal的hidden dim以及block数都设置了不同的值,PredFormer对spatial和temporal GTB采用相同的固定的参数,因此只需要调整M的值,在比较少次数的调整后就可以达到最优性能。

ViT模型的训练通常要求较大的数据集,在时空预测任务上,大多数据集在几千到几万的量级,数据集少,因此很容易过拟合。作者还探索了不同的正则化策略,包括dropout和drop path,通过广泛的消融实验,作者发现同时使用dropout和uniform的drop path(不同于一般ViT使用线性增加的drop path rate)会产生最优的模型效果。

作者还进行了可视化比较,可以看到,在PredFormer相对于TAU明显减少了预测误差。作者还给出了一个特殊例子来证明PredFormerr模型相比于CNN模型在泛化性上的优越性。在交通流预测任务上,当第四帧相比前三帧明显减少流量时,TAU受限于归纳偏置仍然预测了较高的流量,而PredFormer却能捕捉到这里的变化。PredFormerr预测剧烈变化的能力在交通流和天气预测中可能有非常宝贵的应用价值。

END




后台留言”交流“,加入圆圆算法交流群~
后台留言”星球“,加入圆圆算法知识请星球~【时序预测专题课程持续更新中
知识星球提供一文贯通笔记、经典代码解析、问答服务、新人入门,已有870+小伙伴加入价格随人数增加和内容丰富上涨,感兴趣的同学尽早加入~


投稿&加交流群请加微信,备注机构+方向拉群~

【历史干货算法笔记】
生成式模型入门:一文讲懂3大类生成式模型
Sptial-Temporal时空预测总结:建模思路、优化方法梳理
时序预测顶会论文数据集、数据处理方法、训练方法汇总
时间序列预测实战方法概述:从数据到模型
Informer模型结构和代码解析
基于Transformer的时序预测模型TFT代码详解
时空预测经典模型STGCN原理和代码解读
一网打尽:14种预训练语言模型大汇总
Vision-Language多模态建模方法脉络梳理
花式Finetune方法大汇总
从ViT到Swin,10篇顶会论文看Transformer在CV领域的发展历程

如果觉得有帮助麻烦分享在看点赞~  

圆圆的算法笔记
定期更新深度学习/算法干货笔记和世间万物学习记录~
 最新文章