WHALE来了,南大周志华团队 最新成果发布!下一个世界模型?

科技   2024-11-18 13:00   广东  

来源:机器之心

人类能够在脑海中设想一个想象中的世界,以预测不同的动作可能导致不同的结果。受人类智能这一方面的启发,世界模型被设计用于抽象化现实世界的动态,并提供这种「如果…… 会怎样」的预测。


因此,具身智能体可以与世界模型进行交互,而不是直接与现实世界环境交互,以生成模拟数据,这些数据可以用于各种下游任务,包括反事实预测、离线策略评估、离线强化学习。


世界模型在具身环境的决策中起着至关重要的作用,使得在现实世界中成本高昂的探索成为可能。为了促进有效的决策,世界模型必须具备强大的泛化能力,以支持分布外 (OOD) 区域的想象,并提供可靠的不确定性估计来评估模拟体验的可信度,这两者都对之前的可扩展方法提出了重大挑战。


本文,来自南京大学、南栖仙策等机构的研究者引入了 WHALE(World models with beHavior-conditioning and retrAcing-rollout LEarning),这是一个用于学习可泛化世界模型的框架,由两种可以与任何神经网络架构普遍结合的关键技术组成。



  • 论文地址:https://arxiv.org/pdf/2411.05619

  • 论文标题:WHALE: TOWARDS GENERALIZABLE AND SCALABLE WORLD MODELS FOR EMBODIED DECISION-MAKING


首先,在确定策略分布差异是泛化误差的主要来源的基础上,作者引入了一种行为 - 条件(behavior-conditioning)技术来增强世界模型的泛化能力,该技术建立在策略条件模型学习的概念之上,旨在使模型能够主动适应不同的行为,以减轻分布偏移引起的外推误差。


此外,作者还提出了一种简单而有效的技术,称为 retracing-rollout,以便对模型想象进行有效的不确定性估计。作为一种即插即用的解决方案, retracing-rollout 可以有效地应用于各种实施任务中的末端执行器姿态控制,而无需对训练过程进行任何更改。


为了实现 WHALE 框架,作者提出了 Whale-ST,这是一个基于时空 transformer 的可扩展具身世界模型,旨在为现实世界的视觉控制任务提供忠实的长远想象。


为了证实 Whale-ST 的有效性,作者在模拟的 Meta-World 基准和物理机器人平台上进行了广泛的实验。


在模拟任务上的实验结果表明,Whale-ST 在价值估计准确率和视频生成保真度方面均优于现有的世界模型学习方法。此外,作者还证明了基于 retracing-rollout 技术的 Whale-ST 可以有效捕获模型预测误差并使用想象的经验增强离线策略优化。


作为进一步的举措,作者引入了 Whale-X,这是一个具有 414M 参数的世界模型,该模型在 Open X-Embodiment 数据集中的 970k 个现实世界演示上进行了训练。通过在完全没见过的环境和机器人中的一些演示进行微调,Whale-X 在视觉、动作和任务视角中展示了强大的 OOD 通用性。此外,通过扩大预训练数据集或模型参数,Whale-X 在预训练和微调阶段都表现出了令人印象深刻的可扩展性。



总结来说,这项工作的主要贡献概述如下:


  • 作者引入了 WHALE,这是一个学习可泛化世界模型的框架,由两项关键技术组成:行为 - 条件(behavior-conditioning)和 retracing-rollout,以解决世界模型在决策应用中的两个主要挑战:泛化和不确定性估计;

  • 通过整合 WHALE 的这两种技术,作者提出了 Whale-ST,这是一种可扩展的基于时空 transformer 的世界模型,旨在实现更有效的决策,作者进一步提出了 Whale-X,这是一个在 970K 机器人演示上预训练的 414M 参数世界模型;

  • 最后,作者进行了大量的实验,以证明 Whale-ST 和 Whale-X 在模拟和现实世界任务中的卓越可扩展性和泛化性,突出了它们在增强决策方面的效果。

 

学习可泛化的世界模型以进行具身决策


世界模型中的序列决策通常需要智能体探索超出训练数据集的分布外 (OOD) 区域。这要求世界模型表现出强大的泛化能力,使其能够做出与现实世界动态密切相关的准确预测。同时,可靠地量化预测不确定性对于稳健的决策至关重要,这可以防止离线策略优化利用错误的模型预测。考虑到这些问题,作者提出了 WHALE,这是一个用于学习可泛化世界模型的框架,具有增强的泛化性和高效的不确定性估计。


用于泛化的行为 - 条件


根据公式(2)的误差分解可知,世界模型的泛化误差主要来源于策略分歧引起的误差积累。



为了解决这个问题,一种可能的解决方案是将行为信息嵌入到世界模型中,使得模型能够主动识别策略的行为模式,并适应由策略引起的分布偏移。


基于行为 - 条件,作者引入了一个学习目标,即从训练轨迹中获取行为嵌入,并整合学习到的嵌入。


