物理机制+神经网络=Nature!教你NeuralGCM正刊气象大模型!(含代码)

文摘   2024-12-14 13:45   北京  


神经网络与物理机制的结合——NeuralGCM

在天气和气候模拟的研究中,传统的通用环流模型(General Circulation Models, GCMs)是基于物理规律的模拟工具,长期以来一直是预报和研究地球系统的重要支柱。然而传统的物理模型计算总是比较慢的。

近期的Nature正刊,引入了神经网络与物理机制的结合——NeuralGCM。这种方法将机器学习组件与可微分的大气动力核心整合,实现了端到端的在线训练,能够在1到15天的天气预报中与最先进的物理和机器学习方法相媲美,同时以数十倍的计算效率完成数年乃至数十年的气候模拟。

这次我们来推理一下NeuralGCM大模型看看效果!

环境

首先安装dinosaur,neuralgcm,jax等库

最好还要有GPU环境,加快推理计算

! pip install -q -U neuralgcm dinosaur-dycore gcsfs

然后导入必要的库

import gcsfs
import jax
import numpy as np
import pickle
import xarray

from dinosaur import horizontal_interpolation
from dinosaur import spherical_harmonic
from dinosaur import xarray_utils
import neuralgcm

gcs = gcsfs.GCSFileSystem(token='anon')

加载模型

然后我们就要加载预训练好的模型

Pre-trained model checkpoints from the NeuralGCM paper areavailable for download on Google Cloud Storage:

  • Deterministic models:
    • gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_0_7_deg.pkl
    • gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl
    • gs://gresearch/neuralgcm/04_30_2024/neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl

根据上述描述,我们可以加载0.7°,1.4°和2.8°模型,这里我们选择中间的模型

model_name = 'neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl'  #@param ['neural_gcm_dynamic_forcing_deterministic_0_7_deg.pkl', 'neural_gcm_dynamic_forcing_deterministic_1_4_deg.pkl', 'neural_gcm_dynamic_forcing_deterministic_2_8_deg.pkl', 'neural_gcm_dynamic_forcing_stochastic_1_4_deg.pkl'] {type: "string"}

with gcs.open(f'gs://gresearch/neuralgcm/04_30_2024/{model_name}''rb') as f:
  ckpt = pickle.load(f)

model = neuralgcm.PressureLevelModel.from_checkpoint(ckpt)

加载ERA5驱动数据

然后我们就可以下载驱动数据进行训练了,这里我们直接使用推荐的zarr数据,就不用自己下载了

era5_path = 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3'
full_era5 = xarray.open_zarr(gcs.get_mapper(era5_path), chunks=None)

demo_start_time = '2020-02-14'
demo_end_time = '2020-02-18'
data_inner_steps = 24  # process every 24th hour

sliced_era5 = (
    full_era5
    [model.input_variables + model.forcing_variables]
    .pipe(
        xarray_utils.selective_temporal_shift,
        variables=model.forcing_variables,
        time_shift='24 hours',
    )
    .sel(time=slice(demo_start_time, demo_end_time, data_inner_steps))
    .compute()
)

然后把ERA5的驱动改为NeuralGCM适应的驱动分辨率1.4°

era5_grid = spherical_harmonic.Grid(
    latitude_nodes=full_era5.sizes['latitude'],
    longitude_nodes=full_era5.sizes['longitude'],
    latitude_spacing=xarray_utils.infer_latitude_spacing(full_era5.latitude),
    longitude_offset=xarray_utils.infer_longitude_offset(full_era5.longitude),
)
regridder = horizontal_interpolation.ConservativeRegridder(
    era5_grid, model.data_coords.horizontal, skipna=True
)
eval_era5 = xarray_utils.regrid(sliced_era5, regridder)
eval_era5 = xarray_utils.fill_nan_with_nearest(eval_era5)

模型推理

''然后就可以预测了:

inner_steps = 24  # save model outputs once every 24 hours
outer_steps = 4 * 24 // inner_steps  # total of 4 days
timedelta = np.timedelta64(1, 'h') * inner_steps
times = (np.arange(outer_steps) * inner_steps)  # time axis in hours

# initialize model state
inputs = model.inputs_from_xarray(eval_era5.isel(time=0))
input_forcings = model.forcings_from_xarray(eval_era5.isel(time=0))
rng_key = jax.random.key(42)  # optional for deterministic models
initial_state = model.encode(inputs, input_forcings, rng_key)

# use persistence for forcing variables (SST and sea ice cover)
all_forcings = model.forcings_from_xarray(eval_era5.head(time=1))

# make forecast
final_state, predictions = model.unroll(
    initial_state,
    all_forcings,
    steps=outer_steps,
    timedelta=timedelta,
    start_with_input=True,
)
predictions_ds = model.data_to_xarray(predictions, times=times)

最后predictions_ds就是我们的推理结果

这里有所有高空层的变量和结果:

模型预测的各个变量

以比湿为例,可以可视化预测和模型结果

# Visualize ERA5 vs NeuralGCM trajectories
combined_ds.specific_humidity.sel(level=850).plot(
    x='longitude', y='latitude', row='time', col='model', robust=True, aspect=2, size=2
);
将预测的比湿与ERA5真实值对比

最后模型如果代码错误,似乎有几个小bug,要改一下NeuralGCM库中的transforms.pyapi.py相应报错处就行了,应该是Google开发人员的一些小错误,自行更正就行了。


地学万事屋
分享先进Matlab、R、Python、GEE地学应用,以及分享制图攻略。
 最新文章