[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass#19650
Open
ConvolutedDog wants to merge 1 commit into
Open
[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass#19650ConvolutedDog wants to merge 1 commit into
ConvolutedDog wants to merge 1 commit into
Conversation
Fix a crash (apache#19576) when AdjustMatmulOrder encounters mixed-dimension matmul chains common in transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D intermediate results. Changes: - Replace full 2D transpose with permute_last_two_dims for permuted matmul patterns, swapping only the last two axes for ND tensors. - Remove hard ndim==2 checks in the permuted rewrite path. - Account for batch prefixes when comparing naive matmul FLOPs, so reorder decisions reflect batched vs. weight-only inner matmuls. - Skip reorder when neither evaluation order is provably cheaper. - Add regression tests for symbolic/concrete batched LoRA shapes. - Add a numerics test covering a minimal attention block with ND permute_dims.
Contributor
There was a problem hiding this comment.
Code Review
This pull request enhances the adjust_matmul_order pass to support ND tensors by introducing helper functions to permute and transpose the last two dimensions, and updates the FLOP calculation to account for batch dimensions. It also adds several tests covering symbolic and concrete batch sizes, as well as correctness on a batched attention block. A review comment highlights a potential underflow issue in transpose_shape_last_two_dims when handling 1D tensors, which could lead to out-of-bounds memory access, and suggests skipping the optimization for operands with fewer than two dimensions.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix a crash (#19576) when AdjustMatmulOrder encounters mixed-dimension matmul chains common in transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D intermediate results.
Changes: