KAN卷积神经网络来了! 就在昨天Alex Bodner团队发布一篇推文,展现了关于 KAN卷积神经网络 的研究成果。
是的你没看错,KAN卷积神经网络 已经被实现了。
什么是 KAN?
KAN 连接应用的函数定义为一个学习的 B 样条曲线,加上一个残差激活函数 b(x),所有这些乘以一个可学习的参数 w。
KAN的高效实现代码和KAN卷积神经网络实现代码我一起打包好了,大家可以任意添加一位小助手获取(长按二维码图片添加既可)
KAN Convolutions(KAN卷积)是一种特殊的卷积操作,它在每个边缘上应用一个可学习的非线性函数,并将它们相加。
KAN卷积的核相当于一个具有4个输入和1个输出神经元的KAN线性层。
KAN 卷积中的参数
假设我们有一个KxK的核(或称为卷积核)。
在这种情况下,对于该矩阵的每个元素,我们有一个ϕ
,其参数数量是:gridsize + 1
。由于实现上的问题,高效的KAN(Kernel Activation Network)定义了:
这给予了激活函数b
更多的表达能力。
因此,线性层的参数数量是gridsize + 2
。
所以,对于KAN卷积,我们总共有K²(gridsize + 2)
个参数,而普通的卷积只有K²
个参数。
考虑到gridsize
(在我们的实验中)通常介于k
和k²
之间,但k
倾向于是一个较小的值,介于2和16之间。
初步评估
我们测试的不同架构是:
KAN卷积层连接到KAN线性层(KKAN)
Kan卷积层连接到多层感知机(MLP)(CKAN)
在卷积之间使用批量归一化的CKAN(CKAN_BN)
ConvNet(经典卷积连接到MLP)(ConvNet)
简单的MLPs
KAN卷积的实现是一个有前景的想法,尽管它仍处于早期阶段。
只是进行了一些初步实验来评估KAN卷积的性能,以下是一些研究结果:
卷积层列表中的每个元素包含了卷积的数量和对应的核大小
基于28x28的MNIST数据集,我们可以观察到KANConv & MLP模型在准确度上与传统的大型ConvNet相比是可以接受的。
然而,不同之处在于KANConv & MLP所需的参数数量是标准ConvNet所需参数的七分之一。
此外,KKAN在准确度上比中等规模的ConvNet低0.04,但参数数量几乎只有一半(94k vs 157k),这显示了这种架构的潜力。
目前,我们并没有看到KAN卷积网络在性能上相对于传统卷积网络有显著的提升。
我们认为这是由于我们正在使用简单的数据集和小型模型所导致的,因为我们的架构的优势在于其所需的参数数量显著少于我们尝试过的最佳架构(大型ConvNet,这是一个不公平的比较,因为其规模庞大)。
在对比具有相同MLP连接的2个等价的传统卷积层和KAN卷积层时,传统方法略微胜出,准确度提高了0.06,而KAN卷积和具有几乎一半参数数量的KAN线性层则准确度降低了0.04。
简而言之,KAN卷积是一种特殊的卷积技术,它通过使用可学习的非线性函数来提高卷积层的表达能力,并且可能在某些任务中能以较少的参数达到接近甚至相当的性能。
代码我帮大家下载好了,大家可以自行研究下。(长按二维码图片添加即可)
例子
为MNIST构建KANConv
import torch
from torch import nn
import torch.nn.functional as F
from kan_convolutional.KANConv import KAN_Convolutional_Layer
class KANC_MLP(nn.Module):
def __init__(self,device: str = 'cpu'):
super().__init__()
self.conv1 = KAN_Convolutional_Layer(
n_convs = 5,
kernel_size= (3,3),
device = device
)
self.conv2 = KAN_Convolutional_Layer(
n_convs = 5,
kernel_size = (3,3),
device = device
)
self.pool1 = nn.MaxPool2d(
kernel_size=(2, 2)
)
self.flat = nn.Flatten()
self.linear1 = nn.Linear(625, 256)
self.linear2 = nn.Linear(256, 10)
def forward(self, x):
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool1(x)
x = self.flat(x)
x = self.linear1(x)
x = self.linear2(x)
x = F.log_softmax(x, dim=1)
return x
如果喜欢本篇的内容记得点点再看,并把他转发到你的朋友圈。请永远不要停止学习,这是你武装自己对抗这个世界最有力的武器!