KAN 2.0: Kolmogorov-Arnold Networks Meet Science
目录
0. 摘要
1. 简介
2. MultKAN:通过乘法增强 KAN
3. 从科学到 KANs
3.1 向 KANs 中添加重要特征
3.2 向 KANs 构建模块化结构
3.3 将符号公式编译到 KANs 中
4. 从 KANs 到科学
4.1 从 KANs 中识别重要特征
4.2 从 KANs 中识别模块化结构
4.3 从 KANs 中识别符号公式
5. 应用
7. 讨论
0. 摘要
AI 与科学的主要挑战在于它们的内在不兼容性:当今的 AI 主要基于连接主义,而科学则依赖于符号主义。为弥合这两个领域之间的差距,我们提出了一个框架,旨在无缝协同 Kolmogorov-Arnold网络(KANs)与科学。该框架强调 KANs 在科学发现中的三个方面的应用:识别相关特征、揭示模块化结构以及发现符号公式。这种协同是双向的:科学到 KAN(将科学知识融入 KANs),以及 KAN 到科学(从 KANs 中提取科学见解)。我们强调了 pykan 中的一些主要新功能:
MultKAN:带有乘法节点的 KANs。
kanpiler:将符号公式编译为 KANs 的编译器。
树转换器:将 KANs(或任何神经网络)转换为树图
基于这些工具,我们展示了 KANs 在发现各种类型物理定律方面的能力,包括守恒量、拉格朗日量、对称性和本构定律(constitutive laws)。
1. 简介
最近,一种称为 Kolmogorov-Arnold 网络(KAN)的新型神经网络在科学相关任务中显示出潜力。与多层感知机(MLP)不同,MLP 在节点上具有固定的激活函数,而 KANs 在边缘上具有可学习的激活函数。由于 KANs 能够将高维函数分解为一维函数,通过符号回归(symbolically regressing)这些一维函数,可以提高可解释性。然而,他们对可解释性的定义有些狭窄,几乎完全将其等同于提取符号公式的能力。这种有限的定义限制了 KANs 的应用范围,因为在科学中,符号公式并不总是必要或可行的。例如,尽管符号方程在物理学中强大且普遍存在,但在化学和生物学中,系统通常过于复杂,无法用这种方程来表示。在这些领域,模块化结构和关键特征可能足以表征这些系统的有趣方面。另一个被忽视的方面是将知识嵌入 KANs 的逆向任务:如何以物理学引导学习的方式,将先验知识融入 KANs?
具体而言,科学解释可能具有不同的层次,从最粗略 / 最简单 / 关联性的解释到最精细 / 最困难 / 因果性的解释:
重要特征:例如,“y 完全由 x1 和 x2 决定,而其他因素不重要。” 换句话说,存在一个函数 f 使得 y = f(x1, x2)。
模块化结构:例如,“x1 和 x2 以加法方式独立地对 y 作出贡献。” 这意味着存在函数 g 和 h 使得 y = g(x1) + h(x2)。
符号公式:例如,“y 依赖于 x1 作为正弦函数,并依赖于 x2 作为指数函数。”换句话说,y = sin(x1) + exp(x2)。
本文报告了如何从 KANs 中嵌入和提取这些特性。论文的结构如下(如图 1 所示):
在第 2 节中,我们通过引入乘法节点扩展了原始的 KAN,提出了一种新模型称为 MultKAN。
在第 3 节中,我们探讨了将科学归纳偏置嵌入 KANs 的方法,重点讨论了重要特征(第 3.1 节)、模块化结构(第 3.2 节)和符号公式(第 3.3 节)。
在第 4 节中,我们提出了从 KANs 中提取科学知识的方法,同样涵盖了重要特征(第 4.1 节)、模块化结构(第 4.2 节)和符号公式(第 4.3 节)。
在第 5 节中,我们利用前几节开发的工具将 KANs 应用于各种科学发现任务。这些任务包括发现守恒量、对称性、拉格朗日量和本构定律。
2. MultKAN:通过乘法增强 KAN
(2024,KAN,MLP,可训练激活函数,样条函数,分层函数)Kolmogorov–Arnold 网络
乘法 Kolmogorov-Arnold 网络(MultKAN)。为了明确引入乘法操作,我们提出了 MultKAN,它可以更清晰地揭示数据中的乘法结构。MultKAN(如图 2 右上所示)类似于 KAN,两者都具有标准的 KAN 层。我们将 KAN 层的输入节点称为节点,将 KAN 层的输出节点称为子节点。KAN 与 MultKAN 之间的区别在于从当前层的子节点到下一层节点的转换。在 KAN 中,节点直接从上一层的子节点复制。而在 MultKAN 中,一些节点(加法节点)从相应的子节点复制,而其他节点(乘法节点)则对来自上一层的 k 个子节点执行乘法。为了简单起见,我们设定 k=2 且小于 3。
根据 MultKAN 图(图 2 右上),可以直观地理解 MultKAN 是一个普通的 KAN,插入了可选的乘法操作。为了在数学上更精确,我们定义以下符号:第 l 层的加法(乘法)操作数量分别表示为 n^a_l (n^m_l)。这些被收集到数组中:加法宽度 n^a ≡ [n^a_0, n^a_1, ···, n^a_L] 和乘法宽度 n^m ≡ [n^m_0, n^m_1, ···, n^m_L]。当 n^m_0 = n^m_1 = ··· = n^m_L = 0 时,MultKAN 简化为一个KAN。例如,图 2(右上)展示了一个 n^a = [2, 2, 1] 和 n^m = [0, 2, 0] 的 MultKAN。
一个 MultKAN 层由标准的 KANLayer Φ_l 和一个乘法层 M_l 组成。Φ_l 接收一个输入向量
并输出
乘法层由两个部分组成:乘法部分对子节点对执行乘法,而另一部分执行恒等变换。用 Python 表示,M_l 将 z_l 转换如下:
其中 ⊙ 表示元素逐个相乘。MultKANLayer 可以简洁地表示为 Ψ_l ≡ M_l ◦ Φ_l。整个 MultKAN 表示为:
由于乘法层中没有可训练的参数,所有适用于 KAN 的稀疏正则化技术(如 ℓ1 和熵正则化)都可以直接应用于 MultKAN。对于乘法任务f(x, y) = xy,MultKAN 确实学会使用一个乘法节点,使其执行简单的乘法,因为所有学习到的激活函数都是线性的(图 2 右下)。
尽管 KAN 以前被视为 MultKANs 的一个特例,但我们扩展了定义,并将 “KAN” 和 “MultKAN” 视为同义词。默认情况下,当我们提到 KAN 时,乘法是允许的。如果我们具体提到一个没有乘法的KAN,我们将明确说明。
3. 从科学到 KANs
在科学中,领域知识至关重要,即使在数据少或没有数据的情况下,也能有效工作。因此,将物理学引导的方法应用于 KAN 是有益的:我们应当将现有的归纳偏置融入 KAN,同时保持其从数据中发现新物理的灵活性。
我们探讨了三种可以集成到 KAN 中的归纳偏置。从最粗略/最简单/关联性到最精细/最困难/因果性,它们分别是重要特征(第3.1节)、模块化结构(第3.2节)和符号公式(第3.3节)。
3.1 向 KANs 中添加重要特征
在回归问题中,目标是找到一个函数 f,使得 y = f(x1, x2, ···, xn)。假设我们想引入一个辅助输入变量 a = a(x1, x2, ···, xn),将函数转换为 y = f(x1, ···, xn, a)。虽然辅助变量 a 并未增加新信息,但它可以增强神经网络的表达能力。这是因为网络不需要耗费资源来计算辅助变量。此外,计算过程可能会变得更简单,从而提高可解释性。
3.2 向 KANs 构建模块化结构
模块化在自然界中普遍存在:例如,人类大脑皮层被划分为几个功能上独立的模块,这些模块分别负责诸如感知或决策等特定任务。这种模块化简化了对神经网络的理解,因为它使我们能够集体解释一组神经元,而不是单独分析每个神经元。结构上的模块化特征表现为连接簇,其中簇内连接远强于簇间连接。
3.3 将符号公式编译到 KANs 中
科学家常常通过符号方程(symbolic equations)来表达复杂现象,这令人满意。然而,虽然这些方程简洁,但由于其特定的功能形式,可能缺乏捕捉所有细微差别的表达能力。相比之下,神经网络具有高度的表达能力,但可能在学习已经为科学家所知的领域知识时,浪费训练时间和数据。为了利用这两种方法的优势,我们提出了一个两步程序:(1)将符号方程编译为 KANs;(2)使用数据对这些 KANs 进行微调。第一步的目的是将已知的领域知识嵌入 KANs,而第二步则专注于从数据中学习新的“物理”。
通过宽度/深度扩展来增加表达能力。kanpiler 生成的 KAN 网络是紧凑的,没有冗余边,这可能限制其表达能力并阻碍进一步微调。为了解决这一问题,我们提出了expand_width 和 expand_depth 方法来使网络变得更宽更深,如图 5(c) 所示。扩展方法最初添加零激活函数,这些函数在训练过程中因梯度为零而受限。因此,应使用 perturb 方法将这些零函数扰动为非零值,使其具有非零梯度并可训练。
4. 从 KANs 到科学
如今的黑箱深度神经网络虽然功能强大,但解释这些模型仍然是一个挑战。科学家不仅追求高性能的模型,还希望能够从模型中提取有意义的知识。在本节中,我们将专注于增强 KANs 在科学用途上的可解释性。我们将探讨从 KANs 中提取知识的三个层次,从最基础到最复杂的:重要特征(第 4.1 节)、模块化结构(第 4.2 节)和符号公式(第 4.3 节)。
4.1 从 KANs 中识别重要特征
识别重要变量对于许多任务至关重要。给定一个回归模型 f,其中 y ≈ f(x1, x2, ... , xn),我们的目标是为输入变量分配分数,以衡量它们的重要性。Liu 等人 [57] 使用了 L1 范数来表示边的重要性,但这种度量可能存在问题,因为它只考虑了局部信息。
为了解决这个问题,我们引入了一种更有效的归因分数(attribution score),它比 L1 范数更好地反映了变量的重要性。为简单起见,假设存在乘法节点,因此我们不需要区分节点和子节点。假设我们有一个宽度为 [n0,n1,⋅⋅⋅,nL] 的 L 层 KAN。我们将 E_{l,i,j} 定义为在 (l,i,j) 边上的激活值的标准差,N_{l,i} 定义为在 (l,i) 节点上的激活值的标准差。然后,我们定义节点(归因)分数 A_{l,i} 和边(归因)分数 B_{l,i,j}。
在 [57] 中,我们简单地定义了 B_{l,i,j} = E_{l,i,j} 和 A_{l,i} = N_{l,i}。然而,这种定义没有考虑到网络后面的部分;即使一个节点或一条边本身具有较大的范数,如果网络的其余部分实际上是零函数,它也可能不会对输出做出贡献。因此,我们现在从输出层到输入层逐层迭代计算节点和边的分数。我们将所有输出维度的分数设为单位分数,即 A_{L,i} = 1, i=0,1,⋅⋅⋅,n_L−1,然后按以下方式计算分数:
4.2 从 KANs 中识别模块化结构
虽然归因分数提供了关于哪些边或节点重要的有价值见解,但它并未揭示模块化结构,即重要的边和节点是如何连接的。在本部分中,我们旨在通过检查两种类型的模块化结构来揭示从训练后的 KANs 和 MLPs 中得到的模块化结构:解剖模块化和函数模块化。
解剖模块化(Anatomical modularity)指的是相互空间上接近的神经元之间的连接强度通常比相隔较远的神经元更强。尽管人工神经网络缺乏物理空间坐标,但引入物理空间的概念已被证明可以提高可解释性 [51, 52]。 我们采用了 [51, 52] 中的神经元交换方法,该方法在保持网络功能的同时,缩短了连接的长度。
函数模块化(Functional modularity)涉及神经网络所表示的整体功能。给定一个 Oracle 网络,其中内部细节如权重和隐藏层激活值无法访问(过于复杂而难以分析),我们仍然可以通过输入和输出的前向和后向传播收集关于功能模块化的信息。我们定义了三种类型的函数模块化(见图 8 (a)),主要基于 [84]。
4.3 从 KANs 中识别符号公式
符号公式是最具信息量的,因为一旦知道了这些公式,它们可以清楚地揭示重要特征和模块化结构。在 Liu 等人 [57] 中,作者展示了一些可以提取符号公式的示例,必要时结合一些先验知识。利用上述提出的新工具(特征重要性、模块化结构和符号公式),用户可以利用这些新工具轻松与 KANs 互动和协作,使符号回归变得更加容易。我们在下文中介绍了三种技巧,如图 9 所示。
5. 应用
前面的部分主要关注回归问题,以便于教学目的。在本节中,我们将 KANs 应用于发现物理概念,如守恒量、拉格朗日量、隐藏对称性和本构定律(constitutive laws)。这些示例展示了如何将本文提出的工具有效地整合到现实科学研究中,以应对这些复杂的任务。
7. 讨论
KAN 在软件 1.0 和 2.0 之间进行插值。Kolmogorov-Arnold Networks(KANs)与其他神经网络(软件 2.0,由 Andrej Karpathy 提出的术语)的关键区别在于其更高的可解释性,这使得用户可以像使用传统软件(软件 1.0)一样进行操作。然而,KANs 并不完全是传统的软件,因为它们具有(1)学习能力(优点),使得它们能够从数据中学习新知识,以及(2)减少的可解释性(缺点),随着网络规模的增加,它们变得更难以解释和控制。图 14(a)展示了软件 1.0、软件 2.0 和 KANs 在可解释性-学习能力平面上的位置,说明 KANs 如何在这两种范式之间平衡权衡。本文的目标是提出各种工具,使 KANs 更像软件 1.0,同时利用软件 2.0 的学习能力。
效率改进。原始的 pykan 包 [57] 效率较低。我们已经采用了一些技术来提高其效率。
高效的样条评估。受到 Efficient KAN [9] 的启发,我们通过避免不必要的输入扩展来优化样条评估。对于一个具有 L 层、每层 N 个神经元和网格大小 G 的 KAN,内存使用已从 O(LN²G) 减少到 O(LNG)。
仅在需要时启用符号分支。一个 KAN 层包含样条分支和符号分支。符号分支比样条分支耗时更多,因为它无法并行化(需要灾难性的双重循环)。然而,在许多应用中,符号分支是不必要的,因此我们可以在可能的情况下跳过它,显著减少运行时间,尤其是当网络较大时。
仅在需要时保存中间激活值。为了绘制 KAN 图,需要保存中间激活值。最初,激活值默认被保存,导致运行时间变慢和内存使用过多。我们现在仅在需要时保存中间激活值(例如,用于绘图或在训练中应用正则化)。用户可通过一行代码启用这些效率改进:
model.speed()
。GPU 加速。最初,所有模型都在 CPU 上运行,因为问题的规模较小。我们现在已使模型兼容 GPU。例如,用 Adam 训练一个 [4,100,100,100,1] 的模型 100 步,之前在 CPU 上需要整整一天(在实现 1、2、3 之前),现在在 CPU 上需要 20 秒,在 GPU 上则不到 1 秒。然而,KANs 在效率上仍然落后于 MLPs,尤其是在大规模时。社区一直致力于基准测试和提高 KAN 的效率,效率差距已经显著减少 [36]。
由于本文的目标是使 KANs 更像软件 1.0,在面对 1.0(具有交互性和多功能性)与 2.0(高效和具体性)之间的权衡时,我们优先考虑交互性和多功能性。例如,我们在模型中存储缓存数据(这会消耗额外的内存),因此用户可以直接调用 model.plot()
来生成 KAN 图,而无需手动进行前向传递以收集数据。
可解释性。尽管 KANs 中的可学习一元函数比 MLPs 中的权重矩阵更具可解释性,但可扩展性仍然是一个挑战。随着 KAN 模型的扩展,即使所有样条函数本身都是可解释的,管理这些 1D 函数的组合输出变得越来越困难。因此,KAN 可能仅在网络规模相对较小时才保持可解释性(图 14(b),粗红线)。需要注意的是,可解释性依赖于内在因素(与模型本身相关)和外在因素(与可解释性方法相关)。高级的可解释性方法应能够处理不同层次的可解释性。例如,通过符号回归、模块化发现和特征归因(图 14(b),细红线)来解释 KANs,可解释性与规模的 Pareto 前沿超出了单独 KAN 所能达到的范围。未来研究的一个有希望的方向是开发更先进的可解释性方法,进一步推动当前的 Pareto 前沿。
未来工作。本文介绍了一个将 KANs 与科学知识整合的框架,主要关注小规模的物理相关示例。未来的两个有前景的方向包括将该框架应用于更大规模的问题,并将其扩展到物理学以外的其他科学领域。
项目页面:https://github.com/KindXiaoming/pykan
论文地址:https://arxiv.org/abs/2408.10205#
进 Q 交流群:922230617 或加 VX:CV_EDPJ 进 V 交流群
加 VX 群请备注学校 / 单位 + 研究方向
CV 进计算机视觉群
KAN 进 KAN 群