Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extension/llm/custom_ops/op_sdpa_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ void cpu_flash_attention(
}
};
torch::executor::parallel_for(
0, batchSize * num_head * qSlice, 1, compute_lambda);
0, batchSize * num_head * qSlice, 1, compute_lambda, num_thread);
}
} // namespace sdpa::impl
} // namespace native
Expand Down
22 changes: 15 additions & 7 deletions extension/threadpool/thread_parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,21 @@ void set_thread_num(int64_t thread_num) {
thread_num_ = thread_num;
}

inline std::tuple<int64_t, int64_t>
calc_num_tasks_and_chunk_size(int64_t begin, int64_t end, int64_t grain_size) {
inline std::tuple<int64_t, int64_t> calc_num_tasks_and_chunk_size(
int64_t begin,
int64_t end,
int64_t grain_size,
int64_t num_threads) {
if ((end - begin) < grain_size) {
return std::make_tuple(1, std::max((int64_t)0, end - begin));
}
// Choose number of tasks based on grain size and number of threads.
int64_t chunk_size =
divup((end - begin), get_threadpool()->get_thread_count());
// Choose number of tasks based on grain size and number of threads. A
// caller-supplied num_threads pins this to the same count it sized its
// per-thread scratch with; <= 0 means use the threadpool's current count.
if (num_threads <= 0) {
num_threads = get_threadpool()->get_thread_count();
}
int64_t chunk_size = divup((end - begin), num_threads);
// Make sure each task is at least grain_size size.
chunk_size = std::max(grain_size, chunk_size);
int64_t num_tasks = divup((end - begin), chunk_size);
Expand All @@ -54,7 +61,8 @@ bool parallel_for(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
runtime::FunctionRef<void(int64_t, int64_t)> f) {
runtime::FunctionRef<void(int64_t, int64_t)> f,
const int64_t num_threads) {
ET_CHECK_OR_RETURN_FALSE(
begin >= 0 && end >= 0 && end >= begin,
"begin = %" PRId64 ", end = %" PRId64,
Expand All @@ -63,7 +71,7 @@ bool parallel_for(
ET_CHECK_OR_RETURN_FALSE(grain_size > 0, "grain_size = %" PRId64, grain_size);
int64_t num_tasks = 0, chunk_size = 0;
std::tie(num_tasks, chunk_size) =
calc_num_tasks_and_chunk_size(begin, end, grain_size);
calc_num_tasks_and_chunk_size(begin, end, grain_size, num_threads);

auto task = [&f, begin, end, chunk_size](size_t task_id) {
set_thread_num(task_id);
Expand Down
12 changes: 10 additions & 2 deletions runtime/kernel/thread_parallel_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ constexpr int64_t GRAIN_SIZE = 32768;
* described below
* f: user function applied in parallel to the chunks, signature:
* void f(int64_t begin, int64_t end)
* num_threads: number of threads to partition the work across. When <= 0
* (the default), the threadpool's current thread count is used. Callers that
* pre-size per-thread scratch indexed by get_thread_num() should pass the
* same count they sized with, so the number of chunks (and thus the maximum
* get_thread_num()) cannot exceed that count.
* Returns true if all work items are processed successfully, false otherwise
*
* Warning: parallel_for does NOT copy thread local states from the current
Expand All @@ -70,7 +75,8 @@ bool parallel_for(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
runtime::FunctionRef<void(int64_t, int64_t)> f);
runtime::FunctionRef<void(int64_t, int64_t)> f,
const int64_t num_threads = -1);

int64_t get_thread_num();

Expand All @@ -81,7 +87,9 @@ bool parallel_for(
const int64_t begin,
const int64_t end,
const int64_t grain_size,
const Func& func) {
const Func& func,
const int64_t num_threads = -1) {
(void)num_threads;
return internal::parallel_for_no_threadpool(begin, end, grain_size, func);
}

Expand Down
Loading