-
Notifications
You must be signed in to change notification settings - Fork 0
Fix blocky effect #29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4876cb2
4ca9a34
688c30b
ee2714f
7f8780b
3ca2468
768b678
d152312
b86c676
6ce4b18
8e774c9
3f29f92
7472b17
939d7a2
4bf024a
99e1f1e
3c5eebd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -185,7 +185,7 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12): | |
| def forward(self, x, M, T, H, W, padded_days_mask=None): | ||
| """ | ||
| Args: | ||
| x: (B, M*T*H*W, C) containing spatio-temporal tokens, where C is the embedding dimension. | ||
| x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension. | ||
| M: number of months | ||
| T: number of temporal tokens per month after temporal patching (Tp) | ||
| H: spatial height after spatial patching | ||
|
|
@@ -194,9 +194,12 @@ def forward(self, x, M, T, H, W, padded_days_mask=None): | |
| True indicating which day tokens are padded (because some months | ||
| have fewer days). This is used to mask out padded tokens in attention computation. | ||
| Returns: | ||
| Tensor of shape (B, M*H*W, C) with one temporally aggregated, where C is the embedding dimension. | ||
| Tensor of shape (B, M, H*W, C) with one temporally aggregated, where C is the embedding dimension. | ||
| """ | ||
| seq = rearrange(x, "b (m t h w) c -> b (h w) m t c", m=M, t=T, h=H, w=W) | ||
| B, M, Tp, Hp, Wp, C = x.shape | ||
|
|
||
| # Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing | ||
| seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C) | ||
|
|
||
| pe_days = self.pos_days(T).to(seq.device).to(seq.dtype) # (T, C) | ||
| pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C) | ||
|
|
@@ -209,10 +212,10 @@ def forward(self, x, M, T, H, W, padded_days_mask=None): | |
|
|
||
| # padded_days_mask is (B, M, T) true=padded, -> (B, HW, M, T) | ||
| if padded_days_mask is not None: | ||
| pad = padded_days_mask[:, None, :, :].expand(x.shape[0], H * W, M, T) | ||
| pad = padded_days_mask[:, None, :, :].expand(B, H * W, M, T) | ||
| day_logits = day_logits.masked_fill(pad, float("-inf")) | ||
|
|
||
| day_w = torch.softmax(day_logits, dim=-1) | ||
| day_w = torch.softmax(day_logits, dim=-1) # turns inf to 0 | ||
| month_tokens = (seq * day_w.unsqueeze(-1)).sum(dim=3) # (B, HW, M, C) | ||
|
|
||
| # Cross-month attention at each spatial location | ||
|
|
@@ -222,10 +225,10 @@ def forward(self, x, M, T, H, W, padded_days_mask=None): | |
| z = z + attn_out | ||
| z = z + self.month_ffn(z) | ||
|
|
||
| # Back to flattened tokens with month kept | ||
| z = rearrange(z, "(b s) m c -> b s m c", b=x.shape[0], s=H * W) | ||
| out = rearrange(z, "b (h w) m c -> b (m h w) c", h=H, w=W) | ||
| return out # (B, M*H*W, C) C: embedding dimension | ||
| # Back to (B, M, Hp*Wp, C) | ||
| z = z.view(B, Hp * Wp, M, C) | ||
| out = z.permute(0, 2, 1, 3) # (B, M, Hp*Wp, C) | ||
| return out # (B, M, H*W, C) C: embedding dimension | ||
|
|
||
|
|
||
| class MonthlyConvDecoder(nn.Module): | ||
|
|
@@ -293,10 +296,10 @@ def __init__( | |
| # Refinement block: a small conv layers to smooth patch boundaries | ||
| self.refine = nn.Sequential( | ||
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | ||
| nn.BatchNorm2d(out_channels), | ||
| nn.GroupNorm(num_groups=8, num_channels=out_channels), | ||
| nn.GELU(), | ||
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | ||
| nn.BatchNorm2d(out_channels), | ||
| nn.GroupNorm(num_groups=8, num_channels=out_channels), | ||
| nn.GELU(), | ||
| ) | ||
|
|
||
|
|
@@ -314,18 +317,18 @@ def forward(self, latent, M, out_H, out_W, land_mask=None): | |
| M: Number of months (temporal patches) | ||
| out_H: Target output height (must be divisible by patch_h) | ||
| out_W: Target output width (must be divisible by patch_w) | ||
| land_mask: Optional boolean tensor of shape (out_H, out_W). Values set to True | ||
| land_mask: Optional boolean tensor of shape (B, out_H, out_W). Values set to True | ||
| will be masked out (set to 0) in the output (only ocean pixels exist). | ||
| Returns: | ||
| Tensor of shape (B, M, out_H, out_W) representing the monthly variable e.g. SST. | ||
| """ | ||
| B, MHW, C = latent.shape | ||
| B, M, Np, C = latent.shape | ||
| Hp = out_H // self.patch_h | ||
| Wp = out_W // self.patch_w | ||
| assert MHW == M * Hp * Wp, f"Token mismatch: got {MHW}, expected {M * Hp * Wp}" | ||
| assert Np == Hp * Wp, f"Token mismatch: got {Np}, expected {Hp * Wp}" | ||
|
|
||
| # transforms the latent tensor from sequence format to image format for | ||
| # convolution operations; (B, M*Hp*Wp, C) -> (B*M, C, Hp, Wp) | ||
| # convolution operations; | ||
| out = latent.view(B, M, Hp, Wp, C).permute(0, 1, 4, 2, 3).contiguous() | ||
| out = out.view(B * M, C, Hp, Wp) | ||
|
|
||
|
|
@@ -349,7 +352,7 @@ def forward(self, latent, M, out_H, out_W, land_mask=None): | |
|
|
||
| # Mask out land areas if land_mask is provided | ||
| if land_mask is not None: | ||
| out = out.masked_fill(land_mask.bool()[None, None, :, :], 0.0) | ||
| out = out.masked_fill(land_mask.bool()[:, None, :, :], 0.0) | ||
| return out # (B, M, out_H, out_W) | ||
|
|
||
|
|
||
|
|
@@ -500,10 +503,11 @@ def __init__( | |
| patch_size=(1, 4, 4), | ||
| max_days=31, | ||
| max_months=12, | ||
| hidden=128, | ||
| num_months=12, | ||
| hidden=256, | ||
| overlap=1, | ||
| max_H=1024, | ||
| max_W=1024, | ||
| max_H=256, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving to lat,lon based encoding this would be less relevant, but how does this match to the total grid size given theglobal input data?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, thanks. So max_H, max_W are basically the total input dimensions quantised by the patch size. That is fine, of course, and I also understand that we aren't passing the full global image in one go. My point was, that as it stands, the 2d positional encodings are pre-calculated on the basis of these sizes, i.e. the sine and cosine values derive from the max_H and max_W values. Should one change input dimensions/resolution this may be a challenge to generalisation. Other approaches also have their own difficulties, this is just sth. to remain aware of. |
||
| max_W=256, | ||
| spatial_depth=2, | ||
| spatial_heads=4, | ||
| ): | ||
|
|
@@ -515,6 +519,7 @@ def __init__( | |
| patch_size: Tuple of (T, H, W) patch sizes for temporal and spatial patching | ||
| max_days: Maximum number of days for temporal positional encoding | ||
| max_months: Maximum number of months for temporal positional encoding | ||
| num_months: Number of months to predict (output channels in decoder) | ||
| hidden: Hidden dimension used in the decoder | ||
| overlap: Overlap for deconvolution in the decoder | ||
| max_H: Maximum spatial height for 2D positional encoding | ||
|
|
@@ -541,7 +546,7 @@ def __init__( | |
| patch_w=patch_size[2], | ||
| hidden=hidden, | ||
| overlap=overlap, | ||
| num_months=max_months, | ||
| num_months=num_months, | ||
| ) | ||
| self.patch_size = patch_size | ||
|
|
||
|
|
@@ -552,7 +557,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None | |
| daily_data: Tensor of shape (B, C, M, T, H, W) containing daily | ||
| data, where C is the number of channels (e.g., 1 for SST) | ||
| daily_mask: Boolean tensor of same shape as daily_data indicating missing values | ||
| land_mask_patch: Boolean tensor of shape (H, W) to mask land areas in the output | ||
| land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output | ||
| padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded | ||
| (True for padded tokens). Used to mask out padded tokens in temporal attention. | ||
| Returns: | ||
|
|
@@ -574,18 +579,6 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None | |
| ) | ||
| assert T % self.patch_size[0] == 0, "T must be divisible by patch size" | ||
|
|
||
| # Step 1: Encode spatio-temporal patches | ||
| # each month independently by folding M into batch | ||
| daily_data_reshaped = daily_data.reshape(B * M, C, T, H, W) | ||
| daily_mask_reshaped = daily_mask.reshape(B * M, C, T, H, W) | ||
| latent = self.encoder( | ||
| daily_data_reshaped, daily_mask_reshaped | ||
| ) # (B*M, N_patches, embed_dim) | ||
|
|
||
| # Step 2: Aggregate temporal information for each spatial patch | ||
| # latent -> (B, M*Np, embed_dim) to match the aggregator input x: (B, M*Tp*Hp*Wp, embed_dim) | ||
| latent = latent.reshape(B, M * Np, -1) | ||
|
|
||
| if padded_days_mask is not None and self.patch_size[0] > 1: | ||
| B, M, T_days = padded_days_mask.shape | ||
| if T_days % self.patch_size[0] != 0: | ||
|
|
@@ -596,23 +589,45 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None | |
| B, M, T_days // self.patch_size[0], self.patch_size[0] | ||
| ).all(dim=-1) # (B, M, Tp) | ||
|
|
||
| # Step 1: Encode spatio-temporal patches | ||
| # each month independently by folding M into batch | ||
| # encoder input shape = (B, C, T, H, W) where C is channel. | ||
| # encoder output shape = (B, N_patches, embed_dim) | ||
| # so M is folded into B, and T, H, W are the spatio-temporal dimensions to be patched. | ||
| daily_data_reshaped = daily_data.reshape(B * M, C, T, H, W) | ||
| daily_mask_reshaped = daily_mask.reshape(B * M, C, T, H, W) | ||
|
|
||
| latent = self.encoder( | ||
| daily_data_reshaped, daily_mask_reshaped | ||
| ) # (B*M, N_patches, embed_dim) | ||
|
|
||
| # Step 2: Aggregate temporal information for each spatial patch | ||
| # temporal input shape = (B, M*T*H*W, C), C: embedding dimension | ||
| # temporal output shape = (B, M, H*W, C) C: embedding dimension | ||
| embed_dim = latent.shape[-1] | ||
| latent = latent.view(B, M, Tp, Hp, Wp, embed_dim) | ||
|
|
||
| agg_latent = self.temporal( | ||
| latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask | ||
| ) # (B, M*Hp*Wp, embed_dim) | ||
| ) # (B, M, Hp*Wp, embed_dim) | ||
|
|
||
| # Step 3: Add spatial positional encodings and mix spatial features | ||
| E = agg_latent.shape[-1] | ||
| agg_latent = agg_latent.view(B, M, Hp * Wp, E) | ||
| # spatial PE output shape = (Hp, Wp, embed_dim) | ||
| pe = ( | ||
| self.spatial_pe(Hp, Wp).to(agg_latent.device).to(agg_latent.dtype) | ||
| ) # (Hp*Wp, E) | ||
| x = agg_latent + pe[None, None, :, :] | ||
| ) # (Hp, Wp, E) | ||
| x = agg_latent + pe[None, None, :, :] # (B, M, Hp*Wp, E) | ||
|
|
||
| # Step 4: Spatial mixing with Transformer | ||
| x = x.view(B * M, Hp * Wp, E) | ||
| x = self.spatial_tr(x) # (B*M, Hp*Wp, E) | ||
| x = x.view(B, M * Hp * Wp, E) # back to (B, M*Hp*Wp, E) | ||
| # spatial transformer input shape = (B, N, C), output shape = (B, N, C) C: embedding dimension | ||
| # M is folded in B. | ||
| C = x.shape[-1] | ||
| x = x.reshape(B * M, Hp * Wp, C) | ||
| x = self.spatial_tr(x) | ||
| x = x.view(B, M, Hp * Wp, C) | ||
|
|
||
| # Step 5: Decode to full-resolution 2D map | ||
| # decoder input shape is (B, M*Hp*Wp, C), C: embedding dimension | ||
| # decoder output shape is (B, M, H, W) | ||
| monthly_pred = self.decoder(x, M, H, W, land_mask_patch) # (B, M, H, W) | ||
| return monthly_pred | ||
Uh oh!
There was an error while loading. Please reload this page.