-
Notifications
You must be signed in to change notification settings - Fork 633
[PyTorch] ONNX test fix + export for FP8 attention #2598
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2c49133
bec2c3c
0e054ad
4bc878c
75ca174
1f0111f
768796b
87c5101
6f04da2
c3a1acf
c07ffb7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -164,6 +164,11 @@ class FP8EmulationFunc(torch.autograd.Function): | |
| @staticmethod | ||
| def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): | ||
| # pylint: disable=missing-function-docstring | ||
| if is_in_onnx_export_mode(): | ||
| return FP8EmulationFunc.onnx_forward( | ||
| tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout | ||
| ) | ||
|
|
||
| if quantizer_name == "QKV_quantizer": | ||
| query_layer, key_layer, value_layer = [ | ||
| x.contiguous() for x in [tensor1, tensor2, tensor3] | ||
|
|
@@ -202,6 +207,47 @@ def backward(ctx, grad1, grad2, grad3): | |
| tensors = grad1, grad2, grad3 | ||
| return tensors[0], tensors[1], tensors[2], None, None, None | ||
|
|
||
| @staticmethod | ||
| def onnx_forward(tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout=None): | ||
| """ | ||
| ONNX-compatible forward for FP8 emulation using operations with defined ONNX translations. | ||
| """ | ||
| # pylint: disable=unused-argument | ||
| is_qkv_quantizer = quantizer_name == "QKV_quantizer" | ||
| assert isinstance( | ||
| quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) | ||
| ), "ONNX FP8 emulation path supports only Float8 quantizers." | ||
|
Comment on lines
+217
to
+219
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Assert statement will cause ONNX export to fail if non-Float8 quantizers are used. Consider replacing with a runtime check that raises a more descriptive error or logging a warning. |
||
|
|
||
| if is_qkv_quantizer: | ||
| # Flatten + concatenate + quantize + split. Equivalent to combine_and_quantize Case 3. | ||
| orig_dtype = tensor1.dtype | ||
| shapes = [tensor1.shape, tensor2.shape, tensor3.shape] | ||
| numels = [tensor1.numel(), tensor2.numel(), tensor3.numel()] | ||
|
|
||
| # Flatten and concatenate | ||
| combined = torch.cat( | ||
| [tensor1.reshape(-1), tensor2.reshape(-1), tensor3.reshape(-1)], dim=0 | ||
| ) | ||
|
Comment on lines
+227
to
+230
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is fine for FP8 attention, although we'll need to revisit whenever we support MXFP8 or NVFP4. Why can't we concatenate the 2D tensors?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure if this will work for all layouts and different max_q_length and max_kv_length. Added asserions that's it not mxfp8, because I want to merge it fast. I will rethink it when adding support for mxfp8. |
||
|
|
||
| # Quantize + dequantize combined tensor using quantizer's ONNX methods | ||
| combined_fp8 = quantizer.onnx_quantize(combined) | ||
| out = quantizer.onnx_dequantize(combined_fp8).to(orig_dtype) | ||
|
|
||
| # Split back | ||
| out1 = out[: numels[0]].reshape(shapes[0]) | ||
| out2 = out[numels[0] : numels[0] + numels[1]].reshape(shapes[1]) | ||
| out3 = out[numels[0] + numels[1] :].reshape(shapes[2]) | ||
|
|
||
| return out1, out2, out3 | ||
| if quantizer_name in ["S_quantizer", "O_quantizer"]: | ||
| # Emulate FP8 on single tensor using quantizer's ONNX methods | ||
| orig_dtype = tensor1.dtype | ||
| t_fp8 = quantizer.onnx_quantize(tensor1) | ||
| out = quantizer.onnx_dequantize(t_fp8).to(orig_dtype) | ||
| return out, tensor2, tensor3 | ||
| # Pass-through | ||
| return tensor1, tensor2, tensor3 | ||
|
|
||
|
|
||
| class UnfusedDotProductAttention(torch.nn.Module): | ||
| """Parallel attention w/o QKV and Proj Gemms | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.