Python | 使用shp批量裁剪栅格及工具开发(文末粉丝抽奖活动!!!)

文摘   2024-10-30 07:30   天津  

LXX

读完需要

3
分钟

速读仅需 1 分钟

前言:

我们通常裁剪遥感影像栅格是在 Arcgis 中进行,类似于 ArcGIS 里面 Clip 工具箱的功能,就是使用一个 shapefile 文件对栅格数据进行裁剪,这个功能在平常中经常用到,如果需求是实现批量裁剪影像就有点繁琐。因此本文介绍通过 Python 进行栅格的批量裁剪,并制作一个简单的小程序以便后续使用。


希望各位同学点个关注,点个小赞,这将是更新的动力,不胜感激❥(^_-)

1

   

方法

在这里主要使用到一个 Rasterio 库,专门用于栅格数据的读写操作。它支持多种栅格数据格式,如 GeoTIFF、ENVI 和 HDF5,为处理和分析栅格数据提供了强大的工具。RasterIO 适用于各种栅格数据应用,如卫星遥感、地图制作等。通过 RasterIO,用户可以方便地读取、写入和操作栅格数据,提高数据处理效率。此外,RasterIO 还支持自定义栅格数据类型和变换操作,具有很高的灵活性和可扩展性。总的来说,RasterIO 是一个功能强大、易用的栅格数据处理库,对于遥感、地理信息系统等领域的数据处理和分析具有重要意义。


流程如下:

(1)文件路径和编码设置:定义了输入的 tif 文件路径和 Shapefile 文件路径,设置了常见的编码列表,以便读取 Shapefile 文件。

(2)读取 Shapefile 文件:使用不同的编码尝试读取 Shapefile 文件,检查并剔除空字段。成功读取后,将 Shapefile 文件的数据存入 GeoDataFrame 中,若读取失败则抛出错误。读取 tif 文件并检查波段数:

(3)读取 tif 文件,输出栅格数据的坐标参考系统 (CRS) 和波段数量。根据波段数选择用于 RGB 显示的三个波段,并将这些波段数据组合成三维数组。(4)检查像元值范围:查看选择波段中的像元值范围,以便后续可视化。

(5)CRS 匹配与转换:如果 Shapefile 和 tif 文件的 CRS 不匹配,则将 Shapefile 转换到 tif 的 CRS。

(6)几何转换与原始数据可视化:将 Shapefile 的几何数据转换为 GeoJSON 格式,便于后续裁剪。可视化原始 RGB 图像并叠加 Shapefile 边界,以 magma 色带显示。

(7)批量裁剪 tif 文件:批量处理 input_folder 文件夹中的所有 .tif 文件,使用 mask 函数裁剪每个文件以适应 Shapefile 边界。更新裁剪后的元数据,并将结果保存到 output_folder 文件夹中。

(8)随机图像可视化:随机选择一个裁剪后的图像进行可视化,展示裁剪后图像的 RGB 组合,设置无色区域的透明背景并叠加 Shapefile 边界。(9)处理时间输出:记录并输出处理时间,以及栅格的 CRS 信息。

2


   

代码

