KAN-UNet遥感应用
最近Nature大子刊Nature reviews electrical engineering发布了综述,深度学习在遥感树木监测的应用:
Brandt, M., Chave, J., Li, S., Fensholt, R., Ciais, P., Wigneron, J.-P., Gieseke, F., Saatchi, S., Tucker, C. J., & Igel, C. (2024). High-resolution sensors and deep learning models for tree resource monitoring. Nature Reviews Electrical Engineering. https://doi.org/10.1038/s44287-024-00116-8
本文总结了高分辨率卫星和传感器技术的发展,结合人工智能技术(CNN/ViT/UNet)的应用,推动了树木三维结构(如树冠高度和木材体积)的精准监测。
模型其实也不复杂,是基于UNet的,如图所示
那我们能不能基于最近大火的KAN,把最底层的MLP换成KAN,从而设计出更好的模型KAN-UNet,从而也发一篇Nature呢?(开玩笑)
模型结构
模型结构也很容易理解,就是把核心层换成KAN了:
「特征提取阶段(黄色块):」
编码路径使用多层卷积块,每层的分辨率逐渐降低(如 H/2H/2H/2, H/4H/4H/4, H/8H/8H/8 等),通道数量 CiC_iCi 逐渐增加,用于提取多层次的特征。 解码路径则逐层上采样,同时融合来自编码路径的特征。
中间特征经过「Tokenization」处理后送入 「KAN 层」(知识增强层),通过深度学习网络建模特征之间的复杂关系。 KAN 层基于结构化知识进行多阶段特征交互,如图中所示分为三个阶段 Φ1,Φ2,Φ3 该模块还包括 「Depthwise Convolution(深度卷积)」 和 「Layer Normalization(层归一化)」,进一步提升处理能力。
当模型应用于扩散式 U-KAN(Diffusion U-KAN)时,会注入时间嵌入(红色圆圈表示),以实现动态特征表征。
整个网络采用经典的 U-Net 跳跃连接设计,将编码路径的中间层直接与解码路径对接,以保持特征细节和语义信息的完整性。
总体来看,U-KAN 模型通过整合 U-Net 的多尺度分割能力与 KAN 的知识增强特性,在处理高复杂度分割任务时具备强大的特征表达能力。
代码
不多说了,直接放代码,感兴趣的同学自己也可以发一篇Nature
import torch
from torch import nn
import torch.nn.functional as F
from KANUmain.src.fastkanconv import FastKANConvLayer
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, device):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.device = device
self.double_conv = nn.Sequential(
FastKANConvLayer(self.in_channels, self.out_channels//2, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
nn.BatchNorm2d(self.out_channels//2),
nn.ReLU(inplace=True),
FastKANConvLayer(self.out_channels//2, self.out_channels, padding=1, kernel_size=3, stride=1, kan_type='RBF'),
nn.BatchNorm2d(self.out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels, device='mps'):
super().__init__()
self.device = device
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels, device=self.device)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True, device='mps'):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, device=device)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = FastKANConvLayer(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
class KANU_Net(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True, device='mps'):
super(KANU_Net, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.device = device
self.channels = [64, 128, 256, 512, 1024]
self.inc = (DoubleConv(n_channels, 64, device=self.device))
self.down1 = (Down(self.channels[0], self.channels[1], self.device))
self.down2 = (Down(self.channels[1], self.channels[2], self.device))
self.down3 = (Down(self.channels[2], self.channels[3], self.device))
factor = 2 if bilinear else 1
self.down4 = (Down(self.channels[3], self.channels[4] // factor, self.device))
self.up1 = (Up(self.channels[4], self.channels[3] // factor, bilinear, self.device))
self.up2 = (Up(self.channels[3], self.channels[2] // factor, bilinear, self.device))
self.up3 = (Up(self.channels[2], self.channels[1] // factor, bilinear, self.device))
self.up4 = (Up(self.channels[1], self.channels[0], bilinear, self.device))
self.outc = (OutConv(self.channels[0], n_classes))
def forward(self, x):
# Encoder
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
#Decoder
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
# print(device)
model = KANU_Net(6, 6, 'mps').to(device)
# print(model)
x = torch.randn((1, 6, 224, 224)).to(device)
print(model(x).shape)
优质实惠的GPT-4(进群即可免费体验3天,名额有限,火速进群!)