Skip to content

feat(aggregation): Add ExcessMTL#747

Merged
ValerianRey merged 8 commits into
SimplexLab:mainfrom
KhusPatel4450:feat/excess-mtl-weighting
Jun 24, 2026
Merged

feat(aggregation): Add ExcessMTL#747
ValerianRey merged 8 commits into
SimplexLab:mainfrom
KhusPatel4450:feat/excess-mtl-weighting

Conversation

@KhusPatel4450

Copy link
Copy Markdown
Contributor

feat(aggregation): Add ExcessMTLWeighting

Implements ExcessMTLWeighting from Robust Multi-Task Learning with Excess Risks (He et al., ICML 2024).

At each forward call, per-task excess risks are estimated via a second-order Taylor approximation (Equations 6-7) using an AdaGrad-style diagonal Hessian accumulated across all calls. Task weights are then updated via an exponentiated gradient step (Equation 9) and normalised to the probability simplex.

Design notes

  • State: two registered buffers, _grad_sum ([m, n], accumulates squared gradients) and _weights ([m], current task weights). Both move with .to(device) and appear in state_dict().
  • Warmup (n_warmup_steps, default 0): during warmup, weights stay uniform and gradient statistics are collected. On the first post-warmup call, the average excess risk over the warmup period is saved as a normalisation baseline (initial_w), following Appendix C.1. Setting n_warmup_steps=0 matches the official implementation and LibMTL behaviour (first call's excess used as baseline directly).
  • Normalisation convention: weights initialised to [1/m, ..., 1/m] and always sum to 1, following the paper (vs. LibMTL's sum-to-m).
  • _n_steps: stored as a registered buffer (scalar torch.long) so warmup progress survives checkpointing. Zeroed in-place in reset() to preserve device placement.

References

@KhusPatel4450 KhusPatel4450 added cc: feat Conventional commit type for new features. package: aggregation labels Jun 20, 2026
@github-actions github-actions Bot changed the title feat(Aggregation): Add ExcessMTLWeighting feat(aggregation): Add ExcessMTLWeighting Jun 20, 2026
@KhusPatel4450 KhusPatel4450 force-pushed the feat/excess-mtl-weighting branch from 5143ce9 to d4a29f8 Compare June 20, 2026 15:20

@ValerianRey ValerianRey left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tyvm for the PR! This is only a partial review, but it's gonna be easier to review the rest after those things are adressed I think.

Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
@ValerianRey

Copy link
Copy Markdown
Member

Ty for the updates! This is much cleaner! I'll make the thorough round of review soon

@ValerianRey ValerianRey left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good job ty!

It's honestly one of the hardest implementations we've ever made!

Got a few things to fix still, and then we can merge.

Comment thread src/torchjd/aggregation/_excess_mtl.py
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
Comment thread src/torchjd/aggregation/_excess_mtl.py Outdated
@ValerianRey ValerianRey changed the title feat(aggregation): Add ExcessMTLWeighting feat(aggregation): Add ExcessMTL Jun 24, 2026
ValerianRey and others added 2 commits June 24, 2026 21:06
…essMTL

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@ValerianRey ValerianRey enabled auto-merge (squash) June 24, 2026 19:20
@ValerianRey ValerianRey merged commit 5a373b4 into SimplexLab:main Jun 24, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants