推理阶段同时从提示数据中学习算子并将其应用于新问题,而无需任何权重更新

科技   2024-11-05 09:10   上海  

In-context operator learning with data prompts for differential equation problems

带有微分方程问题数据提示的上下文运算符学习

https://www.pnas.org/doi/epdf/10.1073/pnas.2310142120


本文介绍了“上下文内算子学习”的范式及其相应的模型“上下文内算子网络”,以在推理阶段同时从提示数据中学习算子并将其应用于新问题,而无需任何权重更新。现有方法受限于使用神经网络来近似特定方程的解或特定算子,当切换到具有不同方程的新问题时需要重新训练。通过将单个神经网络训练为算子学习器,而不是解/算子近似器,我们不仅可以摆脱为新问题重新训练(甚至微调)神经网络,还可以利用跨算子共享的共性,从而在学习新算子时只需要提示中的几个示例。我们的数值结果显示了单个神经网络作为多样化类型微分方程问题(包括常微分方程、偏微分方程和平均场控制问题的正向和逆向问题)的少样本算子学习器的能力,并展示了其将学习能力泛化到训练分布之外的算子的能力。

意义

本文介绍了上下文内算子网络(ICON),这是一种神经网络方法,可以在推理阶段从提示数据中学习新算子,而无需任何权重更新。与现有方法不同,这些方法受限于近似特定方程的解或算子,并且需要为新问题重新训练,ICON将单个神经网络训练为算子学习器,消除了在遇到不同问题时重新训练或微调的需求。数值结果证明了ICON在解决各种类型微分方程问题中的有效性,并能泛化到训练分布之外的算子。所提出的方法从自然语言处理中成功使用的上下文内学习技术中汲取灵感,并对物理系统中的人工通用智能具有意义。


算子学习 | 元学习 | 上下文内学习 | 微分方程 | 人工智能

神经网络的发展对解决微分方程问题产生了重大影响。我们建议读者参考参考文献1以了解该领域的最新进展。

一种典型方法旨在直接近似给定特定问题的解。使用深度学习解决偏微分方程(PDEs)首先在参考文献2中引入,用于高维抛物线方程,并在参考文献3中进一步发展。深度Galerkin方法(4)对神经网络施加约束以满足规定的微分方程和边界条件。深度Ritz方法(5)利用PDEs的变分形式,可用于解决可以转化为等效能量最小化问题的PDEs。物理信息神经网络(PINNs)(6)提出了一种深度神经网络方法,通过在损失函数中整合数据和微分方程来解决正向和逆向问题。弱对抗网络(7)通过将弱解和测试函数分别参数化为原始和对立网络来利用PDEs的弱形式。参考文献8通过在神经网络参数化中编码拉格朗日和欧拉视角来解决高维平均场博弈问题。APAC-net(9)提出了一种生成对抗网络风格的方法,利用原始-对偶公式来解决平均场博弈问题。

尽管这些方法取得了成功,但它们设计用于解决具有特定微分方程的问题。当方程中的项或初始/边界条件发生变化时,神经网络需要重新训练。虽然可以采用迁移学习技术通过微调预训练神经网络来减轻训练成本(10–19),但在目标函数发生重大变化时,这些技术可能不够充分。

后来,人们努力用不同的参数或初始/边界条件来近似微分方程的解算子。早在参考文献20和21中,浅层神经网络就被用于近似非线性算子。在参考文献22中,作者提出使用深度神经网络解决参数化PDE问题。参考文献23引入了一种贝叶斯方法,使用深度卷积编码器-解码器网络进行由随机PDEs支配的问题中的不确定性量化和传播。PDE-Net(24)利用卷积核学习微分算子,使其能够从数据中揭示演化PDE模型,并利用学习到的解映射进行前向预测。深度算子网络(DeepONet)(25, 26)设计了一种神经网络架构来近似将参数或初始/边界条件映射到解的解算子。傅里叶神经算子(FNO)(27, 28)利用傅里叶空间中的积分核来学习解算子。在参考文献29中,作者提出了一种数据驱动的框架,用于近似参数化PDEs的无限维空间之间的输入-输出映射,灵感来自神经网络和模型简化。物理信息神经算子(30)在不同分辨率下结合数据和PDE约束来学习参数化PDEs的解算子。其他相关工作包括(31–35)。

上述方法成功展示了神经网络近似解算子的能力。然而,在这些方法中,一个神经网络仅限于近似一个算子。即使微分方程发生微小变化,也会导致解算子的变化。例如,在学习从扩散系数到泊松方程解的解算子的情况下,如果源项(未设计为算子输入的一部分)发生变化,或方程中引入新项,解算子也会发生变化。因此,神经网络必须重新训练,至少进行微调(30, 33–39),以近似新算子。

我们认为,各种解算子之间存在共性。通过使用具有单一权重集的单个神经网络来学习各种解算子,我们不仅可以摆脱重新训练(甚至微调)神经网络,还可以利用这些共性,从而在学习新算子时需要更少的数据。

如果我们将学习一个解算子视为一个任务,那么我们现在旨在用单个神经网络解决多个与微分方程相关的任务。我们对这个神经网络的期望不仅限于简单地学习特定算子。相反,我们期望它获得“从数据中学习算子”并将新学到的算子应用于新问题的能力。

这种学习和应用新算子的能力可能是人工通用智能(AGI)的重要组成部分。通过观察物理系统的输入和输出,人类可以学习将输入映射到输出的底层算子,并根据其目标控制系统。例如,摩托车手可以快速适应新的摩托车;皮划艇手可以快速适应新的皮划艇或变化的水条件。如果一个人在两项运动中都有专长,他们可能能够在第一次尝试中掌握喷气滑雪。我们期望具有AGI的机器人能够像人类一样适应新环境和任务。

“学会学习”的范式,即元学习,在人工智能的最新发展中取得了巨大成功。在自然语言处理(NLP)中,GPT-2(40)中引入的上下文内学习在GPT-3(41)中进一步扩展,展示了大型语言模型作为少样本学习者的能力。这里的上下文内学习指的是一种学习范式,其中生成语言模型根据提示的“上下文”执行给定任务,包括任务描述和与该任务相关的几个示例。我们建议读者参考参考文献42以了解上下文内学习的最新进展。

在上下文内学习之前,NLP任务主要由BERT风格的预训练加微调范式(43)主导,其中语言模型预训练生成句子嵌入,然后针对特定下游任务进行微调,通常使用额

在本文中,我们将上下文内学习的思想适应并扩展到学习微分方程问题的算子。我们将算子的输入称为“条件”,将算子输出称为“感兴趣的量(QoIs)”。一个“示例”由一对条件和QoI组成。在之前的算子学习范式中,神经网络在共享相同算子的示例上进行训练。在推理阶段,它将新条件作为输入并预测与所学算子对应的QoI。在本文中,在推理阶段,我们改为让训练好的神经网络将示例和新条件(即“问题条件”)作为输入,并同时完成以下两个工作:1)从示例中学习算子,2)将学到的算子应用于问题条件并预测相应的QoI。我们强调,在推理阶段没有权重更新。我们将提出的范式命名为“上下文内算子学习”,相应的模型简称为“上下文内算子网络”或“ICON”。

图1A描述了ICON模型的训练和推理过程。与此并行,图1B描绘了NLP的上下文内学习示例,源自参考文献41。NLP的上下文内学习与科学机器学习的上下文内算子学习之间的类比是清晰的。我们通过在提示示例中指定算子来体现上下文内学习的精神,而不是将特定算子嵌入神经网络权重中。事实上,权重仅封装了跨算子共享的共性以及动态学习算子的能力。因此,上下文内算子学习表现出几个优越的特性,包括1)无需微调即可学习新算子,2)学习新算子所需的数据量减少,3)对训练分布外的算子具有强大的泛化能力。

需要注意的是,将上下文内学习适应于算子学习需要额外努力。上下文内算子学习的一个显著特征是输入和输出是连续函数,而不是NLP中使用的离散标记。为了克服这一挑战,我们采用了一种灵活且通用的方法,将这些函数表示为键值对集合,其中键表示函数输入,值对应于各自的函数输出。这些键值对被打包成“数据提示”,作为ICON模型的输入,类似于语言模型中使用的自然语言提示。我们还采用了一种定制的transformer编码器-解码器架构(46),确保1)示例数量灵活,2)每个条件/QoI函数的键值对数量和选择灵活,3)学习过程对输入键值对的排列不变,4)问题QoI函数的预测不限于预设的输入集合,而是适用于任何输入,5)不同输入的问题QoI函数的预测可以并行执行。我们将在问题设置和方法论中详细讨论。

