KAN网络,学术研究中的庞氏骗局
最近KAN比较火爆,自从原始论文发表之后,相关研究层出不穷。
一个代表性的研究是港中文提出的U-KAN,已经被引用了好几次了
模型结构也很容易理解,就是把核心层换成KAN了:
相关研究都已经应用到水文学中了,这里是用KAN和Transformer预测径流:
KAN真有说的这么厉害?根据作者所说,在短期预测中,KAN的的效果要优于Transformer!
KAN
那么什么是KAN呢?
在MLP中,一层中的每个节点/神经元都连接到下一层中的每个节点/神经元。
MLP 中的节点或神经元使用激活函数来捕获其输入中的非线性。
这些激活函数是「固定」且「非线性的。」
MLP 的灵感来自通用逼近定理。如果 MLP 的隐藏层有足够的神经元,它就可以将任何真实的 连续函数逼近到任何所需的精度。
MLP ( N(x)
) 的这种近似可以在数学上描述如下
KAN 是受Kolmogorov-Arnold 表示定理(由俄罗斯数学家Vladimir Arnold和Andrey Kolmogorov提出)启发的神经网络。
该定理指出,每个多变量 连续函数可以用连续单变量函数的总和来表示。
简而言之,它告诉我们每个复杂的多变量函数都可以分解为更简单的一维函数。
该定理在数学上描述如下——
这个定理催生了KAN架构:
在最简单的形式中,KAN 类似于Kolmogorov-Arnold 表示定理方程,并且仅由两层组成。
第一层使用一组单变量函数对每个输入进行变换。
第二层对这些变换进行求和并输出最终预测。
但当扩展到学习复杂的现实世界函数时,KAN 与 MLP 一样由多个层组成,其中每一层的输出都是下一层的输入。
与 MLP 不同,在 KAN 中,激活函数是可学习的B样条。
由于 KAN 中发生的所有操作都是可微分的,因此可以使用反向传播和传统损失函数来训练它们。
下表简要总结了 KAN 和 MLP 之间的差异。
KAN模型测试
KAN刚发布时,我尝试了其在MNIST数据集的效果,核心架构如下:
import torch
import torch.nn as nn
class SimpleKAN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleKAN, self).__init__()
self.u_funcs = nn.ModuleList([nn.Linear(1, hidden_size) for _ in range(input_size)])
self.v_funcs = nn.ModuleList([nn.Linear(hidden_size, 1) for _ in range(input_size)])
self.w_funcs = nn.Linear(input_size, output_size)
def forward(self, x):
u_outputs = [torch.relu(u(x[:, i:i+1])) for i, u in enumerate(self.u_funcs)]
v_outputs = [v(u) for v, u in zip(self.v_funcs, u_outputs)]
stacked_v = torch.cat(v_outputs, dim=1)
output = self.w_funcs(stacked_v)
return output
我自己测试得到的准确率为93%
准确率有93%,看起来不错,但:
精确度仅仅超过了一层的MLP,距离一些卷积神经网络的98%,99%准确率还有一定距离。
当时我觉得这个模型可能并没有宣称的那么好。。
最近,随着越来越多的炒作,KAN变体层出不穷:
U-KAN / Swin-UKAN / TransUKAN / VisionKAN /MambaKAN /ResKAN/AttnKAN...
(有些是我编的但可以预见的是,以后肯定会有。)
效果真的好吗?我自己尝试了一下KAN-UNet
模型结构如下:
import torch
from torch import nn
import torch.nn.functional as F
from KANUmain.fastkanconv import FastKANConvLayer
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, device):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.device = device
self.double_conv = nn.Sequential(
FastKANConvLayer(self.in_channels, self.out_channels//2, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
nn.BatchNorm2d(self.out_channels//2),
nn.ReLU(inplace=True),
FastKANConvLayer(self.out_channels//2, self.out_channels, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
nn.BatchNorm2d(self.out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels, device='mps'):
super().__init__()
self.device = device
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels, device=self.device)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True, device='mps'):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, device=device)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = FastKANConvLayer(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class KANU_Net(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True, device='mps'):
super(KANU_Net, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.device = device
self.channels = [64, 128, 256, 512, 1024]
self.inc = (DoubleConv(n_channels, 64, device=self.device))
self.down1 = (Down(self.channels[0], self.channels[1], self.device))
self.down2 = (Down(self.channels[1], self.channels[2], self.device))
self.down3 = (Down(self.channels[2], self.channels[3], self.device))
factor = 2 if bilinear else 1
self.down4 = (Down(self.channels[3], self.channels[4] // factor, self.device))
self.up1 = (Up(self.channels[4], self.channels[3] // factor, bilinear, self.device))
self.up2 = (Up(self.channels[3], self.channels[2] // factor, bilinear, self.device))
self.up3 = (Up(self.channels[2], self.channels[1] // factor, bilinear, self.device))
self.up4 = (Up(self.channels[1], self.channels[0], bilinear, self.device))
self.outc = (OutConv(self.channels[0], n_classes))
def forward(self, x):
# Encoder
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
#Decoder
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
然后我试了一下经典的分割任务。效果非常差。
最近也得知了消息,KAN最终被NeurIPS 2024拒稿
Reviewer估计是,一开始信了KAN的各种炒作,然后代码一跑直接沉默了。
全世界都在吹KAN,然而实践是检验真理的唯一标准,也没有见到一个SOTA的具体KAN实例出来。
结语
学术界也同样有庞氏骗局,见到热点,也不管真假,抓紧加入骗局,在一个精心设置的数据集上调参数,精挑细选最好的结果,把测试集“不小心”混一点训练数据进去等等等。上面的还算良心点的,验证怀疑有些文章指标都是造假的,反正大家都吹好,也不差我这一篇。
现在发表顶刊好好做研究的人越来越少了,都是想着如何包装,如何忽悠,如何炒作。
科研人共勉
求求你点个在看吧,这对我真的很重要