现在做地学深度学习不用TorchGeo也是神人了...

文摘   2024-10-26 00:00   北京  

必学!新一代地学深度学习神器TorchGeo

引言

TorchGeo 是一个PyTorch 库,提供特定于地理空间数据的数据集、采样器、转换和预训练模型。

目标是使地学从业者能够更轻松地对地理空间数据使用 Deep Learning 模型。

由于种种原因,无论是对于地学从业者还是计算机学者,在大范围遥感数据集使用深度学习都存在一些挑战:

  • 地学学生不太了解深度学习流程,Python语句,数据构建等等。
  • 计算机学生不太了解遥感图像,包括遥感图像预处理,采样,等等。

在Dan Morris(Microsoft AI for Earth 计划前首席科学家)去年 IEEE-GRSS的演讲中,他强调了与地理空间分析相关的一些挑战

  • 除非您拥有遥感博士学位,否则使用地理空间数据很痛苦;
  • 除非您拥有分布式计算博士学位,否则处理非常大的数据是一件痛苦的事情;
  • 一线的地学从业者通常不具备上述任何一种情况。

最重要的是,使用人工智能进行地理空间分析的人会带来额外的复杂性,因为大多数框架都是为 RGB 图片开发的,并没有考虑到地理空间数据的特殊性:

  • 由于遥感图像尺度原因无法通过神经网络(遥感图像通常是高分辨率,整张不可能进入神经网络);
  • 不同的尺度;
  • 不同的投影和 CRS;
  • 时间分量;和其他。

因此,目前,对于不了解这些不同主题的人来说,将深度学习模型应用于地理空间任务确实具有挑战性。

内容

TorchGeo 库2022年就推出,以应对其中一些挑战。

直到2024年,可用的文档(视频、教程等)仍然非常有限。

这里我写一个教程,制作一个完整的项目。包括:

  • 数据集加载;
  • 为图像创建 RasterDatasets、DataLoader 和 Samplers;
  • 裁剪数据集;
  • 对数据进行标准化;
  • 创建光谱索引;
  • 创建深度学习模型 (U-Net);
  • 损失函数;
  • 训练
  • 测试
  • 评估结果

环境

使用 Pip 或者 Conda 命令安装。

pip install rasterio
pip install torchgeo
# checking both insallations
import rasterio as rio
import torchgeo

数据

我们将使用的数据集是地球地表水数据集 [1](根据 Creative Commons Attribution 4.0 International Public License 授权),其中包含来自世界不同地区的补丁(图 1)及其相应的水面罩。该数据集使用来自 Sentinel-2 卫星的光学图像,空间分辨率为 10m。

该数据集的另一个优点是我们有性能基准来比较我们的结果:

解压缩后,数据结构如图,其中tra代表训练集,代表val验证集。

from pathlib import Path
import xarray as xr
import matplotlib.pyplot as plt

root = Path('D:/Onedrive/Acdemic/GEE_CNN/dset-s2')
assert root.exists()

train_imgs = list((root/'tra_scene').glob('*.tif'))
train_masks = list((root/'tra_truth').glob('*.tif'))

# As the images and corresponding masks are matched by name, we will sort both lists to keep them synchronized.
train_imgs.sort(); train_masks.sort()
idx = 0
img = xr.open_rasterio(train_imgs[idx])
mask = xr.open_rasterio(train_masks[idx])
_, axs = plt.subplots(1, 2, figsize=(15, 6))

# plot the tile
rgb = img.data[[2, 1, 0]].transpose((1, 2, 0))/3000
axs[0].imshow(rgb.clip(min=0, max=1))

# plot the mask
axs[1].imshow(mask.data.squeeze(), cmap='Blues')

创建RasterDataset

在传统深度学习中,我们一般要设计DataLoader,这里则采用RasterDataset

我们可以准备将其加载到神经网络中。为此,我们将使用以下命令(并将在序列中使用)创建由 TorchGeo 提供的类的实例,并指向特定目录:

from torchgeo.datasets import RasterDataset, unbind_samples, stack_samples

train_ds = RasterDataset(root=(root/’tra_scene’).as_posix(), crs='epsg:3395', res=10)

请注意,我们将 CRS(坐标参考系)指定为 。TorchGeo 要求所有图像都加载到同一个 CRS 中。

采样

传统cv中,每张图像维度都是相同的,或采用公共数据集,而无需采样。

许多遥感应用涉及使用地理空间数据集——具有地理元数据的数据集。在 TorchGeo 中,我们定义了一个 GeoDataset 类来表示这些类型的数据集。每个 GeoDataset 不是按整数索引,而是按时空边界框索引,这意味着可以智能组合覆盖不同地理范围的两个或多个数据集。

要创建可从数据集馈送到神经网络的训练补丁,我们需要选择固定大小的样本。TorchGeo 有很多采样器,但这里我们将使用随机采样,然后,这些边界框用于查询我们想要的图像部分。

from torchgeo.samplers import RandomGeoSampler
sampler = RandomGeoSampler(train, size=(512, 512), length=100)

size 是我们想要的训练补丁的形状,length 是这个采样器将作为一个 epoch 提供的补丁数量。

要从数据集中随机绘制一个项目,我们可以为一个边界框调用采样器,然后它们将此边界框传递给数据集。结果将是一个包含以下条目的字典:image、crs 和 bbox。

import torch 

# this is to get the same result in every pass
torch.manual_seed(0)

bbox = next(iter(sampler))
sample = train_ds[bbox]
sample1 = truth_ds[bbox]

print(sample.keys())
print(sample['image'].shape)
import torch
import matplotlib.pyplot as plt

_, axs = plt.subplots(1, 2, figsize=(15, 6))

arr = torch.clamp(sample['image']/10000, min=0, max=1).numpy()
rgb = arr.transpose(2, 1, 0)[:, :, [2, 1 , 0]]

arr1 = torch.clamp(sample1['image']/10000, min=0, max=1).numpy()
#rgb1 = arr.transpose(2, 1, 0)[:, :, [2, 1 , 0]]


axs[0].imshow(rgb.clip(min=0, max=1))
axs[1].imshow(arr1[0], cmap='Blues')

这样随机采样好的图(带有label)可以直接用于训练!很方便。

总结

我们了解了如何在 TorchGeo 中创建 RasterDataset,以及如何使用 Sampler 从中绘制固定大小的样本。这是我们工作流程的第一步。在后续的教程,我会介绍模型以及训练~


地学万事屋
分享先进Matlab、R、Python、GEE地学应用,以及分享制图攻略。
 最新文章