还有其他使用生成语言模型解决科学和数学相关任务的工作。例如,参考文献47和48集中在结合文本、图像和表格数据的提示中的数学推理和科学问答任务。MyCrunchGPT(49)作为科学机器学习各个阶段的集成工具,利用ChatGPT的能力根据用户提示协调工作流程。机器学习任务的执行仍然需要使用PINNs、DeepONets等不同的方法。在这些工作中,尽管任务与科学和数学相关,但上下文内学习的重点不是直接执行数值科学计算,而是主要应用于执行NLP任务的语言模型。元学习也在参考文献50–52中使用,其中任务相似性被利用来提高新PDE任务的性能。然而,在这些工作中,面对新任务时不可避免地需要神经网络进行微调。在本文中,我们尝试将上下文内学习的范式适应于数值微分方程问题。

本文的其余部分组织如下。在下一节中,我们将介绍上下文内算子学习的问题设置。然后,我们将详细介绍ICON的方法论,随后是实验结果,这些结果展示了ICON在推理阶段从示例中学习算子并应用于问题条件的能力。此外,我们讨论了几个我们认为将增强读者对所提出方法理解的主题。最后,我们总结论文并讨论局限性和未来工作。


问题设置

在本节中,我们将介绍上下文内算子学习的问题设置。

算子定义为一种映射,它接受单个输入函数或输入函数的元组,并生成输出函数。在本文中,我们将算子的输入称为“条件”,将算子输出称为“QoIs”。

以一维ODE问题 = 𝛼u(t) +𝛽c(t) + 𝛾为例。给定参数𝛼, 𝛽, 𝛾 ∈ R,正向问题学习从控制函数c: [0, T] → R和初始条件u(0)到解函数u: [0, T] → R的解算子。在这种情况下,c: [0, T] → R和初始条件u(0)构成条件,u: [0, T] → R是QoI。请注意,虽然u(0)是一个数字,我们仍然可以将其视为域{0}上的函数以适应框架。相反,在逆问题中,我们旨在学习从解函数u: [0, T] → R到控制函数c: [0, T] → R的算子。在这种情况下,函数u被视为条件,函数c是QoI。

在实际场景中,通常难以获得条件和QoIs的解析表示。相反,我们通常依赖于从系统收集的观察或数据。为了解决这个问题,我们采用了一种灵活且可推广的方法,使用键值对来表示这些实体,其中键是离散的函数输入,值是函数的相应输出。继续以上述一维ODE问题为例,为了表示函数c: [0, T] → R,我们将离散时间实例视为键,并将c的相应函数值作为关联值。我们使用键0和值u(0)来表示u的初始条件。需要注意的是,键值对的数量是任意的,键的选择是灵活的,并且它们可以在不同函数之间变化。

训练数据可以表示为 对于其中每个 i 对应于不同的操作符。

对于给定的 i, 表示一组条件-QoI对,它们共享相同的操作符。在我们的设置中,需要强调的是,这里的操作符是完全未知的,甚至包括相应的微分方程类型。这个方面与许多现实世界场景一致,在这些场景中,要么缺少控制方程的参数,要么需要从头构建方程本身。

在推理阶段,我们面对的是条件和QoI对,称为“示例”,它们也共享一个未知的操作符。此外,我们还给出了一个称为“问题条件”的条件。目标是预测与问题条件和未知操作符相对应的QoI。需要注意的是,推理阶段中的未知操作符可能与训练数据集中的操作符不同,甚至可能是分布之外的。

在推理阶段,我们得到共享未知算子的条件和QoI对,称为“示例”。此外,我们得到一个称为“问题条件”的条件。目标是预测与问题条件和未知算子对应的QoI。请注意,推理阶段的未知算子可能与训练数据集中的算子不同,甚至可能超出分布范围。


方法论

在本节中,我们将提供方法的全面概述。首先,我们将解释构建神经网络输入的过程,包括提示和查询。随后,我们将研究神经网络架构。此外,我们将讨论数据准备和训练过程。最后,我们将详细阐述推理过程。

**提示和查询**。模型期望从多个示例中学习算子,每个示例由一对条件和QoI组成,并将其应用于问题条件,对问题QoI进行预测。由于问题QoI是一个函数,因此还需要指定模型应在何处进行评估,即问题QoI的键,称为“查询”(每个查询是一个向量)。我们将示例和问题条件分组为“提示”,与查询一起作为神经网络输入。神经网络的输出表示问题QoI的值的预测,对应于输入查询。

