Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 72 additions & 27 deletions src/relax/transform/adjust_matmul_order.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,20 +141,44 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
auto shape_b = opt_shape_b.value();
auto shape_c = opt_shape_c.value();

auto permute_last_two_dims = [&](Expr expr) -> Expr {
auto opt_shape = get_shape(expr);
if (!opt_shape) return expr;

size_t ndim = opt_shape.value().size();
TVM_FFI_ICHECK_GE(ndim, 2);

ffi::Optional<ffi::Array<int64_t>> axes;

if (ndim == 2) {
// Pass none axes to permute_dims for simple transpose of 2D tensors.
axes = std::nullopt;
} else {
ffi::Array<int64_t> axes_array;
for (size_t i = 0; i < ndim; ++i) axes_array.push_back(i);
axes_array.Set(ndim - 1, ndim - 2);
axes_array.Set(ndim - 2, ndim - 1);
axes = ffi::Optional<ffi::Array<int64_t>>(axes_array);
}
return permute_dims(std::move(expr), axes);
};

auto transpose_shape_last_two_dims = [&](ffi::Array<PrimExpr>& shape) {
PrimExpr last_dim_shape = shape[shape.size() - 1];
shape.Set(shape.size() - 1, shape[shape.size() - 2]);
shape.Set(shape.size() - 2, last_dim_shape);
};

if (matches.count(pat_permuted_matmul_on_lhs)) {
expr_a = permute_dims(expr_a, std::nullopt);
expr_b = permute_dims(expr_b, std::nullopt);
TVM_FFI_ICHECK_EQ(shape_a.size(), 2);
TVM_FFI_ICHECK_EQ(shape_b.size(), 2);
shape_a = {shape_a[1], shape_a[0]};
shape_b = {shape_b[1], shape_b[0]};
expr_a = permute_last_two_dims(expr_a);
expr_b = permute_last_two_dims(expr_b);
transpose_shape_last_two_dims(shape_a);
transpose_shape_last_two_dims(shape_b);
} else if (matches.count(pat_permuted_matmul_on_rhs)) {
expr_b = permute_dims(expr_b, std::nullopt);
expr_c = permute_dims(expr_c, std::nullopt);
TVM_FFI_ICHECK_EQ(shape_b.size(), 2);
TVM_FFI_ICHECK_EQ(shape_c.size(), 2);
shape_b = {shape_b[1], shape_b[0]};
shape_c = {shape_c[1], shape_c[0]};
expr_b = permute_last_two_dims(expr_b);
expr_c = permute_last_two_dims(expr_c);
transpose_shape_last_two_dims(shape_b);
transpose_shape_last_two_dims(shape_c);
}
Comment thread
ConvolutedDog marked this conversation as resolved.

// If two of the three are compile-time, group those two values
Expand All @@ -166,13 +190,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
}

// Otherwise, select the order that reduces the total number of
// operations required, assuming a naive matmul.

// Matmul on LHS: ([N,R]*[R,M]) * [M,batch]
// Matmul on RHS: [N,R] * ([R,M]*[M,batch])
//
// LHS first: `N*R*M + N*M*batch = N*M*(R+batch)`
// RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch`
// operations required, assuming a naive matmul (see below).

if (shape_a.size() == 1) {
shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]};
Expand All @@ -192,13 +210,41 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)};
}

auto size_N = shape_a[shape_a.size() - 2];
auto size_R = shape_a[shape_a.size() - 1];
auto size_M = shape_c[shape_c.size() - 2];
auto size_B = shape_c[shape_c.size() - 1];
PrimExpr size_N = shape_a[shape_a.size() - 2]; // row of A
PrimExpr size_R = shape_a[shape_a.size() - 1]; // col of A and row of B
PrimExpr size_M = shape_c[shape_c.size() - 2]; // row of C and col of B
PrimExpr size_B = shape_c[shape_c.size() - 1]; // col of C

auto calculate_batch = [](ffi::Array<PrimExpr>& shape) {
PrimExpr batch = 1;
for (size_t i = 0; i < shape.size() - 2; ++i) {
batch *= shape[i];
}
return batch;
};

PrimExpr batch_A = calculate_batch(shape_a);
PrimExpr batch_B = calculate_batch(shape_b);
PrimExpr batch_C = calculate_batch(shape_c);

auto ops_with_lhs_first = (size_R + size_B) * size_N * size_M;
auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B;
// Compare naive matmul FLOPs for two evaluation orders of
// matmul(A, matmul(B, C)) vs matmul(matmul(A, B), C)
//
// Matrix dims (last two axes of each operand):
// A: [N, R] B: [R, M] C: [M, B_last]
// Batch prefixes (product of all leading axes):
// batch_A, batch_B, batch_C
//
// LHS first — matmul(matmul(A, B), C):
// inner matmul(A, B): batch_A * batch_B * N * R * M
// outer matmul(., C): batch_A * batch_B * batch_C * N * M * B_last
// total: batch_A * batch_B * N * M * (R + batch_C * B_last)
PrimExpr ops_with_lhs_first = (size_R + batch_C * size_B) * size_N * size_M * batch_A * batch_B;
// RHS first — matmul(A, matmul(B, C)):
// inner matmul(B, C): batch_B * batch_C * R * M * B_last
// outer matmul(A, .): batch_A * batch_B * batch_C * N * R * B_last
// total: batch_B * batch_C * R * B_last * (M + batch_A * N)
PrimExpr ops_with_rhs_first = (size_M + batch_A * size_N) * size_R * size_B * batch_B * batch_C;

arith::Analyzer analyzer;
analyzer.rewrite_simplify.SetEnabledExtensions(static_cast<arith::RewriteSimplifier::Extension>(
Expand All @@ -214,8 +260,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, Expr>)>>
return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), DataType::Void());
}

// If we cannot determine which order is best, keep the existing
// order.
// If we cannot determine which order is best, keep the existing order.
return expr;
};

Expand Down
Loading
Loading