diff --git a/climanet/st_encoder_decoder.py b/climanet/st_encoder_decoder.py index 7df74ad..7603d93 100644 --- a/climanet/st_encoder_decoder.py +++ b/climanet/st_encoder_decoder.py @@ -549,6 +549,22 @@ def __init__( num_months=num_months, ) self.patch_size = patch_size + + # Store config for easy model replication + self.config = { + 'in_chans': in_chans, + 'embed_dim': embed_dim, + 'patch_size': patch_size, + 'max_days': max_days, + 'max_months': max_months, + 'num_months': num_months, + 'hidden': hidden, + 'overlap': overlap, + 'max_H': max_H, + 'max_W': max_W, + 'spatial_depth': spatial_depth, + 'spatial_heads': spatial_heads, + } def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None): """Forward pass of the Spatio-Temporal model. diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index c6d2e0a..733fdc5 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -8,6 +8,7 @@ "outputs": [], "source": [ "from pathlib import Path\n", + "import dask\n", "import xarray as xr\n", "import torch\n", "import torch.nn.functional\n", @@ -27,22 +28,1746 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "13a3b0c8-1d92-460d-84a4-a3a59ca081af", + "execution_count": null, + "id": "5b1b1129", "metadata": {}, "outputs": [], "source": [ + "# # Normal reading: load full dataset in memory\n", + "\n", + "# # Data folder\n", + "# data_folder = Path(\"../../data/output/\")\n", + "\n", + "# # Training patch size\n", + "# patch_size_training = 80\n", + "\n", + "# # Get all files\n", + "# daily_files = sorted(data_folder.rglob(\"20*_day_ERA5_masked_ts.nc\"))\n", + "# monthly_files = sorted(data_folder.rglob(\"20*_mon_ERA5_full_ts.nc\"))\n", + "\n", + "# daily_data = xr.open_mfdataset(daily_files)\n", + "# monthly_data = xr.open_mfdataset(monthly_files)\n", + " \n", + "# lsm_mask = xr.open_dataset(data_folder / \"era5_lsm_bool.nc\" ) # downloaded from ERA5 and regridded\n", + "# lsm_mask = lsm_mask.rename({\"latitude\": \"lat\", \"longitude\": \"lon\"})[[\"lsm\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "13a3b0c8-1d92-460d-84a4-a3a59ca081af", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:20: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " daily_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lat\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n", + "/tmp/ipykernel_31814/1980159491.py:31: UserWarning: The specified chunks separate the stored chunks along dimension \"lon\" starting at index 40. This could degrade performance. Instead, consider rechunking after loading.\n", + " monthly_data = xr.open_mfdataset(\n" + ] + } + ], + "source": [ + "# For debug, load part of the two year dataset\n", + "\n", + "# Data folder\n", "data_folder = Path(\"../../data/output/\")\n", "\n", - "file_names = [data_folder / \"202001_day_ERA5_masked_ts.nc\", data_folder / \"202002_day_ERA5_masked_ts.nc\"]\n", - "daily_data = xr.open_mfdataset(file_names)\n", + "# (Only for local debug) Subset the data while loading\n", + "# Define ROI once so subsetting happens during file open (not after)\n", + "lon_subset = slice(-10, 10)\n", + "lat_subset = slice(-5, 5)\n", + "\n", + "# Training patch size\n", + "patch_size_training = 20\n", "\n", - "file_names = [data_folder / \"202001_mon_ERA5_full_ts.nc\", data_folder / \"202002_mon_ERA5_full_ts.nc\"]\n", - "monthly_data = xr.open_mfdataset(file_names)\n", + "# Keep only required variable + spatial subset while reading each file\n", + "def _preprocess_roi(ds):\n", + " return ds[[\"ts\"]].sel(lon=lon_subset, lat=lat_subset)\n", "\n", - "file_name = data_folder / \"era5_lsm_bool.nc\" # downloded from era5 and regridded using the function `regrid_to_boundary_centered_grid`\n", + "daily_files = sorted(data_folder.rglob(\"20*_day_ERA5_masked_ts.nc\"))\n", + "monthly_files = sorted(data_folder.rglob(\"20*_mon_ERA5_full_ts.nc\"))\n", + "\n", + "# Use smaller spatial chunks to reduce peak memory per task\n", + "daily_data = xr.open_mfdataset(\n", + " daily_files,\n", + " combine=\"by_coords\",\n", + " preprocess=_preprocess_roi,\n", + " chunks={\"time\": 1, \"lat\": patch_size_training*2, \"lon\": patch_size_training*2},\n", + " data_vars=\"minimal\",\n", + " coords=\"minimal\",\n", + " compat=\"override\",\n", + " parallel=False,\n", + ")\n", + "\n", + "monthly_data = xr.open_mfdataset(\n", + " monthly_files,\n", + " combine=\"by_coords\",\n", + " preprocess=_preprocess_roi,\n", + " chunks={\"time\": 1, \"lat\": patch_size_training*2, \"lon\": patch_size_training*2},\n", + " data_vars=\"minimal\",\n", + " coords=\"minimal\",\n", + " compat=\"override\",\n", + " parallel=False,\n", + ")\n", + "\n", + "file_name = data_folder / \"era5_lsm_bool.nc\" # downloaded from ERA5 and regridded\n", "lsm_mask = xr.open_dataset(file_name)\n", - "lsm_mask = lsm_mask.rename({'latitude': 'lat', 'longitude': 'lon'})" + "lsm_mask = lsm_mask.rename({\"latitude\": \"lat\", \"longitude\": \"lon\"})[[\"lsm\"]].sel(lon=lon_subset, lat=lat_subset)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7d13e24a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 9MB\n",
+       "Dimensions:  (time: 731, lat: 40, lon: 80)\n",
+       "Coordinates:\n",
+       "  * time     (time) datetime64[ns] 6kB 2020-01-01T11:30:00 ... 2021-12-31T11:...\n",
+       "  * lat      (lat) float32 160B -4.875 -4.625 -4.375 ... 4.375 4.625 4.875\n",
+       "  * lon      (lon) float32 320B -9.875 -9.625 -9.375 ... 9.375 9.625 9.875\n",
+       "Data variables:\n",
+       "    ts       (time, lat, lon) float32 9MB dask.array<chunksize=(1, 20, 40), meta=np.ndarray>\n",
+       "Attributes:\n",
+       "    CDI:          Climate Data Interface version 2.2.4 (https://mpimet.mpg.de...\n",
+       "    Conventions:  CF-1.6\n",
+       "    history:      Tue Feb 03 08:53:20 2026: cdo daymean /work/bd0854/b380103/...\n",
+       "    frequency:    day\n",
+       "    CDO:          Climate Data Operators version 2.2.2 (https://mpimet.mpg.de...
" + ], + "text/plain": [ + " Size: 9MB\n", + "Dimensions: (time: 731, lat: 40, lon: 80)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 6kB 2020-01-01T11:30:00 ... 2021-12-31T11:...\n", + " * lat (lat) float32 160B -4.875 -4.625 -4.375 ... 4.375 4.625 4.875\n", + " * lon (lon) float32 320B -9.875 -9.625 -9.375 ... 9.375 9.625 9.875\n", + "Data variables:\n", + " ts (time, lat, lon) float32 9MB dask.array\n", + "Attributes:\n", + " CDI: Climate Data Interface version 2.2.4 (https://mpimet.mpg.de...\n", + " Conventions: CF-1.6\n", + " history: Tue Feb 03 08:53:20 2026: cdo daymean /work/bd0854/b380103/...\n", + " frequency: day\n", + " CDO: Climate Data Operators version 2.2.2 (https://mpimet.mpg.de..." + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "daily_data" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "db9e7465", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 308kB\n",
+       "Dimensions:  (time: 24, lat: 40, lon: 80)\n",
+       "Coordinates:\n",
+       "  * time     (time) datetime64[ns] 192B 2020-01-16T11:30:00 ... 2021-12-16T11...\n",
+       "  * lat      (lat) float32 160B -4.875 -4.625 -4.375 ... 4.375 4.625 4.875\n",
+       "  * lon      (lon) float32 320B -9.875 -9.625 -9.375 ... 9.375 9.625 9.875\n",
+       "Data variables:\n",
+       "    ts       (time, lat, lon) float32 307kB dask.array<chunksize=(1, 20, 40), meta=np.ndarray>\n",
+       "Attributes:\n",
+       "    CDI:          Climate Data Interface version 2.2.4 (https://mpimet.mpg.de...\n",
+       "    Conventions:  CF-1.6\n",
+       "    history:      Tue Feb 03 08:53:10 2026: cdo monmean /work/bd0854/b380103/...\n",
+       "    frequency:    mon\n",
+       "    CDO:          Climate Data Operators version 2.2.2 (https://mpimet.mpg.de...
" + ], + "text/plain": [ + " Size: 308kB\n", + "Dimensions: (time: 24, lat: 40, lon: 80)\n", + "Coordinates:\n", + " * time (time) datetime64[ns] 192B 2020-01-16T11:30:00 ... 2021-12-16T11...\n", + " * lat (lat) float32 160B -4.875 -4.625 -4.375 ... 4.375 4.625 4.875\n", + " * lon (lon) float32 320B -9.875 -9.625 -9.375 ... 9.375 9.625 9.875\n", + "Data variables:\n", + " ts (time, lat, lon) float32 307kB dask.array\n", + "Attributes:\n", + " CDI: Climate Data Interface version 2.2.4 (https://mpimet.mpg.de...\n", + " Conventions: CF-1.6\n", + " history: Tue Feb 03 08:53:10 2026: cdo monmean /work/bd0854/b380103/...\n", + " frequency: mon\n", + " CDO: Climate Data Operators version 2.2.2 (https://mpimet.mpg.de..." + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "monthly_data" ] }, { @@ -50,7 +1775,7 @@ "id": "b4be672b-a2c1-4a6f-89a1-2c22cc6cde02", "metadata": {}, "source": [ - "## Subset data (for fast example)" + "## (Only for local debug) Subset the data" ] }, { @@ -94,11 +1819,12 @@ } ], "source": [ - "# create monthly mean to predict mean and std per month\n", - "daily_subset_averaged = daily_subset[\"ts\"].resample(time=\"MS\").mean(skipna=True)\n", - "mean = daily_subset_averaged.mean(dim=[\"lat\", \"lon\"], skipna=True).values\n", - "std = daily_subset_averaged.std(dim=[\"lat\", \"lon\"], skipna=True).values\n", - "print(f'mean: {mean}, std: {std}')" + "# Compute monthly climatology stats without persisting the full (time, lat, lon) monthly field\n", + "monthly_ts = daily_data[\"ts\"].resample(time=\"MS\").mean(skipna=True)\n", + "mean = monthly_ts.mean(dim=[\"lat\", \"lon\"], skipna=True).compute().values\n", + "std = monthly_ts.std(dim=[\"lat\", \"lon\"], skipna=True).compute().values\n", + "\n", + "print(f\"mean: {mean}, std: {std}\")" ] }, { @@ -519,10 +2245,18 @@ "err.isel(time=1).plot()" ] }, + { + "cell_type": "markdown", + "id": "add2659d", + "metadata": {}, + "source": [ + "## Output model and prediction" + ] + }, { "cell_type": "code", - "execution_count": null, - "id": "bf3c8bea-c519-4a8f-a8a0-e25af79bac39", + "execution_count": 18, + "id": "2cfec2ef", "metadata": {}, "outputs": [], "source": [] @@ -538,7 +2272,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "climanet", "language": "python", "name": "python3" }, @@ -552,7 +2286,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.14" + "version": "3.14.0" } }, "nbformat": 4, diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..ca1016c --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,42 @@ +# Execute training tasks on SLURM + +1. Make a working directory + +```sh +mkdir training +cd training +``` + +2. Clone this repo +```sh +git clone git@github.com:ESMValGroup/ClimaNet.git +``` + +3. Install uv for dependency management. Se [uv doc](https://docs.astral.sh/uv/getting-started/installation/). + +4. Create a venv and install Python dependencies using uv +```sh +cd ClimaNet +``` + +``` +uv sync +``` + +A `.venv` dir will appear + +5. Copy the python script and slurm script into the working dir: + +```sh +cp ClimaNet/scripts/example* . +``` + +6. Config `example.slurm`, in the `source ...` line, make sure the venv just created is activated. + Note that the account is the ESO4CLIMA project account, which is shared by multiple users. + +7. Config `example.py`, make sure the path of input data and land mask data is correct. + +8. Execute the SLURM job +```sh +sbatch example.slurm +``` \ No newline at end of file diff --git a/scripts/example.slurm b/scripts/example.slurm new file mode 100644 index 0000000..a41b132 --- /dev/null +++ b/scripts/example.slurm @@ -0,0 +1,12 @@ +#!/bin/bash +#SBATCH --job-name=eso4clima +#SBATCH --partition=compute +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=128 +#SBATCH --time=01:00:00 +#SBATCH --account=bd0854 + +source /home/b/b383704/eso4clima/ClimaNet/.venv/bin/activate + +python /home/b/b383704/eso4clima/train_twoyears/example.py + diff --git a/scripts/example_inference.py b/scripts/example_inference.py new file mode 100644 index 0000000..4c82026 --- /dev/null +++ b/scripts/example_inference.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python3 +"""Example inference script""" + +from pathlib import Path +import torch +import torch.nn.functional +import xarray as xr +from torch.utils.data import DataLoader + +from climanet import STDataset +from climanet.st_encoder_decoder import SpatioTemporalModel +from climanet.utils import pred_to_numpy + +import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def main(): + # Data files + data_folder = Path( + "/work/bd0854/b380103/eso4clima/output/v1.0/concatenated/" + ) # HPC + # data_folder = Path("../../data/output/") # local + daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc")) + monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc")) + daily_files.sort() + monthly_files.sort() + + # Land surface + lsm_file = "/home/b/b383704/eso4clima/train_twoyears/era5_lsm_bool.nc" # HPC + # lsm_file = data_folder / "era5_lsm_bool.nc" # local + + # Path to the trained model + model_save_path = Path("./models/spatio_temporal_model.pth") + + # Load full dataset + daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc")) + monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc")) + patch_size_training = 80 + daily_data = xr.open_mfdataset(daily_files) + monthly_data = xr.open_mfdataset(monthly_files) + + daily_data = xr.open_mfdataset( + daily_files, + combine="by_coords", + chunks={ + "time": 1, + "lat": patch_size_training * 2, + "lon": patch_size_training * 2, + }, + data_vars="minimal", + coords="minimal", + compat="override", + parallel=False, + ) + + monthly_data = xr.open_mfdataset( + monthly_files, + combine="by_coords", + chunks={ + "time": 1, + "lat": patch_size_training * 2, + "lon": patch_size_training * 2, + }, + data_vars="minimal", + coords="minimal", + compat="override", + parallel=False, + ) + + lsm_mask = xr.open_dataset(lsm_file) + + # Load the trained model + model = SpatioTemporalModel() + model.load_state_dict(torch.load(model_save_path)) + model.eval() + + # Calculate prediction and attach to monthly_data xr.Dataset + dataset_pred = STDataset( + daily_da=daily_data["ts"], + monthly_da=monthly_data["ts"], + land_mask=lsm_mask["lsm"], + patch_size=(daily_data.sizes["lat"], daily_data.sizes["lon"]), + ) + dataloader_pred = DataLoader( + dataset_pred, + batch_size=len(dataset_pred), + pin_memory=False, + ) + full_batch = next(iter(dataloader_pred)) + daily_batch = full_batch["daily_patch"] + daily_mask = full_batch["daily_mask_patch"] + land_mask_patch = full_batch["land_mask_patch"][0, ...] + padded_days_mask = full_batch["padded_days_mask"] + with torch.no_grad(): + pred = model(daily_batch, daily_mask, land_mask_patch, padded_days_mask) + monthly_prediction = pred_to_numpy(pred, land_mask=land_mask_patch)[0] + monthly_data["ts_pred"] = (("time", "lat", "lon"), monthly_prediction) + + # Save the xr.Dataset with predictions + predictions_save_path = Path("./predicted_data/predictions.nc") + predictions_save_path.parent.mkdir(parents=True, exist_ok=True) + monthly_data.to_netcdf(predictions_save_path) + logger.info(f"Saved predictions to: {predictions_save_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/example_training.py b/scripts/example_training.py new file mode 100644 index 0000000..c823108 --- /dev/null +++ b/scripts/example_training.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +"""Example training script""" + +from pathlib import Path +import torch +import torch.nn.functional +import xarray as xr +from torch.utils.data import DataLoader + +from climanet import STDataset +from climanet.st_encoder_decoder import SpatioTemporalModel + +import logging + +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def main(): + # Data files + data_folder = Path( + "/work/bd0854/b380103/eso4clima/output/v1.0/concatenated/" + ) # HPC + # data_folder = Path("../../data/output/") # local + daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc")) + monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc")) + daily_files.sort() + monthly_files.sort() + + # Land surface + lsm_file = "/home/b/b383704/eso4clima/train_twoyears/era5_lsm_bool.nc" # HPC + # lsm_file = data_folder / "era5_lsm_bool.nc" # local + + # Load full dataset + daily_files = sorted(data_folder.rglob("20*_day_ERA5_masked_ts.nc")) + monthly_files = sorted(data_folder.rglob("20*_mon_ERA5_full_ts.nc")) + patch_size_training = 80 + daily_data = xr.open_mfdataset(daily_files) + monthly_data = xr.open_mfdataset(monthly_files) + + daily_data = xr.open_mfdataset( + daily_files, + combine="by_coords", + chunks={ + "time": 1, + "lat": patch_size_training * 2, + "lon": patch_size_training * 2, + }, + data_vars="minimal", + coords="minimal", + compat="override", + parallel=False, + ) + + monthly_data = xr.open_mfdataset( + monthly_files, + combine="by_coords", + chunks={ + "time": 1, + "lat": patch_size_training * 2, + "lon": patch_size_training * 2, + }, + data_vars="minimal", + coords="minimal", + compat="override", + parallel=False, + ) + + lsm_mask = xr.open_dataset(lsm_file) + + # Compute monthly climatology stats without persisting the full (time, lat, lon) monthly field + monthly_ts = daily_data["ts"].resample(time="MS").mean(skipna=True) + mean = monthly_ts.mean(dim=["lat", "lon"], skipna=True).compute().values + std = monthly_ts.std(dim=["lat", "lon"], skipna=True).compute().values + logger.info(f"mean: {mean}, std: {std}") + + # Make a dataset + dataset = STDataset( + daily_da=daily_data["ts"], + monthly_da=monthly_data["ts"], + land_mask=lsm_mask["lsm"], + patch_size=(patch_size_training, patch_size_training), + ) + + # Initialize training + device = "cuda" if torch.cuda.is_available() else "cpu" + model = SpatioTemporalModel( + embed_dim=128, + patch_size=(1, 2, 2), + overlap=2, + max_months=monthly_data.sizes["time"], + ).to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + decoder = model.decoder + with torch.no_grad(): + decoder.bias.copy_(torch.from_numpy(mean)) + decoder.scale.copy_(torch.from_numpy(std) + 1e-6) + + # Make a dataloader + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=True, + pin_memory=False, + ) + + # Training process + best_loss = float("inf") + patience = 10 + counter = 0 + model.train() + for epoch in range(101): + for batch in dataloader: + optimizer.zero_grad() + + daily_batch = batch["daily_patch"] + daily_mask = batch["daily_mask_patch"] + monthly_target = batch["monthly_patch"] + land_mask_patch = batch["land_mask_patch"][0, ...] + padded_days_mask = batch["padded_days_mask"] + + pred = model(daily_batch, daily_mask, land_mask_patch, padded_days_mask) + + ocean = (~land_mask_patch).to(pred.device) + ocean = ocean[None, None, :, :] + + loss = ( + torch.nn.functional.l1_loss(pred, monthly_target, reduction="none") + * ocean + ) + loss_per_month = loss.sum(dim=(-2, -1)) / ocean.sum(dim=(-2, -1)) + loss = loss_per_month.mean() + + loss.backward() + optimizer.step() + + if loss.item() < best_loss: + best_loss = loss.item() + counter = 0 + + if epoch % 20 == 0: + logger.info(f"The loss is {best_loss} at epoch {epoch}") + else: + counter += 1 + if counter >= patience: + logger.info( + f"No improvement for {patience} epochs, stopping early at epoch {epoch}." + ) + break + + logger.info("training done!") + logger.info(f"Final loss: {loss.item()}") + + # Save the trained model with config + checkpoint = { + "config": model.config, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "epoch": epoch, + "loss": loss.item(), + } + model_save_path = Path("./models/spatio_temporal_model.pth") + model_save_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(checkpoint, model_save_path) + logger.info(f"Checkpoint saved to {model_save_path}") + + +if __name__ == "__main__": + main()