Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -965,8 +965,8 @@ void NgramMatch(const paddle::Tensor& input_ids,
const int max_ngram_size,
const int max_draft_tokens);

void HybridMtpNgram(const paddle::Tensor& input_ids,
const paddle::Tensor& input_ids_len,
void HybridMtpNgram(const paddle::Tensor& token_ids_all,
const paddle::Tensor& prompt_lens,
const paddle::Tensor& pre_ids,
const paddle::Tensor& step_idx,
const paddle::Tensor& draft_token_num,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@
// Also copies tentative matched tokens to scratch buffers.
// ============================================================
__global__ void ngram_match_mixed_search_kernel(
const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
const int32_t *seq_lens_this_time,
const int64_t *max_dec_len,
int64_t *draft_tokens_copy,
int32_t *seq_lens_this_time_copy,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
Expand Down Expand Up @@ -69,8 +69,9 @@ __global__ void ngram_match_mixed_search_kernel(
if (draft_budget <= 0 || remaining_dec <= 0) return;
int max_draft_tokens = static_cast<int>(min(draft_budget, remaining_dec));

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t prompt_len = prompt_lens[batch_idx];
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
const int64_t cur_input_ids_len = prompt_len;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];

Expand Down Expand Up @@ -228,16 +229,16 @@ static int sum_mixed_cpu(const int *value, int num) {
return sum_value;
}

static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *input_ids_len,
static void find_candidate_pred_tokens_mixed(const int64_t *token_ids_all,
const int64_t *prompt_lens,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t max_model_len,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
Expand Down Expand Up @@ -268,11 +269,12 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
int max_draft_tokens_query =
static_cast<int>(std::min(draft_budget, remaining_dec));

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t prompt_len = prompt_lens[batch_idx];
const int64_t *cur_input_ids = token_ids_all + batch_idx * max_model_len;
int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride;
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t cur_input_ids_len = prompt_len;
unprocessed_batch_size--;

auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx);
Expand Down Expand Up @@ -363,8 +365,8 @@ static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
// threshold enforcement + final token copy.
// ============================================================

void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
void HybridMtpNgram(const paddle::Tensor &token_ids_all,
const paddle::Tensor &prompt_lens,
const paddle::Tensor &pre_ids,
const paddle::Tensor &step_idx,
const paddle::Tensor &draft_token_num,
Expand All @@ -375,8 +377,7 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
const int max_ngram_size,
const int min_ngram_size,
const int max_draft_tokens) {
auto input_ids_shape = input_ids.shape();
const int64_t input_ids_stride = input_ids_shape[1];
const int64_t max_model_len = token_ids_all.shape()[1];

auto pre_ids_shape = pre_ids.shape();
const int64_t pre_ids_stride = pre_ids_shape[1];
Expand All @@ -392,8 +393,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
threshold = std::stoi(env_var);
}

if (input_ids.is_gpu()) {
auto stream = input_ids.stream();
if (token_ids_all.is_gpu()) {
auto stream = token_ids_all.stream();

// NOTE: GPU path does not pass seq_lens_decoder to kernels — the mixed
// variant uses ori_seq_len_this_time == 0 to skip inactive items. This
Expand All @@ -408,16 +409,16 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
auto draft_tokens_copy =
paddle::empty({max_batch_size, draft_tokens_stride},
paddle::DataType::INT64,
input_ids.place());
token_ids_all.place());

// Scratch copy of seq_lens_this_time (Phase 1 writes tentative counts)
auto seq_lens_this_time_copy = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());

// Save a copy of original seq_lens_this_time for Phase 2
// (Phase 1 reads from the original, Phase 2 needs ori values)
auto seq_lens_this_time_orig = paddle::empty(
{max_batch_size}, paddle::DataType::INT32, input_ids.place());
{max_batch_size}, paddle::DataType::INT32, token_ids_all.place());
cudaMemcpyAsync(seq_lens_this_time_orig.data<int32_t>(),
seq_lens_this_time.data<int32_t>(),
max_batch_size * sizeof(int32_t),
Expand All @@ -434,16 +435,16 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
NGRAM_BLOCK_THREADS,
0,
stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
seq_lens_this_time.data<int32_t>(),
max_dec_len.data<int64_t>(),
draft_tokens_copy.data<int64_t>(),
seq_lens_this_time_copy.data<int32_t>(),
input_ids_stride,
max_model_len,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
Expand All @@ -464,16 +465,16 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
threshold);
} else {
find_candidate_pred_tokens_mixed(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
token_ids_all.data<int64_t>(),
prompt_lens.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
max_model_len,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
Expand All @@ -484,8 +485,8 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
}

PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
.Inputs({"input_ids",
"input_ids_len",
.Inputs({"token_ids_all",
"prompt_lens",
"pre_ids",
"step_idx",
"draft_token_num",
Expand Down
13 changes: 0 additions & 13 deletions fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,16 +470,10 @@ def insert_tasks_v1(

input_ids = request.prompt_token_ids + request.output_token_ids

self.model_inputs["input_ids_len"][idx] = length - 1
async_set_value(self.model_inputs["pre_ids"][idx : idx + 1], -1)
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs["input_ids"][
idx : idx + 1, 1:length
]
# TODO: use token_all_ids replace with input_ids_cpu
if getattr(self, "hybrid_mode", False) and "input_ids_cpu" in self.model_inputs:
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length].cpu()
encoder_block_num = len(request.block_tables)
async_set_value(self.model_inputs["encoder_block_lens"][idx : idx + 1], encoder_block_num)
async_set_value(self.model_inputs["block_tables"][idx : idx + 1, :], -1)
Expand Down Expand Up @@ -567,17 +561,13 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
request = req_dicts[i]
idx = request.idx
length = len(request.prompt_token_ids)
self.model_inputs.input_ids_len[idx] = length - 1

if req_dicts[i].disaggregate_info is not None and req_dicts[i].disaggregate_info["role"] == "decode":
length = len(request.prompt_token_ids)
if length > 1:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = request.prompt_token_ids[-1]
prefill_token_num = self.max_draft_token_num + 1
self.model_inputs["draft_tokens"][idx : idx + 1, 0:1] = paddle.to_tensor(
Expand Down Expand Up @@ -606,9 +596,6 @@ def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests:
self.model_inputs["input_ids"][idx : idx + 1, : length - 1] = self.target_model_inputs[
"input_ids"
][idx : idx + 1, 1:length]
self.model_inputs["input_ids_cpu"][idx : idx + 1, : length - 1] = np.array(
request.prompt_token_ids
)[1:]
self.model_inputs["pre_ids"][idx : idx + 1] = -1
self.model_inputs["step_idx"][idx : idx + 1] = 0
if self.cache_config.enable_chunked_prefill:
Expand Down
5 changes: 2 additions & 3 deletions fastdeploy/spec_decode/mtp_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,10 +393,9 @@ def _update_status(self):
)

def _extend_draft_token_with_ngram_match(self):
# TODO: replace with gpu tensor
hybrid_mtp_ngram(
self.model_inputs["input_ids_cpu"].cuda(),
self.model_inputs["input_ids_len"].cuda(),
self.model_inputs["token_ids_all"],
self.model_inputs["prompt_lens"],
self.model_inputs["pre_ids"],
self.model_inputs["step_idx"],
self.target_model_inputs["actual_draft_token_num"],
Expand Down
17 changes: 2 additions & 15 deletions fastdeploy/worker/input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,6 @@ def init_share_inputs(self):

self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
self.input_ids_cpu = paddle.full(
shape=[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
fill_value=-1,
dtype="int64",
device="cpu",
)
self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"])

self.seq_lens_encoder = paddle.clone(self.target_model_input_batch["seq_lens_encoder"])
Expand All @@ -776,7 +770,7 @@ def init_share_inputs(self):
self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"])
self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"])
if "token_ids_all" in self.target_model_input_batch:
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
self.token_ids_all = self.target_model_input_batch["token_ids_all"]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug token_ids_all 改为直接引用 target 张量,但 swap_batch_slot(第 394 行)中 swap_data(self.token_ids_all, i1, i2) 会对该张量做 in-place 赋值(tensor[idx1] = tensor[idx2].clone()),从而直接污染 target_model_input_batch["token_ids_all"]

影响:多请求并发场景下,swap_batch_slot 调用后 target 张量的行顺序被原地改变,后续 prefill/decode 使用的 token_ids_all 数据错误,导致 ngram match 命中错误 token,推测 draft 质量下降甚至乱码。reset_model_inputs 中同一位置(第 973 行)存在相同问题。

建议修复

方案 A(推荐):token_ids_all 保留 paddle.clone(),保持与 target 数据独立,避免 swap 污染:

self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])

方案 B:删除 swap_data(self.token_ids_all, i1, i2) 调用,改为让 swap_batch_slot 直接读写 target_model_input_batch 的对应行——但需同步评估 target_model_input_batchtoken_ids_all 的排列维护逻辑。

# TODO: delete pre_ids in mtp
self.pre_ids = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
Expand Down Expand Up @@ -886,7 +880,6 @@ def init_share_inputs(self):
self.last_seq_lens_this_time = paddle.full_like(
self.target_model_input_batch["seq_lens_this_time"], fill_value=-1, dtype="int32"
)
self.input_ids_len = paddle.zeros(shape=[self.scheduler_config.max_num_seqs, 1], dtype="int64", device="cpu")
self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"]
self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"]
self.accept_num = self.target_model_input_batch["accept_num"]
Expand Down Expand Up @@ -936,14 +929,12 @@ def swap_data(tensor, idx1, idx2):
self.index_to_batch_id[i1], self.index_to_batch_id[i2] = self.index_to_batch_id[i2], self.index_to_batch_id[i1]
swap_data(self.block_tables, i1, i2)
swap_data(self.input_ids, i1, i2)
swap_data(self.input_ids_cpu, i1, i2)
swap_data(self.seq_lens_this_time_buffer, i1, i2)
swap_data(self.seq_lens_encoder, i1, i2)
swap_data(self.seq_lens_decoder, i1, i2)
swap_data(self.step_idx, i1, i2)
swap_data(self.pre_ids, i1, i2)
swap_data(self.encoder_block_lens, i1, i2)
swap_data(self.input_ids_len, i1, i2)
swap_data(self.mask_rollback, i1, i2)
swap_data(self.recompute_token_num, i1, i2)
if self.enable_mm:
Expand All @@ -966,7 +957,6 @@ def reset_model_inputs(self) -> None:
# Clone the target model inputs to restore initial values
self.block_tables = paddle.clone(self.target_model_input_batch["block_tables"])
self.input_ids = paddle.clone(self.target_model_input_batch["input_ids"])
fill_paddle_tensor(self, "input_ids_cpu", -1)
# acceptance rate decline when reset seq_lens_this_time
# self.seq_lens_this_time_buffer = paddle.clone(self.target_model_input_batch["seq_lens_this_time"])

Expand All @@ -980,7 +970,7 @@ def reset_model_inputs(self) -> None:
self.index_to_batch_id = {}
if current_platform.is_cuda():
if "token_ids_all" in self.target_model_input_batch:
self.token_ids_all = paddle.clone(self.target_model_input_batch["token_ids_all"])
self.token_ids_all = self.target_model_input_batch["token_ids_all"]
# TODO: delete pre_ids in mtp
self.pre_ids = paddle.full(
[self.scheduler_config.max_num_seqs, self.model_config.max_model_len],
Expand Down Expand Up @@ -1062,9 +1052,6 @@ def reset_model_inputs(self) -> None:
if self.num_model_steps > 1:
fill_paddle_tensor(self, "last_seq_lens_this_time", -1)

# Reset input IDs length
fill_paddle_tensor(self, "input_ids_len", 0)

# Reset various scores and flags
self.temp_scaled_logprobs = self.target_model_input_batch["temp_scaled_logprobs"]
self.top_p_normalized_logprobs = self.target_model_input_batch["top_p_normalized_logprobs"]
Expand Down
Loading
Loading