Skip to content
16 changes: 16 additions & 0 deletions climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,22 @@ def __init__(
num_months=num_months,
)
self.patch_size = patch_size

# Store config for easy model replication
self.config = {
'in_chans': in_chans,
'embed_dim': embed_dim,
'patch_size': patch_size,
'max_days': max_days,
'max_months': max_months,
'num_months': num_months,
'hidden': hidden,
'overlap': overlap,
'max_H': max_H,
'max_W': max_W,
'spatial_depth': spatial_depth,
'spatial_heads': spatial_heads,
}
Comment on lines +554 to +567
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.config = {
'in_chans': in_chans,
'embed_dim': embed_dim,
'patch_size': patch_size,
'max_days': max_days,
'max_months': max_months,
'num_months': num_months,
'hidden': hidden,
'overlap': overlap,
'max_H': max_H,
'max_W': max_W,
'spatial_depth': spatial_depth,
'spatial_heads': spatial_heads,
}


def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None):
"""Forward pass of the Spatio-Temporal model.
Expand Down
Loading