(2024,Jamba1.5,ExpertsInt8量化,LLM,激活损失)大规模混合 Transformer-Mamba 模型

文摘   2024-08-23 14:35   新加坡  

Jamba-1.5: Hybrid Transformer-Mamba Models at Scale

目录

0. 摘要

2. 模型架构

3. 服务考量与改进

3.1 ExpertsInt8 量化

3.2 激活损失

4. 吞吐量和延迟分析

5. 训练

5.1 训练基础设施和数据

5.2 训练阶段

5.3 后训练

5.4 一些观察

6. 评估

6.1 学术基准

6.2 聊天机器人评估 

6.3 长上下文评估 

6.4 多语言能力

8. 结论



0. 摘要

我们推出了 Jamba-1.5,这是一种基于我们 Jamba 架构的新指令微调大型语言模型。Jamba 是一种混合 Transformer-Mamba 的 MoE 架构,能够在不同上下文长度中提供高吞吐量和低内存使用,同时保持与 Transformer 模型相同或更好的质量。我们发布了两种模型尺寸:Jamba-1.5-Large,具有 94B 有效参数,和 Jamba-1.5-Mini,具有 12B 有效参数。这两个模型都经过微调,以支持各种对话和指令跟随能力,并且具有 256K tokens 的有效上下文长度,是开放权重模型中最长的。为了支持成本效益的推理,我们引入了 ExpertsInt8,这是一种新颖的量化技术,能够在处理 256K tokens 上下文时,将 Jamba-1.5-Large 适配于一台配备 8 个 80GB GPU 的机器上运行,并且不会损失质量。在一系列学术和聊天机器人基准测试中,Jamba 模型表现优异,提供高吞吐量并在长上下文基准测试中优于其他开放权重模型。

2. 模型架构

(2024,Attention-Mamba,MoE 替换 MLP)Jamba:混合 Transformer-Mamba 语言模型

(2023,SSM,门控 MLP,选择性输入,上下文压缩)Mamba:具有选择性状态空间的线性时间序列建模

Jamba-1.5-Large 基于我们开发的 Jamba [24] 混合解码器架构,该架构结合了 Transformer 层 [36] 与 Mamba 层 [13],一种状态空间模型 (SSM) [14, 15],以及专家混合 (MoE) 模块 [8, 34]。有关此架构的详细描述,请参见 [24]。

在 Jamba [24] 的工作中,我们发现 Transformer、Mamba 和 MoE 元素的组合有助于在吞吐量、内存使用和质量之间实现平衡。Jamba-1.5-Large 在更大规模上展示了这种灵活性。

Jamba-1.5-Large 遵循相同的 Jamba 结构,但具有更大的容量。它拥有 94B 有效参数,总参数量为 398B。该模型有 9 个模块,每个模块具有以下规格:

  • 每个模块有 l = 8 层。

  • a : m = 1 : 7 的注意力层与 Mamba 层的比例。在我们关于 Jamba 的研究中发现这个比例是最优的 [24],后续的工作也证实了类似的比例成功 [6, 37]。

  • 每 e = 2 层使用 MoE 替代单一的 MLP。共有 n = 16 个专家,每个 token 选择最优的 K = 2 个。

  • 隐藏状态的维度为 8192。

  • 注意力查询头的数量为 64,KV 头的数量为 8。

表 1 将 Jamba-1.5 模型与相似规模的公开模型进行了比较。Jamba-1.5-Mini 的有效参数数量与 Mixtral 8x7B 相当,而 Jamba-1.5-Large 的有效参数数量介于 LLaMA-3.1-70B 和 Mistral-Large-2 之间。同时,Jamba 模型在 KV 缓存内存使用量(在 256K tokens 上下文下)方面远小于所有其他模型,相比同类模型减少了大约一个数量级。

在这些设置下,并结合我们的专门量化技术(第 3.1 节),Jamba-1.5-Large 可以在一台配备 8 个80GB GPU 的机器上运行,支持长达 256K tokens 的上下文长度。

(2024|ICML,Mamba2,SSD,SSM,SMA,矩阵变换,张量收缩,张量并行)Transformer 是 SSM

对于此次发布,我们还尝试了 Mamba-2 [6],这是 Mamba 的一个更快且改进的版本,据报道它在性能上超越了单独使用 Mamba 和 Transformers 的模型。然而,正如图 1 所示,我们发现,在混合架构中,Mamba-1-Attention 组合的效果优于 Mamba-2-Attention,因此我们在 Jamba-1.5-Large 中使用了 Mamba-1。我们还发现混合架构的性能优于纯 Mamba-2。我们推测这可能是因为 Mamba-2 相比于 Mamba-1 的一些优势(特别是使用更大状态尺寸的能力)在 Mamba 层之间交错全注意力层时不那么显著,因为全注意力层可以从整个上下文中汇聚信息。 

3. 服务考量与改进

我们分享了一些见解和改进,旨在实现大规模高效服务 Jamba 模型。

3.1 ExpertsInt8 量化

为了支持 Jamba-1.5-Large 的高效服务,我们开发了一种新型量化技术,称为 ExpertsInt8。我们观察到,超过 85% 的模型权重位于 MoE 层,超过 90% 位于 MoE 或 MLP 层。我们希望在保持快速 BF16 内核的好处的同时,对这些权重进行量化。为此,我们将 MoE 和 MLP 权重量化为 INT8,存储为 INT8,并在实际计算之前将它们解量化回 BF16。重要的是,解量化步骤直接发生在 vLLM [18] 中的 fused_moe 内核内部。这样,解量化过程增加的开销微乎其微,甚至在延迟上比 BF16 更有优势。【我们将这归因于内核在相对较小的权重和激活块上操作,这些块在执行计算之前从 GPU HBM 移动到 SRAM。在我们的实现中,当权重量化为 int8 时,它们从 HBM 移动到 SRAM,因此由于内存占用减少了一半,所需时间也减少了。】

我们已经将修改后的 fused_moe 内核贡献给了 vLLM。【拉取请求在此处:https://github.com/vllm-project/vllm/pull/7415】 

我们的 ExpertsInt8 方法具有几个优势。

  • 首先,它非常快速;量化过程仅需在模型加载时几秒钟。

  • 其次,与 vLLM 中大多数其他技术不同,它不依赖于需要数小时或数天且可能不稳定的校准过程。

  • 第三,我们仍然可以使用 BF16 来处理大规模激活。

  • 第四,它可以在 A100 GPU 上使用,而 FP8 仅在 H100 上可用。

  • 最后,我们的量化在延迟上与 FP8 相当,同时超越了其他量化技术,并且不会导致质量损失。

图 2 比较了使用不同量化技术的延迟,包括 Jamba-1.5-Mini、Jamba-1.5-Large 和两个 Mixtral 模型(8x78B 和 8x22B)。在 H100 GPU 上,ExpertsInt8 的延迟与 FP8 相匹配。在 A100 上,由于 FP8 不可用,ExpertsInt8 是一种有吸引力的技术,显著超越了 GPTQ [9]。结合上述 ExpertsInt8 的优点,这使得它成为服务大型 MoE 模型的一个有吸引力的量化技术。

3.2 激活损失

在预训练过程中,我们发现某些激活值,特别是特定专家的输出以及最后的 Mamba 层的输出,在处理特定输入 token 时,逐渐增大,最终达到高达 4 × 10^6 的值。尽管我们发现这并未对使用 BF16 精度进行的预训练造成损害,但这些激活值的幅度可能会在推理过程中引发数值问题,因为一些量化库仅支持 FP16 精度,而 FP16 的最大范围为 64K。

为了解决这些问题,我们添加了一个“激活损失”(Activation Loss)项,其值与前向传播中激活值的均方值成正比,并设有可配置的 α 因子,以惩罚较大的激活值。通过实验,我们发现这种辅助损失对训练没有影响,即使 α 值达到至少 10^{−3}。对于 Jamba-1.5-Large,我们使用了 α = 10^{−5},这足以将激活值减少到一个可接受的范围(最大 2K-3K)。此外,添加这一辅助损失几乎瞬间降低了激活值,使得它仅在训练结束时添加也不会影响训练速度和质量。

为了验证这种方法,我们在模型上使用 FP16 激活值运行了完整的评估套件,结果与使用 BF16 的评估结果相同,没有出现 NaN/溢出。

4. 吞吐量和延迟分析

得益于混合 Jamba 架构,我们的 Jamba-1.5 模型提供了出色的吞吐量和延迟性能。图 3 和图 4 分别展示了 Jamba-1.5-Mini 和 Jamba-1.5-Large 的表现。如图所示,我们的模型在延迟和吞吐量方面均显著优于相同规模的模型。它们在处理长上下文时展现出显著优势,存在较大的性能差距。重要的是,Jamba-1.5-Large 在处理长上下文时依然高效,而大型的 LLaMA3-405B 不能在相同硬件上运行。【注:Large 比 Mini 有更高的时延和更低的吞吐量,直观理解就是速度换性能】 

