Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,10 @@ def __init__(
spatial_heads: Number of attention heads in the spatial Transformer
"""
super().__init__()

# Store arguments to be used later for model saving/loading
self.config = {k: v for k, v in locals().items() if k not in ('self', '__class__')}

self.encoder = VideoEncoder(
in_chans=in_chans, embed_dim=embed_dim, patch_size=patch_size
)
Expand Down Expand Up @@ -568,7 +572,6 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
Tp = T // self.patch_size[0]
Hp = H // self.patch_size[1]
Wp = W // self.patch_size[2]
Np = Tp * Hp * Wp

# check shape and patch compatibility
assert daily_mask.shape == daily_data.shape, (
Expand Down
149 changes: 149 additions & 0 deletions climanet/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from pathlib import Path
from typing import Tuple

import numpy as np
from torch.utils.data import Dataset
import xarray as xr
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

def regrid_to_boundary_centered_grid(
da: xr.DataArray,
Expand Down Expand Up @@ -150,3 +154,148 @@ def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None):
pred = torch.where(land_mask, torch.full_like(pred, float("nan")), pred)

return pred.detach().cpu().numpy()


def calc_stats(data: xr.DataArray, time_unit="MS", spatial_dims=("lat", "lon")):
averaged = data.resample(time=time_unit).mean(skipna=True)
mean = averaged.mean(dim=spatial_dims, skipna=True).values
std = averaged.std(dim=spatial_dims, skipna=True).values
return mean, std


def train_monthly_model(
model: torch.nn.Module,
dataset: Dataset,
decoder_stats: Tuple[np.ndarray, np.ndarray],
batch_size=2,
num_epoch=100,
patience=10,
accumulation_steps=1,
optimizer_lr=1e-3,
log_dir=".",
save_model=True,
device="cpu",
verbose=True
):
""" Train the model to predict monthly data from daily data.
Args:
model: the PyTorch model to train
dataset: Dataset object containing the training data
decoder_stats: Tuple of (mean, std) for the decoder
batch_size: number of samples per batch
num_epoch: number of epochs to train
patience: number of epochs to wait for improvement before early stopping
accumulation_steps: number of batches to accumulate gradients over before updating weights
optimizer_lr: learning rate for the optimizer
log_dir: directory to save logs
save_model: whether to save the best model to disk
device: device to run training on ("cpu" or "cuda")
verbose: whether to print training progress
"""

# Initialize the model
model = model.to(device)
mean, std = decoder_stats
decoder = model.decoder
with torch.no_grad():
decoder.bias.copy_(torch.from_numpy(mean))
decoder.scale.copy_(torch.from_numpy(std) + 1e-6) # small epsilon to avoid zero

# Create data loader
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=False,
)

# Initialize TensorBoard writer
Path(log_dir).mkdir(parents=True, exist_ok=True)
writer = SummaryWriter(log_dir)

# Set the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=optimizer_lr)
best_loss = float("inf")
counter = 0

model.train()
for epoch in range(num_epoch):
epoch_loss = 0.0

optimizer.zero_grad()

for i, batch in enumerate(dataloader):
# Get batch data
daily_batch = batch["daily_patch"]
daily_mask = batch["daily_mask_patch"]
monthly_target = batch["monthly_patch"]
land_mask = batch["land_mask_patch"]
padded_days_mask = batch["padded_days_mask"]

# Batch prediction
pred = model(daily_batch, daily_mask, land_mask, padded_days_mask) # (B, M, H, W)

# Mask out land pixels
ocean = (~land_mask).to(pred.device).unsqueeze(1).float() # (B, M=1, H, W) bool
loss = torch.nn.functional.l1_loss(pred, monthly_target, reduction="none")
loss = loss * ocean

num = loss.sum(dim=(-2, -1)) # (B, M)
denom = ocean.sum(dim=(-2, -1)).clamp_min(1) # (B, 1)

loss_per_month = num / denom
loss = loss_per_month.mean()

# Scale loss for gradient accumulation
scaled_loss = loss / accumulation_steps
scaled_loss.backward()

# Track unscaled loss for logging
epoch_loss += loss.item()

# Update weights every accumulation_steps batches
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

# Handle remaining gradients if num_batches is not divisible by accumulation_steps
if (i + 1) % accumulation_steps != 0:
optimizer.step()
optimizer.zero_grad()

# Calculate average epoch loss
avg_epoch_loss = epoch_loss / (i + 1)

# Log to TensorBoard
writer.add_scalar('Loss/train', avg_epoch_loss, epoch)
writer.add_scalar('Loss/best', best_loss, epoch)

# Early stopping check
if avg_epoch_loss < best_loss:
best_loss = avg_epoch_loss
counter = 0
else:
counter += 1

if verbose and epoch % 20 == 0:
print(f"Epoch {epoch}: best_loss = {best_loss:.6f}")

if counter >= patience:
writer.add_text('Training', f'Early stop at epoch {epoch}', epoch)
break

# Close the writer when done
writer.close()

if verbose:
print(f"Training complete. Best loss: {best_loss:.6f}")

if save_model:
model_path = Path(log_dir) / "best_model.pth"
torch.save(
{"model_state_dict": model.state_dict(), "model_config": model.config}, model_path
)
if verbose:
print(f"Model saved to {model_path}")

return model
Loading