作者希望将训练轨迹 τ_H 中的决策模式提取到行为嵌入中,这让人联想到以历史 τ_h 为条件的轨迹似然 ELBO(evidence lower bound)的最大化:


作者建议通过最大化 H 个决策步骤上的 ELBO 并调整类似于 β-VAE 的 KL 约束数量来学习行为嵌入:



这里,KL 项将子轨迹的嵌入预测约束到每个时间步骤 h,鼓励它们近似后验编码。这确保了表示保持策略一致,这意味着由相同策略生成的轨迹表现出相似的行为模式,从而表现出相似的表示。然后使用学习到的先验预测器从历史 τ_h 中获得行为嵌入 z_h,以便在世界模型学习期间进行行为调节,其中行为嵌入被接受为未来预测的额外协变量:



不确定性估计 Retracing-rollout


世界模型不可避免地会产生不准确和不可靠的样本,先前的研究从理论和实验上都证明,如果无限制地使用模型生成的数据,策略的性能可能会受到严重损害。因此,不确定性估计对于世界模型至关重要。


作者引入了一种新颖的不确定性估计方法,即 retracing-rollout。retracing-rollout 的核心创新在于引入了 retracing-action,它利用了具身控制中动作空间的语义结构,从而能够更准确、更高效地估计基于 Transformer 的世界模型的不确定性。



接下来作者首先介绍了 retracing-action,具体地说,retracing-action 可以等效替代任何给定的动作序列,形式如公式(5),其中表示动作 a_i 第 j 维的值。



接下来是一个全新的概念:Retracing-rollout。


具体来说:假设给定一个「回溯步骤」k,整个过程开始于从当前时间步 t,回溯到时间步 t-k,将 o_t−k 作为起始帧。


然后,执行一个回溯动作,从 o_t−k 开始,生成相应的结果 o_k+1。


在实际操作中,为了避免超出动作空间的范围,回溯动作被分解为 k 步。在每一步中,前六个维度的动作被设置为,而最后一个维度保持不变。通过这种方式,模型可以通过多步回溯产生期望的结果。


为了估计某一时间点 (o_t,a_t) 的不确定性,采用多种回溯步骤生成不同的回溯 - 轨迹预测结果。具体来说,要计算不同回溯 - 轨迹输出与不使用回溯的输出之间的「感知损失」。同时,引入动态模型的预测熵,通过将「感知损失」和预测熵相乘,得到最终的不确定性估计结果。


与基于集成的其他方法不同,retracing-rollout 方法不需要在训练阶段进行任何修改,因此相比集成方法,它显著减少了计算成本。


作者在论文中还给出了具体的实例。图 3 展示了 Whale-ST 的整体架构。具体来说,Whale-ST 包含三个主要组件:行为调节模型、视频 tokenizer 和动态模型。这些模块采用了时空 transformer 架构。


这些设计显著简化了计算需求,从相对于序列长度的二次依赖关系简化为线性依赖关系,从而降低了模型训练的内存使用量和计算成本,同时提高了模型推理速度。



实验


该团队在模拟任务和现实世界任务上进行了广泛的实验,主要是为了回答以下问题:


  • Whale-ST 在模拟任务上与其他基线相比表现如何?行为 - 条件和 retracing-rollout 策略有效吗?

  • Whale-X 在现实世界任务上的表现如何?Whale-X 能否从互联网规模数据的预训练中受益?

  • Whale-X 的可扩展性如何?增加模型参数或预训练数据是否能提高在现实世界任务上的表现?

模拟任务中的 Whale-ST

该团队在 Meta-World 基准测试上开展实验。Meta-World 是一个包含多种视觉操作任务的测试集。研究者们构建了一个包含 6 万条轨迹的训练数据集,这些轨迹是从 20 个不同的任务中收集来的。模型学习算法需要使用这些数据从头开始训练。

研究团队将 Whale-ST 与 FitVid、MCVD、DreamerV3、iVideoGPT 进行了对比。评估指标如下:

  • 预测准确性:验证模型是否能够正确估计给定动作序列的值,具体通过值差、回报相关性 (Return Correlation) 和 Regret 进行评估;

  • 视频保真度:研究团队采用 FVD、PSNR、LPIPS 和 SSIM 来衡量视频轨迹生成的质量。


下表展示了预测准确性的结果,其中,Whale-ST 在所有三个指标上都表现出色。在 64 × 64 的分辨率下,Whale-ST 的值差与 DreamerV3 的最高分非常接近。当在更高分辨率 256 × 256 测试时,Whale-ST 的表现进一步提升,取得了最小的值差和最高的回报相关性,反映了 Whale-ST 能更细致地理解动态环境。


表 2 展示了视频保真度的结果,Whale-ST 在所有指标上均优于其他方法,特别是 FVD 具有显著优势。


不确定性估计

针对不确定性,研究团队比较了 retracing-rollout 与两种基准方法:

