必看!如何使用Python深度学习提取卫星图像中的道路

科技   2024-11-07 18:55   广东  


引言

卫星图像分割是一项计算机视觉任务,旨在将图像划分为多个片段或区域,以简化其表示。在本项目中,我们专注于对卫星图像进行分割,识别道路区域。这对于城市规划、交通管理和基础设施建设等多个应用领域至关重要。

本项目包含了一个用于卫星图像分割的代码库,目标是通过深度学习技术识别道路区域。我们将在卫星图像数据上训练一个分割模型,并在新的卫星图像上进行预测。

数据集和使用的库

在进行实战之前,先简单介绍一下本项目中使用的数据集和库:

  • 数据集:我们将使用DeepGlobe Road Extraction数据集进行模型训练,该数据集包含了标记有道路片段的高分辨率卫星图像。
  • :项目实现依赖于多个Python库,包括OpenCV、NumPy、PyTorch和Segmentation Models PyTorch等,这些库提供了图像处理、数据操作和深度学习模型训练的基本功能。

1、数据预处理与模型训练

首先进行数据的预处理和模型训练,下面是代码的关键部分:

import os
import cv2
import numpy as np
import pandas as pd
import random
import tqdm
import seaborn as sns
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
import warnings
warnings.filterwarnings("ignore")
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import albumentations as album

# Load metadata and preprocess data
DATA_DIR = r'C:\Users\MohammadalizadehkorM\Downloads\DeepGlobe Road Extraction Dataset'

metadata_df = pd.read_csv(os.path.join(DATA_DIR, 'metadata.csv'))
metadata_df = metadata_df[metadata_df['split']=='train']
metadata_df = metadata_df[['image_id''sat_image_path''mask_path']]
metadata_df['sat_image_path'] = metadata_df['sat_image_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
metadata_df['mask_path'] = metadata_df['mask_path'].apply(lambda img_pth: os.path.join(DATA_DIR, img_pth))
# Shuffle DataFrame
metadata_df = metadata_df.sample(frac=1).reset_index(drop=True)

# Perform 90/10 split for train / val
valid_df = metadata_df.sample(frac=0.1, random_state=42)
train_df = metadata_df.drop(valid_df.index)
len(train_df), len(valid_df)

# Define class names and RGB values
class_dict = pd.read_csv(os.path.join(DATA_DIR, 'class_dict.csv'))
class_names = class_dict['name'].tolist()
class_rgb_values = class_dict[['r','g','b']].values.tolist()

# Useful to shortlist specific classes in datasets with large number of classes
select_classes = ['background''road']

# Get RGB values of required classes
select_class_indices = [class_names.index(cls.lower()) for cls in select_classes]
select_class_rgb_values =  np.array(class_rgb_values)[select_class_indices]

# Define helper functions
def visualize(**images):
    """
    Plot images in one row
    "
""
    n_images = len(images)
    plt.figure(figsize=(20,8))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([]);
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_',' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

def one_hot_encode(label, label_values):
    """
    Convert a segmentation image label array to one-hot format
    by replacing each pixel value with a vector of length num_classes
    "
""
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)
    return semantic_map

def reverse_one_hot(image):
    """
    Transform a 2D array in one-hot format (depth is num_classes),
    to a 2D array with only 1 channel, where each pixel value is
    the classified class key.
    "
""
    x = np.argmax(image, axis=-1)
    return x

def colour_code_segmentation(image, label_values):
    """
    Given a 1-channel array of class keys, colour code the segmentation results.
    "
""
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]
    return x

# Define dataset class
class RoadsDataset(torch.utils.data.Dataset):
    def __init__(
            self,
            df,
            class_rgb_values=None,
            augmentation=None,
            preprocessing=None,
    ):
        self.image_paths = df['sat_image_path'].tolist()
        self.mask_paths = df['mask_path'].tolist()
        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def __getitem__(self, i):
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        return image, mask

    def __len__(self):
        return len(self.image_paths)

# Define preprocessing and augmentation functions
def get_training_augmentation():
    train_transform = [
        album.HorizontalFlip(p=0.5),
        album.VerticalFlip(p=0.5),
    ]
    return album.Compose(train_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn=None):
    _transform = []
    if preprocessing_fn:
        _transform.append(album.Lambda(image=preprocessing_fn))
    _transform.append(album.Lambda(image=to_tensor, mask=to_tensor))
    return album.Compose(_transform)

# Create dataset instances
dataset = RoadsDataset(train_df, class_rgb_values=select_class_rgb_values)
random_idx = random.randint(0, len(dataset)-1)
image, mask = dataset[2]

visualize(
    original_image=image,
    ground_truth_mask=colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
    one_hot_encoded_mask=reverse_one_hot(mask)
)

# Create augmented dataset
augmented_dataset = RoadsDataset(
    train_df,
    augmentation=get_training_augmentation(),
    class_rgb_values=select_class_rgb_values,
)

random_idx = random.randint(0, len(augmented_dataset)-1)

# Different augmentations on image/mask pairs
for idx in range(3):
    image, mask = augmented_dataset[idx]
    visualize(
        original_image=image,
        ground_truth_mask=colour_code_segmentation(reverse_one_hot(mask), select_class_rgb_values),
        one_hot_encoded_mask=reverse_one_hot(mask)
    )

# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = select_classes
ACTIVATION = 'sigmoid'

# Create segmentation model
model = smp.DeepLabV3Plus(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    classes=len(CLASSES),
    activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

# Get train and val dataset instances
train_dataset = RoadsDataset(
    train_df,
    augmentation=get_training_augmentation(),
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)

valid_dataset = RoadsDataset(
    valid_df,
    preprocessing=get_preprocessing(preprocessing_fn),
    class_rgb_values=select_class_rgb_values,
)

# Get train and val data loaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=4)

if __name__ == '__main__':
    # Set flag to train the model or not
    TRAINING = True
    EPOCHS = 3
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    loss = smp.utils.losses.DiceLoss()
    metrics = [smp.utils.metrics.IoU(threshold=0.5)]
    optimizer = torch.optim.Adam([dict(params=model.parameters(), lr=0.00008),])

    train_epoch = smp.utils.train.TrainEpoch(
        model,
        loss=loss,
        metrics=metrics,
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )

    valid_epoch = smp.utils.train.ValidEpoch(
        model,
        loss=loss,
        metrics=metrics,
        device=DEVICE,
        verbose=True,
    )

    if TRAINING:
        best_iou_score = 0.0
        train_logs_list, valid_logs_list = [], []

        for i in range(0, EPOCHS):
            print('\nEpoch: {}'.format(i))
            train_logs = train_epoch.run(train_loader)
            valid_logs = valid_epoch.run(valid_loader)
            train_logs_list.append(train_logs)
            valid_logs_list.append(valid_logs)

            if best_iou_score < valid_logs['iou_score']:
                best_iou_score = valid_logs['iou_score']
                torch.save(model, './best_model.pth')
                print('Model saved!')

该部分代码主要展示了如何加载数据集、创建数据加载器并定义和训练深度学习模型。使用的模型是基于ResNet50的DeepLabV3Plus。

2、对新图像进行推理

一旦我们有了训练好的模型,就可以用它来对新的卫星图像进行推理,以检测道路。该部分代码如下:

# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background''road']
ACTIVATION = 'sigmoid'

# Define preprocessing function
def preprocess_image(image):
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
    image = preprocessing_fn(image)
    return image.transpose(2, 0, 1).astype('float32')

# Load the model
model = torch.load('best_model.pth', map_location=torch.device('cpu'))  # Load the model on CPU

# Load and preprocess your single TIF file
input_image_path = r'path_to_your_input_image.tif'
input_image = cv2.imread(input_image_path)
input_image = cv2.cvtColor(input_image, cv2.COLOR_BGR2RGB)

# Pad the input image to make dimensions divisible by 16
h, w, _ = input_image.shape
new_h = int(np.ceil(h / 16) * 16)
new_w = int(np.ceil(w / 16) * 16)
pad_top = (new_h - h) // 2
pad_bottom = new_h - h - pad_top
pad_left = (new_w - w) // 2
pad_right = new_w - w - pad_left
input_image = cv2.copyMakeBorder(input_image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)

input_image = preprocess_image(input_image)

# Perform inference
with torch.no_grad():
    input_tensor = torch.from_numpy(input_image).unsqueeze(0)
    model.eval()
    output = model(input_tensor)

# Process the output as needed
output_mask = output.squeeze().cpu().numpy()  # Remove batch dimension and move to CPU
predicted_class_index = np.argmax(output_mask, axis=0)  # Get the index of the class with the highest probability

# Assuming road class is class 1, create binary mask for road class
road_mask = (predicted_class_index == 1).astype(np.uint8) * 255

# Save the output mask
output_path = r'path_to_output_mask.png'
cv2.imwrite(output_path, road_mask)  # Save the road mask as an image

3、输出结果

在某些情况下,保留输出结果中的地理信息是很重要的。以下代码展示了如何使用GDAL库来保存带有地理信息的道路检测结果:


from osgeo import gdal

# Define model parameters
ENCODER = 'resnet50'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background''road']
ACTIVATION = 'sigmoid'

# Define preprocessing function
def preprocess_image(image):
    preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
    # Ensure input image is in RGB format
    if image.shape[2] == 4:
        image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
    # Pad the image to make its dimensions divisible by 16
    h, w, _ = image.shape
    new_h = int(np.ceil(h / 16) * 16)
    new_w = int(np.ceil(w / 16) * 16)
    pad_top = (new_h - h) // 2
    pad_bottom = new_h - h - pad_top
    pad_left = (new_w - w) // 2
    pad_right = new_w - w - pad_left
    image = cv2.copyMakeBorder(image, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
    # Apply preprocessing function
    image = preprocessing_fn(image)
    return image.transpose(2, 0, 1).astype('float32')

# Specify the path to the model checkpoint file
model_checkpoint_path = 'best_model.pth'

# Check if the model checkpoint file exists
if not os.path.exists(model_checkpoint_path):
    raise FileNotFoundError(f"Model checkpoint file '{model_checkpoint_path}' not found.")

# Load the model
model = torch.load(model_checkpoint_path, map_location=torch.device('cpu'))  # Load the model on CPU

# Load and preprocess your single TIF file
input_image_path = r'path_to_your_input_image.tif'

# Open the image using GDAL to retain geospatial information
ds = gdal.Open(input_image_path)
input_image = np.transpose(ds.ReadAsArray(), (1, 2, 0))

# Preprocess the input image
input_image = preprocess_image(input_image)

# Perform inference
with torch.no_grad():
    input_tensor = torch.from_numpy(input_image).unsqueeze(0)
    model.eval()
    output = model(input_tensor)

# Process the output as needed
output_mask = output.squeeze().cpu().numpy()  # Remove batch dimension and move to CPU
predicted_class_index = np.argmax(output_mask, axis=0)  # Get the index of the class with the highest probability

# Assuming road class is class 1, create binary mask for road class
road_mask = (predicted_class_index == 1).astype(np.uint8) * 255

# Get the geotransform and projection from the input image
geotransform = ds.GetGeoTransform()
projection = ds.GetProjection()

# Save the output mask with geospatial information
output_path = r'path_to_output_mask_with_GCS.tif'
driver = gdal.GetDriverByName('GTiff')
output_ds = driver.Create(output_path, road_mask.shape[1], road_mask.shape[0], 1, gdal.GDT_Byte)
output_ds.SetGeoTransform(geotransform)
output_ds.SetProjection(projection)
output_ds.GetRasterBand(1).WriteArray(road_mask)
output_ds = None

结语

在本教程中,我们介绍了从数据预处理和模型训练,到推理和输出结果的整个流程。通过这些步骤,我们可以构建自己的道路检测系统,并将其应用于实际场景中。

数据下载链接:https://www.kaggle.com/datasets/balraj98/deepglobe-road-extraction-dataset?resource=download

点点GIS
一点GIS,一点Python,一点杂谈
 最新文章