import geopandas as gpdimport fionaimport osimport rasteriofrom rasterio.mask import maskfrom shapely.geometry import mappingimport matplotlib.pyplot as pltfrom tqdm import tqdmimport timeimport numpy as npimport random
# 文件路径infile = r'D:/other/GZH/Python/S2/input/sentinel-2-1.tif'shpfile = r'D:/other/GZH/Python/shp/yibin.shp'
# 定义常见编码列表encodings = ["utf-8", "ISO-8859-1", "GBK"]
# 读取 Shapefile 文件gdf = Nonefor enc in encodings: try: with fiona.open(shpfile, encoding=enc) as src: # 检查并去除空字段 valid_fields = {key: value for key, value in src.schema['properties'].items() if key.strip()} gdf = gpd.read_file(shpfile, include_fields=list(valid_fields.keys()), encoding=enc) print(f"Successfully loaded Shapefile with encoding: {enc}") break except Exception as e: print(f"Failed to load with encoding {enc}: {e}")if gdf is None: raise ValueError("Unable to load Shapefile with provided encodings.")
# 读取 tif 文件并检查波段数with rasterio.open(infile) as src: raster_crs = src.crs raster_bounds = src.bounds num_bands = src.count # 获取波段数 print(f"The image has {num_bands} bands.")
# 假设用户想要选择 3 个波段用于 RGB 显示 # 可以在此处更改为任意 3 个有效波段 selected_bands = [1, 2, 3] if num_bands >= 3 else list(range(1, num_bands + 1)) img_data = np.dstack([src.read(band) for band in selected_bands]) # 读取选择的波段
# 查看所选波段的像元值范围min_pixel_value, max_pixel_value = img_data.min(), img_data.max()print(f"Pixel Value Range for selected bands: {min_pixel_value} to {max_pixel_value}")
# 如果 Shapefile 和栅格数据的 CRS 不匹配,则转换 Shapefile 的 CRSif gdf.crs != raster_crs: gdf = gdf.to_crs(raster_crs)print(f"Shapefile CRS: {gdf.crs}")
# 几何转换为 GeoJSON 格式geoms = [mapping(geom) for geom in gdf.geometry]
# 可视化原始 RGB 图像和 Shapefile 边界,使用 `magma` 色带条plt.figure(figsize=(10, 10))plt.imshow(img_data, cmap='magma', vmin=min_pixel_value, vmax=max_pixel_value, extent=(raster_bounds.left, raster_bounds.right, raster_bounds.bottom, raster_bounds.top))gdf.boundary.plot(ax=plt.gca(), color='red', linewidth=1, label="Shapefile Boundary")plt.colorbar(label="Pixel Intensity (magma)", cmap='magma')plt.legend(loc="upper right")plt.title("Original RGB Image with Shapefile Overlay (Selected Bands)")plt.show()
# 批量裁剪 tif 文件input_folder = r'D:/other/GZH/Python/S2/input'output_folder = r'D:/other/GZH/Python/S2/OUT'
# 记录开始时间start_time = time.time()processed_files = []
for tif_file in tqdm(os.listdir(input_folder), desc="Processing files", unit="file"): if tif_file.endswith('.tif'): in_file = os.path.join(input_folder, tif_file) output_file = os.path.join(output_folder, tif_file) try: with rasterio.open(in_file) as src: out_image, out_transform = mask(src, geoms, crop=True, filled=True) out_meta = src.meta.copy()
# 将裁剪后像元值为 0 的区域设置为无色 out_image = np.ma.masked_equal(out_image, 0)
# 更新元数据 out_meta.update({"driver": "GTiff", "height": out_image.shape[1], "width": out_image.shape[2], "transform": out_transform})
# 保存裁剪后的栅格数据 with rasterio.open(output_file, "w", **out_meta) as dest: dest.write(out_image)
print(f"Processed file saved to: {output_file}") processed_files.append(output_file)
except Exception as e: print(f"Error processing {tif_file}: {e}")
# 随机选择一张裁剪后的图像进行可视化if processed_files: random_file = random.choice(processed_files) with rasterio.open(random_file) as src: cropped_image = np.dstack([src.read(band, masked=True) for band in selected_bands]) cropped_bounds = src.bounds min_cropped_value, max_cropped_value = cropped_image.min(), cropped_image.max()
# 可视化裁剪后的图像,使用 `magma` 色带条,并设置无色的透明背景 plt.figure(figsize=(10, 10)) plt.imshow(cropped_image, cmap='magma', vmin=min_cropped_value, vmax=max_cropped_value, extent=(cropped_bounds.left, cropped_bounds.right, cropped_bounds.bottom, cropped_bounds.top)) plt.colorbar(label="Pixel Intensity (magma)", cmap='magma') gdf.boundary.plot(ax=plt.gca(), color='blue', linewidth=1, label="Shapefile Boundary") plt.legend(loc="upper right") plt.title("Cropped RGB Image with Shapefile Overlay (Transparent Background for Value 0)") plt.show()
# 打印处理时间end_time = time.time()print("All files processed.")print(f"Total processing time: {end_time - start_time:.2f} seconds")print(f"Raster CRS: {raster_crs}")

裁剪结果

裁剪前

裁剪后

3

   

工具代码

将以上代码整合为一个小工具,以便后续使用。代码如下:

import sysimport osimport timeimport numpy as npimport geopandas as gpdimport fiona  # 导入 fionaimport rasteriofrom rasterio.mask import maskfrom shapely.geometry import mappingfrom PyQt5.QtWidgets import (    QApplication, QMainWindow, QLabel, QProgressBar, QPushButton, QVBoxLayout, QWidget,    QFileDialog, QLineEdit, QGridLayout, QHBoxLayout)from PyQt5.QtCore import Qt, QThread, pyqtSignalimport matplotlib.pyplot as pltfrom matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvasfrom matplotlib.figure import Figurefrom PyQt5.QtWidgets import QMessageBox

# 后台处理任务的线程类class WorkerThread(QThread): progress_signal = pyqtSignal(int) time_signal = pyqtSignal(str) finished_signal = pyqtSignal() result_image_signal = pyqtSignal(object)
def __init__(self, input_folder, shpfile, output_folder): super().__init__() self.input_folder = input_folder self.output_folder = output_folder self.shpfile = shpfile
def run(self): start_time = time.time() processed_files = []
# 读取 Shapefile 文件 encodings = ["utf-8", "ISO-8859-1", "GBK"] gdf = None for enc in encodings: try: with fiona.open(self.shpfile, encoding=enc) as src: valid_fields = {key: value for key, value in src.schema['properties'].items() if key.strip()} gdf = gpd.read_file(self.shpfile, include_fields=list(valid_fields.keys()), encoding=enc) print(f"Successfully loaded Shapefile with encoding: {enc}") break except Exception as e: print(f"Failed to load with encoding {enc}: {e}") if gdf is None: print("Unable to load Shapefile with provided encodings.") return
geoms = [mapping(geom) for geom in gdf.geometry]
# 遍历文件夹内的所有 .tif 文件 tif_files = [f for f in os.listdir(self.input_folder) if f.endswith('.tif')] for i, tif_file in enumerate(tif_files): in_file = os.path.join(self.input_folder, tif_file) output_file = os.path.join(self.output_folder, tif_file) try: with rasterio.open(in_file) as src: raster_crs = src.crs raster_bounds = src.bounds num_bands = src.count selected_bands = [1, 2, 3] if num_bands >= 3 else list(range(1, num_bands + 1)) img_data = np.dstack([src.read(band) for band in selected_bands])
# 发出原图数据用于显示 if i == 0: self.result_image_signal.emit(img_data)
out_image, out_transform = mask(src, geoms, crop=True, filled=True) out_meta = src.meta.copy() out_image = np.ma.masked_equal(out_image, 0)
out_meta.update({"driver": "GTiff", "height": out_image.shape[1], "width": out_image.shape[2], "transform": out_transform})
with rasterio.open(output_file, "w", **out_meta) as dest: dest.write(out_image)
processed_files.append(output_file)
# 发出裁剪结果图像信号(只发送最后一张图用于显示) if i == len(tif_files) - 1: self.result_image_signal.emit(out_image) except Exception as e: print(f"Error processing {tif_file}: {e}")
# 更新进度条和运行时间 self.progress_signal.emit(int((i + 1) / len(tif_files) * 100)) elapsed_time = time.time() - start_time self.time_signal.emit(f"运行时间:{elapsed_time:.2f} 秒")
self.finished_signal.emit()

