From dc13d311115adcb200ba1d50d30c8ead5f9bf5eb Mon Sep 17 00:00:00 2001 From: chen Date: Thu, 21 May 2026 09:57:38 +0000 Subject: [PATCH] fix: correct batched matmul strides for bs=1 and integrate CTest into build pipeline Set stride to 0 when batch_size is 1 to enable proper broadcasting in cuBLAS, and add configurable CTest execution after builds with googletest submodule. --- docs/test_usage_guide.md | 2 +- infini_train/src/kernels/cuda/matmul.cu | 18 +++++++++--------- scripts/run_models_and_profile.bash | 5 +++++ scripts/test_config.json | 4 +++- third_party/googletest | 1 + 5 files changed, 19 insertions(+), 11 deletions(-) create mode 160000 third_party/googletest diff --git a/docs/test_usage_guide.md b/docs/test_usage_guide.md index 108f995a..45113e52 100644 --- a/docs/test_usage_guide.md +++ b/docs/test_usage_guide.md @@ -8,7 +8,7 @@ ```bash mkdir build && cd build -cmake -DBUILD_TEST=ON -DUSE_CUDA=ON .. +cmake -DBUILD_TEST=ON -DUSE_CUDA=ON -DUSE_NCCL=ON .. make -j$(nproc) ``` diff --git a/infini_train/src/kernels/cuda/matmul.cu b/infini_train/src/kernels/cuda/matmul.cu index 8ee33174..7e301039 100644 --- a/infini_train/src/kernels/cuda/matmul.cu +++ b/infini_train/src/kernels/cuda/matmul.cu @@ -63,9 +63,9 @@ std::shared_ptr MatmulForward(const std::shared_ptr &input, cons .alpha = 1.0f, .beta = 0.0f, .batch_count = static_cast(bs), - .stride_a = n * k, - .stride_b = k * m, - .stride_c = m * n, + .stride_a = bs > 1 ? n * k : 0, + .stride_b = bs > 1 ? k * m : 0, + .stride_c = bs > 1 ? m * n : 0, .input_dtype = dtype, .output_dtype = dtype, }); @@ -133,9 +133,9 @@ std::shared_ptr MatmulBackwardInput(const std::shared_ptr &other .alpha = 1.0f, .beta = 0.0f, .batch_count = static_cast(bs), - .stride_a = k * n, - .stride_b = n * m, - .stride_c = m * k, + .stride_a = bs > 1 ? k * n : 0, + .stride_b = bs > 1 ? n * m : 0, + .stride_c = bs > 1 ? m * k : 0, .input_dtype = compute_dtype, .output_dtype = output_dtype, }); @@ -202,9 +202,9 @@ std::shared_ptr MatmulBackwardOther(const std::shared_ptr &input .alpha = 1.0f, .beta = 0.0f, .batch_count = static_cast(bs), - .stride_a = n * m, - .stride_b = k * m, - .stride_c = n * k, + .stride_a = bs > 1 ? n * m : 0, + .stride_b = bs > 1 ? k * m : 0, + .stride_c = bs > 1 ? n * k : 0, .input_dtype = compute_dtype, .output_dtype = output_dtype, }); diff --git a/scripts/run_models_and_profile.bash b/scripts/run_models_and_profile.bash index 06589904..8d68a08a 100755 --- a/scripts/run_models_and_profile.bash +++ b/scripts/run_models_and_profile.bash @@ -70,6 +70,8 @@ BUILD_DIR="$(read_var BUILD_DIR)"; : "${BUILD_DIR:=../build}" LOG_DIR="$(read_var LOG_DIR)"; : "${LOG_DIR:=logs}" PROFILE_LOG_DIR="$(read_var PROFILE_LOG_DIR)"; : "${PROFILE_LOG_DIR:=./profile_logs}" COMPARE_LOG_DIR="$(read_var COMPARE_LOG_DIR)"; : "${COMPARE_LOG_DIR:=}" +RUN_CTEST="$(read_var RUN_CTEST)"; : "${RUN_CTEST:=true}" +CTEST_CMD="$(read_var CTEST_CMD)"; : "${CTEST_CMD:=ctest --output-on-failure -LE cuda -j$(nproc) && ctest --output-on-failure -L cuda -j1}" mkdir -p "$BUILD_DIR" "$LOG_DIR" "$PROFILE_LOG_DIR" @@ -244,6 +246,9 @@ for ((id=0; id