Muon优化器赏析:向量与矩阵有何本质区别?

科技   2024-12-11 22:36   北京  
©PaperWeekly 原创 · 作者 | 苏剑林
单位 | 科学空间
研究方向 | NLP、神经网络


随着 LLM 时代的到来,学术界对于优化器的研究热情似乎有所减退。这主要是因为目前主流的 AdamW 已经能够满足大多数需求,而如果对优化器“大动干戈”,那么需要巨大的验证成本。因此,当前优化器的变化,多数都只是工业界根据自己的训练经验来对 AdamW 打的一些小补丁。

不过,最近推特上一个名为“Muon” [1] 的优化器颇为热闹,它声称比 AdamW 更为高效,且并不只是在 Adam 基础上的“小打小闹”,而是体现了关于向量与矩阵差异的一些值得深思的原理。本文让我们一起赏析一番。


▲ Muon与AdamW效果对比(来源:推特@Yuchenj_UW)


算法初探
Muon 全称是“MomentUm Orthogonalized by Newton-schulz”,它适用于矩阵参数 ,其更新规则是

这里  是矩阵符号函数 [2],它并不是简单地对矩阵每个分量取 操作,而是 函数的矩阵化推广,它跟 SVD 的关系是:

其中 ,r 是 的秩。更多的理论细节我们稍后再展开,这里我们先来尝试直观感知如下事实:
Muon 是一个类似于 Adam 的自适应学习率优化器。
像 Adagrad、RMSprop、Adam 等自适应学习率优化器的特点是通过除以梯度平方的滑动平均的平方根来调整每个参数的更新量,这达到了两个效果:1)损失函数的常数缩放不影响优化轨迹;2)每个参数分量的更新幅度尽可能一致。Muon 正好满足这两个特性:
1. 损失函数乘以 也会乘以 ,结果是 被乘以 ,但 Muon 最后的更新量是将 变为单位阵,所以不影响优化结果;
2. 当 被 SVD 为 时, 的不同奇异值体现了 的“各向异性”,而将它们都置一则更加各向同性,也起到同步更新幅度的作用。
(注:Muon 还有个 Nesterov 版,它只是将更新规则中的 换成 ,其余部份完全一致,简单起见就不在正文中展开介绍了。)


符号函数

利用 SVD,我们还可以证明恒等式

其中 是矩阵的 1/2 次幂的逆矩阵,如果不可逆的话则取伪逆。这个恒等式能让我们更好理解为什么 的矩阵推广:对于标量 x 我们有 ,正是上式的一个特殊情形(当 矩阵时)。
这个特殊例子还可以推广到对角阵 ):

其中 是指向量/矩阵的每个分量都取 。上式意味着,当 是对角阵时,Muon 就退化为带动量的 SignSGD(Signum)或笔者所提的 Tiger,它们都是 Adam 的经典近似。
反过来说,Muon 与 Signum、Tiger 的区别就是 Element-wise 的 替换成了矩阵版 对于 n 维向量来说,我们还可以视为 的矩阵,此时 正好是 归一化。
所以,在 Muon 框架下对向量我们有两种视角:一是对角矩阵,如 LayerNorm 的 gamma 参数,结果是对动量取 ;二是 的矩阵,结果是对动量做 归一化。
此外,输入和输出的 Embedding 虽然也是矩阵,但它们使用上是稀疏的,所以更合理的方式也是将它们当成多个向量独立处理。
当 m=n=r 时, 还有一个意义是“最优正交近似”:

类似地,对于 我们可以写出(假设 没有零元素):

不论是 还是 ,我们都可以视为对更新量的一种规整化约束,所以 Muon 和 Signum、Tiger 可以视作是同一思路下的优化器,它们都以动量 为出发点来构建更新量,只是为更新量选择了不同的规整化方法。
式(5)的证明:对于正交矩阵 ,我们有

其中涉及到的运算规则我们在伪逆中已经介绍过。由于 都是正交矩阵,所以 也是正交矩阵,正交矩阵的每个分量必然不超过 1。
又因为 ,所以上式取最小值对应于每个 取最大值,即 ,这意味着 ,即
该结论还可以仔细地推广到 m,n,r 不相等的情形,但这里不作进一步展开。

迭代求解

实践中,如果每一步都对 做 SVD 来求解 的话,那么计算成本还是比较大的,因此作者提出了用 Newton-schulz 迭代来近似计算
迭代的出发点是恒等式(3),不失一般性,我们假设 ,然后考虑在 处泰勒展开 ,展开的方式是直接将标量函数 的结果用到矩阵中:

保留到二阶,结果是 ,那么我们有

