From eccba0ccb17e0d373a499b9188a724e8223db4e1 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 12 May 2026 00:33:52 +0000 Subject: [PATCH 01/19] [Common, PyTorch] Improve mHC to match DeepSeek's implementation Signed-off-by: Kaining Zhong --- tests/pytorch/test_mhc.py | 198 +++- transformer_engine/common/triton/mhc.py | 1071 +++++++++++----------- transformer_engine/pytorch/triton/mhc.py | 524 ++++++++--- 3 files changed, 1090 insertions(+), 703 deletions(-) diff --git a/tests/pytorch/test_mhc.py b/tests/pytorch/test_mhc.py index 541ce9a8c2..29087fa4bc 100644 --- a/tests/pytorch/test_mhc.py +++ b/tests/pytorch/test_mhc.py @@ -14,13 +14,14 @@ mhc_fused_aggregate, mhc_fused_expand_combine, mhc_fused_projection, + mhc_generate_mix_and_aggregate, ) # Disable TF32 for matmul to ensure consistency between the fused and reference implementations torch.backends.cuda.matmul.allow_tf32 = False -def mhc_projection_ref(x, phi): +def mhc_projection_ref(x, phi, norm_weight): """ Reference operator for mHC's projection building operation. @@ -29,19 +30,20 @@ def mhc_projection_ref(x, phi): - phi_pre: (n, nC) - phi_post: (n, nC) - phi_res: (n^2, nC) + norm_weight: (nC,) or None, if not None, apply element-wise multiplication to phi before projection n: number of Hyper Connection streams C: hidden dimension per stream """ - x_dtype = x.dtype - x = x.to(torch.float32) - phi = phi.to(torch.float32) - - Hs = x @ phi.T # (M, 2n + n^2) - x_fp32 = x.to(torch.float32) # Use fp32 for better numerical stability in variance calculation + x_fp32 = x.to(torch.float32) ms = (x_fp32 * x_fp32).mean(dim=1) - return Hs.to(x_dtype), ms + phi_fp32 = phi.to(torch.float32) + if norm_weight is not None: + phi_fp32 = phi_fp32 * norm_weight.to(torch.float32)[None, :] + Hs = x_fp32 @ phi_fp32.T # (M, 2n + n^2) + + return Hs, ms def mhc_scale_ref(H, alpha, beta, ms, n): @@ -139,9 +141,9 @@ def mhc_aggregate_ref(x, H_pre, n): s, b, C, n = x.shape H_pre = H_pre.view(s, b, n, 1) - out = (x @ H_pre).view(s, b, C) + out = (x.to(H_pre.dtype) @ H_pre).view(s, b, C) - return out + return out.to(x.dtype) def mhc_expand_combine_ref(f, bias, H_post, x, H_res, n): @@ -267,25 +269,46 @@ def get_tols(dtype): @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) -def test_mhc_projection(cfg: MHCConfig, dtype): +@pytest.mark.parametrize( + "dtypes", + [ + (torch.float32, torch.float32), + (torch.bfloat16, torch.bfloat16), + (torch.bfloat16, torch.float32), + ], + ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"], +) +@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"]) +def test_mhc_projection(cfg: MHCConfig, dtypes, has_norm_weight): reset_rng_states() s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n nC = n * C N = 2 * n + n * n - tols = get_tols(dtype) + x_dtype = dtypes[0] + phi_dtype = dtypes[1] + tols = get_tols(x_dtype) use_tf32 = False - x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) - phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") - + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype) + phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda") x_ref = x.detach().clone().requires_grad_(True) phi_ref = phi.detach().clone().requires_grad_(True) - ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref) - fused_out_Hs_padded, fused_out_ms = mhc_fused_projection(x, phi, use_tf32) + has_norm_weight = False + + if has_norm_weight: + norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype) + norm_weight_ref = norm_weight.detach().clone().requires_grad_(True) + else: + norm_weight = None + norm_weight_ref = None + + ref_out_Hs, ref_out_ms = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref) + fused_out_Hs_padded, fused_out_ms = mhc_fused_projection( + x, phi, norm_weight=norm_weight, use_tf32=use_tf32 + ) fused_out_Hs = fused_out_Hs_padded[:, :N] torch.testing.assert_close(fused_out_Hs, ref_out_Hs, **tols) @@ -295,10 +318,12 @@ def test_mhc_projection(cfg: MHCConfig, dtype): torch.testing.assert_close(x.grad, x_ref.grad, **tols) torch.testing.assert_close(phi.grad, phi_ref.grad, **tols) + if has_norm_weight: + torch.testing.assert_close(norm_weight.grad, norm_weight_ref.grad, **tols) @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) def test_mhc_scale(cfg: MHCConfig, dtype): reset_rng_states() @@ -329,28 +354,39 @@ def test_mhc_scale(cfg: MHCConfig, dtype): torch.cat([fused_out[i] for i in range(3)], dim=-1).sum().backward() torch.testing.assert_close(H_padded.grad[:, :N], H_ref.grad, **tols) + torch.testing.assert_close(ms.grad, ms_ref.grad, **tols) torch.testing.assert_close(alpha.grad, alpha_ref.grad, **tols) torch.testing.assert_close(beta.grad, beta_ref.grad, **tols) - torch.testing.assert_close(ms.grad, ms_ref.grad, **tols) @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) -@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) -def test_mhc_combined(cfg: MHCConfig, dtype): +@pytest.mark.parametrize( + "dtypes", + [ + (torch.float32, torch.float32), + (torch.bfloat16, torch.bfloat16), + (torch.bfloat16, torch.float32), + ], + ids=["x_fp32_phi_fp32", "x_bf16_phi_bf16", "x_bf16_phi_fp32"], +) +@pytest.mark.parametrize("has_norm_weight", [False, True], ids=["no_norm_weight", "norm_weight"]) +def test_mhc_rmsnorm(cfg: MHCConfig, dtypes, has_norm_weight): + # Verify if the fused kernel is equivalent to applying RMSNorm in the normal order reset_rng_states() s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n N = 2 * n + n * n nC = n * C - tols = get_tols(dtype) + x_dtype = dtypes[0] + phi_dtype = dtypes[1] + tols = get_tols(x_dtype) use_tf32 = False - x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=dtype) - phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") - - alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) - beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + x = torch.randn(s * b, nC, device="cuda", requires_grad=True, dtype=x_dtype) + phi = torch.randn(N, nC, dtype=phi_dtype, requires_grad=True, device="cuda") + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=phi_dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=phi_dtype) x_ref = x.detach().clone().requires_grad_(True) phi_ref = phi.detach().clone().requires_grad_(True) @@ -358,8 +394,17 @@ def test_mhc_combined(cfg: MHCConfig, dtype): alpha_ref = alpha.detach().clone().requires_grad_(True) beta_ref = beta.detach().clone().requires_grad_(True) - ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref) - fused_out_H_padded, fused_out_r = mhc_fused_projection(x, phi, use_tf32) + if has_norm_weight: + norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype) + norm_weight_ref = norm_weight.detach().clone().requires_grad_(True) + else: + norm_weight = None + norm_weight_ref = None + + ref_out_H, ref_out_r = mhc_projection_ref(x_ref, phi_ref, norm_weight_ref) + fused_out_H_padded, fused_out_r = mhc_fused_projection( + x, phi, norm_weight=norm_weight, use_tf32=use_tf32 + ) ref_H_pre, ref_H_post, ref_H_res = mhc_scale_ref( ref_out_H[:, :N], alpha_ref, beta_ref, ref_out_r, n @@ -368,17 +413,19 @@ def test_mhc_combined(cfg: MHCConfig, dtype): fused_out_H_padded, alpha, beta, fused_out_r, n ) - def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref): - dtype = x_ref.dtype - x_ref = x_ref.to(torch.float32) - phi_ref = phi_ref.to(torch.float32) - alpha_ref = alpha_ref.to(torch.float32) - beta_ref = beta_ref.to(torch.float32) - + def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref): # Check if after spliting RMSNorm to two steps in projection and scaling, - # theresult is close to applying RMSNorm in the correct order - x_rmsnorm = F.rms_norm(x_ref, normalized_shape=(nC,)) - H = x_rmsnorm @ phi_ref.T + # the result is close to applying RMSNorm in the correct order. + # Run RMSNorm in fp32 so the bf16 case has the same precision pattern as the + # kernel/ref (F.rms_norm on bf16 input would round x_rmsnorm back to bf16). + eps = torch.finfo(torch.float32).eps + norm_weight_fp32 = ( + norm_weight_ref.to(torch.float32) if norm_weight_ref is not None else None + ) + x_rmsnorm = F.rms_norm( + x_ref.to(torch.float32), normalized_shape=(nC,), weight=norm_weight_fp32, eps=eps + ) + H = x_rmsnorm @ phi_ref.T.to(torch.float32) H_pre = H[:, :n] H_post = H[:, n : 2 * n] H_res = H[:, 2 * n :] @@ -391,25 +438,82 @@ def mhc_combined(x_ref, phi_ref, alpha_ref, beta_ref): out_post = 2 * out_post.sigmoid() out_res = out_res - return out_pre.to(dtype), out_post.to(dtype), out_res.to(dtype) + return out_pre, out_post, out_res # Return in FP32 to match the kernel's behavior combined_H_pre, combined_H_post, combined_H_res = mhc_combined( - x_ref, phi_ref, alpha_ref, beta_ref + x_ref, phi_ref, alpha_ref, beta_ref, norm_weight_ref ) torch.testing.assert_close(combined_H_pre, ref_H_pre, **tols) torch.testing.assert_close(combined_H_post, ref_H_post, **tols) torch.testing.assert_close(combined_H_res, ref_H_res, **tols) + torch.testing.assert_close(ref_H_pre, fused_H_pre, **tols) + torch.testing.assert_close(ref_H_post, fused_H_post, **tols) + torch.testing.assert_close(ref_H_res, fused_H_res, **tols) + torch.testing.assert_close(combined_H_pre, fused_H_pre, **tols) torch.testing.assert_close(combined_H_post, fused_H_post, **tols) torch.testing.assert_close(combined_H_res, fused_H_res, **tols) +@pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) +@pytest.mark.parametrize("dtype", [torch.float32], ids=["fp32"]) +def test_mhc_fuse_grad_acc(cfg: MHCConfig, dtype): + # Skip bf16 tests since in the unfused path the we accumulate 3 bf16 gradients, whereas in the fused path + # we accumulate 3 fp32 gradients and then cast to bf16 in the end, which causes two paths to have different precision patterns + + reset_rng_states() + + s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n + N = 2 * n + n * n + nC = n * C + + tols = get_tols(dtype) + use_tf32 = False + + x = torch.randn(s, b, C, n, device="cuda", requires_grad=True, dtype=dtype) + phi = torch.randn(N, nC, dtype=dtype, requires_grad=True, device="cuda") + + alpha = torch.randn(3, device="cuda", requires_grad=True, dtype=dtype) + beta = torch.randn(1, 2 * n + n * n, device="cuda", requires_grad=True, dtype=dtype) + x_ref = x.detach().clone().requires_grad_(True) + phi_ref = phi.detach().clone().requires_grad_(True) + + alpha_ref = alpha.detach().clone().requires_grad_(True) + beta_ref = beta.detach().clone().requires_grad_(True) + + def end_to_end(x, phi, alpha, beta, fused_grad_x_acc): + aggregated, H_post, H_res = mhc_generate_mix_and_aggregate( + x, phi, alpha, beta, None, use_tf32, fused_grad_x_acc + ) + H_res = mhc_fused_sinkhorn(H_res.view(s, b, n, n), n).view(s * b, n * n) + expanded_combined = mhc_fused_expand_combine( + aggregated, + None, + H_post, + x, + H_res, + n, + False, + fused_grad_x_acc, + ) + + return expanded_combined + + expanded_combined_fuse_grad = end_to_end(x_ref, phi_ref, alpha_ref, beta_ref, False) + expanded_combined_no_fuse_grad = end_to_end(x, phi, alpha, beta, True) + + grad_output = torch.randn_like(expanded_combined_fuse_grad) + expanded_combined_fuse_grad.backward(grad_output) + expanded_combined_no_fuse_grad.backward(grad_output) + + torch.testing.assert_close(x.grad, x_ref.grad, **tols) + + @pytest.mark.parametrize("cfg", mhc_configs, ids=MHCConfig.desc) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=["fp32", "bf16"]) -@pytest.mark.parametrize("recompute", [False, True], ids=["no_recompute", "recompute"]) -def test_mhc_sinkhorn(cfg: MHCConfig, dtype, recompute): +def test_mhc_sinkhorn(cfg: MHCConfig, dtype): reset_rng_states() s, b, C, n = cfg.s, cfg.b, cfg.C, cfg.n @@ -420,7 +524,7 @@ def test_mhc_sinkhorn(cfg: MHCConfig, dtype, recompute): x_ref = x.detach().clone().requires_grad_(True) ref_out = mhc_sinkhorn_ref(x_ref, n) - fused_out = mhc_fused_sinkhorn(x, n, recompute) + fused_out = mhc_fused_sinkhorn(x, n) torch.testing.assert_close(fused_out, ref_out, **tols) @@ -446,7 +550,7 @@ def test_mhc_aggregate(cfg: MHCConfig, dtype): H_pre_ref = H_pre.detach().clone().requires_grad_(True) ref_out = mhc_aggregate_ref(x_ref, H_pre_ref, n) - fused_out = mhc_fused_aggregate(x, H_pre, n, False) + fused_out = mhc_fused_aggregate(x, H_pre, n, use_tf32=False) torch.testing.assert_close(fused_out, ref_out, **tols) @@ -482,7 +586,7 @@ def test_mhc_expand_combine(cfg: MHCConfig, dtype, with_bias): H_res_ref = H_res.detach().clone().requires_grad_(True) ref_out = mhc_expand_combine_ref(f_ref, bias_ref, H_post_ref, x_ref, H_res_ref, n) - fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n, False) + fused_out = mhc_fused_expand_combine(f, bias, H_post, x, H_res, n=n, use_tf32=False) torch.testing.assert_close(fused_out, ref_out, **tols) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 965bb437ff..62a9c81fe3 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -12,6 +12,8 @@ import triton import triton.language as tl +MAX_GRID_DIM_Y = 65535 # Maximum grid dimension in Y direction for current CUDA architectures + def projection_config_fwd(): block_m = [64, 128] @@ -34,23 +36,11 @@ def projection_config_fwd(): return configs -def projection_config_bwd(): - block_m = [32, 128] - block_k = [128] - warps = [2] - stages = [2, 3, 4] - - configs = [] - for m, bk, w, s in itertools.product(block_m, block_k, warps, stages): - configs.append( - triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk}, num_warps=w, num_stages=s) - ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] - return configs - - -@triton.autotune(configs=projection_config_fwd(), key=["M", "K"], reset_to_zero=["h_ptr", "ms_ptr"]) +@triton.autotune( + configs=projection_config_fwd(), + key=["M", "K", "USE_TMA"], + reset_to_zero=["h_ptr", "ms_ptr"], +) @triton.jit def _mhc_projection_fwd_fused( x_ptr, # (M, K) @@ -67,12 +57,14 @@ def _mhc_projection_fwd_fused( stride_hm: tl.constexpr, stride_hn: tl.constexpr, stride_ms: tl.constexpr, + stride_norm_weight: tl.constexpr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, STEP_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, precision: tl.constexpr, + USE_TMA: tl.constexpr, # If True, load x and phi via TMA tensor descriptors (Hopper+ only). Falls back to pointer-arith tl.load otherwise. ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -86,8 +78,9 @@ def _mhc_projection_fwd_fused( tl.assume(stride_hm == 32) tl.assume(stride_hn == 1) tl.assume(stride_ms == 1) + tl.assume(stride_norm_weight == 1) - tl.assume(BLOCK_SIZE_M % 32 == 0) + tl.assume(BLOCK_SIZE_M % 8 == 0) tl.assume(BLOCK_SIZE_K % 32 == 0) tl.assume(BLOCK_SIZE_N == 32) @@ -98,24 +91,53 @@ def _mhc_projection_fwd_fused( h_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) ms_acc = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + if USE_TMA: + x_desc = tl.make_tensor_descriptor( + x_ptr, + shape=[M, K], + strides=[stride_xm, 1], + block_shape=[BLOCK_SIZE_M, STEP_SIZE_K], + ) + phi_desc = tl.make_tensor_descriptor( + phi_ptr, + shape=[N, K], + strides=[stride_phin, 1], + block_shape=[BLOCK_SIZE_N, STEP_SIZE_K], + ) + k_base = pid_k * BLOCK_SIZE_K for k_start in range(0, tl.cdiv(BLOCK_SIZE_K, STEP_SIZE_K)): - k_offs = k_base + k_start * STEP_SIZE_K + tl.arange(0, STEP_SIZE_K) + k_off = k_base + k_start * STEP_SIZE_K + k_offs = k_off + tl.arange(0, STEP_SIZE_K) mask_k = k_offs < K - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) - phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + k_offs[None, :] * stride_phik - phi = tl.load( - phi_ptrs, - mask=(offs_n_full[:, None] < N) & mask_k[None, :], - other=0.0, - cache_modifier=".ca", - ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) - ms_acc += tl.sum(x * x, axis=1) + + if USE_TMA: + x = tl.load_tensor_descriptor(x_desc, [pid_m * BLOCK_SIZE_M, k_off]) + phi = tl.load_tensor_descriptor(phi_desc, [0, k_off]) + else: + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + k_offs[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + k_offs[None, :] * stride_phik + phi = tl.load( + phi_ptrs, + mask=(offs_n_full[:, None] < N) & mask_k[None, :], + other=0.0, + cache_modifier=".ca", + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + + ms_acc += tl.sum(x.to(tl.float32) * x.to(tl.float32), axis=1) + + # Currently triton has a bug where for small block size, tl.dot(x, phi.T) will use SMEM to transpose the matrix + # instead of emit a ldmatrix instruction with `.trans` modifier, which leads bank conflicts and performance regression + # See https://github.com/triton-lang/triton/issues/6569#issuecomment-2841739082 h_acc = tl.dot( - x, tl.trans(phi, (1, 0)), h_acc, input_precision=precision, out_dtype=tl.float32 + x.to(phi.dtype), + tl.trans(phi, (1, 0)), + h_acc, + input_precision=precision, + out_dtype=tl.float32, ) h_ptrs = h_ptr + offs_m[:, None] * stride_hm + offs_n_full[None, :] * stride_hn @@ -129,15 +151,35 @@ def _mhc_projection_fwd_fused( tl.atomic_add(ms_ptrs, ms, mask=masks_ms, sem="relaxed") +def projection_config_bwd_dx(): + block_m = [32, 128] + block_k = [128] + warps = [2] + stages = [2, 3, 4] + + configs = [] + for m, bk, w, s in itertools.product(block_m, block_k, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_K": bk}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + @triton.autotune( - configs=projection_config_bwd(), + configs=projection_config_bwd_dx(), key=["M", "K"], + # When FUSE_GRAD_X_ACC=True the kernel does a read-modify-write on grad_x_ptr; without + # restore_value the autotune timing trials accumulate onto the buffer and corrupt it. + restore_value=["grad_x_ptr"], ) @triton.jit -def _mhc_projection_bwd_fused( +def _mhc_projection_bwd_fused_dx( x_ptr, grad_x_ptr, # (M, K) phi_ptr, # (N, K) + norm_weight_ptr, # (K,) grad_h_ptr, # (M, N) grad_ms_ptr, # (M,) M, @@ -149,6 +191,7 @@ def _mhc_projection_bwd_fused( stride_grad_xk: tl.constexpr, stride_phin, stride_phik: tl.constexpr, + stride_norm_weight: tl.constexpr, stride_grad_phin, stride_grad_phik: tl.constexpr, stride_grad_hm: tl.constexpr, @@ -159,6 +202,8 @@ def _mhc_projection_bwd_fused( BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, precision: tl.constexpr, + FUSE_GRAD_X_ACC: tl.constexpr, + HAS_NORM_WEIGHT: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_k = tl.program_id(axis=1) @@ -174,6 +219,7 @@ def _mhc_projection_bwd_fused( tl.assume(stride_grad_phin == K) tl.assume(stride_grad_phik == 1) tl.assume(stride_grad_ms == 1) + tl.assume(stride_norm_weight == 1) tl.assume(BLOCK_SIZE_M % 32 == 0) tl.assume(BLOCK_SIZE_K % 32 == 0) @@ -204,19 +250,163 @@ def _mhc_projection_bwd_fused( phi = tl.load( phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0 ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + + if HAS_NORM_WEIGHT: + norm_weight_ptrs = norm_weight_ptr + offs_k * stride_norm_weight + norm_weight = tl.load(norm_weight_ptrs, mask=mask_k, other=0.0, cache_modifier=".ca").to( + phi.dtype + ) # (BLOCK_SIZE_K,) + phi = phi.to(tl.float32) * norm_weight.to(tl.float32)[None, :] + grad_ms = tl.load( grad_ms_ptrs, mask=offs_ms < M, other=0.0, cache_modifier=".ca" ) # (BLOCK_SIZE_M,) grad_x = x * (grad_ms * 2 / tl.cast(K, tl.float32))[:, None] grad_x = tl.dot( - grad_h, phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 + grad_h.to(phi.dtype), phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_k[None, :] * stride_grad_xk grad_x = grad_x.to(x.dtype) + if FUSE_GRAD_X_ACC: # If fused gradient accumulation is enabled, the buffer is always fp32 + grad_x_acc = tl.load(grad_x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) + grad_x = grad_x.to(tl.float32) + grad_x_acc tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_k[None, :]) +def projection_config_bwd_dphi(): + block_m = [512, 1024, 2048] + step_m = [32] + block_k = [128, 256] + warps = [2] + stages = [2, 3, 4] + + configs = [] + for bm, sm, bk, w, s in itertools.product(block_m, step_m, block_k, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": bm, "STEP_SIZE_M": sm, "BLOCK_SIZE_K": bk}, + num_warps=w, + num_stages=s, + ) + ) + return configs + + +@triton.autotune( + configs=projection_config_bwd_dphi(), + key=["M", "K"], + reset_to_zero=["grad_phi_ptr", "grad_norm_weight_ptr"], +) +@triton.jit +def _mhc_projection_bwd_fused_dphi( + x_ptr, # (M, K) + grad_H_ptr, # (M, 32) + phi_ptr, # (N, K), N=24 in our case since n = 4 + norm_weight_ptr, # (K,) + grad_phi_ptr, # (N, K), N=24 in our case since n = 4 + grad_norm_weight_ptr, # (K,) + M, + N, + K, + stride_xm, + stride_xk: tl.constexpr, + stride_grad_Hm: tl.constexpr, + stride_grad_Hn: tl.constexpr, + stride_phin, + stride_phik: tl.constexpr, + stride_norm_weight: tl.constexpr, + stride_grad_phin, + stride_grad_phik: tl.constexpr, + stride_grad_norm_weight: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + STEP_SIZE_M: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + precision: tl.constexpr, +): + pid_k = tl.program_id(axis=0) + pid_m = tl.program_id(axis=1) + + tl.assume(pid_k >= 0) + tl.assume(stride_xm > 0) + tl.assume(stride_xk == 1) + tl.assume(stride_grad_Hm == 32) + tl.assume(stride_grad_Hn == 1) + tl.assume(stride_phin == K) + tl.assume(stride_phik == 1) + tl.assume(stride_grad_phin == K) + tl.assume(stride_grad_phin == stride_phin) + tl.assume(stride_grad_phik == 1) + tl.assume(stride_grad_norm_weight == 1) + tl.assume(stride_norm_weight == 1) + + tl.assume(BLOCK_SIZE_M % 128 == 0) + tl.assume(BLOCK_SIZE_K % 64 == 0) + tl.assume(BLOCK_SIZE_N == 32) + tl.assume(STEP_SIZE_M % 32 == 0) + + offs_k = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + mask_k = offs_k < K + offs_n_full = tl.arange(0, BLOCK_SIZE_N) + mask_n = offs_n_full < N + + grad_psi_acc = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_K), dtype=tl.float32) + + m_start = pid_m * BLOCK_SIZE_M + m_end = tl.minimum(m_start + BLOCK_SIZE_M, M) + for m_idx in range(0, tl.cdiv(m_end - m_start, STEP_SIZE_M)): + offs_m = m_start + m_idx * STEP_SIZE_M + tl.arange(0, STEP_SIZE_M) + mask_m = offs_m < M + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0 + ) # (STEP_SIZE_M, BLOCK_SIZE_K) + grad_H_ptrs = ( + grad_H_ptr + offs_m[:, None] * stride_grad_Hm + offs_n_full[None, :] * stride_grad_Hn + ) + grad_H = tl.load( + grad_H_ptrs, mask=mask_m[:, None] & mask_n[None, :], other=0.0 + ) # (STEP_SIZE_M, BLOCK_SIZE_N) + + grad_psi_acc = tl.dot( + tl.trans(grad_H, (1, 0)), + x.to(grad_H.dtype), + acc=grad_psi_acc, + out_dtype=tl.float32, + input_precision=precision, + ) + + phi_ptrs = phi_ptr + offs_n_full[:, None] * stride_phin + offs_k[None, :] * stride_phik + phi = tl.load( + phi_ptrs, mask=(offs_n_full[:, None] < N) & mask_k[None, :], other=0.0 + ) # (BLOCK_SIZE_N, BLOCK_SIZE_K) + norm_weight_ptrs = norm_weight_ptr + offs_k * stride_norm_weight + norm_weight = tl.load( + norm_weight_ptrs, mask=mask_k, other=0.0, cache_modifier=".cg" + ) # (BLOCK_SIZE_K,) + phi = phi.to(tl.float32) + norm_weight = norm_weight.to(tl.float32) + + # Keep grad_psi in SRAM and get grad_phi & grad_norm_weight + grad_phi = grad_psi_acc * norm_weight[None, :].to(grad_psi_acc.dtype) # (32, BLOCK_SIZE_K) + grad_norm_weight = tl.sum(grad_psi_acc * phi.to(grad_psi_acc.dtype), axis=0) # (BLOCK_SIZE_K,) + + grad_phi_ptrs = ( + grad_phi_ptr + offs_n_full[:, None] * stride_grad_phin + offs_k[None, :] * stride_grad_phik + ) + grad_norm_weight_ptrs = grad_norm_weight_ptr + offs_k * stride_grad_norm_weight + + tl.atomic_add( + grad_phi_ptrs, + grad_phi, + mask=(offs_n_full[:, None] < N) & mask_k[None, :], + sem="relaxed", + ) + tl.atomic_add(grad_norm_weight_ptrs, grad_norm_weight, mask=mask_k, sem="relaxed") + + def scale_config(): block_m = [128] warps = [4] @@ -749,6 +939,22 @@ def _mhc_sinkhorn_fwd_fused( tl.store(output_ptrs, P, mask=mask_batch[:, None]) +def aggregate_config_fwd(): + block_m = [1, 2, 4] + block_c = [128, 256] + warps = [1, 2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for m, c, w, s in itertools.product(block_m, block_c, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] + return configs + + @triton.autotune( configs=sinkhorn_config(), key=["M"], @@ -854,25 +1060,21 @@ def _mhc_sinkhorn_bwd_fused( ) -def aggregate_config(): - block_m = [1, 2, 4] - block_c = [64, 128, 256] - warps = [1, 2, 4] - stages = [1, 2, 3, 4] +def aggregate_prune_fwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) - configs = [] - for m, c, w, s in itertools.product(block_m, block_c, warps, stages): - configs.append( - triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] - return configs + ) + return pruned_configs @triton.autotune( - configs=aggregate_config(), + configs=aggregate_config_fwd(), key=["M", "C"], + prune_configs_by={"early_config_prune": aggregate_prune_fwd}, ) @triton.jit def _mhc_aggregate_fwd( @@ -949,7 +1151,51 @@ def _mhc_aggregate_fwd( tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_c[None, :]) -@triton.autotune(configs=aggregate_config(), key=["M", "C"], reset_to_zero=["grad_H_pre_ptr"]) +def aggregate_config_bwd(): + block_m = [1, 2, 4] + block_c = [64, 128, 256] + step_c = [32, 64] + warps = [1, 2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for bm, bc, sc, w, s in itertools.product(block_m, block_c, step_c, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": bm, "BLOCK_SIZE_C": bc, "STEP_SIZE_C": sc}, + num_warps=w, + num_stages=s, + ) + ) + return configs + + +def aggregate_prune_bwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) + + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, + configs, + ) + ) + + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs + + +@triton.autotune( + configs=aggregate_config_bwd(), + key=["M", "C"], + reset_to_zero=["grad_H_pre_ptr"], + # When FUSE_GRAD_X_ACC=True the kernel does a read-modify-write on grad_x_ptr; without + # restore_value the autotune timing trials accumulate onto the buffer and corrupt it. + restore_value=["grad_x_ptr"], + prune_configs_by={"early_config_prune": aggregate_prune_bwd}, +) @triton.jit def _mhc_aggregate_bwd( grad_output_ptr, # (M, C) @@ -969,7 +1215,9 @@ def _mhc_aggregate_bwd( # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + STEP_SIZE_C: tl.constexpr, precision: tl.constexpr, + FUSE_GRAD_X_ACC: tl.constexpr, ): """ Forward: @@ -992,38 +1240,14 @@ def _mhc_aggregate_bwd( tl.assume(stride_grad_output_m > 0 and stride_grad_output_c == 1) tl.assume(BLOCK_SIZE_C % 32 == 0) + tl.assume(STEP_SIZE_C % 32 == 0) + tl.assume(BLOCK_SIZE_C % STEP_SIZE_C == 0) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - - grad_output_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_c[None, :] * stride_grad_output_c - ) - grad_output = tl.load( - grad_output_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C) - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - grad_H_pre = tl.dot( - tl.reshape(grad_output, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) - grad_H_pre = tl.reshape(grad_H_pre, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) - offs_grad_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - grad_H_pre_ptrs = grad_H_pre_ptr + offs_grad_H_pre - tl.atomic_add(grad_H_pre_ptrs, grad_H_pre, mask=offs_grad_H_pre < M * n, sem="relaxed") + offs_c_start = pid_c * BLOCK_SIZE_C + offs_cn_start = pid_c * BLOCK_SIZE_C * n H_pre_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_pre = tl.load( @@ -1031,19 +1255,59 @@ def _mhc_aggregate_bwd( ) # (BLOCK_SIZE_M * n) H_pre = tl.reshape(H_pre, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - grad_x = grad_output[:, :, None] * H_pre[:, None, :] # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) + grad_H_pre_acc = tl.zeros((BLOCK_SIZE_M, 1, n), dtype=tl.float32) + for i in tl.range(0, BLOCK_SIZE_C, STEP_SIZE_C, loop_unroll_factor=2): + offs_c = offs_c_start + i + tl.arange(0, STEP_SIZE_C) + offs_cn = offs_cn_start + i * n + tl.arange(0, STEP_SIZE_C * n) + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + grad_output_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_c[None, :] * stride_grad_output_c + ) + grad_output = tl.load( + grad_output_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C) - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store( - grad_x_ptrs, - grad_x, - mask=mask_m[:, None] & mask_cn[None, :], - ) + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C * n) + + grad_H_pre_acc = tl.dot( + tl.reshape(grad_output, (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(x, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_pre_acc, + input_precision=precision, + out_dtype=tl.float32, + ) + + # grad_x = grad_output @ H_pre.T: (BLOCK_SIZE_M, STEP_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, STEP_SIZE_C, n) + grad_x = grad_output[:, :, None] * H_pre[:, None, :] # (BLOCK_SIZE_M, STEP_SIZE_C, n) + grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, STEP_SIZE_C * n)) + + grad_x_ptrs = ( + grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + ) + if FUSE_GRAD_X_ACC: # If fused gradient accumulation is enabled, the buffer is always fp32 + grad_x_acc = tl.load(grad_x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0) + grad_x = grad_x.to(tl.float32) + grad_x_acc + tl.store( + grad_x_ptrs, + grad_x, + mask=mask_m[:, None] & mask_cn[None, :], + ) + + grad_H_pre = tl.reshape(grad_H_pre_acc, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + offs_grad_H_pre = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) + grad_H_pre_ptrs = grad_H_pre_ptr + offs_grad_H_pre + tl.atomic_add(grad_H_pre_ptrs, grad_H_pre, mask=offs_grad_H_pre < M * n, sem="relaxed") -def expand_combine_config(): + +def expand_combine_config_fwd(): block_m = [1, 2, 4] block_c = [128, 256] warps = [1, 2] @@ -1054,18 +1318,34 @@ def expand_combine_config(): configs.append( triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] return configs +def expand_combine_prune_fwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) + + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs + ) + ) + + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs + + @triton.autotune( - configs=expand_combine_config(), + configs=expand_combine_config_fwd(), key=["M", "C"], + prune_configs_by={"early_config_prune": expand_combine_prune_fwd}, ) @triton.jit def _mhc_expand_combine_fwd( f_ptr, # (M, C) + bias_ptr, # (C,), or None if HAS_BIAS is False H_post_ptr, # (M, n) x_ptr, # (M, C, n) H_res_ptr, # (M, n, n) @@ -1075,6 +1355,7 @@ def _mhc_expand_combine_fwd( n: tl.constexpr, stride_fm, stride_fc, + stride_bias, # Not used if HAS_BIAS is False stride_xm, stride_xCn, stride_output_m, @@ -1082,9 +1363,10 @@ def _mhc_expand_combine_fwd( # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + HAS_BIAS: tl.constexpr, ): """ - output = f @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + output = (f + bias[None, :, None]) @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) """ pid_m = tl.program_id(1) @@ -1095,6 +1377,7 @@ def _mhc_expand_combine_fwd( tl.assume(C > 0) tl.assume(n == 4) tl.assume(stride_fm > 0 and stride_fc == 1) + tl.assume(stride_bias == 1) tl.assume(stride_xm > 0 and stride_xCn == 1) tl.assume(stride_output_m > 0 and stride_output_Cn == 1) @@ -1109,6 +1392,8 @@ def _mhc_expand_combine_fwd( f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + if HAS_BIAS: + bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_post = tl.load( @@ -1116,10 +1401,12 @@ def _mhc_expand_combine_fwd( ) H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # Residual connection path: res_out = f @ H_post: + # Residual connection path: res_out = f @ H_post + bias @ H_post: # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) # Due to broadcasting, it's equivalent to a multiplicaiton out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) + if HAS_BIAS: + out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc) out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) @@ -1167,332 +1454,59 @@ def _mhc_expand_combine_fwd( tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) -@triton.autotune( - configs=expand_combine_config(), - key=["M", "C"], - reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr"], -) -@triton.jit -def _mhc_expand_combine_bwd( - grad_output_ptr, # (M, C, n) - f_ptr, # (M, C) - H_post_ptr, # (M, n) - x_ptr, # (M, C, n) - H_res_ptr, # (M, n, n) - grad_H_post_ptr, # (M, n) - grad_f_ptr, # (M, C) - grad_H_res_ptr, # (M, n, n) - grad_x_ptr, # (M, C, n) - M, - C, - n: tl.constexpr, - stride_grad_output_m, - stride_grad_output_Cn, - stride_fm, - stride_fc, - stride_xm, - stride_xCn, - stride_grad_fm, - stride_grad_fc, - stride_grad_xm, - stride_grad_xCn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_C: tl.constexpr, - precision: tl.constexpr, -): - """ - Each block - It reads - - (BLOCK_SIZE_M, BLOCK_SIZE_C) of f, which is the output of the attention / FFN module - - (BLOCK_SIZE_M, n) of H_post, which is applied for the transformation of the attention / FFN output - - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of x, which is the skip connection's input - - (BLOCK_SIZE_M, n*n) of H_res, which is applied for the transformation of the skip connection - and writes - - (BLOCK_SIZE_M, n) of grad_H_post - - (BLOCK_SIZE_M, BLOCK_SIZE_C) of grad_f - - (BLOCK_SIZE_M, n, n) of grad_H_res - - (BLOCK_SIZE_M, BLOCK_SIZE_C, n) of grad_x - - Forward: - out = f @ H_post + x @ H_res - Backward: - GEMM: - grad_H_post = f.T @ grad_output: (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - Not GEMM: - grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, 1) = (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) - grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - """ - - pid_m = tl.program_id(1) - pid_c = tl.program_id(0) - - tl.static_assert(n == 4) - tl.assume(M > 0) - tl.assume(C > 0) - tl.assume(n == 4) - tl.assume(stride_fm > 0 and stride_fc == 1) - tl.assume(stride_xm > 0 and stride_xCn == 1) - tl.assume(stride_grad_output_m > 0 and stride_grad_output_Cn == 1) - tl.assume(stride_grad_fm > 0 and stride_grad_fc == 1) - tl.assume(stride_grad_xm > 0 and stride_grad_xCn == 1) - - tl.assume(BLOCK_SIZE_C % 32 == 0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) - mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) - - H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - - H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - H_res = tl.load( - H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 - ) # (BLOCK_SIZE_M, n, n) - H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) - - grad_out_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_cn[None, :] * stride_grad_output_Cn - ) - grad_out = tl.load( - grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - grad_out = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) - offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post - tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - grad_H_res = tl.dot( - tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 - ) # (BLOCK_SIZE_M, n, n) - grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) - offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res - tl.atomic_add( - grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" - ) - - grad_out_reshape = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - grad_out01, grad_out23 = tl.split( - grad_out_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - grad_out0, grad_out1 = tl.split( - grad_out01 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_out2, grad_out3 = tl.split( - grad_out23 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) - # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: - # grad_f = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) - # + grad_out[:, :, 1] @ H_post.T[:, 1, :] - # + grad_out[:, :, 2] @ H_post.T[:, 2, :] - # + grad_out[:, :, 3] @ H_post.T[:, 3, :] - # where H_post.T[:, i, :] = H_post[:, :, i] - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) - H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) - H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - - grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) - # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) - grad_f = grad_f_acc.to(f.dtype) - - grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc - tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) - - # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul - # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] - # + grad_out[:, :, 1] @ H_res.T[:, 1, :] - # + grad_out[:, :, 2] @ H_res.T[:, 2, :] - # + grad_out[:, :, 3] @ H_res.T[:, 3, :] - # where H_res.T[:, i, :] = H_res[:, :, i] - # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] - - H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) - H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - - grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) - - grad_x = grad_x_acc.to(x.dtype) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) - - -@triton.autotune( - configs=expand_combine_config(), - key=["M", "C"], -) -@triton.jit -def _mhc_expand_combine_with_bias_fwd( - f_ptr, # (M, C) - bias_ptr, # (C,) - H_post_ptr, # (M, n) - x_ptr, # (M, C, n) - H_res_ptr, # (M, n, n) - output_ptr, # # (M, C, n) - M, - C, - n: tl.constexpr, - stride_fm, - stride_fc, - stride_bias, - stride_xm, - stride_xCn, - stride_output_m, - stride_output_Cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_C: tl.constexpr, -): - """ - output = (f + bias[None, :, None]) @ H_post: (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - + x @ H_res: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - """ - pid_m = tl.program_id(1) - pid_c = tl.program_id(0) - - tl.static_assert(n == 4) - tl.assume(M > 0) - tl.assume(C > 0) - tl.assume(n == 4) - tl.assume(stride_fm > 0 and stride_fc == 1) - tl.assume(stride_bias == 1) - tl.assume(stride_xm > 0 and stride_xCn == 1) - tl.assume(stride_output_m > 0 and stride_output_Cn == 1) - - tl.assume(BLOCK_SIZE_C % 32 == 0) - - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) - mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n +def expand_combine_config_bwd(): + block_m = [1, 2, 4] + block_c = [128, 256] + step_c = [32, 64] + warps = [1, 2] + stages = [1, 2, 3, 4] - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) - bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + configs = [] + for m, c, sc, w, s in itertools.product(block_m, block_c, step_c, warps, stages): + configs.append( + triton.Config( + {"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c, "STEP_SIZE_C": sc}, + num_warps=w, + num_stages=s, + ) + ) + return configs - offs_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) - H_post = tl.load( - H_post_ptr + offs_H_post, mask=offs_H_post < M * n, other=0.0, cache_modifier=".ca" - ) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) - # Residual connection path: res_out = f @ H_post + bias @ H_post: - # (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # Due to broadcasting, it's equivalent to a multiplicaiton - out_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - out_acc = tl.fma(bias[None, :, None], H_post[:, None, :], out_acc) - out_acc = tl.fma(f[:, :, None], H_post[:, None, :], out_acc) +def expand_combine_prune_bwd(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) - H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) - H_res = tl.load( - H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0, cache_modifier=".ca" + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, + configs, + ) ) - H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) - - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # Manifold connection path: manifold_out = H_res @ x: - # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - # triton doesn't support dot prod with inner dimension < 16, so we need to manually unroll the computation for n=4: - # x @ H_res = x[:, :, 0] @ H_res[:, 0, :] - # + x[:, :, 1] @ H_res[:, 1, :] - # + x[:, :, 2] @ H_res[:, 2, :] - # + x[:, :, 3] @ H_res[:, 3, :] - - x_reshape = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2)) - x01, x23 = tl.split( - x_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - x0, x1 = tl.split(x01) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - x2, x3 = tl.split(x23) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - H_resT = tl.reshape(tl.trans(H_res, (0, 2, 1)), (BLOCK_SIZE_M, n, 2, 2)) - H_res01, H_res23 = tl.split(H_resT) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - out_acc = tl.fma(x0[:, :, None], H_res0[:, None, :], out_acc) - out_acc = tl.fma(x1[:, :, None], H_res1[:, None, :], out_acc) - out_acc = tl.fma(x2[:, :, None], H_res2[:, None, :], out_acc) - out_acc = tl.fma(x3[:, :, None], H_res3[:, None, :], out_acc) - - out = out_acc.to(x.dtype) - out = tl.reshape(out, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - output_ptrs = ( - output_ptr + offs_m[:, None] * stride_output_m + offs_cn[None, :] * stride_output_Cn - ) - tl.store(output_ptrs, out, mask=mask_m[:, None] & mask_cn[None, :]) + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + pruned_configs = pruned_configs[:1] + return pruned_configs @triton.autotune( - configs=expand_combine_config(), + configs=expand_combine_config_bwd(), key=["M", "C"], reset_to_zero=["grad_H_post_ptr", "grad_H_res_ptr", "grad_bias_ptr"], + prune_configs_by={"early_config_prune": expand_combine_prune_bwd}, ) @triton.jit -def _mhc_expand_combine_with_bias_bwd( +def _mhc_expand_combine_bwd( grad_output_ptr, # (M, C, n) f_ptr, # (M, C) - bias_ptr, # (C,) + bias_ptr, # (C,), or None if HAS_BIAS is False H_post_ptr, # (M, n) x_ptr, # (M, C, n) H_res_ptr, # (M, n, n) grad_H_post_ptr, # (M, n) grad_f_ptr, # (M, C) - grad_bias_ptr, # (C,) + grad_bias_ptr, # (C,), or None if HAS_BIAS is False grad_H_res_ptr, # (M, n, n) grad_x_ptr, # (M, C, n) M, @@ -1502,18 +1516,20 @@ def _mhc_expand_combine_with_bias_bwd( stride_grad_output_Cn, stride_fm, stride_fc, - stride_bias, + stride_bias, # Not used if HAS_BIAS is False stride_xm, stride_xCn, stride_grad_fm, stride_grad_fc, - stride_grad_bias, + stride_grad_bias, # Not used if HAS_BIAS is False stride_grad_xm, stride_grad_xCn, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_C: tl.constexpr, + STEP_SIZE_C: tl.constexpr, precision: tl.constexpr, + HAS_BIAS: tl.constexpr, ): """ Each block @@ -1557,137 +1573,166 @@ def _mhc_expand_combine_with_bias_bwd( tl.assume(BLOCK_SIZE_C % 32 == 0) offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_c = pid_c * BLOCK_SIZE_C + tl.arange(0, BLOCK_SIZE_C) - offs_cn = pid_c * BLOCK_SIZE_C * n + tl.arange(0, BLOCK_SIZE_C * n) mask_m = offs_m < M - mask_c = offs_c < C - mask_cn = offs_cn < C * n - f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc - f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + offs_c_start = pid_c * BLOCK_SIZE_C + offs_cn_start = pid_c * BLOCK_SIZE_C * n - bias = tl.load(bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0) # (BLOCK_SIZE_C,) + grad_H_post_acc = tl.zeros((BLOCK_SIZE_M, 1, n), dtype=tl.float32) + grad_H_res_acc = tl.zeros((BLOCK_SIZE_M, n, n), dtype=tl.float32) H_post_offs = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) H_post = tl.load(H_post_ptr + H_post_offs, mask=H_post_offs < M * n, other=0.0) - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, n)) # (BLOCK_SIZE_M, n) + H_post_reshape = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) + H_post01, H_post23 = tl.split(H_post_reshape) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) + H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) + H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) H_res_offs = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) H_res = tl.load( H_res_ptr + H_res_offs, mask=H_res_offs < M * n * n, other=0.0 ) # (BLOCK_SIZE_M, n, n) H_res = tl.reshape(H_res, (BLOCK_SIZE_M, n, n)) # (BLOCK_SIZE_M, n, n) + H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) + H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) + H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) + H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - grad_out_ptrs = ( - grad_output_ptr - + offs_m[:, None] * stride_grad_output_m - + offs_cn[None, :] * stride_grad_output_Cn - ) - grad_out = tl.load( - grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C * n) - grad_out = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) + for i in tl.range(0, BLOCK_SIZE_C, STEP_SIZE_C, loop_unroll_factor=2): + offs_c = offs_c_start + i + tl.arange(0, STEP_SIZE_C) + offs_cn = offs_cn_start + i * n + tl.arange(0, STEP_SIZE_C * n) + mask_c = offs_c < C + mask_cn = offs_cn < C * n + + f_ptrs = f_ptr + offs_m[:, None] * stride_fm + offs_c[None, :] * stride_fc + f = tl.load(f_ptrs, mask=mask_m[:, None] & mask_c[None, :], other=0.0) + + if HAS_BIAS: + bias = tl.load( + bias_ptr + offs_c * stride_bias, mask=mask_c, other=0.0 + ) # (STEP_SIZE_C,) + + grad_out_ptrs = ( + grad_output_ptr + + offs_m[:, None] * stride_grad_output_m + + offs_cn[None, :] * stride_grad_output_Cn + ) + grad_out = tl.load( + grad_out_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C * n) + grad_out = tl.reshape( + grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n) + ) # (BLOCK_SIZE_M, STEP_SIZE_C, n) + + # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, STEP_SIZE_C) @ (BLOCK_SIZE_M, STEP_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) + grad_H_post_acc = tl.dot( + tl.reshape(f, (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_post_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + if HAS_BIAS: + grad_H_post_acc = tl.dot( + tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, STEP_SIZE_C)), + tl.reshape(grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, n)), + acc=grad_H_post_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, 1, n) + + x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn + x = tl.load( + x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 + ) # (BLOCK_SIZE_M, STEP_SIZE_C*n) + x = tl.reshape(x, (BLOCK_SIZE_M, STEP_SIZE_C, n)) # (BLOCK_SIZE_M, STEP_SIZE_C, n) + + # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, STEP_SIZE_C) @ (BLOCK_SIZE_M, STEP_SIZE_C, n) = (BLOCK_SIZE_M, n, n) + grad_H_res_acc = tl.dot( + tl.trans(x, (0, 2, 1)), + grad_out, + acc=grad_H_res_acc, + input_precision=precision, + out_dtype=tl.float32, + ) # (BLOCK_SIZE_M, n, n) + + grad_out_reshape = tl.reshape( + grad_out, (BLOCK_SIZE_M, STEP_SIZE_C, 2, 2) + ) # (BLOCK_SIZE_M, STEP_SIZE_C, 2, 2) + grad_out01, grad_out23 = tl.split( + grad_out_reshape + ) # (BLOCK_SIZE_M, STEP_SIZE_C, 2), (BLOCK_SIZE_M, STEP_SIZE_C, 2) + grad_out0, grad_out1 = tl.split( + grad_out01 + ) # (BLOCK_SIZE_M, STEP_SIZE_C), (BLOCK_SIZE_M, STEP_SIZE_C) + grad_out2, grad_out3 = tl.split( + grad_out23 + ) # (BLOCK_SIZE_M, STEP_SIZE_C), (BLOCK_SIZE_M, STEP_SIZE_C) + + # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, STEP_SIZE_C) = (BLOCK_SIZE_M, 1, STEP_SIZE_C) + # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: + # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, STEP_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) + # + grad_out[:, :, 1] @ H_post.T[:, 1, :] + # + grad_out[:, :, 2] @ H_post.T[:, 2, :] + # + grad_out[:, :, 3] @ H_post.T[:, 3, :] + # where H_post.T[:, i, :] = H_post[:, :, i] + + grad_f_acc = tl.zeros((BLOCK_SIZE_M, STEP_SIZE_C), dtype=tl.float32) + # (BLOCK_SIZE_M, STEP_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, STEP_SIZE_C) + grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) + grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) + grad_f = grad_f_acc.to(f.dtype) + + grad_f_ptrs = ( + grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc + ) + tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) + + if HAS_BIAS: + grad_bias = tl.sum(grad_f_acc, axis=0) # (STEP_SIZE_C,) + # This is reduction over M dimension, so it has nothing to do with whether we use split-C. It only depends on determinism or not. + grad_bias_ptrs = grad_bias_ptr + offs_c * stride_grad_bias + tl.atomic_add(grad_bias_ptrs, grad_bias, mask=mask_c, sem="relaxed") + + # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, STEP_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, STEP_SIZE_C) + # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul + # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] + # + grad_out[:, :, 1] @ H_res.T[:, 1, :] + # + grad_out[:, :, 2] @ H_res.T[:, 2, :] + # + grad_out[:, :, 3] @ H_res.T[:, 3, :] + # where H_res.T[:, i, :] = H_res[:, :, i] + # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] + + grad_x_acc = tl.zeros((BLOCK_SIZE_M, STEP_SIZE_C, n), dtype=tl.float32) + grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) + grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) + + grad_x = grad_x_acc.to(x.dtype) + grad_x = tl.reshape( + grad_x, (BLOCK_SIZE_M, STEP_SIZE_C * n) + ) # (BLOCK_SIZE_M, STEP_SIZE_C*n) + + grad_x_ptrs = ( + grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn + ) + tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) - # grad_H_post = f.T @ grad_output # (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.reshape(f, (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.dot( - tl.broadcast_to(bias[None, None, :], (BLOCK_SIZE_M, 1, BLOCK_SIZE_C)), - tl.reshape(grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)), - acc=grad_H_post, - input_precision=precision, - out_dtype=tl.float32, - ) # (BLOCK_SIZE_M, 1, n) - grad_H_post = tl.reshape(grad_H_post, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) + grad_H_post = tl.reshape(grad_H_post_acc, (BLOCK_SIZE_M * n,)) # (BLOCK_SIZE_M * n) offs_grad_H_post = pid_m * BLOCK_SIZE_M * n + tl.arange(0, BLOCK_SIZE_M * n) grad_H_post_ptrs = grad_H_post_ptr + offs_grad_H_post - tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") - x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_cn[None, :] * stride_xCn - x = tl.load( - x_ptrs, mask=mask_m[:, None] & mask_cn[None, :], other=0.0 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - x = tl.reshape(x, (BLOCK_SIZE_M, BLOCK_SIZE_C, n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C, n) - - # grad_H_res = x.T @ grad_output: (BLOCK_SIZE_M, n, BLOCK_SIZE_C) @ (BLOCK_SIZE_M, BLOCK_SIZE_C, n) = (BLOCK_SIZE_M, n, n) - grad_H_res = tl.dot( - tl.trans(x, (0, 2, 1)), grad_out, input_precision=precision, out_dtype=tl.float32 - ) # (BLOCK_SIZE_M, n, n) - grad_H_res = tl.reshape(grad_H_res, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) + grad_H_res = tl.reshape(grad_H_res_acc, (BLOCK_SIZE_M * n * n,)) # (BLOCK_SIZE_M * n * n) offs_grad_H_res = pid_m * BLOCK_SIZE_M * n * n + tl.arange(0, BLOCK_SIZE_M * n * n) grad_H_res_ptrs = grad_H_res_ptr + offs_grad_H_res + + tl.atomic_add(grad_H_post_ptrs, grad_H_post, mask=offs_grad_H_post < M * n, sem="relaxed") tl.atomic_add( - grad_H_res_ptrs, grad_H_res.to(tl.float32), mask=offs_grad_H_res < M * n * n, sem="relaxed" + grad_H_res_ptrs, + grad_H_res.to(tl.float32), + mask=offs_grad_H_res < M * n * n, + sem="relaxed", ) - - grad_out_reshape = tl.reshape( - grad_out, (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2, 2) - grad_out01, grad_out23 = tl.split( - grad_out_reshape - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C, 2), (BLOCK_SIZE_M, BLOCK_SIZE_C, 2) - grad_out0, grad_out1 = tl.split( - grad_out01 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_out2, grad_out3 = tl.split( - grad_out23 - ) # (BLOCK_SIZE_M, BLOCK_SIZE_C), (BLOCK_SIZE_M, BLOCK_SIZE_C) - - # grad_f = grad_output @ H_post.T: (BLOCK_SIZE_M, 1, n) @ (BLOCK_SIZE_M, n, BLOCK_SIZE_C) = (BLOCK_SIZE_M, 1, BLOCK_SIZE_C) - # Triton doesn't support dot prod with inner dimension < 16, so we need to hack this: - # = grad_out[:, :, 0] @ H_post.T[:, 0, :] (BLOCK_SIZE_M, BLOCK_SIZE_C, 1) @ (BLOCK_SIZE_M, 1, 1) - # + grad_out[:, :, 1] @ H_post.T[:, 1, :] - # + grad_out[:, :, 2] @ H_post.T[:, 2, :] - # + grad_out[:, :, 3] @ H_post.T[:, 3, :] - # where H_post.T[:, i, :] = H_post[:, :, i] - H_post = tl.reshape(H_post, (BLOCK_SIZE_M, 2, 2)) - H_post01, H_post23 = tl.split(H_post) # (BLOCK_SIZE_M, 2), (BLOCK_SIZE_M, 2) - H_post0, H_post1 = tl.split(H_post01) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - H_post2, H_post3 = tl.split(H_post23) # (BLOCK_SIZE_M,), (BLOCK_SIZE_M,) - - grad_f_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C), dtype=tl.float32) - # (BLOCK_SIZE_M, BLOCK_SIZE_C) * (BLOCK_SIZE_M, 1) -> (BLOCK_SIZE_M, BLOCK_SIZE_C) - grad_f_acc = tl.fma(grad_out0, H_post0[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out1, H_post1[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out2, H_post2[:, None], grad_f_acc) - grad_f_acc = tl.fma(grad_out3, H_post3[:, None], grad_f_acc) - grad_f = grad_f_acc.to(f.dtype) - - grad_f_ptrs = grad_f_ptr + offs_m[:, None] * stride_grad_fm + offs_c[None, :] * stride_grad_fc - tl.store(grad_f_ptrs, grad_f, mask=mask_m[:, None] & mask_c[None, :]) - - grad_bias = tl.sum(grad_f_acc, axis=0) # (BLOCK_SIZE_C,) - grad_bias_ptrs = grad_bias_ptr + offs_c * stride_grad_bias - tl.atomic_add(grad_bias_ptrs, grad_bias, mask=mask_c, sem="relaxed") - - # grad_x = grad_output @ H_res.T: (BLOCK_SIZE_M, BLOCK_SIZE_C, n) @ (BLOCK_SIZE_M, n, n) = (BLOCK_SIZE_M, n, BLOCK_SIZE_C) - # The inner dim is n=4 which is too small for triton, so we will manually unroll the matmul - # grad_x = grad_out[:, :, 0] @ H_res.T[:, 0, :] - # + grad_out[:, :, 1] @ H_res.T[:, 1, :] - # + grad_out[:, :, 2] @ H_res.T[:, 2, :] - # + grad_out[:, :, 3] @ H_res.T[:, 3, :] - # where H_res.T[:, i, :] = H_res[:, :, i] - # Due to broadcasting, it's equivalent to multiplying each H_res[:, i, :].T with grad_out[:, i, :] - - H_res_reshape = tl.reshape(H_res, (BLOCK_SIZE_M, n, 2, 2)) # (BLOCK_SIZE_M, n, 2, 2) - H_res01, H_res23 = tl.split(H_res_reshape) # (BLOCK_SIZE_M, n, 2), (BLOCK_SIZE_M, n, 2) - H_res0, H_res1 = tl.split(H_res01) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - H_res2, H_res3 = tl.split(H_res23) # (BLOCK_SIZE_M, n), (BLOCK_SIZE_M, n) - - grad_x_acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_C, n), dtype=tl.float32) - grad_x_acc = tl.fma(grad_out0[:, :, None], H_res0[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out1[:, :, None], H_res1[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) - grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) - - grad_x = grad_x_acc.to(x.dtype) - grad_x = tl.reshape(grad_x, (BLOCK_SIZE_M, BLOCK_SIZE_C * n)) # (BLOCK_SIZE_M, BLOCK_SIZE_C*n) - - grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_cn[None, :] * stride_grad_xCn - tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_cn[None, :]) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index 987216e327..d95eeb3f63 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -5,27 +5,33 @@ """PyTorch wrapper functions for mHC (manifold Hyper-Connection) Triton kernels.""" import os +from typing import Optional import torch import triton from transformer_engine.common.triton.mhc import ( + _mhc_projection_bwd_fused_dphi, + _mhc_projection_bwd_fused_dx, _mhc_scale_fwd_fused, _mhc_scale_bwd_fused, - _mhc_expand_combine_with_bias_fwd, - _mhc_expand_combine_with_bias_bwd, _mhc_expand_combine_fwd, _mhc_expand_combine_bwd, _mhc_aggregate_fwd, _mhc_aggregate_bwd, _mhc_projection_fwd_fused, - _mhc_projection_bwd_fused, - _mhc_sinkhorn_fwd_fused, _mhc_sinkhorn_fwd_fused_recompute, - _mhc_sinkhorn_bwd_fused, _mhc_sinkhorn_bwd_fused_recompute, + _mhc_sinkhorn_fwd_fused, + _mhc_sinkhorn_bwd_fused, ) from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm +_SUPPORT_TMA = torch.cuda.get_device_capability()[0] >= 9 + + +def _tma_aligned(t): + return (t.stride(0) * t.element_size()) % 16 == 0 and t.data_ptr() % 16 == 0 + def check_deterministic(operator: str): """ @@ -39,6 +45,101 @@ def check_deterministic(operator: str): ) +def mhc_generate_mix_and_aggregate( + x: torch.Tensor, + phi: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, + norm_weight: Optional[torch.Tensor] = None, + use_tf32: bool = True, + fuse_grad_x_acc: bool = False, +): + """ + Generate the mix matrix H_pre, H_post, H_res and apply H_pre to x to aggregate n streams + This wraps projection, scale, sinkhorn, and aggregate operations into one function. + + To use mHC in your model: + ``` + layer_input, H_post, H_res = mhc_generate_mix_and_aggregate(x, phi, alpha, beta) + layer_output = layer(layer_input) # Attn / FFN layer + x = mhc_fused_expand_combine(layer_input, bias, H_post, x, H_res) + ``` + + This API accepts both BF16 and FP32 parameters, though the DeepSeek V4 recipe is: + - x: BF16 + - phi, alpha, beta: FP32 + + Parameters + ---------- + x : torch.Tensor, + input tensor of shape (s, b, C, n), where s is the sequence length, b is the batch size, C is the hidden dimension per hyper connection, and n is the number of hyper connections, + dtype is torch.bfloat16 or torch.float32 + Note that C is equal to the original hidden dimension divided by n. + phi : torch.Tensor + projection matrix of shape (N, nC), where N=2n+n*n (=24 for n=4), and nC is the hidden dimension after expansion (n times of C), + dtype is torch.bfloat16 or torch.float32 + norm_weight : torch.Tensor or None + optional, the weight for RMSNorm, of shape (K,), which is the learnable per-element affine parameters (gamma) applied to RMSNorm + dtype is torch.bfloat16 or torch.float32 + alpha : torch.Tensor + scaling factor for H, of shape (3,), where + alpha[0] is applied to H[:, 0:n] for H_pre + alpha[1] is applied to H[:, n:2n] for H_post + alpha[2] is applied to H[:, 2n:2n+n*n] for H_res + dtype: torch.bfloat16 or torch.float32 + beta : torch.Tensor + bias term for H, of shape (1, 2*n+n*n), where + beta[0, 0:n] is applied to H[:, 0:n] for H_pre + beta[0, n:2n] is applied to H[:, n:2n] for H_post + beta[0, 2n:2n+n*n] is applied to H[:, 2n:2n+n*n] for H_res + dtype is torch.bfloat16 or torch.float32 + use_tf32 : bool + whether to use TF32 for matrix multiplications + fuse_grad_x_acc : bool + Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. + Note: you must enable this flag for both `mhc_generate_mix_and_aggregate` and `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. + + Returns + ------- + out : torch.Tensor + out of shape (s, b, C), which is the aggregated result after applying H_pre to x, which will be fed into attention / FFN + with the same dtype as x + H_post : torch.Tensor + H_post of shape (s, b, n), which will be used in the post-processing after attention / FFN in `mhc_fused_expand_combine` + with dtype float32 + H_res : torch.Tensor + H_res of shape (s, b, n, n), which will be used to mix the residual connection in `mhc_fused_expand_combine` + with dtype float32 + """ + check_deterministic("mhc_generate_mix_and_aggregate") + s, b, C, n = x.shape + assert ( + n == 4 + ), "Only n=4 is supported in this implementation, where n is the Hyper Connection number" + nC = n * C + H, ms = mhc_fused_projection( + x.view(s * b, nC), + phi, + norm_weight=norm_weight, + use_tf32=use_tf32, + fuse_grad_x_acc=fuse_grad_x_acc, + ) + h_pre, h_post, h_res = mhc_fused_scale(H, alpha, beta, ms, n) + H_pre = h_pre.view(s, b, n) + H_post = h_post.view(s, b, n) + H_res = h_res.view(s, b, n, n) + H_res = mhc_fused_sinkhorn(H_res, n, recompute_hist=True, iters=20) + out = mhc_fused_aggregate( + x, + H_pre.view(s, b, n), + n, + use_tf32=use_tf32, + fuse_grad_x_acc=fuse_grad_x_acc, + ) + return out, H_post, H_res + + def mhc_fused_sinkhorn( H_res: torch.Tensor, n: int = 4, recompute_hist: bool = True, iters: int = 20 ): @@ -52,6 +153,7 @@ def mhc_fused_sinkhorn( ---------- H_res : torch.Tensor input H_res matrix of shape (s, b, n, n) that needs to be normalized into a doubly stochastic matrix. + dtype is torch.bfloat16 or torch.float32 n : int number of hyper connections, where only n=4 is supported in the current implementation recompute_hist : bool @@ -63,6 +165,7 @@ def mhc_fused_sinkhorn( ------- out : torch.Tensor out of shape (s, b, n, n), which is the final H_res after Sinkhorn normalization + with the same dtype as H_res """ assert n == 4, "Only n=4 is supported in this implementation" out = mHCSinkhornOp.apply(H_res, n, recompute_hist, iters) @@ -70,7 +173,11 @@ def mhc_fused_sinkhorn( def mhc_fused_scale( - H: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor, ms: torch.Tensor, n: int + H: torch.Tensor, + alpha: torch.Tensor, + beta: torch.Tensor, + ms: torch.Tensor, + n: int, ): """ Fused scale operation to compute the scaled H matrices (see eq. 16-18, section 4.3.1 of the DeepSeek mHC paper): @@ -96,6 +203,7 @@ def mhc_fused_scale( beta[0, 0:n] is applied to H[:, 0:n] for H_pre beta[0, n:2n] is applied to H[:, n:2n] for H_post beta[0, 2n:2n+n*n] is applied to H[:, 2n:2n+n*n] for H_res + Note: we assume alpha and beta have the same dtype, and according to the DeepSeek paper they should be fp32 ms : torch.Tensor mean square for each row of H from the projection kernel, of shape (M,), used for RMSNorm scaling n : int @@ -104,15 +212,17 @@ def mhc_fused_scale( Returns ------- h_pre : torch.Tensor - Scaled H_pre of shape (M, n), which aggregates (s, b, C, n) input of a Hyper Connection block into (s, b, n) as the input of attention / MLP + Scaled H_pre of shape (M, n), which aggregates (s, b, C, n) input of a Hyper Connection block into (s, b, n) as the input of attention / MLP, + with the same dtype as H h_post : torch.Tensor - Scaled H_post of shape (M, n), which expands the output of attention / MLP of shape (s, b, n) back to (s, b, C, n) for the residual connection + Scaled H_post of shape (M, n), which expands the output of attention / MLP of shape (s, b, n) back to (s, b, C, n) for the residual connection, + with the same dtype as H h_res : torch.Tensor - Scaled H_res of shape (M, n*n), which mixes the n streams of the (s, b, C, n) input of a Hyper Connection block + Scaled H_res of shape (M, n*n), which mixes the n streams of the (s, b, C, n) input of a Hyper Connection block, + with the same dtype as H """ assert n == 4, "Only n=4 is supported in this implementation" - check_deterministic("mhc_fused_scale") out = mHCScaleFusedOp.apply(H, alpha, beta, ms, n) h_pre = out[..., :n] h_post = out[..., n : 2 * n] @@ -120,7 +230,13 @@ def mhc_fused_scale( return h_pre, h_post, h_res -def mhc_fused_aggregate(x: torch.Tensor, H_pre: torch.Tensor, n: int, use_tf32: bool = True): +def mhc_fused_aggregate( + x: torch.Tensor, + H_pre: torch.Tensor, + n: int, + use_tf32: bool = True, + fuse_grad_x_acc: bool = False, +): """ Aggregate operation to merge n activation streams into one (see section 4.3.1 of the DeepSeek mHC paper): out = x @ H_pre: (s, b, C, n) @ (s, b, n, 1) -> (s, b, C, 1) -> (s, b, C) after squeezing the last dimension @@ -130,22 +246,29 @@ def mhc_fused_aggregate(x: torch.Tensor, H_pre: torch.Tensor, n: int, use_tf32: x : torch.Tensor input activation tensor of shape (s, b, C, n), where s is the sequence length, b is the batch size, C is the hidden dimension per hyper connection, and n is the number of hyper connections. Note that C is equal to the original hidden dimension divided by n. + dtype is torch.bfloat16 or torch.float32 H_pre: torch.Tensor input H_pre matrix of shape (s, b, n) + dtype is torch.bfloat16 or torch.float32 n: int number of hyper connections, where only n=4 is supported in the current implementation use_tf32: bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + fuse_grad_x_acc : bool + Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. + Note: if enabled, you must also enable this flag for `mhc_fused_projection` & `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. Returns ------- out: torch.Tensor - output activation tensor of shape (s, b, C), which is the aggregated output after merging n hyper connections + output activation tensor of shape (s, b, C), which is the aggregated output after merging n hyper connections, + with the same dtype as x """ assert n == 4, "Only n=4 is supported in this implementation" check_deterministic("mhc_fused_aggregate") - out = mHCAggregateOp.apply(x, H_pre, n, use_tf32) + out = mHCAggregateOp.apply(x, H_pre, n, use_tf32, fuse_grad_x_acc) return out @@ -157,6 +280,7 @@ def mhc_fused_expand_combine( H_res: torch.Tensor, n: int, use_tf32: bool = True, + fuse_grad_x_acc: bool = False, ): """ Expand and combine operation for merging n hyper connections (see section 4.3.1 of the DeepSeek mHC paper): @@ -167,24 +291,35 @@ def mhc_fused_expand_combine( ---------- f : torch.Tensor input activation tensor of shape (s, b, C), which is the output from the attention / FFN sub-layer in a transformer block + dtype is torch.bfloat16 or torch.float32 bias : torch.Tensor or None optional bias tensor of shape (C,) from the last linear layer, where f + bias is fused in this kernel for better performance + dtype is torch.bfloat16 or torch.float32 H_post : torch.Tensor input H_post matrix of shape (s, b, n) + dtype is torch.bfloat16 or torch.float32 x : torch.Tensor input activation tensor of shape (s, b, C, n), which is the hyper connection input before the aggregation operation + dtype is torch.bfloat16 or torch.float32 H_res : torch.Tensor input H_res matrix of shape (s, b, n, n) + dtype is torch.bfloat16 or torch.float32 n : int - number of hyper connections + number of hyper connections, where only n=4 is supported in the current implementation use_tf32 : bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail + fuse_grad_x_acc : bool + Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. + Note: if enabled, you must also enable this flag for `mhc_fused_projection` & `mhc_fused_aggregate` or `mhc_generate_mix_and_aggregate` which is a wrapper of the former two, + so they can share the same buffer for activation's gradient accumulation. Returns ------- out : torch.Tensor - out of shape (s, b, C, n), which is the expanded and combined output after merging n hyper connections + out of shape (s, b, C, n), which is the expanded and combined output after merging n hyper connections, + with the same dtype as x """ assert n == 4, "Only n=4 is supported in this implementation" check_deterministic("mhc_fused_expand_combine") @@ -196,41 +331,68 @@ def mhc_fused_expand_combine( H_res, n, use_tf32, + fuse_grad_x_acc, ) return out -def mhc_fused_projection(x: torch.Tensor, phi: torch.Tensor, use_tf32: bool = True): +def mhc_fused_projection( + x: torch.Tensor, + phi: torch.Tensor, + use_tf32: bool = True, + fuse_grad_x_acc: bool = False, + norm_weight: Optional[torch.Tensor] = None, +): """ Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, section 4.3.1 of the DeepSeek mHC paper): H = x @ phi^T: (M, K) @ (K, N) -> (M, N), which is padded to (M, 32) for better memory access pattern in the next kernels. ms = mean(x^2, dim=-1): (M,) + If norm_weight is provided, it will be absorbed into phi. In this case, the operation becomes: + Projection: + - H = x @ (phi.T * norm_weight) = x @ phi.T * norm_weight + - ms = mean(x^2, dim=-1) + - H = H / sqrt(ms) = x @ (phi.T * norm_weight) / sqrt(ms), where this step is fused into `mhc_fused_scale` + which is equivalent to performing the computation in the normal order: + - x_normalized = RMSNorm(x) = x * norm_weight / sqrt(ms) + - H = x_normalized @ phi.T = (x / sqrt(ms) @ phi.T) * norm_weight + Note: the current implementation only supports n=4 Parameters ---------- x : torch.Tensor input tensor of shape (M, K), where M=s*b is the batch size and K=nC is the hidden dimension after expansion. + dtype is torch.bfloat16 or torch.float32 phi : torch.Tensor projection matrix of shape (N, K), where N=2n+n*n (=24 for n=4) + dtype is torch.bfloat16 or torch.float32 use_tf32 : bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail. + fuse_grad_x_acc : bool + Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. + Note: if enabled, you must also enable this flag for `mhc_fused_aggregate` & `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. + norm_weight : torch.Tensor or None + optional, the weight for RMSNorm, of shape (K,), which is the learnable per-element affine parameters (gamma) applied to RMSNorm + dtype is torch.bfloat16 or torch.float32 Returns ------- H : torch.Tensor - Projected matrix of shape (M, 32), where only the first N elements in the last dimension are valid. + Projected matrix of shape (M, 32), where only the first N elements in the last dimension are valid, + with dtype float32 ms : torch.Tensor - Mean square of shape (M,), which is used for RMSNorm in the next kernel. + Mean square of shape (M,), which is used for RMSNorm in the next kernel, + with dtype float32 """ assert ( phi.shape[0] == 24 ), "Currently only n=4 is supported, which means phi should have 24 in its first dimension" check_deterministic("mhc_fused_projection") - H, ms = mHCProjectionOp.apply(x, phi, use_tf32) + H, ms = mHCProjectionOp.apply(x, phi, norm_weight, use_tf32, fuse_grad_x_acc) return H, ms @@ -240,16 +402,20 @@ class mHCProjectionOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, phi, use_tf32=True): + def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fuse_grad_x_acc=False): """ The forward pass of the fused projection operation. Computes H = x @ phi^T and the mean + If norm_weight is provided, it will be absorbd by phi square ms = mean(x^2, dim=-1) for RMSNorm in a single fused kernel. Parameters: ctx : The context object. x (tensor): The input tensor of shape (M, K), where M=s*b is the flattened batch dimension and K=nC is the hidden dimension after expansion. phi (tensor): The projection matrix of shape (N, K), where N=2n+n*n (=24 for n=4). + norm_weight (tensor or None): Optional, or tensor of shape (K,). RMSNorm's learnable per-element affine parameters use_tf32 (bool): Whether to use TF32 precision for matmul operations. If False, uses IEEE for better precision. + n (int): Number of hyper connections, where only n=4 is supported in the current implementation. + fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tuple: A tuple of (H, ms) where H is the projected matrix of shape (M, 32) padded for memory alignment (only the first N elements are valid), and ms is the mean square of shape (M,) in FP32. @@ -267,9 +433,7 @@ def forward(ctx, x, phi, use_tf32=True): # Pad H to (s, b, 32) for better memory access pattern in the kernel, but only the first N elements in the last dimension are valid H = torch.zeros((M, 32), device=device, dtype=torch.float32) - ms = torch.zeros( - (M,), device=device, dtype=torch.float32 - ) # Mean square for x, used to compute RMSNorm in the next kernel + ms = torch.zeros((M,), device=device, dtype=torch.float32) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( @@ -277,6 +441,32 @@ def forward(ctx, x, phi, use_tf32=True): triton.cdiv(K, META["BLOCK_SIZE_K"]), ) + use_tma = _SUPPORT_TMA and _tma_aligned(x) and _tma_aligned(phi) + if use_tma: + # TMA descriptors require a global memory allocation + def alloc_fn( + size: int, alignment: int, stream: Optional[int] + ): # pylint: disable=unused-argument + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + ctx.save_for_backward(x, phi, ms, norm_weight) + ctx.phi_dtype = phi.dtype + ctx.fuse_grad_x_acc = fuse_grad_x_acc + + if norm_weight is not None: + phi = phi.to(torch.float32) * norm_weight.to(torch.float32) + + precision = "tf32" if ctx.use_tf32 else "ieee" + # If upcasting from bf16 to fp32 takes place inside the triton kernel, triton will ignore "ieee" precision and use tf32 anyway + # See https://github.com/triton-lang/triton/issues/10176 for detail. + # Therefore, we need to use tf32x3 instead which at least has better accuracy than tf32 just to make the tests pass. In production + # precision should be tf32 so it's not affected. + if precision == "ieee" and x.dtype == torch.bfloat16 and phi.dtype == torch.float32: + precision = "tf32x3" + ctx.precision = precision + _mhc_projection_fwd_fused[grid]( x_ptr=x, # (M, K) phi_ptr=phi, # (N, K) @@ -292,22 +482,30 @@ def forward(ctx, x, phi, use_tf32=True): stride_hm=32, stride_hn=1, stride_ms=1, + stride_norm_weight=1, BLOCK_SIZE_N=32, - precision="tf32" if use_tf32 else "ieee", + precision=precision, + USE_TMA=use_tma, ) - ctx.save_for_backward(x, phi, ms) - ctx.phi_dtype = phi.dtype - - return H.to(ctx.dtype), ms # Keep ms in fp32 + return H, ms # Keep both in fp32, which will be passed to sigmoid in mHCScaleFusedOp @staticmethod def backward(ctx, grad_H, grad_ms): """ The backward pass of the fused projection operation. Computes gradients for x and phi. - grad_phi = grad_H^T @ x, truncated to the first N rows. - grad_x = grad_H @ phi + 2 * x * grad_ms / K, where the second term is the gradient contribution from + - grad_psi = grad_H^T @ x: (2n + n^2, M) @ (M, nC) = (2n + n^2, nC), where grad_H's last dim is padded to 32 + If norm_weight is None: + - grad_phi = grad_psi + Otherwise, + - grad_phi = grad_psi * norm_weight: (2n + n^2, nC) * (nC,) = (2n + n^2, nC) + - grad_norm_weight = sum(grad_psi * phi, dim=0): ((2n + n^2, nC) * (2n + n^2, nC)).sum(dim=0) -> (nC,) + Reorder a bit: + - grad_phi = grad_H^T @ x * norm_weight + - grad_norm_weight = sum((grad_H^T @ x) * phi, dim=0) + + - grad_x = grad_H @ phi + 2 * x * grad_ms / K, where the second term is the gradient contribution from the mean square computation fused in the forward pass. Parameters: @@ -316,9 +514,9 @@ def backward(ctx, grad_H, grad_ms): grad_ms (tensor): The gradient of the loss with respect to the mean square, of shape (M,). Returns: - tuple: A tuple with the gradients (grad_x, grad_phi, None). + tuple: A tuple with the gradients (grad_x, grad_phi, grad_norm_weight, None). """ - x, phi, ms = ctx.saved_tensors + x, phi, ms, norm_weight = ctx.saved_tensors M, K = x.shape device = x.device @@ -332,12 +530,57 @@ def backward(ctx, grad_H, grad_ms): M, ) - grad_x = torch.empty((M, K), device=device, dtype=x.dtype) + fuse_grad_x_acc = hasattr(x.untyped_storage(), "grad_x_acc") and ctx.fuse_grad_x_acc + if fuse_grad_x_acc: + grad_x = x.untyped_storage().grad_x_acc.view_as(x) + else: + grad_x = torch.empty((M, K), device=device, dtype=x.dtype) + + if norm_weight is not None: + # With norm_weight, we need a fused kernel to perform GEMM and output both phi & norm_weight gradients + # pylint: disable=unnecessary-lambda-assignment + grid = lambda META: ( + triton.cdiv(K, META["BLOCK_SIZE_K"]), + triton.cdiv(M, META["BLOCK_SIZE_M"]), + ) - grad_x = torch.empty((M, K), device=device, dtype=x.dtype) - grad_phi = general_gemm(x, grad_H, out_dtype=torch.float32, layout="NT")[0][:N, :].to( - phi.dtype - ) # (2n + n^2, M) @ (M, nC) = (2n + n^2, nC); grad_H's last dim is padded to 32 + # For reduction over M, we should prefer parallelizing over M since it's likely to be better, unless determinism is enforced + grad_phi = torch.zeros_like(phi, dtype=torch.float32) + grad_norm_weight = torch.zeros_like(norm_weight, dtype=torch.float32) + + _mhc_projection_bwd_fused_dphi[grid]( + x_ptr=x, # (M, K) + grad_H_ptr=grad_H, # (M, 32) + phi_ptr=phi, # (N, K) + norm_weight_ptr=norm_weight, # (K,) + grad_phi_ptr=grad_phi, # (N, K) + grad_norm_weight_ptr=grad_norm_weight, # (K,) + M=M, + N=N, + K=K, + stride_xm=K, + stride_xk=1, + stride_grad_Hm=32, + stride_grad_Hn=1, + stride_phin=K, + stride_phik=1, + stride_norm_weight=1, + stride_grad_phin=K, + stride_grad_phik=1, + stride_grad_norm_weight=1, + BLOCK_SIZE_N=32, + precision="tf32" if ctx.use_tf32 else "ieee", + ) + + grad_phi = grad_phi.to(phi.dtype) + grad_norm_weight = grad_norm_weight.to(norm_weight.dtype) + else: + # Without norm_weight, this is only a GEMM with no fusion needed so we let cuBLAS handle it + grad_phi = general_gemm( + x.to(grad_H.dtype), grad_H, out_dtype=torch.float32, layout="NT" + )[0][:N, :] + grad_phi = grad_phi.to(phi.dtype) + grad_norm_weight = None # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( @@ -345,10 +588,11 @@ def backward(ctx, grad_H, grad_ms): triton.cdiv(K, META["BLOCK_SIZE_K"]), ) - _mhc_projection_bwd_fused[grid]( + _mhc_projection_bwd_fused_dx[grid]( x_ptr=x, grad_x_ptr=grad_x, # (M, K) phi_ptr=phi, # (N, K) + norm_weight_ptr=norm_weight, # (K,) grad_h_ptr=grad_H, # (M, 32) grad_ms_ptr=grad_ms, # (M,) M=M, @@ -360,16 +604,22 @@ def backward(ctx, grad_H, grad_ms): stride_grad_xk=1, stride_phin=K, stride_phik=1, + stride_norm_weight=1, stride_grad_phin=K, stride_grad_phik=1, stride_grad_hm=32, stride_grad_hn=1, stride_grad_ms=1, BLOCK_SIZE_N=32, - precision="tf32" if ctx.use_tf32 else "ieee", + precision=ctx.precision, + FUSE_GRAD_X_ACC=fuse_grad_x_acc, + HAS_NORM_WEIGHT=norm_weight is not None, ) - return grad_x.to(ctx.dtype), grad_phi.to(ctx.dtype), None + if fuse_grad_x_acc: + del x.untyped_storage().grad_x_acc + + return grad_x.to(x.dtype), grad_phi, grad_norm_weight, None, None, None, None class mHCScaleFusedOp(torch.autograd.Function): @@ -507,10 +757,10 @@ def backward(ctx, grad_out): ) return ( - grad_h.to(ctx.dtype), - grad_alpha.to(ctx.dtype), - grad_beta.to(ctx.dtype), - grad_ms.to(ctx.dtype), + grad_h, + grad_alpha.to(alpha.dtype), + grad_beta.to(alpha.dtype), # We assume alpha and beta have the same dtype + grad_ms, None, ) @@ -676,7 +926,6 @@ def backward(ctx, grad_out): ) grad_res = grad_res.view(s, b, n, n) - return grad_res.to(ctx.dtype), None, None, None @@ -686,7 +935,7 @@ class mHCAggregateOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, H_pre, n, use_tf32=True): + def forward(ctx, x, H_pre, n, use_tf32=True, fuse_grad_x_acc=False): """ The forward pass of the aggregate operation. Merges n activation streams into one by computing a weighted sum using H_pre: @@ -699,6 +948,7 @@ def forward(ctx, x, H_pre, n, use_tf32=True): H_pre (tensor): The pre-connection matrix of shape (s, b, n), used as weights for aggregation. n (int): The number of hyper connections (only n=4 is supported). use_tf32 (bool): Whether to use TF32 precision for matmul operations. + fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tensor: The aggregated output of shape (s, b, C). @@ -735,6 +985,7 @@ def forward(ctx, x, H_pre, n, use_tf32=True): ctx.save_for_backward(x, H_pre) ctx.n = n ctx.use_tf32 = use_tf32 + ctx.fuse_grad_x_acc = fuse_grad_x_acc return out @@ -763,7 +1014,12 @@ def backward(ctx, grad_output): assert n == 4, "Only n=4 is supported in this implementation" M = s * b - grad_x = torch.empty_like(x) + fuse_grad_x_acc = hasattr(x.untyped_storage(), "grad_x_acc") and ctx.fuse_grad_x_acc + if fuse_grad_x_acc: + grad_x = x.untyped_storage().grad_x_acc.view_as(x) + else: + grad_x = torch.empty_like(x) + grad_H_pre = torch.zeros( (s, b, n), dtype=torch.float32, device=H_pre.device ) # We need to use atomic_add for this so we need higher precision @@ -790,11 +1046,15 @@ def backward(ctx, grad_output): stride_grad_xm=nC, stride_grad_xCn=1, precision="tf32" if ctx.use_tf32 else "ieee", + FUSE_GRAD_X_ACC=fuse_grad_x_acc, ) grad_H_pre = grad_H_pre.to(H_pre.dtype) # Cast back to the original dtype of H_pre - return grad_x, grad_H_pre, None, None + if fuse_grad_x_acc: + grad_x = None + + return grad_x, grad_H_pre, None, None, None, None class mHCExpandCombineOp(torch.autograd.Function): @@ -803,7 +1063,7 @@ class mHCExpandCombineOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): + def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True, fuse_grad_x_acc=False): """ The forward pass of the expand and combine operation. Expands the sub-layer output f back to n streams using H_post, and combines with the residual connections using H_res: @@ -819,6 +1079,7 @@ def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): H_res (tensor): The residual connection matrix of shape (s, b, n, n). n (int): The number of hyper connections (only n=4 is supported). use_tf32 (bool): Whether to use TF32 precision for matmul operations. + fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tensor: The expanded and combined output of shape (s, b, C, n). @@ -843,45 +1104,29 @@ def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True): triton.cdiv(M, META["BLOCK_SIZE_M"]), ) - if bias is None: - _mhc_expand_combine_fwd[grid]( - f_ptr=f, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - output_ptr=out, - M=M, - C=C, - n=n, - stride_fm=C, - stride_fc=1, - stride_xm=Cn, - stride_xCn=1, - stride_output_m=Cn, - stride_output_Cn=1, - ) - else: - _mhc_expand_combine_with_bias_fwd[grid]( - f_ptr=f, - bias_ptr=bias, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - output_ptr=out, - M=M, - C=C, - n=n, - stride_fm=C, - stride_fc=1, - stride_bias=1, - stride_xm=Cn, - stride_xCn=1, - stride_output_m=Cn, - stride_output_Cn=1, - ) + _mhc_expand_combine_fwd[grid]( + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + output_ptr=out, + M=M, + C=C, + n=n, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=Cn, + stride_xCn=1, + stride_output_m=Cn, + stride_output_Cn=1, + HAS_BIAS=bias is not None, + ) ctx.n = n ctx.have_bias = bias is not None + ctx.fuse_grad_x_acc = fuse_grad_x_acc if bias is not None: ctx.save_for_backward(f, bias, H_post, x, H_res) else: @@ -919,81 +1164,74 @@ def backward(ctx, grad_output): M = s * b grad_f = torch.empty_like(f) - grad_bias = torch.zeros_like(bias, dtype=torch.float32) if bias is not None else None + grad_x = torch.empty_like(x) + + # Since triton's autotune will reset grad_bias pointer when tuning, we need an empty placeholder here + grad_bias = torch.empty(1, device=grad_output.device, dtype=grad_output.dtype) grad_H_post = torch.zeros_like( H_post, dtype=torch.float32 ) # We need to use atomic_add for this so we need higher precision - grad_x = torch.empty_like(x) grad_H_res = torch.zeros_like( H_res, dtype=torch.float32 ) # We need to use atomic_add for this so we need higher precision + if bias is not None: + grad_bias = torch.zeros_like(bias, dtype=torch.float32) if bias is not None else None + # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( triton.cdiv(C, META["BLOCK_SIZE_C"]), triton.cdiv(M, META["BLOCK_SIZE_M"]), ) + _mhc_expand_combine_bwd[grid]( + grad_output_ptr=grad_output, + f_ptr=f, + bias_ptr=bias, + H_post_ptr=H_post, + x_ptr=x, + H_res_ptr=H_res, + grad_H_post_ptr=grad_H_post, + grad_f_ptr=grad_f, + grad_bias_ptr=grad_bias, + grad_H_res_ptr=grad_H_res, + grad_x_ptr=grad_x, + M=M, + C=C, + n=n, + stride_grad_output_m=n * C, + stride_grad_output_Cn=1, + stride_fm=C, + stride_fc=1, + stride_bias=1, + stride_xm=n * C, + stride_xCn=1, + stride_grad_fm=C, + stride_grad_fc=1, + stride_grad_bias=1, + stride_grad_xm=n * C, + stride_grad_xCn=1, + precision="tf32" if ctx.use_tf32 else "ieee", + HAS_BIAS=bias is not None, + ) + + # If no bias, replace the grad_bias placeholder with None if bias is None: - _mhc_expand_combine_bwd[grid]( - grad_output_ptr=grad_output, - f_ptr=f, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - grad_H_post_ptr=grad_H_post, - grad_f_ptr=grad_f, - grad_H_res_ptr=grad_H_res, - grad_x_ptr=grad_x, - M=M, - C=C, - n=n, - stride_grad_output_m=n * C, - stride_grad_output_Cn=1, - stride_fm=C, - stride_fc=1, - stride_xm=n * C, - stride_xCn=1, - stride_grad_fm=C, - stride_grad_fc=1, - stride_grad_xm=n * C, - stride_grad_xCn=1, - precision="tf32" if ctx.use_tf32 else "ieee", - ) - else: - _mhc_expand_combine_with_bias_bwd[grid]( - grad_output_ptr=grad_output, - f_ptr=f, - bias_ptr=bias, - H_post_ptr=H_post, - x_ptr=x, - H_res_ptr=H_res, - grad_H_post_ptr=grad_H_post, - grad_f_ptr=grad_f, - grad_bias_ptr=grad_bias, - grad_H_res_ptr=grad_H_res, - grad_x_ptr=grad_x, - M=M, - C=C, - n=n, - stride_grad_output_m=n * C, - stride_grad_output_Cn=1, - stride_fm=C, - stride_fc=1, - stride_bias=1, - stride_xm=n * C, - stride_xCn=1, - stride_grad_fm=C, - stride_grad_fc=1, - stride_grad_bias=1, - stride_grad_xm=n * C, - stride_grad_xCn=1, - precision="tf32" if ctx.use_tf32 else "ieee", - ) + grad_bias = None grad_H_post = grad_H_post.to(H_post.dtype) # Cast back to the original dtype of H_post grad_H_res = grad_H_res.to(H_res.dtype) # Cast back to the original dtype of H_res if bias is not None: grad_bias = grad_bias.to(bias.dtype) - return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None + if ctx.fuse_grad_x_acc: + assert not hasattr(x.untyped_storage(), "grad_x_acc"), ( + "Unexpected: grad_x_acc is already attached in x's storage. This implies incorrect" + " usage of `fuse_grad_x_acc` optimization. Please disable fuse_grad_x_acc or check" + " if there are other places where grad_x_acc is attached to x's storage." + ) + # When fused x gradient accumulation is enabled, use fp32 for the accumulation buffer + x.untyped_storage().grad_x_acc = grad_x.to(torch.float32) + grad_x = None + + return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None, None, None From 5d55d3aa5b078cbb0b21f4ee3dc7a5a6e4b7c565 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 12 May 2026 17:54:34 +0000 Subject: [PATCH 02/19] fix Signed-off-by: Kaining Zhong --- transformer_engine/pytorch/triton/mhc.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index d95eeb3f63..cbef6e4374 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -26,7 +26,8 @@ ) from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm -_SUPPORT_TMA = torch.cuda.get_device_capability()[0] >= 9 +def support_tma(): + return torch.cuda.get_device_capability()[0] >= 9 def _tma_aligned(t): @@ -441,7 +442,7 @@ def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fuse_grad_x_acc=False) triton.cdiv(K, META["BLOCK_SIZE_K"]), ) - use_tma = _SUPPORT_TMA and _tma_aligned(x) and _tma_aligned(phi) + use_tma = support_tma() and _tma_aligned(x) and _tma_aligned(phi) if use_tma: # TMA descriptors require a global memory allocation def alloc_fn( From 1edfa077799a799e30e52edf5e2f5200dfe708d7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 May 2026 17:55:46 +0000 Subject: [PATCH 03/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/triton/mhc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index cbef6e4374..b89a522f04 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -26,6 +26,7 @@ ) from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm + def support_tma(): return torch.cuda.get_device_capability()[0] >= 9 From c41ff8952918a3cbc555bb33457a672f6d52212b Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 12 May 2026 18:39:26 +0000 Subject: [PATCH 04/19] fix Signed-off-by: Kaining Zhong --- transformer_engine/pytorch/triton/mhc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index b89a522f04..45bc61f7c8 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -27,7 +27,7 @@ from transformer_engine.pytorch.cpp_extensions.gemm import general_gemm -def support_tma(): +def _support_tma(): return torch.cuda.get_device_capability()[0] >= 9 @@ -443,7 +443,7 @@ def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fuse_grad_x_acc=False) triton.cdiv(K, META["BLOCK_SIZE_K"]), ) - use_tma = support_tma() and _tma_aligned(x) and _tma_aligned(phi) + use_tma = _support_tma() and _tma_aligned(x) and _tma_aligned(phi) if use_tma: # TMA descriptors require a global memory allocation def alloc_fn( From f587a43e5d00c05f98dd6e5360a700cdda57e86b Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 12 May 2026 21:02:17 +0000 Subject: [PATCH 05/19] make fused_grad_x_acc less confusing Signed-off-by: Kaining Zhong --- tests/pytorch/test_mhc.py | 13 +-- transformer_engine/common/triton/mhc.py | 6 +- transformer_engine/pytorch/triton/mhc.py | 107 +++++++++++------------ 3 files changed, 63 insertions(+), 63 deletions(-) diff --git a/tests/pytorch/test_mhc.py b/tests/pytorch/test_mhc.py index 29087fa4bc..4ac0e51751 100644 --- a/tests/pytorch/test_mhc.py +++ b/tests/pytorch/test_mhc.py @@ -296,8 +296,6 @@ def test_mhc_projection(cfg: MHCConfig, dtypes, has_norm_weight): x_ref = x.detach().clone().requires_grad_(True) phi_ref = phi.detach().clone().requires_grad_(True) - has_norm_weight = False - if has_norm_weight: norm_weight = torch.randn(nC, device="cuda", requires_grad=True, dtype=x_dtype) norm_weight_ref = norm_weight.detach().clone().requires_grad_(True) @@ -484,8 +482,11 @@ def test_mhc_fuse_grad_acc(cfg: MHCConfig, dtype): beta_ref = beta.detach().clone().requires_grad_(True) def end_to_end(x, phi, alpha, beta, fused_grad_x_acc): + fused_grad_x_acc_buffer = None + if fused_grad_x_acc: + fused_grad_x_acc_buffer = torch.empty_like(x, dtype=torch.float32) aggregated, H_post, H_res = mhc_generate_mix_and_aggregate( - x, phi, alpha, beta, None, use_tf32, fused_grad_x_acc + x, phi, alpha, beta, None, use_tf32, fused_grad_x_acc_buffer ) H_res = mhc_fused_sinkhorn(H_res.view(s, b, n, n), n).view(s * b, n * n) expanded_combined = mhc_fused_expand_combine( @@ -496,13 +497,13 @@ def end_to_end(x, phi, alpha, beta, fused_grad_x_acc): H_res, n, False, - fused_grad_x_acc, + fused_grad_x_acc_buffer, ) return expanded_combined - expanded_combined_fuse_grad = end_to_end(x_ref, phi_ref, alpha_ref, beta_ref, False) - expanded_combined_no_fuse_grad = end_to_end(x, phi, alpha, beta, True) + expanded_combined_fuse_grad = end_to_end(x_ref, phi_ref, alpha_ref, beta_ref, fused_grad_x_acc=True) + expanded_combined_no_fuse_grad = end_to_end(x, phi, alpha, beta, fused_grad_x_acc=False) grad_output = torch.randn_like(expanded_combined_fuse_grad) expanded_combined_fuse_grad.backward(grad_output) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 62a9c81fe3..87299fed91 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -1530,6 +1530,7 @@ def _mhc_expand_combine_bwd( STEP_SIZE_C: tl.constexpr, precision: tl.constexpr, HAS_BIAS: tl.constexpr, + FUSE_GRAD_X_ACC: tl.constexpr, ): """ Each block @@ -1711,7 +1712,10 @@ def _mhc_expand_combine_bwd( grad_x_acc = tl.fma(grad_out2[:, :, None], H_res2[:, None, :], grad_x_acc) grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) - grad_x = grad_x_acc.to(x.dtype) + if FUSE_GRAD_X_ACC: + grad_x = grad_x_acc # If fusing gradient accumulation, the buffer should be always fp32 so we don't cast here + else: + grad_x = grad_x_acc.to(x.dtype) grad_x = tl.reshape( grad_x, (BLOCK_SIZE_M, STEP_SIZE_C * n) ) # (BLOCK_SIZE_M, STEP_SIZE_C*n) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index 45bc61f7c8..72a086edb2 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -54,7 +54,7 @@ def mhc_generate_mix_and_aggregate( beta: torch.Tensor, norm_weight: Optional[torch.Tensor] = None, use_tf32: bool = True, - fuse_grad_x_acc: bool = False, + fused_grad_x_acc_buffer: Optional[torch.Tensor] = None, ): """ Generate the mix matrix H_pre, H_post, H_res and apply H_pre to x to aggregate n streams @@ -97,10 +97,11 @@ def mhc_generate_mix_and_aggregate( dtype is torch.bfloat16 or torch.float32 use_tf32 : bool whether to use TF32 for matrix multiplications - fuse_grad_x_acc : bool - Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. - If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. - Note: you must enable this flag for both `mhc_generate_mix_and_aggregate` and `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. + fused_grad_x_acc_buffer : Optional[torch.Tensor] + A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch, which should be reused + during the backward of mhc_fused_aggregate, mhc_fused_expand_combine and mhc_fused_projection operations + Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns ------- @@ -125,7 +126,7 @@ def mhc_generate_mix_and_aggregate( phi, norm_weight=norm_weight, use_tf32=use_tf32, - fuse_grad_x_acc=fuse_grad_x_acc, + fused_grad_x_acc_buffer=fused_grad_x_acc_buffer, ) h_pre, h_post, h_res = mhc_fused_scale(H, alpha, beta, ms, n) H_pre = h_pre.view(s, b, n) @@ -137,7 +138,7 @@ def mhc_generate_mix_and_aggregate( H_pre.view(s, b, n), n, use_tf32=use_tf32, - fuse_grad_x_acc=fuse_grad_x_acc, + fused_grad_x_acc_buffer=fused_grad_x_acc_buffer, ) return out, H_post, H_res @@ -237,7 +238,7 @@ def mhc_fused_aggregate( H_pre: torch.Tensor, n: int, use_tf32: bool = True, - fuse_grad_x_acc: bool = False, + fused_grad_x_acc_buffer: Optional[torch.Tensor] = None, ): """ Aggregate operation to merge n activation streams into one (see section 4.3.1 of the DeepSeek mHC paper): @@ -257,10 +258,11 @@ def mhc_fused_aggregate( use_tf32: bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail - fuse_grad_x_acc : bool - Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. - If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. - Note: if enabled, you must also enable this flag for `mhc_fused_projection` & `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. + fused_grad_x_acc_buffer : Optional[torch.Tensor] + A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch, which should be reused + during the backward of mhc_fused_aggregate, mhc_fused_expand_combine and mhc_fused_projection operations + Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns ------- @@ -270,7 +272,7 @@ def mhc_fused_aggregate( """ assert n == 4, "Only n=4 is supported in this implementation" check_deterministic("mhc_fused_aggregate") - out = mHCAggregateOp.apply(x, H_pre, n, use_tf32, fuse_grad_x_acc) + out = mHCAggregateOp.apply(x, H_pre, n, use_tf32, fused_grad_x_acc_buffer) return out @@ -282,7 +284,7 @@ def mhc_fused_expand_combine( H_res: torch.Tensor, n: int, use_tf32: bool = True, - fuse_grad_x_acc: bool = False, + fused_grad_x_acc_buffer: Optional[torch.Tensor] = None, ): """ Expand and combine operation for merging n hyper connections (see section 4.3.1 of the DeepSeek mHC paper): @@ -311,11 +313,11 @@ def mhc_fused_expand_combine( use_tf32 : bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail - fuse_grad_x_acc : bool - Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. - If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. - Note: if enabled, you must also enable this flag for `mhc_fused_projection` & `mhc_fused_aggregate` or `mhc_generate_mix_and_aggregate` which is a wrapper of the former two, - so they can share the same buffer for activation's gradient accumulation. + fused_grad_x_acc_buffer : Optional[torch.Tensor] + A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch, which should be reused + during the backward of mhc_fused_aggregate, mhc_fused_expand_combine and mhc_fused_projection operations + Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns ------- @@ -333,7 +335,7 @@ def mhc_fused_expand_combine( H_res, n, use_tf32, - fuse_grad_x_acc, + fused_grad_x_acc_buffer, ) return out @@ -342,8 +344,8 @@ def mhc_fused_projection( x: torch.Tensor, phi: torch.Tensor, use_tf32: bool = True, - fuse_grad_x_acc: bool = False, norm_weight: Optional[torch.Tensor] = None, + fused_grad_x_acc_buffer: Optional[torch.Tensor] = None, ): """ Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, section 4.3.1 of the DeepSeek mHC paper): @@ -373,13 +375,14 @@ def mhc_fused_projection( use_tf32 : bool whether to use TF32 precision for matmul operations. If False, it will use ieee for better precision. This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail. - fuse_grad_x_acc : bool - Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. - If enable, triton kernels will accumulate the gradient of x in the same buffer to avoid copying the gradient by PyTorch. - Note: if enabled, you must also enable this flag for `mhc_fused_aggregate` & `mhc_fused_expand_combine` so they can share the same buffer for activation's gradient accumulation. norm_weight : torch.Tensor or None optional, the weight for RMSNorm, of shape (K,), which is the learnable per-element affine parameters (gamma) applied to RMSNorm dtype is torch.bfloat16 or torch.float32 + fused_grad_x_acc_buffer : Optional[torch.Tensor] + A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch, which should be reused + during the backward of mhc_fused_aggregate, mhc_fused_expand_combine and mhc_fused_projection operations + Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns ------- @@ -394,7 +397,7 @@ def mhc_fused_projection( phi.shape[0] == 24 ), "Currently only n=4 is supported, which means phi should have 24 in its first dimension" check_deterministic("mhc_fused_projection") - H, ms = mHCProjectionOp.apply(x, phi, norm_weight, use_tf32, fuse_grad_x_acc) + H, ms = mHCProjectionOp.apply(x, phi, norm_weight, use_tf32, fused_grad_x_acc_buffer) return H, ms @@ -404,7 +407,7 @@ class mHCProjectionOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fuse_grad_x_acc=False): + def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fused_grad_x_acc_buffer=None): """ The forward pass of the fused projection operation. Computes H = x @ phi^T and the mean If norm_weight is provided, it will be absorbd by phi @@ -417,7 +420,7 @@ def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fuse_grad_x_acc=False) norm_weight (tensor or None): Optional, or tensor of shape (K,). RMSNorm's learnable per-element affine parameters use_tf32 (bool): Whether to use TF32 precision for matmul operations. If False, uses IEEE for better precision. n (int): Number of hyper connections, where only n=4 is supported in the current implementation. - fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + fused_grad_x_acc_buffer (torch.Tensor or None): A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tuple: A tuple of (H, ms) where H is the projected matrix of shape (M, 32) padded for memory alignment (only the first N elements are valid), and ms is the mean square of shape (M,) in FP32. @@ -455,7 +458,7 @@ def alloc_fn( ctx.save_for_backward(x, phi, ms, norm_weight) ctx.phi_dtype = phi.dtype - ctx.fuse_grad_x_acc = fuse_grad_x_acc + ctx.fused_grad_x_acc_buffer = fused_grad_x_acc_buffer if norm_weight is not None: phi = phi.to(torch.float32) * norm_weight.to(torch.float32) @@ -532,9 +535,8 @@ def backward(ctx, grad_H, grad_ms): M, ) - fuse_grad_x_acc = hasattr(x.untyped_storage(), "grad_x_acc") and ctx.fuse_grad_x_acc - if fuse_grad_x_acc: - grad_x = x.untyped_storage().grad_x_acc.view_as(x) + if ctx.fused_grad_x_acc_buffer is not None: + grad_x = ctx.fused_grad_x_acc_buffer.view_as(x) else: grad_x = torch.empty((M, K), device=device, dtype=x.dtype) @@ -614,13 +616,10 @@ def backward(ctx, grad_H, grad_ms): stride_grad_ms=1, BLOCK_SIZE_N=32, precision=ctx.precision, - FUSE_GRAD_X_ACC=fuse_grad_x_acc, + FUSE_GRAD_X_ACC=ctx.fused_grad_x_acc_buffer is not None, HAS_NORM_WEIGHT=norm_weight is not None, ) - if fuse_grad_x_acc: - del x.untyped_storage().grad_x_acc - return grad_x.to(x.dtype), grad_phi, grad_norm_weight, None, None, None, None @@ -937,7 +936,7 @@ class mHCAggregateOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, x, H_pre, n, use_tf32=True, fuse_grad_x_acc=False): + def forward(ctx, x, H_pre, n, use_tf32=True, fused_grad_x_acc_buffer=None): """ The forward pass of the aggregate operation. Merges n activation streams into one by computing a weighted sum using H_pre: @@ -950,7 +949,7 @@ def forward(ctx, x, H_pre, n, use_tf32=True, fuse_grad_x_acc=False): H_pre (tensor): The pre-connection matrix of shape (s, b, n), used as weights for aggregation. n (int): The number of hyper connections (only n=4 is supported). use_tf32 (bool): Whether to use TF32 precision for matmul operations. - fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + fused_grad_x_acc_buffer (torch.Tensor or None): A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tensor: The aggregated output of shape (s, b, C). @@ -987,7 +986,7 @@ def forward(ctx, x, H_pre, n, use_tf32=True, fuse_grad_x_acc=False): ctx.save_for_backward(x, H_pre) ctx.n = n ctx.use_tf32 = use_tf32 - ctx.fuse_grad_x_acc = fuse_grad_x_acc + ctx.fused_grad_x_acc_buffer = fused_grad_x_acc_buffer return out @@ -1016,9 +1015,8 @@ def backward(ctx, grad_output): assert n == 4, "Only n=4 is supported in this implementation" M = s * b - fuse_grad_x_acc = hasattr(x.untyped_storage(), "grad_x_acc") and ctx.fuse_grad_x_acc - if fuse_grad_x_acc: - grad_x = x.untyped_storage().grad_x_acc.view_as(x) + if ctx.fused_grad_x_acc_buffer is not None: + grad_x = ctx.fused_grad_x_acc_buffer.view_as(x) else: grad_x = torch.empty_like(x) @@ -1048,12 +1046,12 @@ def backward(ctx, grad_output): stride_grad_xm=nC, stride_grad_xCn=1, precision="tf32" if ctx.use_tf32 else "ieee", - FUSE_GRAD_X_ACC=fuse_grad_x_acc, + FUSE_GRAD_X_ACC=ctx.fused_grad_x_acc_buffer is not None, ) grad_H_pre = grad_H_pre.to(H_pre.dtype) # Cast back to the original dtype of H_pre - if fuse_grad_x_acc: + if ctx.fused_grad_x_acc_buffer is not None: grad_x = None return grad_x, grad_H_pre, None, None, None, None @@ -1065,7 +1063,7 @@ class mHCExpandCombineOp(torch.autograd.Function): """ @staticmethod - def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True, fuse_grad_x_acc=False): + def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True, fused_grad_x_acc_buffer=None): """ The forward pass of the expand and combine operation. Expands the sub-layer output f back to n streams using H_post, and combines with the residual connections using H_res: @@ -1081,7 +1079,7 @@ def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True, fuse_grad_x_acc=Fa H_res (tensor): The residual connection matrix of shape (s, b, n, n). n (int): The number of hyper connections (only n=4 is supported). use_tf32 (bool): Whether to use TF32 precision for matmul operations. - fuse_grad_x_acc (bool): Use the same buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. + fused_grad_x_acc_buffer (torch.Tensor or None): A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. Returns: tensor: The expanded and combined output of shape (s, b, C, n). @@ -1128,7 +1126,7 @@ def forward(ctx, f, bias, H_post, x, H_res, n, use_tf32=True, fuse_grad_x_acc=Fa ctx.n = n ctx.have_bias = bias is not None - ctx.fuse_grad_x_acc = fuse_grad_x_acc + ctx.fused_grad_x_acc_buffer = fused_grad_x_acc_buffer if bias is not None: ctx.save_for_backward(f, bias, H_post, x, H_res) else: @@ -1166,7 +1164,10 @@ def backward(ctx, grad_output): M = s * b grad_f = torch.empty_like(f) - grad_x = torch.empty_like(x) + if ctx.fused_grad_x_acc_buffer is not None: + grad_x = ctx.fused_grad_x_acc_buffer.view_as(x) + else: + grad_x = torch.empty_like(x) # Since triton's autotune will reset grad_bias pointer when tuning, we need an empty placeholder here grad_bias = torch.empty(1, device=grad_output.device, dtype=grad_output.dtype) @@ -1215,6 +1216,7 @@ def backward(ctx, grad_output): stride_grad_xCn=1, precision="tf32" if ctx.use_tf32 else "ieee", HAS_BIAS=bias is not None, + FUSE_GRAD_X_ACC=ctx.fused_grad_x_acc_buffer is not None, ) # If no bias, replace the grad_bias placeholder with None @@ -1226,14 +1228,7 @@ def backward(ctx, grad_output): if bias is not None: grad_bias = grad_bias.to(bias.dtype) - if ctx.fuse_grad_x_acc: - assert not hasattr(x.untyped_storage(), "grad_x_acc"), ( - "Unexpected: grad_x_acc is already attached in x's storage. This implies incorrect" - " usage of `fuse_grad_x_acc` optimization. Please disable fuse_grad_x_acc or check" - " if there are other places where grad_x_acc is attached to x's storage." - ) - # When fused x gradient accumulation is enabled, use fp32 for the accumulation buffer - x.untyped_storage().grad_x_acc = grad_x.to(torch.float32) + if ctx.fused_grad_x_acc_buffer is not None: grad_x = None return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None, None, None From fb17f290c700aedcbc8ed91bad75def5fd3cd3cf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 May 2026 21:03:30 +0000 Subject: [PATCH 06/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_mhc.py | 4 +++- transformer_engine/common/triton/mhc.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_mhc.py b/tests/pytorch/test_mhc.py index 4ac0e51751..44258a82de 100644 --- a/tests/pytorch/test_mhc.py +++ b/tests/pytorch/test_mhc.py @@ -502,7 +502,9 @@ def end_to_end(x, phi, alpha, beta, fused_grad_x_acc): return expanded_combined - expanded_combined_fuse_grad = end_to_end(x_ref, phi_ref, alpha_ref, beta_ref, fused_grad_x_acc=True) + expanded_combined_fuse_grad = end_to_end( + x_ref, phi_ref, alpha_ref, beta_ref, fused_grad_x_acc=True + ) expanded_combined_no_fuse_grad = end_to_end(x, phi, alpha, beta, fused_grad_x_acc=False) grad_output = torch.randn_like(expanded_combined_fuse_grad) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 87299fed91..800907440a 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -1713,7 +1713,7 @@ def _mhc_expand_combine_bwd( grad_x_acc = tl.fma(grad_out3[:, :, None], H_res3[:, None, :], grad_x_acc) if FUSE_GRAD_X_ACC: - grad_x = grad_x_acc # If fusing gradient accumulation, the buffer should be always fp32 so we don't cast here + grad_x = grad_x_acc # If fusing gradient accumulation, the buffer should be always fp32 so we don't cast here else: grad_x = grad_x_acc.to(x.dtype) grad_x = tl.reshape( From cbe971fa092f2b7676e606f0fcab8952297493e7 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 12 May 2026 23:58:15 +0000 Subject: [PATCH 07/19] fix Signed-off-by: Kaining Zhong --- transformer_engine/pytorch/triton/mhc.py | 26 ++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index 72a086edb2..f22f046bee 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -35,6 +35,22 @@ def _tma_aligned(t): return (t.stride(0) * t.element_size()) % 16 == 0 and t.data_ptr() % 16 == 0 +_tma_allocator_initialized = False + + +def _init_tma_allocator(): + # TMA descriptors require a global memory allocation. Registered once on first use. + global _tma_allocator_initialized + if _tma_allocator_initialized: + return + + def alloc_fn(size: int, alignment: int, stream: Optional[int]): # pylint: disable=unused-argument + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + _tma_allocator_initialized = True + + def check_deterministic(operator: str): """ Checks if the non-deterministic algorithm is allowed for the given operator. If not, raises an assertion error with instructions on how to allow it. @@ -448,13 +464,7 @@ def forward(ctx, x, phi, norm_weight=None, use_tf32=True, fused_grad_x_acc_buffe use_tma = _support_tma() and _tma_aligned(x) and _tma_aligned(phi) if use_tma: - # TMA descriptors require a global memory allocation - def alloc_fn( - size: int, alignment: int, stream: Optional[int] - ): # pylint: disable=unused-argument - return torch.empty(size, device="cuda", dtype=torch.int8) - - triton.set_allocator(alloc_fn) + _init_tma_allocator() ctx.save_for_backward(x, phi, ms, norm_weight) ctx.phi_dtype = phi.dtype @@ -1179,7 +1189,7 @@ def backward(ctx, grad_output): ) # We need to use atomic_add for this so we need higher precision if bias is not None: - grad_bias = torch.zeros_like(bias, dtype=torch.float32) if bias is not None else None + grad_bias = torch.zeros_like(bias, dtype=torch.float32) # pylint: disable=unnecessary-lambda-assignment grid = lambda META: ( From 1507618c0be54ad1309869607d6efc1f0720131a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 12 May 2026 23:59:24 +0000 Subject: [PATCH 08/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/triton/mhc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index f22f046bee..e129fd6a04 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -44,7 +44,9 @@ def _init_tma_allocator(): if _tma_allocator_initialized: return - def alloc_fn(size: int, alignment: int, stream: Optional[int]): # pylint: disable=unused-argument + def alloc_fn( + size: int, alignment: int, stream: Optional[int] + ): # pylint: disable=unused-argument return torch.empty(size, device="cuda", dtype=torch.int8) triton.set_allocator(alloc_fn) From 486a4cff2d3a98910c3bef341c9eea962422cee8 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 13 May 2026 00:36:06 +0000 Subject: [PATCH 09/19] nit Signed-off-by: Kaining Zhong --- tests/pytorch/test_mhc.py | 1 - transformer_engine/common/triton/mhc.py | 3 +- transformer_engine/pytorch/triton/mhc.py | 41 +++++++++++++++++------- 3 files changed, 31 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_mhc.py b/tests/pytorch/test_mhc.py index 44258a82de..1ec69255ee 100644 --- a/tests/pytorch/test_mhc.py +++ b/tests/pytorch/test_mhc.py @@ -488,7 +488,6 @@ def end_to_end(x, phi, alpha, beta, fused_grad_x_acc): aggregated, H_post, H_res = mhc_generate_mix_and_aggregate( x, phi, alpha, beta, None, use_tf32, fused_grad_x_acc_buffer ) - H_res = mhc_fused_sinkhorn(H_res.view(s, b, n, n), n).view(s * b, n * n) expanded_combined = mhc_fused_expand_combine( aggregated, None, diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 800907440a..65023a4349 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -267,10 +267,11 @@ def _mhc_projection_bwd_fused_dx( grad_h.to(phi.dtype), phi, acc=grad_x, input_precision=precision, out_dtype=tl.float32 ) # (BLOCK_SIZE_M, BLOCK_SIZE_K) grad_x_ptrs = grad_x_ptr + offs_m[:, None] * stride_grad_xm + offs_k[None, :] * stride_grad_xk - grad_x = grad_x.to(x.dtype) if FUSE_GRAD_X_ACC: # If fused gradient accumulation is enabled, the buffer is always fp32 grad_x_acc = tl.load(grad_x_ptrs, mask=mask_m[:, None] & mask_k[None, :], other=0.0) grad_x = grad_x.to(tl.float32) + grad_x_acc + else: + grad_x = grad_x.to(x.dtype) tl.store(grad_x_ptrs, grad_x, mask=mask_m[:, None] & mask_k[None, :]) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index e129fd6a04..651e34bf30 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -35,7 +35,7 @@ def _tma_aligned(t): return (t.stride(0) * t.element_size()) % 16 == 0 and t.data_ptr() % 16 == 0 -_tma_allocator_initialized = False +_tma_allocator_initialized = False # pylint: disable=global-statement def _init_tma_allocator(): @@ -138,6 +138,11 @@ def mhc_generate_mix_and_aggregate( assert ( n == 4 ), "Only n=4 is supported in this implementation, where n is the Hyper Connection number" + if fused_grad_x_acc_buffer is not None: + assert fused_grad_x_acc_buffer.dtype == torch.float32, \ + "fused_grad_x_acc_buffer must be fp32" + assert fused_grad_x_acc_buffer.numel() == x.numel(), \ + "fused_grad_x_acc_buffer.numel() must match x.numel()" nC = n * C H, ms = mhc_fused_projection( x.view(s * b, nC), @@ -146,11 +151,8 @@ def mhc_generate_mix_and_aggregate( use_tf32=use_tf32, fused_grad_x_acc_buffer=fused_grad_x_acc_buffer, ) - h_pre, h_post, h_res = mhc_fused_scale(H, alpha, beta, ms, n) - H_pre = h_pre.view(s, b, n) - H_post = h_post.view(s, b, n) - H_res = h_res.view(s, b, n, n) - H_res = mhc_fused_sinkhorn(H_res, n, recompute_hist=True, iters=20) + H_pre, H_post, H_res = mhc_fused_scale(H, alpha, beta, ms, n) + H_res = mhc_fused_sinkhorn(H_res.view(s, b, n, n), n, recompute_hist=True, iters=20) out = mhc_fused_aggregate( x, H_pre.view(s, b, n), @@ -278,8 +280,8 @@ def mhc_fused_aggregate( This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail fused_grad_x_acc_buffer : Optional[torch.Tensor] A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. - If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch, which should be reused - during the backward of mhc_fused_aggregate, mhc_fused_expand_combine and mhc_fused_projection operations + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch. + This optimization requires the operation order to be mhc_fused_projection -> mhc_fused_aggregate -> mhc_fused_expand_combine. Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns @@ -289,6 +291,11 @@ def mhc_fused_aggregate( with the same dtype as x """ assert n == 4, "Only n=4 is supported in this implementation" + if fused_grad_x_acc_buffer is not None: + assert fused_grad_x_acc_buffer.dtype == torch.float32, \ + "fused_grad_x_acc_buffer must be fp32" + assert fused_grad_x_acc_buffer.numel() == x.numel(), \ + "fused_grad_x_acc_buffer.numel() must match x.numel()" check_deterministic("mhc_fused_aggregate") out = mHCAggregateOp.apply(x, H_pre, n, use_tf32, fused_grad_x_acc_buffer) return out @@ -333,8 +340,8 @@ def mhc_fused_expand_combine( This is mainly used by our unittests since TF32 precision will introduce some errors and cause tests to fail fused_grad_x_acc_buffer : Optional[torch.Tensor] A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. - If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch, which should be reused - during the backward of mhc_fused_aggregate, mhc_fused_expand_combine and mhc_fused_projection operations + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch. + This optimization requires the operation order to be mhc_fused_projection -> mhc_fused_aggregate -> mhc_fused_expand_combine. Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns @@ -344,6 +351,11 @@ def mhc_fused_expand_combine( with the same dtype as x """ assert n == 4, "Only n=4 is supported in this implementation" + if fused_grad_x_acc_buffer is not None: + assert fused_grad_x_acc_buffer.dtype == torch.float32, \ + "fused_grad_x_acc_buffer must be fp32" + assert fused_grad_x_acc_buffer.numel() == x.numel(), \ + "fused_grad_x_acc_buffer.numel() must match x.numel()" check_deterministic("mhc_fused_expand_combine") out = mHCExpandCombineOp.apply( f, @@ -398,8 +410,8 @@ def mhc_fused_projection( dtype is torch.bfloat16 or torch.float32 fused_grad_x_acc_buffer : Optional[torch.Tensor] A pre-allocated buffer for inplace gradient accumulation to avoid PyTorch autograd overhead. - If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch, which should be reused - during the backward of mhc_fused_aggregate, mhc_fused_expand_combine and mhc_fused_projection operations + If not None, triton kernels will accumulate the gradient of x into this same buffer to avoid copying the gradient by PyTorch. + This optimization requires the operation order to be mhc_fused_projection -> mhc_fused_aggregate -> mhc_fused_expand_combine. Note: the buffer must have dtype float32, and it will be cast to the activation's dtype and be returned in mhc_fused_projection Returns @@ -415,6 +427,11 @@ def mhc_fused_projection( phi.shape[0] == 24 ), "Currently only n=4 is supported, which means phi should have 24 in its first dimension" check_deterministic("mhc_fused_projection") + if fused_grad_x_acc_buffer is not None: + assert fused_grad_x_acc_buffer.dtype == torch.float32, \ + "fused_grad_x_acc_buffer must be fp32" + assert fused_grad_x_acc_buffer.numel() == x.numel(), \ + "fused_grad_x_acc_buffer.numel() must match x.numel()" H, ms = mHCProjectionOp.apply(x, phi, norm_weight, use_tf32, fused_grad_x_acc_buffer) return H, ms From 3f2a9abd08829781c06e7f22ce74d1d633a5576f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 00:37:08 +0000 Subject: [PATCH 10/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/triton/mhc.py | 42 ++++++++++++++---------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index 651e34bf30..4371e0c409 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -35,7 +35,7 @@ def _tma_aligned(t): return (t.stride(0) * t.element_size()) % 16 == 0 and t.data_ptr() % 16 == 0 -_tma_allocator_initialized = False # pylint: disable=global-statement +_tma_allocator_initialized = False # pylint: disable=global-statement def _init_tma_allocator(): @@ -139,10 +139,12 @@ def mhc_generate_mix_and_aggregate( n == 4 ), "Only n=4 is supported in this implementation, where n is the Hyper Connection number" if fused_grad_x_acc_buffer is not None: - assert fused_grad_x_acc_buffer.dtype == torch.float32, \ - "fused_grad_x_acc_buffer must be fp32" - assert fused_grad_x_acc_buffer.numel() == x.numel(), \ - "fused_grad_x_acc_buffer.numel() must match x.numel()" + assert ( + fused_grad_x_acc_buffer.dtype == torch.float32 + ), "fused_grad_x_acc_buffer must be fp32" + assert ( + fused_grad_x_acc_buffer.numel() == x.numel() + ), "fused_grad_x_acc_buffer.numel() must match x.numel()" nC = n * C H, ms = mhc_fused_projection( x.view(s * b, nC), @@ -292,10 +294,12 @@ def mhc_fused_aggregate( """ assert n == 4, "Only n=4 is supported in this implementation" if fused_grad_x_acc_buffer is not None: - assert fused_grad_x_acc_buffer.dtype == torch.float32, \ - "fused_grad_x_acc_buffer must be fp32" - assert fused_grad_x_acc_buffer.numel() == x.numel(), \ - "fused_grad_x_acc_buffer.numel() must match x.numel()" + assert ( + fused_grad_x_acc_buffer.dtype == torch.float32 + ), "fused_grad_x_acc_buffer must be fp32" + assert ( + fused_grad_x_acc_buffer.numel() == x.numel() + ), "fused_grad_x_acc_buffer.numel() must match x.numel()" check_deterministic("mhc_fused_aggregate") out = mHCAggregateOp.apply(x, H_pre, n, use_tf32, fused_grad_x_acc_buffer) return out @@ -352,10 +356,12 @@ def mhc_fused_expand_combine( """ assert n == 4, "Only n=4 is supported in this implementation" if fused_grad_x_acc_buffer is not None: - assert fused_grad_x_acc_buffer.dtype == torch.float32, \ - "fused_grad_x_acc_buffer must be fp32" - assert fused_grad_x_acc_buffer.numel() == x.numel(), \ - "fused_grad_x_acc_buffer.numel() must match x.numel()" + assert ( + fused_grad_x_acc_buffer.dtype == torch.float32 + ), "fused_grad_x_acc_buffer must be fp32" + assert ( + fused_grad_x_acc_buffer.numel() == x.numel() + ), "fused_grad_x_acc_buffer.numel() must match x.numel()" check_deterministic("mhc_fused_expand_combine") out = mHCExpandCombineOp.apply( f, @@ -428,10 +434,12 @@ def mhc_fused_projection( ), "Currently only n=4 is supported, which means phi should have 24 in its first dimension" check_deterministic("mhc_fused_projection") if fused_grad_x_acc_buffer is not None: - assert fused_grad_x_acc_buffer.dtype == torch.float32, \ - "fused_grad_x_acc_buffer must be fp32" - assert fused_grad_x_acc_buffer.numel() == x.numel(), \ - "fused_grad_x_acc_buffer.numel() must match x.numel()" + assert ( + fused_grad_x_acc_buffer.dtype == torch.float32 + ), "fused_grad_x_acc_buffer must be fp32" + assert ( + fused_grad_x_acc_buffer.numel() == x.numel() + ), "fused_grad_x_acc_buffer.numel() must match x.numel()" H, ms = mHCProjectionOp.apply(x, phi, norm_weight, use_tf32, fused_grad_x_acc_buffer) return H, ms From fdccbdedc8c8fcb127de5f8fafa5392832cb58b8 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 13 May 2026 00:56:39 +0000 Subject: [PATCH 11/19] fix Signed-off-by: Kaining Zhong --- transformer_engine/pytorch/triton/mhc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index 4371e0c409..b1c9c81f6e 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -35,12 +35,12 @@ def _tma_aligned(t): return (t.stride(0) * t.element_size()) % 16 == 0 and t.data_ptr() % 16 == 0 -_tma_allocator_initialized = False # pylint: disable=global-statement +_tma_allocator_initialized = False def _init_tma_allocator(): # TMA descriptors require a global memory allocation. Registered once on first use. - global _tma_allocator_initialized + global _tma_allocator_initialized # pylint: disable=global-statement if _tma_allocator_initialized: return @@ -162,7 +162,7 @@ def mhc_generate_mix_and_aggregate( use_tf32=use_tf32, fused_grad_x_acc_buffer=fused_grad_x_acc_buffer, ) - return out, H_post, H_res + return out, H_post.view(s, b, n), H_res def mhc_fused_sinkhorn( From adb65eb9dbe6c6b86781dec2339a8b3bfed676dd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 00:57:33 +0000 Subject: [PATCH 12/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/triton/mhc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index b1c9c81f6e..3565b17ff4 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -40,7 +40,7 @@ def _tma_aligned(t): def _init_tma_allocator(): # TMA descriptors require a global memory allocation. Registered once on first use. - global _tma_allocator_initialized # pylint: disable=global-statement + global _tma_allocator_initialized # pylint: disable=global-statement if _tma_allocator_initialized: return From 6592735b5f486f4b3f77ab4b4c069ab86b465c12 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 13 May 2026 17:45:15 +0000 Subject: [PATCH 13/19] skip autotuning for project_bwd_dphi Signed-off-by: Kaining Zhong --- transformer_engine/common/triton/mhc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 65023a4349..9c8c275f3e 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -291,6 +291,8 @@ def projection_config_bwd_dphi(): num_stages=s, ) ) + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] return configs From 484a759978d0fa3cccf3cecad571f72b55dc5298 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 13 May 2026 17:45:30 +0000 Subject: [PATCH 14/19] fix returned values Signed-off-by: Kaining Zhong --- transformer_engine/pytorch/triton/mhc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/triton/mhc.py b/transformer_engine/pytorch/triton/mhc.py index 3565b17ff4..80c1350d82 100644 --- a/transformer_engine/pytorch/triton/mhc.py +++ b/transformer_engine/pytorch/triton/mhc.py @@ -657,7 +657,7 @@ def backward(ctx, grad_H, grad_ms): HAS_NORM_WEIGHT=norm_weight is not None, ) - return grad_x.to(x.dtype), grad_phi, grad_norm_weight, None, None, None, None + return grad_x.to(x.dtype), grad_phi, grad_norm_weight, None, None class mHCScaleFusedOp(torch.autograd.Function): @@ -1091,7 +1091,7 @@ def backward(ctx, grad_output): if ctx.fused_grad_x_acc_buffer is not None: grad_x = None - return grad_x, grad_H_pre, None, None, None, None + return grad_x, grad_H_pre, None, None, None class mHCExpandCombineOp(torch.autograd.Function): @@ -1268,4 +1268,4 @@ def backward(ctx, grad_output): if ctx.fused_grad_x_acc_buffer is not None: grad_x = None - return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None, None, None + return grad_f, grad_bias, grad_H_post, grad_x, grad_H_res, None, None, None From 645eade82aed2887a3bb30c3a9a0095c227be641 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 13 May 2026 17:48:40 +0000 Subject: [PATCH 15/19] prune projecti bwd dphi for safety Signed-off-by: Kaining Zhong --- transformer_engine/common/triton/mhc.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 9c8c275f3e..95c74d2df9 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -295,11 +295,21 @@ def projection_config_bwd_dphi(): configs = configs[:1] return configs +def projection_prune_bwd_dphi(configs, named_args, **kwargs): + M = named_args.get("M", kwargs.get("M", None)) + + pruned_configs = list( + filter( + lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs + ) + ) + return pruned_configs @triton.autotune( configs=projection_config_bwd_dphi(), key=["M", "K"], reset_to_zero=["grad_phi_ptr", "grad_norm_weight_ptr"], + prune_configs_by={"early_config_prune": projection_prune_bwd_dphi}, ) @triton.jit def _mhc_projection_bwd_fused_dphi( From 91ae3a4f1c330530806d5461739497d7d715189b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 17:49:56 +0000 Subject: [PATCH 16/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/triton/mhc.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 95c74d2df9..928dca1537 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -295,6 +295,7 @@ def projection_config_bwd_dphi(): configs = configs[:1] return configs + def projection_prune_bwd_dphi(configs, named_args, **kwargs): M = named_args.get("M", kwargs.get("M", None)) @@ -305,6 +306,7 @@ def projection_prune_bwd_dphi(configs, named_args, **kwargs): ) return pruned_configs + @triton.autotune( configs=projection_config_bwd_dphi(), key=["M", "K"], From 143e89348e5a417dabb027ac415b41bd58e75a57 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 13 May 2026 18:03:16 +0000 Subject: [PATCH 17/19] skip autotune in prune functions Signed-off-by: Kaining Zhong --- transformer_engine/common/triton/mhc.py | 41 ++++++++++++++----------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 928dca1537..acd4bb7a7e 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -291,8 +291,6 @@ def projection_config_bwd_dphi(): num_stages=s, ) ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] return configs @@ -304,6 +302,11 @@ def projection_prune_bwd_dphi(configs, named_args, **kwargs): lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs ) ) + + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] return pruned_configs @@ -954,22 +957,6 @@ def _mhc_sinkhorn_fwd_fused( tl.store(output_ptrs, P, mask=mask_batch[:, None]) -def aggregate_config_fwd(): - block_m = [1, 2, 4] - block_c = [128, 256] - warps = [1, 2, 4] - stages = [1, 2, 3, 4] - - configs = [] - for m, c, w, s in itertools.product(block_m, block_c, warps, stages): - configs.append( - triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) - ) - if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] - return configs - - @triton.autotune( configs=sinkhorn_config(), key=["M"], @@ -1074,6 +1061,19 @@ def _mhc_sinkhorn_bwd_fused( mask=mask_batch[:, None], ) +def aggregate_config_fwd(): + block_m = [1, 2, 4] + block_c = [128, 256] + warps = [1, 2, 4] + stages = [1, 2, 3, 4] + + configs = [] + for m, c, w, s in itertools.product(block_m, block_c, warps, stages): + configs.append( + triton.Config({"BLOCK_SIZE_M": m, "BLOCK_SIZE_C": c}, num_warps=w, num_stages=s) + ) + return configs + def aggregate_prune_fwd(configs, named_args, **kwargs): M = named_args.get("M", kwargs.get("M", None)) @@ -1083,6 +1083,11 @@ def aggregate_prune_fwd(configs, named_args, **kwargs): lambda config: triton.cdiv(M, config.kwargs["BLOCK_SIZE_M"]) <= MAX_GRID_DIM_Y, configs ) ) + + # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here + # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead + if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": + configs = configs[:1] return pruned_configs From 653251e2e6f638e7e1bc1330961637bd9d18f520 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 May 2026 18:04:11 +0000 Subject: [PATCH 18/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/triton/mhc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index acd4bb7a7e..75507fc0dc 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -1061,6 +1061,7 @@ def _mhc_sinkhorn_bwd_fused( mask=mask_batch[:, None], ) + def aggregate_config_fwd(): block_m = [1, 2, 4] block_c = [128, 256] From 920020c19283c25cf45c08dd4273d8e1e63f47ff Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 13 May 2026 18:18:30 +0000 Subject: [PATCH 19/19] fix Signed-off-by: Kaining Zhong --- transformer_engine/common/triton/mhc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/triton/mhc.py b/transformer_engine/common/triton/mhc.py index 75507fc0dc..a87f99a572 100644 --- a/transformer_engine/common/triton/mhc.py +++ b/transformer_engine/common/triton/mhc.py @@ -306,7 +306,7 @@ def projection_prune_bwd_dphi(configs, named_args, **kwargs): # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] + pruned_configs = pruned_configs[:1] return pruned_configs @@ -1088,7 +1088,7 @@ def aggregate_prune_fwd(configs, named_args, **kwargs): # Triton will skip calling prune function if the autotune returns only one config, which breaks the determinism override here # So we need to apply NVTE_DISABLE_TRITON_AUTOTUNING in the pruner instead if os.environ.get("NVTE_DISABLE_TRITON_AUTOTUNING", "0") == "1": - configs = configs[:1] + pruned_configs = pruned_configs[:1] return pruned_configs