Google大语言模型Gemma 2介绍及其微调(上篇)

文摘   2024-07-22 16:27   上海  
  • 引言
  • 简介
  • Gemma 2模型介绍
    • 架构设计
    • 训练方法
    • 后训练优化
    • 关键发现:知识蒸馏的影响
    • 性能评估
  • 使用
    • 体验:Hugging Chat
    • 如何提示 Gemma 2
    • 基于Hugging Face Transformers
  • 结论与展望
  • 模型汇总

引言

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:卖荔枝的小男孩。Google 最近谷歌发布了开源大语言模型 Gemma 2,目前可以在 huggingface 上找到 4 个开源模型(2 个基础模型和 2 个微调模型)。今天这篇小作文主要介绍Gemma 2的一些技术特点及其使用初体验,下一篇小作文将介绍如何微调Gemma 2模型。

技术报告原文:https://storage.googleapis.com/deepmind-media/gemma/gemma-2-report.pdf

简介

2024年6月27日,Google DeepMind发布了Gemma 2,这是Gemma系列轻量级开放语言模型的最新成员。Gemma 2在架构和训练方法上都有重大创新,在多项基准测试中取得了显著进步,甚至可以与参数规模大2-3倍的模型相媲美。本文将对Gemma 2技术报告的主要内容进行解读,包括模型架构、预训练和后训练方法、性能评估等方面。

Gemma 2 模型的训练数据量约为其第一代的两倍,总计 13 万亿 Tokens(27b 模型)和 8 万亿 Tokens(9b 模型)的网页数据(主要是英语)、代码和数学数据。官方尚未透露训练数据混合的具体细节。

Gemma 2模型介绍

Gemma 2系列包括三种规模的模型:27B(270亿参数)、9B(90亿参数)和2.6B(26亿参数)。其中27B和9B模型已经发布,2.6B模型即将发布。这些模型均采用Transformer解码器架构,具有以下主要特点:

  • 上下文长度为8192个token
  • 使用旋转位置编码(RoPE)
  • 采用近似GeGLU非线性激活函数
  • 交替使用局部滑动窗口注意力和全局注意力
  • 应用logit软上限技术
  • 使用RMSNorm进行后归一化和前归一化
  • 9B和27B模型采用分组查询注意力(GQA)机制

架构设计

相比Gemma 1,Gemma 2在架构上有几项重要改进:

(1) 局部滑动窗口和全局注意力交替

Gemma 2在每隔一层交替使用局部滑动窗口注意力和全局注意力。局部注意力的窗口大小为4096个token,全局注意力的跨度为8192个token。这种设计兼顾了计算效率和长程依赖建模能力。

(2) Logit软上限

在每个注意力层和最终输出层,Gemma 2对logit值进行软上限处理,将其限制在[-soft_cap, +soft_cap]范围内。具体实现为:

logits ← soft_cap * tanh(logits/soft_cap)

9B和27B模型的注意力logit上限为50.0,最终logit上限为30.0。这项技术有助于稳定训练过程和生成质量。

(3) RMSNorm归一化

Gemma 2使用RMSNorm来归一化每个Transformer子层、注意力层和前馈层的输入和输出,以提高训练稳定性。

(4) 分组查询注意力(GQA)

9B和27B模型采用了GQA机制,组数为2。实验表明,GQA可以在保持下游任务性能的同时提高推理速度。

训练方法

Gemma 2的另一大亮点是采用了创新的训练方法,尤其是对较小规模模型(9B和2.6B)使用知识蒸馏技术。

(1) 预训练数据

27B模型在13万亿token的数据上训练,9B模型使用8万亿token,2.6B模型使用2万亿token。训练数据主要为英语,来源包括网页文档、代码和科学文章等。tokenizer采用SentencePiece方法,词表大小为256k。

(2) 知识蒸馏

对于9B和2.6B模型,Gemma 2团队创新性地将知识蒸馏应用于大规模预训练。具体做法是:使用一个大型模型作为教师,最小化学生模型与教师模型在每个token的条件概率分布之间的负对数似然。这种方法可以在有限的计算资源下,让小模型获得类似于训练更多token的效果。

(3) 计算基础设施

Gemma 2使用了TPUv4、TPUv5e和TPUv5p进行训练,采用数据并行和模型并行相结合的策略。例如,27B模型在6144个TPUv5p芯片上训练,采用768路数据复制和8路模型分片。

后训练优化

为了进一步提升模型性能和安全性,Gemma 2在预训练后进行了一系列优化:

(1) 监督微调(SFT)

在英语指令-响应对数据集上进行监督微调,数据包括合成和人工生成的样本。

(2) 人类反馈强化学习(RLHF)

在SFT基础上应用RLHF,奖励模型基于标注的英语偏好数据训练,策略使用与SFT相同的提示。

(3) 模型融合

将不同超参数实验得到的模型进行平均,以提升整体性能。


关键发现:知识蒸馏的影响

技术报告重点介绍了知识蒸馏对小型语言模型性能的显著影响:

  • 在500B token上训练的2.6B模型,使用知识蒸馏比从头训练在3项基准测试的平均分上高出7.4个百分点。
  • 随着模型规模增加(从200M到1B参数),知识蒸馏带来的收益依然存在。
  • 在9B规模上,使用GQA替代多头注意力(MHA)对性能影响不大,但可以减少参数量并提高推理速度。
  • 对于9B模型,更深的网络结构略优于更宽的结构。
  • 可以在推理时调整局部注意力的滑动窗口大小,对困惑度的影响较小,为推理速度优化提供了灵活性。

性能评估

Gemma 2在多项基准测试中展现出优异性能:

  • 在同等规模的开放模型中达到最佳性能
  • 甚至可以与参数量大2-3倍的模型相竞争
  • 在问答、常识推理、数学科学、编程等多个领域的任务上表现出色