5. 训练

5.1 训练基础设施和数据

Jamba-1.5-Large 在 NVIDIA H100 GPU 上训练,使用我们内部开发的专有框架,包括 FSDP、张量并行、序列并行和专家并行。对于专家并行,我们适配了 MegaBlocks [10]。

5.2 训练阶段

该模型的训练分为三个阶段。

  • 在预训练阶段,模型首先在我们内部的数据集上进行训练,该数据集最后更新于 2024 年 3 月。我们的预训练数据集是公开的网页文档、代码、书籍和科学文章的混合体。我们的预处理流程包括解析、质量过滤和去重。为了最大化利用公开数据,我们开发了自己的解析器,并使用它来提取文本和格式。数据混合的具体组成通过各种消融实验确定。该阶段包括多语言数据,重点关注以下语言:英语、西班牙语、法语、葡萄牙语、意大利语、荷兰语、德语、阿拉伯语和希伯来语。

  • 然后,模型进行了一个短的中期训练,重点训练长文档,以强调其长距离能力。

  • 最后,模型经过了后训练(post-training),如下一节所述。

5.3 后训练

我们后训练的方法旨在同时实现两个目标:(i)为模型提供各种技能和对话能力;(ii)保留预训练中的能力,特别是中期训练中的长上下文能力。这两个目标部分存在冲突,因为大多数现有的后训练数据集包含相对较短的示例。

考虑到这些因素,我们的后训练过程包括在高质量对话数据、特定技能数据和长上下文数据上进行监督微调 [32, 39]。混合这些不同类型的数据旨在保留长上下文能力并获得所需的技能。如下面的评估所示,我们发现我们的模型在长上下文评估中表现非常好。

在进行监督微调时,我们大量使用合成数据,这在最近的基础模型中很常见 [7],并反映了我们构建复合 AI 系统 [20] 的结构化数据的方法。我们开发了多种不同的数据合成流程,针对不同的模型能力。所有流程都应用以下模式:(i)在目标分布中采样或生成提示;(ii)从语言模型中生成响应;(iii)根据自动验证和评分过滤或排名响应;(iv)后编辑以去除伪影并适应所需格式。我们使用不同的模型、提示、采样、过滤和编辑来处理不同的数据管道,从而组成最终的数据混合体。

我们基于大量主要内部的自动化指标选择了最终的训练配方(数据混合和超参数)。两个 Jamba-1.5 模型使用相同的控制 token 和格式模板进行微调,我们将这些作为我们发布的一部分,提供 HF 兼容的标记器和聊天模板;有关详细信息,请参见模型卡。

以下是一些合成数据生成的显著示例:

  • 基于表格的 QA:我们生成表格数据及其相应的问题-答案对,如我们在表格理解 [20] 中所示。然后,我们使用语言模型将表格转换为自然语言段落。我们生成的训练示例包括提取、聚合和归属任务,涉及给定表格中特定行或列的文本。

  • 文档 QA:给定一个文档,我们提示语言模型生成问题-答案对,适用于单个或多个段落。我们有时通过添加类似的文本将这些示例嵌入更长的上下文中,以鼓励长上下文理解和归属。

  • 工具使用:我们使用开源的 Glaive 函数调用数据集 [1] 作为起点,通过各种启发式方法和对输出模式的验证进行过滤。为了支持并行函数调用,我们首先为 Glaive 中的每个函数生成多个有效的参数分配。接下来,我们从这些有效的参数分配中采样子集,针对相同函数和不同函数,生成与函数调用集对应的用户请求。最后,我们提示函数调用语言模型对这些生成的用户请求做出响应,并仅保留与原始参数分配匹配的函数调用的响应。

  • 可控性:我们定义了一组可以轻松验证的指令,并合成了包括一个或多个约束的通用文档草拟任务的提示。我们从语言模型中生成这些提示的完成,并基于对我们细粒度指令的验证以及通用奖励模型进行拒绝采样。为了支持系统消息中的指令,我们选择了多种这种类型的提示,这些提示共享细粒度指令实例,并将这些提示重新格式化为多轮对话,将指令移到系统消息中。

5.4 一些观察