假如 的某个近似,我们认为将它代入上式后,会得到 的一个更好的近似,于是我们得到一个可用的迭代格式

然而,查看 Muon 的官方代码我们就会发现,它里边的 Newton-schulz 迭代确实是这个形式,但三个系数却是 (3.4445, -4.7750, 2.0315),而且作者没有给出数学推导,只有一段语焉不详的注释:

▲ Muon优化器的Newton-schulz迭代

收敛加速

为了猜测官方迭代算法的来源,我们考虑一般的迭代过程

其中 a,b,c 是三个待求解的系数,如果想要更高阶的迭代算法,我们也可以逐次补充  等项,下面的分析过程是通用的。
我们选择的初始值 矩阵的范数,选择的依据是除以 不改变 SVD 的 ,但可以让 的所有奇异值都在 [0,1] 之间,让迭代的初始奇异值更标准一些。
现在假设 可以 SVD 为 ,那么代入上式我们可以得到

因此,式(11)实际上在迭代奇异值组成的对角阵 ,如果记 ,那么 其中
又因为对角阵的幂等于对角线元素各自取幂,所以问题简化成单个奇异值 的迭代。我们的目标是计 换言之希望通过迭代将 变为单位阵,这又可以简化为迭代 将单个奇异值变为 1。
受 @leloykun [3] 启发,我们将 a,b,c 的选择视为一个最优化问题,目标是让迭代过程对于任意初始奇异值都收敛得尽可能快。首先我们将 g(x) 重新参数化为

其中 。该参数化的好处是直观表示出了迭代的 5 个不动点 。由于我们的目标是收敛到 1,因此初始化我们选择 ,想法是不管迭代过程往 走还是往 走,结果都是 1 附近。
接下来,我们确定迭代步数 T,这样迭代过程就称为一个确定性函数,然后我们将矩阵的形状(即 n,m)确定好,就可以采样一批矩阵,并通过 SVD 来算奇异值。
最后,我们将这些奇异值当成输入,而目标输出则是 1,损失函数是平方误差,整个模型完全可导,可以用梯度下降解决(@leloykun [3] 则假设了 ,然后用网格搜索来求解)。
一些计算结果:

从表格可以看出,结果跟矩阵大小、迭代步数都有明显关系;从损失函数来看,非方阵比方阵更容易收敛;Muon 作者给出的 a,b,c,大概是迭代步数为 5 时方阵的最优解。当迭代步数给定时,结果依赖于矩阵大小,这本质上是依赖于奇异值的分布,关于这个分布有个值得一提的结果是当 时为 Marchenko–Pastur 分布 [4]

参考代码:

import jax
import jax.numpy as jnp
from tqdm import tqdm

n, m, T = 1024, 1024, 5
key, data = jax.random.key(42), jnp.array([])
for _ in tqdm(range(1000), ncols=0, desc='SVD'):
    key, subkey = jax.random.split(key)
    M = jax.random.normal(subkey, shape=(n, m))
    S = jnp.linalg.svd(M, full_matrices=False)[1]
    data = jnp.concatenate([data, S / (S**2).sum()**0.5])

@jax.jit
def f(w, x):
    k, x1, x2 = w
    for _ in range(T):
        x = x + k * x * (x**2 - x1**2) * (x**2 - x2**2)
    return ((x - 1)**2).mean()

f_grad = jax.grad(f)
w, u = jnp.array([1, 0.9, 1.1]), jnp.zeros(3)
for _ in tqdm(range(100000), ncols=0, desc='SGD'):
    u = 0.9 * u + f_grad(w, data)  # 动量加速
    w = w - 0.01 * u

k, x1, x2 = w
a, b, c = 1 + k * x1**2 * x2**2, -k * (x1**2 + x2**2), k
print(f'{n} & {m} & {T} & {k:.3f} & {x1:.3f} & {x2:.3f} & {a:.3f} & {b:.3f} & {c:.3f} & {f(w, data):.5f}')

一些思考

