应用场景
应用某个函数计算某个量的时候,需要进行循环计算(如对年份、网格等),且计算过程各自独立。为了提速,有多种办法,比如说python的multiprocessing
库,但有的时候用这个库计算的时候有些奇怪的问题。回归原始的话,可以借助Bash来实现多进程并行计算。
用法
比如我有一个脚本(如下所示),用来循环计算每一年的ERA5数据的统计量,由于直接运行只会用到单个进程(multiprocessing
库的方法即为被我注释掉的部分)。每个年份串行计算,速度会很慢。
import xarray as xr
import numpy as np
import os
import glob
#from multiprocessing import Process
dir_path0 = '/test1/ERA5_dataset/single_level/ERA5_1021'
dir_path = '2m_dewpoint_temperature'
var_surface = 'd2m'
indicator = ['mean', 'std', 'min', 'max']
def statistic_one_year(year):
directory = f'/test1/lacar_zhxia/data/ERA5_static'
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
if os.path.exists(f'/test1/lacar_zhxia/data/ERA5_static/static_result_var_surface_{year}.npy'):
print(f'year: {year} already exists')
return
static_result = np.zeros((len(indicator)))
print(f'year: {year}, var: {var_surface}')
file_list = np.sort(glob.glob(os.path.join(dir_path0,dir_path,dir_path) + '_' + str(year) +'*.nc'))
data = xr.open_mfdataset(file_list, combine='nested', concat_dim='valid_time')
data = data[f'{var_surface}']
static_result[0] = data.mean(dim=["valid_time", "longitude", "latitude"]).values
static_result[1] = data.std(dim=["valid_time", "longitude", "latitude"]).values
static_result[2] = data.min(dim=['valid_time','latitude','longitude']).values
static_result[3] = data.max(dim=['valid_time','latitude','longitude']).values
np.save(f'/test1/lacar_zhxia/data/ERA5_static/static_result_var_surface_{year}.npy', static_result)
if __name__ == '__main__':
for year in range(1979, 2023+1):
result = statistic_one_year(year)
#processes = []
#for year in range(1979, 2023+1):
# p = Process(target=statistic_one_year, args=(year,))
# p.start()
# processes.append(p)
#for p in processes:
# p.join()
# p.close()
如果不用multiprocessing
库,笨的办法就是手动在python里面拆分年份,多次提交。但现在我可以将这个Python脚本转换成Bash脚本中调用python脚本,并在Bash脚本中循环年份,自动提交多个进程(可以自行设定最大进程数,实现接续提交),如下所示:
#!/bin/bash
# Bash脚本:run_statistic_era5.sh
# Python脚本的路径
PYTHON_SCRIPT="/path/to/statistic_era5.py"
# 循环的年份范围
START_YEAR=1979
END_YEAR=2023
# 使用数组来存储年份
YEARS=($(seq $START_YEAR $END_YEAR))
# 提交进程的最大数量,可以运行nproc或者lscpu看核心数来合理设定
MAX_PROCESSES=4
# 用于存储子进程PID的数组
PIDS=()
# 函数:等待子进程结束
wait_for_children() {
for PID in "${PIDS[@]}"; do
wait $PID
done
PIDS=() # 清空PIDS数组
}
# 循环年份
for YEAR in "${YEARS[@]}"; do
# 如果当前运行的进程数达到最大值,则等待一个进程结束
if [ ${#PIDS[@]} -ge $MAX_PROCESSES ]; then
wait_for_children
fi
# 提交新的Python进程
python $PYTHON_SCRIPT $YEAR &
PID=$!
PIDS+=($PID)
echo "Submitted job for year $YEAR with PID $PID"
done
# 等待所有子进程结束
wait_for_children
echo "All jobs completed."
同时需要对原python脚本略微修改:
import sys
# 原有的代码...
if __name__ == '__main__':
if len(sys.argv) != 2:
print("Usage: python statistic_era5.py <year>")
sys.exit(1)
#在Python中,sys.argv是一个列表,它包含了命令行运行Python脚本时传递给脚本的参数。sys.argv的第一个元素(sys.argv[0])总是脚本的名称,其余的元素是传递给脚本的参数。
year = int(sys.argv[1])
statistic_one_year(year)
然后直接运行
bash run_statistic_era5.sh
自此,你就摆脱了multiprocessing
,开始拥抱Bash
.