我们分享了一些关于 Jamba-1.5 开发过程中的观察。这些观察虽然尚未完全探索,但希望能够激发社区进一步研究这些问题。首先,虽然我们仅在后训练阶段包含了非常小比例的非英语数据,仅针对几个语言和特定技能,我们的 Jamba-1.5 模型在多语言环境下表现相当出色。我们确实在预训练阶段包括了多语言数据,如上所述。因此,我们推测模型能够在主要以英语进行后训练时,利用预训练阶段学到的知识。

其次,我们高效的 Jamba 架构降低了在长上下文上的微调成本,使我们能够在给定预算下进行更多实验。因此,我们能够在后训练阶段实验多种不同的训练配方。

最后,尽管偏好微调算法如 PPO [33] 或 DPO [29] 可以改善模型输出与人类意图之间的对齐,我们发现细致的合成数据生成、数据过滤和监督微调的结合对获得强大的后训练模型至关重要。

6. 评估

虽然我们相信基准测试仅与实际应用的成功和用户满意度部分相关,但我们仍报告了一些关键公共基准的结果。首先,我们报告标准学术基准的结果。然后,我们评估模型在聊天机器人基准上的表现。最后,我们在多个长上下文评估和多语言评估中评估 Jamba-1.5-Large。

我们将其与最近的相同规模范围的开放权重模型进行比较:在比较 Jamba-1.5-Large 时,与 LLaMA-3.1 70B 和 Mistral-Large-2-123B;在比较 Jamba-1.5-Mini 时,与 LLaMA-3.1-8B 和 Gemma-2-9B 进行比较。

6.1 学术基准

6.2 聊天机器人评估 

在本节中,我们对 Jamba-1.5 模型在两个聊天机器人(chatbot)场景中的表现进行了吞吐量评估:

  • Arena-Hard [22],这是一个包含 500 个具有挑战性的用户查询的数据集,使用 GPT4-Turbo 作为评判;

  • WildBench [25],该数据集也使用 GPT4-Turbo 作为评判,但进行了长度偏差的缓解处理。 

6.3 长上下文评估 

我们在 RULER 基准上进行评估,RULER 是一组 13 个合成任务,旨在评估语言模型的长上下文能力。RULER 包括 8 种变体的针在大 haystack 检索任务 [17, 21, 27, 28],其中包含多个“针” [2]。此外,它还有一个变量跟踪任务,需要返回一系列变量绑定,两个聚合任务,需要返回最常见的词汇,以及两个问答任务,其中包含来自自然数据集 [30, 41] 的答案的段落被插入到随机段落中,以模拟长上下文。

接下来,我们在 ∞BENCH 数据集上进行评估,该数据集旨在评估语言模型的长上下文能力,平均长度为 100K 词汇。我们专注于两个英文任务,理解长篇小说:问答(EN.QA)和多项选择题问答(EN.MC)。如表 5 所示,Jamba-1.5 模型在这方面表现非常出色,超越了同样规模的 LLaMA-3.1 和 Mistral-Large-2 模型。(由于 Gemma-2 9B 的上下文窗口较短(8K),我们未报告其结果。)

6.4 多语言能力

我们对 Jamba-1.5 在非英语语言中的能力进行了基本评估。具体来说,我们报告了在多语言 MMLU 数据集 [19] 上的结果,该数据集通过 LM Evaluation Harness [11] 分发。如表 6 所示,Jamba-1.5-Mini 的表现与对比模型相当或更好。Jamba-1.5-Large 略微落后于其可比模型,但仍展现出良好的多语言能力。

8. 结论

我们介绍了 Jamba-1.5-Large 和 Jamba-1.5-Mini,这两个基于 Jamba 混合 Transformer-Mamba 架构的大规模模型。两个模型在学术基准、聊天机器人评估和长上下文评估中均表现出色,同时提供了改进的延迟和吞吐量,特别是在处理长上下文时。我们发布了模型权重,希望社区能够使用这些模型并在此技术基础上进行进一步开发。

 

论文地址:https://arxiv.org/abs/2408.12570

Jamba 开源模型许可:https://www.ai21.com/licenses/jamba-open-model-license。

项目页面:https://huggingface.co/ai21labs

Jamba-1.5-Mini:https://huggingface.co/ai21labs/AI21-Jamba-1.5-Mini

Jamba-1.5-Large:https://huggingface.co/ai21labs/AI21-Jamba-1.5-Large


公和众与号:EDPJ(进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群)
加 VX 群请备注学校 / 单位 + 研究方向
CV 进计算机视觉群
KAN 进 KAN 群

EDPJ
CV 博士在读。文章搜索:公众号主页右上角放大镜搜关键词。
 最新文章