01
引言
大型语言模型通常是在海量数据集上训练的神经网络,用于理解和生成人类语言。它们依赖于Transformer等架构,这些架构使用诸如自注意力机制来处理和生成文本。
02
基础知识
自Meta发布基础模型LLama 系列后,我们目睹了各种基于Llama 的微调开源模型(如Alpacca 和Vicuna 等)。一些典型的模型如Falcon , MPT , 以及 Llama-2 和 Llama-3 等流行模型已成为主流模型选择。
Model | Positional Embeddings | Attention Mechanism |
AliBi Embeddings | Multi-head Attention (MHA) | |
Falcon | Rotary Embeddings | Multi-Query Attention |
Llama2 | Rotary Embeddings | Grouped Query Attention for 70B MHA for 13B and 7B |
03
Llama-2模型结构
首先,我们来加载Llama-2 模型,并尝试打印其模型结构,代码如下:
from transformers import AutoTokenizer, AutoModelForCausalLM
# You can create token for ur account here: https://huggingface.co/settings/tokens
model_name = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token)
下图展示了该模型的具体网络结构,如下所示:
04
嵌入层参数计算
05
注意力块参数计算
接着我们来关注自注意力块的参数量计算,其代码实现如下:
如我们在第二节中的表格所示,70B 版本的Llama-2 采用了分组查询注意力机制GQA , 而13B 版本的模型则采用了多头注意力机制MHA 。值得注意的是,GQA 涉及在每个组内共享 Key-Value 对,从而减少了推理过程中KV-Cache的大小。。
在Llama-2-13B 的 MHA 块中,有 40 个注意力头,每个注意力头的维度为 128。因此,W_Q 矩阵的大小计算为 5120 x (128 x 40),即 26 214 400 个参数。重要的是,在 MHA 块中,W_O、W_K 和 W_V 矩阵的维数与 W_Q 相同。
06
MLP块参数计算
接着我们来看MLP块,其结构为:
out = down_proj(actn_fn(up_proj(input)))
不过,在 Llama-2 中,MLP 模块由三个基本层组成:up_proj、down_proj 和 gate_proj,这三个层的组合创造了一个独特的架构:
out = down_proj( act_fn(gate_proj(input)) x up_proj(input) )
07
RMS归一化层参数计算
Llama-2 使用的是 RMSNorm,而不是论文 《Attention is All You Need》中提到的 LayerNorm。RMSNorm 使用激活的均方根进行归一化,并使用可学习的参数对其进行缩放。
08
LM head层参数计算
最终的 LM 分类头接收了 5,120 维特征,并将其分为 32,000 个类别。
因此,lm_head_param = 5120X32000=163840000
09
计算总参数量
在Transformer架构中,注意力模块和 MLP 模块合并为一个Transformer层,并重复多次。要计算参数总数,我们可以使用下面的公式:
Total parameters = embed_parameters + num_layers x (attn_module_parameters
+ mlp_block_parameters + per_layer_rms_norm_ parameters)
+ pre_lm_head_rms_norm_parameters + lm_head_parameters
带入相应的数值,结果如下:
Total parameters = 163,840,000 + 40 x ( 104,857,600
+ 212,336,640 + 5,120 x 2) + 5, 120 + 163,840,000
= 13,015,864,320
10
Pytorch验证
要确定上面加载的 PyTorch 模型中的参数数量,可以使用下面的代码片段:
num_parameters = sum(p.numel() for p in model.parameters())
print(num_parameters)
# Number of parameters in Llama-2-13B: 13015864320
所以,我们的计算结果完全正确!
11
总结
本文探讨了如何计算LLM大语言模型的参数总量,通过逐层拆解计算,可以弄清楚每一层的参数量,最后通过和Pytorch的计算结果进行核验,证实了我们的计算过程,希望可以帮助大家处理其他模型的参数计算。
如果您觉得这篇文章有价值,欢迎点赞收藏关注一键三连,以获取更多有价值的内容。
点击上方小卡片关注我
添加个人微信,进专属粉丝群!