(1)基于熵的方法:研究团队采用基于 Transformer 的动态模型,它通过计算模型输出的预测熵来量化不确定性
(2)基于集成的方法:研究团队训练了三个独立的动态模型,然后通过比较每个模型生成的图像之间的像素级差异来估计不确定性。

具体来说,他们从模型误差预测和离线强化学习两个角度进行评估。

下表展示了模型误差预测的结果,在所有 5 个任务中,retracing-rollout 均优于其他基线方法。与基于集成的方法相比,retracing-rollout 提升了 500%,与基于熵的方法相比,提高了 50%。


下图展示了离线 MBRL 的结果,retracing-rollout 在 5 个任务中的 3 个任务中收敛得更好、具备更强的稳定性。特别是在关水龙头和滑盘子任务中,retracing-rollout 是唯一能够稳定收敛的方法,而其他方法在训练后期出现了不同程度的性能下降。


Whale-X 在真实世界中的表现

为了评估 Whale-X 在实际物理环境中的泛化能力,研究团队在 ARX5 机器人上进行了全面实验。

与预训练数据不同,评估任务调整了摄像机角度和背景等,增加了对世界模型的挑战。他们收集了每个任务 60 条轨迹的数据集用于微调,任务包括开箱、推盘、投球和移动瓶子,还设计了多个模型从未接触过的任务来测试模型的视觉、运动和任务泛化能力。

如图 5 所示,Whale-X 在真实世界中展现出了明显的优势。

具体来说:


1. 与没有行为 - 条件的模型相比,Whale-X 的一致性提高了 63%,表明该机制显著提升了 OOD 泛化能力;
2. 在 97 万个样本上进行预训练的 Whale-X,比从零开始训练的模型具有更高的一致性,凸显了大规模互联网数据预训练的优势;
3. 增加模型参数能够提升世界模型的泛化能力。Whale-X-base(203M)动态模型在三个未见任务中的一致性率是 77M 版本的三倍。

此外,视频生成质量与一致性的结果一致,如表 4 所示。通过行为 - 条件策略、大规模预训练数据集和扩展模型参数,三种策略结合,显著提高了模型的 OOD 泛化能力,尤其是在生成高质量视频方面。


扩展性

固定视频 token 和行为 - 条件这两个部分不变,仅调整模型的参数量和预训练数据集的大小,Whale-X 的拓展性如何呢?

研究团队在预训练阶段训练了四个动态模型,参数数量从 39M 到 456M 不等,结果如图 7 的前两幅图所示。


这些结果表明,Whale-X 展现出强大的扩展性:无论是增加预训练数据还是增加模型参数,都会降低训练 loss。

除此之外,研究团队还验证了更大的模型在微调阶段是否能够展现更好的性能。

为此,他们微调了一系列动态模型,结果如图 7 最左侧所示。不难发现,经过微调后,更大的模型在测试数据上表现出更低的 loss,进一步突显了 Whale-X 在真实任务中出色的扩展性。

可视化

  • 定性评估


图 1 展示了在 Meta-World、Open X-Embodiment 和研究团队设计的真实任务上的定性评估结果。


结果表明,Whale-ST 和 Whale-X 能够生成高保真度的视频轨迹,尤其是在长时间跨度的轨迹生成过程中,保持了视频的质量和一致性。

  • 可控生成

图 8 展示了 Whale-X 在控制性和泛化性方面的强大能力。给定一个未见过的动作序列,Whale-X 能够生成与人类理解相符的视频,学习动作与机器人手臂移动之间的因果联系。


  • 行为条件可视化


通过 t-SNE 可视化,研究表明 Whale-X 成功地学习到行为嵌入,能够区分不同策略之间的差异。例如,对于同一任务,不同的策略会有不同的行为表示,而噪声策略的嵌入则介于专家策略和随机策略之间,体现了模型在策略建模上的合理性。此外,专家策略在不同任务中的嵌入也能被区分,而随机策略则无法区分,表明模型更擅长表示和区分策略,而不是任务本身。


更多研究细节,请参考原文。

参考链接:https://arxiv.org/abs/2411.05619

推荐阅读




欢迎大家加入DLer-计算机视觉技术交流群!


大家好,群里会第一时间发布计算机视觉方向的前沿论文解读和交流分享,主要方向有:图像分类、Transformer、目标检测、目标跟踪、点云与语义分割、GAN、超分辨率、人脸检测与识别、动作行为与时空运动、模型压缩和量化剪枝、迁移学习、人体姿态估计等内容。


进群请备注:研究方向+学校/公司+昵称(如图像分类+上交+小明)

👆 长按识别,邀请您进群!


深度学习技术前沿
本公众号专注于深度学习领域的前沿技术分享和学术交流。推送有关于机器学习、深度学习、强化学习、计算机视觉、自然语言处理等领域干货文章,致力于在第一时间内汇集和发布最新人工智能技术和前沿资讯。
 最新文章