# 主窗口类class MainWindow(QMainWindow): def __init__(self): super().__init__()
# 设置窗口属性 self.setWindowTitle("图像处理软件") self.setGeometry(300, 300, 1200, 800)
# 创建布局 layout = QGridLayout()
# 输入路径选择 self.input_path_edit = QLineEdit(self) self.input_button = QPushButton("选择影像输入路径") self.input_button.clicked.connect(self.select_input_folder) layout.addWidget(self.input_button, 0, 0) layout.addWidget(self.input_path_edit, 0, 1)
# SHP文件路径选择 self.shp_path_edit = QLineEdit(self) self.shp_button = QPushButton("选择SHP文件路径") self.shp_button.clicked.connect(self.select_shp_file) layout.addWidget(self.shp_button, 1, 0) layout.addWidget(self.shp_path_edit, 1, 1)
# 输出路径选择 self.output_path_edit = QLineEdit(self) self.output_button = QPushButton("选择结果输出路径") self.output_button.clicked.connect(self.select_output_folder) layout.addWidget(self.output_button, 2, 0) layout.addWidget(self.output_path_edit, 2, 1)
# 修改后的布局代码 # 将 time_label 放置在 progress 的下方,使其宽度一致
# 进度条 self.progress = QProgressBar(self) layout.addWidget(self.progress, 3, 0, 1, 2) # 占据第3行,第0列到第1列


# 启动按钮 self.start_button = QPushButton("开始处理") self.start_button.clicked.connect(self.start_processing) layout.addWidget(self.start_button, 4, 0, 1, 2)
# 添加原图和结果图可视化窗口 self.input_canvas = MplCanvas(self) layout.addWidget(self.input_canvas, 5, 0,1,2)


# 设置布局 container = QWidget() container.setLayout(layout) self.setCentralWidget(container)
# 线程 self.thread = None
def select_input_folder(self): folder = QFileDialog.getExistingDirectory(self, "选择影像输入文件夹") if folder: self.input_path_edit.setText(folder)
def select_shp_file(self): shpfile, _ = QFileDialog.getOpenFileName(self, "选择SHP文件", "", "Shapefiles (*.shp)") if shpfile: self.shp_path_edit.setText(shpfile)
def select_output_folder(self): folder = QFileDialog.getExistingDirectory(self, "选择结果输出文件夹") if folder: self.output_path_edit.setText(folder)
def start_processing(self): input_folder = self.input_path_edit.text() shpfile = self.shp_path_edit.text() output_folder = self.output_path_edit.text()
if not input_folder or not shpfile or not output_folder: print("请确保所有路径均已填写") return
# 启动后台线程 self.thread = WorkerThread(input_folder, shpfile, output_folder) self.thread.progress_signal.connect(self.update_progress) self.thread.time_signal.connect(self.update_time) self.thread.finished_signal.connect(self.show_finish_message) self.thread.result_image_signal.connect(self.update_visualizations) self.thread.start()
def update_progress(self, value): self.progress.setValue(value)
def update_time(self, time_text): self.time_label.setText(time_text)
# 修改 show_finish_message 方法,弹出消息框 def show_finish_message(self): msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("处理完成") msg.setText("所有文件处理完成!") msg.setStandardButtons(QMessageBox.Ok) msg.exec_()
def update_visualizations(self, image): # 根据图像的波段数确定显示方式 if len(image.shape) == 3 and image.shape[0] >= 3: # 至少3个波段 rgb_image = np.dstack([image[i] for i in range(3)]) # 取前3个波段 self.input_canvas.plot_image(rgb_image) elif len(image.shape) == 2: # 单波段图像 self.result_canvas.plot_image(image, cmap="magma") else: print("图像数据格式不适合显示")

# Matplotlib 画布类class MplCanvas(FigureCanvas): def __init__(self, parent=None, width=5, height=4, dpi=100): fig = Figure(figsize=(width, height), dpi=dpi) self.axes = fig.add_subplot(111) super().__init__(fig)
def plot_image(self, image, cmap=None): self.axes.clear() if cmap: self.axes.imshow(image, cmap=cmap) else: self.axes.imshow(image) # RGB图像不指定 cmap self.draw()

# 主程序入口app = QApplication(sys.argv)window = MainWindow()window.show()sys.exit(app.exec_())


4

   

最后

感谢大家一直以来的支持,时光飞逝马上一年又要结束,在此就提前给大家抽个奖吧,奖品就是《Python地理数据处理》大家后台私信‘抽奖’即可参与,我会随机抽取三位同学❥(^_-)


最后,祝大家天天开心,开心真的很重要!!!



遥感小屋
分享遥感相关文章、代码,大家一起交流,互帮互助
 最新文章