登顶Nature,准备起飞!KAN-UNet又杀疯了(有代码)

文摘   2024-12-02 23:58   北京  

KAN-UNet遥感应用

最近Nature大子刊Nature reviews electrical engineering发布了综述,深度学习在遥感树木监测的应用:

Brandt, M., Chave, J., Li, S., Fensholt, R., Ciais, P., Wigneron, J.-P., Gieseke, F., Saatchi, S., Tucker, C. J., & Igel, C. (2024). High-resolution sensors and deep learning models for tree resource monitoring. Nature Reviews Electrical Engineering. https://doi.org/10.1038/s44287-024-00116-8

本文总结了高分辨率卫星和传感器技术的发展,结合人工智能技术(CNN/ViT/UNet)的应用,推动了树木三维结构(如树冠高度和木材体积)的精准监测。

模型其实也不复杂,是基于UNet的,如图所示

那我们能不能基于最近大火的KAN,把最底层的MLP换成KAN,从而设计出更好的模型KAN-UNet,从而也发一篇Nature呢?(开玩笑)

模型结构

模型结构也很容易理解,就是把核心层换成KAN了:

  1. 「特征提取阶段(黄色块):」
  • 编码路径使用多层卷积块,每层的分辨率逐渐降低(如 H/2H/2H/2, H/4H/4H/4, H/8H/8H/8 等),通道数量 CiC_iCi 逐渐增加,用于提取多层次的特征。
  • 解码路径则逐层上采样,同时融合来自编码路径的特征。
  • 「Tokenized KAN Block(绿色块):」
    • 中间特征经过「Tokenization」处理后送入 「KAN 层」(知识增强层),通过深度学习网络建模特征之间的复杂关系。
    • KAN 层基于结构化知识进行多阶段特征交互,如图中所示分为三个阶段 Φ1,Φ2,Φ3
    • 该模块还包括 「Depthwise Convolution(深度卷积)」「Layer Normalization(层归一化)」,进一步提升处理能力。
  • 「时间嵌入(Time Embedding):」
    • 当模型应用于扩散式 U-KAN(Diffusion U-KAN)时,会注入时间嵌入(红色圆圈表示),以实现动态特征表征。
  • 「模块连接方式:」
    • 整个网络采用经典的 U-Net 跳跃连接设计,将编码路径的中间层直接与解码路径对接,以保持特征细节和语义信息的完整性。

    总体来看,U-KAN 模型通过整合 U-Net 的多尺度分割能力与 KAN 的知识增强特性,在处理高复杂度分割任务时具备强大的特征表达能力。

    代码

    不多说了,直接放代码,感兴趣的同学自己也可以发一篇Nature

    import torch
    from torch import nn
    import torch.nn.functional as F

    from KANUmain.src.fastkanconv import FastKANConvLayer

    class DoubleConv(nn.Module):
        """(convolution => [BN] => ReLU) * 2"""

        def __init__(self, in_channels, out_channels, device):
            super().__init__()
            self.in_channels = in_channels
            self.out_channels = out_channels
            self.device = device

            self.double_conv = nn.Sequential(
                FastKANConvLayer(self.in_channels, self.out_channels//2, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
                nn.BatchNorm2d(self.out_channels//2),
                nn.ReLU(inplace=True),
                FastKANConvLayer(self.out_channels//2, self.out_channels, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
                nn.BatchNorm2d(self.out_channels),
                nn.ReLU(inplace=True)
            )

        def forward(self, x):
            return self.double_conv(x)
        
    class Down(nn.Module):
        """Downscaling with maxpool then double conv"""

        def __init__(self, in_channels, out_channels, device='mps'):
            super().__init__()
            self.device = device
            self.maxpool_conv = nn.Sequential(
                nn.MaxPool2d(2),
                DoubleConv(in_channels, out_channels, device=self.device)
            )

        def forward(self, x):
            return self.maxpool_conv(x)
        
    class Up(nn.Module):
        """Upscaling then double conv"""

        def __init__(self, in_channels, out_channels, bilinear=True, device='mps'):
            super().__init__()

            # if bilinear, use the normal convolutions to reduce the number of channels
            if bilinear:
                self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
                self.conv = DoubleConv(in_channels, out_channels, device=device)
            else:
                self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
                self.conv = DoubleConv(in_channels, out_channels)
            
        def forward(self, x1, x2):
            x1 = self.up(x1)
            # input is CHW
            diffY = x2.size()[2] - x1.size()[2]
            diffX = x2.size()[3] - x1.size()[3]

            x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                            diffY // 2, diffY - diffY // 2])
            x = torch.cat([x2, x1], dim=1)
            return self.conv(x)
        
    class OutConv(nn.Module):
        def __init__(self, in_channels, out_channels):
            super(OutConv, self).__init__()
            self.conv = FastKANConvLayer(in_channels, out_channels, kernel_size=1)

        def forward(self, x):
            return self.conv(x)

    class KANU_Net(nn.Module):
        def __init__(self, n_channels, n_classes, bilinear=True, device='mps'):
            super(KANU_Net, self).__init__()
            self.n_channels = n_channels
            self.n_classes = n_classes
            self.bilinear = bilinear
            self.device = device

            self.channels = [64, 128, 256, 512, 1024]

            self.inc = (DoubleConv(n_channels, 64, device=self.device))
            
            self.down1 = (Down(self.channels[0], self.channels[1], self.device))
            self.down2 = (Down(self.channels[1], self.channels[2], self.device))
            self.down3 = (Down(self.channels[2], self.channels[3], self.device))
            factor = 2 if bilinear else 1
            self.down4 = (Down(self.channels[3], self.channels[4] // factor, self.device))
            self.up1 = (Up(self.channels[4], self.channels[3] // factor, bilinear, self.device))
            self.up2 = (Up(self.channels[3], self.channels[2] // factor, bilinear, self.device))
            self.up3 = (Up(self.channels[2], self.channels[1] // factor, bilinear, self.device))
            self.up4 = (Up(self.channels[1], self.channels[0], bilinear, self.device))
            self.outc = (OutConv(self.channels[0], n_classes))

        def forward(self, x):
            # Encoder
            x1 = self.inc(x)
            x2 = self.down1(x1)
            x3 = self.down2(x2)
            x4 = self.down3(x3)
            x5 = self.down4(x4)
            
            #Decoder
            x = self.up1(x5, x4)
            x = self.up2(x, x3)
            x = self.up3(x, x2)
            x = self.up4(x, x1)
            logits = self.outc(x)
            return logits

    if __name__ == "__main__":
        device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
        # print(device)
        model = KANU_Net(6, 6, 'mps').to(device)
        # print(model)
        x = torch.randn((1, 6, 224, 224)).to(device)
        print(model(x).shape)

    地学万事屋
    分享先进Matlab、R、Python、GEE地学应用,以及分享制图攻略。
     最新文章