import jax import jax.numpy as jnp from tqdm import tqdm
n, m, T = 1024, 1024, 5 key, data = jax.random.key(42), jnp.array([]) for _ in tqdm(range(1000), ncols=0, desc='SVD'): key, subkey = jax.random.split(key) M = jax.random.normal(subkey, shape=(n, m)) S = jnp.linalg.svd(M, full_matrices=False)[1] data = jnp.concatenate([data, S / (S**2).sum()**0.5])
@jax.jit def f(w, x): k, x1, x2 = w for _ in range(T): x = x + k * x * (x**2 - x1**2) * (x**2 - x2**2) return ((x - 1)**2).mean()
f_grad = jax.grad(f) w, u = jnp.array([1, 0.9, 1.1]), jnp.zeros(3) for _ in tqdm(range(100000), ncols=0, desc='SGD'): u = 0.9 * u + f_grad(w, data) # 动量加速 w = w - 0.01 * u
k, x1, x2 = w a, b, c = 1 + k * x1**2 * x2**2, -k * (x1**2 + x2**2), k print(f'{n} & {m} & {T} & {k:.3f} & {x1:.3f} & {x2:.3f} & {a:.3f} & {b:.3f} & {c:.3f} & {f(w, data):.5f}')
一些思考
如果按照默认选择 T=5,那么对于一个 的矩阵参数,Muon 的每一步更新至少需要算 15 次 与 的矩阵乘法,这计算量毋庸置疑是比 Adam 明显大的,由此可能有读者担心 Muon 实践上是否可行。事实上,这种担心是多余的,Muon 计算虽然比 Adam 复杂,但每一步增加的时间不多,笔者的结论是 5% 内,Muon 作者则声称能做到 2%。这是因为 Muon 的矩阵乘法发生在当前梯度计算完后、下一梯度计算前,这期间几乎所有的算力都是空闲的,而这些矩阵乘法是静态大小且可以并行,因此不会明显增加时间成本,反而是 Muon 比 Adam 少一组缓存变量,显存成本更低。Muon 最值得深思的地方,其实是向量与矩阵的内在区别,以及它对优化的影响。SGD、Adam、Tiger 等常见优化器的更新规则是 Element-wise 的,即不论向量、矩阵参数,实际都视为一个大向量,分量按照相同的规则独立地更新。具备这个特性的优化器往往理论分析起来更加简化,也方便张量并行,因为一个大矩阵切成两个小矩阵独立处理,并不改变优化轨迹。但 Muon 不一样,它以矩阵为基本单位,考虑了矩阵的一些独有特性。可能有些读者会奇怪:矩阵和向量不都只是一堆数字的排列吗,能有什么区别?举个例子,矩阵我们有“迹(trace)”这个概念,它是对角线元素之和,这个概念不是瞎定义的,它有一个重要特性是在相似变换下保持不变,它还等于矩阵的所有特征值之和。从这个例子就可以看出,矩阵的对角线元素跟非对角线元素,地位其实是不完全对等的。而 Muon 正是因为考虑了这种不对等性,才有着更好的效果。当然,这也会导致一些负面影响。如果一个矩阵被划分到不同设备上,那么用 Muon 时就需要将它们的梯度就需要汇聚起来再计算更新量了,而不能每个设备独立更新,这增加了通信成本。即便我们不考虑并行方面,这个问题也存在,比如 Multi-Head Attention 一般是通过单个大矩阵投影到 Q(K,V 同理),然后用 reshape 的方式得到多个 Head,这样在模型参数中就只有单个矩阵,但它本质上是多个小矩阵,所以按道理我们需要将大矩阵拆开成多个小矩阵独立更新。总之,Muon 这种非 Element-wise 的更新规则,在捕捉向量与矩阵的本质差异的同时,也会引入一些小问题,这可能会不满足一些读者的审美。(补充:几乎在本文发布的同时,Muon 的作者 Keller Jordan 也发布了自己的一篇博客《Muon: An optimizer for hidden layers in neural networks》[5]。)
范数视角
从理论上看,Muon 捕捉了矩阵的什么关键特性呢?也许接下来的范数视角可以回答我们的问题。这一节的讨论主要参考了论文《Stochastic Spectral Descent for Discrete Graphical Models》[6] 和《Old Optimizer, New Norm: An Anthology》[7],特别是后一篇。不过其中的出发点并不是新的,我们在《梯度流:探索通向最小值之路》就已经简单涉猎过:对于向量参数 ,我们将下一步的更新规则定义为其中 是某个向量范数,这称为在某个范数约束下的“最速梯度下降”。接着假设 足够小,那么第一项占主导,这意味着 与 会很接近,于是我们假设 的一阶近似够用了,于是问题简化成记 ,那么可以简写成计算 的一般思路是求导,但《Old Optimizer, New Norm: An Anthology》[7] 提供了一个不用求导的统一方案:将 分解为范数 和方向向量 ,于是 只是一个标量,跟学习率类似,容易求得最优值是 ,而更新方向则是最大化 。现在代入欧氏范数即,我们就有和,这样一来 ,即梯度下降(SGD)。一般地,对于 p 范数Hölder 不等式 [8] 给出 ,其中 ,利用它我们得到等号成立的条件是以它为方向向量的优化器叫做 pbSGD,可参考《pbSGD: Powered Stochastic Gradient Descent Methods for Accelerated Non-Convex Optimization》[9]。特别地,当 时有 和 ,此时退化为 SignSGD,即 SignSGD 实际上是 范数下的最速梯度下降。