现在还用KAN网络的也是神人了...

文摘   2024-11-10 23:59   北京  

KAN网络,学术研究中的庞氏骗局

最近KAN比较火爆,自从原始论文发表之后,相关研究层出不穷。

一个代表性的研究是港中文提出的U-KAN,已经被引用了好几次了

模型结构也很容易理解,就是把核心层换成KAN了:

相关研究都已经应用到水文学中了,这里是用KAN和Transformer预测径流:

KAN真有说的这么厉害?根据作者所说,在短期预测中,KAN的的效果要优于Transformer!

KAN

那么什么是KAN呢?

在MLP中,一层中的每个节点/神经元都连接到下一层中的每个节点/神经元。

MLP 中的节点或神经元使用激活函数来捕获其输入中的非线性。

这些激活函数是「固定」「非线性的。」

MLP 的灵感来自通用逼近定理。如果 MLP 的隐藏层有足够的神经元,它就可以将任何真实的 连续函数逼近到任何所需的精度。

MLP ( N(x)) 的这种近似可以在数学上描述如下

KAN 是受Kolmogorov-Arnold 表示定理(由俄罗斯数学家Vladimir Arnold和Andrey Kolmogorov提出)启发的神经网络。

该定理指出,每个多变量 连续函数可以用连续单变量函数的总和来表示。

简而言之,它告诉我们每个复杂的多变量函数都可以分解为更简单的一维函数。

该定理在数学上描述如下——

这个定理催生了KAN架构:

在最简单的形式中,KAN 类似于Kolmogorov-Arnold 表示定理方程,并且仅由两层组成。

第一层使用一组单变量函数对每个输入进行变换。

第二层对这些变换进行求和并输出最终预测。

但当扩展到学习复杂的现实世界函数时,KAN 与 MLP 一样由多个层组成,其中每一层的输出都是下一层的输入。

与 MLP 不同,在 KAN 中,激活函数是可学习的B样条。

由于 KAN 中发生的所有操作都是可微分的,因此可以使用反向传播和传统损失函数来训练它们。

下表简要总结了 KAN 和 MLP 之间的差异。

KAN模型测试

KAN刚发布时,我尝试了其在MNIST数据集的效果,核心架构如下:

import torch
import torch.nn as nn

class SimpleKAN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleKAN, self).__init__()
        self.u_funcs = nn.ModuleList([nn.Linear(1, hidden_size) for _ in range(input_size)])
        self.v_funcs = nn.ModuleList([nn.Linear(hidden_size, 1) for _ in range(input_size)])
        self.w_funcs = nn.Linear(input_size, output_size)
    
    def forward(self, x):
        u_outputs = [torch.relu(u(x[:, i:i+1])) for i, u in enumerate(self.u_funcs)]
        v_outputs = [v(u) for v, u in zip(self.v_funcs, u_outputs)]
        stacked_v = torch.cat(v_outputs, dim=1)
        output = self.w_funcs(stacked_v)
        return output

我自己测试得到的准确率为93%

准确率有93%,看起来不错,但:

模型在MNIST上准确率

精确度仅仅超过了一层的MLP,距离一些卷积神经网络的98%,99%准确率还有一定距离。

当时我觉得这个模型可能并没有宣称的那么好。。

最近,随着越来越多的炒作,KAN变体层出不穷:

U-KAN / Swin-UKAN / TransUKAN / VisionKAN /MambaKAN /ResKAN/AttnKAN...

(有些是我编的但可以预见的是,以后肯定会有。)

效果真的好吗?我自己尝试了一下KAN-UNet

模型结构如下:

import torch
from torch import nn
import torch.nn.functional as F

from KANUmain.fastkanconv import FastKANConvLayer

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, device):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.device = device

        self.double_conv = nn.Sequential(
            FastKANConvLayer(self.in_channels, self.out_channels//2, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
            nn.BatchNorm2d(self.out_channels//2),
            nn.ReLU(inplace=True),
            FastKANConvLayer(self.out_channels//2, self.out_channels, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)
    
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels, device='mps'):
        super().__init__()
        self.device = device
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels, device=self.device)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
    
class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True, device='mps'):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, device=device)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
        
    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
    
class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = FastKANConvLayer(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class KANU_Net(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True, device='mps'):
        super(KANU_Net, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        self.device = device

        self.channels = [64, 128, 256, 512, 1024]

        self.inc = (DoubleConv(n_channels, 64, device=self.device))
        
        self.down1 = (Down(self.channels[0], self.channels[1], self.device))
        self.down2 = (Down(self.channels[1], self.channels[2], self.device))
        self.down3 = (Down(self.channels[2], self.channels[3], self.device))
        factor = 2 if bilinear else 1
        self.down4 = (Down(self.channels[3], self.channels[4] // factor, self.device))
        self.up1 = (Up(self.channels[4], self.channels[3] // factor, bilinear, self.device))
        self.up2 = (Up(self.channels[3], self.channels[2] // factor, bilinear, self.device))
        self.up3 = (Up(self.channels[2], self.channels[1] // factor, bilinear, self.device))
        self.up4 = (Up(self.channels[1], self.channels[0], bilinear, self.device))
        self.outc = (OutConv(self.channels[0], n_classes))

    def forward(self, x):
        # Encoder
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        
        #Decoder
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

然后我试了一下经典的分割任务。效果非常差。

最近也得知了消息,KAN最终被NeurIPS 2024拒稿

Reviewer估计是,一开始信了KAN的各种炒作,然后代码一跑直接沉默了。

全世界都在吹KAN,然而实践是检验真理的唯一标准,也没有见到一个SOTA的具体KAN实例出来。

结语

学术界也同样有庞氏骗局,见到热点,也不管真假,抓紧加入骗局,在一个精心设置的数据集上调参数,精挑细选最好的结果,把测试集“不小心”混一点训练数据进去等等等。上面的还算良心点的,验证怀疑有些文章指标都是造假的,反正大家都吹好,也不差我这一篇。

现在发表顶刊好好做研究的人越来越少了,都是想着如何包装,如何忽悠,如何炒作。

科研人共勉

求求你点个在看吧,这对我真的很重要



地学万事屋
分享先进Matlab、R、Python、GEE地学应用,以及分享制图攻略。
 最新文章