具体评估涵盖了以下方面:

  • 自动化基准测试
  • 人工评估
  • 问答能力(如SQuAD、Natural Questions)
  • 常识推理(如WinoGrande、HellaSwag)
  • 数学和科学(如MATH、MMLU)
  • 代码能力(如HumanEval、MBPP)

使用

体验:Hugging Chat

你可以在 Hugging Chat 上与 Gemma 27B 指令模型聊天!查看此链接:https://huggingface.co/chat/models/google/gemma-2-27b-it

如何提示 Gemma 2

基础模型没有提示格式。像其他基础模型一样,它们可以用于继续输入序列的合理延续或零样本/少样本推理。指令版本有一个非常简单的对话结构:

<start_of_turn>user
knock knock<end_of_turn>
<start_of_turn>model
who is there<end_of_turn>
<start_of_turn>user
LaMDA<end_of_turn>
<start_of_turn>model
LaMDA who?<end_of_turn><eos>

必须精确地复制此格式才能有效使用。稍后我们将展示如何使用 transformers 中的聊天模板轻松地复制指令提示。

基于Hugging Face Transformers

随着 Transformers 版本 4.42  的发布,你可以使用 Gemma 并利用 Hugging Face 生态系统中的所有工具。要使用 Transformers 使用 Gemma 模型,请确保使用最新的 transformers 版本:

pip install "transformers>=4.42.3" --upgrade

以下代码片段展示了如何使用 transformers 使用 gemma-2-9b-it。它需要大约 18 GB 的RAM,适用于许多消费者 GPU。相同的代码片段适用于 gemma-2-27b-it,需要 56GB 的 RAM,使其非常适合生产用例。通过加载 8-bit 或 4-bit 模式,可以进一步减少内存消耗。

from transformers import pipeline
import torch

pipe = pipeline(
    "text-generation",
    model="google/gemma-2-9b-it",
    model_kwargs={"torch_dtype": torch.bfloat16},
    device="cuda",
)

messages = [
    {"role""user""content""请模仿蔡坤的方式说两句土味情话"},
]
outputs = pipe(
    messages,
    max_new_tokens=256,
    do_sample=False,
)
assistant_response = outputs[0]["generated_text"][-1]["content"]
print(assistant_response)

输出结果如下:

1. 你就像一碗热腾腾的番茄鸡蛋面,让我欲罢不能!
2. 你的眼睛像星星一样闪亮,照亮了我这颗孤独的心!


希望你喜欢我的土味情话!😜

这里我们使用 bfloat16 因为这是指令调优模型的参考精度。在你的硬件上运行 float16 可能会更快,90 亿模型的结果应该是相似的。然而,使用 float16 时,270 亿指令调优模型会产生不稳定的输出:对于该模型权重,你必须使用 bfloat16。

你还可以自动量化模型,以 8-bit 甚至 4-bit 模式加载。加载 4-bit 模式的 270 亿版本需要大约 18 GB 的内存,使其兼容许多消费者显卡和 Google Colab 中的 GPU。这是你在 4-bit 模式下加载生成管道的方式:

import os
from transformers import pipeline
import torch

init_model_dir = "/share_model_zoo/LLM/"
# init_model_id = "google/gemma-2-9b"
init_model_id = "google/gemma-2-9b-it"
init_model_path = os.path.join(init_model_dir, init_model_id)

pipeline = pipeline(
    "text-generation",
    model=init_model_path,
    model_kwargs={
        "torch_dtype": torch.bfloat16,
        "quantization_config": {"load_in_4bit"True}
    },
    device_map="auto",
)

messages = [
    {"role""user""content""请模仿蔡坤的语气说两句土味情话"},
]
outputs = pipeline(
    messages,
    max_new_tokens=256,
    do_sample=False,
)
assistant_response = outputs[0]["generated_text"][-1]["content"]
print(assistant_response)

输出结果如下:

1.  宝贝,你就像一朵盛开的玫瑰,香气迷人,让我忍不住想把你捧在手里,永远守护着。
2.  你是我这辈子最想遇见的人,就像天上的星星,照亮我的整个世界,让我不再迷茫。


希望你喜欢这些土味情话!

有关使用 Transformers 模型的更多详细信息,请查看模型卡。

结论与展望

Gemma 2代表了轻量级开放语言模型的重大进展。通过创新的架构设计和训练方法,尤其是知识蒸馏技术的大规模应用,Gemma 2在有限的参数规模下实现了卓越的性能。这不仅推动了语言模型技术的发展,也为更广泛的AI应用提供了高效、实用的解决方案。

未来的研究方向可能包括:

  • 进一步优化知识蒸馏技术
  • 探索更高效的模型架构
  • 增强模型的多语言和跨领域能力
  • 改进模型的安全性和可控性

总的来说,Gemma 2的成功表明,通过精心的设计和创新的训练方法,我们可以在不盲目追求参数规模的情况下,显著提升语言模型的能力。这为构建更高效、更易部署的AI系统开辟了新的可能性,有望推动AI技术在更多实际场景中的应用与落地。

进交流群请添加小助手微信



关于互联网持续学习圈


互联网持续学习圈是由清华大学计算机系校友、前阿里和微软算法工程师创办。汇聚互联网精英、985高校及海外硕博、自主创业者等,是持续学习者的专属圈。专注互联网资讯、科研、求职等。器识其先,文艺其从,陪你进化二十年。


互联网持续学习圈
清华大学计算机系校友、前微软、阿里高级算法工程师创办。汇聚互联网精英、985高校及海外硕博、自主创业者,持续学习者的专属圈。专注互联网资讯、科研、求职等。器识其先,文艺其从,陪你进化二十年。
 最新文章