今天给大家带来知乎@真中合欢的一篇文章,《LLM实践--数据去重:Simhash&Minhash 原理分析&代码实现》
知乎:https://zhuanlan.zhihu.com/p/739101179
数据处理是LLM pretrain的核心环节,去重又是数据处理的重要组成部分,这篇文章就分享一下数据去重的一般做法。我写东西的主要是想学会什么,而仅仅是了解什么,所以回答和文章大多都会附上代码,这篇也是一样。这个系列的文章估计废话会比较多。
数据去重大致可以分为三个粒度:文档粒度、段落粒度和句子粒度。
顾名思义,文档粒度就是以文档为单位,删除重复文档。这种做法最为普遍,主要是为了删除完全重复或几乎一致的文档,这种文档一般来自于相同文档的不同源发布转载、重复爬取等。段落粒度和文档粒度没有特别本质的差别,一般适用于一些特殊的源和场景,比如法律相关的文档大量引述法条这种,可能产生大面积引用的场景,做法和文档去重也不会差太多,不过要额外关注一下分段方法。句子粒度的去重则是主要为了去除一些边界词,比如所有知乎文章最后都有一句“真诚赞赏,手留余香”。句子级别的去重就为了删除这些前后缀、边栏之类的东西。句子级别的去重一般是相同句子匹配,设置一个阈值卡掉就可以了,不过要注意下计算成本,这里大规模句子去重无论是用统计查表,还是分桶聚合都是不小的计算量。这篇文章主要是分享minhash去重,句子和段落的去重就不再多说,主要来说说文档粒度去重。
文档这个叫法比较笼统,实际上就是指我们训练的一条数据。不管是一句话,一段话,还是几万字,只要他被收录在一条训练数据里(这里指没有经过拼接处理的数据),就算是一个文档。基础的文档去重是删除重复key,这个key一般是文档的来源,比如网址。之后就是想办法将一篇文档向量化,通过比较两个文档向量的相似度,判断两个文档是否为重复文档。文档向量化的方法有很多,比如使用基于transformer结构的模型比如GLM推出的BGE系列模型,再简单一点的比如fasttext。再比如就是甚至模型都不需要仅仅通过算法就能实现的,也就是本文要介绍的Minhash和Simhash。看名字能够猜出,这两个算法是和哈希函数相关的算法。不知道读者是否和我一样,在第一次看到这两个算法的时候都产生出一个疑问:我们要计算文档相似度,至少需要保证相似文档的向量相似度高。但是哈希函数通常都是被认为是随机函数,并且局部敏感,文档稍微变动一点哈希值都完全不同,是怎么用来计算相似度的?下面我就简单介绍一下这两个算法是怎么向量化文档的。
首先说Simghash算法
Simhash算法的第一步,就是对文档进行分词,现代我们也不说什么老掉牙的空格分词、jieba分词、n-gram了,我习惯就用自己手边的BPE tokenizer分词。没有训练过分词器的,可以从qwen、llama,甚至openai开源的gpt分词器中扒一个当自己的分词器。比如“不能复现的软件不算开源软件”这句话,分词以后是“不能”、“复现”、“的”、“软件”、“不算”、“开源软件”。
第二步是计算每个词的hash值,比如我们用md5计算这些词的hash值
from hashlib import md5
def hash(text):
return md5(text.encode()).hexdigest()
words = ['不能','复现','的','软件','不算','开源软件']
[print(f'{word}:{hash(word)}') for word in words]
'''
不能:38f9286be23a182e7403ef05db293b49
复现:90c33da6cdc6c148d6eab60f7f4926e1
的:01d7aa494b0727f8db77be1d3685de9e
软件:f0dfc65b71f0f4c075027ecbfe66ef7c
不算:7c85dceaa0f0e8bb3a351a059ed05d04
开源软件:ed507ba1538b7c4098b3f82e7ba8af9c
'''
我们知道md5实际上是用16进制表示的二进制数,那直接打出二进制看一下:
print(bin(int('38f9286be23a182e7403ef05db293b49',16)))
# 0b111000111110010010100001101011111000100011101000011000001011100111010000000011111011110000010111011011001010010011101101001001
我们将获得的二进制用0补齐到md5标准的128位,然后1保持不变,0变成-1,就得到了simhash算法中单词的向量表示:
import numpy as np
def encode_word(word):
bin_hash_value = bin(int(word,16))[2:].zfill(128)
embedding = np.array([1 if bit == '1' else -1 for bit in bin_hash_value])
return embedding
print(encode_word(hash('不能')))
到现在,我们已经能获得任意词的向量表示了。接下来就是获得文档的向量表示。一个很自然的想法就是平均池化,也就是将所有词的向量表示取均值。但是在simhash这里我们不求均值,仅求和
word_embeddings = [encode_word(hash(word)) for word in words]
print(sum(word_embeddings))
'''
[ 0 0 2 2 0 -2 -6 -2 4 4 -4 2 -2 0 0 4 0 0 2 0 4 0 0 -2
0 2 2 -4 2 -4 2 2 0 4 0 -2 -2 -4 0 2 2 0 0 0 -2 -2 2 -2
0 2 2 0 0 0 -4 -2 0 2 0 -2 2 -4 -2 -4 0 2 0 6 0 0 0 -2
-2 -2 2 0 -4 -2 4 2 2 0 4 4 4 2 4 -4 -4 -4 -4 -4 2 4 0 4
0 2 2 6 4 2 6 0 0 0 0 -4 0 -2 -4 0 0 0 2 0 4 4 4 2
0 0 -2 0 2 2 -4 -2 ]
'''
我们得到了一个负数、零、正数共存的向量。还没完,再clip一下,把值截断到1-0之间,才是Simhash下的文档向量:
print(np.clip(sum(word_embeddings),0,1))
'''
[0 0 1 1 0 0 0 0 1 1 0 1 0 0 0 1 0 0 1 0 1 0 0 0 0 1 1 0 1 0 1 1 0 1 0 0 0
0 0 1 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0 1 0 0 1 0 0 0 0 1 0 1 0 0 0 0 0 0
1 0 0 0 1 1 1 0 1 1 1 1 1 0 0 0 0 0 1 1 0 1 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0
0 0 0 1 0 1 1 1 1 0 0 0 0 1 1 0 0]
'''
现在,我们只需要比较两个文档向量的汉明距离,就能比较两个文档的相似性了。汉明距离就是不一样的位数。观察一下Simhash是否符合我们对文档向量的要求,也就是相似文档的相似性高,不相似文档的相似性低。首先对于两篇完全相同的文档,能够想见对应的向量是完全相同的。那么如果只改变少量词呢?假设如果文档词的数量众多,仅仅增加或减少一个词,只会略微改变求和后的值,并不会明显的改变正负号,那么在clip后的结果变化就不大。因此Simhash是满足我们对相似向量的要求的。
但是Simhash这一堆骚操作到底是在干什么?词向量的表示相加获得文档向量还比较好理解,但是为什么要计算hash值?为什么不用0、1表示的词向量而是用1、-1?
从实质上看,Simhash的向量其实就是一个词频统计向量。一开始我说分词器我用的是语言模型的tokenizer,换句话说就是一个词表大小有限且已知的分词器,比如qwen的tokenizer词表大小是15w,那么最终分出来的所有词就只会从这词表的15w个里选一个。Simhash其实对哈希函数没有特别的要求,那我们就设计一个特殊点的哈希函数:我们用one-hot哈希函数。也就是说分词以后,我们给这每一个词一个唯一索引,然后转化为one-hot向量。词表也别15万了,为了好算我们假设词表只有10个词。
def get_onehot(index,vocab_size):
tensor = np.zeros(vocab_size,dtype=np.int32)
tensor[index] = 1
return tensor
vocabs = {
'不能':0,'复现':1,'的':2,'软件':3,'不算':4,
'开源软件':5,'模型':6,'开源模型':7,'免费':8,'只能算':9
}
[print(f'{word}:{get_onehot(vocabs[word],len(vocabs))}') for word in words]
'''
不能:[1 0 0 0 0 0 0 0 0 0]
复现:[0 1 0 0 0 0 0 0 0 0]
的:[0 0 1 0 0 0 0 0 0 0]
软件:[0 0 0 1 0 0 0 0 0 0]
不算:[0 0 0 0 1 0 0 0 0 0]
开源软件:[0 0 0 0 0 1 0 0 0 0]
'''
基于这个词向量求个和,计算一下文档向量,得到的是词计数向量:
new_word_embeddings = [get_onehot(vocabs[word],len(vocabs)) for word in words]
print(sum(new_word_embeddings))
# [1 1 1 1 1 1 0 0 0 0]
再计算一下词频和每个词的平均词频:
freq = sum(new_word_embeddings) / len(words) # 词频
mean_freq = sum(freq) / len(vocabs) # 每个词的平均词频
print(freq)
print(mean_freq)
我们令所有词频大于平均词频的词为1,其他为0,得到我们定义的文档向量:
doc_embedding = np.zeros_like(freq)
doc_embedding[freq > mean_freq] = 1
print(doc_embedding)
这明显是一个利用词频统计得到的向量。如果我不想对求和后的文档向量求词频,再求平均词频,再比较平均词频,而是从求和这一步就直接得出词频大于平均词频的值怎么办?只需要把one-hot向量改造一下,把1改为(1-1/N),把0改为(-1/N),这样求和后,大于0的位置就是词频大于平均词频的值。(有兴趣的可以自己推一下)。这和Simhash把1、0变为1、-1是不是很像?但为什么我们这里有个系数而Simhash没有?答案是我们的one-hot向量不均匀,只有1个1,却有N-1个0。而md5计算出的词向量1、0是均匀的,各占1/2,系数就消掉了。那么为什么我们不需要hash,而Simhash需要?因为我们已知词表大小,一个定长的one-hot就够了,而原始Simhash用的是n-grad分词,词表大小不是固定的,不能用one-hot,只能用hash函数来散列词向量。hash是随机的,我的one-hot也是随机的,毕竟词表我也是随便拍的,所以本质没有什么差别。
我上面列的这个Simhash也不是完全正宗的Simhash,一个是上面提到的,正宗的Simhash用的是ngram分词,另一个是它还统计了一个tf-idf词频乘在了词向量上,相信看到这里的读者应该能理解乘tf-idf的意义是什么,我就不再赘述了。
再说Minhash
Minhash的第一步和第二步和Simhash是一样的,都是分词计算每个词的hash值,就不再赘述。
我们已经知道词的哈希值可以转化成二进制bit,也就是说它本质就是个数,我们是可以比大小的。那么Minhash的第三步是把所有词的哈希值拿来做个比较,取出其中的最小值,作为文档向量的第一个值。
def hash(text):
return md5(text.encode()).hexdigest()
words = ['不能','复现','的','软件','不算','开源软件']
doc_embedding = []
doc_embedding.append(min([int(hash(word),16) for word in words ]))
print(doc_embedding)
# [2449025636878415118544441536416177822]
第四步再换一个不同的哈希函数,再次计算每个词的哈希值,这样会得到不同刚才的哈希值。再计算最小值,作为文档向量的第二个值。然后以此类推,选择N个哈希函数,计算N次所有词的哈希值,然后取每一次的最小值,构成一个N维的文档向量。这里换个哈希函数怎么操作?第一次我们用的md5,那第二次可以换成sha256,但是这似乎有些太考验我们的算法能力了。N我一般设200,那难道我们要找200个不同哈希函数?其实不用这么麻烦,还是用md5作为基础的哈希函数,把md5得到的结果转位数字,乘上一个大整数,再加上一个大整数,再对一个大质数求余,得到的结果就当作一个新的哈希值。这样如果需要200个哈希函数,我们只需要随机200个不同的大整数就可以了。
这是新的哈希函数:
def hash_new(text,a,b,prime=4294967311):
return (int(hash(text),16) * a + b) % prime
计算文档向量:
import random
random.seed(3407)
N = 200
A = [random.randint(1,(1 << 32) - 1) for i in range(N)]
B = [random.randint(1,(1 << 32) - 1) for i in range(N)]
doc_embedding = []
for i in range(N):
doc_embedding.append(min([hash_new(word,A[i],B[i]) for word in words]))
print(doc_embedding)
#[2423597390, 1270089199, 681530810, 168570878, 844989722, 342611399, 454597371, 405385321, 1455306417, 1799453425, 654234254, 123401623, 361392250, 361127621, 171480922, 48254851, 260370392, 294534994, 235573334, 31825184, 205781226, 132342607, 127505912, 1353217623, 871307438, 510363507, 954752857, 811346828, 608154468, 609529456, 1049415395, 160677016, 43239236, 901577591, 615921321, 587563498, 1320500130, 854922356, 232620261, 312281959, 386793238, 1337207987, 207370128, 750621631, 479036101, 221556004, 1451322663, 2364862315, 69337639, 361395067, 293355426, 422205765, 2544647403, 258751502, 489316447, 245619660, 439354362, 246756035, 446164460, 877452501, 120515651, 427586115, 705763926, 14423229, 47163397, 920215099, 843161250, 487425253, 1203534114, 298772626, 18895916, 67285674, 1296659337, 232071759, 650250036, 1240114578, 285630986, 401775071, 1326181169, 158825632, 895423844, 9416270, 245543362, 919956232, 243020444, 270959032, 198410531, 215627334, 694245618, 432904497, 236518318, 1274000213, 1089620997, 223162401, 277620305, 197883008, 31486764, 753716919, 749241405, 144588116, 1947626321, 333172229, 204993464, 1064400689, 668383379, 1184148909, 194732569, 828785985, 401150057, 594082862, 312738173, 967629887, 159876200, 1567826513, 1109553618, 1267138127, 735657912, 663700649, 145385704, 192905202, 1524193947, 63742926, 333926830, 1414344574, 6418678, 144938387, 173305400, 1101304510, 1434490425, 30961883, 58356953, 1115388456, 471768102, 1123933108, 19005016, 27542207, 897426129, 407463217, 450285222, 166759660, 1236436830, 513197790, 1849908251, 509075643, 1064150793, 657203906, 530911647, 826557359, 351395892, 1000841022, 584132878, 1050800763, 97401011, 462272254, 294821960, 293235667, 186144859, 313449706, 52029025, 112310135, 965801791, 263747740, 737717930, 996945331, 40462684, 785165888, 207385385, 788583638, 320703301, 222135923, 106995624, 384904863, 1487049842, 579033109, 162579224, 1391691432, 550626690, 135040907, 147014727, 626004236, 1249974224, 563752075, 274398108, 1452942235, 464002293, 49990373, 34628133, 385463128, 610885835, 630973844, 278622350, 557982613, 233417992, 110545601, 1091671537, 428292734, 512214567, 515890282, 128651558, 170081777]
Minhash通过比较两个文档向量的jaccard相似度计算相似性。也就是把两个doc embedding转化为2个集合,求两个集合交集的大小与两个集合元素之比。
奶奶滴这又是在干啥?怎么好像比Simhash还抽象。这里我也想不出特别眼前一亮的解释,但还是努力解释一下。
哈希函数是什么?我们可以认为一个理想的哈希函数是一个均匀分布函数,给它一个值, 它能够以均等的可能性将其随机映射到空间中的某一个点上。假设我们上面用的hash_new它是一个理想的哈希函数,那么它就是一个能将值随机的映射到0 - 4294967311上的均匀分布随机函数。
那么两个hash_new呢?我们知道两个独立的均匀分布的联合分布还是均匀分布,是2维均匀分布:
所以我们可以把“两个hash函数散列一个词”这个操作,等价为将一个词随机映射到二维平面上,再分别投影到两个轴上。
那么Minhash计算每个hash函数最小值的操作,实际上就是在找这些随机点在每个轴投影的下界。计算jaccard相似度就是在比较下界的重合度。
可以想象,找到一篇基础文章,对它更换、增减一些词,实际上就是移动或增减图上的点。那么如果改变的词少,可能并不会改变下界,改变的越多,越可能改变下界。
Minhash与Simhash可以看作是两兄弟,Minhash是将文档的词随机散布到空间中后,取下界,也就是最小池化。Simhash将词散布到空间后,计算词频也就是加权均值计算重心,也就是平均池化。
局部敏感Minhash(LSH Minhash)
现在我们知道两张hash是怎么将文档向量化的,下一步就是两两比较相似度去重了。但要知道,两两比较的复杂度是 级别的,这在大规模去重的时候是无法接受的,必须进行简化。局部敏感哈希给出的减小n的方法就是分桶。如果我能够以 复杂度现将每个文档均匀的分到不同的桶里,再在每个桶内部两两比较相似度,就可以减小n了,每个桶内的文档数量是数据总量n处以桶数量m,那么比较复杂度就是: ,总的复杂度再乘一个桶的数量,也就是 。说到 复杂度、均匀、分桶,第一时间又想到了hash算法。可以,那我们就计算一下doc embedding的哈希值,按照哈希值分桶,所有hash值一样的分在一个桶里,如果两个文档被分到了不同的桶,就不进行比较,认为他们不相似:
def hash_embed(doc_embedding):
return hash(','.join([str(e) for e in doc_embedding]))
print(hash_embed(doc_embedding))
如果这样分桶,那只有doc embedding完全一样才会被分到一个桶里,此时桶的数量就是无重复文档的数量。如果我们的数据中没有相似文档,这个桶数量m就接近文档总数n,时间复杂度分子分母一约就变成 复杂度了。但是如果我们的文档全都是一样的,那么分桶数量就会是1,复杂度就还是 ,相当于没有没有优化。
话说回来,我们并不是想要做完全匹配,doc embeding有1位不同我们也希望能被分到一个桶里比较一下,应该怎么办?也好办,我们把doc embedding从中间分成两半,分别计算hash值。这样假设doc embedding只有一位不同,那么不同的这一位必然落在embedding的前半部分或后半部分,那么剩下的相同的那一半embedding必然是一样的,计算的hash值就是相同的,那么按照这一部分计算出的hash值分桶,就能被分在一个桶里。也就是说在计算全部文档的相似度时,我分两个批次,第一个批次只拿出前一半 embedding计算一次hash分一次桶,比较一次相似度。第二个批次只拿出后一半embedding计算一次hash分一次桶,比较一次相似度。然后将两个批次的结果合并,如果一个文档在任意一个批次的比较中被检测到和其他某个文档相似,那么它就和那个文档相似。
如果需要把2位不同的分在一个桶里呢?那就embedding分3份,不同的2位最多落在其中两份,必然有一份hash值相似。
显然我们只要将embedding分为k+1份,计算k+1个hash,就能确保把有k个位不同的embedding分到一起。这里再规范化一下表述,这个embedding的切分份数在LSHMinash里称为band数。增加band数一方面让我们能够把更多位不同的embedding放在一个桶里,一方面会带来大量的无效比较。如果两个文档完全相似,那么他们两个文档的band个hash值也都相同,本来只需要在一个桶里比较一次就够了,现在却要被放到band个桶里多次比较。那我们能不能减少一点band数,适当承受一下2位不同的embedding没有被分到一个桶里的误差?这其实是一个简单的概率问题,也就是当给定embedding长度N,不同的位数M,band数K时,两个文档至少有一次能够被分进一个桶里的概率是:。这个很好推,实际上就是用M、N计算两个文档的相似性s,根据s计算每个band中hash相同的概率p,再组合成所有band至少一次相同的概率就可以了。但请注意这个公式并不适用于M比较小的时候,偏差来自于公式假定了不同的位落在不同band的概率是独立的,但实际不独立,M越小越不独立。
实际用的时候,根据你的hash向量长度N、能接受的不同位数M和可以接受的概率P,算一下K应该设置多少就可以了。LSH Simhash也是类似的做法,doc embedding切band然后分桶。
最后在给出大规模计算Minhash的代码前,我们再从更加具体的角度考虑下两种hash算法的区别。想象一下如何变换一个文档能够在不改变Minhash的情况下, 改变Simhash,以及什么情况下会反过来?答:增加现有词的数量,不管增加多少,Minhash都不会变。增加的少,Simhash可能也不会变,但是增加的多了Simhash一定变。反过来的情况我想了下好像是不存在的,那就只探讨一下可能性,如果文档增加了一个冷门词,在Simhash中由于计算的是加权,一个词的权重可能不足以体现在最终结果上,但是在Minhash中是有比较大的概率被作为一个新的下界的(这个概率取决于embedding的维度和当前文档的词数量。大概是 。
基于Spark的Minhash去重
其实Minhash和Simhash也没有什么特别质的区别,用的时候多数情况下还是从量上表现出的差异。可以根据喜好选一个,因为Minhash对相同词重复出现不敏感,所以我喜欢用Minhash。下面上代码,数据load、dump这种简单代码我就不贴了。大规模跑数据我用的都是Spark,关于Spark的安装可以去问chat老师,我之前也写过一个简单的介绍:"LLM实践--Hugingface&vLLM + Spark集群"
https://zhuanlan.zhihu.com/p/715290922
首先是读取一些命令行参数,data数输入数据的路径,可以是目录,也可以是文件,spark会自动处理。num-gpus是整个集群的gpu数量,上面我贴的文章里也讲了怎么用spark管理gpu。这里之所以要考虑gpu,是我会先用gpu给语料打个质量分。之后是定义一下临时数据落盘的位置,大规模计算的时候数据落盘保证一下可靠性还是很重要的:
from argparse import ArgumentParser()
parse = ArgumentParser()
parse.add_argument('data',nargs=1,type=str)
parse.add_argument('--num-gpus',type=int,default=800)
parse.add_argument('--num-merge-part',type=int,default=None)
parse.add_argument('--threshold',type=float,default=0.9)
args = parse.parse_args()
num_qs_part = args.num_gpus * 4
input_path = ','.join(args.data)
output_path = f'{input_path.rstrip("/")}_result'
qs_temp_path = f'{output_path}_tmp_qs'
minhash_temp_path = f'{output_path}_tmp_minhash'
dd_temp_path = f'{output_path}_tmp_dd'
lsh_threshold = 0.8
hash_threshold = args.threshold
下面第一步就是打qs分数,safty_load确保数据都是正常的json格式,compute_qs 计算语料质量分。这里之所以要计算质量分,是我们在计算完文档相似性后,需要确定两个文档该删哪个。比如我选择删除质量分低的。当然也可以不打质量分,删除更靠后的,或者长度更长、更短的。如果不打分我习惯删除更长且更靠后的。这里选择删除依据要关注一点,就是判断依据必须能确保全局一致性。我们是要进行LSH分桶的,如果AB两个文档相似,在band1留A删B,在band2留B删A,汇总一下就都删了。所以仅靠长度不行,如果两个文档长度相同的就没有一致性了,所以还得加上行号这个依据。计算qs要动用gpu,耗时蛮久的,随意最后数据落个盘保存到qs_temp_path:
def compute_qs(partition):
model = ...
tokenizer = ...
for line in partition:
text = line['text']
qs = model(tokenizer.encode(text)).logits
line['qs'] = qs.items()
yield line
spark = SparkSession.builder.appName("Compute Quality Score")....
origin_rdd = spark.sparkContext.textFile(input_path).repartition(num_qs_part).flatMap(safty_load)
qs_rdd = origin_rdd.mapPartitions(compute_qs)
qs_rdd.map(dump_json).saveAsTextFile(qs_temp_path)
接下来就是计算Minhash和去重了,首先加载一下分词器和minhash算法库。这里用datasketch这个库帮助我们计算minhash值和band数量,这里还顺便定义了一个3-gram分词器,它把tokenizer分词结果再拼成3-gram,这里实测有一点影响,因为tokenizer虽然包含一部分n-gram信息了,但还是不太多。不过不用太纠结,这个也是量上的差别,不是质的变化:
from transformers import AutoTokenizer
from datasketch import MinHash,MinHashLSH
minhash_tokenizer = AutoTokenizer.from_pretrained('Your tokenizer')
LSH = MinHashLSH(lsh_threshold, 200)
gram_3 = Ngram(3)
path = qs_temp_path
接下里就是读取刚才落盘的包含qs的数据,拼上行号,计算minhash,再把带行号的数据落盘一次后面用:
qs_rdd = spark.sparkContext.textFile(path).map(load_json).zipWithIndex().map(lambda x:{**x[0],'hashindex':x[1]})
minhash_rdd = qs_rdd.map(get_minhash).cache()
minhash_rdd.map(remove_minhash).map(dump_json).repartition(1000).saveAsTextFile(minhash_temp_path)
计算minhash调库很方便:
def get_minhash(line):
text = line['text']
tokens = gram_3(minhash_tokenizer.tokenize(text))
m = MinHash(num_perm=200)
m.update_batch([token.encode() for token in tokens])
line['minhash'] = m.digest().tolist()
return line
再往后就是分band、分桶、去重再聚合,然后落盘
minhash_bulks_rdd = minhash_rdd.flatMap(scatter_hash_bulk)
deduplicate_res_rdd = minhash_bulks_rdd.groupByKey().flatMap(deduplicate)
deduplicate_res_rdd.groupBy(lambda x:x['hashindex']).map(merge_deduplicate).map(remove_minhash).map(dump_json).saveAsTextFile(dd_temp_path)
分桶的代码如下,LSH是我们在上面定义的全局变量,是datasketch库提供的局部敏感哈希,如果是本地做其实用这个就行了,但我们这里要在spark上面mapreduce,所以我就借用一下这个类,帮忙算一下给定阈值下的band数量LSH.b。我们这里分桶的时候除了计算每个band的hash值,还给前面拼了一个桶号,避免不同band相同的hash值混在一起,然后只保留必要字段,降低spark通信成本:
def get_LSH_keys(hash:list,bands):
num_perm = len(hash)
rows = num_perm // bands
bulks = []
for i in range(bands):
start = i * rows
end = (i+1) * rows
features = hash[start:end]
bulks.append(f'bulk:{i}_'+md5(','.join([str(f) for f in features])))
return bulks
def scatter_hash_bulk(line):
minhash = line['minhash']
LSH_keys = get_LSH_keys(minhash,LSH.b)
doc_info = {}
doc_info['hashindex'] = line['hashindex']
doc_info['minhash'] = line['minhash']
doc_info['qs'] = line.get('qs',-1)
output = [[lsh,doc_info] for lsh in LSH_keys]
return output
接下来就是去重了,首先定义jaccard相似度的计算函数,然后依次按照行号、质量分排序,然后两两比较相似度,大于阈值的认为相似,把相似文档的信息yeild出去,没找到相似的就不yeild了,再见少点点spark的通信存储量:
def deduplicate(group):
def jaccard_similarity(signature1, signature2):
assert len(signature1) == len(signature2), "Signatures must be of the same length"
count = sum([1 for i in range(len(signature1)) if signature1[i] == signature2[i]])
return count / len(signature1)
key,content = group
docs = sorted([{**doc,'sim_doc':None} for doc in content],key=lambda x:x['hashindex'])
docs.sort(key=lambda x:x['qs'],reverse=True)
retained_docs = []
for i,doc in enumerate(docs):
if i == 0:
retained_docs.append(doc)
else:
for j,retained_doc in enumerate(retained_docs):
sim_score = jaccard_similarity(retained_doc['minhash'],doc['minhash'])
if sim_score >= hash_threshold:
doc['sim_doc'] = retained_doc['hashindex']
break
if doc['sim_doc'] is not None:
retained_docs.append(doc)
else:
output = {'hashindex':doc['hashindex'],'sim_doc':doc['sim_doc']}
yield output
merge不同band的去重结果:
def merge_deduplicate(group):
key,content = group
merged_res = None
for res in content:
if merged_res is None:
merged_res = res
continue
if res['sim_doc'] is not None:
merged_res['sim_doc'] = res['sim_doc']
break
return merged_res
至此我们就获得了所有需要被删除的文档id,以及删除它们是因为和谁相似,接下来就是把全量数据和去重的结果join一下,删除数据就行了。一般这里的做法是做一下左连接,原始数据在左,去重数据在右,然后再过滤,如果一条数据被拼上了去重信息,说明应该被删掉,没拼就不删。但是我们可以优化一下,原始数据量很大,去重数据却不多,说明这是一个大左表join小右表,join效率很低。我们不如把小右表读取成一个字典,然后让左表直接去判断是否在这个字典里。
deduplicate_dict= sc.broadcast(sc.textFile(dd_temp_path).map(load_json).map(lambda x:(x['hashindex'],x['sim_docs'])).collectAsMap())
output_rdd = minhash_rdd.filter(lambda x:x['hashindex'] not in deduplicate_dict)
output_rdd.map(dump_json).saveAsTextFile(output_path)
到这里spark+minhash去重的代码也就写完了,最后还想再回过来谈谈数据去重的作用吧。数据去重是肯定不能不做的,不做去重有的文档可能被训个几百遍了,一些文档才训了一遍。但是去重也不能一个劲的猛去,因为有的数据一遍真的记不住,需要多训几遍才行。简单一点的策略,就是像我们这片文章这样,在一批数据上选个阈值去重一下,看看留下多少够不够训练用,删了什么被删的有多相似。但是还有很多精细的工作可以做。比如前司负责数据工作的谷歌大佬就做过一些什么知识训练多少遍能被记住的实验。不同知识被记住需要的训练量是不一样的,如果我们能按照被记住需要的训练量来去重,无疑能够训练的更高效。但这里还有几个比较难的点,比如什么样的算是一条知识,”中国的首都是北京“,”北京是中国的首都“,”北京之于中国就像华盛顿之于美国“,这里描述的都是相同的知识,但是表述却不一样。再就是以什么样的间隔训练,这就涉及模型的记忆和遗忘曲线,或许可以参考下人的艾宾浩斯遗忘曲线?这方面其实还有个有趣的话题,这位谷歌大佬提出过把分形曲线应用于数据组织,也做了一些验证试验。
写在最后
后续计划从数据下载开始,分享下质量筛选、采样拼接、数据重排、tokenizer训练,基于megatron的3d并行配置,megatron改造,数据配比实验、scaling law计算、模型loss估算、指标估算等等的实践,有机会可能会分享下self regularization或者分型曲线等奇奇怪怪的LLM实验。开头也说了我写文章喜欢写代码,所以这个系列的文章也争取把论文已公开的、普遍使用的方法附上代码实现,但是不那么普遍使用的可能只会讲讲动机和思路,望理解。
PS:看到这里,如果觉得不错,可以来个点赞、在看、关注。给公众号添加【星标⭐️】不迷路!您的支持是我坚持的最大动力!
欢迎多多关注公众号「NLP工作站」,加入交流群,交个朋友吧,一起学习,一起进步!