如果按照默认选择 T=5,那么对于一个 的矩阵参数,Muon 的每一步更新至少需要算 15 次 的矩阵乘法,这计算量毋庸置疑是比 Adam 明显大的,由此可能有读者担心 Muon 实践上是否可行。
事实上,这种担心是多余的,Muon 计算虽然比 Adam 复杂,但每一步增加的时间不多,笔者的结论是 5% 内,Muon 作者则声称能做到 2%。
这是因为 Muon 的矩阵乘法发生在当前梯度计算完后、下一梯度计算前,这期间几乎所有的算力都是空闲的,而这些矩阵乘法是静态大小且可以并行,因此不会明显增加时间成本,反而是 Muon 比 Adam 少一组缓存变量,显存成本更低。
Muon 最值得深思的地方,其实是向量与矩阵的内在区别,以及它对优化的影响。SGD、Adam、Tiger 等常见优化器的更新规则是 Element-wise 的,即不论向量、矩阵参数,实际都视为一个大向量,分量按照相同的规则独立地更新。
具备这个特性的优化器往往理论分析起来更加简化,也方便张量并行,因为一个大矩阵切成两个小矩阵独立处理,并不改变优化轨迹。
但 Muon 不一样,它以矩阵为基本单位,考虑了矩阵的一些独有特性。可能有些读者会奇怪:矩阵和向量不都只是一堆数字的排列吗,能有什么区别?
举个例子,矩阵我们有“迹(trace)”这个概念,它是对角线元素之和,这个概念不是瞎定义的,它有一个重要特性是在相似变换下保持不变,它还等于矩阵的所有特征值之和。从这个例子就可以看出,矩阵的对角线元素跟非对角线元素,地位其实是不完全对等的。而 Muon 正是因为考虑了这种不对等性,才有着更好的效果。
当然,这也会导致一些负面影响。如果一个矩阵被划分到不同设备上,那么用 Muon 时就需要将它们的梯度就需要汇聚起来再计算更新量了,而不能每个设备独立更新,这增加了通信成本。
即便我们不考虑并行方面,这个问题也存在,比如 Multi-Head Attention 一般是通过单个大矩阵投影到 Q(K,V 同理),然后用 reshape 的方式得到多个 Head,这样在模型参数中就只有单个矩阵,但它本质上是多个小矩阵,所以按道理我们需要将大矩阵拆开成多个小矩阵独立更新。
总之,Muon 这种非 Element-wise 的更新规则,在捕捉向量与矩阵的本质差异的同时,也会引入一些小问题,这可能会不满足一些读者的审美。
(补充:几乎在本文发布的同时,Muon 的作者 Keller Jordan 也发布了自己的一篇博客《Muon: An optimizer for hidden layers in neural networks》[5]。)

范数视角

从理论上看,Muon 捕捉了矩阵的什么关键特性呢?也许接下来的范数视角可以回答我们的问题。
这一节的讨论主要参考了论文《Stochastic Spectral Descent for Discrete Graphical Models》[6] 和《Old Optimizer, New Norm: An Anthology》[7],特别是后一篇。
不过其中的出发点并不是新的,我们在《梯度流:探索通向最小值之路》就已经简单涉猎过:对于向量参数 ,我们将下一步的更新规则定义为

其中 是某个向量范数,这称为在某个范数约束下的“最速梯度下降”。接着假设 足够小,那么第一项占主导,这意味着 会很接近,于是我们假设 的一阶近似够用了,于是问题简化成

,那么可以简写成

计算 的一般思路是求导,但《Old Optimizer, New Norm: An Anthology》[7] 提供了一个不用求导的统一方案:将 分解为范数 和方向向量 ,于是

只是一个标量,跟学习率类似,容易求得最优值是 ,而更新方向则是最大化
现在代入欧氏范数 我们就 这样一来 ,即梯度下降(SGD)。一般地,对于 p 范数

Hölder 不等式 [8] 给出 ,其中 ,利用它我们得到

等号成立的条件是

以它为方向向量的优化器叫做 pbSGD,可参考《pbSGD: Powered Stochastic Gradient Descent Methods for Accelerated Non-Convex Optimization》[9]
特别地,当 时有 ,此时退化为 SignSGD,即 SignSGD 实际上是 范数下的最速梯度下降。


矩阵范数

现在让我们将目光切换到矩阵参数 。类似地,我们将它的更新规则定义为

此时 是某种矩阵范数。同样使用一阶近似,我们得到

这里 。还是使用“范数-方向”解耦,即 我们得到

然后就是具体范数具体分析了。矩阵常用的范数有两种,一种是 F 范数,它实际上就是将矩阵展平成向量后算的欧氏范数,这种情况下结论跟向量是一样的,答案就是 SGD,这里不再展开;另一种则是由向量范数诱导出来的 2 范数,也称谱范数:

注意右端出现的 的对象都是向量,所以定义是明确的。更多关于 2 范数的讨论可以参考《深度学习中的Lipschitz约束:泛化与生成模型》《低秩近似之路(二):SVD》
由于 2 范数是由“矩阵-向量”乘法诱导出来的,因此它更贴合矩阵乘法,并且还恒成立 ,即 2 范数相比 F 范数更紧凑。
所以,接下来我们就针对 2 范数进行计算。设 的 SVD 我们有

