必学!新一代地学深度学习神器TorchGeo
引言
TorchGeo 是一个PyTorch 库,提供特定于地理空间数据的数据集、采样器、转换和预训练模型。
目标是使地学从业者能够更轻松地对地理空间数据使用 Deep Learning 模型。
由于种种原因,无论是对于地学从业者还是计算机学者,在大范围遥感数据集使用深度学习都存在一些挑战:
地学学生不太了解深度学习流程,Python语句,数据构建等等。 计算机学生不太了解遥感图像,包括遥感图像预处理,采样,等等。
在Dan Morris(Microsoft AI for Earth 计划前首席科学家)去年 IEEE-GRSS的演讲中,他强调了与地理空间分析相关的一些挑战
除非您拥有遥感博士学位,否则使用地理空间数据很痛苦; 除非您拥有分布式计算博士学位,否则处理非常大的数据是一件痛苦的事情; 一线的地学从业者通常不具备上述任何一种情况。
最重要的是,使用人工智能进行地理空间分析的人会带来额外的复杂性,因为大多数框架都是为 RGB 图片开发的,并没有考虑到地理空间数据的特殊性:
由于遥感图像尺度原因无法通过神经网络(遥感图像通常是高分辨率,整张不可能进入神经网络); 不同的尺度; 不同的投影和 CRS; 时间分量;和其他。
因此,目前,对于不了解这些不同主题的人来说,将深度学习模型应用于地理空间任务确实具有挑战性。
内容
TorchGeo 库2022年就推出,以应对其中一些挑战。
直到2024年,可用的文档(视频、教程等)仍然非常有限。
这里我写一个教程,制作一个完整的项目。包括:
数据集加载; 为图像创建 RasterDatasets、DataLoader 和 Samplers; 裁剪数据集; 对数据进行标准化; 创建光谱索引; 创建深度学习模型 (U-Net); 损失函数; 训练 测试 评估结果
环境
使用 Pip 或者 Conda 命令安装。
pip install rasterio
pip install torchgeo
# checking both insallations
import rasterio as rio
import torchgeo
数据
我们将使用的数据集是地球地表水数据集 [1](根据 Creative Commons Attribution 4.0 International Public License 授权),其中包含来自世界不同地区的补丁(图 1)及其相应的水面罩。该数据集使用来自 Sentinel-2 卫星的光学图像,空间分辨率为 10m。
该数据集的另一个优点是我们有性能基准来比较我们的结果:
解压缩后,数据结构如图,其中tra代表训练集,代表val验证集。
from pathlib import Path
import xarray as xr
import matplotlib.pyplot as plt
root = Path('D:/Onedrive/Acdemic/GEE_CNN/dset-s2')
assert root.exists()
train_imgs = list((root/'tra_scene').glob('*.tif'))
train_masks = list((root/'tra_truth').glob('*.tif'))
# As the images and corresponding masks are matched by name, we will sort both lists to keep them synchronized.
train_imgs.sort(); train_masks.sort()
idx = 0
img = xr.open_rasterio(train_imgs[idx])
mask = xr.open_rasterio(train_masks[idx])
_, axs = plt.subplots(1, 2, figsize=(15, 6))
# plot the tile
rgb = img.data[[2, 1, 0]].transpose((1, 2, 0))/3000
axs[0].imshow(rgb.clip(min=0, max=1))
# plot the mask
axs[1].imshow(mask.data.squeeze(), cmap='Blues')
创建RasterDataset
在传统深度学习中,我们一般要设计DataLoader,这里则采用RasterDataset
我们可以准备将其加载到神经网络中。为此,我们将使用以下命令(并将在序列中使用)创建由 TorchGeo 提供的类的实例,并指向特定目录:
from torchgeo.datasets import RasterDataset, unbind_samples, stack_samples
train_ds = RasterDataset(root=(root/’tra_scene’).as_posix(), crs='epsg:3395', res=10)
请注意,我们将 CRS(坐标参考系)指定为 。TorchGeo 要求所有图像都加载到同一个 CRS 中。
采样
传统cv中,每张图像维度都是相同的,或采用公共数据集,而无需采样。
许多遥感应用涉及使用地理空间数据集——具有地理元数据的数据集。在 TorchGeo 中,我们定义了一个 GeoDataset 类来表示这些类型的数据集。每个 GeoDataset 不是按整数索引,而是按时空边界框索引,这意味着可以智能组合覆盖不同地理范围的两个或多个数据集。
要创建可从数据集馈送到神经网络的训练补丁,我们需要选择固定大小的样本。TorchGeo 有很多采样器,但这里我们将使用随机采样,然后,这些边界框用于查询我们想要的图像部分。
from torchgeo.samplers import RandomGeoSampler
sampler = RandomGeoSampler(train, size=(512, 512), length=100)
size 是我们想要的训练补丁的形状,length 是这个采样器将作为一个 epoch 提供的补丁数量。
要从数据集中随机绘制一个项目,我们可以为一个边界框调用采样器,然后它们将此边界框传递给数据集。结果将是一个包含以下条目的字典:image、crs 和 bbox。
import torch
# this is to get the same result in every pass
torch.manual_seed(0)
bbox = next(iter(sampler))
sample = train_ds[bbox]
sample1 = truth_ds[bbox]
print(sample.keys())
print(sample['image'].shape)
import torch
import matplotlib.pyplot as plt
_, axs = plt.subplots(1, 2, figsize=(15, 6))
arr = torch.clamp(sample['image']/10000, min=0, max=1).numpy()
rgb = arr.transpose(2, 1, 0)[:, :, [2, 1 , 0]]
arr1 = torch.clamp(sample1['image']/10000, min=0, max=1).numpy()
#rgb1 = arr.transpose(2, 1, 0)[:, :, [2, 1 , 0]]
axs[0].imshow(rgb.clip(min=0, max=1))
axs[1].imshow(arr1[0], cmap='Blues')
这样随机采样好的图(带有label)可以直接用于训练!很方便。
总结
我们了解了如何在 TorchGeo 中创建 RasterDataset,以及如何使用 Sampler 从中绘制固定大小的样本。这是我们工作流程的第一步。在后续的教程,我会介绍模型以及训练~