diff --git a/extension/llm/custom_ops/op_sdpa_impl.h b/extension/llm/custom_ops/op_sdpa_impl.h index 8b923673a08..cf825bc5a0b 100644 --- a/extension/llm/custom_ops/op_sdpa_impl.h +++ b/extension/llm/custom_ops/op_sdpa_impl.h @@ -1098,7 +1098,7 @@ void cpu_flash_attention( } }; torch::executor::parallel_for( - 0, batchSize * num_head * qSlice, 1, compute_lambda); + 0, batchSize * num_head * qSlice, 1, compute_lambda, num_thread); } } // namespace sdpa::impl } // namespace native diff --git a/extension/threadpool/thread_parallel.cpp b/extension/threadpool/thread_parallel.cpp index 0fd95019753..516dad4cb99 100644 --- a/extension/threadpool/thread_parallel.cpp +++ b/extension/threadpool/thread_parallel.cpp @@ -36,14 +36,21 @@ void set_thread_num(int64_t thread_num) { thread_num_ = thread_num; } -inline std::tuple -calc_num_tasks_and_chunk_size(int64_t begin, int64_t end, int64_t grain_size) { +inline std::tuple calc_num_tasks_and_chunk_size( + int64_t begin, + int64_t end, + int64_t grain_size, + int64_t num_threads) { if ((end - begin) < grain_size) { return std::make_tuple(1, std::max((int64_t)0, end - begin)); } - // Choose number of tasks based on grain size and number of threads. - int64_t chunk_size = - divup((end - begin), get_threadpool()->get_thread_count()); + // Choose number of tasks based on grain size and number of threads. A + // caller-supplied num_threads pins this to the same count it sized its + // per-thread scratch with; <= 0 means use the threadpool's current count. + if (num_threads <= 0) { + num_threads = get_threadpool()->get_thread_count(); + } + int64_t chunk_size = divup((end - begin), num_threads); // Make sure each task is at least grain_size size. chunk_size = std::max(grain_size, chunk_size); int64_t num_tasks = divup((end - begin), chunk_size); @@ -54,7 +61,8 @@ bool parallel_for( const int64_t begin, const int64_t end, const int64_t grain_size, - runtime::FunctionRef f) { + runtime::FunctionRef f, + const int64_t num_threads) { ET_CHECK_OR_RETURN_FALSE( begin >= 0 && end >= 0 && end >= begin, "begin = %" PRId64 ", end = %" PRId64, @@ -63,7 +71,7 @@ bool parallel_for( ET_CHECK_OR_RETURN_FALSE(grain_size > 0, "grain_size = %" PRId64, grain_size); int64_t num_tasks = 0, chunk_size = 0; std::tie(num_tasks, chunk_size) = - calc_num_tasks_and_chunk_size(begin, end, grain_size); + calc_num_tasks_and_chunk_size(begin, end, grain_size, num_threads); auto task = [&f, begin, end, chunk_size](size_t task_id) { set_thread_num(task_id); diff --git a/runtime/kernel/thread_parallel_interface.h b/runtime/kernel/thread_parallel_interface.h index 8cce610dcb4..eacc4fb09b3 100644 --- a/runtime/kernel/thread_parallel_interface.h +++ b/runtime/kernel/thread_parallel_interface.h @@ -60,6 +60,11 @@ constexpr int64_t GRAIN_SIZE = 32768; * described below * f: user function applied in parallel to the chunks, signature: * void f(int64_t begin, int64_t end) + * num_threads: number of threads to partition the work across. When <= 0 + * (the default), the threadpool's current thread count is used. Callers that + * pre-size per-thread scratch indexed by get_thread_num() should pass the + * same count they sized with, so the number of chunks (and thus the maximum + * get_thread_num()) cannot exceed that count. * Returns true if all work items are processed successfully, false otherwise * * Warning: parallel_for does NOT copy thread local states from the current @@ -70,7 +75,8 @@ bool parallel_for( const int64_t begin, const int64_t end, const int64_t grain_size, - runtime::FunctionRef f); + runtime::FunctionRef f, + const int64_t num_threads = -1); int64_t get_thread_num(); @@ -81,7 +87,9 @@ bool parallel_for( const int64_t begin, const int64_t end, const int64_t grain_size, - const Func& func) { + const Func& func, + const int64_t num_threads = -1) { + (void)num_threads; return internal::parallel_for_no_threadpool(begin, end, grain_size, func); }