根据定义, 于是 ,因此

等号在所有 都等于 1 时取到,此时

至此,我们证明了 2 范数惩罚下的梯度下降正是 时的 Muon 优化器!当 时,滑动平均生效,我们可以将它视为梯度的一种更精准的估计,所以改为对动量取
总的来说,Muon 相当于 2 范数约束下的梯度下降,2 范数更好地度量了矩阵之间的本质差异,从而使每一步都走得更精准、更本质。


追根溯源

Muon 还有一个更久远的相关工作《Shampoo: Preconditioned Stochastic Tensor Optimization》[10],这是 2018 年的论文,提出了名为 Shampoo 的优化器,跟 Muon 有异曲同工之处。
Adam 通过梯度平方的平均来自适应学习率的策略,最早提出自 Adagrad 的论文《Adaptive Subgradient Methods for Online Learning and Stochastic Optimization》[11],里边提出的是直接将梯度平方累加的策略,这相当于全局等权平均,后来的 RMSProp、Adam 则类比动量的设计,改为滑动平均,发现在实践中表现更好。
不仅如此,Adagrad 最开始提出的实际是累加外积 ,只不过缓存外积空间成本太大,所以实践中改为 Hadamard 积 。那累加外积的理论依据是什么呢?
这我们在《从Hessian近似看自适应学习率优化器》[12推导过,答案是“梯度外积的长期平均 近似了 Hessian 矩阵的平方 ,所以这实际上在近似二阶的 Newton 法。
Shampoo 传承了 Adagrad 缓存外积的思想,但考虑到成本问题,取了个折中。
跟 Muon一样,它同样是针对矩阵(以及高阶张量)进行优化,策略是缓存梯度的矩阵乘积 ,而不是外积,这样空间成本是 而不是

这里的 是笔者自己加的,Shampoo 默认了 同样是矩阵的幂运算,可以用 SVD 来完成。由于 Shampoo 没有提出 Newton-schulz 迭代之类的近似方案,是直接用 SVD 算的,所以为了节省计算成本,它并没有每一步都计算 ,而是间隔一定步数才更新它们的结果。
特别地,当 时,Shampoo 的更新向量 通过对 进行 SVD 我们可以证明

这表明 时 Shampoo 和 Muon 在理论上是等价的!因此,Shampoo 与 Muon 在更新量的设计方面有着相通之处。



文章小结

本文介绍了最近推特上颇为热闹的 Muon 优化器,它专门为矩阵参数定制,目前看来比 AdamW 更高效,并且似乎体现了一些向量化与矩阵化的本质差异,值得学习和思考一番。

参考文献

[1] https://github.com/KellerJordan/Muon

[2] https://en.wikipedia.org/wiki/Matrix_sign_function

[3] https://x.com/leloykun/status/1846165001746501899

[4] https://en.wikipedia.org/wiki/Marchenko%E2%80%93Pastur_distribution

[5] https://kellerjordan.github.io/posts/muon/

[6] https://ieeexplore.ieee.org/abstract/document/7347351

[7] https://papers.cool/arxiv/2409.20325

[8] https://en.wikipedia.org/wiki/H%C3%B6lder%27s_inequality

[9] https://www.ijcai.org/proceedings/2020/451

[10] https://papers.cool/arxiv/1802.09568

[11] https://jmlr.org/papers/v12/duchi11a.html

[12] https://kexue.fm/archives/10588


更多阅读



#投 稿 通 道#

 让你的文字被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢?答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学术热点剖析科研心得竞赛经验讲解等。我们的目的只有一个,让知识真正流动起来。


📝 稿件基本要求:

• 文章确系个人原创作品,未曾在公开渠道发表,如为其他平台已发表或待发表的文章,请明确标注 

• 稿件建议以 markdown 格式撰写,文中配图以附件形式发送,要求图片清晰,无版权问题

• PaperWeekly 尊重原作者署名权,并将为每篇被采纳的原创首发稿件,提供业内具有竞争力稿酬,具体依据文章阅读量和文章质量阶梯制结算


📬 投稿通道:

• 投稿邮箱:hr@paperweekly.site 

• 来稿请备注即时联系方式(微信),以便我们在稿件选用的第一时间联系作者

• 您也可以直接添加小编微信(pwbot02)快速投稿,备注:姓名-投稿


△长按添加PaperWeekly小编



🔍


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧


·
·
·

PaperWeekly
PaperWeekly是一个推荐、解读、讨论和报道人工智能前沿论文成果的学术平台,致力于让国内外优秀科研工作得到更为广泛的传播和认可。社区:http://paperweek.ly | 微博:@PaperWeekly
 最新文章