尽管存在替代方法,但在本文中,我们选择了一种简单的方法来构建提示,即将示例和问题条件连接起来创建矩阵表示。矩阵的每一列表示一个键值对。由于我们将使用transformer(46),提示中列的顺序不会影响结果。

为了适应具有不同数量输入条件函数和来自不同空间的函数的算子,我们重新构造了提示和查询中的键。具体来说,我们将提示/查询的第一行分配给不同的函数项,第二行表示时间坐标,第三行表示第一个空间坐标,依此类推。如果某些条目不需要,我们将用零填充它们。

在表1中,我们展示了在一维正向和逆向ODE问题中使用的示例的矩阵表示,如问题设置中所述。提示只是沿着行连接示例和问题条件。

最后,我们注意到示例和键值对的数量可能因不同的提示而异。Transformer特别设计用于处理不同长度的输入。然而,为了批处理的目的,我们仍然使用零填充以确保一致的长度。这种填充以及适当的掩码有效地忽略了这些零填充,对数学计算没有影响。

**神经网络架构**。在我们的方法中,我们采用了一种定制的transformer编码器-解码器(46)神经网络架构,如图2所示。

在进入编码器之前,提示的列通过共享线性层和层归一化(53)进行维度调整。transformer编码器的架构设计遵循参考文献46提出的模型。具体来说,它由一组相同的层组成,每层有两个子层:多头自注意力机制和带有GELU激活(54)的浅层全连接前馈网络。每个子层都由残差连接(55)包裹,随后是层归一化。编码器将提示中所有示例和问题条件的信息合并,生成一个表示算子和问题条件嵌入的输出矩阵。

解码器也由一组相同的层组成,每层有一个多头交叉注意力机制和带有GELU激活的浅层全连接前馈网络。与编码器类似,每个子层都由残差连接包裹,随后是层归一化。与参考文献46中的模型不同,解码器中移除了自注意力层。我们将在后面讨论这一点。编码器的输出嵌入在层归一化后,用作解码器内交叉注意力机制的键和值输入。与嵌入一起,查询(即问题QoI的键)也在共享线性层和层归一化后注入解码器。它们反复通过交叉注意力子层(作为查询)和前馈网络,最终形成解码器的输出。最后,解码器的输出通过一个额外的线性层,以匹配问题QoI值的维度。

这种架构中使用的transformer编码器-解码器与用于计算机视觉中目标检测任务的编码器-解码器有相似之处(56)。在这种情况下,解码器将编码器生成的图像嵌入和“对象查询”作为输入,解码器的每个输出随后通过一个公共前馈网络进行预测。

transformer架构在促进上下文内算子学习适应方面起着关键作用。它能够处理任意长度的输入序列并保持对序列排列的不变性,这与每个条件/QoI函数的键值表示完美契合。首先,它允许示例数量的变化。其次,它为每个条件/QoI函数的键值对数量和选择提供了灵活性。最后,它确保键值对的顺序重新排列不会影响结果。

此外,重要的是要注意,在我们的解码器中移除了自注意力层。因此,对于固定的提示,如果我们向模型输入n个查询向量(或问题QoI的n个键)并接收n个相应的值作为输出,每个值仅由其对应的查询确定,不受其他查询的影响。这种独立性使我们能够设计任意数量的查询,并在我们希望评估问题QoI函数的任何地方并行进行预测。

数据准备和训练 在训练神经网络之前,我们准备包含不同类型微分方程问题数值解的数据。数据生成的详细信息在算法1中描述。

在训练过程中,在每次迭代中,我们随机从数据中构建一批提示、查询和标签(真实值)。请注意,不同问题和不同算子出现在同一批中。损失函数是神经网络输出与标签之间的均方误差(MSE)损失。训练过程的详细信息在算法2中描述。

推理:无需权重更新的少样本学习 训练后,我们使用训练好的神经网络根据描述算子的几个示例以及问题条件来预测问题QoI。在一次前向传递中,神经网络同时完成以下两个任务:从示例中学习算子,并将学到的算子应用于问题条件以预测问题QoI。我们强调,神经网络在这种前向传递中不更新其权重。换句话说,训练好的神经网络充当少样本算子学习器,训练阶段可以被视为“学习如何学习算子”。


数值结果

我们设计了19种类型的训练问题,每种类型有1,000组参数,因此总共有19 × 1,000 = 19,000个算子。对于每个算子,我们生成100对条件-QoI。换句话说,算法1和2中的M = 1,000,N = 100。在训练期间构建提示时,我们随机选择一到五个示例。每个条件/QoI中的键值对数量随机从41到50不等。因此,最大提示长度为550,由五个示例组成,累计长度为500,外加一个额外的问题条件50。本文中使用的神经网络总共有大约3000万个参数。其他配置和训练的详细信息在SI附录中。

问题 我们在表2中列出了所有19种类型的问题,以及数据准备阶段(算法1)中参数和条件-QoI对的设置。至于参数的实现,我们在SI附录中展示。

分布内算子 在本节中,我们展示了每种19种类型问题的测试误差,参数、条件和QoI的分布与训练阶段相同,即分布内算子学习。每个条件/QoI中的键值对数量随机从41到50不等,与训练阶段相同。通过使用不同的随机种子,我们确保测试数据与训练数据不同(尽管在同一分布中),并且在测试期间每个条件-QoI对仅显示一次,要么作为示例,要么作为问题。

我们在图3和图4中展示了一些上下文内算子学习的测试案例。

在图5中,我们展示了表2中列出的所有19个问题的每个提示中示例数量的相对误差。对于每种类型的问题,我们进行了500个上下文内学习案例,对应于100个不同的算子,即每个算子有五个案例。首先,绝对误差是通过在所有上下文内学习案例中平均预测的问题QoI值与其对应的真实值之间的差异来计算的。然后,相对误差是通过将绝对误差除以真实值的绝对值的平均值来获得的。

在所有19个问题中,从图5可以明显看出,即使在使用单个示例的情况下,平均相对误差仍保持在6%以下。当使用五个示例时,大多数平均相对误差在2%左右。这突显了单个神经网络从示例中有效学习算子并准确预测各种类型微分方程问题的QoI的能力。此外,对于所有19个问题,随着每个提示中示例数量的增加,误差持续下降。


超分辨率和亚分辨率函数

尽管神经网络使用41到50个键值对来表示条件和QoI进行训练,但它展示了无需任何微调即可泛化到更大范围数量的能力,包括更多键值对(超分辨率)或更少键值对(亚分辨率)。FNO(27, 28)展示了类似的能力,但在我们的论文中,这种泛化归因于transformer的适应性,而不是使用积分核。

在图6中,我们检查了问题17,即MFC g-parameter 2D → 2D,每个条件/QoI中的键值对数量(随机采样)从10到500不等。平均相对误差的计算方式与分布内算子相同,只是我们在域(t, x) ∈ [0.5, 1]×[0, 1]中进行预测并评估误差,通过在时间-空间域上设置查询作为网格点。图4中展示了一个包含三个示例和50个键值对的案例。

对于提示中固定数量的示例,随着每个条件/QoI中键值对数量的增加,平均相对误差减少,最终收敛到1%以下,即使对于单个示例的情况,即一次性学习。

分布外算子 在本节中,我们检查了神经网络将上下文内学习泛化到训练分布之外的算子的能力。在这里,我们强调“分布外”一词并不指条件,而是指算子本身超出了训练期间观察到的算子分布。

我们在表2中的四个代表性问题类型上进行了测试,即问题5、6、11和12。在ODE 3的正向和逆向问题的培训过程中,我们从均匀分布U(-1, 1)、U(0.5, 1.5)和U(-1, 1)中随机生成了a1、a2、a3。每个三元组(a1, a2, a3)定义了一个操作符。现在,我们将分布扩展到一个更大的区域。为了评估并提供性能的视觉描述,我们将区域[0.1, 3.0] × [-3, 3]划分为一个网格。然后通过在每个网格单元中测试(a2, a3)对来评估性能。

a1继续从分布U(-1, 1)中随机抽样。具体来说,我们在每个单元中进行了500个上下文学习案例,对应于100个不同的操作符和每个操作符的五个不同示例和问题。

在这里,示例的数量固定为五个,键值对的数量固定为训练中使用的最大数量。

我们计算了每个单元的相对误差,并在图7A和B中描述了结果。

对线性反应-扩散PDE问题的正向和逆向问题进行了类似的分析。我们将(a, c)区域划分为网格,同时保持边界条件参数u(0)和u(1)从U(−1, 1)中随机采样。平均相对误差如图7 C和D所示。

