Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644
Add NVTE_BACKWARD_MODE=default|unquant|dequant#2644zianglih wants to merge 51 commits intoNVIDIA:mainfrom
NVTE_BACKWARD_MODE=default|unquant|dequant#2644Conversation
Greptile SummaryThis PR introduces Key changes:
Notable issues:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[FP8 Forward Pass] --> B{backward_mode?}
B -->|default| C[Quantize input & weight\ncolumnwise=True for backward\nSave quantized tensors]
B -->|unquant| D[Quantize input & weight\ncolumnwise=False\nSave ORIGINAL high-precision\ninput_ and self.weight]
B -->|dequant| E[Quantize input & weight\ncolumnwise=False\nSave QUANTIZED FP8 tensors\nrowwise only]
C --> F[Backward: FP8 GEMMs\nwith quantized grad_output\ndgrad uses columnwise weight\nwgrad uses columnwise input]
D --> G[Backward: High-precision GEMMs\ndgrad uses original weight\nwgrad uses original input\nno dequantize needed]
E --> H[Backward: High-precision GEMMs\nexplicit .dequantize on saved tensors\nfor dgrad weight and wgrad input]
F --> I[Update FP8 amax\nfor backward quantizers]
G --> J[ctx.fp8=False\nSkip FP8 amax update\nin module path]
H --> J
J --> K{te_ops path?}
K -->|Yes| L[⚠️ fuser.py still calls\nreduce_and_update_fp8_tensors\nforward=False]
K -->|No| M[✅ Correctly skipped]
|
|
I'll work on potential unit test breakage. |
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this line seems redundant since you already skip the quantization step in base.py grad_output_preprocess?
| not ctx.use_bias | ||
| and not ctx.requires_wgrad | ||
| and ctx.grad_output_quantizer is not None | ||
| and use_fp8_bwd |
| recipe = cls.get_fp8_recipe() | ||
| if recipe is not None and recipe.delayed(): | ||
| # Ignore NVTE_KEEP_BACKWARD_UNQUANTIZED when delayed scaling is used | ||
| return False |
There was a problem hiding this comment.
Maybe it's better to assert an error for delayed scaling? Okay with both.
There was a problem hiding this comment.
I agree. If the user specifies an unsupported combination, I think it's better to fail loudly than to secretly disobey their instructions.
| # Note: dgrad GEMM requires row-wise usage, wgrad GEMM | ||
| # requires column-wise usage | ||
| if ctx.grad_output_quantizer is not None: | ||
| if ctx.grad_output_quantizer is not None and use_fp8_bwd: |
There was a problem hiding this comment.
this seems redundant too if we skip quant in grad_output_preprocess
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
|
/te-ci pytorch L1 |
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.Add
NVTE_BACKWARD_MODE=default|unquant|dequantenv var:default: existing default quantization behaviorunquant: quantized fprop + high precision wgrad & dgrad using unquantized activation and weightdequant: quantized fpop + high precision wgrad & dgrad using activation and weight dequantized directly from fprop quantized valueType of change
Changes
Please list the changes introduced in this PR:
Checklist: