Skip to content

[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass#19650

Open
ConvolutedDog wants to merge 1 commit into
apache:mainfrom
ConvolutedDog:fix-adjust-mm-order
Open

[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass#19650
ConvolutedDog wants to merge 1 commit into
apache:mainfrom
ConvolutedDog:fix-adjust-mm-order

Conversation

@ConvolutedDog
Copy link
Copy Markdown
Contributor

Fix a crash (#19576) when AdjustMatmulOrder encounters mixed-dimension matmul chains common in transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D intermediate results.

Changes:

  • Replace full 2D transpose with permute_last_two_dims for permuted matmul patterns, swapping only the last two axes for ND tensors.
  • Remove hard ndim==2 checks in the permuted rewrite path.
  • Account for batch prefixes when comparing naive matmul FLOPs, so reorder decisions reflect batched vs. weight-only inner matmuls.
  • Skip reorder when neither evaluation order is provably cheaper.
  • Add regression tests for symbolic/concrete batched LoRA shapes.
  • Add a numerics test covering a minimal attention block with ND permute_dims.

Fix a crash (apache#19576) when AdjustMatmulOrder
encounters mixed-dimension matmul chains common in transformer models (e.g.
matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands
in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D
intermediate results.

Changes:
- Replace full 2D transpose with permute_last_two_dims for permuted matmul
  patterns, swapping only the last two axes for ND tensors.
- Remove hard ndim==2 checks in the permuted rewrite path.
- Account for batch prefixes when comparing naive matmul FLOPs, so reorder
  decisions reflect batched vs. weight-only inner matmuls.
- Skip reorder when neither evaluation order is provably cheaper.
- Add regression tests for symbolic/concrete batched LoRA shapes.
- Add a numerics test covering a minimal attention block with ND permute_dims.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the adjust_matmul_order pass to support ND tensors by introducing helper functions to permute and transpose the last two dimensions, and updates the FLOP calculation to account for batch dimensions. It also adds several tests covering symbolic and concrete batch sizes, as well as correctness on a batched attention block. A review comment highlights a potential underflow issue in transpose_shape_last_two_dims when handling 1D tensors, which could lead to out-of-bounds memory access, and suggests skipping the optimization for operands with fewer than two dimensions.

Comment thread src/relax/transform/adjust_matmul_order.cc
@ConvolutedDog ConvolutedDog marked this pull request as ready for review June 1, 2026 01:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant