NeurIPS 2024论文解析:基于SPU实现的两方密态推理框架深度解读

企业   2024-12-06 18:40   浙江  

 
导语:NeurIPS( Conference on Neural Information Processing Systems)是人工智能领域的顶级会议之一,每年吸引全球众多学者和研究人员参与。NeurIPS 2024 将于当地时间 2024 年 12 月 9 日至 15 日在加拿大温哥华举行。蚂蚁密算科技隐语团队与上海交通大学合作论文《Nimbus: Secure and Efficient Two-Party Inference for Transformers》在本次投稿的 15,671 篇有效论文投稿且最终录取率为 25.8% 中脱颖而出,顺利中选。

🔍 点击文末“阅读原文”,即可获取论文原文链接
本文作者:李正一
上海交通大学计算机科学博士研究生、隐语社区Contributor、本论文一作
Transformer 神经网络在各领域上展现出了惊人的效果,也是最热门大模型的结构基础,在众多任务上有潜在的实际应用。为了解决伴随而来的隐私问题,本论文提出了基于  Secretflow-SPU 实现的两方隐私推理框架 Nimbus,以实现 Transformer 神经网络的隐私保护推理。为 Transformer 神经网络中线性层的矩阵乘法及非线性层的激活函数提出了针对性的加速,以实现在保护模型和用户数据隐私的前提下的高效推理,为大模型的隐私推理场景提供了重要技术支持。本文将带来本篇论文的深度技术解读,一起来关注!
一、背景

1. 隐私推理

在本篇论文中,我们主要围绕最常见的 Machine Learning as a Service(MLaaS)推理场景展开研究,其中模型所有者(server)提供私有神经网络模型,用户(client)提供推理任务的输入数据。现有工作主要考虑半诚实敌手模型,各个计算方会遵循协议执行计算,但可能试图分析他们接收的消息来窃取敏感信息。在该敌手模型下,隐私推理确保了模型拥有者对用户的输入一无所知,而用户只能收到推理最终的结果。

为了实现这一隐私保护目标,隐私推理融合了多种密码学原语。当前的两方隐私推理方案大多采用同态加密与多方安全计算相结合的混合协议。在此过程中,神经网络中每个算子的输入与输出均被转化为秘密共享的形式,而具体的计算过程则根据各个算子的性质选择最为合适的密码学工具。例如,在处理神经网络中的线性层时,通常会选用同态加密技术;而对于非线性函数的计算,则还需要借助于基于多方安全计算的方法,如利用 oblivious transfer 来执行比较操作。
2. Transformer 神经网络

Transformer 神经网络在各领域上展现出了惊人的效果,也是热门大模型的结构基础,在众多任务上有潜在的实际应用。然而,Transformer 网络中包含的大量矩阵乘法以及复杂的非线性函数为隐私推理带来了,巨大的性能挑战[1,2,5,6,7,8]

3. 记号
本文使用大写字母来表示矩阵,比如用表示参数矩阵,表示激活值矩阵。表示的第i行,表示矩阵的第i行第j列元素。秘密共享使用符号表示,比如表示由用户(client)持有的秘密共享。同态密文使用符号表示,注意一个经过同态加密的矩阵可能包含多个密文。表示环,环上的元素均为模的整数。我们使用表示一个多项式环,其中N为2的幂次方。多项式环上的多项式用表示,表示多项式第j项的系数。

二、核心方法

为了解决伴随而来的性能挑战,本论文提出了新的两方隐私推理框架  Nimbus,为 Transformer 神经网络中线性层的矩阵乘法及非线性层的激活函数提出了针对性的加速。下面我们将分别简单介绍这两部分具体的技术。

1. 线性层--基于外积的用户端矩阵乘法协议

(一)现有工作的线性层协议

现有的工作都采用 server 端执行矩阵乘法[1,3,4,5,6,8],以最新的工作BumbleBee[6]为例,我们称为 server-side inner product (SIP),如下图所示:

传统线性层 server-side inner product 协议
参与计算的双方持有激活值的秘密共享 。服务提供商还持有参数的明文 。参数和激活值采用下列公式编码为环上的多项式,其他没有赋值的多项式系数设置为零。

由上述编码方式得到的多项式相乘后,结果 多项式中的部分参数即为矩阵乘法 的结果。该过程在的多项式环上的一个例子由下图所示,其中结果多项式的奇数次项系数对应结果矩阵的值。

通过多项式乘法模拟矩阵乘法内积运算

在 Transformer 神经网络的矩阵乘法中,往往有 ,编码时系数所对应的次数会大于。在这种情况下,激活值和参数矩阵会被分割成更小的窗口 。相乘后的结果为结果矩阵中 的窗口的部分和。

在现有的协议和编码方式中,输入通信和输出通信的数量为 。为了减少通信密文的数量,Iron[5] 将窗口大小的选择规划为一个优化问题,传输需要至少 个密文。后续的工作 BumbleBee[6] 提出了一种输出密文的压缩方法,通过额外的计算来减少通信量,但总体延迟相似。

此外,现有的编码方式下,参数和激活值分别被编码(并加密)为明文多项式和密文多项式。而明文多项式和密文多项式之间的乘法复杂度为通过额外的 NTT/INTT 操作后才可以降为

(二)Nimbus 的线性层协议

我们注意到在 的约束下,窗口大小的选择需要兼顾输入和输出密文的数量。Nimbus 的方法包括两个方面:

  1. 通过消除输入通信,解除输入通信对求解窗口大小的影响;
  2. 在此基础上设计高效编码方式进一步提升计算输出密文的通信效率。

Nimbus 重新设计了矩阵乘法的协议消除了输入通信,具体的协议流程如下图所示。Nimbus 能够消除输入通信来自于参数静态性的观察。在推理的过程中,参数是预先确定的,所以可以加密后提前存放在用户本地。在推理时,用户可以直接在本地用激活值的秘密共享和参数的密文做乘法,由此消除了输入通信。

Nimbus的线性层client-side outer product协议

在新的协议下,Nimbus 对窗口大小的选择不再受制约于输入通信密文的数量。比如,Nimbus 可以选择将激活值矩阵的窗口大小选择为 。在这种情况下,Nimbus 将参数按照行编码进明文多项式并加密为密文多项式。于是,不同于之前的工作需要使用多项式乘法模拟内积,Nimbus 可以通过外积的方式实现矩阵乘法。我们使用一个例子来展示这个过程,如下左图为矩阵乘法的期望功能。右图为得到结果矩阵Z的第一行的计算方式。白色空格表示参数行不足以用满环上多项式的系数。可以看到,明文标量与密文多项式相乘得到了结果矩阵第一行结果 的部分和,多个部分和经过累加后得到输出矩阵的第一行。

这种计算方式具有两个明显的优势。首先,明文标量和密文多项式相乘的复杂度为线性,低于之前工作明文多项式-密文多项式乘法的。此外,虽然 Nimbus 的输出密文和之前工作一样有未利用的系数,但是利用外积乘法得到的结果多项式中有效数字是连续排布的,这让我们可以使用一个“免费”的右移操作将多个密文多项式合并压缩。如下图所示,我们将结果矩阵的第二行右移后与结果矩阵的第一行压缩为一个利用率100%密文。将输出密文的效率做到了最高。

通过密文多项式右移操作实现输出密文压缩
📢虽然 Nimbus 的线性层协议可以大幅减少通信和总计算量,但是让 Client 完成同态乘法听起来会给 Client 带来很大的计算开销。然而,在传统的同态线性层计算中,乘法本身并不是最慢的,相当一部分时间花在 Client 给激活值加密和解密的过程。这是因为由于复杂度的 NTT 操作发生在用户加密和解密的过程中,而乘法本身只需要。所以,当 Nimbus 协议不再需要 Client 处理加解密后,Client 的计算开销几乎维持了不变,甚至变得更少。

2. 非线性层--分布感知的高效非线性函数近似

对于 Transformer 模型来说,非线性层中的主要效率瓶颈在于安全计算 exponential 和 GELU 函数。一种主流的计算方式是通过分段多项式来近似非线性函数[5,6,7,8],而分段多项式逼近可以通过执行双方加法、乘法和比较操作来安全计算。为了保持精度,现有工作采用3段多项式(次数为6)来逼近 GELU 函数,采用两段泰勒级数(泰勒展开次数为6)来逼近 exponential 函数。高次多项式的计算以及处理分段的比较操作会给安全计算带来很大的开销。

(一)分布感知的非线性函数近似

现有工作生成分段多项式的策略是最小化近似多项式和原函数的误差,这相当于将非线性函数的分布视为了均匀分布。而 Nimbus 引入了一个不同的观察,Transformer 网络的激活值分布具有明显的规律性。比如,在下图的 GELU 函数和 exponential 函数的输入中,exponential 函数有80% 的输入值落在 [-5,0]之间,而 GELU 函数有90% 的输入值小于0。这些信息应该被结合在为非线性函数分段,以及拟合每个分段时。

GELU 函数和Exponential 函数输入值分布
比如,在分段时,虽然 GELU 函数在[0,1]之间具有较明显的非线性。但是因为这部分输入值几乎很少出现,所以可以采用简单的线性函数拟合。而在拟合每个分段时,Nimbus 将输入值的概率分布集成到误差函数中,以拟合更真实的误差期望。并且 Nimbus 发现,输入值的概率分布只需要一个约512个 token 的子数据集即可获得较为稳定的估计。此外,先前工作为不同深度的非线性函数使用共同的区间断点和多项式系数。而在 Nimbus 的策略下,不同深度的激活值分布略有不同,更合理的策略是按照深度采用独立的系数。与先前研究中假设均匀输入分布并直接最小化原始函数的近似误差相比,我们的策略能够生成分段数更少并且次数更低的近似多项式。

(二)升环-截断操作融合协议

此外,使用低次多项式也减少了定点数计算过程中误差的积累和数据溢出的可能性,允许我们的计算使用更小的环和精度。比如将64比特的大环和18比特的定点数精度降低为32比特的小环和12比特的定点数精度,可以进一步带来约2倍左右的性能提升。然而,由于 Transformer 网络的其他算子仍旧需要采用高精度和大环,所以需要额外执行大小环的切换。将元素从大环切换到小环双方可以在本地独立完成,而小环切换到大环需要双方经过多轮的通信来处理秘密共享的 wrap 问题[10]。Nimbus 注意到每次的升环操作都会跟在一个截断操作之后。而截断操作本身也需要处理 warp 问题,所以,Nimbus 提出了一个新的协议,将升环操作与截断操作融合,从而复用了截断操作的 wrap 结果,实现了免费的升环操作。

三、实验结果

论文中的实验考虑了两种网络环境:LAN (3Gbps,RTT=1ms), WAN (400Mbps, RTT=10ms)。比较的 Baseline 包括Iron[5]和 BumbleBee[6]。主要的实验包括性能测试以及模型精度测试。

1. 性能实验

LAN(上图)和WAN(下图)下Iron,BumbleBee和Nimbus的性能对比
本实验展示 BERT-base 模型在输入长度为128的情况下,采用 Nimbus 框架相较于 BumbleBee 框架所实现的显著加速效果。在 LAN 下,Nimbus 展现了相对于 BumbleBee 约5倍的整体性能提升;其中,线性层处理速度达到了显著的10倍优化,而非线性层实现了接近4倍的速度优化。在 WAN 下,整体加速比约为3倍。其中,线性层加速约4倍、非线性层约3倍。此外,文章中还在更大规模模型及不同输入序列长度条件下,展示了 Nimbus 的一致性加速。

2. 模型精度实验

为了验证 Nimbus 精简多项式近似后对模型精度的影响,我们选取BERT-base作为实验模型,测试了在GLUE benchmark中8个任务的表现。实验结果证明在即使不做微调的情况下,Nimbus的多项式近似仅造成了约0.57%的平均精度损失。而经过微调后,Nimbus只有0.07%的精度损失,几乎对精度没有影响。



四、结论

本论文提出了一种用于 Transformer 的隐私保护的高效的两方推理框架 Nimbus。我们提出了一种基于外积的用户端的高效安全矩阵乘法协议,为线性层实现了更高的计算和通信效率。对于非线性层,我们采用了分布感知的多项式近似方法,从而可以使用更简单的近似,并减少通信量和交互轮数。这些优化显著提升了性能,向 Transformer 的隐私推理的实际应用迈出了重要一步。

本论文的实现基于SecretFlow-SPU[9],对于密态推理的支持。本文所探讨的线性层协议,基于 SPU 框架内 BumbleBee 模块的实现,通过修改 SPU 后端的 C++ 算子完成了同态协议的构建。对于非线性层而言,其性能提升的关键在于对近似方法的优化;因此,我们利用 SPU 提供的 Python 前端接口,实现了更为简洁高效的多项式近似。目前本文的相关代码发布在 SPU 的 GitHub 仓库中开源一个 PoC 分支,欢迎查看。

  • 代码开源:https://github.com/secretflow/spu/tree/nimbus

深度视频解读
本论文一作、本文作者李正一的详细技术解读,欢迎查看👇
  • 直播现场互动问答:https://www.yuque.com/secret-flow/admin/fnofay8tm23imerz


Reference
[1] Tianyu Chen, Hangbo Bao, Shaohan Huang, Li Dong, Binxing Jiao, Daxin Jiang, Haoyi Zhou, Jianxin Li, and Furu Wei. The-x: Privacy preserving transformer inference with homomorphic encryption. arXiv preprint arXiv:2206.00216, 2022.
[2] Li, D., Wang, H., Shao, R., Guo, H., Xing, E., and Zhang, H. MPCFORMER: FAST, PERFORMANT AND PRIVATE TRANSFORMER INFERENCE WITH MPC.
[3] Chiraag Juvekar, Vinod Vaikuntanathan, and Anantha Chandrakasan. GAZELLE: A low latency framework for secure neural network inference. In 27th USENIX Security Symposium (USENIX Security 18), pages 1651–1669, 2018.
[4] Zhicong Huang, Wenjie Lu, Cheng Hong, and Jiansheng Ding. Cheetah: Lean and fast secure two-party deep neural network inference. In 31st USENIX Security Symposium (USENIX Security 22), pages 809–826, 2022.
[5] Meng Hao, Hongwei Li, Hanxiao Chen, Pengzhi Xing, Guowen Xu, and Tianwei Zhang. Iron: Private inference on transformers. Advances in Neural Information Processing Systems, 35:15718–15731, 2022.
[6] Wenjie Lu, Zhicong Huang, Zhen Gu, Jingyu Li, Jian Liu, Kui Ren, Cheng Hong, Tao Wei, and WenGuang Chen. Bumblebee: Secure two party inference framework for large transformers. Cryptology ePrint Archive, 2023.
[7] Ye Dong, Wenjie Lu, Yancheng Zheng, Haoqi Wu, Derun Zhao, Jin Tan, Zhicong Huang, Cheng Hong, Tao Wei, and Wenguang Cheng. Puma: Secure inference of llama-7b in five minutes. arXiv preprint arXiv:2307.12533, 2023.
[8] Qi Pang, Jinhao Zhu, Helen Möllering, Wenting Zheng, and Thomas Schneider. Bolt: Privacy-preserving, accurate and efficient inference for transformers. In 2024 IEEE Symposium on Security and Privacy (SP), pages 130–130. IEEE Computer Society, 2024.
[9] Junming Ma, Yancheng Zheng, Jun Feng, Derun Zhao, Haoqi Wu, Wenjing Fang, Jin Tan, Chaofan Yu, Benyu Zhang, and Lei Wang. SecretFlow-SPU: A performant and User- Friendly framework for Privacy-Preserving machine learning. In 2023 USENIX Annual Technical Conference (USENIX ATC 23), pages 17–33, 2023.
[10] Deevashwer Rathee, Mayank Rathee, Rahul Kranti Kiran Goli, Divya Gupta, Rahul Sharma, Nishanth Chandran, and Aseem Rastogi. Sirnn: A math library for secure rnn inference. In 2021 IEEE Symposium on Security and Privacy (SP), pages 1003–1020. IEEE, 2021.

蚂蚁技术AntTech
科技是蚂蚁创造未来的核心动力
 最新文章