显然,对于所有四个问题,即使算子参数超出了训练区域,神经网络也展示了准确的预测能力。这展示了其强大的泛化能力,能够学习和应用分布外算子。


泛化到新形式的方程

如参考文献41所述,上下文内学习相对于预训练加微调的一个优势是能够将多种技能混合在一起解决新任务。GPT-4(57)甚至展示了超出人类预期的涌现能力或行为。

尽管我们的实验规模远小于GPT-3或GPT-4,但我们也观察到了神经网络学习并应用训练数据中从未见过的新形式方程的算子的初步证据。

特别是,我们设计了一个新的ODE+ bu(t) + a2,在时间区间[0, 1]上,通过将线性项bu(t)添加到从ODE 3借用的ODE 2中。在新问题中,b也是一个参数,算子由(a1, a2, b)决定。我们研究了新ODE的正向和逆向问题,并评估了神经网络在b ∈ [−0.3, 0.3]时的性能。其他设置,包括a1、a2和c(t)的分布,与问题3和4(ODE 2的正向和逆向问题)相同。

为了研究扩大训练数据集的影响,在图8中,我们展示了使用不同训练数据集训练的神经网络的平均相对误差。在这里,我们以与分布外算子每个单元相同的方式获得每个b的平均相对误差。为了减少计算成本,在本节中,我们训练了与分析其他部分相同的神经网络,但仅使用一半的批量大小,进行1/5的训练步骤。我们注意到,在这些新运行中,训练数据集的大小不同,但训练步骤和批量大小是一致的。换句话说,神经网络在训练期间遇到的提示数量相同。数据集类型的扩展只是增强了提示的多样性。

我们首先仅使用涉及ODE 2(正向和逆向问题)的数据集训练神经网络。然后,作为参考,我们将“错误”算子直接应用于问题条件。“错误”算子定义为对应于ODE 而不是新ODE,具有相同的a1和a2。请注意,当b = 0时,新ODE简化为ODE 2,因此误差为零。作为另一个参考,我们使用相同的神经网络进行上下文内算子学习,但将提示中的示例替换为对应于ODE 2的示例,记为“错误示例”。我们可以看到,使用“正确示例”的神经网络表现并不比两个参考更好,这表明网络几乎无法将其上下文内算子学习的能力泛化到ODE 2之外。

然后,我们逐渐将更多与ODE相关的数据集添加到训练数据中。令人鼓舞的是,随着训练数据集变得更大,误差显示出下降趋势。当使用所有ODE 1、2和3进行训练时,神经网络的表现明显优于仅使用ODE 2训练的网络。

这种证据表明,随着相关训练数据的大小和多样性的增加,神经网络有可能学习和应用对应于以前未见过的方程形式的算子。

最后,我们还展示了在其他部分使用的神经网络的结果,该网络使用完整数据集进行训练,批量大小更大,训练时间更长。新ODE的性能没有提高,这是合理的,因为新添加的阻尼振荡器、PDE和MFC问题的数据与新ODE不密切相关。


讨论

为什么极少数示例足以学习算子

我们尝试从以下几个方面回答这个问题。

首先,我们实际上只需要为特定分布的问题条件学习算子,而不是所有可能的问题条件。

其次,训练算子和测试算子共享共性。例如,对于ODE问题,u的时间导数u和c在每个时间t满足相同的方程。如果神经网络在训练期间捕捉到这种共享属性,并且在推理期间在示例中注意到这种属性,它只需要识别ODE,这对于几个示例来说是足够的。

最后,本文中的算子相当简单且局限于一个小家族,因此用几个示例容易识别。对于训练和测试中更大的算子家族,上下文内算子学习可能需要更多示例(特别是对于那些复杂的算子),以及具有更多计算资源的大型神经网络。

微调的不同角色 微调方法在NLP和科学机器学习领域以多种方式使用。区分微调的各种角色可能很重要。

BERT风格(43)预训练后跟微调范式是NLP中的一个例子。这种策略首先在大规模语料库上预训练BERT风格的神经网络以生成句子嵌入。预训练后,模型在特定下游任务上进行微调,通常使用额外的任务特定层将句子嵌入映射到所需输出。预训练模型处理各种NLP任务中特别具有挑战性但常见的方面——创建良好的句子嵌入,这显著简化了下游任务。然而,预训练模型并不直接解决下游任务,†每个下游任务都需要一个任务特定的模型版本。

