引言
卫星图像分割是一项计算机视觉任务,旨在将图像划分为多个片段或区域,以简化其表示。在本项目中,我们专注于对卫星图像进行分割,识别道路区域。这对于城市规划、交通管理和基础设施建设等多个应用领域至关重要。
本项目包含了一个用于卫星图像分割的代码库,目标是通过深度学习技术识别道路区域。我们将在卫星图像数据上训练一个分割模型,并在新的卫星图像上进行预测。
数据集和使用的库
在进行实战之前,先简单介绍一下本项目中使用的数据集和库:
数据集:我们将使用DeepGlobe Road Extraction数据集进行模型训练,该数据集包含了标记有道路片段的高分辨率卫星图像。 库:项目实现依赖于多个Python库,包括OpenCV、NumPy、PyTorch和Segmentation Models PyTorch等,这些库提供了图像处理、数据操作和深度学习模型训练的基本功能。
1、数据预处理与模型训练
首先进行数据的预处理和模型训练,下面是代码的关键部分:
import os
import cv2
import numpy as np
import pandas as pd
import random
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album
# Load metadata and preprocess data
DATA_DIR = r'C:\Users\MohammadalizadehkorM\Downloads\DeepGlobe Road Extraction Dataset'
metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
metadata_df = metadata_df[metadata_df['split']=='train']
metadata_df = metadata_df[['image_id', 'sat_image_path', 'mask_path']]
metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# Shuffle DataFrame
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)
# Perform 90/10 split for train / val
valid_df = metadata_df.sample(frac=0.1, random_state=42)
train_df = metadata_df.drop(valid_df.index)
len(train_df), len(valid_df)
# Define class names and RGB values
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'class_dict.csv'))
class_names = class_dict['name'].tolist()
class_rgb_values = class_dict[['r','g','b']].values.tolist()
# Useful to shortlist specific classes in datasets with large number of classes
select_classes = ['background', 'road']
# Get RGB values of required classes
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values = np.array(class_rgb_values)[select_class_indices]
# Define helper functions
def visualize(**images):
"""
Plot images in one row
"""
n_images = len(images)
plt.figure(figsize=(20,8))
for idx, (name, image) in enumerate(images.items()):
plt.subplot(1, n_images, idx + 1)
plt.xticks([]);
plt.yticks([])
# get title from the parameter names
plt.title(name.replace('_',' ').title(), fontsize=20)
plt.imshow(image)
plt.show()
def one_hot_encode(label, label_values):
"""
Convert a segmentation image label array to one-hot format
by replacing each pixel value with a vector of length num_classes
"""
semantic_map = []
for colour in label_values:
equality = np.equal(label, colour)
class_map = np.all(equality, axis=-1)
semantic_map.append(class_map)
semantic_map = np.stack(semantic_map, axis=-1)
return semantic_map
def reverse_one_hot(image):
"""
Transform a 2D array in one-hot format (depth is num_classes),
to a 2D array with only 1 channel, where each pixel value is
the classified class key.
"""
x = np.argmax(image, axis=-1)
return x
def colour_code_segmentation(image, label_values):
"""
Given a 1-channel array of class keys, colour code the segmentation results.
"""
colour_codes = np.array(label_values)
x = colour_codes[image.astype(int)]
return x
# Define dataset class
class RoadsDataset(torch.utils.data.Dataset):
def __init__(
self,
df,
class_rgb_values=None,
augmentation=None,
preprocessing=None,
):
self.image_paths = df['sat_image_path'].tolist()
self.mask_paths = df['mask_path'].tolist()
self.class_rgb_values = class_rgb_values
self.augmentation = augmentation
self.preprocessing = preprocessing
def __getitem__(self, i):
image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
if self.augmentation:
sample = self.augmentation(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
if self.preprocessing:
sample = self.preprocessing(image=image, mask=mask)
image, mask = sample['image'], sample['mask']
return image, mask
def __len__(self):
return len(self.image_paths)
# Define preprocessing and augmentation functions
def get_training_augmentation():
train_transform = [
album.HorizontalFlip(p=0.5),
album.VerticalFlip(p=0.5),
]
return album.Compose(train_transform)
def to_tensor(x, **kwargs):
return x.transpose(2, 0, 1).astype('float32')
def get_preprocessing(preprocessing_fn=None):
_transform = []
if preprocessing_fn:
_transform.append(album.Lambda(image=preprocessing_fn))
_transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
return album.Compose(_transform)
# Create dataset instances
dataset = RoadsDataset(train_df, class_rgb_values=select_class_rgb_values)
random_idx = random.randint(0, len(dataset)-1)
image, mask = dataset[2]
visualize(
original_image=image,
ground_truth_mask=colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
one_hot_encoded_mask=reverse_one_hot(mask)
)
# Create augmented dataset
augmented_dataset = RoadsDataset(
train_df,
augmentation=get_training_augmentation(),
class_rgb_values=select_class_rgb_values,
)
random_idx = random.randint(0, len(augmented_dataset)-1)
# Different augmentations on image/mask pairs
for idx in range(3):
image, mask = augmented_dataset[idx]
visualize(
original_image=image,
ground_truth_mask=colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
one_hot_encoded_mask=reverse_one_hot(mask)
)
# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = select_classes
ACTIVATION = 'sigmoid'
# Create segmentation model
model = smp.DeepLabV3Plus(
encoder_name=ENCODER,
encoder_weights=ENCODER_WEIGHTS,
classes=len(CLASSES),
activation=ACTIVATION,
)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
# Get train and val dataset instances
train_dataset = RoadsDataset(
train_df,
augmentation=get_training_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
class_rgb_values=select_class_rgb_values,
)
valid_dataset = RoadsDataset(
valid_df,
preprocessing=get_preprocessing(preprocessing_fn),
class_rgb_values=select_class_rgb_values,
)
# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=4)
if __name__ == '__main__':
# Set flag to train the model or not
TRAINING = True
EPOCHS = 3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5)]
optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.00008),])
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=DEVICE,
verbose=True,
)
valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)
if TRAINING:
best_iou_score = 0.0
train_logs_list, valid_logs_list = [], []
for i in range(0, EPOCHS):
print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(valid_loader)
train_logs_list.append(train_logs)
valid_logs_list.append(valid_logs)
if best_iou_score < valid_logs['iou_score']:
best_iou_score = valid_logs['iou_score']
torch.save(model, './best_model.pth')
print('Model saved!')
该部分代码主要展示了如何加载数据集、创建数据加载器并定义和训练深度学习模型。使用的模型是基于ResNet50的DeepLabV3Plus。
2、对新图像进行推理
一旦我们有了训练好的模型,就可以用它来对新的卫星图像进行推理,以检测道路。该部分代码如下:
# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background', 'road']
ACTIVATION = 'sigmoid'
# Define preprocessing function
def preprocess_image(image):
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
image = preprocessing_fn(image)
return image.transpose(2, 0, 1).astype('float32')
# Load the model
model = torch.load('best_model.pth', map_location=torch.device('cpu')) # Load the model on CPU
# Load and preprocess your single TIF file
input_image_path = r'path_to_your_input_image.tif'
input_image = cv2.imread(input_image_path)
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)
# Pad the input image to make dimensions divisible by 16
h, w, _ = input_image.shape
new_h = int(np.ceil(h / 16) * 16)
new_w = int(np.ceil(w / 16) * 16)
pad_top = (new_h - h) // 2
pad_bottom = new_h - h - pad_top
pad_left = (new_w - w) // 2
pad_right = new_w - w - pad_left
input_image = cv2.copyMakeBorder(input_image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
input_image = preprocess_image(input_image)
# Perform inference
with torch.no_grad():
input_tensor = torch.from_numpy(input_image).unsqueeze(0)
model.eval()
output = model(input_tensor)
# Process the output as needed
output_mask = output.squeeze().cpu().numpy() # Remove batch dimension and move to CPU
predicted_class_index = np.argmax(output_mask, axis=0) # Get the index of the class with the highest probability
# Assuming road class is class 1, create binary mask for road class
road_mask = (predicted_class_index == 1).astype(np.uint8) * 255
# Save the output mask
output_path = r'path_to_output_mask.png'
cv2.imwrite(output_path, road_mask) # Save the road mask as an image
3、输出结果
在某些情况下,保留输出结果中的地理信息是很重要的。以下代码展示了如何使用GDAL库来保存带有地理信息的道路检测结果:
from osgeo import gdal
# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background', 'road']
ACTIVATION = 'sigmoid'
# Define preprocessing function
def preprocess_image(image):
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
# Ensure input image is in RGB format
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
# Pad the image to make its dimensions divisible by 16
h, w, _ = image.shape
new_h = int(np.ceil(h / 16) * 16)
new_w = int(np.ceil(w / 16) * 16)
pad_top = (new_h - h) // 2
pad_bottom = new_h - h - pad_top
pad_left = (new_w - w) // 2
pad_right = new_w - w - pad_left
image = cv2.copyMakeBorder(image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
# Apply preprocessing function
image = preprocessing_fn(image)
return image.transpose(2, 0, 1).astype('float32')
# Specify the path to the model checkpoint file
model_checkpoint_path = 'best_model.pth'
# Check if the model checkpoint file exists
if not os.path.exists(model_checkpoint_path):
raise FileNotFoundError(f"Model checkpoint file '{model_checkpoint_path}' not found.")
# Load the model
model = torch.load(model_checkpoint_path, map_location=torch.device('cpu')) # Load the model on CPU
# Load and preprocess your single TIF file
input_image_path = r'path_to_your_input_image.tif'
# Open the image using GDAL to retain geospatial information
ds = gdal.Open(input_image_path)
input_image = np.transpose(ds.ReadAsArray(), (1, 2, 0))
# Preprocess the input image
input_image = preprocess_image(input_image)
# Perform inference
with torch.no_grad():
input_tensor = torch.from_numpy(input_image).unsqueeze(0)
model.eval()
output = model(input_tensor)
# Process the output as needed
output_mask = output.squeeze().cpu().numpy() # Remove batch dimension and move to CPU
predicted_class_index = np.argmax(output_mask, axis=0) # Get the index of the class with the highest probability
# Assuming road class is class 1, create binary mask for road class
road_mask = (predicted_class_index == 1).astype(np.uint8) * 255
# Get the geotransform and projection from the input image
geotransform = ds.GetGeoTransform()
projection = ds.GetProjection()
# Save the output mask with geospatial information
output_path = r'path_to_output_mask_with_GCS.tif'
driver = gdal.GetDriverByName('GTiff')
output_ds = driver.Create(output_path, road_mask.shape[1], road_mask.shape[0], 1, gdal.GDT_Byte)
output_ds.SetGeoTransform(geotransform)
output_ds.SetProjection(projection)
output_ds.GetRasterBand(1).WriteArray(road_mask)
output_ds = None
结语
在本教程中,我们介绍了从数据预处理和模型训练,到推理和输出结果的整个流程。通过这些步骤,我们可以构建自己的道路检测系统,并将其应用于实际场景中。
数据下载链接:https://www.kaggle.com/datasets/balraj98/deepglobe-road-extraction-dataset?resource=download