Skip to content

[Common, PyTorch] Improve mHC to match DeepSeek's implementation#2978

Open
kainzhong wants to merge 19 commits into
NVIDIA:mainfrom
kainzhong:feat/mhc_optimization1
Open

[Common, PyTorch] Improve mHC to match DeepSeek's implementation#2978
kainzhong wants to merge 19 commits into
NVIDIA:mainfrom
kainzhong:feat/mhc_optimization1

Conversation

@kainzhong
Copy link
Copy Markdown
Collaborator

Description

Some enhancement for mHC to better align with DeepSeek's tilelang implementation: https://github.com/deepseek-ai/TileKernels/tree/main/tile_kernels/mhc

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add mhc_generate_mix_and_aggregate API that does projection, scale, sinkhorn and aggregate together
  • Allow mhc_fused_projection to accept arguments with mixed dtype: x.dtype=bf16, phi.dtype=fp32, which matches DeepSeek's implementation
  • mhc_fused_projection now outputs fp32 regardless of the input dtype, matching DeepSeek's implementation
  • Add fuse_grad_x_acc optimization (default to False), which will reuse the same grad_x buffer to accumulate the initial mHC input x's gradient for mhc_fused_expand_combine, mhc_fused_aggregate and mhc_fused_projection
  • Support norm_weight for mhc_fused_projection, which would be equivalent to apply RMSNorm in the unfused manner with elementwise_affine=True, which would be the learnable per-element affine parameters for RMSNorm
  • Refactor some kernel code to avoid duplication. I just realized if you make a triton kernel constexpr, it can be used as a macro in if branches since triton will not compile if it knows in compile time that some branch will not be taken
  • Fix the bug that grid will exceed CUDA's limitation when M is too large and the autotune candidate is BLOCK_SIZE_M=1. Such invalid configs will be pruned now.
  • Improve projection op by using TMA if on Hopper+

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@kainzhong kainzhong force-pushed the feat/mhc_optimization1 branch from af685d7 to eccba0c Compare May 12, 2026 00:57
kainzhong and others added 11 commits May 12, 2026 17:54
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@kainzhong kainzhong marked this pull request as ready for review May 13, 2026 17:26
@kainzhong kainzhong requested a review from ksivaman as a code owner May 13, 2026 17:26
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 13, 2026

Greptile Summary

This PR improves the mHC (manifold Hyper-Connection) Triton kernels to align more closely with DeepSeek's TileKernels reference implementation, adding several new features and fixing correctness issues.

  • Adds mhc_generate_mix_and_aggregate as a convenience API that wraps projection, scale, sinkhorn, and aggregate into a single call with the DeepSeek V4 recipe (x: BF16, phi/alpha/beta: FP32).
  • Introduces norm_weight support in mhc_fused_projection for learnable RMSNorm affine parameters, a fuse_grad_x_acc optimization that reuses a shared FP32 buffer to accumulate gradients across three backward kernels, TMA-backed loads on Hopper+, and fixes grid-dim pruning to avoid CUDA's 65535 Y-dimension limit across all affected kernels.

Confidence Score: 5/5

The change is safe to merge. All three previously-reported backward return-count mismatches are now correctly fixed, the grid-dimension pruners are properly structured with env-var overrides inside the pruner callbacks, and the fused gradient accumulation buffer flows correctly through the reverse-pass order.

All newly-added autograd Function backward() returns match their forward input counts, TMA path is conditionally guarded, norm_weight gradient math is consistent between forward and backward, and the buffer-sharing scheme initializes before accumulating in the correct backward order. The remaining comments are non-blocking style and test-coverage observations.

No files require blocking attention. tests/pytorch/test_mhc.py would benefit from a mixed-dtype aggregate test case to cover the BF16-x/FP32-H_pre combination used at runtime.

Important Files Changed

Filename Overview
transformer_engine/pytorch/triton/mhc.py New mhc_generate_mix_and_aggregate API, mHCProjectionOp backward updated to 5 return values matching 5 forward inputs, fused_grad_x_acc_buffer plumbed through all three ops; one dead attribute (ctx.phi_dtype) set but never consumed in backward.
transformer_engine/common/triton/mhc.py New _mhc_projection_bwd_fused_dphi/_dx kernels with pruners, TMA path in forward, norm_weight in backward kernels, STEP_SIZE_C loop for aggregate/expand-combine bwd, and grid-dim pruners for all autotuned kernels; stride_norm_weight appears in the fwd kernel signature but is never used for memory access.
tests/pytorch/test_mhc.py Good new test coverage for mixed-dtype projection, norm_weight, fuse_grad_acc; aggregate test still only uses matched dtypes and does not cover the bf16-x / fp32-H_pre combination that arises in mhc_generate_mix_and_aggregate.

Reviews (4): Last reviewed commit: "fix" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/triton/mhc.py Outdated
Comment thread transformer_engine/pytorch/triton/mhc.py Outdated
Comment thread transformer_engine/pytorch/triton/mhc.py Outdated
Comment thread transformer_engine/common/triton/mhc.py
kainzhong and others added 4 commits May 13, 2026 17:45
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Comment thread transformer_engine/common/triton/mhc.py Outdated
kainzhong and others added 2 commits May 13, 2026 18:03
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Comment thread transformer_engine/common/triton/mhc.py
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant