[Common, PyTorch] Improve mHC to match DeepSeek's implementation#2978
[Common, PyTorch] Improve mHC to match DeepSeek's implementation#2978kainzhong wants to merge 19 commits into
Conversation
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
af685d7 to
eccba0c
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
for more information, see https://pre-commit.ci
Greptile SummaryThis PR improves the mHC (manifold Hyper-Connection) Triton kernels to align more closely with DeepSeek's TileKernels reference implementation, adding several new features and fixing correctness issues.
Confidence Score: 5/5The change is safe to merge. All three previously-reported backward return-count mismatches are now correctly fixed, the grid-dimension pruners are properly structured with env-var overrides inside the pruner callbacks, and the fused gradient accumulation buffer flows correctly through the reverse-pass order. All newly-added autograd Function backward() returns match their forward input counts, TMA path is conditionally guarded, norm_weight gradient math is consistent between forward and backward, and the buffer-sharing scheme initializes before accumulating in the correct backward order. The remaining comments are non-blocking style and test-coverage observations. No files require blocking attention. Important Files Changed
Reviews (4): Last reviewed commit: "fix" | Re-trigger Greptile |
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
for more information, see https://pre-commit.ci
Description
Some enhancement for mHC to better align with DeepSeek's tilelang implementation: https://github.com/deepseek-ai/TileKernels/tree/main/tile_kernels/mhc
Fixes # (issue)
Type of change
Changes
mhc_generate_mix_and_aggregateAPI that does projection, scale, sinkhorn and aggregate togethermhc_fused_projectionto accept arguments with mixed dtype: x.dtype=bf16, phi.dtype=fp32, which matches DeepSeek's implementationmhc_fused_projectionnow outputs fp32 regardless of the input dtype, matching DeepSeek's implementationfuse_grad_x_accoptimization (default to False), which will reuse the same grad_x buffer to accumulate the initial mHC input x's gradient formhc_fused_expand_combine,mhc_fused_aggregateandmhc_fused_projectionnorm_weightformhc_fused_projection, which would be equivalent to apply RMSNorm in the unfused manner withelementwise_affine=True, which would be the learnable per-element affine parameters for RMSNormChecklist: