前言
好久不见!即将放假,本来想赶着九月的尾巴分享一个很好用的半自动标注工具,结果只写了一半,九月的最后一天又在匆匆离去,所以只好先分享两个小函数水一下hhhhhh
祝大家国庆七天快乐!
将影像裁剪成指定大小的样本
主要是用于在制作样本时从一幅栅格影像上快速自动裁剪成指定大小的样本块,对于不满足指定大小的边界区域进行填充
def sample_clip(file_path, out_path, CropSize, image_id):
"""
将影像裁剪为指定大小的样本块
file_path: 输入路径
out_path: 输出路径
CropSize: 样本大小
image_id: 样本前缀,用来避免重名
"""
image = gdal.Open(file_path)
driver = gdal.GetDriverByName('GTiff')
width = image.RasterXSize
height = image.RasterYSize
geotransform = image.GetGeoTransform()
projection = image.GetProjection()
bands = image.RasterCount
# 左上角x坐标
top_left_x = geotransform[0]
# 东西方向像素分辨率
w_e_pixel_resolution = geotransform[1]
# 左上角y坐标
top_left_y = geotransform[3]
n_s_pixel_resolution = geotransform[5]
col_num = int(width / CropSize) # 宽度可以分成几块
row_num = int(height / CropSize) # 高度可以分成几块
if width % CropSize != 0:
col_num += 1
if height % CropSize != 0:
row_num += 1
for i in range(row_num):
for j in range(col_num):
offset_x = j * CropSize
offset_y = i * CropSize
b_xsize = min(width - offset_x, CropSize)
b_ysize = min(height - offset_y, CropSize)
top_left_x_cropped = top_left_x + offset_x * w_e_pixel_resolution
top_left_y_cropped = top_left_y + offset_y * n_s_pixel_resolution
dst_transform = (
top_left_x_cropped, geotransform[1], geotransform[2], top_left_y_cropped, geotransform[4],
geotransform[5])
cropped = image.ReadAsArray(offset_x, offset_y, b_xsize, b_ysize)
# 输出路径构建
out_name = os.path.join(out_path, f'{image_id}_{i}_{j}.tif')
out_ds = driver.Create(out_name, CropSize, CropSize, bands, gdal.GDT_Byte)
out_ds.SetGeoTransform(dst_transform)
out_ds.SetProjection(projection)
for b in range(bands):
out_ds.GetRasterBand(b + 1).WriteArray(cropped[b, :, :])
out_ds = None
更改样本标签值为指定值
在样本标注过程中可能会将很多类细分,有时候会突发奇想的合并这些同类标签
def change_img_value(file_path, out_path, value, ds_value):
"""
将图像样本标签值设置为目标值
file_path: 输入路径
out_path: 输出路径
value: 标签值
ds_value: 目标值
Returns:
"""
files = os.listdir(file_path)
for file in files:
file_data = Image.open(os.path.join(file_path, file))
image_arr = np.array(file_data)
label_data = np.where(image_arr == value, ds_value, image_arr)
label_data = Image.fromarray(label_data)
label_data.save(os.path.join(out_path, file))