大语言模型上下文窗口扩展方法
一、简介
上下文窗口(context window)是大语言模型(LLM)的一个重要概念,指模型在一次推理或生成过程中可以处理的文本长度。大语言模型的上下文窗口大小取决于预训练期间使用的训练文本序列长度。在推理阶段,一旦用户输入的文本序列长度超过上下文窗口大小,大模型性能表现将会严重下降。
但是在实际应用场景中,如进行长时间对话、总结长文档或执行长期规划等,经常会超过大模型的上下文窗口大小。因此如何扩展大模型的上下文窗口大小是一个被持续关注的问题。
本文分享两篇相关的研究:SelfExtend和YaRN(Yet another RoPE extensioN method)来分别介绍在微调阶段和推理阶段应用的大模型上下文窗口扩展方法。
二、SelfExtend
SelfExtend 通过构建双层注意力信息(分组注意力和近邻注意力)来扩展 LLM 的上下文窗口。分组注意力捕获相距较远的标记之间的依赖关系,而邻居注意力捕获指定范围内相邻标记之间的依赖关系。两级注意力是在推理过程中基于原始模型的自注意力机制计算的。通过少量代码修改, SelfExtend 可以轻松扩展现有 LLM 的上下文窗口,无需任何微调。
2.1.分组注意力
处理大模型看不到的相对位置的一种可行且直接的方法是将它们映射到预训练期间看到的位置。可以使用 FLOOR 操作将未见过的位置映射到预训练上下文窗口内的位置,如下图所示。
所提出的方法与原始自注意力机制相同,只是 FLOOR 操作之前应用于每个 token 的原始位置内积。我们将应用 FLOOR 操作的自注意力表示为“分组注意力”。在Python风格中,“分组注意力”表示为:
其中 是原始的位置编码,B是batch size,L是输入文本序列长度;是分组位置编码,是分组大小。
2.2. 近邻注意力
如下图所示,分组注意力对扩展上下文窗口确实有效,但是我们也注意到分组注意力的引入也导致了模型在处理上下文窗口内的序列的困惑度升高。
一些研究表明,目标Token的直接近邻的位置编码准确性和有效性对LLM的性能影响较大,维持接近目标的Token的标准注意力机制是非常重要的。因此,必须在目标Token附近保留标准注意力机制而不采用分组注意力。
2.3. SelfExtend
SelfExtend的整体流程如下:
SelfExtend 包含两种不同类型的注意力机制:
分组注意力:在近邻窗口外使用分组注意力,以处理Token之间的长距离关系,同时扩展上下文窗口大小。 标准注意力(近邻注意力):对指定范围(近邻窗口)内的相邻Token采用传统的注意力机制。以确保目标附近的Token信息的准确性。
两种注意力机制以近邻窗口为分界线,近邻窗口外采用分组注意力,近邻窗口内采用标准注意力,最后将两种注意力进行合并。
使用SelfExtend扩展后的上下文窗口长度可由以下公式计算得到 :
2.4.实验
2.4.1. 超参数实验
在Needle In A HayStack上测试Phi-2、Llama-2-7b-chat-hf模型对超参数分组大小和近邻窗口大小对模型性能的影响进行了测试。
实验结果表明SelfExtend对超参数的选择并不过分敏感。只要组大小和邻居窗口不太大或太小,组大小和邻居窗口大小的预定义启发式值通常足以实现令人满意的性能。我们将这些结果总结为经验规则。将预训练上下文窗口表示为 L,目标扩展长度为 N,邻居窗口为 W,组大小为 G,选择超参数的经验规则是确保以下不等式成立:
2.4.2 性能评估结果
不需要微调的SelfExtend方法在语言建模任务、综合长文本任务、现实世界上下文任务以及变长密钥检索任务中都取得了与基于微调的扩展方法相当甚至更好的性能表现。并且相较于使用标准注意力的模型,SelfExtend在短上下文任务上的表现也并无性能下降问题。
三、YaRN(Yet another RoPE extensioN method)
旋转位置嵌入 (RoPE) 已被证明可以在基于 Transformer 的语言模型中有效地编码位置信息。然而,这些模型无法处理超过它们训练的序列长度的序列。YaRN是一种计算高效的方法,用于扩展此类模型的上下文窗口,与以前的方法相比,需要的Token少 10 倍,训练步骤少 2.5 倍。通过使用YaRN微调 LLaMA 模型,可以有效地利用和推断比原始预训练允许的更长的上下文长度,同时也超越了上下文窗口扩展方面的其他最新技术。
3.1.高频信息损失:NTK-aware 插值
从信息编码的角度来看待 旋转位置嵌入(Rotary Position Embeddings, RoPE),使用神经正切核(Neural Tangent Kernel,NTK)理论,如果输入维度较低且缺乏相应的嵌入,深度神经网络将难以学习高频信息高频成分。
NTK-aware插值用于解决 RoPE 嵌入插值时丢失高频信息的问题。不同于线性插值法将 RoPE 的每个维度均等地缩放 s 倍,NTK-aware插值通过减少高频和增加低频来将插值压力分散到多个维度。可以通过多种方式获得这样的变换,但最简单的是对 θ 的值进行基础更改。
为了进行位置插值,我们对RoPE中的位置编码嵌入变换进行修改:
其中m为位置索引, 是基于频率的对角矩阵。 为基数(通常取b=10000); 和 取决于具体插值方法。
对于NTK-aware插值, 和 定义如下:
其中D表示隐藏神经元集合,且:
3.2.相对局部距离损失:NTK-by-parts 插值
RoPE 嵌入的一个有趣的观察是,给定上下文大小 L,对于一些维度 d,其波长比预训练期间的最大上下文长度更长(λ > L),这表明某些维度的嵌入可能在旋转域中分布不均匀。在这种情况下,我们假设拥有所有唯一的位置对意味着绝对位置信息保持不变。相反,当波长较短时,神经网络只能访问相对位置信息。 此外,当我们通过比例 s 或使用基数变化 b' 拉伸所有 RoPE 维度时,所有标记都会变得彼此更接近,因为旋转较小量的两个向量的点积更大。这种扩展严重削弱了LLM理解其内部嵌入之间的小型和局部关系的能力。我们假设这种压缩会导致模型对附近标记的位置顺序感到困惑,从而降低模型的能力。为了解决这个问题,我们选择根本不对较高频率维度进行插值,而始终对较低频率维度进行插值。特别地:
如果波长 λ 远小于上下文大小 L,则我们不进行插值; 如果波长λ等于或大于上下文大小L,我们只想进行插值并避免任何外推(与前文的“NTK-aware”方法不同); 介于两者之间可以同时使用两种方法,类似于NTK -aware插值。
因此,为了方便起见,引入原始上下文大小L和波长λ之间的比率。在第 d 个隐藏状态中,比率 r 取决于 d,如下所示:
为了定义上述不同插值策略的边界,我们引入两个额外的参数α、β。所有隐藏维度 d(其中 r(d) < α )是我们按比例 s 线性插值的维度(与 PI 完全相同,避免任何外推),而 d(其中 r(d) > β )是不插值的维度。定义斜坡函数 γ 为
因此,我们引出NTK-by-parts插值的定义:
其中,α 和 β 的值应根据具体情况进行调整。对于 Llama 系列模型,实验得到的α 和 β 的最佳值为 α = 1 和 β = 32。
3.3. 动态缩放:动态NTK插值
对于嵌入层固定的插值方法,比例因子 保持不变,这会导致模型长度小于L时可能会出现性能下降,而当序列长度大于时模型可能突然退化。
为了解决这一问题引入动态缩放方法,即在每次前向传递中,位置嵌入都会更新比例因子 ,其中 l′ 是当前序列的序列长度。通过引入动态缩放,允许模型在达到训练的上下文限制时优雅地降级,而不会立即崩溃。
3.4. YaRN
除了前面的插值技术之外,我们还观察到,无论数据样本和扩展上下文窗口上的Token位置如何,在注意力 softmax 上引入温度 t 对困惑度都有影响 。修改后的注意力权重计算如下:
结合上述缩放策略和NTK-by-parts插值方法,我们得到YaRN插值方法。YaRN 方法结合了上述的发现,并在微调和非微调场景中超越了之前的所有方法。由于其占用空间小,YaRN 允许与修改注意力机制的库直接兼容。
3.5. 实验效果
YaRN方法在长序列语言建模和密钥检索任务都达到了非常高的准确度。
其次,为了测试上下文扩展下模型性能的下降,使用Hugging Face 的开放LLM在基准化测试评估YaRN模型,并与Llama2 Baseline以及公开的的PI和NTK-aware模型数据进行比较。
如上表所示,YaRN 模型与其各自的 Llama 2 基准之间的性能下降很小。另外,在YaRN s = 16 和 s = 32 模型之间的分数平均下降了 0.49%。因此,从 64k 到 128k 的迭代扩展导致的性能损失可以忽略不计。
四、总结
YaRN和SelfExtend分别在微调和推理阶段提供了有效的上下文窗口扩展方案,前者通过优化位置嵌入插值和动态缩放,后者通过双层注意力机制扩展了模型的上下文处理能力。两者在不同应用场景下都表现出色,为大语言模型的长文本处理提供了新的技术路径。
五、相关文献
[1] Jin H, Han X, Yang J, et al. Llm maybe longlm: Self-extend llm context window without tuning[J]. arXiv preprint arXiv:2401.01325, 2024.
[2] Peng B, Quesnelle J, Fan H, et al. Yarn: Efficient context window extension of large language models[J]. arXiv preprint arXiv:2309.00071, 2023.