-
Notifications
You must be signed in to change notification settings - Fork 25
NVFP4 Random Hadamard Transform (butterfly permutation-based) #509
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
d954c6d
7b5cf20
640f7e8
82faeec
b6a3ae4
9e32d3a
84209ad
c669bd2
8040909
e375923
c6bd974
d76aa06
ae979d0
b58cbd1
e5d7446
7734ce5
c169c75
80e0aab
de7863a
bda7b13
7ddb539
63c7a48
a260459
3dd8af9
ab217cb
26c5fb7
2087f24
05cedb7
67b93a8
465d547
9fb21f9
2f66594
f74a0ab
e3a2502
3a63f32
17d50ee
e32a758
4857721
b243b4c
6527004
ca1aacf
bc9f0a3
9f1851d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| # This file was modified for portability to AMDGPU | ||
| # Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # See LICENSE for license information. | ||
|
|
@@ -246,3 +248,124 @@ def test_nvfp4_quantization_noncontiguous_inputs( | |
| use_cpp_allocator=use_cpp_allocator, | ||
| with_random_sign_mask=with_random_sign_mask, | ||
| ) | ||
|
|
||
|
|
||
| def _ref_wht16_tiled(x: torch.Tensor, sign_mask: int) -> torch.Tensor: | ||
| """Reference 16-point WHT tiled along last dim, normalised by 0.25.""" | ||
| x = x.float() | ||
| _rows, cols = x.shape | ||
| d = torch.tensor( | ||
| [((-1) ** ((sign_mask >> i) & 1)) for i in range(16)], | ||
| dtype=torch.float32, device=x.device, | ||
| ) | ||
| out = x.clone() | ||
| for c in range(0, cols, 16): | ||
| tile = out[:, c:c+16] * d # apply sign | ||
| h = 1 | ||
| while h < 16: | ||
| for i in range(0, 16, h * 2): | ||
| a = tile[:, i:i+h].clone() | ||
| b = tile[:, i+h:i+2*h].clone() | ||
| tile[:, i:i+h] = a + b | ||
| tile[:, i+h:i+2*h] = a - b | ||
| h *= 2 | ||
| out[:, c:c+16] = tile * 0.25 | ||
| return out | ||
|
|
||
|
|
||
| def _ref_quantize_wht16_tiled( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is like NVFP4QuantizerRef + padding + unpadding. Why do we need padding and unpadding? |
||
| x: torch.Tensor, sign_mask: int, global_amax: torch.Tensor | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| # Mirror the TE columnwise RHT path by BF16-rounding WHT(x.T) | ||
| # before applying NVFP4 reference quantization with the TE global amax. | ||
|
|
||
| x_t_rht = _ref_wht16_tiled(x.t().contiguous(), sign_mask=sign_mask).to(dtype=x.dtype) | ||
| ref_quantizer = NVFP4QuantizerRef( | ||
| dtype=utils.Fp4Formats.E2M1, | ||
| rowwise=True, | ||
| columnwise=False, | ||
| pow_2_scales=False, | ||
| eps=0.0, | ||
| quant_tile_shape=(1, 16), | ||
| with_rht=False, | ||
| with_random_sign_mask=False, | ||
| ) | ||
|
|
||
| x_t_rht_padded = ref_quantizer._pad_tensor( | ||
| x_t_rht, | ||
| row_divisor=ref_quantizer.quant_tile_shape[0], | ||
| col_divisor=ref_quantizer.quant_tile_shape[1], | ||
| ) | ||
|
|
||
| qx_t_ref, sx_t_ref = ref_quantizer._quantize_blockwise_reference( | ||
| x_t_rht_padded, | ||
| global_amax, | ||
| ref_quantizer.quant_tile_shape[1], | ||
| ref_quantizer.quant_tile_shape[0], | ||
| pow_2_scales=ref_quantizer.pow_2_scales, | ||
| eps=ref_quantizer.eps, | ||
| ) | ||
|
|
||
| qx_t_ref = ref_quantizer._rm_pad_tensor(qx_t_ref, (x_t_rht.shape[0], x_t_rht.shape[1] // 2)) | ||
|
|
||
| return qx_t_ref, sx_t_ref | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("rows,cols", [(64, 64), (128, 128)]) | ||
| def test_hadamard_transform_amax(rows, cols): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition to the quantization result check, I saw your new test also checks the intermediate results of the rht. Why not add just a hadmard transform check without the quantization result checking? I think the quantization result checking is already in test_nvfp4_quantization_noncontiguous_inputs and test_rht_with_quantization_block_tiling_versus_reference |
||
| """ | ||
| Tests hadamard_transform_amax() via NVFP4Quantizer (with_rht=True), | ||
| without requiring a full NVFP4 recipe. | ||
| Checks: | ||
| - amax_rowwise == max|x| (pre-RHT amax of raw input) | ||
| - amax_colwise == max|WHT(x.T)| (post-RHT amax of transposed input) | ||
| - packed columnwise output == quantized BF16-rounded WHT(x.T) | ||
| """ | ||
| torch.manual_seed(42) | ||
| x = torch.randn((rows, cols), dtype=torch.bfloat16, device="cuda").contiguous() | ||
|
|
||
| quantizer = NVFP4Quantizer( | ||
| fp4_dtype=tex.DType.kFloat4E2M1, | ||
| rowwise=True, | ||
| columnwise=True, | ||
| with_amax_reduction=False, | ||
| amax_reduction_group=None, | ||
| with_rht=True, | ||
| with_post_rht_amax=True, | ||
| with_random_sign_mask=True, | ||
| ) | ||
| out = quantizer(x) | ||
|
|
||
| # amax_rowwise: pre-RHT, should equal max|x| | ||
| expected_rowwise_amax = x.float().abs().max() | ||
| torch.testing.assert_close( | ||
| out._amax_rowwise.float().squeeze(), | ||
| expected_rowwise_amax, | ||
| rtol=0, atol=0, | ||
| ) | ||
|
|
||
| # amax_colwise: post-RHT of x.T, should equal max|WHT(x.T)| | ||
| sign_mask_t = quantizer.rht_matrix_random_sign_mask_t | ||
| x_t = x.t().contiguous() # (cols, rows) | ||
| wht_x_t = _ref_wht16_tiled(x_t, sign_mask=sign_mask_t).to(torch.bfloat16).float() | ||
| expected_colwise_amax = wht_x_t.float().abs().max() | ||
|
|
||
| torch.testing.assert_close( | ||
| out._amax_columnwise.float().squeeze().item(), | ||
| float(expected_colwise_amax), | ||
| rtol=0, atol=0, | ||
| ) | ||
|
|
||
| assert out._columnwise_data is not None | ||
| assert out._columnwise_scale_inv is not None | ||
|
|
||
| qx_t_ref, sx_t_ref = _ref_quantize_wht16_tiled(x, sign_mask_t, out._amax_columnwise) | ||
|
|
||
| qx_t = unpack_fp4(out._columnwise_data.view(torch.uint8)) | ||
| qx_t_ref = unpack_fp4(qx_t_ref.view(torch.uint8)) | ||
| torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) | ||
|
|
||
| sx_t = out._columnwise_scale_inv | ||
| sx_t_ref = sx_t_ref.view(dtype=torch.uint8) | ||
| sx_t_valid = sx_t[: sx_t_ref.shape[0], : sx_t_ref.shape[1]] | ||
| torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, don't forget guarding our rocm specific codes |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we reuse the _apply_rht in upstream's experimental quantizer?