写在前面 & 笔者的个人理解
最近自回归(AR)生成模型的成功,如自然语言处理中的GPT系列,促使人们努力在视觉任务中复制这一成功。一些工作试图通过构建能够生成逼真的未来视频序列和预测自车状态的基于视频的世界模型,将这种方法扩展到自动驾驶。然而,先前的工作往往产生不令人满意的结果,因为经典的GPT框架旨在处理1D上下文信息,如文本,并且缺乏对视频生成所必需的空间和时间动态进行建模的固有能力。本文介绍了DrivingWorld,这是一个GPT风格的自动驾驶世界模型,具有多种时空融合机制。这种设计能够有效地对空间和时间动态进行建模,从而促进高保真、长持续时间的视频生成。具体来说,我们提出了一种下一状态预测策略来模拟连续帧之间的时间一致性,并应用下一token预测策略来捕获每个帧内的空间信息。为了进一步提高泛化能力,我们提出了一种新的掩码策略和重新加权策略用于token预测,以缓解长期漂移问题并实现精确控制。我们的工作展示了制作高保真、持续时间超过40秒的一致视频片段的能力,这比最先进的驾驶世界模型长2倍多。实验表明,与先前的工作相比,我们的方法实现了卓越的视觉质量和更精确的可控未来视频生成。
开源链接:https://github.com/YvanYin/DrivingWorld
总结来说,本文介绍了DrivingWorld,这是一个基于GPT风格视频生成框架的驾驶世界模型。我们的主要目标是在自回归框架中增强时间一致性的建模,以创建更准确可靠的世界模型。为了实现这一目标,我们的模型结合了三个关键创新:1)时间感知标记化:我们提出了一种时间感知标记器,将视频帧转换为时间相干标记,将未来视频预测的任务重新表述为预测序列中的未来标记。2)混合token预测:我们引入了一种下一状态预测策略来预测连续状态之间的时间一致性,而不是仅仅依赖于下一个token预测策略。之后,应用下一个token预测策略来捕获每个状态内的空间信息。3)长时间可控策略:为了提高鲁棒性,我们在自回归训练过程中实施了随机标记丢弃和平衡注意力策略,从而能够生成具有更精确控制的持续时间更长的视频。DrivingWorld使用AR框架增强了视频生成中的时间连贯性,学习了未来进化的有意义表示。实验表明,所提出的模型具有良好的泛化性能,能够生成超过40秒的视频序列,并提供准确的下一步轨迹预测,保持合理的可控性。
相关工作回顾
世界模型。世界模型捕捉了环境的全面表示,并根据一系列行动预测了未来的状态。世界模型在游戏和实验室环境中都得到了广泛的探索。Dreamer利用过去的经验训练了一个潜在动力学模型,以预测潜在空间内的状态值和行为。DreamerV2基于最初的Dreamer模型构建,在雅达利游戏中达到了人类水平的性能。DreamerV3使用了更大的网络,并成功地学会了从零开始在Minecraft中获取钻石。DayDreamer扩展了Dreamer,在现实世界中训练了四个机器人,成功地完成了运动和操纵任务。
最近驾驶场景的世界模型在学术界和工业界都引起了极大的关注。之前的大多数工作仅限于模拟器或控制良好的实验室环境。Drive WM使用扩散模型探索了现实世界中的驾驶规划者。GAIA-1基于自回归模型研究了现实世界的驾驶规划者,但GAIA-1具有较大的参数和计算需求,随着条件框架数量的增加而增加。在本文中,我们提出了一个自回归框架下的自动驾驶场景的有效世界模型。
VQVAE。VQVAE通过矢量量化学习离散码本表示,以对图像分布进行建模。VQGAN通过结合LPIPS损失和对抗性PatchGAN损失提高了真实感。MoVQ通过将空间变异信息嵌入量化向量中,解决了VQGAN的空间条件归一化问题。LlamaGen进一步微调了VQGAN,表明较小的码本矢量维数和较大的码本大小可以提高重建性能。虽然基于VQGAN的结构被广泛使用,但一些方法探索了更高效的架构。ViT VQGAN用视觉变换器取代了卷积编码器-解码器,提高了模型捕获长距离依赖关系的能力。VAR采用多尺度结构来预测先前尺度的后续尺度,从而提高了发电质量和速度。然而,这些方法侧重于单一图像处理,阻碍了它们捕获时间一致性。为了解决这个问题,我们提出了一种时间感知标记器和解码器。
视频生成。目前有三种主流的视频生成模型:基于GAN、基于扩散和基于GPT的方法。基于GAN的方法经常面临几个挑战,例如模式崩溃,生成器生成的视频的多样性受到限制。此外,生成器和鉴别器之间的对抗性学习可能会导致训练过程中的不稳定。基于扩散的方法的一个主要问题是它们无法生成精确控制的视频。扩散过程的随机性在每一步都引入了随机性,使得难以对生成内容中的特定属性进行严格控制。另一方面,传统的基于GPT的方法允许一定程度的控制,但它们的计算成本随序列长度呈二次增长,显著影响了模型效率。本文提出了一种解耦的时空世界模型框架,该框架在确保精确控制的同时,显著降低了计算成本,提高了模型效率。
DrivingWorld方法详解
我们提出的世界模型DrivingWorld利用GPT风格的架构高效预测未来状态,能够以10Hz的频率将预测时间延长到40秒以上。该模型旨在理解过去的现实世界状态,并预测未来的视频内容和车辆运动。DrivingWorld专门专注于根据时间1到T的历史状态预测时间T+1的下一个状态,我们可以通过逐一顺序预测未来状态来生成长视频。
如图2所示,我们提出的DrivingWorld不仅可以根据过去的观测结果生成未来状态,还可以通过操纵车辆的位置和方向来支持复杂驾驶场景的可控模拟。
Tokenizer
标记化将连续数据转换为离散标记,从而能够与语言模型和增强的多模态序列建模集成。在我们的方法中,标记器将多模态状态映射到统一的离散空间中,从而实现了精确可控的多模态生成。为了为图像生成时间一致的嵌入,我们提出了一种时间感知的矢量量化标记器。我们提出的车辆姿态标记器将姿态轨迹离散化,并将其整合到我们的DrivingWorld中。
前言:Single Image Vector Quantized Tokenizer。单图像矢量量化(VQ)标记器旨在将图像特征图转换为离散标记q。量化器利用包含K个矢量的学习离散码本,将每个特征f(i,j)映射到Z中最接近代码的索引。这种方法能够将连续图像数据转换为离散token。
时间感知矢量量化标记器。单图像VQ标记器通常难以产生时间一致的嵌入,导致不连续的视频预测,阻碍了世界模型的训练。
为了解决这个问题,我们提出了一种时间感知的矢量量化标记器,旨在确保随时间推移的一致嵌入。具体来说,为了捕捉时间依赖性,我们在VQGAN量化之前和之后都插入了一个self-att,其中注意力沿着时间维度进行操作。这使得我们的模型能够捕捉帧之间的长期时间关系,提高生成序列的连贯性和一致性。我们的模型基于LlammaGen的开源VQGAN实现。我们直接而有效的时间self-att的集成可以无缝地整合到原始框架中,然后进行微调,以开发一个健壮且通用的时间感知VQ标记器。
车辆位姿标记器。为了准确表示车辆的自车状态,包括其方向θ和位置(x,y),我们采用以自车辆为中心的坐标系,如图2所示。我们采用相邻时间步长之间的相对姿态,而不是全局姿态。这是因为在长期序列中,由于绝对姿态值的增加,全球姿态带来了重大挑战。这种增长使得归一化变得困难,并降低了模型的鲁棒性。随着序列变长,管理这些大的姿势值变得越来越困难,阻碍了有效的长期视频生成。
World Model
世界模型旨在理解过去的状态输入,模拟现实世界的动态,并预测未来的状态。在我们的背景下,它预测了即将到来的驾驶场景,并规划了可行的未来轨迹。为此,世界模型将历史状态标记连接成一个长序列,其中2D图像标记以锯齿形顺序展开为1D形式。因此,目标是预测下一个状态。基于过去的观测序列,捕捉时间和多模态依赖关系。请注意,来自不同模态的所有离散token在被馈送到世界模型之前,都由其各自的可学习嵌入层映射到共享的潜在空间中。所有后续过程都在这个潜在空间内进行。
前言:下一个token预测。一种直接的方法是使用GPT-2结构进行1D顺序下一个token预测。图3(a)显示了一个简化示例。因果注意被应用于下一个token预测,T+1中的第i个token被建模为:
因此我们提出了一种下一状态预测管道,它由两个模块组成:一个集成时间和多模态信息以生成下一状态特征(即时间多模态融合模块),另一个是自回归模块(即内部状态自回归模块)以生成高质量的内部状态token。时间多模态融合模块。我们的时间多模态模块由一个单独的时间层和一个多模态层组成。这将时间和多模态信息的处理解耦,从而提高了训练和推理速度,同时也降低了GPU内存消耗。如图3(b)所示,我们建议在时间转换层Fa(·)中使用因果注意力掩码,其中每个token只关注自身和所有先前帧中相同顺序位置的token,充分利用时间信息。
在多模态信息融合层Fb(·)中,我们在同一帧中采用双向掩码,旨在充分整合内部状态多模态信息,并促进模态之间的交互。每个token处理来自同一时间步的其他token:
内部状态自回归模块。在时间多模态模块之后,我们获得了用于未来帧状态预测的特征。一种天真的方法是同时预测下一个状态tokenht。最近,多图像生成工作提出,用于下一个token预测的自回归流水线可以生成更好的图像,甚至优于扩散方法。受此启发,我们提出了一个内部状态自回归模块来生成下一时间步的姿势和图像(见图3(b))。
然后,它们被输入到内部状态自回归Transformer层Fc(·)。因果掩码在这些层中使用,因此每个token只能出席自己并前缀内部状态token。自回归过程如方程式6所示。由于我们的管道同时包含了下一个状态预测和下一个内部状态token预测,我们在训练中实施了两种教师强制策略,即一种用于帧级别,另一种用于内部状态级别。
训练损失交叉熵:
Decoder
使用世界模型预测下一个状态标记,然后我们可以利用解码器为该状态生成相应的相对方向、相对位置和重建图像。这个过程使我们能够将预测的潜在表示映射回物理输出,包括空间和视觉数据。
Vehicle Pose Decoder:
Temporal-aware Decoder:
Long-term Controllable Generation
Token Dropout实现无漂移自动回归。在训练过程中,世界模型使用过去的地面真实token作为条件来预测下一个token。然而,在推理过程中,模型必须依赖于先前生成的表征进行调节,这可能包含缺陷。仅使用完美的GT图像进行训练可能会在推理过程中导致内容漂移问题,导致生成的输出迅速退化并最终失败。为了解决这个问题,我们提出了一种随机掩蔽策略(RMS),其中一些来自地面真实token的token被随机丢弃。每个标记有50%的机会被该帧中的另一个随机标记替换,并且这种丢失以30%的概率应用于整个调节图像序列。如图4所示,这种dropout策略显著缓解了推理过程中的漂移问题。
平衡注意力实现精确控制。世界模型利用广泛的注意力操作在代币之间交换和融合信息。然而,每个前视图图像被离散化为512个标记,而只有2个标记表示姿势(方向和位置)。这种不平衡会导致模型忽略姿态信号,导致可控生成不令人满意。为了解决这个问题,我们提出了一种平衡的注意力操作,通过在注意力机制中优先考虑自车状态标记,而不是平等地关注所有标记,来实现更精确的控制。具体来说,我们手动增加注意力图中方向和位置标记的权重(在softmax层之前),分别为这些标记添加0.4和0.2的恒定权重。此外,我们结合了QK范数和2D旋转位置编码,以进一步稳定训练并提高性能。
实验结果
结论和未来工作
总之,DrivingWorld通过利用GPT风格的框架来生成更长、高保真的视频预测,并提高了泛化能力,从而解决了以前自动驾驶视频生成模型的局限性。与在长序列中难以保持连贯性或严重依赖标记数据的传统方法不同,DrivingWorld生成了逼真、结构化的视频序列,同时实现了精确的动作控制。与经典的GPT结构相比,我们提出的时空GPT结构采用了下一状态预测策略来模拟连续帧之间的时间一致性,然后应用下一token预测策略来捕获每个帧内的空间信息。展望未来,我们计划整合更多的多模态信息,并整合多视图输入。通过融合来自不同模态和视角的数据,我们的目标是提高动作控制和视频生成的准确性,增强模型理解复杂驾驶环境的能力,并进一步提高自动驾驶系统的整体性能和可靠性。
参考
[1] DrivingWorld: Constructing World Model for Autonomous Driving via Video GPT