生成大型语言模型的上下文内学习的最新进展,如GPT(41, 57)和LLaMA(58, 59),是NLP领域的一个重大范式转变。与为单个下游任务微调预训练模型不同,上下文内学习利用提示的任务描述和示例来定义任务。虽然这些生成大型语言模型可以使用微调技术,但它们的功能与传统的BERT风格微调不同。事实上,这些模型可以直接处理多个任务,而无需任何微调。微调这些生成大型语言模型的目的不是为了启用任务特定的调整,而主要是为了提高模型在特定领域或任务类型上的熟练度。

对于科学机器学习,预训练后跟微调范式也在解近似或算子近似的框架中提出。在这些情况下,神经网络被训练来近似特定的解函数(10–19)或算子(30, 33–39),然后进行微调以近似类似的解函数或算子。然而,这种方法与BERT风格预训练加微调范式共享类似的限制,即神经网络必须为每个不同的函数或算子单独微调。当更仔细地比较科学机器学习和NLP时,情况进一步恶化。在NLP领域,创建高效的句子嵌入简化了大多数(如果不是全部)下游任务。相反,科学机器学习中的函数和算子非常多样化,定义一个可以作为起点用于近似广泛函数或算子的通用“基础函数”或“基础算子”极其困难。从这个角度来看,开发“基础模型”(60)的任务,即在大规模数据上训练并适应广泛下游任务的模型,在这种方法中变得艰巨,即使神经网络的规模扩大。

所提出的上下文内算子学习将GPT风格模型的范式,而不是BERT风格预训练加微调,转移到科学机器学习中。正如我们的实验所示,ICON模型可以直接学习广泛的算子,而无需任何微调。然而,就像GPT风格的模型一样,ICON模型也可以微调以专门处理特定的一组算子。展望未来,我们设想开发一个在大规模数据集上训练的模型,在上下文内算子学习范式下,作为基础模型。该模型可以直接用于广泛的算子学习任务,或者可以微调以提高其在处理特定一组算子上的熟练度。

ICON在小规模上的应用 在狭窄领域中,具有训练数据的小型语言模型(低于1000万个参数)可以生成多样化、流畅且一致的故事(61)。与此一致,我们的实验结果表明,由大约3000万个参数组成的小型ICON模型有能力处理相对简单的合成算子。这些发现表明,对于实际应用,如果目标是掌握有限范围的算子而不是训练一个通用基础模型,那么小型ICON模型就足够了。


总结

在本文中,我们提出了“上下文内算子学习”的范式及其相应的模型“ICON”,用于学习微分方程问题的算子。它超越了传统范式,即近似特定问题的解或某些特定解算子。相反,ICON在推理期间充当“算子学习器”,即从给定的示例中学习算子并将其应用于新条件,而无需任何权重更新。

通过我们的数值实验,我们证明了单个神经网络有能力从少量提示示例中学习算子,并有效地将其应用于问题条件。这种单个神经网络,无需任何重新训练或微调,可以处理一系列多样化的微分方程问题,包括ODE、PDE和平均场控制问题的正向和逆向问题。

此外,尽管在训练期间表示条件/QoI函数的键值对数量限制在狭窄范围内,ICON可以在测试期间将其上下文内算子学习能力泛化到更广泛的范围内,随着键值对数量的增加,误差减少并收敛。此外,ICON展示了学习参数超出训练分布的算子的能力。

最后,我们的观察提供了初步证据,表明ICON有可能学习和应用对应于以前未见过的方程形式的算子。

我们实验的规模相对较小。未来,我们希望扩大神经网络的规模、微分方程问题的类型、键和值的维度、条件和QoI的长度以及示例数量。这需要进一步发展上下文内算子学习,包括改进神经网络架构和训练方法,以及进一步的理论和数值研究上下文内算子学习的工作原理。例如,在NLP领域,如GPT-4中,规模扩大导致超出人类预期的涌现能力或行为(57)。我们期待在大型算子学习网络中见证这种涌现的可能性。





原文链接:https://www.pnas.org/doi/epdf/10.1073/pnas.2310142120


CreateAMind
ALLinCreateAMind.AGI.top , 前沿AGI技术探索,论文跟进,复现验证,落地实验。 鼓励新思想的探讨及验证等。 探索比大模型更优的智能模型。
 最新文章