神经网络与物理机制的结合——NeuralGCM
在天气和气候模拟的研究中,传统的通用环流模型(General Circulation Models, GCMs)是基于物理规律的模拟工具,长期以来一直是预报和研究地球系统的重要支柱。然而传统的物理模型计算总是比较慢的。
近期的Nature正刊,引入了神经网络与物理机制的结合——NeuralGCM。这种方法将机器学习组件与可微分的大气动力核心整合,实现了端到端的在线训练,能够在1到15天的天气预报中与最先进的物理和机器学习方法相媲美,同时以数十倍的计算效率完成数年乃至数十年的气候模拟。
这次我们来推理一下NeuralGCM大模型看看效果!
环境
首先安装dinosaur,neuralgcm,jax等库
最好还要有GPU环境,加快推理计算
! pip install -q -U neuralgcm dinosaur-dycore gcsfs
然后导入必要的库
import gcsfs
import jax
import numpy as np
import pickle
import xarray
from dinosaur import horizontal_interpolation
from dinosaur import spherical_harmonic
from dinosaur import xarray_utils
import neuralgcm
gcs = gcsfs.GCSFileSystem(token='anon')
加载模型
然后我们就要加载预训练好的模型
Pre-trained model checkpoints from the NeuralGCM paper areavailable for download on Google Cloud Storage:
Deterministic models: gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_0_7_deg.pkl
gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl
gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl
根据上述描述,我们可以加载0.7°,1.4°和2.8°模型,这里我们选择中间的模型
model_name = 'neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl' #@param ['neural_gcm_dynamic_forcing_deterministic_0_7_deg.pkl', 'neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl', 'neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl', 'neural_gcm_dynamic_forcing_stochastic_1_4_deg.pkl'] {type: "string"}
with gcs.open(f'gs://gresearch/neuralgcm/04_30_2024/{model_name}', 'rb') as f:
ckpt = pickle.load(f)
model = neuralgcm.PressureLevelModel.from_checkpoint(ckpt)
加载ERA5驱动数据
然后我们就可以下载驱动数据进行训练了,这里我们直接使用推荐的zarr数据,就不用自己下载了
era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(gcs.get_mapper(era5_path), chunks=None)
demo_start_time = '2020-02-14'
demo_end_time = '2020-02-18'
data_inner_steps = 24 # process every 24th hour
sliced_era5 = (
full_era5
[model.input_variables + model.forcing_variables]
.pipe(
xarray_utils.selective_temporal_shift,
variables=model.forcing_variables,
time_shift='24 hours',
)
.sel(time=slice(demo_start_time, demo_end_time, data_inner_steps))
.compute()
)
然后把ERA5的驱动改为NeuralGCM适应的驱动分辨率1.4°
era5_grid = spherical_harmonic.Grid(
latitude_nodes=full_era5.sizes['latitude'],
longitude_nodes=full_era5.sizes['longitude'],
latitude_spacing=xarray_utils.infer_latitude_spacing(full_era5.latitude),
longitude_offset=xarray_utils.infer_longitude_offset(full_era5.longitude),
)
regridder = horizontal_interpolation.ConservativeRegridder(
era5_grid, model.data_coords.horizontal, skipna=True
)
eval_era5 = xarray_utils.regrid(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)
模型推理
''然后就可以预测了:
inner_steps = 24 # save model outputs once every 24 hours
outer_steps = 4 * 24 // inner_steps # total of 4 days
timedelta = np.timedelta64(1, 'h') * inner_steps
times = (np.arange(outer_steps) * inner_steps) # time axis in hours
# initialize model state
inputs = model.inputs_from_xarray(eval_era5.isel(time=0))
input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0))
rng_key = jax.random.key(42) # optional for deterministic models
initial_state = model.encode(inputs, input_forcings, rng_key)
# use persistence for forcing variables (SST and sea ice cover)
all_forcings = model.forcings_from_xarray(eval_era5.head(time=1))
# make forecast
final_state, predictions = model.unroll(
initial_state,
all_forcings,
steps=outer_steps,
timedelta=timedelta,
start_with_input=True,
)
predictions_ds = model.data_to_xarray(predictions, times=times)
最后predictions_ds
就是我们的推理结果
这里有所有高空层的变量和结果:
以比湿为例,可以可视化预测和模型结果
# Visualize ERA5 vs NeuralGCM trajectories
combined_ds.specific_humidity.sel(level=850).plot(
x='longitude', y='latitude', row='time', col='model', robust=True, aspect=2, size=2
);
最后模型如果代码错误,似乎有几个小bug,要改一下NeuralGCM
库中的transforms.py
和api.py
相应报错处就行了,应该是Google开发人员的一些小错误,自行更正就行了。