今天给大家带来知乎@真中合欢的大模型实践系列文章-LLM的拒绝采样。
作者:真中合欢
知乎:https://zhuanlan.zhihu.com/p/4547529049
拒绝采样是一种蒙特卡洛方法,和重要性采样一样,都是在原始分布难以采样时,用一个易于采样的建议分布进行采样。拒绝采样只是为了解决目标分布采样困难问题,它需要原始分布是已知的。形式描述是这样的:
假设已知原始分布为 ,但是从 采样较为困难,我们可以找到一个容易采样的建议分布 。再确定一个常数 ,确保任取x满足 。然后从 中采样,以 的概率保留这个样本,得到的采样结果就是服从 的样本。
对比来看看拒绝采样和重要性采样,重要性采样的目的是通过易采样的建议分布估算原始分布的期望,目的是为了数值计算,而拒绝采样是为了采样出一批样本。下面举个具体的拒绝采样的用例。
具体例子
举个具体点的例子,假设现在有已知的分布 的概率密度函数 如下:
任取 我们都可以直接计算出 的值,我们也可以画出函数图像。
import matplotlib.pyplot as plt
import numpy as np
def func(x):
y = 1 / np.sqrt(2 * np.pi) * (np.e ** (- x ** 2 / 2)) * (1 + 0.5 * np.sin(5 * x))
return y
x = np.arange(1000) / 100 - 5
y = func(x)
plt.plot(x,y)
plt.show()
现在要采样出5000个服从 分布的随机样本。python肯定是没有直接从这个分布采样的函数,我们可以用拒绝采样来实现这个函数。因为 中 的取值范围是 ,我们找到建议分布也要满足这个要求。我们在 中观察到一个正态分布的影子,而python正好也能生成正态分布,那我们就用一个正态分布作为建议分布:
现在需要确定常数C,注意到:
最大值是1.5,那么C取1.5就能保证分母恒大于等于分子。接下来实现这个采样函数:
def proposal_func(x):
y = 1 / np.sqrt(2*np.pi) * (np.e ** (- x **2 / 2))
return y
def rejection_sampling(size):
samples = []
while len(samples) < size:
x = np.random.randn()
if func(x) / (1.5 * proposal_func(x)) > np.random.random(): # 所有“以 x 概率保留”操作在代码实现都是判断 x 是否大于一个0-1随机数
samples.append(x)
return samples
samples = rejection_sampling(5000)
把采样结果的直方图画出来看看是否服从我们的原始分布:
plt.plot(x,y)
plt.hist(samples, bins=60, density=True, alpha=0.5)
plt.show()
如果疑惑为什么用 作为接受概率就能够模拟原始分布,可以去问o1,它能给出完整的证明。
拒绝采样的应用
上面的案例是纯从统计学的角度使用拒绝采样,那在LLM的场景中,什么时候会用到拒绝采样呢?在深度学习中我们判别模型一般就可以当作 ,而训练或推理数据就是样本x。也就是说当我有了一个模型,想要数据的时候,就会用到拒绝采样。
上一篇CP提到,在拿不到开源模型pretrain训练数据的情况下,我们用模型在自有数据上的概率+重要性采样来调整loss权重。那这一次就用模型在自有数据上的概率+拒绝采样直接估算pretrain训练数据的分布。我这次拿开源的Qwen2-72B和Qwen2.5-72B作为实验对象。分别以Qwen模型计算出的句子概率作为原始分布,用在我司之前某个版本数据训练的模型作为建议分布,使用拒绝采样看看qwen模型各数据的偏差:
这里外圈是Qwen2.5,内圈是Qwen2,数据集做了匿名和归一化处理,可以理解成左侧是qwen2.5的强项,右侧是2的
当然更重要的应用,还是把拒绝采样用在模型训练上,这又是NLPer在追寻 路上的一种新尝试。
拒绝采样训练 RST
拒绝采样训练本身非常简单:用经过一轮或几轮的 SFT 模型、few shot Pretrain 模型回答 SFT 训练数据的问题,用奖励模型对原始标注数据和模型生成的答案进行打分,得分 top1 的替换原始训练数据中的答案,产生新的一版训练数据继续训练模型。
一般要注意的就是 reward 模型的 OOD 问题,如果 RS 用的 prompt 是从 reward 模型的训练数据里摘出来的,一般没什么问题。但是如果不一致,一般都会 OOD,导致错误结果分高,模型越训越烂。如果 SFT 和强化是分开迭代的,或者 reward 数据用开源的可能会遇到这个问题。RS 和 reward 尽量用同分布的数据,不行的话最好也对训练数据做一个聚类或者分类,对不会 OOD 的数据用 RS,其他还是正常训练。这里奖励模型的训练也有不少门道,不过怎么生产奖励模型的训练数据(人工标注、模型标注、先验规则、MCTS)和奖励模型的训练(判别式、生成式)并不是本文想要讨论的。本文更想简答讨论一下SFT、DPO、RS、PPO之间的联系和差异。
对比几种训练方式的异同
首先是SFT,用的是next token loss或者叫lm loss,这是现代语言模型最原始、最直接的训练方式:
其中 表示当给定 这个token序列时,第t个token是 的概率。这是个性质相当好的loss,它优化的是 的联合概率,也就是说它是一个生成模型的loss:
这个loss给pretrain用一点问题没有,因为pretrain的目的就是知识注入,理想的pretrain应该能通过采样还原出所有学过的文字知识,但是给SFT用似乎有一点问题。知乎有一个关于为什么要做RLHF的问题:
https://www.zhihu.com/question/651021172。
各方向大佬在里面分享了各种前沿的研究成果和动机,看完令人受益匪浅,推荐加入必读。但是通篇看完感觉并不能说服我SFT为什么不行。我在各种文章、现实工作中听到的最多的观点一般都会从在线、离线学习的角度分析SFT的不足。这确实是很重要的一点,但是离线学习也是学习,假设加入SFT数据集的数据就是最好最理想的数据,就包含期望的answer,模型怎么还是学不会?大多数人的成长似乎也都是读书上大课,也没有很多老师来因材施教,模型怎么就不行。我认为是因为训练和推理有gap。和pretrain相比,SFT的侧重点已经从忠实的反应训练数据分布变成了按照人的预期进行生成。但是语言模型loss低 能生成。假如使用贪婪的方式采样,能生成意味着所有目标token的概率都是top1。也就是说lm loss应该改成类似下面的这种形式:
是个指示函数,当 不是top1 token的时候为1,否则为0。举个具体的例子的话就是"中国的首都是哪里?是北京","是"、"北"、"京"的概率提高都能使loss降低,但是这句话能不能说对的关键是"北"的概率是不是top1,"京"的概率再高,"北"采不出来也还是不对。当然这个公式只是个示意,没有做严谨推导,或许真正需要的是一个软标签指示函数或者概率指示函数。
看完SFT再来看看DPO
DPO优化的损失函数是:
其中 分别表示正在训练的模型和冻住不训的参考模型, 分别表示prompt,较差的answer和较好的answer。当初第一眼见DPO时觉得这个抛弃奖励函数和PPO直接优化语言模型的设计非常巧妙,所以立马造了一批数据试验了一下。结果第一下就给模型训崩了,但是在我的惯性认知里这种连续的loss不会这么不稳定。坐下仔细思考了下DPO的loss函数,它的优化目标是最大化给定 下,好answer优于坏answer的概率:
这根本不是个生成模型loss,这是判别模型的loss,这下训崩就合情合理了。虽然loss里隐含了一点kl,但是常用kl的应该能明白,kl提供的梯度非常小,根本限制不住模型。所以目前工业界稳定跑通的DPO,基本用的都是下面这个公式:
也就是把好样本的语言模型loss加上,给这个loss多一些生成的成分。在这个训练范式下,DPO的作用更多是拿来打压一下生成过程中明确遇到的一些负例。
接下来就来到了拒绝采样
既然判别loss训出来的模型不好生成,那就找个好生成的模型来采样,判别模型负责拒绝就可以了,于是也就有了拒绝采样的训练。拒绝采样保证了训练的模型依旧是一个生成模型,靠reward模型拟合 并指导模型训练的方向。但是reward+top1的拒绝采样也不是一劳永逸的。首先是不连续的loss + reward模型OOD导致模型越训越烂的问题,其次就是这个topn采样并不是无偏的,没有完全依照的概率去采样。所以很多优化是从这方面入手,有根据 reward阈值的拒绝采样(但是reward在不同domain阈值差别很大,不好确定阈值超参)、归一化reward分数当作概率采样或者reward+规则等等。但是这些大部分都做不到无偏。
最后就到了PPO
PPO可谓是集众家之所长,首先延续了拒绝采样的优点,确保了训练模型是生成模型,归一化reward也一定程度减小了估计的偏差。并且奖励“sample出的样本”和token级别的value function还确保了推理和训练的一致性。不过代价就是引入了太多的离散变量,非常不稳定,每个地方都需要做clip。
最后必须要感谢下我司强化组的大佬,经常耐心为我解惑
PS:看到这里,如果觉得不错,可以来个点赞、在看、关注。给公众号添加【星标⭐️】不迷路!您的支持是我坚持的最大动力!
欢迎多多关注公众号「NLP工作站」,加入交流群,交个朋友吧,一起学习,一起进步!